Spaces:
Sleeping
Sleeping
| """ | |
| Inference module for MNIST digit classification. | |
| Provides a clean API for making predictions with the trained model. | |
| Handles image preprocessing and returns predictions with confidence scores. | |
| """ | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Union, Dict | |
| class DigitClassifier: | |
| """Production inference wrapper for MNIST digit classifier.""" | |
| def __init__(self, model_path: str, device: str = None): | |
| """ | |
| Initialize the digit classifier. | |
| Args: | |
| model_path: Path to model checkpoint (.pt file) | |
| device: Device to run inference on ('cuda' or 'cpu'). | |
| If None, auto-detects CUDA availability. | |
| """ | |
| if device is None: | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| else: | |
| self.device = device | |
| self.model_path = Path(model_path) | |
| if not self.model_path.exists(): | |
| raise FileNotFoundError(f"Model not found at {model_path}") | |
| self.model = self._load_model() | |
| self.model.eval() | |
| # Normalization values (same as training) | |
| self.mean = 0.1307 | |
| self.std = 0.3081 | |
| def _load_model(self) -> torch.nn.Module: | |
| """Load model from checkpoint.""" | |
| from scripts.models import BaselineCNN | |
| model = BaselineCNN() | |
| # Load checkpoint | |
| checkpoint = torch.load(self.model_path, map_location=self.device) | |
| # Handle different checkpoint formats | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| return model.to(self.device) | |
| def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor: | |
| """ | |
| Preprocess image for model input. | |
| Handles: | |
| - RGB to grayscale conversion | |
| - Resizing to 28x28 | |
| - Normalization | |
| - Inversion if needed (white digit on black background) | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| Preprocessed tensor of shape (1, 1, 28, 28) | |
| """ | |
| # Convert numpy array to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Convert to grayscale if RGB | |
| if image.mode != 'L': | |
| image = image.convert('L') | |
| # Resize to 28x28 if needed | |
| if image.size != (28, 28): | |
| image = image.resize((28, 28), Image.Resampling.LANCZOS) | |
| # Convert to numpy array | |
| img_array = np.array(image).astype(np.float32) | |
| # Normalize to [0, 1] | |
| img_array = img_array / 255.0 | |
| # Check if inversion is needed (MNIST is white digit on black background) | |
| # If most pixels are bright, it's likely a black digit on white background | |
| if img_array.mean() > 0.5: | |
| img_array = 1.0 - img_array | |
| # Apply normalization (same as training) | |
| img_array = (img_array - self.mean) / self.std | |
| # Convert to tensor and add batch and channel dimensions | |
| img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0) | |
| return img_tensor.to(self.device) | |
| def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict: | |
| """ | |
| Predict digit from image. | |
| Args: | |
| image: PIL Image or numpy array containing digit | |
| Returns: | |
| Dictionary with: | |
| - digit: Predicted digit (0-9) | |
| - confidence: Confidence score (0-1) | |
| - probabilities: List of probabilities for each digit | |
| """ | |
| img_tensor = self.preprocess(image) | |
| with torch.no_grad(): | |
| outputs = self.model(img_tensor) | |
| probabilities = torch.softmax(outputs, dim=1)[0] | |
| confidence, predicted = torch.max(probabilities, dim=0) | |
| return { | |
| 'digit': int(predicted.item()), | |
| 'confidence': float(confidence.item()), | |
| 'probabilities': probabilities.cpu().numpy().tolist() | |
| } | |
| def predict_batch(self, images: list) -> list: | |
| """ | |
| Predict digits for a batch of images. | |
| Args: | |
| images: List of PIL Images or numpy arrays | |
| Returns: | |
| List of prediction dictionaries | |
| """ | |
| return [self.predict(img) for img in images] | |
| def test_inference(): | |
| """Test inference module with sample images.""" | |
| import sys | |
| from pathlib import Path | |
| # Add project root to path | |
| project_root = Path(__file__).parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| from scripts.data_loader import MnistDataloader | |
| print("Testing Inference Module") | |
| print("=" * 50) | |
| # Check if model exists | |
| model_path = project_root / 'models' / 'best_model.pt' | |
| if not model_path.exists(): | |
| print(f"Error: Model not found at {model_path}") | |
| print("Please train a model first.") | |
| return | |
| # Load MNIST test data | |
| data_path = project_root / 'data' / 'raw' | |
| loader = MnistDataloader( | |
| training_images_filepath=str(data_path / 'train-images.idx3-ubyte'), | |
| training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'), | |
| test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'), | |
| test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte') | |
| ) | |
| _, (x_test, y_test) = loader.load_data() | |
| # Initialize classifier | |
| print(f"\n1. Loading model from: {model_path}") | |
| classifier = DigitClassifier(str(model_path)) | |
| print(f" Device: {classifier.device}") | |
| # Test on a few images | |
| print("\n2. Testing predictions on 10 random test images:") | |
| print("-" * 50) | |
| indices = np.random.choice(len(x_test), 10, replace=False) | |
| correct = 0 | |
| for i, idx in enumerate(indices, 1): | |
| image = x_test[idx] | |
| true_label = y_test[idx] | |
| # Convert list to numpy array if needed | |
| if isinstance(image, list): | |
| image = np.array(image) | |
| # Convert to PIL Image | |
| img = Image.fromarray(image.astype(np.uint8), mode='L') | |
| # Predict | |
| result = classifier.predict(img) | |
| is_correct = result['digit'] == true_label | |
| correct += is_correct | |
| print(f" Image {i}: True={true_label}, Pred={result['digit']}, " | |
| f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}") | |
| accuracy = correct / len(indices) * 100 | |
| print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%") | |
| # Test edge cases | |
| print("\n3. Testing edge cases:") | |
| print("-" * 50) | |
| # Blank image | |
| blank = np.zeros((28, 28), dtype=np.uint8) | |
| blank_img = Image.fromarray(blank, mode='L') | |
| result = classifier.predict(blank_img) | |
| print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}") | |
| # All white image | |
| white = np.ones((28, 28), dtype=np.uint8) * 255 | |
| white_img = Image.fromarray(white, mode='L') | |
| result = classifier.predict(white_img) | |
| print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}") | |
| # Different size image | |
| test_img = x_test[0] | |
| if isinstance(test_img, list): | |
| test_img = np.array(test_img) | |
| large = Image.fromarray(test_img.astype(np.uint8), mode='L') | |
| large = large.resize((56, 56)) | |
| result = classifier.predict(large) | |
| print( | |
| f" Resized image (56x56): " | |
| f"Pred={result['digit']}, Conf={result['confidence']:.4f}" | |
| ) | |
| print("\n✓ Inference module test complete!") | |
| if __name__ == '__main__': | |
| test_inference() | |