Spaces:
Sleeping
Sleeping
File size: 11,739 Bytes
fc895f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 | """
dataset.py
----------
Loads and converts the CubiCasa5k dataset into YOLOv8 segmentation format.
CubiCasa5k provides:
- Floor plan images (PNG)
- SVG annotations with labelled polygons per element class
We convert SVG β YOLO segmentation format:
<class_id> <x1> <y1> <x2> <y2> ... (normalised 0-1 polygon coords)
Class map (14 classes):
0 Background
1 OuterWall
2 InnerWall
3 Window
4 Door
5 Stairs
6 Railing
7 Kitchen
8 LivingRoom
9 Bedroom
10 Bathroom
11 Corridor
12 Balcony
13 Garage
Usage:
from src.segmentation.dataset import CubiCasaDataset
ds = CubiCasaDataset("data/cubicasa5k")
ds.prepare(output_dir="data/yolo_dataset")
"""
import os
import shutil
import random
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Optional
import cv2
import numpy as np
from PIL import Image
# ββ Class definitions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CLASS_NAMES = [
"Background", # 0
"OuterWall", # 1
"InnerWall", # 2
"Window", # 3
"Door", # 4
"Stairs", # 5
"Railing", # 6
"Kitchen", # 7
"LivingRoom", # 8
"Bedroom", # 9
"Bathroom", # 10
"Corridor", # 11
"Balcony", # 12
"Garage", # 13
]
# Map SVG class names β our integer IDs
SVG_CLASS_MAP = {
"Wall": 1,
"OuterWall": 1,
"InnerWall": 2,
"Window": 3,
"Door": 4,
"Stairs": 5,
"Railing": 6,
"Kitchen": 7,
"LivingRoom": 8,
"Living": 8,
"Bedroom": 9,
"Bathroom": 10,
"Toilet": 10,
"Corridor": 11,
"Hallway": 11,
"Balcony": 12,
"Terrace": 12,
"Garage": 13,
"CarPort": 13,
}
NUM_CLASSES = len(CLASS_NAMES)
# ββ Dataset class βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class CubiCasaDataset:
"""
Converts CubiCasa5k dataset to YOLOv8 segmentation format.
CubiCasa5k download:
https://zenodo.org/record/2613548
Args:
root_dir: Path to the extracted CubiCasa5k folder.
val_split: Fraction of data to use for validation.
test_split: Fraction of data to use for testing.
seed: Random seed for reproducible splits.
"""
def __init__(
self,
root_dir: str,
val_split: float = 0.15,
test_split: float = 0.10,
seed: int = 42,
):
self.root = Path(root_dir)
self.val_split = val_split
self.test_split = test_split
self.seed = seed
if not self.root.exists():
raise FileNotFoundError(
f"Dataset root not found: {root_dir}\n"
"Download CubiCasa5k from: https://zenodo.org/record/2613548"
)
def prepare(self, output_dir: str = "data/yolo_dataset") -> str:
"""
Convert and split the dataset into train/val/test sets.
Args:
output_dir: Where to write the YOLO-formatted dataset.
Returns:
Path to the generated dataset.yaml file.
"""
out = Path(output_dir)
print(f"Preparing CubiCasa5k β YOLO format in: {out}")
# Discover all floor plan samples
samples = self._discover_samples()
print(f" Found {len(samples)} annotated floor plans")
# Split into train / val / test
splits = self._split(samples)
for split_name, split_samples in splits.items():
print(f" {split_name}: {len(split_samples)} samples")
# Convert and write each split
for split_name, split_samples in splits.items():
self._write_split(split_samples, out, split_name)
# Write dataset.yaml
yaml_path = self._write_yaml(out)
print(f"\nDataset ready. Config: {yaml_path}")
return str(yaml_path)
# ββ Internal helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _discover_samples(self) -> list[dict]:
"""
Find all (image, annotation) pairs in the dataset.
CubiCasa5k stores each floor plan in its own subdirectory.
"""
samples = []
# CubiCasa5k structure: root/high_quality/<id>/F1_scaled.png + model.svg
for subdir in sorted(self.root.rglob("F1_scaled.png")):
img_path = subdir
svg_path = subdir.parent / "model.svg"
if svg_path.exists():
samples.append({
"image": str(img_path),
"annotation": str(svg_path),
})
return samples
def _split(self, samples: list[dict]) -> dict[str, list[dict]]:
"""Reproducible train/val/test split."""
random.seed(self.seed)
shuffled = samples.copy()
random.shuffle(shuffled)
n = len(shuffled)
n_test = int(n * self.test_split)
n_val = int(n * self.val_split)
return {
"test": shuffled[:n_test],
"val": shuffled[n_test:n_test + n_val],
"train": shuffled[n_test + n_val:],
}
def _write_split(
self, samples: list[dict], out: Path, split: str
) -> None:
"""Convert samples and write images + labels to split directory."""
img_dir = out / "images" / split
lbl_dir = out / "labels" / split
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
ok, skipped = 0, 0
for sample in samples:
try:
stem = Path(sample["image"]).parent.name
# Copy image
dst_img = img_dir / f"{stem}.png"
shutil.copy2(sample["image"], dst_img)
# Parse SVG β YOLO label file
img = Image.open(sample["image"])
w, h = img.size
polygons = parse_svg_annotations(sample["annotation"], w, h)
if not polygons:
skipped += 1
continue
dst_lbl = lbl_dir / f"{stem}.txt"
write_yolo_labels(polygons, dst_lbl)
ok += 1
except Exception as e:
print(f" Warning: skipping {sample['image']}: {e}")
skipped += 1
print(f" {split}: wrote {ok} labels, skipped {skipped}")
def _write_yaml(self, out: Path) -> Path:
"""Write the YOLO dataset configuration YAML."""
yaml_path = out / "dataset.yaml"
content = f"""# CubiCasa5k β YOLOv8 segmentation dataset
path: {out.resolve()}
train: images/train
val: images/val
test: images/test
nc: {NUM_CLASSES - 1} # exclude background (class 0)
names: {CLASS_NAMES[1:]}
"""
yaml_path.write_text(content)
return yaml_path
# ββ SVG parsing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def parse_svg_annotations(
svg_path: str, img_w: int, img_h: int
) -> list[dict]:
"""
Parse a CubiCasa5k SVG annotation file.
Args:
svg_path: Path to model.svg
img_w: Image width in pixels (for normalisation)
img_h: Image height in pixels (for normalisation)
Returns:
List of dicts: [{"class_id": int, "polygon": [(x, y), ...]}, ...]
All coordinates normalised to [0, 1].
"""
try:
tree = ET.parse(svg_path)
root = tree.getroot()
except ET.ParseError as e:
raise ValueError(f"Invalid SVG: {svg_path}: {e}")
ns = {"svg": "http://www.w3.org/2000/svg"}
polygons = []
# SVG viewBox gives us the coordinate system
viewBox = root.get("viewBox", f"0 0 {img_w} {img_h}")
vb = [float(v) for v in viewBox.split()]
svg_w, svg_h = vb[2], vb[3]
for elem in root.iter():
tag = elem.tag.split("}")[-1] # strip namespace
class_name = (
elem.get("class", "") or
elem.get("id", "").split("-")[0] or
""
)
class_id = SVG_CLASS_MAP.get(class_name)
if class_id is None:
continue
pts = None
if tag == "polygon":
pts = _parse_polygon_points(elem.get("points", ""))
elif tag == "polyline":
pts = _parse_polygon_points(elem.get("points", ""))
elif tag == "rect":
pts = _rect_to_polygon(elem)
elif tag == "path":
pts = _path_to_polygon(elem.get("d", ""))
if pts and len(pts) >= 3:
# Normalise to [0, 1] relative to image size
norm = [
(
round(x / svg_w, 6),
round(y / svg_h, 6),
)
for x, y in pts
]
polygons.append({"class_id": class_id, "polygon": norm})
return polygons
def write_yolo_labels(polygons: list[dict], output_path: Path) -> None:
"""
Write YOLO segmentation label file.
Format per line: <class_id> <x1> <y1> <x2> <y2> ...
"""
lines = []
for ann in polygons:
coords = " ".join(
f"{x} {y}" for x, y in ann["polygon"]
)
lines.append(f"{ann['class_id'] - 1} {coords}") # YOLO is 0-indexed
output_path.write_text("\n".join(lines))
# ββ Geometry helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _parse_polygon_points(points_str: str) -> list[tuple]:
"""Parse SVG 'points' attribute into list of (x, y) tuples."""
try:
vals = [float(v) for v in points_str.replace(",", " ").split()]
return [(vals[i], vals[i + 1]) for i in range(0, len(vals) - 1, 2)]
except (ValueError, IndexError):
return []
def _rect_to_polygon(elem) -> list[tuple]:
"""Convert SVG <rect> to 4-point polygon."""
try:
x = float(elem.get("x", 0))
y = float(elem.get("y", 0))
w = float(elem.get("width", 0))
h = float(elem.get("height", 0))
return [(x, y), (x + w, y), (x + w, y + h), (x, y + h)]
except (ValueError, TypeError):
return []
def _path_to_polygon(d: str) -> list[tuple]:
"""
Naive SVG path β polygon converter.
Only handles M/L/Z commands (absolute move/line/close).
Sufficient for simple architectural polygons.
"""
pts = []
try:
tokens = d.replace(",", " ").split()
i = 0
while i < len(tokens):
cmd = tokens[i]
if cmd in ("M", "L"):
x, y = float(tokens[i + 1]), float(tokens[i + 2])
pts.append((x, y))
i += 3
elif cmd == "Z":
i += 1
else:
i += 1
except (IndexError, ValueError):
pass
return pts
|