MiniGPT / dataset.py
CreatedNull's picture
Upload folder using huggingface_hub
4de3b20 verified
raw
history blame
9.73 kB
from concurrent.futures import thread
import json
import threading
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm
import re
import time
import os
from collections import Counter
class ChatDataset(Dataset):
def __init__(self, file_path, tokenizer, block_size=16):
self.samples = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
data = json.loads(line)
tokens = tokenizer.encode(data["text"]) + [tokenizer.stoi["<END>"]]
for i in range(0, len(tokens) - block_size):
x = tokens[i:i + block_size]
y = tokens[i + 1:i + block_size + 1]
self.samples.append((x, y))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
x, y = self.samples[idx]
return torch.tensor(x), torch.tensor(y)
class MiniBPETokenizr:
def __init__(self):
self.stoi = {} # string to index
self.itos = {} # index to string
self.vocab_size = 0
def __len__(self):
return len(self.stoi)
def tokenize(self, text):
text = text.lower().strip()
words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
return [list(w) + ['</w>'] if w.isalnum() else [w] for w in words]
def get_stats(self, corpus):
pairs = Counter()
for tokens in corpus:
for i in range(len(tokens)-1):
pairs[(tokens[i], tokens[i+1])] += 1
return pairs
def merge_vocab(self, corpus, pair_to_merge):
merged = []
bigram = re.escape(' '.join(pair_to_merge))
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
for tokens in corpus:
token_str = ' '.join(tokens)
token_str = pattern.sub(''.join(pair_to_merge), token_str)
merged.append(token_str.split())
return merged
def train(self, texts, merge_limit=1000):
corpus = [sum(self.tokenize(t), []) for t in texts]
merges_done = 0
loop = tqdm(total=merge_limit, desc="Training BPE")
while merges_done < merge_limit:
pairs = self.get_stats(corpus)
if not pairs:
tqdm.write("⚠️ No more pairs to merge.")
break
best = max(pairs, key=pairs.get)
corpus = self.merge_vocab(corpus, best)
merges_done += 1
loop.n = merges_done
loop.refresh()
#tqdm.write(f"best: {best}")
#tqdm.write(f"corpus: {corpus}")
vocab = set(tok for seq in corpus for tok in seq)
vocab.update({"<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"})
self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
self.itos = {i: tok for tok, i in self.stoi.items()}
print(f"stoi: {len(self.stoi)}")
print(f"itos: {len(self.itos)}")
self.vocab_size = len(self.stoi)
def encode(self, text):
tokens = sum(self.tokenize(text), [])
output = []
i = 0
while i < len(tokens):
j = len(tokens)
while j > i:
candidate = ''.join(tokens[i:j])
if candidate in self.stoi:
output.append(self.stoi[candidate])
i = j
break
j -= 1
else:
output.append(self.stoi.get("<UNK>", 1))
i += 1
return output
def decode(self, token_ids):
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
# Join tokens and remove </w> markers, then fix spacing before punctuation
text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
text = re.sub(r'\s([?.!,:;])', r'\1', text) # Remove space before punctuation
return text.strip()
def save(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
def load(self, path):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
self.itos = {int(v): k for k, v in self.stoi.items()}
self.vocab_size = len(self.stoi)
class SimpleTokenizr:
def __init__(self):
self.stoi = {}
self.itos = {}
def tokenize(self, text):
# Lowercase and split into words, digits, and punctuation
#return re.findall(r"[a-zA-Z]+|\d+|[^\w\s]", text.lower()) -- somewhat good
return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]",text.lower())
def train(self, texts):
vocab = set()
for text in texts:
tokens = self.tokenize(text)
vocab.update(tokens)
# Add special tokens
vocab.update(["<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":","Minigpt"])
sorted_vocab = sorted(vocab)
self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
self.itos = {idx: token for token, idx in self.stoi.items()}
def encode(self, text):
tokens = self.tokenize(text)
return [self.stoi.get(tok, self.stoi["<UNK>"]) for tok in tokens] + [self.stoi["<END>"]]
def decode(self, token_ids):
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
# Filter special/utility tokens
clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":"}]
# Join with proper formatting
text = ''
for i, tok in enumerate(clean_tokens):
if re.match(r"[.,!?;:]", tok): # no space before punctuation
text += tok
elif i > 0:
text += ' ' + tok
else:
text += tok
return text.strip().capitalize()
def save(self, path):
with open(path, "w", encoding="utf-8") as f:
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
def load(self, path):
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
self.itos = {int(k): v for v, k in self.stoi.items()}
def __len__(self):
return len(self.stoi)
@property
def vocab_size(self):
return len(self.stoi)
def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=0.001)
checkpoint_path = "./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
if "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
start_step = checkpoint["step"]
else:
print("⚠️ Legacy checkpoint detected. Loading only model weights.")
model.load_state_dict(checkpoint)
else:
print("🚀 Starting from scratch.")
total_steps = start_step
sreq = 0
#scheduler = OneCycleLR(optimizer,max_lr=1e-4,total_steps=epochs * len(dataloader),pct_start=0.1,anneal_strategy="linear")
for epoch in range(start_epoch, epochs):
total_loss = 0
loop = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{epochs} Training")
for step, (x, y) in loop:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
total_steps += 1
sreq += 1
# Save every 4 steps
if sreq >= 4:
tqdm.write("💾 Saved checkpoint.")
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch,
"step": total_steps
}, "./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth")
tokenizer.save("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
sreq = 0
loop.set_postfix(loss=loss.item())
print(f"✅ Final Loss: {total_loss / total_steps:.4f}")
torch.save(model.state_dict(), "./customchatbot-v1/trained-mini-gpt/mini-gpt.pth")
tokenizer.save("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
print("🎉 Training complete.")
# 🔧 Example usage
# tokenizer = SimpleTokenizr()
# tokenizer.load("path/to/tokenizer.json")
# dataset = ChatDataset("your_dataset.jsonl", tokenizer)
# model = YourModelClass(...) # your GPT-like model
# train(model, dataset, tokenizer, epochs=2, filepathh="your_dataset.jsonl")