faizan
fix: resolve all 468 ruff linting errors (code quality enforcement complete)
e77a25a
"""
Inference module for MNIST digit classification.
Provides a clean API for making predictions with the trained model.
Handles image preprocessing and returns predictions with confidence scores.
"""
import torch
from PIL import Image
import numpy as np
from pathlib import Path
from typing import Union, Dict
class DigitClassifier:
"""Production inference wrapper for MNIST digit classifier."""
def __init__(self, model_path: str, device: str = None):
"""
Initialize the digit classifier.
Args:
model_path: Path to model checkpoint (.pt file)
device: Device to run inference on ('cuda' or 'cpu').
If None, auto-detects CUDA availability.
"""
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
self.model_path = Path(model_path)
if not self.model_path.exists():
raise FileNotFoundError(f"Model not found at {model_path}")
self.model = self._load_model()
self.model.eval()
# Normalization values (same as training)
self.mean = 0.1307
self.std = 0.3081
def _load_model(self) -> torch.nn.Module:
"""Load model from checkpoint."""
from scripts.models import BaselineCNN
model = BaselineCNN()
# Load checkpoint
checkpoint = torch.load(self.model_path, map_location=self.device)
# Handle different checkpoint formats
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
return model.to(self.device)
def preprocess(self, image: Union[Image.Image, np.ndarray]) -> torch.Tensor:
"""
Preprocess image for model input.
Handles:
- RGB to grayscale conversion
- Resizing to 28x28
- Normalization
- Inversion if needed (white digit on black background)
Args:
image: PIL Image or numpy array
Returns:
Preprocessed tensor of shape (1, 1, 28, 28)
"""
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to grayscale if RGB
if image.mode != 'L':
image = image.convert('L')
# Resize to 28x28 if needed
if image.size != (28, 28):
image = image.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to numpy array
img_array = np.array(image).astype(np.float32)
# Normalize to [0, 1]
img_array = img_array / 255.0
# Check if inversion is needed (MNIST is white digit on black background)
# If most pixels are bright, it's likely a black digit on white background
if img_array.mean() > 0.5:
img_array = 1.0 - img_array
# Apply normalization (same as training)
img_array = (img_array - self.mean) / self.std
# Convert to tensor and add batch and channel dimensions
img_tensor = torch.tensor(img_array).unsqueeze(0).unsqueeze(0)
return img_tensor.to(self.device)
def predict(self, image: Union[Image.Image, np.ndarray]) -> Dict:
"""
Predict digit from image.
Args:
image: PIL Image or numpy array containing digit
Returns:
Dictionary with:
- digit: Predicted digit (0-9)
- confidence: Confidence score (0-1)
- probabilities: List of probabilities for each digit
"""
img_tensor = self.preprocess(image)
with torch.no_grad():
outputs = self.model(img_tensor)
probabilities = torch.softmax(outputs, dim=1)[0]
confidence, predicted = torch.max(probabilities, dim=0)
return {
'digit': int(predicted.item()),
'confidence': float(confidence.item()),
'probabilities': probabilities.cpu().numpy().tolist()
}
def predict_batch(self, images: list) -> list:
"""
Predict digits for a batch of images.
Args:
images: List of PIL Images or numpy arrays
Returns:
List of prediction dictionaries
"""
return [self.predict(img) for img in images]
def test_inference():
"""Test inference module with sample images."""
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from scripts.data_loader import MnistDataloader
print("Testing Inference Module")
print("=" * 50)
# Check if model exists
model_path = project_root / 'models' / 'best_model.pt'
if not model_path.exists():
print(f"Error: Model not found at {model_path}")
print("Please train a model first.")
return
# Load MNIST test data
data_path = project_root / 'data' / 'raw'
loader = MnistDataloader(
training_images_filepath=str(data_path / 'train-images.idx3-ubyte'),
training_labels_filepath=str(data_path / 'train-labels.idx1-ubyte'),
test_images_filepath=str(data_path / 't10k-images.idx3-ubyte'),
test_labels_filepath=str(data_path / 't10k-labels.idx1-ubyte')
)
_, (x_test, y_test) = loader.load_data()
# Initialize classifier
print(f"\n1. Loading model from: {model_path}")
classifier = DigitClassifier(str(model_path))
print(f" Device: {classifier.device}")
# Test on a few images
print("\n2. Testing predictions on 10 random test images:")
print("-" * 50)
indices = np.random.choice(len(x_test), 10, replace=False)
correct = 0
for i, idx in enumerate(indices, 1):
image = x_test[idx]
true_label = y_test[idx]
# Convert list to numpy array if needed
if isinstance(image, list):
image = np.array(image)
# Convert to PIL Image
img = Image.fromarray(image.astype(np.uint8), mode='L')
# Predict
result = classifier.predict(img)
is_correct = result['digit'] == true_label
correct += is_correct
print(f" Image {i}: True={true_label}, Pred={result['digit']}, "
f"Conf={result['confidence']:.4f} {'✓' if is_correct else '✗'}")
accuracy = correct / len(indices) * 100
print(f"\nAccuracy on {len(indices)} samples: {accuracy:.1f}%")
# Test edge cases
print("\n3. Testing edge cases:")
print("-" * 50)
# Blank image
blank = np.zeros((28, 28), dtype=np.uint8)
blank_img = Image.fromarray(blank, mode='L')
result = classifier.predict(blank_img)
print(f" Blank image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
# All white image
white = np.ones((28, 28), dtype=np.uint8) * 255
white_img = Image.fromarray(white, mode='L')
result = classifier.predict(white_img)
print(f" White image: Pred={result['digit']}, Conf={result['confidence']:.4f}")
# Different size image
test_img = x_test[0]
if isinstance(test_img, list):
test_img = np.array(test_img)
large = Image.fromarray(test_img.astype(np.uint8), mode='L')
large = large.resize((56, 56))
result = classifier.predict(large)
print(
f" Resized image (56x56): "
f"Pred={result['digit']}, Conf={result['confidence']:.4f}"
)
print("\n✓ Inference module test complete!")
if __name__ == '__main__':
test_inference()