""" Inference script for making predictions with trained MNIST models Usage: python inference.py --model-path checkpoints/best_model.pth --image-path my_digit.png """ import torch import torch.nn as nn from torchvision import transforms from PIL import Image import argparse import numpy as np import matplotlib.pyplot as plt from pathlib import Path # Model architectures (must match training) class ConvNet(nn.Module): """Convolutional Neural Network for MNIST""" def __init__(self, dropout_rate=0.3, num_classes=10): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(128) self.pool = nn.MaxPool2d(2, 2) self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5) self.fc1 = nn.Linear(128 * 7 * 7, 256) self.bn5 = nn.BatchNorm1d(256) self.dropout1 = nn.Dropout(dropout_rate) self.fc2 = nn.Linear(256, 128) self.bn6 = nn.BatchNorm1d(128) self.dropout2 = nn.Dropout(dropout_rate * 0.5) self.fc3 = nn.Linear(128, num_classes) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = torch.relu(x) x = self.conv2(x) x = self.bn2(x) x = torch.relu(x) x = self.pool(x) x = self.dropout_conv(x) x = self.conv3(x) x = self.bn3(x) x = torch.relu(x) x = self.conv4(x) x = self.bn4(x) x = torch.relu(x) x = self.pool(x) x = self.dropout_conv(x) x = x.view(x.size(0), -1) x = self.fc1(x) x = self.bn5(x) x = torch.relu(x) x = self.dropout1(x) x = self.fc2(x) x = self.bn6(x) x = torch.relu(x) x = self.dropout2(x) x = self.fc3(x) return x class ImprovedNN(nn.Module): """Enhanced fully connected network""" def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], num_classes=10, dropout_rate=0.3): super(ImprovedNN, self).__init__() layers = [] prev_size = input_size for i, hidden_size in enumerate(hidden_sizes): layers.extend([ nn.Linear(prev_size, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(), nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5) ]) prev_size = hidden_size layers.append(nn.Linear(prev_size, num_classes)) self.network = nn.Sequential(*layers) def forward(self, x): x = x.view(x.size(0), -1) return self.network(x) def load_model(model_path, model_type='cnn', device='cpu'): """Load a trained model from checkpoint""" # Load checkpoint checkpoint = torch.load(model_path, map_location=device) # Get model type from checkpoint if available if 'args' in checkpoint and 'model_type' in checkpoint['args']: model_type = checkpoint['args']['model_type'] # Create model if model_type == 'cnn': model = ConvNet() else: model = ImprovedNN() # Load weights model.load_state_dict(checkpoint['model_state_dict']) model.to(device) model.eval() print(f"✓ Loaded {model_type.upper()} model from {model_path}") print(f" - Trained for {checkpoint.get('epoch', 'unknown')} epochs") print(f" - Validation accuracy: {checkpoint.get('val_acc', 'unknown'):.2f}%") return model def preprocess_image(image_path): """Preprocess an image for inference""" # Load image img = Image.open(image_path).convert('L') # Convert to grayscale # Resize to 28x28 img = img.resize((28, 28), Image.Resampling.LANCZOS) # Convert to tensor and normalize (same as training) # Note: MNIST images saved as PNG are already in correct format: # white/light digits on dark/black background transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) img_tensor = transform(img) # Get array for visualization img_array = np.array(img) return img_tensor, img_array def predict(model, image_tensor, device): """Make prediction on a single image""" # Add batch dimension image_tensor = image_tensor.unsqueeze(0).to(device) # Forward pass with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.softmax(outputs, dim=1) confidence, predicted = torch.max(probabilities, 1) return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy() def visualize_prediction(image, predicted_digit, confidence, probabilities, save_path=None): """Visualize the prediction with confidence scores""" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Show image ax1.imshow(image, cmap='gray') ax1.set_title(f'Input Image\nPredicted: {predicted_digit} ({confidence*100:.1f}%)', fontsize=14, fontweight='bold') ax1.axis('off') # Show probability distribution digits = np.arange(10) colors = ['green' if i == predicted_digit else 'gray' for i in digits] bars = ax2.bar(digits, probabilities * 100, color=colors, alpha=0.7) # Add value labels on bars for i, (bar, prob) in enumerate(zip(bars, probabilities)): height = bar.get_height() ax2.text(bar.get_x() + bar.get_width()/2., height, f'{prob*100:.1f}%', ha='center', va='bottom', fontsize=9) ax2.set_xlabel('Digit', fontsize=12) ax2.set_ylabel('Confidence (%)', fontsize=12) ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold') ax2.set_xticks(digits) ax2.set_ylim([0, 105]) ax2.grid(True, alpha=0.3, axis='y') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"✓ Visualization saved to {save_path}") plt.show() def predict_batch(model, image_paths, device): """Make predictions on multiple images""" results = [] for image_path in image_paths: print(f"\nProcessing: {image_path}") # Preprocess img_tensor, img_array = preprocess_image(image_path) # Predict predicted, confidence, probabilities = predict(model, img_tensor, device) results.append({ 'image_path': image_path, 'predicted': predicted, 'confidence': confidence, 'probabilities': probabilities }) print(f" Prediction: {predicted} (Confidence: {confidence*100:.2f}%)") # Show top 3 predictions top3_idx = np.argsort(probabilities)[-3:][::-1] print(f" Top 3: ", end="") for idx in top3_idx: print(f"{idx}({probabilities[idx]*100:.1f}%) ", end="") print() return results def main(): parser = argparse.ArgumentParser(description='MNIST Digit Recognition Inference') parser.add_argument('--model-path', type=str, required=True, help='Path to trained model checkpoint') parser.add_argument('--image-path', type=str, help='Path to input image (28x28 recommended, grayscale)') parser.add_argument('--image-dir', type=str, help='Directory containing multiple images to predict') parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'], help='Model architecture type (auto-detected from checkpoint if available)') parser.add_argument('--save-viz', type=str, help='Path to save visualization') parser.add_argument('--use-gpu', action='store_true', help='Use GPU if available') args = parser.parse_args() # Setup device device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu') print(f"Using device: {device}") # Load model model = load_model(args.model_path, args.model_type, device) # Single image prediction if args.image_path: print(f"\nProcessing single image: {args.image_path}") # Preprocess img_tensor, img_array = preprocess_image(args.image_path) # Predict predicted, confidence, probabilities = predict(model, img_tensor, device) print(f"\n{'='*50}") print(f"Prediction: {predicted}") print(f"Confidence: {confidence*100:.2f}%") print(f"{'='*50}") # Show all probabilities print("\nAll class probabilities:") for digit in range(10): print(f" {digit}: {probabilities[digit]*100:.2f}%") # Visualize save_path = args.save_viz if args.save_viz else 'prediction_visualization.png' visualize_prediction(img_array, predicted, confidence, probabilities, save_path) # Batch prediction elif args.image_dir: print(f"\nProcessing directory: {args.image_dir}") image_dir = Path(args.image_dir) image_paths = list(image_dir.glob('*.png')) + list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.jpeg')) if not image_paths: print("No images found in directory!") return print(f"Found {len(image_paths)} images") results = predict_batch(model, [str(p) for p in image_paths], device) # Summary print(f"\n{'='*50}") print("Summary:") print(f"{'='*50}") for result in results: print(f"{Path(result['image_path']).name}: {result['predicted']} ({result['confidence']*100:.1f}%)") else: print("Please provide either --image-path or --image-dir") return if __name__ == '__main__': main()