""" Inference module for MNIST digit classification. Provides a clean API for making predictions with the trained model. Handles image preprocessing and returns predictions with confidence scores. """ import torch from PIL import Image import numpy as np from pathlib import Path from typing import Union, Dict class DigitClassifier: """Production inference wrapper for MNIST digit classifier.""" def __init__(self, model_path: str, device: str = None): """ Initialize the digit classifier. Args: model_path: Path to model checkpoint (.pt file) device: Device to run inference on ('cuda' or 'cpu'). If None, auto-detects CUDA availability. """ if device is None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' else: self.device = device self.model_path = Path(model_path) if not self.model_path.exists(): raise FileNotFoundError(f"Model not found at {model_path}") self.model = self._load_model() self.model.eval() # Normalization values (same as training) self.mean = 0.1307 self.std = 0.3081 def _load_model(self) -> torch.nn.Module: """Load model from checkpoint.""" from scripts.models import BaselineCNN model = BaselineCNN() # Load checkpoint checkpoint = torch.load(self.model_path, map_location=self.device) # Handle different checkpoint formats if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) return model.to(self.device) def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor: """ Preprocess image for model input. Handles: - RGB to grayscale conversion - Resizing to 28x28 - Normalization - Inversion if needed (white digit on black background) Args: image: PIL Image or numpy array Returns: Preprocessed tensor of shape (1, 1, 28, 28) """ # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to grayscale if RGB if image.mode != 'L': image = image.convert('L') # Resize to 28x28 if needed if image.size != (28, 28): image = image.resize((28, 28), Image.Resampling.LANCZOS) # Convert to numpy array img_array = np.array(image).astype(np.float32) # Normalize to [0, 1] img_array = img_array / 255.0 # Check if inversion is needed (MNIST is white digit on black background) # If most pixels are bright, it's likely a black digit on white background if img_array.mean() > 0.5: img_array = 1.0 - img_array # Apply normalization (same as training) img_array = (img_array - self.mean) / self.std # Convert to tensor and add batch and channel dimensions img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0) return img_tensor.to(self.device) def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict: """ Predict digit from image. Args: image: PIL Image or numpy array containing digit Returns: Dictionary with: - digit: Predicted digit (0-9) - confidence: Confidence score (0-1) - probabilities: List of probabilities for each digit """ img_tensor = self.preprocess(image) with torch.no_grad(): outputs = self.model(img_tensor) probabilities = torch.softmax(outputs, dim=1)[0] confidence, predicted = torch.max(probabilities, dim=0) return { 'digit': int(predicted.item()), 'confidence': float(confidence.item()), 'probabilities': probabilities.cpu().numpy().tolist() } def predict_batch(self, images: list) -> list: """ Predict digits for a batch of images. Args: images: List of PIL Images or numpy arrays Returns: List of prediction dictionaries """ return [self.predict(img) for img in images] def test_inference(): """Test inference module with sample images.""" import sys from pathlib import Path # Add project root to path project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) from scripts.data_loader import MnistDataloader print("Testing Inference Module") print("=" * 50) # Check if model exists model_path = project_root / 'models' / 'best_model.pt' if not model_path.exists(): print(f"Error: Model not found at {model_path}") print("Please train a model first.") return # Load MNIST test data data_path = project_root / 'data' / 'raw' loader = MnistDataloader( training_images_filepath=str(data_path / 'train-images.idx3-ubyte'), training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'), test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'), test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte') ) _, (x_test, y_test) = loader.load_data() # Initialize classifier print(f"\n1. Loading model from: {model_path}") classifier = DigitClassifier(str(model_path)) print(f" Device: {classifier.device}") # Test on a few images print("\n2. Testing predictions on 10 random test images:") print("-" * 50) indices = np.random.choice(len(x_test), 10, replace=False) correct = 0 for i, idx in enumerate(indices, 1): image = x_test[idx] true_label = y_test[idx] # Convert list to numpy array if needed if isinstance(image, list): image = np.array(image) # Convert to PIL Image img = Image.fromarray(image.astype(np.uint8), mode='L') # Predict result = classifier.predict(img) is_correct = result['digit'] == true_label correct += is_correct print(f" Image {i}: True={true_label}, Pred={result['digit']}, " f"Conf={result['confidence']:.4f} {'āœ“' if is_correct else 'āœ—'}") accuracy = correct / len(indices) * 100 print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%") # Test edge cases print("\n3. Testing edge cases:") print("-" * 50) # Blank image blank = np.zeros((28, 28), dtype=np.uint8) blank_img = Image.fromarray(blank, mode='L') result = classifier.predict(blank_img) print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}") # All white image white = np.ones((28, 28), dtype=np.uint8) * 255 white_img = Image.fromarray(white, mode='L') result = classifier.predict(white_img) print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}") # Different size image test_img = x_test[0] if isinstance(test_img, list): test_img = np.array(test_img) large = Image.fromarray(test_img.astype(np.uint8), mode='L') large = large.resize((56, 56)) result = classifier.predict(large) print( f" Resized image (56x56): " f"Pred={result['digit']}, Conf={result['confidence']:.4f}" ) print("\nāœ“ Inference module test complete!") if __name__ == '__main__': test_inference()