""" Data Quality Analysis Module This module provides functions to systematically check MNIST dataset quality: - Missing values detection - Outlier detection (invalid pixel values) - Class balance analysis - Image dimension verification - Comprehensive quality reporting Usage: from scripts.data_quality import generate_quality_report report = generate_quality_report( (x_train, y_train), (x_test, y_test) ) """ from typing import Tuple, Dict, Any, List import numpy as np from numpy.typing import NDArray from collections import Counter def check_missing_values( images: List[NDArray[np.uint8]], labels: List[int] ) -> Dict[str, Any]: """ Check for NaN or missing values in images and labels. Args: images: List of image arrays (each 28x28) labels: List of integer labels (0-9) Returns: dict: Contains 'has_missing_values', 'missing_count', 'details' """ # Check images for NaN images_with_nan = [] for idx, img in enumerate(images): img_array = np.array(img) if np.isnan(img_array).any(): images_with_nan.append(idx) # Check labels for None labels_with_none = [idx for idx, label in enumerate(labels) if label is None] has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0 return { 'has_missing_values': has_missing, 'missing_count': len(images_with_nan) + len(labels_with_none), 'details': { 'images_with_nan': len(images_with_nan), 'labels_with_none': len(labels_with_none), 'affected_indices': { 'images': images_with_nan[:10], # First 10 only 'labels': labels_with_none[:10] } } } def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]: """ Identify pixels outside valid range [0, 255] for uint8 images. Args: images: List of image arrays (each 28x28) Returns: dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details' """ outlier_images = [] pixel_min = 255 pixel_max = 0 for idx, img in enumerate(images): img_array = np.array(img) img_min = img_array.min() img_max = img_array.max() pixel_min = min(pixel_min, img_min) pixel_max = max(pixel_max, img_max) # Check for values outside [0, 255] if img_min < 0 or img_max > 255: outlier_images.append({ 'index': idx, 'min': int(img_min), 'max': int(img_max) }) return { 'has_outliers': len(outlier_images) > 0, 'outlier_count': len(outlier_images), 'pixel_range': { 'min': int(pixel_min), 'max': int(pixel_max) }, 'details': { 'affected_images': outlier_images[:10] # First 10 only } } def check_class_balance(labels: List[int]) -> Dict[str, Any]: """ Compute samples per class and calculate imbalance ratio. Imbalance ratio = max_count / min_count A ratio < 1.2 indicates good balance (< 20% difference) Args: labels: List of integer labels (0-9) Returns: dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details' """ class_counts = Counter(labels) # Ensure all 10 digits present for digit in range(10): if digit not in class_counts: class_counts[digit] = 0 counts = list(class_counts.values()) max_count = max(counts) min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero imbalance_ratio = max_count / min_count is_balanced = imbalance_ratio < 1.2 # Less than 20% difference # Per-class percentages total = len(labels) class_percentages = { digit: (count / total) * 100 for digit, count in class_counts.items() } return { 'is_balanced': is_balanced, 'imbalance_ratio': round(imbalance_ratio, 3), 'threshold': 1.2, 'class_counts': dict(sorted(class_counts.items())), 'class_percentages': { k: round(v, 2) for k, v in sorted(class_percentages.items()) }, 'details': { 'max_count': max_count, 'min_count': min_count, 'total_samples': total, 'most_common_class': class_counts.most_common(1)[0][0], 'least_common_class': min(class_counts, key=class_counts.get) } } def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]: """ Verify all images are 28x28 as expected for MNIST. Args: images: List of image arrays Returns: dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details' """ expected_shape = (28, 28) invalid_images = [] for idx, img in enumerate(images): img_array = np.array(img) if img_array.shape != expected_shape: invalid_images.append({ 'index': idx, 'shape': img_array.shape }) return { 'all_correct_shape': len(invalid_images) == 0, 'expected_shape': expected_shape, 'invalid_count': len(invalid_images), 'details': { 'total_checked': len(images), 'invalid_images': invalid_images[:10] # First 10 only } } def check_label_validity(labels: List[int]) -> Dict[str, Any]: """ Verify all labels are valid integers in range [0, 9]. Args: labels: List of labels Returns: dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details' """ valid_range = set(range(10)) invalid_labels = [] for idx, label in enumerate(labels): if not isinstance(label, int) or label not in valid_range: invalid_labels.append({ 'index': idx, 'value': label, 'type': type(label).__name__ }) unique_labels = sorted(set(labels)) return { 'all_valid': len(invalid_labels) == 0, 'expected_range': [0, 9], 'invalid_count': len(invalid_labels), 'unique_labels': unique_labels, 'details': { 'total_checked': len(labels), 'invalid_labels': invalid_labels[:10] # First 10 only } } def generate_quality_report( train_data: Tuple[List[NDArray[np.uint8]], List[int]], test_data: Tuple[List[NDArray[np.uint8]], List[int]] ) -> Dict[str, Any]: """ Run all quality checks on training and test sets. Args: train_data: Tuple of (train_images, train_labels) test_data: Tuple of (test_images, test_labels) Returns: dict: Comprehensive quality report with all check results """ x_train, y_train = train_data x_test, y_test = test_data report = { 'dataset_info': { 'train_samples': len(x_train), 'test_samples': len(x_test), 'total_samples': len(x_train) + len(x_test) }, 'training_set': { 'missing_values': check_missing_values(x_train, y_train), 'outliers': check_outliers(x_train), 'class_balance': check_class_balance(y_train), 'image_dimensions': check_image_dimensions(x_train), 'label_validity': check_label_validity(y_train) }, 'test_set': { 'missing_values': check_missing_values(x_test, y_test), 'outliers': check_outliers(x_test), 'class_balance': check_class_balance(y_test), 'image_dimensions': check_image_dimensions(x_test), 'label_validity': check_label_validity(y_test) } } # Overall quality assessment all_checks_pass = ( not report['training_set']['missing_values']['has_missing_values'] and not report['training_set']['outliers']['has_outliers'] and report['training_set']['image_dimensions']['all_correct_shape'] and report['training_set']['label_validity']['all_valid'] and not report['test_set']['missing_values']['has_missing_values'] and not report['test_set']['outliers']['has_outliers'] and report['test_set']['image_dimensions']['all_correct_shape'] and report['test_set']['label_validity']['all_valid'] ) report['summary'] = { 'all_checks_pass': all_checks_pass, 'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND', 'train_balanced': report['training_set']['class_balance']['is_balanced'], 'test_balanced': report['test_set']['class_balance']['is_balanced'], 'recommendations': _generate_recommendations(report) } return report def _generate_recommendations(report: Dict[str, Any]) -> List[str]: """ Generate recommendations based on quality check results. Args: report: Quality report dictionary Returns: list: List of recommendation strings """ recommendations = [] # Check missing values if report['training_set']['missing_values']['has_missing_values']: recommendations.append( "Remove or impute samples with missing values in training set" ) if report['test_set']['missing_values']['has_missing_values']: recommendations.append( "Remove or impute samples with missing values in test set" ) # Check outliers if report['training_set']['outliers']['has_outliers']: recommendations.append( "Clip or remove training images with pixel values outside [0, 255]" ) if report['test_set']['outliers']['has_outliers']: recommendations.append( "Clip or remove test images with pixel values outside [0, 255]" ) # Check class balance train_imbalance = report['training_set']['class_balance']['imbalance_ratio'] if train_imbalance >= 1.5: recommendations.append( f"Consider class rebalancing (imbalance ratio: {train_imbalance:.2f}) " "using oversampling or weighted loss" ) elif train_imbalance >= 1.2: recommendations.append( f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). " "Monitor per-class performance during training." ) # Check dimensions if not report['training_set']['image_dimensions']['all_correct_shape']: recommendations.append( "Resize or remove training images with incorrect dimensions" ) if not report['test_set']['image_dimensions']['all_correct_shape']: recommendations.append( "Resize or remove test images with incorrect dimensions" ) # Check labels if not report['training_set']['label_validity']['all_valid']: recommendations.append( "Remove or correct training samples with invalid labels" ) if not report['test_set']['label_validity']['all_valid']: recommendations.append( "Remove or correct test samples with invalid labels" ) # If all checks pass if not recommendations: recommendations.append( "Dataset is high quality - proceed with preprocessing and normalization" ) return recommendations def print_quality_summary(report: Dict[str, Any]) -> None: """ Print a human-readable summary of the quality report. Args: report: Quality report dictionary from generate_quality_report() """ print("=" * 60) print("MNIST DATASET QUALITY REPORT") print("=" * 60) print() # Dataset info info = report['dataset_info'] print("Dataset Size:") print(f" Training: {info['train_samples']:,} samples") print(f" Test: {info['test_samples']:,} samples") print(f" Total: {info['total_samples']:,} samples") print() # Training set checks print("Training Set Quality Checks:") train = report['training_set'] _print_check("Missing Values", not train['missing_values']['has_missing_values']) _print_check("Outliers", not train['outliers']['has_outliers']) _print_check("Image Dimensions", train['image_dimensions']['all_correct_shape']) _print_check("Label Validity", train['label_validity']['all_valid']) _print_check( f"Class Balance (ratio: {train['class_balance']['imbalance_ratio']})", train['class_balance']['is_balanced'] ) print() # Test set checks print("Test Set Quality Checks:") test = report['test_set'] _print_check("Missing Values", not test['missing_values']['has_missing_values']) _print_check("Outliers", not test['outliers']['has_outliers']) _print_check("Image Dimensions", test['image_dimensions']['all_correct_shape']) _print_check("Label Validity", test['label_validity']['all_valid']) _print_check( f"Class Balance (ratio: {test['class_balance']['imbalance_ratio']})", test['class_balance']['is_balanced'] ) print() # Overall summary summary = report['summary'] print("=" * 60) print(f"Overall Quality: {summary['quality_rating']}") print("=" * 60) print() # Recommendations print("Recommendations:") for i, rec in enumerate(summary['recommendations'], 1): print(f" {i}. {rec}") print() def _print_check(name: str, passed: bool) -> None: """Helper function to print check results with colored status.""" status = "✓ PASS" if passed else "✗ FAIL" print(f" {name:<40} {status}")