File size: 3,036 Bytes
3cc53ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
PyTorch Dataset for facade segmentation with SAM.
Generates bounding-box prompts from ground-truth masks.
"""
import os
import json
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from transformers import SamProcessor


class FacadeDataset(Dataset):
    """
    Dataset for facade segmentation.
    Loads images and binary masks. Generates bbox prompts from masks.
    """
    def __init__(self, data_dir, split="train", processor=None, augment=False):
        self.data_dir = data_dir
        self.split = split
        self.processor = processor
        self.augment = augment

        split_dir = os.path.join(data_dir, split)
        with open(os.path.join(split_dir, "metadata.json"), "r") as f:
            self.items = json.load(f)

        self.img_dir = os.path.join(split_dir, "images")
        self.mask_dir = os.path.join(split_dir, "masks_binary")

    def __len__(self):
        return len(self.items)

    def _get_bbox_from_mask(self, mask_np):
        """Compute bounding box from binary mask with small random jitter."""
        ys, xs = np.where(mask_np > 0)
        if len(xs) == 0:
            return [0, 0, mask_np.shape[1], mask_np.shape[0]]

        x1, y1, x2, y2 = xs.min(), ys.min(), xs.max(), ys.max()

        if self.augment:
            jitter = 10
            h, w = mask_np.shape
            x1 = max(0, x1 + random.randint(-jitter, jitter))
            y1 = max(0, y1 + random.randint(-jitter, jitter))
            x2 = min(w, x2 + random.randint(-jitter, jitter))
            y2 = min(h, y2 + random.randint(-jitter, jitter))

        return [int(x1), int(y1), int(x2), int(y2)]

    def __getitem__(self, idx):
        item = self.items[idx]

        img = Image.open(item["image"]).convert("RGB")
        mask = Image.open(os.path.join(self.mask_dir, os.path.basename(item["mask"]))).convert("L")

        mask_np = np.array(mask)
        bbox = self._get_bbox_from_mask(mask_np)

        if self.processor is not None:
            inputs = self.processor(
                images=img,
                input_boxes=[[bbox]],
                return_tensors="pt",
            )
            for k in inputs:
                if isinstance(inputs[k], torch.Tensor):
                    inputs[k] = inputs[k].squeeze(0)
        else:
            inputs = {
                "pixel_values": torch.from_numpy(np.array(img)).permute(2, 0, 1).float(),
                "input_boxes": torch.tensor([bbox], dtype=torch.float32),
            }

        gt_mask = np.array(mask.resize((256, 256), Image.NEAREST))
        gt_mask = (gt_mask > 0).astype(np.float32)
        inputs["ground_truth_mask"] = torch.from_numpy(gt_mask)

        return inputs


def collate_fn(batch):
    """Collate for DataLoader."""
    return {
        "pixel_values": torch.stack([b["pixel_values"] for b in batch]),
        "input_boxes": torch.stack([b["input_boxes"] for b in batch]),
        "ground_truth_mask": torch.stack([b["ground_truth_mask"] for b in batch]),
    }