mnist-digit-classifier / scripts /test_data_loader.py
faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Test script for data_loader module.
Run from project root:
conda activate ai_engg
python scripts/test_data_loader.py
"""
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from scripts.data_loader import MnistDataloader
def test_data_loader():
"""Test MNIST data loader with actual files."""
print("Testing MNIST Data Loader...")
print("-" * 50)
# Note: Files need to be uncompressed first
# If .gz files exist, uncompress with: gunzip data/raw/*.gz
base_path = Path(__file__).parent.parent / "data" / "raw"
# Try to find uncompressed files
train_images = base_path / "train-images.idx3-ubyte"
train_labels = base_path / "train-labels.idx1-ubyte"
test_images = base_path / "t10k-images.idx3-ubyte"
test_labels = base_path / "t10k-labels.idx1-ubyte"
# Check if files exist
missing_files = []
for filepath in [train_images, train_labels, test_images, test_labels]:
if not filepath.exists():
missing_files.append(str(filepath))
if missing_files:
print("⚠️ Missing uncompressed data files:")
for f in missing_files:
print(f" - {f}")
print("\nTo uncompress .gz files, run:")
print(" cd data/raw && gunzip *.gz")
return False
try:
# Initialize loader
loader = MnistDataloader(
str(train_images),
str(train_labels),
str(test_images),
str(test_labels)
)
print("βœ“ Loader initialized successfully")
# Load data
print("\nLoading MNIST dataset...")
(x_train, y_train), (x_test, y_test) = loader.load_data()
# Verify shapes
print(f"\nβœ“ Training set: {len(x_train):,} images, {len(y_train):,} labels")
print(f"βœ“ Test set: {len(x_test):,} images, {len(y_test):,} labels")
# Convert first image to numpy array to check
import numpy as np
first_img = np.array(x_train[0])
print(f"\nβœ“ Image shape: {first_img.shape}")
print(f"βœ“ Image dtype: {first_img.dtype}")
print(f"βœ“ Label type: {type(y_train[0])}")
# Verify label range
unique_labels = set(y_train + y_test)
print(f"\nβœ“ Unique labels: {sorted(unique_labels)}")
# Verify pixel value range (convert to numpy for analysis)
sample_images = [np.array(img) for img in x_train[:100]]
max_val = max(img.max() for img in sample_images)
min_val = min(img.min() for img in sample_images)
print(f"βœ“ Pixel value range (sample): [{min_val}, {max_val}]")
print("\n" + "=" * 50)
print("βœ… All tests passed!")
print("=" * 50)
return True
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_data_loader()
sys.exit(0 if success else 1)