from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import Response import torch import torch.nn as nn import numpy as np import cv2 from torchvision import models, transforms from PIL import Image import io app = FastAPI(title="Falcon Change Detection API", version="1.0") # ========================================== # 1. LOAD MODEL ON SERVER STARTUP # ========================================== DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Server booting on: {DEVICE}") class SiameseResNetUNet(nn.Module): def __init__(self, n_classes=1): super().__init__() base_model = models.resnet50(weights=None) self.encoder = nn.Sequential(*list(base_model.children())[:-2]) self.up1 = nn.ConvTranspose2d(2048 * 2, 512, kernel_size=2, stride=2) self.conv1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True)) self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.conv2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) self.conv3 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) self.final_conv = nn.Conv2d(128, n_classes, kernel_size=1) def forward(self, x1, x2): f1, f2 = self.encoder(x1), self.encoder(x2) x = torch.cat([f1, f2], dim=1) x = self.conv1(self.up1(x)) x = self.conv2(self.up2(x)) x = self.conv3(self.up3(x)) return self.final_conv(self.final_up(x)) model = SiameseResNetUNet().to(DEVICE) # ⚠️ Ensure this matches your uploaded file name perfectly state_dict = torch.load("falcon_india_finetuned.pth", map_location=DEVICE) if list(state_dict.keys())[0].startswith('module.'): state_dict = {k[7:]: v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() # ========================================== # 2. THE API ENDPOINTS # ========================================== @app.get("/") def health_check(): return {"status": "online", "model": "Falcon Siamese-UNet", "device": DEVICE} @app.post("/detect") async def detect_changes( image_past: UploadFile = File(...), image_present: UploadFile = File(...), volume_knob: float = Form(11.0), # Default to your best value threshold: float = Form(0.85) # Default to your best value ): # 1. Read both uploaded images bytes_past = await image_past.read() bytes_present = await image_present.read() imgA_raw = Image.open(io.BytesIO(bytes_past)).convert("RGB").resize((512, 512), Image.BILINEAR) imgB_raw = Image.open(io.BytesIO(bytes_present)).convert("RGB").resize((512, 512), Image.BILINEAR) imgB_resized = np.array(imgB_raw) # Keep a numpy copy to draw on # 2. Prepare for the AI transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) tA = transform(imgA_raw).unsqueeze(0).to(DEVICE) tB = transform(imgB_raw).unsqueeze(0).to(DEVICE) # 3. Run Inference with torch.no_grad(): preds = model(tA, tB) probs = torch.sigmoid(preds).cpu().numpy()[0][0] # 4. The Night Vision Amplifier amplified_probs = np.clip(probs * volume_knob, 0, 1) binary_mask = (amplified_probs > threshold).astype(np.uint8) * 255 # 5. Draw Red Boundaries on the Present Image contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) img_with_boundaries = imgB_resized.copy() cv2.drawContours(img_with_boundaries, contours, -1, (255, 0, 0), 2) # 6. Compress back to a PNG and send to the user is_success, buffer = cv2.imencode(".png", cv2.cvtColor(img_with_boundaries, cv2.COLOR_RGB2BGR)) io_buf = io.BytesIO(buffer) return Response(content=io_buf.getvalue(), media_type="image/png")