File size: 10,452 Bytes
4fae684
 
9b99d13
4fae684
9b99d13
4fae684
 
037fe03
 
4fae684
 
 
 
9b99d13
 
 
4fae684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b99d13
4fae684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b99d13
4fae684
 
 
037fe03
 
4fae684
037fe03
 
 
 
4fae684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037fe03
 
 
 
4fae684
 
 
 
 
 
 
037fe03
 
 
 
4fae684
 
037fe03
 
4fae684
 
037fe03
 
 
 
 
4fae684
 
 
 
037fe03
 
4fae684
 
 
 
 
 
 
 
037fe03
 
4fae684
 
 
037fe03
4fae684
 
037fe03
4fae684
 
 
 
 
 
 
 
 
 
 
 
 
 
037fe03
 
 
 
4fae684
037fe03
 
 
 
 
 
4fae684
 
 
 
 
 
 
037fe03
4fae684
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037fe03
 
4fae684
 
 
 
f875a1b
9b99d13
037fe03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import os
import json
import numpy as np
from typing import Dict, Any
import cv2
import torch
import logging
import base64
import uvicorn
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from mmdet.apis import init_detector, inference_detector
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "mmdetection"))

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

BASE_DIR       = os.path.dirname(os.path.abspath(__file__))
CONFIG_FILE    = os.path.join(BASE_DIR, "configs", "faster_rcnn.py")
CHECKPOINT_FILE = os.path.join(BASE_DIR, "weights", "faster_rcnn_latest.pth")
MAX_FILE_SIZE  = 10 * 1024 * 1024  # 10 MB
SCORE_THRESH   = 0.3  # lower than default to catch more walls

CLASS_COLORS = {
    0: (220, 60,  60),   # wall  β€” red  (RGB)
    1: (50,  200, 80),   # room  β€” green (RGB)
}
CLASS_NAMES = {0: "wall", 1: "room"}

# ── Device ───────────────────────────────────────────────────────────────────
def determine_device():
    if torch.cuda.is_available():
        try:
            torch.cuda.init()
            return "cuda:0"
        except Exception as e:
            logger.warning(f"CUDA failed: {e}. Using CPU.")
    return "cpu"

# ── Model load ───────────────────────────────────────────────────────────────
device = determine_device()
logger.info(f"Loading Faster R-CNN on {device}…")
model = init_detector(CONFIG_FILE, CHECKPOINT_FILE, device=device)
logger.info("Model ready.")

# ── Result processing (mirrors original run.py exactly) ──────────────────────
def process_inference_result(result) -> Dict[str, Any]:
    bboxes = result.pred_instances.bboxes.cpu().numpy()
    labels = result.pred_instances.labels.cpu().numpy()
    scores = result.pred_instances.scores.cpu().numpy()

    walls, rooms = [], []
    for i, (bbox, label, score) in enumerate(zip(bboxes, labels, scores)):
        if score < SCORE_THRESH:
            continue
        x1, y1, x2, y2 = bbox
        item = {
            "id": f"{'wall' if label == 0 else 'room'}_{i+1}",
            "position": {
                "start": {"x": float(x1), "y": float(y1)},
                "end":   {"x": float(x2), "y": float(y2)}
            },
            "confidence": float(score)
        }
        if label == 0:
            walls.append(item)
        else:
            rooms.append(item)

    all_scores = scores[scores >= SCORE_THRESH]
    return {
        "type": "floor_plan",
        "confidence": float(np.mean(all_scores)) if len(all_scores) else 0.0,
        "detectionResults": {"walls": walls, "rooms": rooms}
    }

# ── Visualisation ─────────────────────────────────────────────────────────────
def draw_detections(img_rgb: np.ndarray, result) -> np.ndarray:
    annotated = img_rgb.copy()
    bboxes = result.pred_instances.bboxes.cpu().numpy()
    labels = result.pred_instances.labels.cpu().numpy()
    scores = result.pred_instances.scores.cpu().numpy()

    for bbox, label, score in zip(bboxes, labels, scores):
        if score < SCORE_THRESH or label not in CLASS_NAMES:
            continue
        color = CLASS_COLORS[label]
        name  = CLASS_NAMES[label]
        x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])

        # Semi-transparent fill
        overlay = annotated.copy()
        cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1)
        cv2.addWeighted(overlay, 0.15, annotated, 0.85, 0, annotated)
        # Border
        cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
        # Label
        lbl = f"{name} {score:.2f}"
        (tw, th), _ = cv2.getTextSize(lbl, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
        cv2.rectangle(annotated, (x1, y1-th-6), (x1+tw+4, y1), color, -1)
        cv2.putText(annotated, lbl, (x1+2, y1-4),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)

    return annotated

# ── FastAPI ───────────────────────────────────────────────────────────────────
app = FastAPI()

HTML = """<!DOCTYPE html>
<html>
<head>
  <title>Floor Plan Detection</title>
  <style>
    *{box-sizing:border-box;margin:0;padding:0}
    body{font-family:monospace;background:#0f0f0f;color:#e0e0e0;padding:32px 24px}
    h1{color:#7eb8f7;margin-bottom:8px}
    p.sub{color:#888;margin-bottom:24px;font-size:.9rem}
    .controls{display:flex;gap:12px;align-items:center;flex-wrap:wrap;margin-bottom:24px}
    input[type=file]{display:none}
    label.btn{padding:9px 18px;background:#1e3a5f;color:#7eb8f7;border:1px solid #7eb8f7;border-radius:4px;cursor:pointer}
    label.btn:hover{background:#2a4f7f}
    button{padding:9px 24px;background:#7eb8f7;color:#0f0f0f;border:none;border-radius:4px;cursor:pointer;font-weight:bold;font-size:.95rem}
    button:hover{background:#5a9ee0}
    #fname{color:#555;font-size:.85rem}
    .row{display:flex;gap:20px;flex-wrap:wrap;margin-bottom:16px}
    .col{flex:1;min-width:280px}
    .col p{color:#888;font-size:.8rem;margin-bottom:6px}
    .imgbox{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;min-height:220px;
            display:flex;align-items:center;justify-content:center;color:#444;overflow:hidden}
    .imgbox img{max-width:100%;display:block}
    #summary{background:#1a1a1a;border:1px solid #2a2a2a;border-radius:6px;padding:14px;
             white-space:pre-wrap;font-size:.85rem;min-height:60px;color:#ccc}
    .legend{margin-top:12px;font-size:.85rem;color:#888}
    .dot{display:inline-block;width:10px;height:10px;border-radius:2px;margin-right:4px;vertical-align:middle}
    .loading{color:#7eb8f7;animation:pulse 1.2s infinite}
    @keyframes pulse{0%,100%{opacity:1}50%{opacity:.4}}
  </style>
</head>
<body>
  <h1>🏠 Floor Plan Detection</h1>
  <p class="sub">Faster R-CNN Β· ResNet-101 Β· FPN Β· fine-tuned on CubiCasa5k</p>

  <div class="controls">
    <label class="btn" for="fi">πŸ“‚ Choose Image</label>
    <input type="file" id="fi" accept="image/jpeg,image/png">
    <button onclick="detect()">β–Ά Run Detection</button>
    <span id="fname">No file chosen</span>
  </div>

  <div class="row">
    <div class="col">
      <p>Input</p>
      <div class="imgbox" id="preview">No image loaded</div>
    </div>
    <div class="col">
      <p>Detections</p>
      <div class="imgbox" id="result">Run detection to see results</div>
    </div>
  </div>

  <div id="summary">Upload an image and click Run Detection.</div>

  <div class="legend">
    <span class="dot" style="background:#dc3c3c"></span>Wall &nbsp;
    <span class="dot" style="background:#32c850"></span>Room
  </div>

<script>
  let file = null;
  document.getElementById('fi').addEventListener('change', e => {
    file = e.target.files[0];
    if (!file) return;
    document.getElementById('fname').textContent = file.name;
    const r = new FileReader();
    r.onload = ev => document.getElementById('preview').innerHTML = `<img src="${ev.target.result}">`;
    r.readAsDataURL(file);
  });

  async function detect() {
    if (!file) { alert('Choose an image first.'); return; }
    document.getElementById('result').innerHTML = '<span class="loading">Running… (30–60s on CPU)</span>';
    document.getElementById('summary').textContent = 'Processing…';
    const fd = new FormData();
    fd.append('image', file);
    try {
      const r = await fetch('/detect', {method:'POST', body:fd});
      const d = await r.json();
      if (d.error) { document.getElementById('result').innerHTML = 'Error'; document.getElementById('summary').textContent = d.error; return; }
      document.getElementById('result').innerHTML = `<img src="data:image/jpeg;base64,${d.image}">`;
      const w = d.json.detectionResults.walls.length;
      const rm = d.json.detectionResults.rooms.length;
      let txt = `Detected: ${w} wall(s)  |  ${rm} room(s)  (conf threshold: 0.30)\n`;
      txt += `Overall confidence: ${(d.json.confidence*100).toFixed(1)}%\n\n`;
      d.json.detectionResults.walls.forEach(x => txt += `  β€’ Wall   ${x.id}  conf=${x.confidence.toFixed(3)}\n`);
      d.json.detectionResults.rooms.forEach(x => txt += `  β€’ Room   ${x.id}  conf=${x.confidence.toFixed(3)}\n`);
      document.getElementById('summary').textContent = txt;
    } catch(e) {
      document.getElementById('result').innerHTML = 'Error';
      document.getElementById('summary').textContent = String(e);
    }
  }
</script>
</body>
</html>"""

@app.get("/", response_class=HTMLResponse)
def index():
    return HTML

@app.post("/detect")
async def detect(image: UploadFile = File(...)):
    if image.content_type not in ["image/jpeg", "image/png"]:
        raise HTTPException(status_code=400, detail="Only JPEG and PNG supported.")

    contents = await image.read()
    if len(contents) > MAX_FILE_SIZE:
        raise HTTPException(status_code=400, detail="File exceeds 10 MB limit.")

    nparr = np.frombuffer(contents, np.uint8)
    img_bgr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
    if img_bgr is None:
        raise HTTPException(status_code=400, detail="Could not decode image.")

    # Original run.py converts BGR→RGB before inference
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    result = inference_detector(model, img_rgb)

    # JSON output β€” matches original run.py exactly
    processed = process_inference_result(result)

    # Visual output β€” draw on RGB image, encode as JPEG
    annotated_rgb = draw_detections(img_rgb, result)
    annotated_bgr = cv2.cvtColor(annotated_rgb, cv2.COLOR_RGB2BGR)
    _, buf = cv2.imencode(".jpg", annotated_bgr, [cv2.IMWRITE_JPEG_QUALITY, 90])
    b64 = base64.b64encode(buf).decode()

    logger.info(f"Inference done: {len(processed['detectionResults']['walls'])} walls, "
                f"{len(processed['detectionResults']['rooms'])} rooms")

    return {"image": b64, "json": processed}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)