"""PyTorch Dataset for CLIPSeg fine-tuning.""" import json import random from pathlib import Path import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from transformers import CLIPSegProcessor class DrywallSegDataset(Dataset): """Dataset that yields (image, mask, prompt) tuples for CLIPSeg.""" def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352): with open(split_json) as f: self.records = json.load(f) self.processor = processor self.image_size = image_size def __len__(self): return len(self.records) def __getitem__(self, idx): rec = self.records[idx] # Load image image = Image.open(rec["image_path"]).convert("RGB") # Load mask and resize to CLIPSeg resolution mask = Image.open(rec["mask_path"]).convert("L") mask = mask.resize((self.image_size, self.image_size), Image.NEAREST) mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0 # {0.0, 1.0} # Random prompt synonym prompt = random.choice(rec["prompts"]) # Process through CLIPSeg processor inputs = self.processor( text=[prompt], images=[image], return_tensors="pt", padding=True, ) return { "pixel_values": inputs["pixel_values"].squeeze(0), "input_ids": inputs["input_ids"].squeeze(0), "attention_mask": inputs["attention_mask"].squeeze(0), "labels": mask_tensor, "dataset": rec["dataset"], "image_path": rec["image_path"], "mask_path": rec["mask_path"], "prompt": prompt, "orig_width": rec["width"], "orig_height": rec["height"], } def collate_fn(batch): """Custom collation: pad input_ids and attention_mask to max length in batch.""" max_len = max(item["input_ids"].shape[0] for item in batch) pixel_values = torch.stack([item["pixel_values"] for item in batch]) labels = torch.stack([item["labels"] for item in batch]) input_ids = [] attention_masks = [] for item in batch: ids = item["input_ids"] mask = item["attention_mask"] pad_len = max_len - ids.shape[0] if pad_len > 0: ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)]) mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)]) input_ids.append(ids) attention_masks.append(mask) return { "pixel_values": pixel_values, "input_ids": torch.stack(input_ids), "attention_mask": torch.stack(attention_masks), "labels": labels, "dataset": [item["dataset"] for item in batch], "image_path": [item["image_path"] for item in batch], "mask_path": [item["mask_path"] for item in batch], "prompt": [item["prompt"] for item in batch], "orig_width": [item["orig_width"] for item in batch], "orig_height": [item["orig_height"] for item in batch], }