mulasagg's picture
final
a9640f8
raw
history blame
1.34 kB
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from PIL import Image
import numpy as np
from typing import List
class CombinedDataset(Dataset):
def __init__(self, api_call_list, img_path, sequence_length, max_len=128, transforms=None, tokenizer_name='bert-base-uncased'):
self.image_path = img_path
self.transforms = transforms
self.max_len = max_len
self.sequence_length = sequence_length
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
self.api_calls = api_call_list
self.encoded_calls = [self.tokenizer.encode(" ".join(call), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True) for call in self.api_calls]
self.padded_calls = np.array([x + [0] * (self.max_len - len(x)) if len(x) < self.max_len else x[:self.max_len] for x in self.encoded_calls])
print("Dataset initialized")
def __len__(self):
return len(self.padded_calls)
def __getitem__(self,idx):
img_path = self.image_path
image = Image.open(img_path)
if self.transforms:
image = self.transforms(image)
tokenized_seq = self.padded_calls
return torch.tensor(tokenized_seq, dtype=torch.long), image