| import json
|
| import torch
|
| from torch.utils.data import Dataset
|
| import re
|
| from collections import Counter
|
|
|
|
|
| class ChatTokenizer:
|
| def __init__(self, vocab_size=1000):
|
| self.vocab_size = vocab_size
|
| self.token2id = {}
|
| self.id2token = {}
|
| self.bpe_ranks = {}
|
|
|
| def tokenize(self, text):
|
| words = re.findall(r"\w+|\S", text.lower())
|
| return [' '.join(list(word)) + ' </w>' for word in words]
|
|
|
| def get_stats(self, tokens):
|
| pairs = Counter()
|
| for token in tokens:
|
| symbols = token.split()
|
| for i in range(len(symbols) - 1):
|
| pairs[(symbols[i], symbols[i+1])] += 1
|
| return pairs
|
|
|
| def merge_pairs(self, tokens, pair):
|
| pattern = re.escape(' '.join(pair))
|
| replacement = ''.join(pair)
|
| return [re.sub(rf'\b{pattern}\b', replacement, token) for token in tokens]
|
|
|
| def train(self, texts):
|
| tokens = []
|
| for text in texts:
|
| tokens.extend(self.tokenize(text))
|
| vocab = Counter(tokens)
|
|
|
| for _ in range(self.vocab_size):
|
| pairs = self.get_stats(vocab)
|
| if not pairs:
|
| break
|
| best = pairs.most_common(1)[0][0]
|
| vocab = Counter(self.merge_pairs(vocab.elements(), best))
|
| self.bpe_ranks[best] = _
|
|
|
| final_tokens = set()
|
| for token in vocab:
|
| final_tokens.update(token.split())
|
| final_tokens.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
|
| self.token2id = {tok: i for i, tok in enumerate(sorted(final_tokens))}
|
| self.id2token = {i: tok for tok, i in self.token2id.items()}
|
|
|
| def encode(self, text):
|
| tokenized = self.tokenize(text)
|
| for pair, _ in sorted(self.bpe_ranks.items(), key=lambda x: x[1]):
|
| tokenized = self.merge_pairs(tokenized, pair)
|
| ids = []
|
| for token in tokenized:
|
| for part in token.split():
|
| ids.append(self.token2id.get(part, self.token2id["<UNK>"]))
|
| ids.append(self.token2id["<END>"])
|
| return ids
|
|
|
| def decode(self, token_ids):
|
| tokens = [self.id2token.get(tid, "<UNK>") for tid in token_ids]
|
| sentence = ""
|
| for tok in tokens:
|
| if tok == "<END>":
|
| break
|
| elif tok == "</w>":
|
| sentence += " "
|
| elif tok in {"<PAD>", "<UNK>"}:
|
| continue
|
| else:
|
| sentence += tok
|
| return sentence.strip()
|
|
|
| def save(self, path):
|
| with open(path, "w", encoding="utf-8") as f:
|
| json.dump({
|
| "token2id": self.token2id,
|
| "bpe_ranks": {f"{a} {b}": r for (a, b), r in self.bpe_ranks.items()}
|
| }, f)
|
|
|
| def load(self, path):
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| self.token2id = {k: int(v) for k, v in data["token2id"].items()}
|
| self.id2token = {v: k for k, v in self.token2id.items()}
|
| self.bpe_ranks = {tuple(k.split()): v for k, v in data["bpe_ranks"].items()}
|
|
|
| def __len__(self):
|
| return len(self.token2id)
|
|
|
| @property
|
| def stoi(self):
|
| return self.token2id
|
|
|
| @property
|
| def itos(self):
|
| return self.id2token
|
|
|
| @property
|
| def vocab_size(self):
|
| return len(self.token2id)
|
|
|
|
|
| class ChatDataset(Dataset):
|
| def __init__(self, file_path, tokenizer, block_size=64):
|
| self.samples = []
|
| with open(file_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| data = json.loads(line)
|
| text = data["text"].strip()
|
|
|
|
|
| if not text.lower().startswith("^user:"):
|
| text = "^User: " + text
|
| if "MiniGPT:" not in text:
|
| text += "\nMiniGPT:"
|
|
|
| tokens = tokenizer.encode(text)
|
|
|
| for i in range(0, len(tokens) - block_size):
|
| x = tokens[i:i + block_size]
|
| y = tokens[i + 1:i + block_size + 1]
|
| self.samples.append((x, y))
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| x, y = self.samples[idx]
|
| return torch.tensor(x), torch.tensor(y)
|
|
|
|
|
|
|
|
|
| class ChatDataset(Dataset):
|
| def __init__(self, file_path, tokenizer, block_size=64):
|
| self.samples = []
|
| with open(file_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if not line:
|
| continue
|
| data = json.loads(line)
|
| text = data["text"].strip()
|
|
|
|
|
| if not text.lower().startswith("^user:"):
|
| text = "^User: " + text
|
| if "MiniGPT:" not in text:
|
| text += "\nMiniGPT:"
|
|
|
| tokens = tokenizer.encode(text) + [tokenizer.stoi["<END>"]]
|
|
|
| for i in range(0, len(tokens) - block_size):
|
| x = tokens[i:i + block_size]
|
| y = tokens[i + 1:i + block_size + 1]
|
| self.samples.append((x, y))
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| x, y = self.samples[idx]
|
| return torch.tensor(x), torch.tensor(y) |