Spaces:
Sleeping
Sleeping
File size: 4,108 Bytes
edd4268 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import Response
import torch
import torch.nn as nn
import numpy as np
import cv2
from torchvision import models, transforms
from PIL import Image
import io
app = FastAPI(title="Falcon Change Detection API", version="1.0")
# ==========================================
# 1. LOAD MODEL ON SERVER STARTUP
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Server booting on: {DEVICE}")
class SiameseResNetUNet(nn.Module):
def __init__(self, n_classes=1):
super().__init__()
base_model = models.resnet50(weights=None)
self.encoder = nn.Sequential(*list(base_model.children())[:-2])
self.up1 = nn.ConvTranspose2d(2048 * 2, 512, kernel_size=2, stride=2)
self.conv1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True))
self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.conv2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.conv3 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True))
self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
self.final_conv = nn.Conv2d(128, n_classes, kernel_size=1)
def forward(self, x1, x2):
f1, f2 = self.encoder(x1), self.encoder(x2)
x = torch.cat([f1, f2], dim=1)
x = self.conv1(self.up1(x))
x = self.conv2(self.up2(x))
x = self.conv3(self.up3(x))
return self.final_conv(self.final_up(x))
model = SiameseResNetUNet().to(DEVICE)
# ⚠️ Ensure this matches your uploaded file name perfectly
state_dict = torch.load("falcon_india_finetuned.pth", map_location=DEVICE)
if list(state_dict.keys())[0].startswith('module.'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# ==========================================
# 2. THE API ENDPOINTS
# ==========================================
@app.get("/")
def health_check():
return {"status": "online", "model": "Falcon Siamese-UNet", "device": DEVICE}
@app.post("/detect")
async def detect_changes(
image_past: UploadFile = File(...),
image_present: UploadFile = File(...),
volume_knob: float = Form(11.0), # Default to your best value
threshold: float = Form(0.85) # Default to your best value
):
# 1. Read both uploaded images
bytes_past = await image_past.read()
bytes_present = await image_present.read()
imgA_raw = Image.open(io.BytesIO(bytes_past)).convert("RGB").resize((512, 512), Image.BILINEAR)
imgB_raw = Image.open(io.BytesIO(bytes_present)).convert("RGB").resize((512, 512), Image.BILINEAR)
imgB_resized = np.array(imgB_raw) # Keep a numpy copy to draw on
# 2. Prepare for the AI
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
tA = transform(imgA_raw).unsqueeze(0).to(DEVICE)
tB = transform(imgB_raw).unsqueeze(0).to(DEVICE)
# 3. Run Inference
with torch.no_grad():
preds = model(tA, tB)
probs = torch.sigmoid(preds).cpu().numpy()[0][0]
# 4. The Night Vision Amplifier
amplified_probs = np.clip(probs * volume_knob, 0, 1)
binary_mask = (amplified_probs > threshold).astype(np.uint8) * 255
# 5. Draw Red Boundaries on the Present Image
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
img_with_boundaries = imgB_resized.copy()
cv2.drawContours(img_with_boundaries, contours, -1, (255, 0, 0), 2)
# 6. Compress back to a PNG and send to the user
is_success, buffer = cv2.imencode(".png", cv2.cvtColor(img_with_boundaries, cv2.COLOR_RGB2BGR))
io_buf = io.BytesIO(buffer)
return Response(content=io_buf.getvalue(), media_type="image/png") |