import torch from torch.utils.data import Dataset class OpenMPDataset(Dataset): def __init__(self, inputs, outputs, tokenizer, max_input_len=500, max_output_len=100): self.inputs = inputs self.outputs = outputs self.tokenizer = tokenizer self.max_input_len = max_input_len self.max_output_len = max_output_len self.pad_idx = tokenizer.char2idx[''] def __len__(self): return len(self.inputs) def __getitem__(self, idx): input_ids = self.tokenizer.encode( self.inputs[idx], self.max_input_len, add_special_tokens=True ) output_ids = self.tokenizer.encode( self.outputs[idx], self.max_output_len, add_special_tokens=True ) input_len = next( (i for i, tok in enumerate(input_ids) if tok == self.pad_idx), self.max_input_len ) output_len = next( (i for i, tok in enumerate(output_ids) if tok == self.pad_idx), self.max_output_len ) return { 'input': torch.tensor(input_ids, dtype=torch.long), 'output': torch.tensor(output_ids, dtype=torch.long), 'input_len': torch.tensor(input_len, dtype=torch.long), 'output_len': torch.tensor(output_len, dtype=torch.long) }