import os import io import cv2 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torchvision.transforms.functional as TF import gradio as gr from fastapi import FastAPI from transformers import SegformerForSemanticSegmentation from huggingface_hub import hf_hub_download print("🚀 BOOTING MONOLITHIC FASTAPI + GRADIO ENGINE...") # ========================================== # 1. AI CONFIG & DOWNLOAD # ========================================== # ⚠️ CHANGE THIS to your exact Model repository! REPO_ID = "Amrender/b5-cartography-weights" FILENAME = "best_model (3).pth" DEVICE = "cpu" hf_token = os.environ.get("HF_TOKEN") try: print(f"⬇️ Fetching B5 Weights from {REPO_ID}...") MODEL_PATH = hf_hub_download( repo_id=REPO_ID, filename=FILENAME, repo_type="model", token=hf_token ) print("✅ Weights downloaded!") except Exception as e: raise RuntimeError(f"❌ Failed to download weights. Check REPO_ID and HF_TOKEN! Error: {e}") # ========================================== # 2. LOAD PYTORCH MODEL # ========================================== class UnifiedCartographer(nn.Module): def __init__(self, num_classes=5): super().__init__() self.model = SegformerForSemanticSegmentation.from_pretrained( "nvidia/segformer-b5-finetuned-cityscapes-1024-1024", num_labels=num_classes, ignore_mismatched_sizes=True ) def forward(self, x): outputs = self.model(pixel_values=x) return F.interpolate(outputs.logits, size=x.shape[-2:], mode="bilinear", align_corners=False) print("🧠 Loading B5 Model into Memory...") ai_model = UnifiedCartographer(num_classes=5) checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) state_dict = checkpoint.get('model_state_dict', checkpoint) clean_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): clean_state_dict[k[7:]] = v elif not k.startswith('model.') and f"model.{k}" in ai_model.state_dict(): clean_state_dict[f"model.{k}"] = v else: clean_state_dict[k] = v ai_model.load_state_dict(clean_state_dict, strict=False) ai_model.to(DEVICE) ai_model.eval() print("✅ AI Engine Online!") # ========================================== # 3. LOCAL INFERENCE & MATH LOGIC # ========================================== def extract_buildings_locally(img_array): """Runs inference directly in memory (no network delays).""" # Auto-resize to prevent CPU RAM crashes max_size = 1024 h, w = img_array.shape[:2] if max(h, w) > max_size: scale = max_size / max(h, w) img_array = cv2.resize(img_array, (int(w * scale), int(h * scale))) # Preprocess input_tensor = torch.from_numpy(img_array.transpose(2, 0, 1)).float() / 255.0 input_tensor = TF.normalize(input_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)).unsqueeze(0).to(DEVICE) # AI Prediction with torch.no_grad(): logits = ai_model(input_tensor) pred_mask = torch.argmax(logits, dim=1).squeeze().cpu().numpy() # Isolate Buildings (Class 1) building_mask = np.zeros_like(pred_mask, dtype=np.uint8) building_mask[pred_mask == 1] = 255 return building_mask, img_array # Return resized image too for overlay def process_temporal_change(img_past, img_present): if img_past is None or img_present is None: return None, None print("📡 Processing Year 1...") mask_y1, _ = extract_buildings_locally(img_past) print("📡 Processing Year 2...") mask_y2, resized_present = extract_buildings_locally(img_present) # Ensure masks are the exact same size for subtraction (in case of weird crops) if mask_y1.shape != mask_y2.shape: mask_y1 = cv2.resize(mask_y1, (mask_y2.shape[1], mask_y2.shape[0])) print("🧮 Calculating Urban Growth...") # Subtraction Math raw_new_construction = cv2.subtract(mask_y2, mask_y1) # Morphological Opening (Noise removal) kernel = np.ones((5,5), np.uint8) clean_new_construction = cv2.morphologyEx(raw_new_construction, cv2.MORPH_OPEN, kernel) clean_new_construction = cv2.dilate(clean_new_construction, np.ones((3,3), np.uint8), iterations=1) # Overlay Neon Cyan overlay = resized_present.copy() overlay[clean_new_construction == 255] = [0, 255, 255] # Neon Cyan final_dashboard = cv2.addWeighted(resized_present, 0.4, overlay, 0.6, 0) mask_display = cv2.cvtColor(clean_new_construction, cv2.COLOR_GRAY2RGB) print("✅ Done!") return final_dashboard, mask_display # ========================================== # 4. FASTAPI & GRADIO INTEGRATION # ========================================== # Initialize FastAPI app = FastAPI(title="Monolithic Temporal Cartography API") # Define an API health route @app.get("/health") def read_root(): return {"status": "Online", "architecture": "Monolith FastAPI + Gradio"} # Build the Gradio UI with gr.Blocks(theme=gr.themes.Monochrome()) as demo: gr.Markdown("# 🏙️ Temporal Urban Growth Tracker (Direct AI Engine)") gr.Markdown("Upload a past and present satellite image. The AI processes these locally in memory, subtracts the footprint history, and highlights brand new construction.") with gr.Row(): with gr.Column(): img_past = gr.Image(label="1. Past (Year 1)", type="numpy") with gr.Column(): img_present = gr.Image(label="2. Present (Year 2)", type="numpy") btn_detect = gr.Button("Analyze Growth", variant="primary") with gr.Row(): with gr.Column(): output_mask = gr.Image(label="3. Extracted New Construction Mask") with gr.Column(): output_overlay = gr.Image(label="4. Growth Highlighted (Neon Cyan)") btn_detect.click( fn=process_temporal_change, inputs=[img_past, img_present], outputs=[output_overlay, output_mask] ) # Mount Gradio onto the root path of the FastAPI server app = gr.mount_gradio_app(app, demo, path="/")