acd23's picture
Upload folder using huggingface_hub
3cc53ab verified
"""
PyTorch Dataset for facade segmentation with SAM.
Generates bounding-box prompts from ground-truth masks.
"""
import os
import json
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import SamProcessor
class FacadeDataset(Dataset):
"""
Dataset for facade segmentation.
Loads images and binary masks. Generates bbox prompts from masks.
"""
def __init__(self, data_dir, split="train", processor=None, augment=False):
self.data_dir = data_dir
self.split = split
self.processor = processor
self.augment = augment
split_dir = os.path.join(data_dir, split)
with open(os.path.join(split_dir, "metadata.json"), "r") as f:
self.items = json.load(f)
self.img_dir = os.path.join(split_dir, "images")
self.mask_dir = os.path.join(split_dir, "masks_binary")
def __len__(self):
return len(self.items)
def _get_bbox_from_mask(self, mask_np):
"""Compute bounding box from binary mask with small random jitter."""
ys, xs = np.where(mask_np > 0)
if len(xs) == 0:
return [0, 0, mask_np.shape[1], mask_np.shape[0]]
x1, y1, x2, y2 = xs.min(), ys.min(), xs.max(), ys.max()
if self.augment:
jitter = 10
h, w = mask_np.shape
x1 = max(0, x1 + random.randint(-jitter, jitter))
y1 = max(0, y1 + random.randint(-jitter, jitter))
x2 = min(w, x2 + random.randint(-jitter, jitter))
y2 = min(h, y2 + random.randint(-jitter, jitter))
return [int(x1), int(y1), int(x2), int(y2)]
def __getitem__(self, idx):
item = self.items[idx]
img = Image.open(item["image"]).convert("RGB")
mask = Image.open(os.path.join(self.mask_dir, os.path.basename(item["mask"]))).convert("L")
mask_np = np.array(mask)
bbox = self._get_bbox_from_mask(mask_np)
if self.processor is not None:
inputs = self.processor(
images=img,
input_boxes=[[bbox]],
return_tensors="pt",
)
for k in inputs:
if isinstance(inputs[k], torch.Tensor):
inputs[k] = inputs[k].squeeze(0)
else:
inputs = {
"pixel_values": torch.from_numpy(np.array(img)).permute(2, 0, 1).float(),
"input_boxes": torch.tensor([bbox], dtype=torch.float32),
}
gt_mask = np.array(mask.resize((256, 256), Image.NEAREST))
gt_mask = (gt_mask > 0).astype(np.float32)
inputs["ground_truth_mask"] = torch.from_numpy(gt_mask)
return inputs
def collate_fn(batch):
"""Collate for DataLoader."""
return {
"pixel_values": torch.stack([b["pixel_values"] for b in batch]),
"input_boxes": torch.stack([b["input_boxes"] for b in batch]),
"ground_truth_mask": torch.stack([b["ground_truth_mask"] for b in batch]),
}