| | import torch |
| | from torch.utils.data import IterableDataset |
| |
|
| | class BilingualDataset(IterableDataset): |
| | def __init__(self, ds_stream, tokenizer, seq_len): |
| | self.ds_stream = ds_stream |
| | self.tokenizer = tokenizer |
| | self.seq_len = seq_len |
| | self.stride = seq_len // 2 |
| | self.sos_token = tokenizer.token_to_id('<s>') |
| | self.eos_token = tokenizer.token_to_id('</s>') |
| | self.pad_token = tokenizer.token_to_id('<pad>') |
| |
|
| | def process_text(self, text): |
| | token_ids = self.tokenizer.encode(text).ids + [self.eos_token] |
| |
|
| | for i in range(0, max(1, len(token_ids) - self.seq_len + 1), self.stride): |
| | chunk = token_ids[i:i + self.seq_len - 2] |
| | chunk = [self.sos_token] + chunk |
| | if len(chunk) < self.seq_len: |
| | chunk += [self.pad_token] * (self.seq_len - len(chunk)) |
| | |
| | input_tensor = torch.tensor(chunk[:-1], dtype=torch.long) |
| | label_tensor = torch.tensor(chunk[1:], dtype=torch.long) |
| | yield { |
| | "input": input_tensor, |
| | "label": label_tensor |
| | } |
| |
|
| | def __iter__(self): |
| | for item in self.ds_stream: |
| | text = item["text"] |
| | yield from self.process_text(text) |
| |
|
| |
|
| | """import torch |
| | import torch.nn as nn |
| | from torch.utils.data import Dataset |
| | |
| | import json |
| | |
| | class BilingualDataset(Dataset): |
| | def __init__(self, ds, tokenizer, seq_len): |
| | super().__init__() |
| | |
| | self.tokenizer = tokenizer |
| | self.seq_len = seq_len |
| | self.ds = ds |
| | self.stride = seq_len//2 |
| | self.sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64) |
| | self.eos_token = torch.tensor([tokenizer.token_to_id('</s>')],dtype=torch.int64) |
| | self.pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64) |
| | |
| | self.data_tokens = [] |
| | |
| | for text in self.ds: |
| | # text = text['instruction'] +" ### " + text['text'] + " \n" + text['output'] |
| | # text = text['user'] +" ### " + text['ai'] |
| | text = text['text'] |
| | tokens = tokenizer.encode(text).ids |
| | self.data_tokens.extend(tokens + [self.eos_token]) |
| | |
| | def __len__(self): |
| | return (len(self.data_tokens) - self.seq_len) // self.stride |
| | |
| | def __getitem__(self, index): |
| | |
| | input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist() |
| | |
| | input_tokens = [self.sos_token] + input_tokens + [self.pad_token] |
| | if len(input_tokens) < self.seq_len - 1: |
| | input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens)) |
| | |
| | input_tokens = torch.tensor(input_tokens) |
| | |
| | |
| | return { |
| | "input": input_tokens[:-1], |
| | # "input_mask": (input_tokens[:-1] != self.pad_token).unsqueeze(0).int() & causal_mask(input_tokens[:-1].size(0)), # (1, seq_len) & (1, seq_len, seq_len) |
| | "label":input_tokens[1:] # ^ CONFUSION SYNTAX :) |
| | } |
| | |
| | def causal_mask(size): |
| | mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int) |
| | return mask == 0""" |