import torch from torch.utils.data import Dataset from transformers import BertTokenizer from PIL import Image import numpy as np from typing import List class CombinedDataset(Dataset): def __init__(self, api_call_list, img_path, sequence_length, max_len=128, transforms=None, tokenizer_name='bert-base-uncased'): self.image_path = img_path self.transforms = transforms self.max_len = max_len self.sequence_length = sequence_length self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name) self.api_calls = api_call_list self.encoded_calls = [self.tokenizer.encode(" ".join(call), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True) for call in self.api_calls] self.padded_calls = np.array([x + [0] * (self.max_len - len(x)) if len(x) < self.max_len else x[:self.max_len] for x in self.encoded_calls]) print("Dataset initialized") def __len__(self): return len(self.padded_calls) def __getitem__(self,idx): img_path = self.image_path image = Image.open(img_path) if self.transforms: image = self.transforms(image) tokenized_seq = self.padded_calls return torch.tensor(tokenized_seq, dtype=torch.long), image