vqa_project / data /dataset_loader.py
PRUTHVIn's picture
Upload folder using huggingface_hub
364daa0 verified
from datasets import load_dataset
import pandas as pd
from torch.utils.data import Dataset
import torch
# ✅ ADD THIS FUNCTION (missing part)
def load_vqa_dataset(split="train", cache_dir="./hf_cache"):
dataset = load_dataset(
"flaviagiammarino/vqa-rad",
split=split,
cache_dir=cache_dir
)
df = pd.DataFrame(dataset)
df = df[["image", "question", "answer"]]
return df
# ✅ Dataset class (already correct)
class VQADataset(Dataset):
def __init__(self, df, transform):
self.df = df
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
image = row["image"].convert("RGB")
image = self.transform(image)
question = torch.tensor(row["question_encoded"])
answer = torch.tensor(row["answer_encoded"])
return image, question, answer