Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| from transformers import BertTokenizer | |
| from PIL import Image | |
| import numpy as np | |
| from typing import List | |
| class CombinedDataset(Dataset): | |
| def __init__(self, api_call_list, img_path, sequence_length, max_len=128, transforms=None, tokenizer_name='bert-base-uncased'): | |
| self.image_path = img_path | |
| self.transforms = transforms | |
| self.max_len = max_len | |
| self.sequence_length = sequence_length | |
| self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name) | |
| self.api_calls = api_call_list | |
| self.encoded_calls = [self.tokenizer.encode(" ".join(call), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True) for call in self.api_calls] | |
| self.padded_calls = np.array([x + [0] * (self.max_len - len(x)) if len(x) < self.max_len else x[:self.max_len] for x in self.encoded_calls]) | |
| print("Dataset initialized") | |
| def __len__(self): | |
| return len(self.padded_calls) | |
| def __getitem__(self,idx): | |
| img_path = self.image_path | |
| image = Image.open(img_path) | |
| if self.transforms: | |
| image = self.transforms(image) | |
| tokenized_seq = self.padded_calls | |
| return torch.tensor(tokenized_seq, dtype=torch.long), image | |