corn-gls-backend / main.py
PulinduVR's picture
Create main.py
c72189e verified
import torch
import base64
import io
import time
import numpy as np
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from torchvision import models, transforms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
app = FastAPI()
# Enable CORS so your React app can talk to this backend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for the demo
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 1. CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = ['Gray Leaf Spot', 'Healthy']
# Define paths to your uploaded weights
# Upload these files to your HF Space manually via "Files" tab
MODEL_PATHS = {
"resnet_base": "models/ResNet50_Aug_False.pth",
"resnet_aug": "models/ResNet50_Aug_True.pth",
"effnet_base": "models/EfficientNet_Aug_False.pth",
"effnet_aug": "models/EfficientNet_Aug_True.pth"
}
# --- 2. LOAD MODELS ---
loaded_models = {}
def load_architecture(model_name, num_classes=2):
"""Rebuilds the architecture to match your training"""
if "resnet" in model_name:
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Target layer for Grad-CAM in ResNet
target_layer = model.layer4[-1]
else:
model = models.efficientnet_b0(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
# Target layer for Grad-CAM in EfficientNet
target_layer = model.features[-1]
return model, target_layer
print("Loading models... this might take a minute...")
for key, path in MODEL_PATHS.items():
try:
# Create architecture
model, layer = load_architecture(key)
# Load weights (Ensure you upload the files!)
# If testing without weights, comment out the next line
state_dict = torch.load(path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Initialize Grad-CAM for this model
cam = GradCAM(model=model, target_layers=[layer])
loaded_models[key] = {"model": model, "cam": cam}
print(f"Loaded {key}")
except Exception as e:
print(f"Error loading {key}: {e}")
# Placeholder for demo if weights are missing
loaded_models[key] = None
# --- 3. UTILITIES ---
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference_and_gradcam(key, image_tensor, original_image_np):
"""Runs prediction and generates heatmap for a single model"""
item = loaded_models[key]
if item is None:
return None
model = item["model"]
cam = item["cam"]
start_time = time.time()
# 1. Prediction
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
conf, pred_idx = torch.max(probs, 1)
inference_time = (time.time() - start_time) * 1000 # ms
# 2. Grad-CAM
# We need gradients, so we run cam() which handles the forward/backward internally
grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :]
visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True)
# Convert Grad-CAM numpy to Base64 String for frontend
pil_img = Image.fromarray(visualization)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
img_str = base64.b64encode(buff.getvalue()).decode("utf-8")
return {
"label": CLASSES[pred_idx.item()],
"confidence": float(conf.item()),
"time": f"{inference_time:.2f}ms",
"heatmap": f"data:image/jpeg;base64,{img_str}"
}
# --- 4. API ENDPOINT ---
@app.post("/analyze")
async def analyze_leaf(file: UploadFile = File(...)):
# Read Image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Prepare Inputs
tensor = transform(image).unsqueeze(0).to(DEVICE)
# For Grad-CAM visualization, we need a normalized float numpy array (0-1)
# Resize original image to 256x256 to match tensor
img_resized = image.resize((256, 256))
img_np = np.array(img_resized, dtype=np.float32) / 255.0
results = []
# Process all 4 models
# Mapping frontend IDs to backend keys
definitions = [
{"id": 1, "key": "resnet_base", "name": "ResNet50 Base"},
{"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"},
{"id": 3, "key": "effnet_base", "name": "EffNet Base"},
{"id": 4, "key": "effnet_aug", "name": "EffNet Aug"},
]
for definition in definitions:
data = run_inference_and_gradcam(definition["key"], tensor, img_np)
if data:
results.append({
"id": definition["id"],
**data
})
else:
# Fallback if model failed to load
results.append({
"id": definition["id"],
"label": "Error",
"confidence": 0.0,
"time": "0ms",
"heatmap": ""
})
return results
@app.get("/")
def home():
return {"message": "Maize Ablation Backend is Running"}