mulasagg's picture
final
a9640f8
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from PIL import Image
import numpy as np
from typing import List
class CombinedDataset(Dataset):
def __init__(self, api_call_list, img_path, sequence_length, max_len=128, transforms=None, tokenizer_name='bert-base-uncased'):
self.image_path = img_path
self.transforms = transforms
self.max_len = max_len
self.sequence_length = sequence_length
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
self.api_calls = api_call_list
self.encoded_calls = [self.tokenizer.encode(" ".join(call), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True) for call in self.api_calls]
self.padded_calls = np.array([x + [0] * (self.max_len - len(x)) if len(x) < self.max_len else x[:self.max_len] for x in self.encoded_calls])
print("Dataset initialized")
def __len__(self):
return len(self.padded_calls)
def __getitem__(self,idx):
img_path = self.image_path
image = Image.open(img_path)
if self.transforms:
image = self.transforms(image)
tokenized_seq = self.padded_calls
return torch.tensor(tokenized_seq, dtype=torch.long), image