File size: 3,106 Bytes
b891e61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """PyTorch Dataset for CLIPSeg fine-tuning."""
import json
import random
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPSegProcessor
class DrywallSegDataset(Dataset):
"""Dataset that yields (image, mask, prompt) tuples for CLIPSeg."""
def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352):
with open(split_json) as f:
self.records = json.load(f)
self.processor = processor
self.image_size = image_size
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
rec = self.records[idx]
# Load image
image = Image.open(rec["image_path"]).convert("RGB")
# Load mask and resize to CLIPSeg resolution
mask = Image.open(rec["mask_path"]).convert("L")
mask = mask.resize((self.image_size, self.image_size), Image.NEAREST)
mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0 # {0.0, 1.0}
# Random prompt synonym
prompt = random.choice(rec["prompts"])
# Process through CLIPSeg processor
inputs = self.processor(
text=[prompt],
images=[image],
return_tensors="pt",
padding=True,
)
return {
"pixel_values": inputs["pixel_values"].squeeze(0),
"input_ids": inputs["input_ids"].squeeze(0),
"attention_mask": inputs["attention_mask"].squeeze(0),
"labels": mask_tensor,
"dataset": rec["dataset"],
"image_path": rec["image_path"],
"mask_path": rec["mask_path"],
"prompt": prompt,
"orig_width": rec["width"],
"orig_height": rec["height"],
}
def collate_fn(batch):
"""Custom collation: pad input_ids and attention_mask to max length in batch."""
max_len = max(item["input_ids"].shape[0] for item in batch)
pixel_values = torch.stack([item["pixel_values"] for item in batch])
labels = torch.stack([item["labels"] for item in batch])
input_ids = []
attention_masks = []
for item in batch:
ids = item["input_ids"]
mask = item["attention_mask"]
pad_len = max_len - ids.shape[0]
if pad_len > 0:
ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)])
mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)])
input_ids.append(ids)
attention_masks.append(mask)
return {
"pixel_values": pixel_values,
"input_ids": torch.stack(input_ids),
"attention_mask": torch.stack(attention_masks),
"labels": labels,
"dataset": [item["dataset"] for item in batch],
"image_path": [item["image_path"] for item in batch],
"mask_path": [item["mask_path"] for item in batch],
"prompt": [item["prompt"] for item in batch],
"orig_width": [item["orig_width"] for item in batch],
"orig_height": [item["orig_height"] for item in batch],
}
|