File size: 928 Bytes
364daa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import load_dataset
import pandas as pd
from torch.utils.data import Dataset
import torch

# ✅ ADD THIS FUNCTION (missing part)
def load_vqa_dataset(split="train", cache_dir="./hf_cache"):
    dataset = load_dataset(
        "flaviagiammarino/vqa-rad",
        split=split,
        cache_dir=cache_dir
    )

    df = pd.DataFrame(dataset)
    df = df[["image", "question", "answer"]]

    return df


# ✅ Dataset class (already correct)
class VQADataset(Dataset):
    def __init__(self, df, transform):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        image = row["image"].convert("RGB")
        image = self.transform(image)

        question = torch.tensor(row["question_encoded"])
        answer = torch.tensor(row["answer_encoded"])

        return image, question, answer