Amrender commited on
Commit
edd4268
·
verified ·
1 Parent(s): 45340e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
2
+ from fastapi.responses import Response
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import cv2
7
+ from torchvision import models, transforms
8
+ from PIL import Image
9
+ import io
10
+
11
+ app = FastAPI(title="Falcon Change Detection API", version="1.0")
12
+
13
+ # ==========================================
14
+ # 1. LOAD MODEL ON SERVER STARTUP
15
+ # ==========================================
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"Server booting on: {DEVICE}")
18
+
19
+ class SiameseResNetUNet(nn.Module):
20
+ def __init__(self, n_classes=1):
21
+ super().__init__()
22
+ base_model = models.resnet50(weights=None)
23
+ self.encoder = nn.Sequential(*list(base_model.children())[:-2])
24
+ self.up1 = nn.ConvTranspose2d(2048 * 2, 512, kernel_size=2, stride=2)
25
+ self.conv1 = nn.Sequential(nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(inplace=True))
26
+ self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
27
+ self.conv2 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
28
+ self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
29
+ self.conv3 = nn.Sequential(nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.ReLU(inplace=True))
30
+ self.final_up = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
31
+ self.final_conv = nn.Conv2d(128, n_classes, kernel_size=1)
32
+
33
+ def forward(self, x1, x2):
34
+ f1, f2 = self.encoder(x1), self.encoder(x2)
35
+ x = torch.cat([f1, f2], dim=1)
36
+ x = self.conv1(self.up1(x))
37
+ x = self.conv2(self.up2(x))
38
+ x = self.conv3(self.up3(x))
39
+ return self.final_conv(self.final_up(x))
40
+
41
+ model = SiameseResNetUNet().to(DEVICE)
42
+ # ⚠️ Ensure this matches your uploaded file name perfectly
43
+ state_dict = torch.load("falcon_india_finetuned.pth", map_location=DEVICE)
44
+ if list(state_dict.keys())[0].startswith('module.'):
45
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
46
+ model.load_state_dict(state_dict)
47
+ model.eval()
48
+
49
+ # ==========================================
50
+ # 2. THE API ENDPOINTS
51
+ # ==========================================
52
+ @app.get("/")
53
+ def health_check():
54
+ return {"status": "online", "model": "Falcon Siamese-UNet", "device": DEVICE}
55
+
56
+ @app.post("/detect")
57
+ async def detect_changes(
58
+ image_past: UploadFile = File(...),
59
+ image_present: UploadFile = File(...),
60
+ volume_knob: float = Form(11.0), # Default to your best value
61
+ threshold: float = Form(0.85) # Default to your best value
62
+ ):
63
+ # 1. Read both uploaded images
64
+ bytes_past = await image_past.read()
65
+ bytes_present = await image_present.read()
66
+
67
+ imgA_raw = Image.open(io.BytesIO(bytes_past)).convert("RGB").resize((512, 512), Image.BILINEAR)
68
+ imgB_raw = Image.open(io.BytesIO(bytes_present)).convert("RGB").resize((512, 512), Image.BILINEAR)
69
+
70
+ imgB_resized = np.array(imgB_raw) # Keep a numpy copy to draw on
71
+
72
+ # 2. Prepare for the AI
73
+ transform = transforms.Compose([
74
+ transforms.ToTensor(),
75
+ transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
76
+ ])
77
+
78
+ tA = transform(imgA_raw).unsqueeze(0).to(DEVICE)
79
+ tB = transform(imgB_raw).unsqueeze(0).to(DEVICE)
80
+
81
+ # 3. Run Inference
82
+ with torch.no_grad():
83
+ preds = model(tA, tB)
84
+ probs = torch.sigmoid(preds).cpu().numpy()[0][0]
85
+
86
+ # 4. The Night Vision Amplifier
87
+ amplified_probs = np.clip(probs * volume_knob, 0, 1)
88
+ binary_mask = (amplified_probs > threshold).astype(np.uint8) * 255
89
+
90
+ # 5. Draw Red Boundaries on the Present Image
91
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
92
+ img_with_boundaries = imgB_resized.copy()
93
+ cv2.drawContours(img_with_boundaries, contours, -1, (255, 0, 0), 2)
94
+
95
+ # 6. Compress back to a PNG and send to the user
96
+ is_success, buffer = cv2.imencode(".png", cv2.cvtColor(img_with_boundaries, cv2.COLOR_RGB2BGR))
97
+ io_buf = io.BytesIO(buffer)
98
+
99
+ return Response(content=io_buf.getvalue(), media_type="image/png")