image_captioning / dataset_git.py
pchandragrid's picture
Deploy Streamlit app
a745a5e
import os
import random
import torch
from torch.utils.data import Dataset
from PIL import Image
import json
class COCODatasetGIT(Dataset):
def __init__(self, annotation_file, image_folder, processor, mode="mixed"):
self.annotations = []
self.image_folder = image_folder
self.processor = processor
self.mode = mode
# Proper JSONL loading
with open(annotation_file, "r") as f:
for line in f:
self.annotations.append(json.loads(line.strip()))
def __len__(self):
return len(self.annotations)
def select_caption(self, captions):
if self.mode == "short":
captions = [c for c in captions if len(c.split()) <= 10]
elif self.mode == "long":
captions = [c for c in captions if len(c.split()) > 10]
if len(captions) == 0:
captions = self.annotations[
random.randint(0, len(self.annotations) - 1)
]["captions"]
return random.choice(captions)
def __getitem__(self, idx):
ann = self.annotations[idx]
image_path = os.path.join(self.image_folder, ann["image"])
image = Image.open(image_path).convert("RGB")
caption = self.select_caption(ann["captions"])
encoding = self.processor(
images=image,
text=caption,
padding="max_length",
truncation=True,
max_length=30,
return_tensors="pt"
)
input_ids = encoding["input_ids"].squeeze(0)
attention_mask = encoding["attention_mask"].squeeze(0)
pixel_values = encoding["pixel_values"].squeeze(0)
return {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids # GIT uses input_ids as labels
}