Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| class CustomDataset(Dataset): | |
| def __init__(self, image, texts, labels, tokenizer, max_len, transforms=None): | |
| self.image = image | |
| self.texts = texts | |
| self.labels = labels | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| self.transforms = transforms | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| image = self.image | |
| text = str(self.texts[idx]) | |
| label = self.labels[idx] | |
| if self.transforms: | |
| image = self.transforms(image) | |
| inputs = self.tokenizer.encode_plus( | |
| text, | |
| None, | |
| add_special_tokens=True, | |
| max_length=self.max_len, | |
| padding='max_length', | |
| truncation=True | |
| ) | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| return { | |
| 'input_ids': torch.tensor(input_ids, dtype=torch.long), | |
| 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), | |
| 'labels': torch.tensor(label, dtype=torch.float), | |
| 'images': image | |
| } | |