Spaces:
Sleeping
Sleeping
| """ | |
| Data Quality Analysis Module | |
| This module provides functions to systematically check MNIST dataset quality: | |
| - Missing values detection | |
| - Outlier detection (invalid pixel values) | |
| - Class balance analysis | |
| - Image dimension verification | |
| - Comprehensive quality reporting | |
| Usage: | |
| from scripts.data_quality import generate_quality_report | |
| report = generate_quality_report( | |
| (x_train, y_train), | |
| (x_test, y_test) | |
| ) | |
| """ | |
| from typing import Tuple, Dict, Any, List | |
| import numpy as np | |
| from numpy.typing import NDArray | |
| from collections import Counter | |
| def check_missing_values( | |
| images: List[NDArray[np.uint8]], | |
| labels: List[int] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Check for NaN or missing values in images and labels. | |
| Args: | |
| images: List of image arrays (each 28x28) | |
| labels: List of integer labels (0-9) | |
| Returns: | |
| dict: Contains 'has_missing_values', 'missing_count', 'details' | |
| """ | |
| # Check images for NaN | |
| images_with_nan = [] | |
| for idx, img in enumerate(images): | |
| img_array = np.array(img) | |
| if np.isnan(img_array).any(): | |
| images_with_nan.append(idx) | |
| # Check labels for None | |
| labels_with_none = [idx for idx, label in enumerate(labels) if label is None] | |
| has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0 | |
| return { | |
| 'has_missing_values': has_missing, | |
| 'missing_count': len(images_with_nan) + len(labels_with_none), | |
| 'details': { | |
| 'images_with_nan': len(images_with_nan), | |
| 'labels_with_none': len(labels_with_none), | |
| 'affected_indices': { | |
| 'images': images_with_nan[:10], # First 10 only | |
| 'labels': labels_with_none[:10] | |
| } | |
| } | |
| } | |
| def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]: | |
| """ | |
| Identify pixels outside valid range [0, 255] for uint8 images. | |
| Args: | |
| images: List of image arrays (each 28x28) | |
| Returns: | |
| dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details' | |
| """ | |
| outlier_images = [] | |
| pixel_min = 255 | |
| pixel_max = 0 | |
| for idx, img in enumerate(images): | |
| img_array = np.array(img) | |
| img_min = img_array.min() | |
| img_max = img_array.max() | |
| pixel_min = min(pixel_min, img_min) | |
| pixel_max = max(pixel_max, img_max) | |
| # Check for values outside [0, 255] | |
| if img_min < 0 or img_max > 255: | |
| outlier_images.append({ | |
| 'index': idx, | |
| 'min': int(img_min), | |
| 'max': int(img_max) | |
| }) | |
| return { | |
| 'has_outliers': len(outlier_images) > 0, | |
| 'outlier_count': len(outlier_images), | |
| 'pixel_range': { | |
| 'min': int(pixel_min), | |
| 'max': int(pixel_max) | |
| }, | |
| 'details': { | |
| 'affected_images': outlier_images[:10] # First 10 only | |
| } | |
| } | |
| def check_class_balance(labels: List[int]) -> Dict[str, Any]: | |
| """ | |
| Compute samples per class and calculate imbalance ratio. | |
| Imbalance ratio = max_count / min_count | |
| A ratio < 1.2 indicates good balance (< 20% difference) | |
| Args: | |
| labels: List of integer labels (0-9) | |
| Returns: | |
| dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details' | |
| """ | |
| class_counts = Counter(labels) | |
| # Ensure all 10 digits present | |
| for digit in range(10): | |
| if digit not in class_counts: | |
| class_counts[digit] = 0 | |
| counts = list(class_counts.values()) | |
| max_count = max(counts) | |
| min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero | |
| imbalance_ratio = max_count / min_count | |
| is_balanced = imbalance_ratio < 1.2 # Less than 20% difference | |
| # Per-class percentages | |
| total = len(labels) | |
| class_percentages = { | |
| digit: (count / total) * 100 | |
| for digit, count in class_counts.items() | |
| } | |
| return { | |
| 'is_balanced': is_balanced, | |
| 'imbalance_ratio': round(imbalance_ratio, 3), | |
| 'threshold': 1.2, | |
| 'class_counts': dict(sorted(class_counts.items())), | |
| 'class_percentages': { | |
| k: round(v, 2) | |
| for k, v in sorted(class_percentages.items()) | |
| }, | |
| 'details': { | |
| 'max_count': max_count, | |
| 'min_count': min_count, | |
| 'total_samples': total, | |
| 'most_common_class': class_counts.most_common(1)[0][0], | |
| 'least_common_class': min(class_counts, key=class_counts.get) | |
| } | |
| } | |
| def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]: | |
| """ | |
| Verify all images are 28x28 as expected for MNIST. | |
| Args: | |
| images: List of image arrays | |
| Returns: | |
| dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details' | |
| """ | |
| expected_shape = (28, 28) | |
| invalid_images = [] | |
| for idx, img in enumerate(images): | |
| img_array = np.array(img) | |
| if img_array.shape != expected_shape: | |
| invalid_images.append({ | |
| 'index': idx, | |
| 'shape': img_array.shape | |
| }) | |
| return { | |
| 'all_correct_shape': len(invalid_images) == 0, | |
| 'expected_shape': expected_shape, | |
| 'invalid_count': len(invalid_images), | |
| 'details': { | |
| 'total_checked': len(images), | |
| 'invalid_images': invalid_images[:10] # First 10 only | |
| } | |
| } | |
| def check_label_validity(labels: List[int]) -> Dict[str, Any]: | |
| """ | |
| Verify all labels are valid integers in range [0, 9]. | |
| Args: | |
| labels: List of labels | |
| Returns: | |
| dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details' | |
| """ | |
| valid_range = set(range(10)) | |
| invalid_labels = [] | |
| for idx, label in enumerate(labels): | |
| if not isinstance(label, int) or label not in valid_range: | |
| invalid_labels.append({ | |
| 'index': idx, | |
| 'value': label, | |
| 'type': type(label).__name__ | |
| }) | |
| unique_labels = sorted(set(labels)) | |
| return { | |
| 'all_valid': len(invalid_labels) == 0, | |
| 'expected_range': [0, 9], | |
| 'invalid_count': len(invalid_labels), | |
| 'unique_labels': unique_labels, | |
| 'details': { | |
| 'total_checked': len(labels), | |
| 'invalid_labels': invalid_labels[:10] # First 10 only | |
| } | |
| } | |
| def generate_quality_report( | |
| train_data: Tuple[List[NDArray[np.uint8]], List[int]], | |
| test_data: Tuple[List[NDArray[np.uint8]], List[int]] | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run all quality checks on training and test sets. | |
| Args: | |
| train_data: Tuple of (train_images, train_labels) | |
| test_data: Tuple of (test_images, test_labels) | |
| Returns: | |
| dict: Comprehensive quality report with all check results | |
| """ | |
| x_train, y_train = train_data | |
| x_test, y_test = test_data | |
| report = { | |
| 'dataset_info': { | |
| 'train_samples': len(x_train), | |
| 'test_samples': len(x_test), | |
| 'total_samples': len(x_train) + len(x_test) | |
| }, | |
| 'training_set': { | |
| 'missing_values': check_missing_values(x_train, y_train), | |
| 'outliers': check_outliers(x_train), | |
| 'class_balance': check_class_balance(y_train), | |
| 'image_dimensions': check_image_dimensions(x_train), | |
| 'label_validity': check_label_validity(y_train) | |
| }, | |
| 'test_set': { | |
| 'missing_values': check_missing_values(x_test, y_test), | |
| 'outliers': check_outliers(x_test), | |
| 'class_balance': check_class_balance(y_test), | |
| 'image_dimensions': check_image_dimensions(x_test), | |
| 'label_validity': check_label_validity(y_test) | |
| } | |
| } | |
| # Overall quality assessment | |
| all_checks_pass = ( | |
| not report['training_set']['missing_values']['has_missing_values'] and | |
| not report['training_set']['outliers']['has_outliers'] and | |
| report['training_set']['image_dimensions']['all_correct_shape'] and | |
| report['training_set']['label_validity']['all_valid'] and | |
| not report['test_set']['missing_values']['has_missing_values'] and | |
| not report['test_set']['outliers']['has_outliers'] and | |
| report['test_set']['image_dimensions']['all_correct_shape'] and | |
| report['test_set']['label_validity']['all_valid'] | |
| ) | |
| report['summary'] = { | |
| 'all_checks_pass': all_checks_pass, | |
| 'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND', | |
| 'train_balanced': report['training_set']['class_balance']['is_balanced'], | |
| 'test_balanced': report['test_set']['class_balance']['is_balanced'], | |
| 'recommendations': _generate_recommendations(report) | |
| } | |
| return report | |
| def _generate_recommendations(report: Dict[str, Any]) -> List[str]: | |
| """ | |
| Generate recommendations based on quality check results. | |
| Args: | |
| report: Quality report dictionary | |
| Returns: | |
| list: List of recommendation strings | |
| """ | |
| recommendations = [] | |
| # Check missing values | |
| if report['training_set']['missing_values']['has_missing_values']: | |
| recommendations.append( | |
| "Remove or impute samples with missing values in training set" | |
| ) | |
| if report['test_set']['missing_values']['has_missing_values']: | |
| recommendations.append( | |
| "Remove or impute samples with missing values in test set" | |
| ) | |
| # Check outliers | |
| if report['training_set']['outliers']['has_outliers']: | |
| recommendations.append( | |
| "Clip or remove training images with pixel values outside [0, 255]" | |
| ) | |
| if report['test_set']['outliers']['has_outliers']: | |
| recommendations.append( | |
| "Clip or remove test images with pixel values outside [0, 255]" | |
| ) | |
| # Check class balance | |
| train_imbalance = report['training_set']['class_balance']['imbalance_ratio'] | |
| if train_imbalance >= 1.5: | |
| recommendations.append( | |
| f"Consider class rebalancing (imbalance ratio: {train_imbalance:.2f}) " | |
| "using oversampling or weighted loss" | |
| ) | |
| elif train_imbalance >= 1.2: | |
| recommendations.append( | |
| f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). " | |
| "Monitor per-class performance during training." | |
| ) | |
| # Check dimensions | |
| if not report['training_set']['image_dimensions']['all_correct_shape']: | |
| recommendations.append( | |
| "Resize or remove training images with incorrect dimensions" | |
| ) | |
| if not report['test_set']['image_dimensions']['all_correct_shape']: | |
| recommendations.append( | |
| "Resize or remove test images with incorrect dimensions" | |
| ) | |
| # Check labels | |
| if not report['training_set']['label_validity']['all_valid']: | |
| recommendations.append( | |
| "Remove or correct training samples with invalid labels" | |
| ) | |
| if not report['test_set']['label_validity']['all_valid']: | |
| recommendations.append( | |
| "Remove or correct test samples with invalid labels" | |
| ) | |
| # If all checks pass | |
| if not recommendations: | |
| recommendations.append( | |
| "Dataset is high quality - proceed with preprocessing and normalization" | |
| ) | |
| return recommendations | |
| def print_quality_summary(report: Dict[str, Any]) -> None: | |
| """ | |
| Print a human-readable summary of the quality report. | |
| Args: | |
| report: Quality report dictionary from generate_quality_report() | |
| """ | |
| print("=" * 60) | |
| print("MNIST DATASET QUALITY REPORT") | |
| print("=" * 60) | |
| print() | |
| # Dataset info | |
| info = report['dataset_info'] | |
| print("Dataset Size:") | |
| print(f" Training: {info['train_samples']:,} samples") | |
| print(f" Test: {info['test_samples']:,} samples") | |
| print(f" Total: {info['total_samples']:,} samples") | |
| print() | |
| # Training set checks | |
| print("Training Set Quality Checks:") | |
| train = report['training_set'] | |
| _print_check("Missing Values", not train['missing_values']['has_missing_values']) | |
| _print_check("Outliers", not train['outliers']['has_outliers']) | |
| _print_check("Image Dimensions", train['image_dimensions']['all_correct_shape']) | |
| _print_check("Label Validity", train['label_validity']['all_valid']) | |
| _print_check( | |
| f"Class Balance (ratio: {train['class_balance']['imbalance_ratio']})", | |
| train['class_balance']['is_balanced'] | |
| ) | |
| print() | |
| # Test set checks | |
| print("Test Set Quality Checks:") | |
| test = report['test_set'] | |
| _print_check("Missing Values", not test['missing_values']['has_missing_values']) | |
| _print_check("Outliers", not test['outliers']['has_outliers']) | |
| _print_check("Image Dimensions", test['image_dimensions']['all_correct_shape']) | |
| _print_check("Label Validity", test['label_validity']['all_valid']) | |
| _print_check( | |
| f"Class Balance (ratio: {test['class_balance']['imbalance_ratio']})", | |
| test['class_balance']['is_balanced'] | |
| ) | |
| print() | |
| # Overall summary | |
| summary = report['summary'] | |
| print("=" * 60) | |
| print(f"Overall Quality: {summary['quality_rating']}") | |
| print("=" * 60) | |
| print() | |
| # Recommendations | |
| print("Recommendations:") | |
| for i, rec in enumerate(summary['recommendations'], 1): | |
| print(f" {i}. {rec}") | |
| print() | |
| def _print_check(name: str, passed: bool) -> None: | |
| """Helper function to print check results with colored status.""" | |
| status = "✓ PASS" if passed else "✗ FAIL" | |
| print(f" {name:<40} {status}") | |