segformer_b5 / app.py
Amrender's picture
Update app.py
b30e7f1 verified
import io
import base64
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import torchvision.transforms.functional as TF
from skimage.morphology import skeletonize
from transformers import SegformerForSemanticSegmentation
from fastapi import FastAPI, UploadFile, File
from huggingface_hub import hf_hub_download
print("🚀 BOOTING FASTAPI PRODUCTION B5 ENGINE...")
# ==========================================
# 1. CONFIG & MODEL DOWNLOAD
# ==========================================
import os
from huggingface_hub import hf_hub_download
print("🚀 BOOTING FASTAPI PRODUCTION B5 ENGINE...")
REPO_ID = "Amrender/b5-cartography-weights"
FILENAME = "best_model (3).pth"
DEVICE = "cpu"
# This pulls the secret key you just saved in your Space settings!
hf_token = os.environ.get("HF_TOKEN")
try:
print(f"⬇️ Fetching B5 Weights from {REPO_ID}...")
# Notice we added token=hf_token here!
MODEL_PATH = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
repo_type="model",
token=hf_token
)
print(f"✅ Weights successfully downloaded to: {MODEL_PATH}")
except Exception as e:
raise RuntimeError(f"❌ Failed to download weights. Check your REPO_ID! Error: {e}")
# ==========================================
# 2. POST-PROCESSING ENGINES (Unchanged)
# ==========================================
def split_plots(binary_mask):
kernel = np.ones((3,3), np.uint8)
eroded = cv2.erode(binary_mask, kernel, iterations=1)
dist_transform = cv2.distanceTransform(eroded, cv2.DIST_L2, 5)
cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
local_max = cv2.dilate(dist_transform, np.ones((15, 15), np.uint8))
peaks = (dist_transform == local_max) & (dist_transform > 0.05)
sure_fg = np.zeros_like(dist_transform, dtype=np.uint8)
sure_fg[peaks] = 255
sure_fg = cv2.dilate(sure_fg, kernel, iterations=1)
sure_bg = cv2.dilate(eroded, kernel, iterations=2)
unknown = cv2.subtract(sure_bg, sure_fg)
ret, markers = cv2.connectedComponents(sure_fg)
markers = markers + 1
markers[unknown == 255] = 0
fake_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
markers = cv2.watershed(fake_rgb, markers)
boundaries = np.zeros_like(binary_mask)
boundaries[markers == -1] = 255
split_mask = binary_mask.copy()
split_mask[markers == -1] = 0
return split_mask, boundaries
def regularize_roads(binary_road_mask, avg_width=10, gap_bridge=20, smooth_factor=0.003):
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (gap_bridge, gap_bridge))
closed_roads = cv2.morphologyEx(binary_road_mask, cv2.MORPH_CLOSE, close_kernel)
contours, hierarchy = cv2.findContours(closed_roads, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
straight_roads = np.zeros_like(closed_roads)
if hierarchy is not None:
for i, cnt in enumerate(contours):
epsilon = smooth_factor * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if hierarchy[0][i][3] == -1:
cv2.drawContours(straight_roads, [approx], -1, 255, -1)
else:
cv2.drawContours(straight_roads, [approx], -1, 0, -1)
bool_mask = straight_roads > 127
skeleton = skeletonize(bool_mask)
skeleton_img = (skeleton * 255).astype(np.uint8)
pave_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (avg_width, avg_width))
uniform_roads = cv2.dilate(skeleton_img, pave_kernel, iterations=1)
return uniform_roads
# ==========================================
# 3. GLOBAL MODEL LOADER
# ==========================================
class UnifiedCartographer(nn.Module):
def __init__(self, num_classes=5):
super().__init__()
self.model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
num_labels=num_classes, ignore_mismatched_sizes=True
)
def forward(self, x):
outputs = self.model(pixel_values=x)
return F.interpolate(outputs.logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
print("🧠 Loading B5 Model into Memory...")
ai_model = UnifiedCartographer(num_classes=5)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
state_dict = checkpoint.get('model_state_dict', checkpoint)
clean_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
clean_state_dict[k[7:]] = v
elif not k.startswith('model.') and f"model.{k}" in ai_model.state_dict():
clean_state_dict[f"model.{k}"] = v
else:
clean_state_dict[k] = v
ai_model.load_state_dict(clean_state_dict, strict=False)
ai_model.to(DEVICE)
ai_model.eval()
print("✅ Custom Satellite Weights successfully loaded!")
# ==========================================
# 4. FASTAPI APP & ROUTES
# ==========================================
app = FastAPI(title="AI Cartography API", version="1.0")
def encode_image_to_base64(img_array):
"""Converts a numpy image array to a base64 encoded string"""
# Convert RGB back to BGR for OpenCV encoding
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
_, buffer = cv2.imencode('.jpg', img_bgr)
return base64.b64encode(buffer).decode('utf-8')
@app.get("/")
def read_root():
return {"status": "Online", "model": "SegFormer B5"}
@app.post("/predict")
async def predict_map(file: UploadFile = File(...)):
"""Receives an image, processes it, and returns base64 encoded maps."""
# 1. Read the uploaded file into an RGB numpy array
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
raw_img_rgb = np.array(image)
# 2. Preprocess
input_tensor = torch.from_numpy(raw_img_rgb.transpose(2, 0, 1)).float() / 255.0
input_tensor = TF.normalize(input_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)).unsqueeze(0).to(DEVICE)
# 3. Inference
with torch.no_grad():
logits = ai_model(input_tensor)
pred_mask = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
# 4. Math Post-Processing
building_mask = np.zeros_like(pred_mask, dtype=np.uint8)
building_mask[pred_mask == 1] = 255
clean_buildings, raw_boundaries = split_plots(building_mask)
thick_boundaries = cv2.dilate(raw_boundaries, np.ones((3,3), np.uint8), iterations=1)
road_mask = np.zeros_like(pred_mask, dtype=np.uint8)
road_mask[pred_mask == 2] = 255
clean_roads = regularize_roads(road_mask, avg_width=10, gap_bridge=20, smooth_factor=0.003)
# 5. Render Final Maps
master_overlay = raw_img_rgb.copy()
master_overlay[clean_roads == 255] = [244, 162, 97]
master_overlay[clean_buildings == 255] = [230, 57, 70]
master_blended = cv2.addWeighted(raw_img_rgb, 0.4, master_overlay, 0.6, 0)
raw_semantic_view = np.zeros_like(raw_img_rgb)
raw_semantic_view[building_mask == 255] = [230, 57, 70]
raw_semantic_view[road_mask == 255] = [244, 162, 97]
# 6. Return as JSON
return {
"status": "success",
"master_map_base64": encode_image_to_base64(master_blended),
"raw_mask_base64": encode_image_to_base64(raw_semantic_view)
}