| """PyTorch Dataset for CLIPSeg fine-tuning.""" |
|
|
| import json |
| import random |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from transformers import CLIPSegProcessor |
|
|
|
|
| class DrywallSegDataset(Dataset): |
| """Dataset that yields (image, mask, prompt) tuples for CLIPSeg.""" |
|
|
| def __init__(self, split_json: str, processor: CLIPSegProcessor, image_size: int = 352): |
| with open(split_json) as f: |
| self.records = json.load(f) |
| self.processor = processor |
| self.image_size = image_size |
|
|
| def __len__(self): |
| return len(self.records) |
|
|
| def __getitem__(self, idx): |
| rec = self.records[idx] |
|
|
| |
| image = Image.open(rec["image_path"]).convert("RGB") |
|
|
| |
| mask = Image.open(rec["mask_path"]).convert("L") |
| mask = mask.resize((self.image_size, self.image_size), Image.NEAREST) |
| mask_tensor = torch.from_numpy(np.array(mask)).float() / 255.0 |
|
|
| |
| prompt = random.choice(rec["prompts"]) |
|
|
| |
| inputs = self.processor( |
| text=[prompt], |
| images=[image], |
| return_tensors="pt", |
| padding=True, |
| ) |
|
|
| return { |
| "pixel_values": inputs["pixel_values"].squeeze(0), |
| "input_ids": inputs["input_ids"].squeeze(0), |
| "attention_mask": inputs["attention_mask"].squeeze(0), |
| "labels": mask_tensor, |
| "dataset": rec["dataset"], |
| "image_path": rec["image_path"], |
| "mask_path": rec["mask_path"], |
| "prompt": prompt, |
| "orig_width": rec["width"], |
| "orig_height": rec["height"], |
| } |
|
|
|
|
| def collate_fn(batch): |
| """Custom collation: pad input_ids and attention_mask to max length in batch.""" |
| max_len = max(item["input_ids"].shape[0] for item in batch) |
|
|
| pixel_values = torch.stack([item["pixel_values"] for item in batch]) |
| labels = torch.stack([item["labels"] for item in batch]) |
|
|
| input_ids = [] |
| attention_masks = [] |
| for item in batch: |
| ids = item["input_ids"] |
| mask = item["attention_mask"] |
| pad_len = max_len - ids.shape[0] |
| if pad_len > 0: |
| ids = torch.cat([ids, torch.zeros(pad_len, dtype=ids.dtype)]) |
| mask = torch.cat([mask, torch.zeros(pad_len, dtype=mask.dtype)]) |
| input_ids.append(ids) |
| attention_masks.append(mask) |
|
|
| return { |
| "pixel_values": pixel_values, |
| "input_ids": torch.stack(input_ids), |
| "attention_mask": torch.stack(attention_masks), |
| "labels": labels, |
| "dataset": [item["dataset"] for item in batch], |
| "image_path": [item["image_path"] for item in batch], |
| "mask_path": [item["mask_path"] for item in batch], |
| "prompt": [item["prompt"] for item in batch], |
| "orig_width": [item["orig_width"] for item in batch], |
| "orig_height": [item["orig_height"] for item in batch], |
| } |
|
|