| from concurrent.futures import thread
|
| import json
|
| import threading
|
| import torch
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
| from torch.optim.lr_scheduler import OneCycleLR
|
| from tqdm import tqdm
|
| import re
|
| import time
|
| import os
|
| from collections import Counter
|
|
|
| class ChatDataset(Dataset):
|
| def __init__(self, file_path, tokenizer, block_size=16):
|
| self.samples = []
|
| with open(file_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| data = json.loads(line)
|
| tokens = tokenizer.encode(data["text"]) + [tokenizer.stoi["<END>"]]
|
| for i in range(0, len(tokens) - block_size):
|
| x = tokens[i:i + block_size]
|
| y = tokens[i + 1:i + block_size + 1]
|
| self.samples.append((x, y))
|
|
|
| def __len__(self):
|
| return len(self.samples)
|
|
|
| def __getitem__(self, idx):
|
| x, y = self.samples[idx]
|
| return torch.tensor(x), torch.tensor(y)
|
|
|
| class MiniBPETokenizr:
|
| def __init__(self):
|
| self.stoi = {}
|
| self.itos = {}
|
| self.vocab_size = 0
|
|
|
| def __len__(self):
|
| return len(self.stoi)
|
|
|
| def tokenize(self, text):
|
| text = text.lower().strip()
|
| words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
|
| return [list(w) + ['</w>'] if w.isalnum() else [w] for w in words]
|
|
|
| def get_stats(self, corpus):
|
| pairs = Counter()
|
| for tokens in corpus:
|
| for i in range(len(tokens)-1):
|
| pairs[(tokens[i], tokens[i+1])] += 1
|
| return pairs
|
|
|
| def merge_vocab(self, corpus, pair_to_merge):
|
| merged = []
|
| bigram = re.escape(' '.join(pair_to_merge))
|
| pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
|
|
|
| for tokens in corpus:
|
| token_str = ' '.join(tokens)
|
| token_str = pattern.sub(''.join(pair_to_merge), token_str)
|
| merged.append(token_str.split())
|
| return merged
|
|
|
| def train(self, texts, merge_limit=1000):
|
| corpus = [sum(self.tokenize(t), []) for t in texts]
|
| merges_done = 0
|
| loop = tqdm(total=merge_limit, desc="Training BPE")
|
|
|
| while merges_done < merge_limit:
|
| pairs = self.get_stats(corpus)
|
| if not pairs:
|
| tqdm.write("⚠️ No more pairs to merge.")
|
| break
|
| best = max(pairs, key=pairs.get)
|
| corpus = self.merge_vocab(corpus, best)
|
| merges_done += 1
|
| loop.n = merges_done
|
| loop.refresh()
|
|
|
|
|
|
|
| vocab = set(tok for seq in corpus for tok in seq)
|
| vocab.update({"<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"})
|
| self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
|
| self.itos = {i: tok for tok, i in self.stoi.items()}
|
| print(f"stoi: {len(self.stoi)}")
|
| print(f"itos: {len(self.itos)}")
|
| self.vocab_size = len(self.stoi)
|
|
|
| def encode(self, text):
|
| tokens = sum(self.tokenize(text), [])
|
| output = []
|
| i = 0
|
| while i < len(tokens):
|
| j = len(tokens)
|
| while j > i:
|
| candidate = ''.join(tokens[i:j])
|
| if candidate in self.stoi:
|
| output.append(self.stoi[candidate])
|
| i = j
|
| break
|
| j -= 1
|
| else:
|
| output.append(self.stoi.get("<UNK>", 1))
|
| i += 1
|
| return output
|
|
|
| def decode(self, token_ids):
|
| tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
|
|
| text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
|
| text = re.sub(r'\s([?.!,:;])', r'\1', text)
|
| return text.strip()
|
|
|
| def save(self, path):
|
| with open(path, "w", encoding="utf-8") as f:
|
| json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
|
|
| def load(self, path):
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| self.itos = {int(v): k for k, v in self.stoi.items()}
|
| self.vocab_size = len(self.stoi)
|
|
|
| class SimpleTokenizr:
|
| def __init__(self):
|
| self.stoi = {}
|
| self.itos = {}
|
|
|
| def tokenize(self, text):
|
|
|
|
|
| return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]",text.lower())
|
|
|
| def train(self, texts):
|
| vocab = set()
|
| for text in texts:
|
| tokens = self.tokenize(text)
|
| vocab.update(tokens)
|
|
|
| vocab.update(["<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":","Minigpt"])
|
| sorted_vocab = sorted(vocab)
|
| self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
|
| self.itos = {idx: token for token, idx in self.stoi.items()}
|
|
|
| def encode(self, text):
|
| tokens = self.tokenize(text)
|
| return [self.stoi.get(tok, self.stoi["<UNK>"]) for tok in tokens] + [self.stoi["<END>"]]
|
|
|
| def decode(self, token_ids):
|
| tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
|
|
| clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":"}]
|
|
|
|
|
| text = ''
|
| for i, tok in enumerate(clean_tokens):
|
| if re.match(r"[.,!?;:]", tok):
|
| text += tok
|
| elif i > 0:
|
| text += ' ' + tok
|
| else:
|
| text += tok
|
| return text.strip().capitalize()
|
|
|
| def save(self, path):
|
| with open(path, "w", encoding="utf-8") as f:
|
| json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
|
|
| def load(self, path):
|
| with open(path, "r", encoding="utf-8") as f:
|
| data = json.load(f)
|
| self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| self.itos = {int(k): v for v, k in self.stoi.items()}
|
|
|
| def __len__(self):
|
| return len(self.stoi)
|
|
|
| @property
|
| def vocab_size(self):
|
| return len(self.stoi)
|
|
|
|
|
| def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0):
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model.to(device)
|
|
|
| dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=0.001)
|
|
|
|
|
| checkpoint_path = "./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth"
|
| if os.path.exists(checkpoint_path):
|
| checkpoint = torch.load(checkpoint_path)
|
| if "model_state_dict" in checkpoint:
|
| model.load_state_dict(checkpoint["model_state_dict"])
|
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| start_epoch = checkpoint["epoch"]
|
| start_step = checkpoint["step"]
|
| else:
|
| print("⚠️ Legacy checkpoint detected. Loading only model weights.")
|
| model.load_state_dict(checkpoint)
|
| else:
|
| print("🚀 Starting from scratch.")
|
|
|
| total_steps = start_step
|
| sreq = 0
|
|
|
| for epoch in range(start_epoch, epochs):
|
| total_loss = 0
|
| loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{epochs} Training")
|
| for step, (x, y) in loop:
|
| x, y = x.to(device), y.to(device)
|
| logits = model(x)
|
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
|
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
|
|
| total_loss += loss.item()
|
| total_steps += 1
|
| sreq += 1
|
|
|
|
|
| if sreq >= 4:
|
| tqdm.write("💾 Saved checkpoint.")
|
| torch.save({
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "epoch": epoch,
|
| "step": total_steps
|
| }, "./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth")
|
| tokenizer.save("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
|
| sreq = 0
|
|
|
| loop.set_postfix(loss=loss.item())
|
|
|
| print(f"✅ Final Loss: {total_loss / total_steps:.4f}")
|
| torch.save(model.state_dict(), "./customchatbot-v1/trained-mini-gpt/mini-gpt.pth")
|
| tokenizer.save("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
|
| print("🎉 Training complete.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |