intisarhasnain's picture
align app.py with original run.py: BGR->RGB fix, lower wall threshold
4fae684
import os
import json
import numpy as np
from typing import Dict, Any
import cv2
import torch
import logging
import base64
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from mmdet.apis import init_detector, inference_detector
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "mmdetection"))
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_FILE = os.path.join(BASE_DIR, "configs", "faster_rcnn.py")
CHECKPOINT_FILE = os.path.join(BASE_DIR, "weights", "faster_rcnn_latest.pth")
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
SCORE_THRESH = 0.3 # lower than default to catch more walls
CLASS_COLORS = {
0: (220, 60, 60), # wall — red (RGB)
1: (50, 200, 80), # room — green (RGB)
}
CLASS_NAMES = {0: "wall", 1: "room"}
# ── Device ───────────────────────────────────────────────────────────────────
def determine_device():
if torch.cuda.is_available():
try:
torch.cuda.init()
return "cuda:0"
except Exception as e:
logger.warning(f"CUDA failed: {e}. Using CPU.")
return "cpu"
# ── Model load ───────────────────────────────────────────────────────────────
device = determine_device()
logger.info(f"Loading Faster R-CNN on {device}…")
model = init_detector(CONFIG_FILE, CHECKPOINT_FILE, device=device)
logger.info("Model ready.")
# ── Result processing (mirrors original run.py exactly) ──────────────────────
def process_inference_result(result) -> Dict[str, Any]:
bboxes = result.pred_instances.bboxes.cpu().numpy()
labels = result.pred_instances.labels.cpu().numpy()
scores = result.pred_instances.scores.cpu().numpy()
walls, rooms = [], []
for i, (bbox, label, score) in enumerate(zip(bboxes, labels, scores)):
if score < SCORE_THRESH:
continue
x1, y1, x2, y2 = bbox
item = {
"id": f"{'wall' if label == 0 else 'room'}_{i+1}",
"position": {
"start": {"x": float(x1), "y": float(y1)},
"end": {"x": float(x2), "y": float(y2)}
},
"confidence": float(score)
}
if label == 0:
walls.append(item)
else:
rooms.append(item)
all_scores = scores[scores >= SCORE_THRESH]
return {
"type": "floor_plan",
"confidence": float(np.mean(all_scores)) if len(all_scores) else 0.0,
"detectionResults": {"walls": walls, "rooms": rooms}
}
# ── Visualisation ─────────────────────────────────────────────────────────────
def draw_detections(img_rgb: np.ndarray, result) -> np.ndarray:
annotated = img_rgb.copy()
bboxes = result.pred_instances.bboxes.cpu().numpy()
labels = result.pred_instances.labels.cpu().numpy()
scores = result.pred_instances.scores.cpu().numpy()
for bbox, label, score in zip(bboxes, labels, scores):
if score < SCORE_THRESH or label not in CLASS_NAMES:
continue
color = CLASS_COLORS[label]
name = CLASS_NAMES[label]
x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
# Semi-transparent fill
overlay = annotated.copy()
cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
cv2.addWeighted(overlay, 0.15, annotated, 0.85, 0, annotated)
# Border
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
# Label
lbl = f"{name} {score:.2f}"
(tw, th), _ = cv2.getTextSize(lbl, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(annotated, (x1, y1-th-6), (x1+tw+4, y1), color, -1)
cv2.putText(annotated, lbl, (x1+2, y1-4),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return annotated
# ── FastAPI ───────────────────────────────────────────────────────────────────
app = FastAPI()
HTML = """<!DOCTYPE html>
<html>
<head>
<title>Floor Plan Detection</title>
<style>
*{box-sizing:border-box;margin:0;padding:0}
body{font-family:monospace;background:#0f0f0f;color:#e0e0e0;padding:32px 24px}
h1{color:#7eb8f7;margin-bottom:8px}
p.sub{color:#888;margin-bottom:24px;font-size:.9rem}
.controls{display:flex;gap:12px;align-items:center;flex-wrap:wrap;margin-bottom:24px}
input[type=file]{display:none}
label.btn{padding:9px 18px;background:#1e3a5f;color:#7eb8f7;border:1px solid #7eb8f7;border-radius:4px;cursor:pointer}
label.btn:hover{background:#2a4f7f}
button{padding:9px 24px;background:#7eb8f7;color:#0f0f0f;border:none;border-radius:4px;cursor:pointer;font-weight:bold;font-size:.95rem}
button:hover{background:#5a9ee0}
#fname{color:#555;font-size:.85rem}
.row{display:flex;gap:20px;flex-wrap:wrap;margin-bottom:16px}
.col{flex:1;min-width:280px}
.col p{color:#888;font-size:.8rem;margin-bottom:6px}
.imgbox{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;min-height:220px;
display:flex;align-items:center;justify-content:center;color:#444;overflow:hidden}
.imgbox img{max-width:100%;display:block}
#summary{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;padding:14px;
white-space:pre-wrap;font-size:.85rem;min-height:60px;color:#ccc}
.legend{margin-top:12px;font-size:.85rem;color:#888}
.dot{display:inline-block;width:10px;height:10px;border-radius:2px;margin-right:4px;vertical-align:middle}
.loading{color:#7eb8f7;animation:pulse 1.2s infinite}
@keyframes pulse{0%,100%{opacity:1}50%{opacity:.4}}
</style>
</head>
<body>
<h1>🏠 Floor Plan Detection</h1>
<p class="sub">Faster R-CNN · ResNet-101 · FPN · fine-tuned on CubiCasa5k</p>
<div class="controls">
<label class="btn" for="fi">📂 Choose Image</label>
<input type="file" id="fi" accept="image/jpeg,image/png">
<button onclick="detect()">▶ Run Detection</button>
<span id="fname">No file chosen</span>
</div>
<div class="row">
<div class="col">
<p>Input</p>
<div class="imgbox" id="preview">No image loaded</div>
</div>
<div class="col">
<p>Detections</p>
<div class="imgbox" id="result">Run detection to see results</div>
</div>
</div>
<div id="summary">Upload an image and click Run Detection.</div>
<div class="legend">
<span class="dot" style="background:#dc3c3c"></span>Wall &nbsp;
<span class="dot" style="background:#32c850"></span>Room
</div>
<script>
let file = null;
document.getElementById('fi').addEventListener('change', e => {
file = e.target.files[0];
if (!file) return;
document.getElementById('fname').textContent = file.name;
const r = new FileReader();
r.onload = ev => document.getElementById('preview').innerHTML = `<img src="${ev.target.result}">`;
r.readAsDataURL(file);
});
async function detect() {
if (!file) { alert('Choose an image first.'); return; }
document.getElementById('result').innerHTML = '<span class="loading">Running… (30–60s on CPU)</span>';
document.getElementById('summary').textContent = 'Processing…';
const fd = new FormData();
fd.append('image', file);
try {
const r = await fetch('/detect', {method:'POST', body:fd});
const d = await r.json();
if (d.error) { document.getElementById('result').innerHTML = 'Error'; document.getElementById('summary').textContent = d.error; return; }
document.getElementById('result').innerHTML = `<img src="data:image/jpeg;base64,${d.image}">`;
const w = d.json.detectionResults.walls.length;
const rm = d.json.detectionResults.rooms.length;
let txt = `Detected: ${w} wall(s) | ${rm} room(s) (conf threshold: 0.30)\n`;
txt += `Overall confidence: ${(d.json.confidence*100).toFixed(1)}%\n\n`;
d.json.detectionResults.walls.forEach(x => txt += ` • Wall ${x.id} conf=${x.confidence.toFixed(3)}\n`);
d.json.detectionResults.rooms.forEach(x => txt += ` • Room ${x.id} conf=${x.confidence.toFixed(3)}\n`);
document.getElementById('summary').textContent = txt;
} catch(e) {
document.getElementById('result').innerHTML = 'Error';
document.getElementById('summary').textContent = String(e);
}
}
</script>
</body>
</html>"""
@app.get("/", response_class=HTMLResponse)
def index():
return HTML
@app.post("/detect")
async def detect(image: UploadFile = File(...)):
if image.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(status_code=400, detail="Only JPEG and PNG supported.")
contents = await image.read()
if len(contents) > MAX_FILE_SIZE:
raise HTTPException(status_code=400, detail="File exceeds 10 MB limit.")
nparr = np.frombuffer(contents, np.uint8)
img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if img_bgr is None:
raise HTTPException(status_code=400, detail="Could not decode image.")
# Original run.py converts BGR→RGB before inference
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
result = inference_detector(model, img_rgb)
# JSON output — matches original run.py exactly
processed = process_inference_result(result)
# Visual output — draw on RGB image, encode as JPEG
annotated_rgb = draw_detections(img_rgb, result)
annotated_bgr = cv2.cvtColor(annotated_rgb, cv2.COLOR_RGB2BGR)
_, buf = cv2.imencode(".jpg", annotated_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
b64 = base64.b64encode(buf).decode()
logger.info(f"Inference done: {len(processed['detectionResults']['walls'])} walls, "
f"{len(processed['detectionResults']['rooms'])} rooms")
return {"image": b64, "json": processed}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)