faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Test training pipeline with small subset of data.
Verifies:
- Training loop runs without errors
- Validation metrics computed correctly
- Model checkpoints saved
- History tracking works
"""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
import torch
from scripts.data_loader import MnistDataloader
from scripts.preprocessing import MnistDataset, create_dataloaders, split_train_val
from scripts.models import BaselineCNN
from scripts.train import train_model, evaluate_model, save_training_history
def main():
"""Test training pipeline with small dataset."""
print("=" * 60)
print("Testing Training Pipeline")
print("=" * 60)
print()
# Load data (small subset for quick test)
print("1. Loading MNIST data...")
data_path = project_root / "data" / "raw"
loader = MnistDataloader(
str(data_path / "train-images.idx3-ubyte"),
str(data_path / "train-labels.idx1-ubyte"),
str(data_path / "t10k-images.idx3-ubyte"),
str(data_path / "t10k-labels.idx1-ubyte")
)
(x_train, y_train), (x_test, y_test) = loader.load_data()
print(f"βœ“ Loaded {len(x_train):,} training samples")
print()
# Use small subset for quick test (1000 samples)
print("2. Creating train/val split...")
(x_train_split, y_train_split), (x_val, y_val) = split_train_val(
x_train[:1000], y_train[:1000], val_split=0.2, random_seed=42
)
print(f"βœ“ Train: {len(x_train_split)} samples")
print(f"βœ“ Val: {len(x_val)} samples")
print()
# Create datasets and loaders
print("3. Creating datasets and loaders...")
train_dataset = MnistDataset(x_train_split, y_train_split, transform=None)
val_dataset = MnistDataset(x_val, y_val, transform=None)
# Small test set for quick validation
test_dataset = MnistDataset(x_test[:200], y_test[:200], transform=None)
train_loader, val_loader = create_dataloaders(
train_dataset, val_dataset, batch_size=32, num_workers=0
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=False, num_workers=0
)
print(f"βœ“ Train batches: {len(train_loader)}")
print(f"βœ“ Val batches: {len(val_loader)}")
print(f"βœ“ Test batches: {len(test_loader)}")
print()
# Create model
print("4. Creating model...")
model = BaselineCNN()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"βœ“ Model: {model.__class__.__name__}")
print(f"βœ“ Device: {device}")
print()
# Train model (short run for testing)
print("5. Training model (3 epochs for testing)...")
print("-" * 60)
history = train_model(
model,
train_loader,
val_loader,
num_epochs=3,
learning_rate=0.001,
patience=10, # Don't trigger early stopping in test
checkpoint_dir='models',
device=device,
use_scheduler=True,
verbose=True
)
print("-" * 60)
print()
# Check checkpoints exist
print("6. Verifying checkpoints...")
best_model_path = project_root / "models" / "best_model.pt"
last_model_path = project_root / "models" / "last_model.pt"
assert best_model_path.exists(), "Best model checkpoint not found"
assert last_model_path.exists(), "Last model checkpoint not found"
print("βœ“ Best model saved")
print("βœ“ Last model saved")
print()
# Save history
print("7. Saving training history...")
history_path = project_root / "experiments" / "test_training_history.json"
save_training_history(history, str(history_path))
print()
# Evaluate on test set
print("8. Evaluating on test set...")
results = evaluate_model(model, test_loader, device=device)
print(f"βœ“ Test Accuracy: {results['accuracy']:.2f}%")
print()
# Print per-class metrics
print("Per-class metrics:")
report = results['classification_report']
for digit in range(10):
if str(digit) in report:
metrics = report[str(digit)]
print(f" Digit {digit}: "
f"Precision={metrics['precision']:.3f}, "
f"Recall={metrics['recall']:.3f}, "
f"F1={metrics['f1-score']:.3f}")
print()
# Summary
print("=" * 60)
print("βœ… ALL TESTS PASSED")
print("=" * 60)
print("\nTraining pipeline is working correctly!")
print(f"Final validation accuracy: {history['val_accuracy'][-1]:.2f}%")
print(f"Test accuracy: {results['accuracy']:.2f}%")
print("\nNote: These are quick test results with limited data.")
print("For full training, use complete dataset and more epochs.")
return 0
if __name__ == "__main__":
sys.exit(main())