File size: 4,108 Bytes
edd4268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import Response
import torch
import torch.nn as nn
import numpy as np
import cv2
from torchvision import models, transforms
from PIL import Image
import io

app = FastAPI(title="Falcon Change Detection API", version="1.0")

# ==========================================
# 1. LOAD MODEL ON SERVER STARTUP
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Server booting on: {DEVICE}")

class SiameseResNetUNet(nn.Module):
    def __init__(self, n_classes=1):
        super().__init__()
        base_model = models.resnet50(weights=None)
        self.encoder = nn.Sequential(*list(base_model.children())[:-2]) 
        self.up1 = nn.ConvTranspose2d(2048 * 2, 512, kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True))
        self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True))
        self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        self.final_conv = nn.Conv2d(128, n_classes, kernel_size=1)
        
    def forward(self, x1, x2):
        f1, f2 = self.encoder(x1), self.encoder(x2)
        x = torch.cat([f1, f2], dim=1) 
        x = self.conv1(self.up1(x))
        x = self.conv2(self.up2(x))
        x = self.conv3(self.up3(x))
        return self.final_conv(self.final_up(x))

model = SiameseResNetUNet().to(DEVICE)
# ⚠️ Ensure this matches your uploaded file name perfectly
state_dict = torch.load("falcon_india_finetuned.pth", map_location=DEVICE)
if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k[7:]: v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()

# ==========================================
# 2. THE API ENDPOINTS
# ==========================================
@app.get("/")
def health_check():
    return {"status": "online", "model": "Falcon Siamese-UNet", "device": DEVICE}

@app.post("/detect")
async def detect_changes(
    image_past: UploadFile = File(...), 
    image_present: UploadFile = File(...),
    volume_knob: float = Form(11.0),   # Default to your best value
    threshold: float = Form(0.85)      # Default to your best value
):
    # 1. Read both uploaded images
    bytes_past = await image_past.read()
    bytes_present = await image_present.read()
    
    imgA_raw = Image.open(io.BytesIO(bytes_past)).convert("RGB").resize((512, 512), Image.BILINEAR)
    imgB_raw = Image.open(io.BytesIO(bytes_present)).convert("RGB").resize((512, 512), Image.BILINEAR)
    
    imgB_resized = np.array(imgB_raw) # Keep a numpy copy to draw on
    
    # 2. Prepare for the AI
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    
    tA = transform(imgA_raw).unsqueeze(0).to(DEVICE)
    tB = transform(imgB_raw).unsqueeze(0).to(DEVICE)

    # 3. Run Inference
    with torch.no_grad():
        preds = model(tA, tB)
        probs = torch.sigmoid(preds).cpu().numpy()[0][0]

    # 4. The Night Vision Amplifier
    amplified_probs = np.clip(probs * volume_knob, 0, 1)
    binary_mask = (amplified_probs > threshold).astype(np.uint8) * 255

    # 5. Draw Red Boundaries on the Present Image
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    img_with_boundaries = imgB_resized.copy()
    cv2.drawContours(img_with_boundaries, contours, -1, (255, 0, 0), 2)

    # 6. Compress back to a PNG and send to the user
    is_success, buffer = cv2.imencode(".png", cv2.cvtColor(img_with_boundaries, cv2.COLOR_RGB2BGR))
    io_buf = io.BytesIO(buffer)
    
    return Response(content=io_buf.getvalue(), media_type="image/png")