File size: 5,241 Bytes
233caeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""

Flask web application for CIFAR-10 image classification

"""
import os
import io
import base64
import torch
from PIL import Image
from flask import Flask, render_template, request, jsonify
import torchvision.transforms as transforms
import numpy as np

import config
from model import get_model
from utils import load_checkpoint

app = Flask(__name__)

# Global model variable
model = None


def load_model():
    """Load the trained model"""
    global model
    
    if not os.path.exists(config.BEST_MODEL_PATH):
        print(f"Warning: Model checkpoint not found at {config.BEST_MODEL_PATH}")
        return False
    
    model = get_model(num_classes=config.NUM_CLASSES, device=config.DEVICE)
    epoch, accuracy = load_checkpoint(model, None, config.BEST_MODEL_PATH)
    model.eval()
    
    print(f"Model loaded from epoch {epoch + 1} with accuracy: {accuracy:.2f}%")
    return True


def preprocess_image(image):
    """

    Preprocess image for model prediction

    

    Args:

        image: PIL Image

        

    Returns:

        torch.Tensor: Preprocessed image tensor

    """
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.4914, 0.4822, 0.4465],
            std=[0.2470, 0.2435, 0.2616]
        )
    ])
    
    return transform(image).unsqueeze(0)


@app.route('/')
def index():
    """Render the main page"""
    return render_template('index.html', class_names=config.CLASS_NAMES)


@app.route('/predict', methods=['POST'])
def predict():
    """Handle prediction requests"""
    if model is None:
        return jsonify({'error': 'Model not loaded'}), 500
    
    try:
        # Get image from request
        if 'file' not in request.files:
            return jsonify({'error': 'No file provided'}), 400
        
        file = request.files['file']
        
        if file.filename == '':
            return jsonify({'error': 'No file selected'}), 400
        
        # Read and preprocess image
        image = Image.open(file.stream).convert('RGB')
        input_tensor = preprocess_image(image).to(config.DEVICE)
        
        # Make prediction
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
            confidence, predicted = torch.max(probabilities, 0)
        
        # Get top 5 predictions
        top5_prob, top5_idx = torch.topk(probabilities, 5)
        top5_predictions = [
            {
                'class': config.CLASS_NAMES[idx],
                'probability': float(prob * 100)
            }
            for idx, prob in zip(top5_idx.cpu().numpy(), top5_prob.cpu().numpy())
        ]
        
        # Prepare response
        response = {
            'predicted_class': config.CLASS_NAMES[predicted.item()],
            'confidence': float(confidence.item() * 100),
            'top5_predictions': top5_predictions
        }
        
        return jsonify(response)
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/random_sample', methods=['GET'])
def random_sample():
    """Get a random sample from CIFAR-10 test set or generate dummy if missing"""
    try:
        from data_loader import get_data_loaders
        # Check if dataset exists
        dataset_path = os.path.join(config.DATA_DIR, 'cifar-10-batches-py')
        
        if os.path.exists(dataset_path):
            _, test_loader = get_data_loaders()
            dataset = test_loader.dataset
            idx = np.random.randint(0, len(dataset))
            image, label = dataset[idx]
            
            # Denormalize image
            mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
            std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)
            image_denorm = image * std + mean
            image_denorm = torch.clamp(image_denorm, 0, 1)
            
            # Convert to PIL Image
            image_np = (image_denorm.numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
            label_name = config.CLASS_NAMES[label]
        else:
            # Generate dummy image for demonstration
            image_np = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
            label_name = "Dummy Sample (Dataset still downloading)"
            
        pil_image = Image.fromarray(image_np)
        
        # Convert to base64
        buffered = io.BytesIO()
        pil_image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        
        return jsonify({
            'image': f'data:image/png;base64,{img_str}',
            'true_label': label_name
        })
    
    except Exception as e:
        return jsonify({'error': str(e)}), 500


if __name__ == '__main__':
    # Load model
    if load_model():
        print("Starting Flask application...")
        app.run(debug=True, host='0.0.0.0', port=5000)
    else:
        print("Failed to load model. Please train the model first using train.py")