File size: 1,341 Bytes
a9640f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch.utils.data import Dataset

from transformers import BertTokenizer
from PIL import Image
import numpy as np

from typing import List


class CombinedDataset(Dataset):
    
    def __init__(self, api_call_list, img_path, sequence_length, max_len=128, transforms=None, tokenizer_name='bert-base-uncased'):
    
        self.image_path = img_path

        self.transforms = transforms
        self.max_len = max_len
        self.sequence_length = sequence_length
        

        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
        self.api_calls = api_call_list
        self.encoded_calls = [self.tokenizer.encode(" ".join(call), add_special_tokens=True, max_length=self.max_len, padding='max_length', truncation=True) for call in self.api_calls]
        self.padded_calls = np.array([x + [0] * (self.max_len - len(x)) if len(x) < self.max_len else x[:self.max_len] for x in self.encoded_calls])
        print("Dataset initialized")

    def __len__(self):
        return len(self.padded_calls)

    def __getitem__(self,idx):
        img_path = self.image_path
        image = Image.open(img_path)
        
        if self.transforms:
            image = self.transforms(image)

        tokenized_seq = self.padded_calls
        

        return torch.tensor(tokenized_seq, dtype=torch.long), image