File size: 10,416 Bytes
21f4ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
"""
Inference script for making predictions with trained MNIST models
Usage: python inference.py --model-path checkpoints/best_model.pth --image-path my_digit.png
"""

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import argparse
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Model architectures (must match training)
class ConvNet(nn.Module):
    """Convolutional Neural Network for MNIST"""
    def __init__(self, dropout_rate=0.3, num_classes=10):
        super(ConvNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5)
        
        self.fc1 = nn.Linear(128 * 7 * 7, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(256, 128)
        self.bn6 = nn.BatchNorm1d(128)
        self.dropout2 = nn.Dropout(dropout_rate * 0.5)
        
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.dropout_conv(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = torch.relu(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = torch.relu(x)
        x = self.pool(x)
        x = self.dropout_conv(x)
        
        x = x.view(x.size(0), -1)
        
        x = self.fc1(x)
        x = self.bn5(x)
        x = torch.relu(x)
        x = self.dropout1(x)
        
        x = self.fc2(x)
        x = self.bn6(x)
        x = torch.relu(x)
        x = self.dropout2(x)
        
        x = self.fc3(x)
        return x

class ImprovedNN(nn.Module):
    """Enhanced fully connected network"""
    def __init__(self, input_size=784, hidden_sizes=[512, 256, 128], 
                 num_classes=10, dropout_rate=0.3):
        super(ImprovedNN, self).__init__()
        
        layers = []
        prev_size = input_size
        
        for i, hidden_size in enumerate(hidden_sizes):
            layers.extend([
                nn.Linear(prev_size, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout_rate if i < len(hidden_sizes) - 1 else dropout_rate * 0.5)
            ])
            prev_size = hidden_size
        
        layers.append(nn.Linear(prev_size, num_classes))
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.network(x)

def load_model(model_path, model_type='cnn', device='cpu'):
    """Load a trained model from checkpoint"""
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Get model type from checkpoint if available
    if 'args' in checkpoint and 'model_type' in checkpoint['args']:
        model_type = checkpoint['args']['model_type']
    
    # Create model
    if model_type == 'cnn':
        model = ConvNet()
    else:
        model = ImprovedNN()
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"✓ Loaded {model_type.upper()} model from {model_path}")
    print(f"  - Trained for {checkpoint.get('epoch', 'unknown')} epochs")
    print(f"  - Validation accuracy: {checkpoint.get('val_acc', 'unknown'):.2f}%")
    
    return model

def preprocess_image(image_path):
    """Preprocess an image for inference"""
    # Load image
    img = Image.open(image_path).convert('L')  # Convert to grayscale
    
    # Resize to 28x28
    img = img.resize((28, 28), Image.Resampling.LANCZOS)
    
    # Convert to tensor and normalize (same as training)
    # Note: MNIST images saved as PNG are already in correct format:
    # white/light digits on dark/black background
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    img_tensor = transform(img)
    
    # Get array for visualization
    img_array = np.array(img)
    
    return img_tensor, img_array

def predict(model, image_tensor, device):
    """Make prediction on a single image"""
    # Add batch dimension
    image_tensor = image_tensor.unsqueeze(0).to(device)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
    
    return predicted.item(), confidence.item(), probabilities.squeeze().cpu().numpy()

def visualize_prediction(image, predicted_digit, confidence, probabilities, save_path=None):
    """Visualize the prediction with confidence scores"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Show image
    ax1.imshow(image, cmap='gray')
    ax1.set_title(f'Input Image\nPredicted: {predicted_digit} ({confidence*100:.1f}%)', 
                  fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # Show probability distribution
    digits = np.arange(10)
    colors = ['green' if i == predicted_digit else 'gray' for i in digits]
    bars = ax2.bar(digits, probabilities * 100, color=colors, alpha=0.7)
    
    # Add value labels on bars
    for i, (bar, prob) in enumerate(zip(bars, probabilities)):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{prob*100:.1f}%',
                ha='center', va='bottom', fontsize=9)
    
    ax2.set_xlabel('Digit', fontsize=12)
    ax2.set_ylabel('Confidence (%)', fontsize=12)
    ax2.set_title('Class Probabilities', fontsize=14, fontweight='bold')
    ax2.set_xticks(digits)
    ax2.set_ylim([0, 105])
    ax2.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"✓ Visualization saved to {save_path}")
    
    plt.show()

def predict_batch(model, image_paths, device):
    """Make predictions on multiple images"""
    results = []
    
    for image_path in image_paths:
        print(f"\nProcessing: {image_path}")
        
        # Preprocess
        img_tensor, img_array = preprocess_image(image_path)
        
        # Predict
        predicted, confidence, probabilities = predict(model, img_tensor, device)
        
        results.append({
            'image_path': image_path,
            'predicted': predicted,
            'confidence': confidence,
            'probabilities': probabilities
        })
        
        print(f"  Prediction: {predicted} (Confidence: {confidence*100:.2f}%)")
        
        # Show top 3 predictions
        top3_idx = np.argsort(probabilities)[-3:][::-1]
        print(f"  Top 3: ", end="")
        for idx in top3_idx:
            print(f"{idx}({probabilities[idx]*100:.1f}%) ", end="")
        print()
    
    return results

def main():
    parser = argparse.ArgumentParser(description='MNIST Digit Recognition Inference')
    parser.add_argument('--model-path', type=str, required=True,
                        help='Path to trained model checkpoint')
    parser.add_argument('--image-path', type=str, 
                        help='Path to input image (28x28 recommended, grayscale)')
    parser.add_argument('--image-dir', type=str,
                        help='Directory containing multiple images to predict')
    parser.add_argument('--model-type', type=str, default='cnn', choices=['cnn', 'fc'],
                        help='Model architecture type (auto-detected from checkpoint if available)')
    parser.add_argument('--save-viz', type=str,
                        help='Path to save visualization')
    parser.add_argument('--use-gpu', action='store_true',
                        help='Use GPU if available')
    
    args = parser.parse_args()
    
    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() and args.use_gpu else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    model = load_model(args.model_path, args.model_type, device)
    
    # Single image prediction
    if args.image_path:
        print(f"\nProcessing single image: {args.image_path}")
        
        # Preprocess
        img_tensor, img_array = preprocess_image(args.image_path)
        
        # Predict
        predicted, confidence, probabilities = predict(model, img_tensor, device)
        
        print(f"\n{'='*50}")
        print(f"Prediction: {predicted}")
        print(f"Confidence: {confidence*100:.2f}%")
        print(f"{'='*50}")
        
        # Show all probabilities
        print("\nAll class probabilities:")
        for digit in range(10):
            print(f"  {digit}: {probabilities[digit]*100:.2f}%")
        
        # Visualize
        save_path = args.save_viz if args.save_viz else 'prediction_visualization.png'
        visualize_prediction(img_array, predicted, confidence, probabilities, save_path)
    
    # Batch prediction
    elif args.image_dir:
        print(f"\nProcessing directory: {args.image_dir}")
        
        image_dir = Path(args.image_dir)
        image_paths = list(image_dir.glob('*.png')) + list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.jpeg'))
        
        if not image_paths:
            print("No images found in directory!")
            return
        
        print(f"Found {len(image_paths)} images")
        
        results = predict_batch(model, [str(p) for p in image_paths], device)
        
        # Summary
        print(f"\n{'='*50}")
        print("Summary:")
        print(f"{'='*50}")
        for result in results:
            print(f"{Path(result['image_path']).name}: {result['predicted']} ({result['confidence']*100:.1f}%)")
    
    else:
        print("Please provide either --image-path or --image-dir")
        return

if __name__ == '__main__':
    main()