File size: 5,539 Bytes
c72189e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import base64
import io
import time
import numpy as np
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from torchvision import models, transforms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

app = FastAPI()

# Enable CORS so your React app can talk to this backend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], # Allow all origins for the demo
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- 1. CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = ['Gray Leaf Spot', 'Healthy']

# Define paths to your uploaded weights
# Upload these files to your HF Space manually via "Files" tab
MODEL_PATHS = {
    "resnet_base": "models/ResNet50_Aug_False.pth",
    "resnet_aug": "models/ResNet50_Aug_True.pth",
    "effnet_base": "models/EfficientNet_Aug_False.pth",
    "effnet_aug": "models/EfficientNet_Aug_True.pth"
}

# --- 2. LOAD MODELS ---
loaded_models = {}

def load_architecture(model_name, num_classes=2):
    """Rebuilds the architecture to match your training"""
    if "resnet" in model_name:
        model = models.resnet50(weights=None)
        model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
        # Target layer for Grad-CAM in ResNet
        target_layer = model.layer4[-1]
    else:
        model = models.efficientnet_b0(weights=None)
        model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
        # Target layer for Grad-CAM in EfficientNet
        target_layer = model.features[-1]
    return model, target_layer

print("Loading models... this might take a minute...")
for key, path in MODEL_PATHS.items():
    try:
        # Create architecture
        model, layer = load_architecture(key)
        # Load weights (Ensure you upload the files!)
        # If testing without weights, comment out the next line
        state_dict = torch.load(path, map_location=DEVICE)
        model.load_state_dict(state_dict)
        
        model.to(DEVICE)
        model.eval()
        
        # Initialize Grad-CAM for this model
        cam = GradCAM(model=model, target_layers=[layer])
        
        loaded_models[key] = {"model": model, "cam": cam}
        print(f"Loaded {key}")
    except Exception as e:
        print(f"Error loading {key}: {e}")
        # Placeholder for demo if weights are missing
        loaded_models[key] = None

# --- 3. UTILITIES ---
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def run_inference_and_gradcam(key, image_tensor, original_image_np):
    """Runs prediction and generates heatmap for a single model"""
    item = loaded_models[key]
    if item is None:
        return None

    model = item["model"]
    cam = item["cam"]
    
    start_time = time.time()
    
    # 1. Prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)
        conf, pred_idx = torch.max(probs, 1)
        
    inference_time = (time.time() - start_time) * 1000 # ms
    
    # 2. Grad-CAM
    # We need gradients, so we run cam() which handles the forward/backward internally
    grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :]
    visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True)
    
    # Convert Grad-CAM numpy to Base64 String for frontend
    pil_img = Image.fromarray(visualization)
    buff = io.BytesIO()
    pil_img.save(buff, format="JPEG")
    img_str = base64.b64encode(buff.getvalue()).decode("utf-8")
    
    return {
        "label": CLASSES[pred_idx.item()],
        "confidence": float(conf.item()),
        "time": f"{inference_time:.2f}ms",
        "heatmap": f"data:image/jpeg;base64,{img_str}"
    }

# --- 4. API ENDPOINT ---
@app.post("/analyze")
async def analyze_leaf(file: UploadFile = File(...)):
    # Read Image
    contents = await file.read()
    image = Image.open(io.BytesIO(contents)).convert("RGB")
    
    # Prepare Inputs
    tensor = transform(image).unsqueeze(0).to(DEVICE)
    
    # For Grad-CAM visualization, we need a normalized float numpy array (0-1)
    # Resize original image to 256x256 to match tensor
    img_resized = image.resize((256, 256))
    img_np = np.array(img_resized, dtype=np.float32) / 255.0

    results = []
    
    # Process all 4 models
    # Mapping frontend IDs to backend keys
    definitions = [
        {"id": 1, "key": "resnet_base", "name": "ResNet50 Base"},
        {"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"},
        {"id": 3, "key": "effnet_base", "name": "EffNet Base"},
        {"id": 4, "key": "effnet_aug", "name": "EffNet Aug"},
    ]
    
    for definition in definitions:
        data = run_inference_and_gradcam(definition["key"], tensor, img_np)
        if data:
            results.append({
                "id": definition["id"],
                **data
            })
        else:
            # Fallback if model failed to load
            results.append({
                "id": definition["id"],
                "label": "Error",
                "confidence": 0.0,
                "time": "0ms",
                "heatmap": ""
            })
            
    return results

@app.get("/")
def home():
    return {"message": "Maize Ablation Backend is Running"}