Spaces:
Sleeping
Sleeping
File size: 5,539 Bytes
c72189e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import torch
import base64
import io
import time
import numpy as np
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from torchvision import models, transforms
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
app = FastAPI()
# Enable CORS so your React app can talk to this backend
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for the demo
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 1. CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASSES = ['Gray Leaf Spot', 'Healthy']
# Define paths to your uploaded weights
# Upload these files to your HF Space manually via "Files" tab
MODEL_PATHS = {
"resnet_base": "models/ResNet50_Aug_False.pth",
"resnet_aug": "models/ResNet50_Aug_True.pth",
"effnet_base": "models/EfficientNet_Aug_False.pth",
"effnet_aug": "models/EfficientNet_Aug_True.pth"
}
# --- 2. LOAD MODELS ---
loaded_models = {}
def load_architecture(model_name, num_classes=2):
"""Rebuilds the architecture to match your training"""
if "resnet" in model_name:
model = models.resnet50(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
# Target layer for Grad-CAM in ResNet
target_layer = model.layer4[-1]
else:
model = models.efficientnet_b0(weights=None)
model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
# Target layer for Grad-CAM in EfficientNet
target_layer = model.features[-1]
return model, target_layer
print("Loading models... this might take a minute...")
for key, path in MODEL_PATHS.items():
try:
# Create architecture
model, layer = load_architecture(key)
# Load weights (Ensure you upload the files!)
# If testing without weights, comment out the next line
state_dict = torch.load(path, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
# Initialize Grad-CAM for this model
cam = GradCAM(model=model, target_layers=[layer])
loaded_models[key] = {"model": model, "cam": cam}
print(f"Loaded {key}")
except Exception as e:
print(f"Error loading {key}: {e}")
# Placeholder for demo if weights are missing
loaded_models[key] = None
# --- 3. UTILITIES ---
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def run_inference_and_gradcam(key, image_tensor, original_image_np):
"""Runs prediction and generates heatmap for a single model"""
item = loaded_models[key]
if item is None:
return None
model = item["model"]
cam = item["cam"]
start_time = time.time()
# 1. Prediction
with torch.no_grad():
outputs = model(image_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
conf, pred_idx = torch.max(probs, 1)
inference_time = (time.time() - start_time) * 1000 # ms
# 2. Grad-CAM
# We need gradients, so we run cam() which handles the forward/backward internally
grayscale_cam = cam(input_tensor=image_tensor, targets=None)[0, :]
visualization = show_cam_on_image(original_image_np, grayscale_cam, use_rgb=True)
# Convert Grad-CAM numpy to Base64 String for frontend
pil_img = Image.fromarray(visualization)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
img_str = base64.b64encode(buff.getvalue()).decode("utf-8")
return {
"label": CLASSES[pred_idx.item()],
"confidence": float(conf.item()),
"time": f"{inference_time:.2f}ms",
"heatmap": f"data:image/jpeg;base64,{img_str}"
}
# --- 4. API ENDPOINT ---
@app.post("/analyze")
async def analyze_leaf(file: UploadFile = File(...)):
# Read Image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# Prepare Inputs
tensor = transform(image).unsqueeze(0).to(DEVICE)
# For Grad-CAM visualization, we need a normalized float numpy array (0-1)
# Resize original image to 256x256 to match tensor
img_resized = image.resize((256, 256))
img_np = np.array(img_resized, dtype=np.float32) / 255.0
results = []
# Process all 4 models
# Mapping frontend IDs to backend keys
definitions = [
{"id": 1, "key": "resnet_base", "name": "ResNet50 Base"},
{"id": 2, "key": "resnet_aug", "name": "ResNet50 Aug"},
{"id": 3, "key": "effnet_base", "name": "EffNet Base"},
{"id": 4, "key": "effnet_aug", "name": "EffNet Aug"},
]
for definition in definitions:
data = run_inference_and_gradcam(definition["key"], tensor, img_np)
if data:
results.append({
"id": definition["id"],
**data
})
else:
# Fallback if model failed to load
results.append({
"id": definition["id"],
"label": "Error",
"confidence": 0.0,
"time": "0ms",
"heatmap": ""
})
return results
@app.get("/")
def home():
return {"message": "Maize Ablation Backend is Running"} |