| """ |
| PyTorch Dataset for facade segmentation with SAM. |
| Generates bounding-box prompts from ground-truth masks. |
| """ |
| import os |
| import json |
| import random |
| import numpy as np |
| from PIL import Image |
| import torch |
| from torch.utils.data import Dataset |
| from transformers import SamProcessor |
|
|
|
|
| class FacadeDataset(Dataset): |
| """ |
| Dataset for facade segmentation. |
| Loads images and binary masks. Generates bbox prompts from masks. |
| """ |
| def __init__(self, data_dir, split="train", processor=None, augment=False): |
| self.data_dir = data_dir |
| self.split = split |
| self.processor = processor |
| self.augment = augment |
|
|
| split_dir = os.path.join(data_dir, split) |
| with open(os.path.join(split_dir, "metadata.json"), "r") as f: |
| self.items = json.load(f) |
|
|
| self.img_dir = os.path.join(split_dir, "images") |
| self.mask_dir = os.path.join(split_dir, "masks_binary") |
|
|
| def __len__(self): |
| return len(self.items) |
|
|
| def _get_bbox_from_mask(self, mask_np): |
| """Compute bounding box from binary mask with small random jitter.""" |
| ys, xs = np.where(mask_np > 0) |
| if len(xs) == 0: |
| return [0, 0, mask_np.shape[1], mask_np.shape[0]] |
|
|
| x1, y1, x2, y2 = xs.min(), ys.min(), xs.max(), ys.max() |
|
|
| if self.augment: |
| jitter = 10 |
| h, w = mask_np.shape |
| x1 = max(0, x1 + random.randint(-jitter, jitter)) |
| y1 = max(0, y1 + random.randint(-jitter, jitter)) |
| x2 = min(w, x2 + random.randint(-jitter, jitter)) |
| y2 = min(h, y2 + random.randint(-jitter, jitter)) |
|
|
| return [int(x1), int(y1), int(x2), int(y2)] |
|
|
| def __getitem__(self, idx): |
| item = self.items[idx] |
|
|
| img = Image.open(item["image"]).convert("RGB") |
| mask = Image.open(os.path.join(self.mask_dir, os.path.basename(item["mask"]))).convert("L") |
|
|
| mask_np = np.array(mask) |
| bbox = self._get_bbox_from_mask(mask_np) |
|
|
| if self.processor is not None: |
| inputs = self.processor( |
| images=img, |
| input_boxes=[[bbox]], |
| return_tensors="pt", |
| ) |
| for k in inputs: |
| if isinstance(inputs[k], torch.Tensor): |
| inputs[k] = inputs[k].squeeze(0) |
| else: |
| inputs = { |
| "pixel_values": torch.from_numpy(np.array(img)).permute(2, 0, 1).float(), |
| "input_boxes": torch.tensor([bbox], dtype=torch.float32), |
| } |
|
|
| gt_mask = np.array(mask.resize((256, 256), Image.NEAREST)) |
| gt_mask = (gt_mask > 0).astype(np.float32) |
| inputs["ground_truth_mask"] = torch.from_numpy(gt_mask) |
|
|
| return inputs |
|
|
|
|
| def collate_fn(batch): |
| """Collate for DataLoader.""" |
| return { |
| "pixel_values": torch.stack([b["pixel_values"] for b in batch]), |
| "input_boxes": torch.stack([b["input_boxes"] for b in batch]), |
| "ground_truth_mask": torch.stack([b["ground_truth_mask"] for b in batch]), |
| } |
|
|