Spaces:
Sleeping
Sleeping
| """ | |
| dataset.py | |
| ---------- | |
| Loads and converts the CubiCasa5k dataset into YOLOv8 segmentation format. | |
| CubiCasa5k provides: | |
| - Floor plan images (PNG) | |
| - SVG annotations with labelled polygons per element class | |
| We convert SVG β YOLO segmentation format: | |
| <class_id> <x1> <y1> <x2> <y2> ... (normalised 0-1 polygon coords) | |
| Class map (14 classes): | |
| 0 Background | |
| 1 OuterWall | |
| 2 InnerWall | |
| 3 Window | |
| 4 Door | |
| 5 Stairs | |
| 6 Railing | |
| 7 Kitchen | |
| 8 LivingRoom | |
| 9 Bedroom | |
| 10 Bathroom | |
| 11 Corridor | |
| 12 Balcony | |
| 13 Garage | |
| Usage: | |
| from src.segmentation.dataset import CubiCasaDataset | |
| ds = CubiCasaDataset("data/cubicasa5k") | |
| ds.prepare(output_dir="data/yolo_dataset") | |
| """ | |
| import os | |
| import shutil | |
| import random | |
| import xml.etree.ElementTree as ET | |
| from pathlib import Path | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| # ββ Class definitions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CLASS_NAMES = [ | |
| "Background", # 0 | |
| "OuterWall", # 1 | |
| "InnerWall", # 2 | |
| "Window", # 3 | |
| "Door", # 4 | |
| "Stairs", # 5 | |
| "Railing", # 6 | |
| "Kitchen", # 7 | |
| "LivingRoom", # 8 | |
| "Bedroom", # 9 | |
| "Bathroom", # 10 | |
| "Corridor", # 11 | |
| "Balcony", # 12 | |
| "Garage", # 13 | |
| ] | |
| # Map SVG class names β our integer IDs | |
| SVG_CLASS_MAP = { | |
| "Wall": 1, | |
| "OuterWall": 1, | |
| "InnerWall": 2, | |
| "Window": 3, | |
| "Door": 4, | |
| "Stairs": 5, | |
| "Railing": 6, | |
| "Kitchen": 7, | |
| "LivingRoom": 8, | |
| "Living": 8, | |
| "Bedroom": 9, | |
| "Bathroom": 10, | |
| "Toilet": 10, | |
| "Corridor": 11, | |
| "Hallway": 11, | |
| "Balcony": 12, | |
| "Terrace": 12, | |
| "Garage": 13, | |
| "CarPort": 13, | |
| } | |
| NUM_CLASSES = len(CLASS_NAMES) | |
| # ββ Dataset class βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CubiCasaDataset: | |
| """ | |
| Converts CubiCasa5k dataset to YOLOv8 segmentation format. | |
| CubiCasa5k download: | |
| https://zenodo.org/record/2613548 | |
| Args: | |
| root_dir: Path to the extracted CubiCasa5k folder. | |
| val_split: Fraction of data to use for validation. | |
| test_split: Fraction of data to use for testing. | |
| seed: Random seed for reproducible splits. | |
| """ | |
| def __init__( | |
| self, | |
| root_dir: str, | |
| val_split: float = 0.15, | |
| test_split: float = 0.10, | |
| seed: int = 42, | |
| ): | |
| self.root = Path(root_dir) | |
| self.val_split = val_split | |
| self.test_split = test_split | |
| self.seed = seed | |
| if not self.root.exists(): | |
| raise FileNotFoundError( | |
| f"Dataset root not found: {root_dir}\n" | |
| "Download CubiCasa5k from: https://zenodo.org/record/2613548" | |
| ) | |
| def prepare(self, output_dir: str = "data/yolo_dataset") -> str: | |
| """ | |
| Convert and split the dataset into train/val/test sets. | |
| Args: | |
| output_dir: Where to write the YOLO-formatted dataset. | |
| Returns: | |
| Path to the generated dataset.yaml file. | |
| """ | |
| out = Path(output_dir) | |
| print(f"Preparing CubiCasa5k β YOLO format in: {out}") | |
| # Discover all floor plan samples | |
| samples = self._discover_samples() | |
| print(f" Found {len(samples)} annotated floor plans") | |
| # Split into train / val / test | |
| splits = self._split(samples) | |
| for split_name, split_samples in splits.items(): | |
| print(f" {split_name}: {len(split_samples)} samples") | |
| # Convert and write each split | |
| for split_name, split_samples in splits.items(): | |
| self._write_split(split_samples, out, split_name) | |
| # Write dataset.yaml | |
| yaml_path = self._write_yaml(out) | |
| print(f"\nDataset ready. Config: {yaml_path}") | |
| return str(yaml_path) | |
| # ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _discover_samples(self) -> list[dict]: | |
| """ | |
| Find all (image, annotation) pairs in the dataset. | |
| CubiCasa5k stores each floor plan in its own subdirectory. | |
| """ | |
| samples = [] | |
| # CubiCasa5k structure: root/high_quality/<id>/F1_scaled.png + model.svg | |
| for subdir in sorted(self.root.rglob("F1_scaled.png")): | |
| img_path = subdir | |
| svg_path = subdir.parent / "model.svg" | |
| if svg_path.exists(): | |
| samples.append({ | |
| "image": str(img_path), | |
| "annotation": str(svg_path), | |
| }) | |
| return samples | |
| def _split(self, samples: list[dict]) -> dict[str, list[dict]]: | |
| """Reproducible train/val/test split.""" | |
| random.seed(self.seed) | |
| shuffled = samples.copy() | |
| random.shuffle(shuffled) | |
| n = len(shuffled) | |
| n_test = int(n * self.test_split) | |
| n_val = int(n * self.val_split) | |
| return { | |
| "test": shuffled[:n_test], | |
| "val": shuffled[n_test:n_test + n_val], | |
| "train": shuffled[n_test + n_val:], | |
| } | |
| def _write_split( | |
| self, samples: list[dict], out: Path, split: str | |
| ) -> None: | |
| """Convert samples and write images + labels to split directory.""" | |
| img_dir = out / "images" / split | |
| lbl_dir = out / "labels" / split | |
| img_dir.mkdir(parents=True, exist_ok=True) | |
| lbl_dir.mkdir(parents=True, exist_ok=True) | |
| ok, skipped = 0, 0 | |
| for sample in samples: | |
| try: | |
| stem = Path(sample["image"]).parent.name | |
| # Copy image | |
| dst_img = img_dir / f"{stem}.png" | |
| shutil.copy2(sample["image"], dst_img) | |
| # Parse SVG β YOLO label file | |
| img = Image.open(sample["image"]) | |
| w, h = img.size | |
| polygons = parse_svg_annotations(sample["annotation"], w, h) | |
| if not polygons: | |
| skipped += 1 | |
| continue | |
| dst_lbl = lbl_dir / f"{stem}.txt" | |
| write_yolo_labels(polygons, dst_lbl) | |
| ok += 1 | |
| except Exception as e: | |
| print(f" Warning: skipping {sample['image']}: {e}") | |
| skipped += 1 | |
| print(f" {split}: wrote {ok} labels, skipped {skipped}") | |
| def _write_yaml(self, out: Path) -> Path: | |
| """Write the YOLO dataset configuration YAML.""" | |
| yaml_path = out / "dataset.yaml" | |
| content = f"""# CubiCasa5k β YOLOv8 segmentation dataset | |
| path: {out.resolve()} | |
| train: images/train | |
| val: images/val | |
| test: images/test | |
| nc: {NUM_CLASSES - 1} # exclude background (class 0) | |
| names: {CLASS_NAMES[1:]} | |
| """ | |
| yaml_path.write_text(content) | |
| return yaml_path | |
| # ββ SVG parsing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_svg_annotations( | |
| svg_path: str, img_w: int, img_h: int | |
| ) -> list[dict]: | |
| """ | |
| Parse a CubiCasa5k SVG annotation file. | |
| Args: | |
| svg_path: Path to model.svg | |
| img_w: Image width in pixels (for normalisation) | |
| img_h: Image height in pixels (for normalisation) | |
| Returns: | |
| List of dicts: [{"class_id": int, "polygon": [(x, y), ...]}, ...] | |
| All coordinates normalised to [0, 1]. | |
| """ | |
| try: | |
| tree = ET.parse(svg_path) | |
| root = tree.getroot() | |
| except ET.ParseError as e: | |
| raise ValueError(f"Invalid SVG: {svg_path}: {e}") | |
| ns = {"svg": "http://www.w3.org/2000/svg"} | |
| polygons = [] | |
| # SVG viewBox gives us the coordinate system | |
| viewBox = root.get("viewBox", f"0 0 {img_w} {img_h}") | |
| vb = [float(v) for v in viewBox.split()] | |
| svg_w, svg_h = vb[2], vb[3] | |
| for elem in root.iter(): | |
| tag = elem.tag.split("}")[-1] # strip namespace | |
| class_name = ( | |
| elem.get("class", "") or | |
| elem.get("id", "").split("-")[0] or | |
| "" | |
| ) | |
| class_id = SVG_CLASS_MAP.get(class_name) | |
| if class_id is None: | |
| continue | |
| pts = None | |
| if tag == "polygon": | |
| pts = _parse_polygon_points(elem.get("points", "")) | |
| elif tag == "polyline": | |
| pts = _parse_polygon_points(elem.get("points", "")) | |
| elif tag == "rect": | |
| pts = _rect_to_polygon(elem) | |
| elif tag == "path": | |
| pts = _path_to_polygon(elem.get("d", "")) | |
| if pts and len(pts) >= 3: | |
| # Normalise to [0, 1] relative to image size | |
| norm = [ | |
| ( | |
| round(x / svg_w, 6), | |
| round(y / svg_h, 6), | |
| ) | |
| for x, y in pts | |
| ] | |
| polygons.append({"class_id": class_id, "polygon": norm}) | |
| return polygons | |
| def write_yolo_labels(polygons: list[dict], output_path: Path) -> None: | |
| """ | |
| Write YOLO segmentation label file. | |
| Format per line: <class_id> <x1> <y1> <x2> <y2> ... | |
| """ | |
| lines = [] | |
| for ann in polygons: | |
| coords = " ".join( | |
| f"{x} {y}" for x, y in ann["polygon"] | |
| ) | |
| lines.append(f"{ann['class_id'] - 1} {coords}") # YOLO is 0-indexed | |
| output_path.write_text("\n".join(lines)) | |
| # ββ Geometry helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _parse_polygon_points(points_str: str) -> list[tuple]: | |
| """Parse SVG 'points' attribute into list of (x, y) tuples.""" | |
| try: | |
| vals = [float(v) for v in points_str.replace(",", " ").split()] | |
| return [(vals[i], vals[i + 1]) for i in range(0, len(vals) - 1, 2)] | |
| except (ValueError, IndexError): | |
| return [] | |
| def _rect_to_polygon(elem) -> list[tuple]: | |
| """Convert SVG <rect> to 4-point polygon.""" | |
| try: | |
| x = float(elem.get("x", 0)) | |
| y = float(elem.get("y", 0)) | |
| w = float(elem.get("width", 0)) | |
| h = float(elem.get("height", 0)) | |
| return [(x, y), (x + w, y), (x + w, y + h), (x, y + h)] | |
| except (ValueError, TypeError): | |
| return [] | |
| def _path_to_polygon(d: str) -> list[tuple]: | |
| """ | |
| Naive SVG path β polygon converter. | |
| Only handles M/L/Z commands (absolute move/line/close). | |
| Sufficient for simple architectural polygons. | |
| """ | |
| pts = [] | |
| try: | |
| tokens = d.replace(",", " ").split() | |
| i = 0 | |
| while i < len(tokens): | |
| cmd = tokens[i] | |
| if cmd in ("M", "L"): | |
| x, y = float(tokens[i + 1]), float(tokens[i + 2]) | |
| pts.append((x, y)) | |
| i += 3 | |
| elif cmd == "Z": | |
| i += 1 | |
| else: | |
| i += 1 | |
| except (IndexError, ValueError): | |
| pass | |
| return pts | |