from datasets import load_dataset import pandas as pd from torch.utils.data import Dataset import torch # ✅ ADD THIS FUNCTION (missing part) def load_vqa_dataset(split="train", cache_dir="./hf_cache"): dataset = load_dataset( "flaviagiammarino/vqa-rad", split=split, cache_dir=cache_dir ) df = pd.DataFrame(dataset) df = df[["image", "question", "answer"]] return df # ✅ Dataset class (already correct) class VQADataset(Dataset): def __init__(self, df, transform): self.df = df self.transform = transform def __len__(self): return len(self.df) def __getitem__(self, idx): row = self.df.iloc[idx] image = row["image"].convert("RGB") image = self.transform(image) question = torch.tensor(row["question_encoded"]) answer = torch.tensor(row["answer_encoded"]) return image, question, answer