File size: 3,106 Bytes
b891e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""PyTorch Dataset for CLIPSeg fine-tuning."""

import json
import random
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPSegProcessor


class DrywallSegDataset(Dataset):
    """Dataset that yields (image, mask, prompt) tuples for CLIPSeg."""

    def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352):
        with open(split_json) as f:
            self.records = json.load(f)
        self.processor = processor
        self.image_size = image_size

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]

        # Load image
        image = Image.open(rec["image_path"]).convert("RGB")

        # Load mask and resize to CLIPSeg resolution
        mask = Image.open(rec["mask_path"]).convert("L")
        mask = mask.resize((self.image_size, self.image_size), Image.NEAREST)
        mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0  # {0.0, 1.0}

        # Random prompt synonym
        prompt = random.choice(rec["prompts"])

        # Process through CLIPSeg processor
        inputs = self.processor(
            text=[prompt],
            images=[image],
            return_tensors="pt",
            padding=True,
        )

        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": mask_tensor,
            "dataset": rec["dataset"],
            "image_path": rec["image_path"],
            "mask_path": rec["mask_path"],
            "prompt": prompt,
            "orig_width": rec["width"],
            "orig_height": rec["height"],
        }


def collate_fn(batch):
    """Custom collation: pad input_ids and attention_mask to max length in batch."""
    max_len = max(item["input_ids"].shape[0] for item in batch)

    pixel_values = torch.stack([item["pixel_values"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])

    input_ids = []
    attention_masks = []
    for item in batch:
        ids = item["input_ids"]
        mask = item["attention_mask"]
        pad_len = max_len - ids.shape[0]
        if pad_len > 0:
            ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
            mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
        input_ids.append(ids)
        attention_masks.append(mask)

    return {
        "pixel_values": pixel_values,
        "input_ids": torch.stack(input_ids),
        "attention_mask": torch.stack(attention_masks),
        "labels": labels,
        "dataset": [item["dataset"] for item in batch],
        "image_path": [item["image_path"] for item in batch],
        "mask_path": [item["mask_path"] for item in batch],
        "prompt": [item["prompt"] for item in batch],
        "orig_width": [item["orig_width"] for item in batch],
        "orig_height": [item["orig_height"] for item in batch],
    }