Spaces:
Sleeping
Sleeping
| import torch | |
| import base64 | |
| import io | |
| import time | |
| import numpy as np | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| app = FastAPI() | |
| # Enable CORS so your React app can talk to this backend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins for the demo | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- 1. CONFIGURATION --- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CLASSES = ['Gray Leaf Spot', 'Healthy'] | |
| # Define paths to your uploaded weights | |
| # Upload these files to your HF Space manually via "Files" tab | |
| MODEL_PATHS = { | |
| "resnet_base": "models/ResNet50_Aug_False.pth", | |
| "resnet_aug": "models/ResNet50_Aug_True.pth", | |
| "effnet_base": "models/EfficientNet_Aug_False.pth", | |
| "effnet_aug": "models/EfficientNet_Aug_True.pth" | |
| } | |
| # --- 2. LOAD MODELS --- | |
| loaded_models = {} | |
| def load_architecture(model_name, num_classes=2): | |
| """Rebuilds the architecture to match your training""" | |
| if "resnet" in model_name: | |
| model = models.resnet50(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
| # Target layer for Grad-CAM in ResNet | |
| target_layer = model.layer4[-1] | |
| else: | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) | |
| # Target layer for Grad-CAM in EfficientNet | |
| target_layer = model.features[-1] | |
| return model, target_layer | |
| print("Loading models... this might take a minute...") | |
| for key, path in MODEL_PATHS.items(): | |
| try: | |
| # Create architecture | |
| model, layer = load_architecture(key) | |
| # Load weights (Ensure you upload the files!) | |
| # If testing without weights, comment out the next line | |
| state_dict = torch.load(path, map_location=DEVICE) | |
| model.load_state_dict(state_dict) | |
| model.to(DEVICE) | |
| model.eval() | |
| # Initialize Grad-CAM for this model | |
| cam = GradCAM(model=model, target_layers=[layer]) | |
| loaded_models[key] = {"model": model, "cam": cam} | |
| print(f"Loaded {key}") | |
| except Exception as e: | |
| print(f"Error loading {key}: {e}") | |
| # Placeholder for demo if weights are missing | |
| loaded_models[key] = None | |
| # --- 3. UTILITIES --- | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def run_inference_and_gradcam(key, image_tensor, original_image_np): | |
| """Runs prediction and generates heatmap for a single model""" | |
| item = loaded_models[key] | |
| if item is None: | |
| return None | |
| model = item["model"] | |
| cam = item["cam"] | |
| start_time = time.time() | |
| # 1. Prediction | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| probs = torch.nn.functional.softmax(outputs, dim=1) | |
| conf, pred_idx = torch.max(probs, 1) | |
| inference_time = (time.time() - start_time) * 1000 # ms | |
| # 2. Grad-CAM | |
| # We need gradients, so we run cam() which handles the forward/backward internally | |
| grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :] | |
| visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True) | |
| # Convert Grad-CAM numpy to Base64 String for frontend | |
| pil_img = Image.fromarray(visualization) | |
| buff = io.BytesIO() | |
| pil_img.save(buff, format="JPEG") | |
| img_str = base64.b64encode(buff.getvalue()).decode("utf-8") | |
| return { | |
| "label": CLASSES[pred_idx.item()], | |
| "confidence": float(conf.item()), | |
| "time": f"{inference_time:.2f}ms", | |
| "heatmap": f"data:image/jpeg;base64,{img_str}" | |
| } | |
| # --- 4. API ENDPOINT --- | |
| async def analyze_leaf(file: UploadFile = File(...)): | |
| # Read Image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| # Prepare Inputs | |
| tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # For Grad-CAM visualization, we need a normalized float numpy array (0-1) | |
| # Resize original image to 256x256 to match tensor | |
| img_resized = image.resize((256, 256)) | |
| img_np = np.array(img_resized, dtype=np.float32) / 255.0 | |
| results = [] | |
| # Process all 4 models | |
| # Mapping frontend IDs to backend keys | |
| definitions = [ | |
| {"id": 1, "key": "resnet_base", "name": "ResNet50 Base"}, | |
| {"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"}, | |
| {"id": 3, "key": "effnet_base", "name": "EffNet Base"}, | |
| {"id": 4, "key": "effnet_aug", "name": "EffNet Aug"}, | |
| ] | |
| for definition in definitions: | |
| data = run_inference_and_gradcam(definition["key"], tensor, img_np) | |
| if data: | |
| results.append({ | |
| "id": definition["id"], | |
| **data | |
| }) | |
| else: | |
| # Fallback if model failed to load | |
| results.append({ | |
| "id": definition["id"], | |
| "label": "Error", | |
| "confidence": 0.0, | |
| "time": "0ms", | |
| "heatmap": "" | |
| }) | |
| return results | |
| def home(): | |
| return {"message": "Maize Ablation Backend is Running"} |