| from datasets import load_dataset | |
| import pandas as pd | |
| from torch.utils.data import Dataset | |
| import torch | |
| # ✅ ADD THIS FUNCTION (missing part) | |
| def load_vqa_dataset(split="train", cache_dir="./hf_cache"): | |
| dataset = load_dataset( | |
| "flaviagiammarino/vqa-rad", | |
| split=split, | |
| cache_dir=cache_dir | |
| ) | |
| df = pd.DataFrame(dataset) | |
| df = df[["image", "question", "answer"]] | |
| return df | |
| # ✅ Dataset class (already correct) | |
| class VQADataset(Dataset): | |
| def __init__(self, df, transform): | |
| self.df = df | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| row = self.df.iloc[idx] | |
| image = row["image"].convert("RGB") | |
| image = self.transform(image) | |
| question = torch.tensor(row["question_encoded"]) | |
| answer = torch.tensor(row["answer_encoded"]) | |
| return image, question, answer |