import torch import base64 import io import time import numpy as np from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware from torchvision import models, transforms from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image app = FastAPI() # Enable CORS so your React app can talk to this backend app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins for the demo allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- 1. CONFIGURATION --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CLASSES = ['Gray Leaf Spot', 'Healthy'] # Define paths to your uploaded weights # Upload these files to your HF Space manually via "Files" tab MODEL_PATHS = { "resnet_base": "models/ResNet50_Aug_False.pth", "resnet_aug": "models/ResNet50_Aug_True.pth", "effnet_base": "models/EfficientNet_Aug_False.pth", "effnet_aug": "models/EfficientNet_Aug_True.pth" } # --- 2. LOAD MODELS --- loaded_models = {} def load_architecture(model_name, num_classes=2): """Rebuilds the architecture to match your training""" if "resnet" in model_name: model = models.resnet50(weights=None) model.fc = torch.nn.Linear(model.fc.in_features, num_classes) # Target layer for Grad-CAM in ResNet target_layer = model.layer4[-1] else: model = models.efficientnet_b0(weights=None) model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) # Target layer for Grad-CAM in EfficientNet target_layer = model.features[-1] return model, target_layer print("Loading models... this might take a minute...") for key, path in MODEL_PATHS.items(): try: # Create architecture model, layer = load_architecture(key) # Load weights (Ensure you upload the files!) # If testing without weights, comment out the next line state_dict = torch.load(path, map_location=DEVICE) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() # Initialize Grad-CAM for this model cam = GradCAM(model=model, target_layers=[layer]) loaded_models[key] = {"model": model, "cam": cam} print(f"Loaded {key}") except Exception as e: print(f"Error loading {key}: {e}") # Placeholder for demo if weights are missing loaded_models[key] = None # --- 3. UTILITIES --- transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def run_inference_and_gradcam(key, image_tensor, original_image_np): """Runs prediction and generates heatmap for a single model""" item = loaded_models[key] if item is None: return None model = item["model"] cam = item["cam"] start_time = time.time() # 1. Prediction with torch.no_grad(): outputs = model(image_tensor) probs = torch.nn.functional.softmax(outputs, dim=1) conf, pred_idx = torch.max(probs, 1) inference_time = (time.time() - start_time) * 1000 # ms # 2. Grad-CAM # We need gradients, so we run cam() which handles the forward/backward internally grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :] visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True) # Convert Grad-CAM numpy to Base64 String for frontend pil_img = Image.fromarray(visualization) buff = io.BytesIO() pil_img.save(buff, format="JPEG") img_str = base64.b64encode(buff.getvalue()).decode("utf-8") return { "label": CLASSES[pred_idx.item()], "confidence": float(conf.item()), "time": f"{inference_time:.2f}ms", "heatmap": f"data:image/jpeg;base64,{img_str}" } # --- 4. API ENDPOINT --- @app.post("/analyze") async def analyze_leaf(file: UploadFile = File(...)): # Read Image contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Prepare Inputs tensor = transform(image).unsqueeze(0).to(DEVICE) # For Grad-CAM visualization, we need a normalized float numpy array (0-1) # Resize original image to 256x256 to match tensor img_resized = image.resize((256, 256)) img_np = np.array(img_resized, dtype=np.float32) / 255.0 results = [] # Process all 4 models # Mapping frontend IDs to backend keys definitions = [ {"id": 1, "key": "resnet_base", "name": "ResNet50 Base"}, {"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"}, {"id": 3, "key": "effnet_base", "name": "EffNet Base"}, {"id": 4, "key": "effnet_aug", "name": "EffNet Aug"}, ] for definition in definitions: data = run_inference_and_gradcam(definition["key"], tensor, img_np) if data: results.append({ "id": definition["id"], **data }) else: # Fallback if model failed to load results.append({ "id": definition["id"], "label": "Error", "confidence": 0.0, "time": "0ms", "heatmap": "" }) return results @app.get("/") def home(): return {"message": "Maize Ablation Backend is Running"}