Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import Response | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import cv2 | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import io | |
| app = FastAPI(title="Falcon Change Detection API", version="1.0") | |
| # ========================================== | |
| # 1. LOAD MODEL ON SERVER STARTUP | |
| # ========================================== | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Server booting on: {DEVICE}") | |
| class SiameseResNetUNet(nn.Module): | |
| def __init__(self, n_classes=1): | |
| super().__init__() | |
| base_model = models.resnet50(weights=None) | |
| self.encoder = nn.Sequential(*list(base_model.children())[:-2]) | |
| self.up1 = nn.ConvTranspose2d(2048 * 2, 512, kernel_size=2, stride=2) | |
| self.conv1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True)) | |
| self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) | |
| self.conv2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True)) | |
| self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) | |
| self.conv3 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True)) | |
| self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) | |
| self.final_conv = nn.Conv2d(128, n_classes, kernel_size=1) | |
| def forward(self, x1, x2): | |
| f1, f2 = self.encoder(x1), self.encoder(x2) | |
| x = torch.cat([f1, f2], dim=1) | |
| x = self.conv1(self.up1(x)) | |
| x = self.conv2(self.up2(x)) | |
| x = self.conv3(self.up3(x)) | |
| return self.final_conv(self.final_up(x)) | |
| model = SiameseResNetUNet().to(DEVICE) | |
| # ⚠️ Ensure this matches your uploaded file name perfectly | |
| state_dict = torch.load("falcon_india_finetuned.pth", map_location=DEVICE) | |
| if list(state_dict.keys())[0].startswith('module.'): | |
| state_dict = {k[7:]: v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| # ========================================== | |
| # 2. THE API ENDPOINTS | |
| # ========================================== | |
| def health_check(): | |
| return {"status": "online", "model": "Falcon Siamese-UNet", "device": DEVICE} | |
| async def detect_changes( | |
| image_past: UploadFile = File(...), | |
| image_present: UploadFile = File(...), | |
| volume_knob: float = Form(11.0), # Default to your best value | |
| threshold: float = Form(0.85) # Default to your best value | |
| ): | |
| # 1. Read both uploaded images | |
| bytes_past = await image_past.read() | |
| bytes_present = await image_present.read() | |
| imgA_raw = Image.open(io.BytesIO(bytes_past)).convert("RGB").resize((512, 512), Image.BILINEAR) | |
| imgB_raw = Image.open(io.BytesIO(bytes_present)).convert("RGB").resize((512, 512), Image.BILINEAR) | |
| imgB_resized = np.array(imgB_raw) # Keep a numpy copy to draw on | |
| # 2. Prepare for the AI | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) | |
| ]) | |
| tA = transform(imgA_raw).unsqueeze(0).to(DEVICE) | |
| tB = transform(imgB_raw).unsqueeze(0).to(DEVICE) | |
| # 3. Run Inference | |
| with torch.no_grad(): | |
| preds = model(tA, tB) | |
| probs = torch.sigmoid(preds).cpu().numpy()[0][0] | |
| # 4. The Night Vision Amplifier | |
| amplified_probs = np.clip(probs * volume_knob, 0, 1) | |
| binary_mask = (amplified_probs > threshold).astype(np.uint8) * 255 | |
| # 5. Draw Red Boundaries on the Present Image | |
| contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| img_with_boundaries = imgB_resized.copy() | |
| cv2.drawContours(img_with_boundaries, contours, -1, (255, 0, 0), 2) | |
| # 6. Compress back to a PNG and send to the user | |
| is_success, buffer = cv2.imencode(".png", cv2.cvtColor(img_with_boundaries, cv2.COLOR_RGB2BGR)) | |
| io_buf = io.BytesIO(buffer) | |
| return Response(content=io_buf.getvalue(), media_type="image/png") |