Spaces:
Sleeping
Sleeping
File size: 2,199 Bytes
0775134 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | import torch
class CharacterTokenizer:
def __init__(self, content):
self.vocab = sorted(list(set(content)))
self.vocab_size = len(self.vocab)
self.char_to_idx = { ch:i for i,ch in enumerate(self.vocab) }
self.idx_to_char = { i:ch for i,ch in enumerate(self.vocab) }
def encode(self, xs):
return [self.char_to_idx[x] for x in xs]
def decode(self, xs):
return ''.join([self.idx_to_char[x] for x in xs])
class BPETokenizer:
"""Subword BPE tokenizer backed by sentencepiece.
If model_path already exists it is loaded directly.
Otherwise sentencepiece is trained on input_path and saved to model_path.
"""
def __init__(self, input_path: str, model_path: str = "bpe.model", vocab_size: int = 2000):
import sentencepiece as spm
import os
if not os.path.exists(model_path):
print(f"Training BPE tokenizer (vocab_size={vocab_size}) → {model_path}")
spm.SentencePieceTrainer.train(
input=input_path,
model_prefix=model_path.replace(".model", ""),
vocab_size=vocab_size,
character_coverage=1.0,
model_type="bpe",
pad_id=3,
)
print("BPE tokenizer ready.")
self.sp = spm.SentencePieceProcessor(model_file=model_path)
self.vocab_size = self.sp.get_piece_size()
def encode(self, text: str):
return self.sp.encode(text)
def decode(self, ids):
return self.sp.decode(ids.tolist() if hasattr(ids, 'tolist') else list(ids))
class Dataset:
def __init__(self, content, context_size, batch_size, split_factor=0.9):
self.context_size = context_size
self.batch_size = batch_size
self.data = content
assert split_factor > 0 and split_factor < 1
n = int(len(self.data) * split_factor)
self.train_data, self.val_data = self.data[:n], self.data[n:]
def get_batch(self, split, device, y_shift=1):
data = self.train_data if split == 'train' else self.val_data
ix = torch.randint(len(data) - self.context_size - y_shift, (self.batch_size,))
x = torch.stack([data[i:i+self.context_size] for i in ix])
y = torch.stack([data[i+y_shift:i+self.context_size+y_shift] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
|