Spaces:
Sleeping
Sleeping
File size: 7,286 Bytes
4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 09ad2ab 45a2c04 0be88cc 2a5e7f9 0eea31e 2a5e7f9 09ad2ab 0eea31e 09ad2ab 2a5e7f9 09ad2ab 0eea31e 09ad2ab 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 09ad2ab 6e91ab9 4fb1b97 6e91ab9 09ad2ab 4fb1b97 09ad2ab 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 6e91ab9 4fb1b97 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import io
import base64
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import torchvision.transforms.functional as TF
from skimage.morphology import skeletonize
from transformers import SegformerForSemanticSegmentation
from fastapi import FastAPI, UploadFile, File
from huggingface_hub import hf_hub_download
print("🚀 BOOTING FASTAPI PRODUCTION B5 ENGINE...")
# ==========================================
# 1. CONFIG & MODEL DOWNLOAD
# ==========================================
import os
from huggingface_hub import hf_hub_download
print("🚀 BOOTING FASTAPI PRODUCTION B5 ENGINE...")
REPO_ID = "Amrender/b5-cartography-weights"
FILENAME = "best_model (3).pth"
DEVICE = "cpu"
# This pulls the secret key you just saved in your Space settings!
hf_token = os.environ.get("HF_TOKEN")
try:
print(f"⬇️ Fetching B5 Weights from {REPO_ID}...")
# Notice we added token=hf_token here!
MODEL_PATH = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
repo_type="model",
token=hf_token
)
print(f"✅ Weights successfully downloaded to: {MODEL_PATH}")
except Exception as e:
raise RuntimeError(f"❌ Failed to download weights. Check your REPO_ID! Error: {e}")
# ==========================================
# 2. POST-PROCESSING ENGINES (Unchanged)
# ==========================================
def split_plots(binary_mask):
kernel = np.ones((3,3), np.uint8)
eroded = cv2.erode(binary_mask, kernel, iterations=1)
dist_transform = cv2.distanceTransform(eroded, cv2.DIST_L2, 5)
cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
local_max = cv2.dilate(dist_transform, np.ones((15, 15), np.uint8))
peaks = (dist_transform == local_max) & (dist_transform > 0.05)
sure_fg = np.zeros_like(dist_transform, dtype=np.uint8)
sure_fg[peaks] = 255
sure_fg = cv2.dilate(sure_fg, kernel, iterations=1)
sure_bg = cv2.dilate(eroded, kernel, iterations=2)
unknown = cv2.subtract(sure_bg, sure_fg)
ret, markers = cv2.connectedComponents(sure_fg)
markers = markers + 1
markers[unknown == 255] = 0
fake_rgb = cv2.cvtColor(binary_mask, cv2.COLOR_GRAY2BGR)
markers = cv2.watershed(fake_rgb, markers)
boundaries = np.zeros_like(binary_mask)
boundaries[markers == -1] = 255
split_mask = binary_mask.copy()
split_mask[markers == -1] = 0
return split_mask, boundaries
def regularize_roads(binary_road_mask, avg_width=10, gap_bridge=20, smooth_factor=0.003):
close_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (gap_bridge, gap_bridge))
closed_roads = cv2.morphologyEx(binary_road_mask, cv2.MORPH_CLOSE, close_kernel)
contours, hierarchy = cv2.findContours(closed_roads, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
straight_roads = np.zeros_like(closed_roads)
if hierarchy is not None:
for i, cnt in enumerate(contours):
epsilon = smooth_factor * cv2.arcLength(cnt, True)
approx = cv2.approxPolyDP(cnt, epsilon, True)
if hierarchy[0][i][3] == -1:
cv2.drawContours(straight_roads, [approx], -1, 255, -1)
else:
cv2.drawContours(straight_roads, [approx], -1, 0, -1)
bool_mask = straight_roads > 127
skeleton = skeletonize(bool_mask)
skeleton_img = (skeleton * 255).astype(np.uint8)
pave_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (avg_width, avg_width))
uniform_roads = cv2.dilate(skeleton_img, pave_kernel, iterations=1)
return uniform_roads
# ==========================================
# 3. GLOBAL MODEL LOADER
# ==========================================
class UnifiedCartographer(nn.Module):
def __init__(self, num_classes=5):
super().__init__()
self.model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
num_labels=num_classes, ignore_mismatched_sizes=True
)
def forward(self, x):
outputs = self.model(pixel_values=x)
return F.interpolate(outputs.logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
print("🧠 Loading B5 Model into Memory...")
ai_model = UnifiedCartographer(num_classes=5)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
state_dict = checkpoint.get('model_state_dict', checkpoint)
clean_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
clean_state_dict[k[7:]] = v
elif not k.startswith('model.') and f"model.{k}" in ai_model.state_dict():
clean_state_dict[f"model.{k}"] = v
else:
clean_state_dict[k] = v
ai_model.load_state_dict(clean_state_dict, strict=False)
ai_model.to(DEVICE)
ai_model.eval()
print("✅ Custom Satellite Weights successfully loaded!")
# ==========================================
# 4. FASTAPI APP & ROUTES
# ==========================================
app = FastAPI(title="AI Cartography API", version="1.0")
def encode_image_to_base64(img_array):
"""Converts a numpy image array to a base64 encoded string"""
# Convert RGB back to BGR for OpenCV encoding
img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
_, buffer = cv2.imencode('.jpg', img_bgr)
return base64.b64encode(buffer).decode('utf-8')
@app.get("/")
def read_root():
return {"status": "Online", "model": "SegFormer B5"}
@app.post("/predict")
async def predict_map(file: UploadFile = File(...)):
"""Receives an image, processes it, and returns base64 encoded maps."""
# 1. Read the uploaded file into an RGB numpy array
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
raw_img_rgb = np.array(image)
# 2. Preprocess
input_tensor = torch.from_numpy(raw_img_rgb.transpose(2, 0, 1)).float() / 255.0
input_tensor = TF.normalize(input_tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)).unsqueeze(0).to(DEVICE)
# 3. Inference
with torch.no_grad():
logits = ai_model(input_tensor)
pred_mask = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
# 4. Math Post-Processing
building_mask = np.zeros_like(pred_mask, dtype=np.uint8)
building_mask[pred_mask == 1] = 255
clean_buildings, raw_boundaries = split_plots(building_mask)
thick_boundaries = cv2.dilate(raw_boundaries, np.ones((3,3), np.uint8), iterations=1)
road_mask = np.zeros_like(pred_mask, dtype=np.uint8)
road_mask[pred_mask == 2] = 255
clean_roads = regularize_roads(road_mask, avg_width=10, gap_bridge=20, smooth_factor=0.003)
# 5. Render Final Maps
master_overlay = raw_img_rgb.copy()
master_overlay[clean_roads == 255] = [244, 162, 97]
master_overlay[clean_buildings == 255] = [230, 57, 70]
master_blended = cv2.addWeighted(raw_img_rgb, 0.4, master_overlay, 0.6, 0)
raw_semantic_view = np.zeros_like(raw_img_rgb)
raw_semantic_view[building_mask == 255] = [230, 57, 70]
raw_semantic_view[road_mask == 255] = [244, 162, 97]
# 6. Return as JSON
return {
"status": "success",
"master_map_base64": encode_image_to_base64(master_blended),
"raw_mask_base64": encode_image_to_base64(raw_semantic_view)
} |