File size: 11,739 Bytes
fc895f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""

dataset.py

----------

Loads and converts the CubiCasa5k dataset into YOLOv8 segmentation format.



CubiCasa5k provides:

  - Floor plan images (PNG)

  - SVG annotations with labelled polygons per element class



We convert SVG β†’ YOLO segmentation format:

  <class_id> <x1> <y1> <x2> <y2> ... (normalised 0-1 polygon coords)



Class map (14 classes):

  0  Background

  1  OuterWall

  2  InnerWall

  3  Window

  4  Door

  5  Stairs

  6  Railing

  7  Kitchen

  8  LivingRoom

  9  Bedroom

  10 Bathroom

  11 Corridor

  12 Balcony

  13 Garage



Usage:

    from src.segmentation.dataset import CubiCasaDataset

    ds = CubiCasaDataset("data/cubicasa5k")

    ds.prepare(output_dir="data/yolo_dataset")

"""

import os
import shutil
import random
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
from PIL import Image


# ── Class definitions ─────────────────────────────────────────────────────────

CLASS_NAMES = [
    "Background",   # 0
    "OuterWall",    # 1
    "InnerWall",    # 2
    "Window",       # 3
    "Door",         # 4
    "Stairs",       # 5
    "Railing",      # 6
    "Kitchen",      # 7
    "LivingRoom",   # 8
    "Bedroom",      # 9
    "Bathroom",     # 10
    "Corridor",     # 11
    "Balcony",      # 12
    "Garage",       # 13
]

# Map SVG class names β†’ our integer IDs
SVG_CLASS_MAP = {
    "Wall":         1,
    "OuterWall":    1,
    "InnerWall":    2,
    "Window":       3,
    "Door":         4,
    "Stairs":       5,
    "Railing":      6,
    "Kitchen":      7,
    "LivingRoom":   8,
    "Living":       8,
    "Bedroom":      9,
    "Bathroom":     10,
    "Toilet":       10,
    "Corridor":     11,
    "Hallway":      11,
    "Balcony":      12,
    "Terrace":      12,
    "Garage":       13,
    "CarPort":      13,
}

NUM_CLASSES = len(CLASS_NAMES)


# ── Dataset class ─────────────────────────────────────────────────────────────

class CubiCasaDataset:
    """

    Converts CubiCasa5k dataset to YOLOv8 segmentation format.



    CubiCasa5k download:

        https://zenodo.org/record/2613548



    Args:

        root_dir:    Path to the extracted CubiCasa5k folder.

        val_split:   Fraction of data to use for validation.

        test_split:  Fraction of data to use for testing.

        seed:        Random seed for reproducible splits.

    """

    def __init__(

        self,

        root_dir: str,

        val_split: float = 0.15,

        test_split: float = 0.10,

        seed: int = 42,

    ):
        self.root = Path(root_dir)
        self.val_split = val_split
        self.test_split = test_split
        self.seed = seed

        if not self.root.exists():
            raise FileNotFoundError(
                f"Dataset root not found: {root_dir}\n"
                "Download CubiCasa5k from: https://zenodo.org/record/2613548"
            )

    def prepare(self, output_dir: str = "data/yolo_dataset") -> str:
        """

        Convert and split the dataset into train/val/test sets.



        Args:

            output_dir: Where to write the YOLO-formatted dataset.



        Returns:

            Path to the generated dataset.yaml file.

        """
        out = Path(output_dir)
        print(f"Preparing CubiCasa5k β†’ YOLO format in: {out}")

        # Discover all floor plan samples
        samples = self._discover_samples()
        print(f"  Found {len(samples)} annotated floor plans")

        # Split into train / val / test
        splits = self._split(samples)
        for split_name, split_samples in splits.items():
            print(f"  {split_name}: {len(split_samples)} samples")

        # Convert and write each split
        for split_name, split_samples in splits.items():
            self._write_split(split_samples, out, split_name)

        # Write dataset.yaml
        yaml_path = self._write_yaml(out)
        print(f"\nDataset ready. Config: {yaml_path}")
        return str(yaml_path)

    # ── Internal helpers ──────────────────────────────────────────────────────

    def _discover_samples(self) -> list[dict]:
        """

        Find all (image, annotation) pairs in the dataset.

        CubiCasa5k stores each floor plan in its own subdirectory.

        """
        samples = []
        # CubiCasa5k structure: root/high_quality/<id>/F1_scaled.png + model.svg
        for subdir in sorted(self.root.rglob("F1_scaled.png")):
            img_path = subdir
            svg_path = subdir.parent / "model.svg"
            if svg_path.exists():
                samples.append({
                    "image": str(img_path),
                    "annotation": str(svg_path),
                })
        return samples

    def _split(self, samples: list[dict]) -> dict[str, list[dict]]:
        """Reproducible train/val/test split."""
        random.seed(self.seed)
        shuffled = samples.copy()
        random.shuffle(shuffled)

        n = len(shuffled)
        n_test = int(n * self.test_split)
        n_val = int(n * self.val_split)

        return {
            "test":  shuffled[:n_test],
            "val":   shuffled[n_test:n_test + n_val],
            "train": shuffled[n_test + n_val:],
        }

    def _write_split(

        self, samples: list[dict], out: Path, split: str

    ) -> None:
        """Convert samples and write images + labels to split directory."""
        img_dir = out / "images" / split
        lbl_dir = out / "labels" / split
        img_dir.mkdir(parents=True, exist_ok=True)
        lbl_dir.mkdir(parents=True, exist_ok=True)

        ok, skipped = 0, 0
        for sample in samples:
            try:
                stem = Path(sample["image"]).parent.name
                # Copy image
                dst_img = img_dir / f"{stem}.png"
                shutil.copy2(sample["image"], dst_img)

                # Parse SVG β†’ YOLO label file
                img = Image.open(sample["image"])
                w, h = img.size
                polygons = parse_svg_annotations(sample["annotation"], w, h)

                if not polygons:
                    skipped += 1
                    continue

                dst_lbl = lbl_dir / f"{stem}.txt"
                write_yolo_labels(polygons, dst_lbl)
                ok += 1

            except Exception as e:
                print(f"  Warning: skipping {sample['image']}: {e}")
                skipped += 1

        print(f"    {split}: wrote {ok} labels, skipped {skipped}")

    def _write_yaml(self, out: Path) -> Path:
        """Write the YOLO dataset configuration YAML."""
        yaml_path = out / "dataset.yaml"
        content = f"""# CubiCasa5k β€” YOLOv8 segmentation dataset

path: {out.resolve()}

train: images/train

val:   images/val

test:  images/test



nc: {NUM_CLASSES - 1}  # exclude background (class 0)

names: {CLASS_NAMES[1:]}

"""
        yaml_path.write_text(content)
        return yaml_path


# ── SVG parsing ───────────────────────────────────────────────────────────────

def parse_svg_annotations(

    svg_path: str, img_w: int, img_h: int

) -> list[dict]:
    """

    Parse a CubiCasa5k SVG annotation file.



    Args:

        svg_path: Path to model.svg

        img_w:    Image width in pixels (for normalisation)

        img_h:    Image height in pixels (for normalisation)



    Returns:

        List of dicts: [{"class_id": int, "polygon": [(x, y), ...]}, ...]

        All coordinates normalised to [0, 1].

    """
    try:
        tree = ET.parse(svg_path)
        root = tree.getroot()
    except ET.ParseError as e:
        raise ValueError(f"Invalid SVG: {svg_path}: {e}")

    ns = {"svg": "http://www.w3.org/2000/svg"}
    polygons = []

    # SVG viewBox gives us the coordinate system
    viewBox = root.get("viewBox", f"0 0 {img_w} {img_h}")
    vb = [float(v) for v in viewBox.split()]
    svg_w, svg_h = vb[2], vb[3]

    for elem in root.iter():
        tag = elem.tag.split("}")[-1]  # strip namespace
        class_name = (
            elem.get("class", "") or
            elem.get("id", "").split("-")[0] or
            ""
        )

        class_id = SVG_CLASS_MAP.get(class_name)
        if class_id is None:
            continue

        pts = None
        if tag == "polygon":
            pts = _parse_polygon_points(elem.get("points", ""))
        elif tag == "polyline":
            pts = _parse_polygon_points(elem.get("points", ""))
        elif tag == "rect":
            pts = _rect_to_polygon(elem)
        elif tag == "path":
            pts = _path_to_polygon(elem.get("d", ""))

        if pts and len(pts) >= 3:
            # Normalise to [0, 1] relative to image size
            norm = [
                (
                    round(x / svg_w, 6),
                    round(y / svg_h, 6),
                )
                for x, y in pts
            ]
            polygons.append({"class_id": class_id, "polygon": norm})

    return polygons


def write_yolo_labels(polygons: list[dict], output_path: Path) -> None:
    """

    Write YOLO segmentation label file.

    Format per line: <class_id> <x1> <y1> <x2> <y2> ...

    """
    lines = []
    for ann in polygons:
        coords = " ".join(
            f"{x} {y}" for x, y in ann["polygon"]
        )
        lines.append(f"{ann['class_id'] - 1} {coords}")  # YOLO is 0-indexed

    output_path.write_text("\n".join(lines))


# ── Geometry helpers ──────────────────────────────────────────────────────────

def _parse_polygon_points(points_str: str) -> list[tuple]:
    """Parse SVG 'points' attribute into list of (x, y) tuples."""
    try:
        vals = [float(v) for v in points_str.replace(",", " ").split()]
        return [(vals[i], vals[i + 1]) for i in range(0, len(vals) - 1, 2)]
    except (ValueError, IndexError):
        return []


def _rect_to_polygon(elem) -> list[tuple]:
    """Convert SVG <rect> to 4-point polygon."""
    try:
        x = float(elem.get("x", 0))
        y = float(elem.get("y", 0))
        w = float(elem.get("width", 0))
        h = float(elem.get("height", 0))
        return [(x, y), (x + w, y), (x + w, y + h), (x, y + h)]
    except (ValueError, TypeError):
        return []


def _path_to_polygon(d: str) -> list[tuple]:
    """

    Naive SVG path β†’ polygon converter.

    Only handles M/L/Z commands (absolute move/line/close).

    Sufficient for simple architectural polygons.

    """
    pts = []
    try:
        tokens = d.replace(",", " ").split()
        i = 0
        while i < len(tokens):
            cmd = tokens[i]
            if cmd in ("M", "L"):
                x, y = float(tokens[i + 1]), float(tokens[i + 2])
                pts.append((x, y))
                i += 3
            elif cmd == "Z":
                i += 1
            else:
                i += 1
    except (IndexError, ValueError):
        pass
    return pts