Spaces:
Sleeping
Sleeping
| """ | |
| predictor.py | |
| ------------ | |
| Runs YOLOv8 segmentation inference on new floor plan images. | |
| Outputs structured predictions consumed by Phase 3 (geometry reconstruction). | |
| Usage: | |
| from src.segmentation.predictor import FloorPlanPredictor | |
| predictor = FloorPlanPredictor("models/segmentation/best.pt") | |
| result = predictor.predict("outputs/plan_4_cleaned.png") | |
| print(result.walls) # list of wall polygons | |
| print(result.rooms) # list of room regions with class labels | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| # ── Result data structures ──────────────────────────────────────────────────── | |
| class DetectedElement: | |
| """A single detected floor plan element.""" | |
| class_id: int | |
| class_name: str | |
| confidence: float | |
| bbox: tuple[int, int, int, int] # (x1, y1, x2, y2) in pixels | |
| mask: Optional[np.ndarray] = None # Binary mask, same size as image | |
| polygon: Optional[list[tuple]] = None # Contour polygon points | |
| class SegmentationResult: | |
| """Full segmentation result for one floor plan image.""" | |
| image_path: str | |
| image_shape: tuple[int, int] # (H, W) | |
| elements: list[DetectedElement] = field(default_factory=list) | |
| def walls(self) -> list[DetectedElement]: | |
| return [e for e in self.elements if 'wall' in e.class_name.lower()] | |
| def doors(self) -> list[DetectedElement]: | |
| return [e for e in self.elements if 'door' in e.class_name.lower()] | |
| def windows(self) -> list[DetectedElement]: | |
| return [e for e in self.elements if 'window' in e.class_name.lower()] | |
| def rooms(self) -> list[DetectedElement]: | |
| room_keywords = {'kitchen','living','bedroom','bathroom','corridor','balcony','garage','room'} | |
| return [e for e in self.elements if any(k in e.class_name.lower() for k in room_keywords)] | |
| def summary(self) -> dict: | |
| from collections import Counter | |
| counts = Counter(e.class_name for e in self.elements) | |
| return { | |
| "total_elements": len(self.elements), | |
| "by_class": dict(counts), | |
| "image": self.image_path, | |
| } | |
| # ── Predictor ───────────────────────────────────────────────────────────────── | |
| class FloorPlanPredictor: | |
| """ | |
| Runs YOLOv8 segmentation on cleaned floor plan images. | |
| Args: | |
| model_path: Path to trained best.pt weights. | |
| confidence: Minimum confidence threshold (0-1). | |
| iou: NMS IoU threshold (0-1). | |
| device: Inference device ('mps', 'cpu', '0'). Auto-detected if None. | |
| """ | |
| # Class names — loaded dynamically from model, fallback below | |
| CLASS_NAMES = [ | |
| "OuterWall", "InnerWall", "Window", "Door", | |
| "Stairs", "Railing", "Kitchen", "LivingRoom", | |
| "Bedroom", "Bathroom", "Corridor", "Balcony", "Garage", | |
| ] | |
| def __init__( | |
| self, | |
| model_path: str, | |
| confidence: float = 0.35, | |
| iou: float = 0.45, | |
| device: Optional[str] = None, | |
| ): | |
| model_path = Path(model_path) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"Model weights not found: {model_path}") | |
| self.model_path = model_path | |
| self.confidence = confidence | |
| self.iou = iou | |
| self.device = device | |
| self._model = None # Lazy load | |
| def _load_model(self): | |
| """Lazy-load the YOLO model on first inference call.""" | |
| if self._model is None: | |
| from ultralytics import YOLO | |
| from src.segmentation.trainer import get_best_device | |
| device = self.device or get_best_device() | |
| self._model = YOLO(str(self.model_path)) | |
| # Use model's own class names | |
| self.CLASS_NAMES = list(self._model.names.values()) | |
| print(f"Model loaded: {self.model_path.name} on {device}") | |
| print(f" Classes: {self.CLASS_NAMES}") | |
| self._device = device | |
| def predict(self, image_path: str) -> SegmentationResult: | |
| """ | |
| Run segmentation on a single floor plan image. | |
| Args: | |
| image_path: Path to a preprocessed floor plan (Phase 1 output). | |
| Returns: | |
| SegmentationResult with all detected elements. | |
| """ | |
| self._load_model() | |
| image_path = Path(image_path) | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| img = cv2.imread(str(image_path)) | |
| h, w = img.shape[:2] | |
| raw = self._model.predict( | |
| source=str(image_path), | |
| conf=self.confidence, | |
| iou=self.iou, | |
| device=self._device, | |
| verbose=False, | |
| retina_masks=True, # Higher quality masks | |
| ) | |
| result = SegmentationResult( | |
| image_path=str(image_path), | |
| image_shape=(h, w), | |
| ) | |
| if raw and len(raw[0].boxes) > 0: | |
| result.elements = self._parse_detections(raw[0], h, w) | |
| print(f"Detected {len(result.elements)} elements in {image_path.name}") | |
| for cls, count in result.summary["by_class"].items(): | |
| print(f" {cls}: {count}") | |
| return result | |
| def predict_batch( | |
| self, image_paths: list[str] | |
| ) -> list[SegmentationResult]: | |
| """Run prediction on multiple images.""" | |
| self._load_model() | |
| results = [] | |
| for i, path in enumerate(image_paths, 1): | |
| print(f"\n── [{i}/{len(image_paths)}] {Path(path).name}") | |
| try: | |
| results.append(self.predict(path)) | |
| except Exception as e: | |
| print(f" ERROR: {e}") | |
| return results | |
| def _parse_detections(self, raw_result, img_h: int, img_w: int) -> list[DetectedElement]: | |
| """Convert raw YOLO output to DetectedElement list.""" | |
| elements = [] | |
| boxes = raw_result.boxes | |
| masks = raw_result.masks | |
| for i in range(len(boxes)): | |
| class_id = int(boxes.cls[i].item()) | |
| confidence = float(boxes.conf[i].item()) | |
| x1, y1, x2, y2 = [int(v) for v in boxes.xyxy[i].tolist()] | |
| # Get binary mask — from segmentation if available, else from bbox | |
| if masks is not None: | |
| mask_data = masks.data[i].cpu().numpy() | |
| mask = cv2.resize( | |
| (mask_data * 255).astype(np.uint8), | |
| (img_w, img_h), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| else: | |
| # Build mask from bounding box | |
| mask = np.zeros((img_h, img_w), dtype=np.uint8) | |
| x1c, y1c = max(0, x1), max(0, y1) | |
| x2c, y2c = min(img_w, x2), min(img_h, y2) | |
| mask[y1c:y2c, x1c:x2c] = 255 | |
| # Extract polygon from mask | |
| polygon = self._mask_to_polygon(mask) if mask is not None else None | |
| class_name = ( | |
| self.CLASS_NAMES[class_id] | |
| if class_id < len(self.CLASS_NAMES) | |
| else f"class_{class_id}" | |
| ) | |
| elements.append(DetectedElement( | |
| class_id=class_id, | |
| class_name=class_name, | |
| confidence=confidence, | |
| bbox=(x1, y1, x2, y2), | |
| mask=mask, | |
| polygon=polygon, | |
| )) | |
| return elements | |
| def _mask_to_polygon( | |
| self, mask: np.ndarray, epsilon_factor: float = 0.005 | |
| ) -> Optional[list[tuple]]: | |
| """ | |
| Convert binary mask to simplified polygon. | |
| Args: | |
| mask: Binary mask (0/255). | |
| epsilon_factor: Contour approximation accuracy (fraction of perimeter). | |
| Returns: | |
| List of (x, y) pixel coordinates or None if no contour found. | |
| """ | |
| contours, _ = cv2.findContours( | |
| mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE | |
| ) | |
| if not contours: | |
| return None | |
| # Use the largest contour | |
| contour = max(contours, key=cv2.contourArea) | |
| epsilon = epsilon_factor * cv2.arcLength(contour, closed=True) | |
| approx = cv2.approxPolyDP(contour, epsilon, closed=True) | |
| return [(int(pt[0][0]), int(pt[0][1])) for pt in approx] | |