PulinduVR commited on
Commit
c72189e
·
verified ·
1 Parent(s): 8fb0645

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +168 -0
main.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import base64
3
+ import io
4
+ import time
5
+ import numpy as np
6
+ from fastapi import FastAPI, File, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from torchvision import models, transforms
9
+ from PIL import Image
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image
12
+
13
+ app = FastAPI()
14
+
15
+ # Enable CORS so your React app can talk to this backend
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"], # Allow all origins for the demo
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # --- 1. CONFIGURATION ---
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ CLASSES = ['Gray Leaf Spot', 'Healthy']
27
+
28
+ # Define paths to your uploaded weights
29
+ # Upload these files to your HF Space manually via "Files" tab
30
+ MODEL_PATHS = {
31
+ "resnet_base": "models/ResNet50_Aug_False.pth",
32
+ "resnet_aug": "models/ResNet50_Aug_True.pth",
33
+ "effnet_base": "models/EfficientNet_Aug_False.pth",
34
+ "effnet_aug": "models/EfficientNet_Aug_True.pth"
35
+ }
36
+
37
+ # --- 2. LOAD MODELS ---
38
+ loaded_models = {}
39
+
40
+ def load_architecture(model_name, num_classes=2):
41
+ """Rebuilds the architecture to match your training"""
42
+ if "resnet" in model_name:
43
+ model = models.resnet50(weights=None)
44
+ model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
45
+ # Target layer for Grad-CAM in ResNet
46
+ target_layer = model.layer4[-1]
47
+ else:
48
+ model = models.efficientnet_b0(weights=None)
49
+ model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
50
+ # Target layer for Grad-CAM in EfficientNet
51
+ target_layer = model.features[-1]
52
+ return model, target_layer
53
+
54
+ print("Loading models... this might take a minute...")
55
+ for key, path in MODEL_PATHS.items():
56
+ try:
57
+ # Create architecture
58
+ model, layer = load_architecture(key)
59
+ # Load weights (Ensure you upload the files!)
60
+ # If testing without weights, comment out the next line
61
+ state_dict = torch.load(path, map_location=DEVICE)
62
+ model.load_state_dict(state_dict)
63
+
64
+ model.to(DEVICE)
65
+ model.eval()
66
+
67
+ # Initialize Grad-CAM for this model
68
+ cam = GradCAM(model=model, target_layers=[layer])
69
+
70
+ loaded_models[key] = {"model": model, "cam": cam}
71
+ print(f"Loaded {key}")
72
+ except Exception as e:
73
+ print(f"Error loading {key}: {e}")
74
+ # Placeholder for demo if weights are missing
75
+ loaded_models[key] = None
76
+
77
+ # --- 3. UTILITIES ---
78
+ transform = transforms.Compose([
79
+ transforms.Resize((256, 256)),
80
+ transforms.ToTensor(),
81
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
82
+ ])
83
+
84
+ def run_inference_and_gradcam(key, image_tensor, original_image_np):
85
+ """Runs prediction and generates heatmap for a single model"""
86
+ item = loaded_models[key]
87
+ if item is None:
88
+ return None
89
+
90
+ model = item["model"]
91
+ cam = item["cam"]
92
+
93
+ start_time = time.time()
94
+
95
+ # 1. Prediction
96
+ with torch.no_grad():
97
+ outputs = model(image_tensor)
98
+ probs = torch.nn.functional.softmax(outputs, dim=1)
99
+ conf, pred_idx = torch.max(probs, 1)
100
+
101
+ inference_time = (time.time() - start_time) * 1000 # ms
102
+
103
+ # 2. Grad-CAM
104
+ # We need gradients, so we run cam() which handles the forward/backward internally
105
+ grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :]
106
+ visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True)
107
+
108
+ # Convert Grad-CAM numpy to Base64 String for frontend
109
+ pil_img = Image.fromarray(visualization)
110
+ buff = io.BytesIO()
111
+ pil_img.save(buff, format="JPEG")
112
+ img_str = base64.b64encode(buff.getvalue()).decode("utf-8")
113
+
114
+ return {
115
+ "label": CLASSES[pred_idx.item()],
116
+ "confidence": float(conf.item()),
117
+ "time": f"{inference_time:.2f}ms",
118
+ "heatmap": f"data:image/jpeg;base64,{img_str}"
119
+ }
120
+
121
+ # --- 4. API ENDPOINT ---
122
+ @app.post("/analyze")
123
+ async def analyze_leaf(file: UploadFile = File(...)):
124
+ # Read Image
125
+ contents = await file.read()
126
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
127
+
128
+ # Prepare Inputs
129
+ tensor = transform(image).unsqueeze(0).to(DEVICE)
130
+
131
+ # For Grad-CAM visualization, we need a normalized float numpy array (0-1)
132
+ # Resize original image to 256x256 to match tensor
133
+ img_resized = image.resize((256, 256))
134
+ img_np = np.array(img_resized, dtype=np.float32) / 255.0
135
+
136
+ results = []
137
+
138
+ # Process all 4 models
139
+ # Mapping frontend IDs to backend keys
140
+ definitions = [
141
+ {"id": 1, "key": "resnet_base", "name": "ResNet50 Base"},
142
+ {"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"},
143
+ {"id": 3, "key": "effnet_base", "name": "EffNet Base"},
144
+ {"id": 4, "key": "effnet_aug", "name": "EffNet Aug"},
145
+ ]
146
+
147
+ for definition in definitions:
148
+ data = run_inference_and_gradcam(definition["key"], tensor, img_np)
149
+ if data:
150
+ results.append({
151
+ "id": definition["id"],
152
+ **data
153
+ })
154
+ else:
155
+ # Fallback if model failed to load
156
+ results.append({
157
+ "id": definition["id"],
158
+ "label": "Error",
159
+ "confidence": 0.0,
160
+ "time": "0ms",
161
+ "heatmap": ""
162
+ })
163
+
164
+ return results
165
+
166
+ @app.get("/")
167
+ def home():
168
+ return {"message": "Maize Ablation Backend is Running"}