| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset |
|
|
| import json |
|
|
| class BilingualDataset(Dataset): |
| def __init__(self, ds, tokenizer, seq_len): |
| super().__init__() |
| |
| self.tokenizer = tokenizer |
| self.seq_len = seq_len |
| self.ds = ds |
| self.stride = seq_len//2 |
| self.sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64) |
| self.eos_token = torch.tensor([tokenizer.token_to_id('</s>')],dtype=torch.int64) |
| self.pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64) |
| self.user_token = torch.tensor([tokenizer.token_to_id('<user>')],dtype=torch.int64) |
| self.ai_token = torch.tensor([tokenizer.token_to_id('<ai>')],dtype=torch.int64) |
| |
| self.data_tokens = [] |
| |
| for text in self.ds: |
| |
| |
| user_tokens = tokenizer.encode(text['instruction'] + " " + text['input']).ids |
| ai_tokens = tokenizer.encode(text['output']).ids |
| self.data_tokens.extend([self.user_token] + user_tokens + [self.ai_token] + ai_tokens+ [self.eos_token] ) |
| |
| def __len__(self): |
| return (len(self.data_tokens) - self.seq_len) // self.stride |
| |
| def __getitem__(self, index): |
| |
| input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist() |
| |
| input_tokens = [self.sos_token] + input_tokens + [self.pad_token] |
| if len(input_tokens) < self.seq_len - 1: |
| input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens)) |
| |
| input_tokens = torch.tensor(input_tokens) |
| |
| |
| return { |
| "input": input_tokens[:-1], |
| |
| "label":input_tokens[1:] |
| } |
| |
| def causal_mask(size): |
| mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int) |
| return mask == 0 |