Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| from transformers import AutoTokenizer | |
| import torch | |
| class TokenizedDataset(Dataset): | |
| def __init__(self, custom_dataset, tokenizer, max_seq_len): | |
| """ | |
| custom_dataset: An instance of CustomDataset | |
| tokenizer: An instance of the tokenizer | |
| max_seq_len: Maximum sequence length for padding | |
| """ | |
| self.dataset = custom_dataset | |
| self.tokenizer = tokenizer | |
| self.max_seq_len = max_seq_len | |
| def __len__(self): | |
| # The length is inherited from the custom dataset | |
| return len(self.dataset) | |
| def tokenize_and_pad(self, text_list): | |
| """ | |
| Tokenize and pad a list of text strings. | |
| """ | |
| # Tokenize all text strings in the list | |
| tokens = self.tokenizer(text_list, padding='max_length', max_length=self.max_seq_len, truncation=True, return_tensors="pt") | |
| return tokens | |
| def __getitem__(self, idx): | |
| # Fetch the transformed data from the CustomDataset instance | |
| transformed_data = self.dataset[idx] | |
| # Initialize containers for inputs and optionally labels | |
| tokenized_inputs = {} | |
| tokenized_labels = {} | |
| # Dynamically process each item in the dataset | |
| for key, value in transformed_data.items(): | |
| if type(value) == int: # Check if value is an integer | |
| # Convert integer to tensor and directly assign to inputs or labels based on key prefix | |
| if key.startswith('label'): | |
| tokenized_labels[key] = torch.tensor(value) # Convert int to tensor for labels | |
| else: | |
| tokenized_inputs[key] = torch.tensor(value) # Convert int to tensor for inputs | |
| if type(value) == str: | |
| tokenized_data = self.tokenize_and_pad(value) | |
| if key.startswith('label'): | |
| tokenized_labels[key] = tokenized_data['input_ids'] | |
| tokenized_labels['attention_mask_' + key] = tokenized_data['attention_mask'] | |
| else: | |
| tokenized_inputs[key] = tokenized_data['input_ids'] | |
| tokenized_inputs['attention_mask_' + key] = tokenized_data['attention_mask'] | |
| # Prepare the return structure, conditionally including 'label' if labels are present | |
| output = {"inputs": tokenized_inputs} | |
| if tokenized_labels: # Check if there are any labels before adding to the output | |
| output["label"] = tokenized_labels | |
| return output | |