ProjectX / app.py
john5050's picture
Update app.py
65cfbb9 verified
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
from torchvision import transforms
from PIL import Image
import numpy as np
import base64
import io
import cv2
from flask import Flask, request, jsonify, send_from_directory
import os
app = Flask(__name__, static_folder="static")
# ── Model β€” exact same as training ────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(
encoder_name="efficientnet-b3",
encoder_weights=None, # no pretrained needed at inference
in_channels=3,
classes=1,
activation=None, # raw logits, same as training
)
MODEL_PATH = "best_model.pth"
if os.path.exists(MODEL_PATH):
model.load_state_dict(
torch.load(MODEL_PATH, map_location=device, weights_only=True)
)
print(f"βœ… Model loaded β€” device: {device}")
else:
print("⚠️ best_model.pth not found β€” running in demo mode")
model.to(device)
model.eval()
# ── Preprocessing β€” matches val_transform from training ───────
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def preprocess(pil_img, patch_size=256):
"""Resize to nearest multiple of patch_size, normalize, tensorize."""
img = np.array(pil_img.convert("RGB"))
h, w = img.shape[:2]
# Pad to multiple of patch_size
new_h = ((h + patch_size - 1) // patch_size) * patch_size
new_w = ((w + patch_size - 1) // patch_size) * patch_size
padded = np.zeros((new_h, new_w, 3), dtype=np.float32)
padded[:h, :w] = img
# Normalize
padded = padded / 255.0
padded = (padded - mean) / std
tensor = torch.tensor(padded).permute(2, 0, 1).float().unsqueeze(0)
return tensor, (h, w)
def run_inference(pil_img, patch_size=256):
"""Run patch-based inference matching training patch extraction."""
img = np.array(pil_img.convert("RGB"))
h, w = img.shape[:2]
# Pad to multiple of patch_size
new_h = ((h + patch_size - 1) // patch_size) * patch_size
new_w = ((w + patch_size - 1) // patch_size) * patch_size
padded = np.zeros((new_h, new_w, 3), dtype=np.uint8)
padded[:h, :w] = img
full_mask = np.zeros((new_h, new_w), dtype=np.float32)
for i in range(0, new_h, patch_size):
for j in range(0, new_w, patch_size):
patch = (
padded[i : i + patch_size, j : j + patch_size].astype(np.float32)
/ 255.0
)
patch = (patch - mean) / std
tensor = (
torch.tensor(patch).permute(2, 0, 1).float().unsqueeze(0).to(device)
)
with torch.no_grad():
out = model(tensor)
prob = torch.sigmoid(out).squeeze().cpu().numpy()
full_mask[i : i + patch_size, j : j + patch_size] = prob
return full_mask[:h, :w], img
# ── Zoning + illegal detection (from your notebook) ───────────
def create_zoning_mask(shape):
h, w = shape
zoning = np.zeros((h, w), dtype=np.uint8)
zoning[:, w // 2 :] = 1
return zoning
def detect_illegal_buildings(binary_mask, zoning_mask):
num_labels, labels = cv2.connectedComponents(binary_mask.astype(np.uint8))
illegal, legal = [], []
for label in range(1, num_labels):
building_pixels = labels == label
if (building_pixels & (zoning_mask == 1)).any():
illegal.append(label)
else:
legal.append(label)
return illegal, legal, labels
def to_base64(arr_or_img):
if isinstance(arr_or_img, np.ndarray):
img = Image.fromarray(arr_or_img.astype(np.uint8))
else:
img = arr_or_img
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
# ── Routes ────────────────────────────────────────────────────
@app.route("/")
def index():
return send_from_directory("static", "index.html")
@app.route("/predict", methods=["POST"])
def predict():
if "image" not in request.files:
return jsonify({"error": "No image provided"}), 400
file = request.files["image"]
pil_img = Image.open(file.stream).convert("RGB")
# Run patch-based segmentation
prob_mask, orig_rgb = run_inference(pil_img, patch_size=256)
binary_mask = (prob_mask > 0.5).astype(np.uint8)
# Zoning-based illegal detection
zoning_mask = create_zoning_mask(binary_mask.shape)
illegal, legal, labels = detect_illegal_buildings(binary_mask, zoning_mask)
# Build overlay: illegal=red, legal=green
overlay = orig_rgb.copy()
for lbl in illegal:
overlay[labels == lbl] = [255, 0, 0]
for lbl in legal:
overlay[labels == lbl] = [0, 200, 100]
total = len(illegal) + len(legal)
illegal_pct = round(float(binary_mask.mean() * 100), 2)
verdict = "ILLEGAL CONSTRUCTION DETECTED" if illegal else "NO VIOLATION DETECTED"
return jsonify(
{
"verdict": verdict,
"illegal_count": len(illegal),
"legal_count": len(legal),
"total_count": total,
"illegal_percent": illegal_pct,
"device": str(device),
"original": to_base64(orig_rgb),
"mask": to_base64((prob_mask * 255).astype(np.uint8)),
"overlay": to_base64(overlay),
}
)
if __name__ == "__main__":
print(f"Running on: {device}")
app.run(host="0.0.0.0",debug=True, port=7860)