Amrender's picture
Create app.py
7b9d2a5 verified
import os
import io
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.transforms.functional as TF
import gradio as gr
from fastapi import FastAPI
from transformers import SegformerForSemanticSegmentation
from huggingface_hub import hf_hub_download
print("๐Ÿš€ BOOTING MONOLITHIC FASTAPI + GRADIO ENGINE...")
# ==========================================
# 1. AI CONFIG & DOWNLOAD
# ==========================================
# โš ๏ธ CHANGE THIS to your exact Model repository!
REPO_ID = "Amrender/b5-cartography-weights"
FILENAME = "best_model (3).pth"
DEVICE = "cpu"
hf_token = os.environ.get("HF_TOKEN")
try:
print(f"โฌ‡๏ธ Fetching B5 Weights from {REPO_ID}...")
MODEL_PATH = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
repo_type="model",
token=hf_token
)
print("โœ… Weights downloaded!")
except Exception as e:
raise RuntimeError(f"โŒ Failed to download weights. Check REPO_ID and HF_TOKEN! Error: {e}")
# ==========================================
# 2. LOAD PYTORCH MODEL
# ==========================================
class UnifiedCartographer(nn.Module):
def __init__(self, num_classes=5):
super().__init__()
self.model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
num_labels=num_classes, ignore_mismatched_sizes=True
)
def forward(self, x):
outputs = self.model(pixel_values=x)
return F.interpolate(outputs.logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
print("๐Ÿง  Loading B5 Model into Memory...")
ai_model = UnifiedCartographer(num_classes=5)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
state_dict = checkpoint.get('model_state_dict', checkpoint)
clean_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
clean_state_dict[k[7:]] = v
elif not k.startswith('model.') and f"model.{k}" in ai_model.state_dict():
clean_state_dict[f"model.{k}"] = v
else:
clean_state_dict[k] = v
ai_model.load_state_dict(clean_state_dict, strict=False)
ai_model.to(DEVICE)
ai_model.eval()
print("โœ… AI Engine Online!")
# ==========================================
# 3. LOCAL INFERENCE & MATH LOGIC
# ==========================================
def extract_buildings_locally(img_array):
"""Runs inference directly in memory (no network delays)."""
# Auto-resize to prevent CPU RAM crashes
max_size = 1024
h, w = img_array.shape[:2]
if max(h, w) > max_size:
scale = max_size / max(h, w)
img_array = cv2.resize(img_array, (int(w * scale), int(h * scale)))
# Preprocess
input_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float() / 255.0
input_tensor = TF.normalize(input_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)).unsqueeze(0).to(DEVICE)
# AI Prediction
with torch.no_grad():
logits = ai_model(input_tensor)
pred_mask = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
# Isolate Buildings (Class 1)
building_mask = np.zeros_like(pred_mask, dtype=np.uint8)
building_mask[pred_mask == 1] = 255
return building_mask, img_array # Return resized image too for overlay
def process_temporal_change(img_past, img_present):
if img_past is None or img_present is None:
return None, None
print("๐Ÿ“ก Processing Year 1...")
mask_y1, _ = extract_buildings_locally(img_past)
print("๐Ÿ“ก Processing Year 2...")
mask_y2, resized_present = extract_buildings_locally(img_present)
# Ensure masks are the exact same size for subtraction (in case of weird crops)
if mask_y1.shape != mask_y2.shape:
mask_y1 = cv2.resize(mask_y1, (mask_y2.shape[1], mask_y2.shape[0]))
print("๐Ÿงฎ Calculating Urban Growth...")
# Subtraction Math
raw_new_construction = cv2.subtract(mask_y2, mask_y1)
# Morphological Opening (Noise removal)
kernel = np.ones((5,5), np.uint8)
clean_new_construction = cv2.morphologyEx(raw_new_construction, cv2.MORPH_OPEN, kernel)
clean_new_construction = cv2.dilate(clean_new_construction, np.ones((3,3), np.uint8), iterations=1)
# Overlay Neon Cyan
overlay = resized_present.copy()
overlay[clean_new_construction == 255] = [0, 255, 255] # Neon Cyan
final_dashboard = cv2.addWeighted(resized_present, 0.4, overlay, 0.6, 0)
mask_display = cv2.cvtColor(clean_new_construction, cv2.COLOR_GRAY2RGB)
print("โœ… Done!")
return final_dashboard, mask_display
# ==========================================
# 4. FASTAPI & GRADIO INTEGRATION
# ==========================================
# Initialize FastAPI
app = FastAPI(title="Monolithic Temporal Cartography API")
# Define an API health route
@app.get("/health")
def read_root():
return {"status": "Online", "architecture": "Monolith FastAPI + Gradio"}
# Build the Gradio UI
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# ๐Ÿ™๏ธ Temporal Urban Growth Tracker (Direct AI Engine)")
gr.Markdown("Upload a past and present satellite image. The AI processes these locally in memory, subtracts the footprint history, and highlights brand new construction.")
with gr.Row():
with gr.Column():
img_past = gr.Image(label="1. Past (Year 1)", type="numpy")
with gr.Column():
img_present = gr.Image(label="2. Present (Year 2)", type="numpy")
btn_detect = gr.Button("Analyze Growth", variant="primary")
with gr.Row():
with gr.Column():
output_mask = gr.Image(label="3. Extracted New Construction Mask")
with gr.Column():
output_overlay = gr.Image(label="4. Growth Highlighted (Neon Cyan)")
btn_detect.click(
fn=process_temporal_change,
inputs=[img_past, img_present],
outputs=[output_overlay, output_mask]
)
# Mount Gradio onto the root path of the FastAPI server
app = gr.mount_gradio_app(app, demo, path="/")