import json import torch from torch.utils.data import Dataset import re from collections import Counter class ChatTokenizer: def __init__(self, vocab_size=1000): self.vocab_size = vocab_size self.token2id = {} self.id2token = {} self.bpe_ranks = {} def tokenize(self, text): words = re.findall(r"\w+|\S", text.lower()) return [' '.join(list(word)) + ' ' for word in words] def get_stats(self, tokens): pairs = Counter() for token in tokens: symbols = token.split() for i in range(len(symbols) - 1): pairs[(symbols[i], symbols[i+1])] += 1 return pairs def merge_pairs(self, tokens, pair): pattern = re.escape(' '.join(pair)) replacement = ''.join(pair) return [re.sub(rf'\b{pattern}\b', replacement, token) for token in tokens] def train(self, texts): tokens = [] for text in texts: tokens.extend(self.tokenize(text)) vocab = Counter(tokens) for _ in range(self.vocab_size): pairs = self.get_stats(vocab) if not pairs: break best = pairs.most_common(1)[0][0] vocab = Counter(self.merge_pairs(vocab.elements(), best)) self.bpe_ranks[best] = _ final_tokens = set() for token in vocab: final_tokens.update(token.split()) final_tokens.update(["", "", "", "^user:", "minigpt:"]) self.token2id = {tok: i for i, tok in enumerate(sorted(final_tokens))} self.id2token = {i: tok for tok, i in self.token2id.items()} def encode(self, text): tokenized = self.tokenize(text) for pair, _ in sorted(self.bpe_ranks.items(), key=lambda x: x[1]): tokenized = self.merge_pairs(tokenized, pair) ids = [] for token in tokenized: for part in token.split(): ids.append(self.token2id.get(part, self.token2id[""])) ids.append(self.token2id[""]) return ids def decode(self, token_ids): tokens = [self.id2token.get(tid, "") for tid in token_ids] sentence = "" for tok in tokens: if tok == "": break elif tok == "": sentence += " " elif tok in {"", ""}: continue else: sentence += tok return sentence.strip() def save(self, path): with open(path, "w", encoding="utf-8") as f: json.dump({ "token2id": self.token2id, "bpe_ranks": {f"{a} {b}": r for (a, b), r in self.bpe_ranks.items()} }, f) def load(self, path): with open(path, "r", encoding="utf-8") as f: data = json.load(f) self.token2id = {k: int(v) for k, v in data["token2id"].items()} self.id2token = {v: k for k, v in self.token2id.items()} self.bpe_ranks = {tuple(k.split()): v for k, v in data["bpe_ranks"].items()} def __len__(self): return len(self.token2id) @property def stoi(self): return self.token2id @property def itos(self): return self.id2token @property def vocab_size(self): return len(self.token2id) class ChatDataset(Dataset): def __init__(self, file_path, tokenizer, block_size=64): self.samples = [] with open(file_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) text = data["text"].strip() # Wrap in format: ^User: ... MiniGPT: ... if not text.lower().startswith("^user:"): text = "^User: " + text if "MiniGPT:" not in text: text += "\nMiniGPT:" tokens = tokenizer.encode(text) for i in range(0, len(tokens) - block_size): x = tokens[i:i + block_size] y = tokens[i + 1:i + block_size + 1] self.samples.append((x, y)) def __len__(self): return len(self.samples) def __getitem__(self, idx): x, y = self.samples[idx] return torch.tensor(x), torch.tensor(y) class ChatDataset(Dataset): def __init__(self, file_path, tokenizer, block_size=64): self.samples = [] with open(file_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) text = data["text"].strip() # Wrap in format: ^User: ... MiniGPT: ... if not text.lower().startswith("^user:"): text = "^User: " + text if "MiniGPT:" not in text: text += "\nMiniGPT:" tokens = tokenizer.encode(text) + [tokenizer.stoi[""]] for i in range(0, len(tokens) - block_size): x = tokens[i:i + block_size] y = tokens[i + 1:i + block_size + 1] self.samples.append((x, y)) def __len__(self): return len(self.samples) def __getitem__(self, idx): x, y = self.samples[idx] return torch.tensor(x), torch.tensor(y)