| import os |
| import json |
| import numpy as np |
| from typing import Dict, Any |
| import cv2 |
| import torch |
| import logging |
| import base64 |
| import uvicorn |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import HTMLResponse, JSONResponse |
| from mmdet.apis import init_detector, inference_detector |
| import sys |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "mmdetection")) |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) |
| CONFIG_FILE = os.path.join(BASE_DIR, "configs", "faster_rcnn.py") |
| CHECKPOINT_FILE = os.path.join(BASE_DIR, "weights", "faster_rcnn_latest.pth") |
| MAX_FILE_SIZE = 10 * 1024 * 1024 |
| SCORE_THRESH = 0.3 |
|
|
| CLASS_COLORS = { |
| 0: (220, 60, 60), |
| 1: (50, 200, 80), |
| } |
| CLASS_NAMES = {0: "wall", 1: "room"} |
|
|
| |
| def determine_device(): |
| if torch.cuda.is_available(): |
| try: |
| torch.cuda.init() |
| return "cuda:0" |
| except Exception as e: |
| logger.warning(f"CUDA failed: {e}. Using CPU.") |
| return "cpu" |
|
|
| |
| device = determine_device() |
| logger.info(f"Loading Faster R-CNN on {device}…") |
| model = init_detector(CONFIG_FILE, CHECKPOINT_FILE, device=device) |
| logger.info("Model ready.") |
|
|
| |
| def process_inference_result(result) -> Dict[str, Any]: |
| bboxes = result.pred_instances.bboxes.cpu().numpy() |
| labels = result.pred_instances.labels.cpu().numpy() |
| scores = result.pred_instances.scores.cpu().numpy() |
|
|
| walls, rooms = [], [] |
| for i, (bbox, label, score) in enumerate(zip(bboxes, labels, scores)): |
| if score < SCORE_THRESH: |
| continue |
| x1, y1, x2, y2 = bbox |
| item = { |
| "id": f"{'wall' if label == 0 else 'room'}_{i+1}", |
| "position": { |
| "start": {"x": float(x1), "y": float(y1)}, |
| "end": {"x": float(x2), "y": float(y2)} |
| }, |
| "confidence": float(score) |
| } |
| if label == 0: |
| walls.append(item) |
| else: |
| rooms.append(item) |
|
|
| all_scores = scores[scores >= SCORE_THRESH] |
| return { |
| "type": "floor_plan", |
| "confidence": float(np.mean(all_scores)) if len(all_scores) else 0.0, |
| "detectionResults": {"walls": walls, "rooms": rooms} |
| } |
|
|
| |
| def draw_detections(img_rgb: np.ndarray, result) -> np.ndarray: |
| annotated = img_rgb.copy() |
| bboxes = result.pred_instances.bboxes.cpu().numpy() |
| labels = result.pred_instances.labels.cpu().numpy() |
| scores = result.pred_instances.scores.cpu().numpy() |
|
|
| for bbox, label, score in zip(bboxes, labels, scores): |
| if score < SCORE_THRESH or label not in CLASS_NAMES: |
| continue |
| color = CLASS_COLORS[label] |
| name = CLASS_NAMES[label] |
| x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) |
|
|
| |
| overlay = annotated.copy() |
| cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) |
| cv2.addWeighted(overlay, 0.15, annotated, 0.85, 0, annotated) |
| |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) |
| |
| lbl = f"{name} {score:.2f}" |
| (tw, th), _ = cv2.getTextSize(lbl, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
| cv2.rectangle(annotated, (x1, y1-th-6), (x1+tw+4, y1), color, -1) |
| cv2.putText(annotated, lbl, (x1+2, y1-4), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
|
|
| return annotated |
|
|
| |
| app = FastAPI() |
|
|
| HTML = """<!DOCTYPE html> |
| <html> |
| <head> |
| <title>Floor Plan Detection</title> |
| <style> |
| *{box-sizing:border-box;margin:0;padding:0} |
| body{font-family:monospace;background:#0f0f0f;color:#e0e0e0;padding:32px 24px} |
| h1{color:#7eb8f7;margin-bottom:8px} |
| p.sub{color:#888;margin-bottom:24px;font-size:.9rem} |
| .controls{display:flex;gap:12px;align-items:center;flex-wrap:wrap;margin-bottom:24px} |
| input[type=file]{display:none} |
| label.btn{padding:9px 18px;background:#1e3a5f;color:#7eb8f7;border:1px solid #7eb8f7;border-radius:4px;cursor:pointer} |
| label.btn:hover{background:#2a4f7f} |
| button{padding:9px 24px;background:#7eb8f7;color:#0f0f0f;border:none;border-radius:4px;cursor:pointer;font-weight:bold;font-size:.95rem} |
| button:hover{background:#5a9ee0} |
| #fname{color:#555;font-size:.85rem} |
| .row{display:flex;gap:20px;flex-wrap:wrap;margin-bottom:16px} |
| .col{flex:1;min-width:280px} |
| .col p{color:#888;font-size:.8rem;margin-bottom:6px} |
| .imgbox{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;min-height:220px; |
| display:flex;align-items:center;justify-content:center;color:#444;overflow:hidden} |
| .imgbox img{max-width:100%;display:block} |
| #summary{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;padding:14px; |
| white-space:pre-wrap;font-size:.85rem;min-height:60px;color:#ccc} |
| .legend{margin-top:12px;font-size:.85rem;color:#888} |
| .dot{display:inline-block;width:10px;height:10px;border-radius:2px;margin-right:4px;vertical-align:middle} |
| .loading{color:#7eb8f7;animation:pulse 1.2s infinite} |
| @keyframes pulse{0%,100%{opacity:1}50%{opacity:.4}} |
| </style> |
| </head> |
| <body> |
| <h1>🏠 Floor Plan Detection</h1> |
| <p class="sub">Faster R-CNN · ResNet-101 · FPN · fine-tuned on CubiCasa5k</p> |
| |
| <div class="controls"> |
| <label class="btn" for="fi">📂 Choose Image</label> |
| <input type="file" id="fi" accept="image/jpeg,image/png"> |
| <button onclick="detect()">▶ Run Detection</button> |
| <span id="fname">No file chosen</span> |
| </div> |
| |
| <div class="row"> |
| <div class="col"> |
| <p>Input</p> |
| <div class="imgbox" id="preview">No image loaded</div> |
| </div> |
| <div class="col"> |
| <p>Detections</p> |
| <div class="imgbox" id="result">Run detection to see results</div> |
| </div> |
| </div> |
| |
| <div id="summary">Upload an image and click Run Detection.</div> |
| |
| <div class="legend"> |
| <span class="dot" style="background:#dc3c3c"></span>Wall |
| <span class="dot" style="background:#32c850"></span>Room |
| </div> |
| |
| <script> |
| let file = null; |
| document.getElementById('fi').addEventListener('change', e => { |
| file = e.target.files[0]; |
| if (!file) return; |
| document.getElementById('fname').textContent = file.name; |
| const r = new FileReader(); |
| r.onload = ev => document.getElementById('preview').innerHTML = `<img src="${ev.target.result}">`; |
| r.readAsDataURL(file); |
| }); |
| |
| async function detect() { |
| if (!file) { alert('Choose an image first.'); return; } |
| document.getElementById('result').innerHTML = '<span class="loading">Running… (30–60s on CPU)</span>'; |
| document.getElementById('summary').textContent = 'Processing…'; |
| const fd = new FormData(); |
| fd.append('image', file); |
| try { |
| const r = await fetch('/detect', {method:'POST', body:fd}); |
| const d = await r.json(); |
| if (d.error) { document.getElementById('result').innerHTML = 'Error'; document.getElementById('summary').textContent = d.error; return; } |
| document.getElementById('result').innerHTML = `<img src="data:image/jpeg;base64,${d.image}">`; |
| const w = d.json.detectionResults.walls.length; |
| const rm = d.json.detectionResults.rooms.length; |
| let txt = `Detected: ${w} wall(s) | ${rm} room(s) (conf threshold: 0.30)\n`; |
| txt += `Overall confidence: ${(d.json.confidence*100).toFixed(1)}%\n\n`; |
| d.json.detectionResults.walls.forEach(x => txt += ` • Wall ${x.id} conf=${x.confidence.toFixed(3)}\n`); |
| d.json.detectionResults.rooms.forEach(x => txt += ` • Room ${x.id} conf=${x.confidence.toFixed(3)}\n`); |
| document.getElementById('summary').textContent = txt; |
| } catch(e) { |
| document.getElementById('result').innerHTML = 'Error'; |
| document.getElementById('summary').textContent = String(e); |
| } |
| } |
| </script> |
| </body> |
| </html>""" |
|
|
| @app.get("/", response_class=HTMLResponse) |
| def index(): |
| return HTML |
|
|
| @app.post("/detect") |
| async def detect(image: UploadFile = File(...)): |
| if image.content_type not in ["image/jpeg", "image/png"]: |
| raise HTTPException(status_code=400, detail="Only JPEG and PNG supported.") |
|
|
| contents = await image.read() |
| if len(contents) > MAX_FILE_SIZE: |
| raise HTTPException(status_code=400, detail="File exceeds 10 MB limit.") |
|
|
| nparr = np.frombuffer(contents, np.uint8) |
| img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| if img_bgr is None: |
| raise HTTPException(status_code=400, detail="Could not decode image.") |
|
|
| |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) |
|
|
| result = inference_detector(model, img_rgb) |
|
|
| |
| processed = process_inference_result(result) |
|
|
| |
| annotated_rgb = draw_detections(img_rgb, result) |
| annotated_bgr = cv2.cvtColor(annotated_rgb, cv2.COLOR_RGB2BGR) |
| _, buf = cv2.imencode(".jpg", annotated_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90]) |
| b64 = base64.b64encode(buf).decode() |
|
|
| logger.info(f"Inference done: {len(processed['detectionResults']['walls'])} walls, " |
| f"{len(processed['detectionResults']['rooms'])} rooms") |
|
|
| return {"image": b64, "json": processed} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|