mnist-digit-classifier / scripts /data_quality.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Data Quality Analysis Module
This module provides functions to systematically check MNIST dataset quality:
- Missing values detection
- Outlier detection (invalid pixel values)
- Class balance analysis
- Image dimension verification
- Comprehensive quality reporting
Usage:
from scripts.data_quality import generate_quality_report
report = generate_quality_report(
(x_train, y_train),
(x_test, y_test)
)
"""
from typing import Tuple, Dict, Any, List
import numpy as np
from numpy.typing import NDArray
from collections import Counter
def check_missing_values(
images: List[NDArray[np.uint8]],
labels: List[int]
) -> Dict[str, Any]:
"""
Check for NaN or missing values in images and labels.
Args:
images: List of image arrays (each 28x28)
labels: List of integer labels (0-9)
Returns:
dict: Contains 'has_missing_values', 'missing_count', 'details'
"""
# Check images for NaN
images_with_nan = []
for idx, img in enumerate(images):
img_array = np.array(img)
if np.isnan(img_array).any():
images_with_nan.append(idx)
# Check labels for None
labels_with_none = [idx for idx, label in enumerate(labels) if label is None]
has_missing = len(images_with_nan) > 0 or len(labels_with_none) > 0
return {
'has_missing_values': has_missing,
'missing_count': len(images_with_nan) + len(labels_with_none),
'details': {
'images_with_nan': len(images_with_nan),
'labels_with_none': len(labels_with_none),
'affected_indices': {
'images': images_with_nan[:10], # First 10 only
'labels': labels_with_none[:10]
}
}
}
def check_outliers(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
"""
Identify pixels outside valid range [0, 255] for uint8 images.
Args:
images: List of image arrays (each 28x28)
Returns:
dict: Contains 'has_outliers', 'outlier_count', 'pixel_range', 'details'
"""
outlier_images = []
pixel_min = 255
pixel_max = 0
for idx, img in enumerate(images):
img_array = np.array(img)
img_min = img_array.min()
img_max = img_array.max()
pixel_min = min(pixel_min, img_min)
pixel_max = max(pixel_max, img_max)
# Check for values outside [0, 255]
if img_min < 0 or img_max > 255:
outlier_images.append({
'index': idx,
'min': int(img_min),
'max': int(img_max)
})
return {
'has_outliers': len(outlier_images) > 0,
'outlier_count': len(outlier_images),
'pixel_range': {
'min': int(pixel_min),
'max': int(pixel_max)
},
'details': {
'affected_images': outlier_images[:10] # First 10 only
}
}
def check_class_balance(labels: List[int]) -> Dict[str, Any]:
"""
Compute samples per class and calculate imbalance ratio.
Imbalance ratio = max_count / min_count
A ratio < 1.2 indicates good balance (< 20% difference)
Args:
labels: List of integer labels (0-9)
Returns:
dict: Contains 'is_balanced', 'imbalance_ratio', 'class_counts', 'details'
"""
class_counts = Counter(labels)
# Ensure all 10 digits present
for digit in range(10):
if digit not in class_counts:
class_counts[digit] = 0
counts = list(class_counts.values())
max_count = max(counts)
min_count = min(counts) if min(counts) > 0 else 1 # Avoid division by zero
imbalance_ratio = max_count / min_count
is_balanced = imbalance_ratio < 1.2 # Less than 20% difference
# Per-class percentages
total = len(labels)
class_percentages = {
digit: (count / total) * 100
for digit, count in class_counts.items()
}
return {
'is_balanced': is_balanced,
'imbalance_ratio': round(imbalance_ratio, 3),
'threshold': 1.2,
'class_counts': dict(sorted(class_counts.items())),
'class_percentages': {
k: round(v, 2)
for k, v in sorted(class_percentages.items())
},
'details': {
'max_count': max_count,
'min_count': min_count,
'total_samples': total,
'most_common_class': class_counts.most_common(1)[0][0],
'least_common_class': min(class_counts, key=class_counts.get)
}
}
def check_image_dimensions(images: List[NDArray[np.uint8]]) -> Dict[str, Any]:
"""
Verify all images are 28x28 as expected for MNIST.
Args:
images: List of image arrays
Returns:
dict: Contains 'all_correct_shape', 'expected_shape', 'invalid_count', 'details'
"""
expected_shape = (28, 28)
invalid_images = []
for idx, img in enumerate(images):
img_array = np.array(img)
if img_array.shape != expected_shape:
invalid_images.append({
'index': idx,
'shape': img_array.shape
})
return {
'all_correct_shape': len(invalid_images) == 0,
'expected_shape': expected_shape,
'invalid_count': len(invalid_images),
'details': {
'total_checked': len(images),
'invalid_images': invalid_images[:10] # First 10 only
}
}
def check_label_validity(labels: List[int]) -> Dict[str, Any]:
"""
Verify all labels are valid integers in range [0, 9].
Args:
labels: List of labels
Returns:
dict: Contains 'all_valid', 'invalid_count', 'unique_labels', 'details'
"""
valid_range = set(range(10))
invalid_labels = []
for idx, label in enumerate(labels):
if not isinstance(label, int) or label not in valid_range:
invalid_labels.append({
'index': idx,
'value': label,
'type': type(label).__name__
})
unique_labels = sorted(set(labels))
return {
'all_valid': len(invalid_labels) == 0,
'expected_range': [0, 9],
'invalid_count': len(invalid_labels),
'unique_labels': unique_labels,
'details': {
'total_checked': len(labels),
'invalid_labels': invalid_labels[:10] # First 10 only
}
}
def generate_quality_report(
train_data: Tuple[List[NDArray[np.uint8]], List[int]],
test_data: Tuple[List[NDArray[np.uint8]], List[int]]
) -> Dict[str, Any]:
"""
Run all quality checks on training and test sets.
Args:
train_data: Tuple of (train_images, train_labels)
test_data: Tuple of (test_images, test_labels)
Returns:
dict: Comprehensive quality report with all check results
"""
x_train, y_train = train_data
x_test, y_test = test_data
report = {
'dataset_info': {
'train_samples': len(x_train),
'test_samples': len(x_test),
'total_samples': len(x_train) + len(x_test)
},
'training_set': {
'missing_values': check_missing_values(x_train, y_train),
'outliers': check_outliers(x_train),
'class_balance': check_class_balance(y_train),
'image_dimensions': check_image_dimensions(x_train),
'label_validity': check_label_validity(y_train)
},
'test_set': {
'missing_values': check_missing_values(x_test, y_test),
'outliers': check_outliers(x_test),
'class_balance': check_class_balance(y_test),
'image_dimensions': check_image_dimensions(x_test),
'label_validity': check_label_validity(y_test)
}
}
# Overall quality assessment
all_checks_pass = (
not report['training_set']['missing_values']['has_missing_values'] and
not report['training_set']['outliers']['has_outliers'] and
report['training_set']['image_dimensions']['all_correct_shape'] and
report['training_set']['label_validity']['all_valid'] and
not report['test_set']['missing_values']['has_missing_values'] and
not report['test_set']['outliers']['has_outliers'] and
report['test_set']['image_dimensions']['all_correct_shape'] and
report['test_set']['label_validity']['all_valid']
)
report['summary'] = {
'all_checks_pass': all_checks_pass,
'quality_rating': 'EXCELLENT' if all_checks_pass else 'ISSUES_FOUND',
'train_balanced': report['training_set']['class_balance']['is_balanced'],
'test_balanced': report['test_set']['class_balance']['is_balanced'],
'recommendations': _generate_recommendations(report)
}
return report
def _generate_recommendations(report: Dict[str, Any]) -> List[str]:
"""
Generate recommendations based on quality check results.
Args:
report: Quality report dictionary
Returns:
list: List of recommendation strings
"""
recommendations = []
# Check missing values
if report['training_set']['missing_values']['has_missing_values']:
recommendations.append(
"Remove or impute samples with missing values in training set"
)
if report['test_set']['missing_values']['has_missing_values']:
recommendations.append(
"Remove or impute samples with missing values in test set"
)
# Check outliers
if report['training_set']['outliers']['has_outliers']:
recommendations.append(
"Clip or remove training images with pixel values outside [0, 255]"
)
if report['test_set']['outliers']['has_outliers']:
recommendations.append(
"Clip or remove test images with pixel values outside [0, 255]"
)
# Check class balance
train_imbalance = report['training_set']['class_balance']['imbalance_ratio']
if train_imbalance >= 1.5:
recommendations.append(
f"Consider class rebalancing (imbalance ratio: {train_imbalance:.2f}) "
"using oversampling or weighted loss"
)
elif train_imbalance >= 1.2:
recommendations.append(
f"Minor class imbalance detected (ratio: {train_imbalance:.2f}). "
"Monitor per-class performance during training."
)
# Check dimensions
if not report['training_set']['image_dimensions']['all_correct_shape']:
recommendations.append(
"Resize or remove training images with incorrect dimensions"
)
if not report['test_set']['image_dimensions']['all_correct_shape']:
recommendations.append(
"Resize or remove test images with incorrect dimensions"
)
# Check labels
if not report['training_set']['label_validity']['all_valid']:
recommendations.append(
"Remove or correct training samples with invalid labels"
)
if not report['test_set']['label_validity']['all_valid']:
recommendations.append(
"Remove or correct test samples with invalid labels"
)
# If all checks pass
if not recommendations:
recommendations.append(
"Dataset is high quality - proceed with preprocessing and normalization"
)
return recommendations
def print_quality_summary(report: Dict[str, Any]) -> None:
"""
Print a human-readable summary of the quality report.
Args:
report: Quality report dictionary from generate_quality_report()
"""
print("=" * 60)
print("MNIST DATASET QUALITY REPORT")
print("=" * 60)
print()
# Dataset info
info = report['dataset_info']
print("Dataset Size:")
print(f" Training: {info['train_samples']:,} samples")
print(f" Test: {info['test_samples']:,} samples")
print(f" Total: {info['total_samples']:,} samples")
print()
# Training set checks
print("Training Set Quality Checks:")
train = report['training_set']
_print_check("Missing Values", not train['missing_values']['has_missing_values'])
_print_check("Outliers", not train['outliers']['has_outliers'])
_print_check("Image Dimensions", train['image_dimensions']['all_correct_shape'])
_print_check("Label Validity", train['label_validity']['all_valid'])
_print_check(
f"Class Balance (ratio: {train['class_balance']['imbalance_ratio']})",
train['class_balance']['is_balanced']
)
print()
# Test set checks
print("Test Set Quality Checks:")
test = report['test_set']
_print_check("Missing Values", not test['missing_values']['has_missing_values'])
_print_check("Outliers", not test['outliers']['has_outliers'])
_print_check("Image Dimensions", test['image_dimensions']['all_correct_shape'])
_print_check("Label Validity", test['label_validity']['all_valid'])
_print_check(
f"Class Balance (ratio: {test['class_balance']['imbalance_ratio']})",
test['class_balance']['is_balanced']
)
print()
# Overall summary
summary = report['summary']
print("=" * 60)
print(f"Overall Quality: {summary['quality_rating']}")
print("=" * 60)
print()
# Recommendations
print("Recommendations:")
for i, rec in enumerate(summary['recommendations'], 1):
print(f" {i}. {rec}")
print()
def _print_check(name: str, passed: bool) -> None:
"""Helper function to print check results with colored status."""
status = "✓ PASS" if passed else "✗ FAIL"
print(f" {name:<40} {status}")