""" PyTorch Dataset for facade segmentation with SAM. Generates bounding-box prompts from ground-truth masks. """ import os import json import random import numpy as np from PIL import Image import torch from torch.utils.data import Dataset from transformers import SamProcessor class FacadeDataset(Dataset): """ Dataset for facade segmentation. Loads images and binary masks. Generates bbox prompts from masks. """ def __init__(self, data_dir, split="train", processor=None, augment=False): self.data_dir = data_dir self.split = split self.processor = processor self.augment = augment split_dir = os.path.join(data_dir, split) with open(os.path.join(split_dir, "metadata.json"), "r") as f: self.items = json.load(f) self.img_dir = os.path.join(split_dir, "images") self.mask_dir = os.path.join(split_dir, "masks_binary") def __len__(self): return len(self.items) def _get_bbox_from_mask(self, mask_np): """Compute bounding box from binary mask with small random jitter.""" ys, xs = np.where(mask_np > 0) if len(xs) == 0: return [0, 0, mask_np.shape[1], mask_np.shape[0]] x1, y1, x2, y2 = xs.min(), ys.min(), xs.max(), ys.max() if self.augment: jitter = 10 h, w = mask_np.shape x1 = max(0, x1 + random.randint(-jitter, jitter)) y1 = max(0, y1 + random.randint(-jitter, jitter)) x2 = min(w, x2 + random.randint(-jitter, jitter)) y2 = min(h, y2 + random.randint(-jitter, jitter)) return [int(x1), int(y1), int(x2), int(y2)] def __getitem__(self, idx): item = self.items[idx] img = Image.open(item["image"]).convert("RGB") mask = Image.open(os.path.join(self.mask_dir, os.path.basename(item["mask"]))).convert("L") mask_np = np.array(mask) bbox = self._get_bbox_from_mask(mask_np) if self.processor is not None: inputs = self.processor( images=img, input_boxes=[[bbox]], return_tensors="pt", ) for k in inputs: if isinstance(inputs[k], torch.Tensor): inputs[k] = inputs[k].squeeze(0) else: inputs = { "pixel_values": torch.from_numpy(np.array(img)).permute(2, 0, 1).float(), "input_boxes": torch.tensor([bbox], dtype=torch.float32), } gt_mask = np.array(mask.resize((256, 256), Image.NEAREST)) gt_mask = (gt_mask > 0).astype(np.float32) inputs["ground_truth_mask"] = torch.from_numpy(gt_mask) return inputs def collate_fn(batch): """Collate for DataLoader.""" return { "pixel_values": torch.stack([b["pixel_values"] for b in batch]), "input_boxes": torch.stack([b["input_boxes"] for b in batch]), "ground_truth_mask": torch.stack([b["ground_truth_mask"] for b in batch]), }