Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- __pycache__/dataset.cpython-312.pyc +0 -0
- __pycache__/filter.cpython-312.pyc +0 -0
- __pycache__/model.cpython-312.pyc +0 -0
- colab-scripts/dataset.py +274 -0
- customgen.py +165 -0
- data/filtered_data.jsonl +3 -0
- data/overfit_data.jsonl +10 -0
- data/reason_data.jsonl +0 -0
- data/reasoned2_data.jsonl +0 -0
- data/reasoned_data.jsonl +0 -0
- data/unfiltered_data.jsonl +3 -0
- dataset.py +142 -90
- datasetgen-synthetic.py +75 -0
- datasetgen.py +35 -10
- datasetgen2.py +64 -0
- datasets/5k_synthetic_dataset.jsonl +0 -0
- filter.py +6 -5
- minigpt.py +11 -7
- model.py +18 -5
- train_custom.py +44 -13
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/filtered_data.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/unfiltered_data.jsonl filter=lfs diff=lfs merge=lfs -text
|
__pycache__/dataset.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/dataset.cpython-312.pyc and b/__pycache__/dataset.cpython-312.pyc differ
|
|
|
__pycache__/filter.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/filter.cpython-312.pyc and b/__pycache__/filter.cpython-312.pyc differ
|
|
|
__pycache__/model.cpython-312.pyc
CHANGED
|
Binary files a/__pycache__/model.cpython-312.pyc and b/__pycache__/model.cpython-312.pyc differ
|
|
|
colab-scripts/dataset.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from tokenizers import Tokenizer
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import os
|
| 9 |
+
import re
|
| 10 |
+
from collections import Counter
|
| 11 |
+
import multiprocessing
|
| 12 |
+
from torch.utils.data import random_split
|
| 13 |
+
multiprocessing.set_start_method("spawn", force=True)
|
| 14 |
+
|
| 15 |
+
class ChatDataset(Dataset):
|
| 16 |
+
def __init__(self, data, tokenizer, block_size=64):
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
self.block_size = block_size
|
| 19 |
+
self.data = self.tokenize_data(data)
|
| 20 |
+
|
| 21 |
+
def tokenize_data(self, data):
|
| 22 |
+
chunks = []
|
| 23 |
+
with open(data, "r", encoding="utf-8") as f:
|
| 24 |
+
for d in f:
|
| 25 |
+
line = json.loads(d.strip())
|
| 26 |
+
# Fix duplicated instruction
|
| 27 |
+
text = "^User: " + line["instruction"].strip() + " MiniGPT: " + line["output"].strip() + " <END>"
|
| 28 |
+
encoding = self.tokenizer.encode(text)
|
| 29 |
+
tokens = encoding.ids
|
| 30 |
+
#print(tokens)
|
| 31 |
+
if len(tokens) < self.block_size:
|
| 32 |
+
continue
|
| 33 |
+
for i in range(0, len(tokens) - self.block_size + 1, self.block_size):
|
| 34 |
+
chunk = tokens[i:i + self.block_size]
|
| 35 |
+
if len(chunk) == self.block_size:
|
| 36 |
+
chunks.append(chunk)
|
| 37 |
+
return chunks
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.data)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
chunk = self.data[idx]
|
| 44 |
+
x = torch.tensor(chunk[:-1])
|
| 45 |
+
y = torch.tensor(chunk[1:])
|
| 46 |
+
return x, y
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class MiniBPETokenizr:
|
| 50 |
+
def __init__(self):
|
| 51 |
+
self.stoi = {}
|
| 52 |
+
self.itos = {}
|
| 53 |
+
self.vocab_size = 0
|
| 54 |
+
|
| 55 |
+
def tokenize(self, text):
|
| 56 |
+
text = text.lower().strip()
|
| 57 |
+
words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
|
| 58 |
+
return [list(w) + ['</w>'] if w.isalnum() else [w] for w in words]
|
| 59 |
+
|
| 60 |
+
def get_stats(self, corpus):
|
| 61 |
+
pairs = Counter()
|
| 62 |
+
for tokens in corpus:
|
| 63 |
+
for i in range(len(tokens) - 1):
|
| 64 |
+
pairs[(tokens[i], tokens[i + 1])] += 1
|
| 65 |
+
return pairs
|
| 66 |
+
|
| 67 |
+
def merge_vocab(self, corpus, pair_to_merge):
|
| 68 |
+
bigram = re.escape(' '.join(pair_to_merge))
|
| 69 |
+
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
|
| 70 |
+
merged = []
|
| 71 |
+
for tokens in corpus:
|
| 72 |
+
token_str = ' '.join(tokens)
|
| 73 |
+
token_str = pattern.sub(''.join(pair_to_merge), token_str)
|
| 74 |
+
merged.append(token_str.split())
|
| 75 |
+
return merged
|
| 76 |
+
|
| 77 |
+
def train(self, texts, merge_limit=1000):
|
| 78 |
+
corpus = [sum(self.tokenize(t), []) for t in texts]
|
| 79 |
+
merges_done = 0
|
| 80 |
+
loop = tqdm(total=merge_limit, desc="Training BPE")
|
| 81 |
+
|
| 82 |
+
while merges_done < merge_limit:
|
| 83 |
+
pairs = self.get_stats(corpus)
|
| 84 |
+
if not pairs:
|
| 85 |
+
break
|
| 86 |
+
best = max(pairs, key=pairs.get)
|
| 87 |
+
corpus = self.merge_vocab(corpus, best)
|
| 88 |
+
merges_done += 1
|
| 89 |
+
loop.update(1)
|
| 90 |
+
|
| 91 |
+
vocab = set(tok for seq in corpus for tok in seq)
|
| 92 |
+
vocab.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
|
| 93 |
+
self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
|
| 94 |
+
self.itos = {i: tok for tok, i in self.stoi.items()}
|
| 95 |
+
self.vocab_size = len(self.stoi)
|
| 96 |
+
|
| 97 |
+
def encode(self, text):
|
| 98 |
+
tokens = sum(self.tokenize(text), [])
|
| 99 |
+
output = []
|
| 100 |
+
i = 0
|
| 101 |
+
while i < len(tokens):
|
| 102 |
+
j = len(tokens)
|
| 103 |
+
while j > i:
|
| 104 |
+
candidate = ''.join(tokens[i:j])
|
| 105 |
+
if candidate in self.stoi:
|
| 106 |
+
output.append(self.stoi[candidate])
|
| 107 |
+
i = j
|
| 108 |
+
break
|
| 109 |
+
j -= 1
|
| 110 |
+
else:
|
| 111 |
+
output.append(self.stoi.get("<UNK>", 1))
|
| 112 |
+
i += 1
|
| 113 |
+
return output
|
| 114 |
+
|
| 115 |
+
def decode(self, token_ids):
|
| 116 |
+
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
| 117 |
+
text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
|
| 118 |
+
text = re.sub(r'\s([?.!,:;])', r'\1', text)
|
| 119 |
+
return text.strip()
|
| 120 |
+
|
| 121 |
+
def save(self, path):
|
| 122 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 123 |
+
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
| 124 |
+
|
| 125 |
+
def load(self, path):
|
| 126 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 127 |
+
data = json.load(f)
|
| 128 |
+
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| 129 |
+
self.itos = {int(v): k for k, v in self.stoi.items()}
|
| 130 |
+
self.vocab_size = len(self.stoi)
|
| 131 |
+
|
| 132 |
+
class SimpleTokenizr:
|
| 133 |
+
def __init__(self):
|
| 134 |
+
self.stoi = {}
|
| 135 |
+
self.itos = {}
|
| 136 |
+
|
| 137 |
+
def tokenize(self, text):
|
| 138 |
+
return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]", text.lower())
|
| 139 |
+
|
| 140 |
+
def train(self, texts):
|
| 141 |
+
vocab = set()
|
| 142 |
+
for text in texts:
|
| 143 |
+
tokens = self.tokenize(text)
|
| 144 |
+
vocab.update(tokens)
|
| 145 |
+
vocab.update(["<PAD>", "<UNK>", "<END>", "^user :", "minigpt :", "MiniGPT :", ":"])
|
| 146 |
+
sorted_vocab = sorted(vocab)
|
| 147 |
+
self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
|
| 148 |
+
self.itos = {idx: token for token, idx in self.stoi.items()}
|
| 149 |
+
|
| 150 |
+
def encode(self, text):
|
| 151 |
+
tokens = self.tokenize(text)
|
| 152 |
+
return [self.stoi.get(tok, self.stoi["<UNK>"]) for tok in tokens] + [self.stoi["<END>"]]
|
| 153 |
+
|
| 154 |
+
def decode(self, token_ids):
|
| 155 |
+
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
| 156 |
+
clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>"}]
|
| 157 |
+
text = ''
|
| 158 |
+
for i, tok in enumerate(clean_tokens):
|
| 159 |
+
if re.match(r"[.,!?;:]", tok):
|
| 160 |
+
text += tok
|
| 161 |
+
elif i > 0:
|
| 162 |
+
text += ' ' + tok
|
| 163 |
+
else:
|
| 164 |
+
text += tok
|
| 165 |
+
return text.strip().capitalize()
|
| 166 |
+
|
| 167 |
+
def save(self, path):
|
| 168 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 169 |
+
json.dump({"stoi": self.stoi, "itos": self.itos}, f)
|
| 170 |
+
|
| 171 |
+
def load(self, path):
|
| 172 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 173 |
+
data = json.load(f)
|
| 174 |
+
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| 175 |
+
self.itos = {int(k): v for v, k in self.stoi.items()}
|
| 176 |
+
|
| 177 |
+
def __len__(self):
|
| 178 |
+
return len(self.stoi)
|
| 179 |
+
|
| 180 |
+
@property
|
| 181 |
+
def vocab_size(self):
|
| 182 |
+
return len(self.stoi)
|
| 183 |
+
|
| 184 |
+
def validate(model, dataloader, device):
|
| 185 |
+
model.eval()
|
| 186 |
+
total_loss, correct, total = 0, 0, 0
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
for x, y in dataloader:
|
| 189 |
+
x, y = x.to(device), y.to(device)
|
| 190 |
+
logits = model(x)
|
| 191 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 192 |
+
total_loss += loss.item()
|
| 193 |
+
|
| 194 |
+
preds = torch.argmax(logits, dim=-1)
|
| 195 |
+
correct += (preds == y).sum().item()
|
| 196 |
+
total += y.numel()
|
| 197 |
+
|
| 198 |
+
avg_loss = total_loss / len(dataloader)
|
| 199 |
+
accuracy = 100 * correct / total
|
| 200 |
+
return avg_loss, accuracy
|
| 201 |
+
|
| 202 |
+
def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0):
|
| 203 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 204 |
+
model.to(device)
|
| 205 |
+
|
| 206 |
+
# 🔀 Proper train/val split
|
| 207 |
+
val_size = int(0.1 * len(dataset))
|
| 208 |
+
train_size = len(dataset) - val_size
|
| 209 |
+
train_set, val_set = random_split(dataset, [train_size, val_size])
|
| 210 |
+
|
| 211 |
+
train_loader = DataLoader(train_set, batch_size=10, shuffle=True, num_workers=2)
|
| 212 |
+
val_loader = DataLoader(val_set, batch_size=10, shuffle=False, num_workers=2)
|
| 213 |
+
|
| 214 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
| 215 |
+
|
| 216 |
+
checkpoint_path = "./trained-mini-gpt/checkpoint-mini-gpt.pth"
|
| 217 |
+
if os.path.exists(checkpoint_path):
|
| 218 |
+
checkpoint = torch.load(checkpoint_path)
|
| 219 |
+
if "model_state_dict" in checkpoint:
|
| 220 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 221 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 222 |
+
start_epoch = checkpoint["epoch"]
|
| 223 |
+
start_step = checkpoint["step"]
|
| 224 |
+
else:
|
| 225 |
+
model.load_state_dict(checkpoint)
|
| 226 |
+
else:
|
| 227 |
+
print("🚀 Starting from scratch.")
|
| 228 |
+
|
| 229 |
+
total_steps = start_step
|
| 230 |
+
|
| 231 |
+
for epoch in range(start_epoch, epochs):
|
| 232 |
+
model.train()
|
| 233 |
+
total_loss, correct, total = 0, 0, 0
|
| 234 |
+
|
| 235 |
+
loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
|
| 236 |
+
for step, (x, y) in loop:
|
| 237 |
+
x, y = x.to(device), y.to(device)
|
| 238 |
+
logits = model(x)
|
| 239 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 240 |
+
|
| 241 |
+
optimizer.zero_grad()
|
| 242 |
+
loss.backward()
|
| 243 |
+
optimizer.step()
|
| 244 |
+
|
| 245 |
+
total_loss += loss.item()
|
| 246 |
+
preds = torch.argmax(logits, dim=-1)
|
| 247 |
+
correct += (preds == y).sum().item()
|
| 248 |
+
total += y.numel()
|
| 249 |
+
acc = 100 * correct / total
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
loop.set_postfix(loss=loss.item(), acc=acc)
|
| 253 |
+
#if step % 100 == 0:
|
| 254 |
+
# torch.save({
|
| 255 |
+
# "model_state_dict": model.state_dict(),
|
| 256 |
+
# "optimizer_state_dict": optimizer.state_dict(),
|
| 257 |
+
# "epoch": epoch,
|
| 258 |
+
# "step": total_steps
|
| 259 |
+
# }, checkpoint_path)
|
| 260 |
+
|
| 261 |
+
# 🔍 Validate after each epoch
|
| 262 |
+
val_loss, val_acc = validate(model, val_loader, device)
|
| 263 |
+
print(f"✅ Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
|
| 264 |
+
|
| 265 |
+
# 💾 Save checkpoint
|
| 266 |
+
torch.save({
|
| 267 |
+
"model_state_dict": model.state_dict(),
|
| 268 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 269 |
+
"epoch": epoch,
|
| 270 |
+
"step": total_steps
|
| 271 |
+
}, checkpoint_path)
|
| 272 |
+
|
| 273 |
+
torch.save(model.state_dict(), "./trained-mini-gpt/mini-gpt.pth")
|
| 274 |
+
print("🎉 Training complete.")
|
customgen.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
|
| 6 |
+
# CONFIG
|
| 7 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 8 |
+
MAX_TOKENS = 27
|
| 9 |
+
NUM_SAMPLES = 50000
|
| 10 |
+
SAVE_PATH = "./customgens/mini_qna_dataset.jsonl"
|
| 11 |
+
|
| 12 |
+
# Extended Templates with Paraphrasing
|
| 13 |
+
TEMPLATES = [
|
| 14 |
+
# WHY
|
| 15 |
+
("Why do {subject} {action}?", "Because {reason}."),
|
| 16 |
+
("What makes {subject} {action}?", "It's because {reason}."),
|
| 17 |
+
("Explain why {subject} {action}.", "{reason} is the reason."),
|
| 18 |
+
|
| 19 |
+
# WHAT IS
|
| 20 |
+
("What is {thing}?", "{thing} is {definition}."),
|
| 21 |
+
("Define {thing}.", "{thing} refers to {definition}."),
|
| 22 |
+
("Can you tell me what {thing} means?", "Sure! It's {definition}."),
|
| 23 |
+
|
| 24 |
+
# HOW
|
| 25 |
+
("How does {thing} work?", "It works by {mechanism}."),
|
| 26 |
+
("What's the mechanism behind {thing}?", "It involves {mechanism}."),
|
| 27 |
+
("Explain how {thing} functions.", "{mechanism} is how it works."),
|
| 28 |
+
|
| 29 |
+
# WHEN / CONDITION
|
| 30 |
+
("What happens when {condition}?", "{result}."),
|
| 31 |
+
("Describe what occurs if {condition}.", "Usually, {result}."),
|
| 32 |
+
("When {condition}, what takes place?", "The result is {result}."),
|
| 33 |
+
|
| 34 |
+
# IMPORTANCE
|
| 35 |
+
("Why is {thing} important?", "Because {importance}."),
|
| 36 |
+
("What makes {thing} important?", "{importance} is why."),
|
| 37 |
+
("Is {thing} important? Why?", "Yes, because {importance}."),
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Knowledge Bank
|
| 41 |
+
DATA = {
|
| 42 |
+
"animals": {
|
| 43 |
+
"subjects": ["cats", "dogs", "birds", "fish"],
|
| 44 |
+
"actions": ["sleep a lot", "bark", "fly", "swim"],
|
| 45 |
+
"reasons": [
|
| 46 |
+
"they conserve energy",
|
| 47 |
+
"they are nocturnal",
|
| 48 |
+
"it's in their nature",
|
| 49 |
+
"they communicate that way"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
"science": {
|
| 53 |
+
"things": ["gravity", "photosynthesis", "a star", "an atom"],
|
| 54 |
+
"definitions": [
|
| 55 |
+
"a force that pulls objects together",
|
| 56 |
+
"the process plants use to make food",
|
| 57 |
+
"a burning ball of gas",
|
| 58 |
+
"the smallest unit of matter"
|
| 59 |
+
],
|
| 60 |
+
"mechanisms": [
|
| 61 |
+
"converting sunlight into energy",
|
| 62 |
+
"attracting objects with mass",
|
| 63 |
+
"splitting light into colors",
|
| 64 |
+
"colliding particles"
|
| 65 |
+
],
|
| 66 |
+
"conditions": ["you heat ice", "a star dies"],
|
| 67 |
+
"results": ["it melts", "it becomes a black hole"],
|
| 68 |
+
"importance": [
|
| 69 |
+
"it keeps us on Earth",
|
| 70 |
+
"it enables life on Earth"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
"food": {
|
| 74 |
+
"things": ["a waffle", "chocolate", "rice", "milk"],
|
| 75 |
+
"definitions": [
|
| 76 |
+
"a sweet, crispy batter cake",
|
| 77 |
+
"a sweet made from cocoa",
|
| 78 |
+
"a grain eaten daily in Asia",
|
| 79 |
+
"a white liquid from cows"
|
| 80 |
+
],
|
| 81 |
+
"importance": [
|
| 82 |
+
"it provides energy",
|
| 83 |
+
"it’s part of daily nutrition"
|
| 84 |
+
]
|
| 85 |
+
}
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
TOPIC_COUNT = {k: 0 for k in DATA}
|
| 89 |
+
MAX_PER_TOPIC = NUM_SAMPLES // len(DATA)
|
| 90 |
+
|
| 91 |
+
def sample_topic():
|
| 92 |
+
options = [t for t in DATA if TOPIC_COUNT[t] < MAX_PER_TOPIC]
|
| 93 |
+
return random.choice(options) if options else None
|
| 94 |
+
|
| 95 |
+
def fill_template(template_pair, topic_data):
|
| 96 |
+
q_temp, a_temp = template_pair
|
| 97 |
+
replacements = {
|
| 98 |
+
"{subject}": random.choice(topic_data.get("subjects", topic_data.get("things", ["something"]))),
|
| 99 |
+
"{action}": random.choice(topic_data.get("actions", ["do things"])),
|
| 100 |
+
"{reason}": random.choice(topic_data.get("reasons", ["that’s how they survive"])),
|
| 101 |
+
"{thing}": random.choice(topic_data.get("things", ["a thing"])),
|
| 102 |
+
"{definition}": random.choice(topic_data.get("definitions", ["an object used every day"])),
|
| 103 |
+
"{mechanism}": random.choice(topic_data.get("mechanisms", ["processing energy"])),
|
| 104 |
+
"{condition}": random.choice(topic_data.get("conditions", ["a change occurs"])),
|
| 105 |
+
"{result}": random.choice(topic_data.get("results", ["it transforms"])),
|
| 106 |
+
"{importance}": random.choice(topic_data.get("importance", ["it is vital to survival"]))
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
q = q_temp
|
| 110 |
+
a = a_temp
|
| 111 |
+
for key, val in replacements.items():
|
| 112 |
+
q = q.replace(key, val)
|
| 113 |
+
a = a.replace(key, val)
|
| 114 |
+
return q.strip(), a.strip()
|
| 115 |
+
|
| 116 |
+
def maybe_add_noise(q, a):
|
| 117 |
+
rand = random.random()
|
| 118 |
+
if rand < 0.05:
|
| 119 |
+
a = "I'm not sure."
|
| 120 |
+
elif rand < 0.10:
|
| 121 |
+
q += " Just wondering."
|
| 122 |
+
a = "Well, " + a
|
| 123 |
+
return q, a
|
| 124 |
+
|
| 125 |
+
def token_count(text):
|
| 126 |
+
return len(tokenizer.encode(text))
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
with open(SAVE_PATH, "w", encoding="utf-8") as f:
|
| 130 |
+
total = 0
|
| 131 |
+
pbar = tqdm(total=NUM_SAMPLES)
|
| 132 |
+
|
| 133 |
+
while total < NUM_SAMPLES:
|
| 134 |
+
topic = sample_topic()
|
| 135 |
+
if not topic:
|
| 136 |
+
break
|
| 137 |
+
template = random.choice(TEMPLATES)
|
| 138 |
+
topic_data = DATA[topic]
|
| 139 |
+
|
| 140 |
+
question, answer = fill_template(template, topic_data)
|
| 141 |
+
question, answer = maybe_add_noise(question, answer)
|
| 142 |
+
|
| 143 |
+
combined = f"Q: {question} A: {answer}"
|
| 144 |
+
if token_count(combined) <= MAX_TOKENS:
|
| 145 |
+
record = {
|
| 146 |
+
"question": question,
|
| 147 |
+
"answer": answer,
|
| 148 |
+
"text": combined
|
| 149 |
+
}
|
| 150 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 151 |
+
total += 1
|
| 152 |
+
TOPIC_COUNT[topic] += 1
|
| 153 |
+
pbar.update(1)
|
| 154 |
+
|
| 155 |
+
if total % 5000 == 0:
|
| 156 |
+
print(f"\n[Sample {total}]")
|
| 157 |
+
print("Q:", question)
|
| 158 |
+
print("A:", answer)
|
| 159 |
+
print("Tokens:", token_count(combined))
|
| 160 |
+
|
| 161 |
+
pbar.close()
|
| 162 |
+
print(f"\n✅ Saved {total} samples to {SAVE_PATH}")
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
main()
|
data/filtered_data.jsonl
CHANGED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:32618c1059ded33dbd43d56601000555b92ddd39e256aa1347c283720060ab59
|
| 3 |
+
size 12276354
|
data/overfit_data.jsonl
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 2 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 3 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 4 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 5 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 6 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 7 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 8 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 9 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
| 10 |
+
{"instruction": "Describe the structure of an atom.", "input": "", "output": "An atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom.", "text": "Describe the structure of an atom.\nAn atom is made up of a nucleus, which contains protons and neutrons, surrounded by electrons that travel in orbits around the nucleus. The protons and neutrons have a positive charge, while the electrons have a negative charge, resulting in an overall neutral atom. The number of each particle determines the atomic number and the type of atom."}
|
data/reason_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/reasoned2_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/reasoned_data.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/unfiltered_data.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0cb453e936376c150b34bc9424acb98744c669e9a8f9373aefab26c03a50b691
|
| 3 |
+
size 40249782
|
dataset.py
CHANGED
|
@@ -1,46 +1,67 @@
|
|
| 1 |
-
from concurrent.futures import thread
|
| 2 |
import json
|
| 3 |
-
import threading
|
| 4 |
import torch
|
|
|
|
| 5 |
import torch.nn.functional as F
|
| 6 |
from torch.utils.data import Dataset, DataLoader
|
| 7 |
-
from
|
| 8 |
from tqdm import tqdm
|
| 9 |
-
import re
|
| 10 |
-
import time
|
| 11 |
import os
|
|
|
|
| 12 |
from collections import Counter
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
class ChatDataset(Dataset):
|
| 15 |
-
def __init__(self,
|
| 16 |
-
self.
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def __len__(self):
|
| 29 |
-
return len(self.
|
| 30 |
|
| 31 |
def __getitem__(self, idx):
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
|
|
|
| 35 |
class MiniBPETokenizr:
|
| 36 |
def __init__(self):
|
| 37 |
-
self.stoi = {}
|
| 38 |
-
self.itos = {}
|
| 39 |
self.vocab_size = 0
|
| 40 |
|
| 41 |
-
def __len__(self):
|
| 42 |
-
return len(self.stoi)
|
| 43 |
-
|
| 44 |
def tokenize(self, text):
|
| 45 |
text = text.lower().strip()
|
| 46 |
words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
|
|
@@ -49,21 +70,20 @@ class MiniBPETokenizr:
|
|
| 49 |
def get_stats(self, corpus):
|
| 50 |
pairs = Counter()
|
| 51 |
for tokens in corpus:
|
| 52 |
-
for i in range(len(tokens)-1):
|
| 53 |
-
pairs[(tokens[i], tokens[i+1])] += 1
|
| 54 |
return pairs
|
| 55 |
|
| 56 |
def merge_vocab(self, corpus, pair_to_merge):
|
| 57 |
-
merged = []
|
| 58 |
bigram = re.escape(' '.join(pair_to_merge))
|
| 59 |
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
|
| 60 |
-
|
| 61 |
for tokens in corpus:
|
| 62 |
token_str = ' '.join(tokens)
|
| 63 |
token_str = pattern.sub(''.join(pair_to_merge), token_str)
|
| 64 |
merged.append(token_str.split())
|
| 65 |
return merged
|
| 66 |
-
|
| 67 |
def train(self, texts, merge_limit=1000):
|
| 68 |
corpus = [sum(self.tokenize(t), []) for t in texts]
|
| 69 |
merges_done = 0
|
|
@@ -72,22 +92,16 @@ class MiniBPETokenizr:
|
|
| 72 |
while merges_done < merge_limit:
|
| 73 |
pairs = self.get_stats(corpus)
|
| 74 |
if not pairs:
|
| 75 |
-
tqdm.write("⚠️ No more pairs to merge.")
|
| 76 |
break
|
| 77 |
best = max(pairs, key=pairs.get)
|
| 78 |
corpus = self.merge_vocab(corpus, best)
|
| 79 |
merges_done += 1
|
| 80 |
-
loop.
|
| 81 |
-
loop.refresh()
|
| 82 |
-
#tqdm.write(f"best: {best}")
|
| 83 |
-
#tqdm.write(f"corpus: {corpus}")
|
| 84 |
|
| 85 |
vocab = set(tok for seq in corpus for tok in seq)
|
| 86 |
-
vocab.update(
|
| 87 |
self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
|
| 88 |
self.itos = {i: tok for tok, i in self.stoi.items()}
|
| 89 |
-
print(f"stoi: {len(self.stoi)}")
|
| 90 |
-
print(f"itos: {len(self.itos)}")
|
| 91 |
self.vocab_size = len(self.stoi)
|
| 92 |
|
| 93 |
def encode(self, text):
|
|
@@ -107,12 +121,11 @@ class MiniBPETokenizr:
|
|
| 107 |
output.append(self.stoi.get("<UNK>", 1))
|
| 108 |
i += 1
|
| 109 |
return output
|
| 110 |
-
|
| 111 |
def decode(self, token_ids):
|
| 112 |
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
| 113 |
-
# Join tokens and remove </w> markers, then fix spacing before punctuation
|
| 114 |
text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
|
| 115 |
-
text = re.sub(r'\s([?.!,:;])', r'\1', text)
|
| 116 |
return text.strip()
|
| 117 |
|
| 118 |
def save(self, path):
|
|
@@ -125,24 +138,21 @@ class MiniBPETokenizr:
|
|
| 125 |
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| 126 |
self.itos = {int(v): k for k, v in self.stoi.items()}
|
| 127 |
self.vocab_size = len(self.stoi)
|
| 128 |
-
|
| 129 |
class SimpleTokenizr:
|
| 130 |
def __init__(self):
|
| 131 |
self.stoi = {}
|
| 132 |
self.itos = {}
|
| 133 |
|
| 134 |
def tokenize(self, text):
|
| 135 |
-
|
| 136 |
-
#return re.findall(r"[a-zA-Z]+|\d+|[^\w\s]", text.lower()) -- somewhat good
|
| 137 |
-
return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]",text.lower())
|
| 138 |
|
| 139 |
def train(self, texts):
|
| 140 |
vocab = set()
|
| 141 |
for text in texts:
|
| 142 |
tokens = self.tokenize(text)
|
| 143 |
vocab.update(tokens)
|
| 144 |
-
|
| 145 |
-
vocab.update(["<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":","Minigpt"])
|
| 146 |
sorted_vocab = sorted(vocab)
|
| 147 |
self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
|
| 148 |
self.itos = {idx: token for token, idx in self.stoi.items()}
|
|
@@ -153,13 +163,10 @@ class SimpleTokenizr:
|
|
| 153 |
|
| 154 |
def decode(self, token_ids):
|
| 155 |
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
| 156 |
-
|
| 157 |
-
clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>","^user :","minigpt :","Minigpt :","MiniGPT :",":"}]
|
| 158 |
-
|
| 159 |
-
# Join with proper formatting
|
| 160 |
text = ''
|
| 161 |
for i, tok in enumerate(clean_tokens):
|
| 162 |
-
if re.match(r"[.,!?;:]", tok):
|
| 163 |
text += tok
|
| 164 |
elif i > 0:
|
| 165 |
text += ' ' + tok
|
|
@@ -184,16 +191,44 @@ class SimpleTokenizr:
|
|
| 184 |
def vocab_size(self):
|
| 185 |
return len(self.stoi)
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
|
|
|
|
| 189 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 190 |
model.to(device)
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
| 197 |
if os.path.exists(checkpoint_path):
|
| 198 |
checkpoint = torch.load(checkpoint_path)
|
| 199 |
if "model_state_dict" in checkpoint:
|
|
@@ -202,19 +237,34 @@ def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_ste
|
|
| 202 |
start_epoch = checkpoint["epoch"]
|
| 203 |
start_step = checkpoint["step"]
|
| 204 |
else:
|
| 205 |
-
print("⚠️ Legacy checkpoint detected. Loading only model weights.")
|
| 206 |
model.load_state_dict(checkpoint)
|
| 207 |
else:
|
| 208 |
print("🚀 Starting from scratch.")
|
| 209 |
|
| 210 |
total_steps = start_step
|
| 211 |
-
|
| 212 |
-
#scheduler = OneCycleLR(optimizer,max_lr=1e-4,total_steps=epochs * len(dataloader),pct_start=0.1,anneal_strategy="linear")
|
| 213 |
for epoch in range(start_epoch, epochs):
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
| 216 |
for step, (x, y) in loop:
|
| 217 |
x, y = x.to(device), y.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
logits = model(x)
|
| 219 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 220 |
|
|
@@ -223,32 +273,34 @@ def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_ste
|
|
| 223 |
optimizer.step()
|
| 224 |
|
| 225 |
total_loss += loss.item()
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
|
|
|
| 2 |
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from tokenizers import Tokenizer
|
| 7 |
from tqdm import tqdm
|
|
|
|
|
|
|
| 8 |
import os
|
| 9 |
+
import re
|
| 10 |
from collections import Counter
|
| 11 |
+
import multiprocessing
|
| 12 |
+
from torch.utils.data import random_split
|
| 13 |
+
|
| 14 |
+
multiprocessing.set_start_method("spawn", force=True)
|
| 15 |
|
| 16 |
class ChatDataset(Dataset):
|
| 17 |
+
def __init__(self, data, tokenizer, block_size=64):
|
| 18 |
+
self.tokenizer = tokenizer
|
| 19 |
+
self.block_size = block_size
|
| 20 |
+
self.data = self.tokenize_data(data)
|
| 21 |
+
|
| 22 |
+
def tokenize_data(self, data):
|
| 23 |
+
chunks = []
|
| 24 |
+
with open(data, "r", encoding="utf-8") as f:
|
| 25 |
+
for d in f:
|
| 26 |
+
line = json.loads(d.strip())
|
| 27 |
+
# Fix duplicated instruction
|
| 28 |
+
text = "^User: " + line["instruction"].strip() + " MiniGPT: " + line["output"].strip() + " <END>"
|
| 29 |
+
encoding = self.tokenizer.encode(text)
|
| 30 |
+
tokens = encoding.ids
|
| 31 |
+
|
| 32 |
+
# You confirmed your 10 examples are long enough, so no change to this filter.
|
| 33 |
+
# If you were to use shorter data later, you'd need to reconsider this.
|
| 34 |
+
if len(tokens) < self.block_size:
|
| 35 |
+
print(f"Skipping short example (length {len(tokens)} < block_size {self.block_size}): {text[:50]}...")
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
# 🎯 CHANGE 3: Use overlapping chunks (stride = 1)
|
| 39 |
+
# This drastically increases the effective number of training samples
|
| 40 |
+
# derived from your limited raw data.
|
| 41 |
+
stride = 1 # Change this to 1 for max overlap, or self.block_size // 2 for moderate
|
| 42 |
+
for i in range(0, len(tokens) - self.block_size + 1, stride):
|
| 43 |
+
chunk = tokens[i:i + self.block_size]
|
| 44 |
+
if len(chunk) == self.block_size: # Ensures only full blocks are added
|
| 45 |
+
chunks.append(chunk)
|
| 46 |
+
print(f"Dataset created with {len(chunks)} total training chunks.") # Added print
|
| 47 |
+
return chunks
|
| 48 |
|
| 49 |
def __len__(self):
|
| 50 |
+
return len(self.data)
|
| 51 |
|
| 52 |
def __getitem__(self, idx):
|
| 53 |
+
chunk = self.data[idx]
|
| 54 |
+
x = torch.tensor(chunk[:-1], dtype=torch.long) # Ensure dtype is long
|
| 55 |
+
y = torch.tensor(chunk[1:], dtype=torch.long) # Ensure dtype is long
|
| 56 |
+
return x, y
|
| 57 |
|
| 58 |
+
# MiniBPETokenizr and SimpleTokenizr classes (no changes, but included for completeness)
|
| 59 |
class MiniBPETokenizr:
|
| 60 |
def __init__(self):
|
| 61 |
+
self.stoi = {}
|
| 62 |
+
self.itos = {}
|
| 63 |
self.vocab_size = 0
|
| 64 |
|
|
|
|
|
|
|
|
|
|
| 65 |
def tokenize(self, text):
|
| 66 |
text = text.lower().strip()
|
| 67 |
words = re.findall(r"[a-zA-Z0-9]+|[^\w\s]", text)
|
|
|
|
| 70 |
def get_stats(self, corpus):
|
| 71 |
pairs = Counter()
|
| 72 |
for tokens in corpus:
|
| 73 |
+
for i in range(len(tokens) - 1):
|
| 74 |
+
pairs[(tokens[i], tokens[i + 1])] += 1
|
| 75 |
return pairs
|
| 76 |
|
| 77 |
def merge_vocab(self, corpus, pair_to_merge):
|
|
|
|
| 78 |
bigram = re.escape(' '.join(pair_to_merge))
|
| 79 |
pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
|
| 80 |
+
merged = []
|
| 81 |
for tokens in corpus:
|
| 82 |
token_str = ' '.join(tokens)
|
| 83 |
token_str = pattern.sub(''.join(pair_to_merge), token_str)
|
| 84 |
merged.append(token_str.split())
|
| 85 |
return merged
|
| 86 |
+
|
| 87 |
def train(self, texts, merge_limit=1000):
|
| 88 |
corpus = [sum(self.tokenize(t), []) for t in texts]
|
| 89 |
merges_done = 0
|
|
|
|
| 92 |
while merges_done < merge_limit:
|
| 93 |
pairs = self.get_stats(corpus)
|
| 94 |
if not pairs:
|
|
|
|
| 95 |
break
|
| 96 |
best = max(pairs, key=pairs.get)
|
| 97 |
corpus = self.merge_vocab(corpus, best)
|
| 98 |
merges_done += 1
|
| 99 |
+
loop.update(1)
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
vocab = set(tok for seq in corpus for tok in seq)
|
| 102 |
+
vocab.update(["<PAD>", "<UNK>", "<END>", "^user:", "minigpt:"])
|
| 103 |
self.stoi = {tok: i for i, tok in enumerate(sorted(vocab))}
|
| 104 |
self.itos = {i: tok for tok, i in self.stoi.items()}
|
|
|
|
|
|
|
| 105 |
self.vocab_size = len(self.stoi)
|
| 106 |
|
| 107 |
def encode(self, text):
|
|
|
|
| 121 |
output.append(self.stoi.get("<UNK>", 1))
|
| 122 |
i += 1
|
| 123 |
return output
|
| 124 |
+
|
| 125 |
def decode(self, token_ids):
|
| 126 |
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
|
|
|
| 127 |
text = ' '.join(t.replace('</w>', '') for t in tokens if t not in {"<PAD>", "<END>", "<UNK>"})
|
| 128 |
+
text = re.sub(r'\s([?.!,:;])', r'\1', text)
|
| 129 |
return text.strip()
|
| 130 |
|
| 131 |
def save(self, path):
|
|
|
|
| 138 |
self.stoi = {k: int(v) for k, v in data["stoi"].items()}
|
| 139 |
self.itos = {int(v): k for k, v in self.stoi.items()}
|
| 140 |
self.vocab_size = len(self.stoi)
|
| 141 |
+
|
| 142 |
class SimpleTokenizr:
|
| 143 |
def __init__(self):
|
| 144 |
self.stoi = {}
|
| 145 |
self.itos = {}
|
| 146 |
|
| 147 |
def tokenize(self, text):
|
| 148 |
+
return re.findall(r"[a-zA-Z']+|\d+|[^\w\s]", text.lower())
|
|
|
|
|
|
|
| 149 |
|
| 150 |
def train(self, texts):
|
| 151 |
vocab = set()
|
| 152 |
for text in texts:
|
| 153 |
tokens = self.tokenize(text)
|
| 154 |
vocab.update(tokens)
|
| 155 |
+
vocab.update(["<PAD>", "<UNK>", "<END>", "^user :", "minigpt :", "MiniGPT :", ":"])
|
|
|
|
| 156 |
sorted_vocab = sorted(vocab)
|
| 157 |
self.stoi = {token: idx for idx, token in enumerate(sorted_vocab)}
|
| 158 |
self.itos = {idx: token for token, idx in self.stoi.items()}
|
|
|
|
| 163 |
|
| 164 |
def decode(self, token_ids):
|
| 165 |
tokens = [self.itos.get(i, "<UNK>") for i in token_ids]
|
| 166 |
+
clean_tokens = [tok for tok in tokens if tok not in {"<PAD>", "<UNK>", "<END>"}]
|
|
|
|
|
|
|
|
|
|
| 167 |
text = ''
|
| 168 |
for i, tok in enumerate(clean_tokens):
|
| 169 |
+
if re.match(r"[.,!?;:]", tok):
|
| 170 |
text += tok
|
| 171 |
elif i > 0:
|
| 172 |
text += ' ' + tok
|
|
|
|
| 191 |
def vocab_size(self):
|
| 192 |
return len(self.stoi)
|
| 193 |
|
| 194 |
+
def validate(model, dataloader, device):
|
| 195 |
+
model.eval()
|
| 196 |
+
total_loss, correct, total = 0, 0, 0
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
for x, y in dataloader:
|
| 199 |
+
x, y = x.to(device), y.to(device)
|
| 200 |
+
logits = model(x)
|
| 201 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 202 |
+
total_loss += loss.item()
|
| 203 |
+
|
| 204 |
+
preds = torch.argmax(logits, dim=-1)
|
| 205 |
+
correct += (preds == y).sum().item()
|
| 206 |
+
total += y.numel()
|
| 207 |
+
|
| 208 |
+
avg_loss = total_loss / len(dataloader)
|
| 209 |
+
accuracy = 100 * correct / total
|
| 210 |
+
return avg_loss, accuracy
|
| 211 |
|
| 212 |
+
# 🎯 CHANGE 4: Add learning_rate parameter to the train function
|
| 213 |
+
def train(model, dataset, tokenizer, epochs, filepathh, start_epoch=0, start_step=0, learning_rate=5e-5):
|
| 214 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 215 |
model.to(device)
|
| 216 |
|
| 217 |
+
# 🔀 Proper train/val split
|
| 218 |
+
val_size = int(0.1 * len(dataset))
|
| 219 |
+
train_size = len(dataset) - val_size
|
| 220 |
+
train_set, val_set = random_split(dataset, [train_size, val_size])
|
| 221 |
|
| 222 |
+
# 🎯 CHANGE 5: Reduce batch_size and num_workers for debugging tiny datasets
|
| 223 |
+
# Batch size 1 or equal to len(train_set) is ideal for testing memorization
|
| 224 |
+
# num_workers=0 simplifies debugging.
|
| 225 |
+
train_loader = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=0)
|
| 226 |
+
val_loader = DataLoader(val_set, batch_size=1, shuffle=False, num_workers=0)
|
| 227 |
|
| 228 |
+
# Use the passed learning_rate
|
| 229 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
| 230 |
+
|
| 231 |
+
checkpoint_path = "./trained-mini-gpt/checkpoint-mini-gpt.pth"
|
| 232 |
if os.path.exists(checkpoint_path):
|
| 233 |
checkpoint = torch.load(checkpoint_path)
|
| 234 |
if "model_state_dict" in checkpoint:
|
|
|
|
| 237 |
start_epoch = checkpoint["epoch"]
|
| 238 |
start_step = checkpoint["step"]
|
| 239 |
else:
|
|
|
|
| 240 |
model.load_state_dict(checkpoint)
|
| 241 |
else:
|
| 242 |
print("🚀 Starting from scratch.")
|
| 243 |
|
| 244 |
total_steps = start_step
|
| 245 |
+
|
|
|
|
| 246 |
for epoch in range(start_epoch, epochs):
|
| 247 |
+
model.train()
|
| 248 |
+
total_loss, correct, total = 0, 0, 0
|
| 249 |
+
|
| 250 |
+
loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
|
| 251 |
for step, (x, y) in loop:
|
| 252 |
x, y = x.to(device), y.to(device)
|
| 253 |
+
|
| 254 |
+
# 🎯 CHANGE 6: Add detailed print statements to observe learning
|
| 255 |
+
# This is CRUCIAL for debugging underfitting on tiny data.
|
| 256 |
+
if step % 1 == 0: # Print every step for tiny datasets
|
| 257 |
+
input_ids_cpu = x[0].cpu().tolist()
|
| 258 |
+
target_ids_cpu = y[0].cpu().tolist()
|
| 259 |
+
|
| 260 |
+
decoded_input = tokenizer.decode(input_ids_cpu)
|
| 261 |
+
decoded_target = tokenizer.decode(target_ids_cpu)
|
| 262 |
+
|
| 263 |
+
print(f"\n--- Epoch {epoch+1}, Step {step} ---")
|
| 264 |
+
print(f"Input (decoded): '{decoded_input}'")
|
| 265 |
+
print(f"Target (decoded): '{decoded_target}'")
|
| 266 |
+
|
| 267 |
+
|
| 268 |
logits = model(x)
|
| 269 |
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
|
| 270 |
|
|
|
|
| 273 |
optimizer.step()
|
| 274 |
|
| 275 |
total_loss += loss.item()
|
| 276 |
+
preds = torch.argmax(logits, dim=-1)
|
| 277 |
+
correct += (preds == y).sum().item()
|
| 278 |
+
total += y.numel()
|
| 279 |
+
acc = 100 * correct / total
|
| 280 |
+
|
| 281 |
+
loop.set_postfix(loss=loss.item(), acc=acc)
|
| 282 |
+
|
| 283 |
+
# After optimizer.step(), print predicted output to see if it matches target
|
| 284 |
+
if step % 1 == 0:
|
| 285 |
+
predicted_logits_cpu = logits[0, :, :].cpu() # For first example in batch
|
| 286 |
+
predicted_ids = torch.argmax(predicted_logits_cpu, dim=-1).tolist()
|
| 287 |
+
decoded_predicted = tokenizer.decode(predicted_ids)
|
| 288 |
+
print(f"Predicted (decoded): '{decoded_predicted}'")
|
| 289 |
+
print(f"Current Batch Loss: {loss.item():.4f}")
|
| 290 |
+
print(f"Current Batch Accuracy: {100 * (preds == y).float().mean().item():.2f}%") # Accuracy for current batch
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# 🔍 Validate after each epoch
|
| 294 |
+
val_loss, val_acc = validate(model, val_loader, device)
|
| 295 |
+
print(f"✅ Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.2f}%")
|
| 296 |
+
|
| 297 |
+
# 💾 Save checkpoint
|
| 298 |
+
torch.save({
|
| 299 |
+
"model_state_dict": model.state_dict(),
|
| 300 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 301 |
+
"epoch": epoch,
|
| 302 |
+
"step": total_steps
|
| 303 |
+
}, checkpoint_path)
|
| 304 |
+
|
| 305 |
+
torch.save(model.state_dict(), "./trained-mini-gpt/mini-gpt.pth")
|
| 306 |
+
print("🎉 Training complete.")
|
datasetgen-synthetic.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
topics = {
|
| 5 |
+
"Math Reasoning": [
|
| 6 |
+
("What is {a} + {b}?", "{a} + {b} is {sum}."),
|
| 7 |
+
("If you have {a} apples and get {b} more, how many?", "{a} + {b} = {sum} apples."),
|
| 8 |
+
("Solve: {a} + {b}", "The answer is {sum}.")
|
| 9 |
+
],
|
| 10 |
+
"Causality": [
|
| 11 |
+
("If it rains, what might happen?", "If it rains, the ground may become wet."),
|
| 12 |
+
("Why do plants grow towards light?", "Because light is a stimulus and plants respond by growing towards it."),
|
| 13 |
+
("What happens if you drop a glass?", "It will likely break due to gravity.")
|
| 14 |
+
],
|
| 15 |
+
"Grammar Correction": [
|
| 16 |
+
("Correct this: 'He go to school everyday.'", "'He goes to school every day.'"),
|
| 17 |
+
("Fix this sentence: 'I has two cat.'", "'I have two cats.'"),
|
| 18 |
+
("Can you fix this sentence: 'he have two taco.'", "'He has two tacos.'"),
|
| 19 |
+
("What’s the correct form of: 'She don't like it.'", "'She doesn't like it.'")
|
| 20 |
+
],
|
| 21 |
+
"Common Sense": [
|
| 22 |
+
("Can a person eat soup with a fork?", "No, it is impractical to eat soup with a fork."),
|
| 23 |
+
("Should you touch fire?", "No, touching fire can cause burns."),
|
| 24 |
+
("If you're tired, what should you do?", "You should rest or sleep.")
|
| 25 |
+
],
|
| 26 |
+
"World Knowledge": [
|
| 27 |
+
("What is the capital of France?", "Paris is the capital of France."),
|
| 28 |
+
("Who was the first president of the USA?", "George Washington."),
|
| 29 |
+
("What currency is used in Japan?", "The Japanese Yen.")
|
| 30 |
+
],
|
| 31 |
+
"Instruction Following": [
|
| 32 |
+
("Open the window and turn off the light.", "Opening the window. Turning off the light."),
|
| 33 |
+
("Sort these numbers in ascending order: 5, 2, 8.", "2, 5, 8."),
|
| 34 |
+
("Sort these numbers in descending order: 5, 2, 8.", "8, 5, 2."),
|
| 35 |
+
("Describe how to make a sandwich.", "Take two slices of bread, add your fillings, and place one slice on top.")
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def generate_sample(id, topic):
|
| 40 |
+
pattern = random.choice(topics[topic])
|
| 41 |
+
if topic == "Math Reasoning":
|
| 42 |
+
a = random.randint(1, 20)
|
| 43 |
+
b = random.randint(1, 20)
|
| 44 |
+
sum_ab = a + b
|
| 45 |
+
input_str = pattern[0].format(a=a, b=b, sum=sum_ab)
|
| 46 |
+
output_str = pattern[1].format(a=a, b=b, sum=sum_ab)
|
| 47 |
+
else:
|
| 48 |
+
input_str = pattern[0]
|
| 49 |
+
output_str = pattern[1]
|
| 50 |
+
return {
|
| 51 |
+
"id": id,
|
| 52 |
+
"topic": topic,
|
| 53 |
+
"input": input_str,
|
| 54 |
+
"output": output_str
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
def generate_dataset(n=10000):
|
| 58 |
+
dataset = []
|
| 59 |
+
topic_list = list(topics.keys())
|
| 60 |
+
for i in range(n):
|
| 61 |
+
topic = random.choice(topic_list)
|
| 62 |
+
sample = generate_sample(i, topic)
|
| 63 |
+
dataset.append(sample)
|
| 64 |
+
return dataset
|
| 65 |
+
|
| 66 |
+
def save_as_jsonl(data, path="./data/reasoned_data.jsonl"):
|
| 67 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 68 |
+
for item in data:
|
| 69 |
+
json.dump(item, f, ensure_ascii=False)
|
| 70 |
+
f.write("\n")
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
data = generate_dataset(10000)
|
| 74 |
+
save_as_jsonl(data)
|
| 75 |
+
print("Saved to ./data/reasoned_data.jsonl")
|
datasetgen.py
CHANGED
|
@@ -2,19 +2,44 @@ from datasets import load_dataset
|
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
from tqdm import tqdm
|
| 5 |
-
from filter import filterdata
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
|
| 9 |
convo = []
|
| 10 |
-
buffer = {}
|
| 11 |
|
| 12 |
-
print("
|
| 13 |
-
for entry in tqdm(ds):
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
from tqdm import tqdm
|
| 5 |
+
from filter import filterdata # Custom filtering logic
|
| 6 |
|
| 7 |
+
# Load 110k samples from OpenWebText
|
| 8 |
+
print("📦 Loading dataset (110k samples)...")
|
| 9 |
+
ds = load_dataset("OpenAssistant/oasst1",split="train")
|
| 10 |
|
| 11 |
convo = []
|
|
|
|
| 12 |
|
| 13 |
+
print("⚙️ Processing dataset into Q&A pairs...")
|
| 14 |
+
for entry in tqdm(ds, unit='samples'):
|
| 15 |
+
if entry.get("role") == "assistant" and entry.get("text") and entry.get("parent_id"):
|
| 16 |
+
parent = next((x for x in ds if x["message_id"] == entry["parent_id"]), None)
|
| 17 |
+
if parent and parent.get("role") == "user":
|
| 18 |
+
convo.append({
|
| 19 |
+
"input": parent["text"],
|
| 20 |
+
"output": entry["text"]
|
| 21 |
+
})
|
| 22 |
|
| 23 |
+
#convo.append({
|
| 24 |
+
# "instruction": instruction,
|
| 25 |
+
# "input": user_input,
|
| 26 |
+
# "output": bot_response,
|
| 27 |
+
# "text": full_instruction + "\n" + bot_response
|
| 28 |
+
#})
|
| 29 |
+
|
| 30 |
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
print(f"✅ Got {len(convo)} usable Q&A pairs.")
|
| 34 |
+
|
| 35 |
+
# Save unfiltered data
|
| 36 |
+
unfiltered_path = "./data/unfiltered_data.jsonl"
|
| 37 |
+
with open(unfiltered_path, "w", encoding="utf-8") as f:
|
| 38 |
+
for line in convo:
|
| 39 |
+
f.write(json.dumps(line, ensure_ascii=False) + "\n")
|
| 40 |
+
|
| 41 |
+
print(f"📝 Saved unfiltered data to {unfiltered_path}")
|
| 42 |
+
|
| 43 |
+
# Run filtering
|
| 44 |
+
print("🚿 Starting filtering...")
|
| 45 |
+
filterdata(convo)
|
datasetgen2.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# datasetgen.py
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
from faker import Faker
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
fake = Faker()
|
| 9 |
+
|
| 10 |
+
OUTPUT_PATH = "data/filtered_data.jsonl"
|
| 11 |
+
#os.makedirs("datasets", exist_ok=True)
|
| 12 |
+
|
| 13 |
+
def generate_example():
|
| 14 |
+
"""Generates a single GPT-like QA pair"""
|
| 15 |
+
q_templates = [
|
| 16 |
+
"What is {}?",
|
| 17 |
+
"How do you {}?",
|
| 18 |
+
"Why is {} important?",
|
| 19 |
+
"Give me an example of {}.",
|
| 20 |
+
"Explain {} in simple terms.",
|
| 21 |
+
"Compare {} and {}.",
|
| 22 |
+
"What happens if {}?",
|
| 23 |
+
"Can you summarize {}?"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
concepts = [
|
| 27 |
+
"machine learning", "quantum physics", "natural selection",
|
| 28 |
+
"photosynthesis", "neural networks", "global warming",
|
| 29 |
+
"black holes", "economic inflation", "probability", "blockchain"
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
actions = [
|
| 33 |
+
"train a neural network", "reduce carbon emissions", "make bread",
|
| 34 |
+
"calculate probability", "grow tomatoes", "optimize code",
|
| 35 |
+
"write a resume", "design a logo", "encrypt data", "learn Python"
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
concept = random.choice(concepts)
|
| 39 |
+
action = random.choice(actions)
|
| 40 |
+
|
| 41 |
+
template = random.choice(q_templates)
|
| 42 |
+
|
| 43 |
+
if '{}' in template and template.count('{}') == 1:
|
| 44 |
+
question = template.format(random.choice([concept, action]))
|
| 45 |
+
else:
|
| 46 |
+
question = template.format(concept, random.choice(concepts))
|
| 47 |
+
|
| 48 |
+
# Simulate an answer (in real GPT training you'd use real completions)
|
| 49 |
+
answer = f"{fake.paragraph(nb_sentences=4)}"
|
| 50 |
+
|
| 51 |
+
return {
|
| 52 |
+
"text": "^User: "+ question + "\nMiniGPT: " + answer + " <END>",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def generate_dataset(n=5000):
|
| 56 |
+
with open(OUTPUT_PATH, "w", encoding="utf-8") as f:
|
| 57 |
+
for _ in tqdm(range(n), desc="Generating Examples"):
|
| 58 |
+
example = generate_example()
|
| 59 |
+
f.write(json.dumps(example, ensure_ascii=False) + "\n")
|
| 60 |
+
|
| 61 |
+
print(f"\n✅ Dataset saved to: {OUTPUT_PATH}")
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
generate_dataset(5000)
|
datasets/5k_synthetic_dataset.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
filter.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
from dataset import SimpleTokenizr
|
|
|
|
| 4 |
|
| 5 |
tokenizer = SimpleTokenizr()
|
| 6 |
|
|
@@ -13,7 +14,7 @@ def filterdata(data):
|
|
| 13 |
unused_lines = 0
|
| 14 |
low_quality_lines = 0
|
| 15 |
long_lines = 0
|
| 16 |
-
for line in data:
|
| 17 |
decoded = json.dumps(line)
|
| 18 |
data = json.loads(decoded)
|
| 19 |
text = data.get("text","")
|
|
@@ -23,20 +24,20 @@ def filterdata(data):
|
|
| 23 |
unused_lines += 1
|
| 24 |
unused.append(line)
|
| 25 |
else:
|
| 26 |
-
if len(encoded)
|
| 27 |
filtered_lines += 1
|
| 28 |
filtered.append(line)
|
| 29 |
-
if len(encoded)
|
| 30 |
long_lines += 1
|
| 31 |
long.append(text)
|
| 32 |
|
| 33 |
print(f"Filtered {filtered_lines} successfully!")
|
| 34 |
print(f"Removed {unused_lines} from data.")
|
| 35 |
-
print(f"Removed {long_lines} from data (too
|
| 36 |
#print(f"Removed {low_quality} from data (low quality).")
|
| 37 |
|
| 38 |
|
| 39 |
-
with open("./
|
| 40 |
for lines in filtered:
|
| 41 |
dump = json.dumps(lines)
|
| 42 |
decoded = json.loads(dump)
|
|
|
|
| 1 |
import json
|
| 2 |
import re
|
| 3 |
from dataset import SimpleTokenizr
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
|
| 6 |
tokenizer = SimpleTokenizr()
|
| 7 |
|
|
|
|
| 14 |
unused_lines = 0
|
| 15 |
low_quality_lines = 0
|
| 16 |
long_lines = 0
|
| 17 |
+
for line in tqdm(data, unit='B', unit_scale=True, unit_divisor=1024):
|
| 18 |
decoded = json.dumps(line)
|
| 19 |
data = json.loads(decoded)
|
| 20 |
text = data.get("text","")
|
|
|
|
| 24 |
unused_lines += 1
|
| 25 |
unused.append(line)
|
| 26 |
else:
|
| 27 |
+
if len(encoded) >= 64:
|
| 28 |
filtered_lines += 1
|
| 29 |
filtered.append(line)
|
| 30 |
+
if len(encoded) < 64:
|
| 31 |
long_lines += 1
|
| 32 |
long.append(text)
|
| 33 |
|
| 34 |
print(f"Filtered {filtered_lines} successfully!")
|
| 35 |
print(f"Removed {unused_lines} from data.")
|
| 36 |
+
print(f"Removed {long_lines} from data (too short).")
|
| 37 |
#print(f"Removed {low_quality} from data (low quality).")
|
| 38 |
|
| 39 |
|
| 40 |
+
with open("./data/filtered_data.jsonl", "w", encoding="utf-8") as f:
|
| 41 |
for lines in filtered:
|
| 42 |
dump = json.dumps(lines)
|
| 43 |
decoded = json.loads(dump)
|
minigpt.py
CHANGED
|
@@ -4,16 +4,20 @@ from model import MiniGPT
|
|
| 4 |
from dataset import MiniBPETokenizr,SimpleTokenizr
|
| 5 |
import json
|
| 6 |
import os
|
|
|
|
| 7 |
|
| 8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
| 10 |
# Load tokenizer
|
| 11 |
-
tokenizer = SimpleTokenizr()
|
| 12 |
-
tokenizer.load("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
|
|
|
|
| 13 |
|
| 14 |
# Load model
|
| 15 |
-
model = MiniGPT(vocab_size=
|
| 16 |
-
model.load_state_dict(torch.load("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth", map_location=device) if os.path.exists("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth") else torch.load("./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth", map_location=device)["model_state_dict"] )
|
|
|
|
|
|
|
| 17 |
model.eval().to(device)
|
| 18 |
totalparams = sum(p.numel() for p in model.parameters())
|
| 19 |
print(f"Model total params: {totalparams:,}")
|
|
@@ -30,7 +34,7 @@ def sample_token(logits, temperature=1.0):
|
|
| 30 |
return torch.multinomial(probs, num_samples=1).item()
|
| 31 |
|
| 32 |
def generate_reply(prompt, max_tokens=100):
|
| 33 |
-
tokens = tokenizer.encode(prompt)
|
| 34 |
if not tokens:
|
| 35 |
print("⚠️ Empty prompt after encoding.")
|
| 36 |
return
|
|
@@ -44,8 +48,8 @@ def generate_reply(prompt, max_tokens=100):
|
|
| 44 |
next_token = sample_token(logits)
|
| 45 |
generated.append(next_token)
|
| 46 |
|
| 47 |
-
next_str = tokenizer.
|
| 48 |
-
encoded_text = tokenizer.encode(next_str)
|
| 49 |
decoded_text = tokenizer.decode(encoded_text)
|
| 50 |
print(decoded_text, end=" ", flush=True)
|
| 51 |
|
|
|
|
| 4 |
from dataset import MiniBPETokenizr,SimpleTokenizr
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
+
from tokenizers import Tokenizer
|
| 8 |
|
| 9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
|
| 11 |
# Load tokenizer
|
| 12 |
+
#tokenizer = SimpleTokenizr()
|
| 13 |
+
#tokenizer.load("./customchatbot-v1/trained-mini-gpt/tokenizer.json")
|
| 14 |
+
tokenizer = Tokenizer.from_file("./trained-mini-gpt/tokenizer.json")
|
| 15 |
|
| 16 |
# Load model
|
| 17 |
+
model = MiniGPT(vocab_size=tokenizer.get_vocab_size())
|
| 18 |
+
#model.load_state_dict(torch.load("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth", map_location=device) if os.path.exists("./customchatbot-v1/trained-mini-gpt/mini-gpt.pth") else torch.load("./customchatbot-v1/trained-mini-gpt/checkpoint-mini-gpt.pth", map_location=device)["model_state_dict"] )
|
| 19 |
+
checkpoint = torch.load("./trained-mini-gpt/mini-gpt.pth", map_location=device)
|
| 20 |
+
model.load_state_dict(checkpoint)
|
| 21 |
model.eval().to(device)
|
| 22 |
totalparams = sum(p.numel() for p in model.parameters())
|
| 23 |
print(f"Model total params: {totalparams:,}")
|
|
|
|
| 34 |
return torch.multinomial(probs, num_samples=1).item()
|
| 35 |
|
| 36 |
def generate_reply(prompt, max_tokens=100):
|
| 37 |
+
tokens = tokenizer.encode(prompt).ids
|
| 38 |
if not tokens:
|
| 39 |
print("⚠️ Empty prompt after encoding.")
|
| 40 |
return
|
|
|
|
| 48 |
next_token = sample_token(logits)
|
| 49 |
generated.append(next_token)
|
| 50 |
|
| 51 |
+
next_str = tokenizer.id_to_token(next_token)
|
| 52 |
+
encoded_text = tokenizer.encode(next_str).ids
|
| 53 |
decoded_text = tokenizer.decode(encoded_text)
|
| 54 |
print(decoded_text, end=" ", flush=True)
|
| 55 |
|
model.py
CHANGED
|
@@ -2,26 +2,39 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
class MiniGPT(nn.Module):
|
| 5 |
-
def __init__(self, vocab_size, d_model=
|
| 6 |
super().__init__()
|
|
|
|
| 7 |
self.token_embed = nn.Embedding(vocab_size, d_model)
|
| 8 |
self.pos_embed = nn.Embedding(max_len, d_model)
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
| 10 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
|
|
|
| 11 |
self.ln = nn.LayerNorm(d_model)
|
| 12 |
self.fc_out = nn.Linear(d_model, vocab_size)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def forward(self, input_ids):
|
| 15 |
B, T = input_ids.shape
|
| 16 |
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
| 17 |
x = self.token_embed(input_ids) + self.pos_embed(pos)
|
| 18 |
x = x.transpose(0, 1) # [T, B, D]
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
x = x.transpose(0, 1) # [B, T, D]
|
| 21 |
x = self.ln(x)
|
| 22 |
return self.fc_out(x)
|
| 23 |
-
|
| 24 |
def reset_params(self):
|
| 25 |
for layer in self.children():
|
| 26 |
-
if hasattr(layer,'reset_parameters'):
|
| 27 |
layer.reset_parameters()
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
|
| 4 |
class MiniGPT(nn.Module):
|
| 5 |
+
def __init__(self, vocab_size, d_model=1024, n_heads=16, n_layers=24, max_len=512):
|
| 6 |
super().__init__()
|
| 7 |
+
|
| 8 |
self.token_embed = nn.Embedding(vocab_size, d_model)
|
| 9 |
self.pos_embed = nn.Embedding(max_len, d_model)
|
| 10 |
+
|
| 11 |
+
# 🎯 CHANGE 1: Set dropout to 0.0 for debugging underfitting on tiny data
|
| 12 |
+
# This allows the model to memorize the small dataset.
|
| 13 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dropout=0.0, batch_first=False)
|
| 14 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
| 15 |
+
|
| 16 |
self.ln = nn.LayerNorm(d_model)
|
| 17 |
self.fc_out = nn.Linear(d_model, vocab_size)
|
| 18 |
|
| 19 |
+
def generate_causal_mask(self, T, device):
|
| 20 |
+
# This mask is correct for a TransformerEncoder used causally (True masks future tokens)
|
| 21 |
+
return torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
|
| 22 |
+
|
| 23 |
def forward(self, input_ids):
|
| 24 |
B, T = input_ids.shape
|
| 25 |
pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
| 26 |
x = self.token_embed(input_ids) + self.pos_embed(pos)
|
| 27 |
x = x.transpose(0, 1) # [T, B, D]
|
| 28 |
+
|
| 29 |
+
# Causal Mask
|
| 30 |
+
mask = self.generate_causal_mask(T, input_ids.device)
|
| 31 |
+
|
| 32 |
+
x = self.transformer(x, mask)
|
| 33 |
x = x.transpose(0, 1) # [B, T, D]
|
| 34 |
x = self.ln(x)
|
| 35 |
return self.fc_out(x)
|
| 36 |
+
|
| 37 |
def reset_params(self):
|
| 38 |
for layer in self.children():
|
| 39 |
+
if hasattr(layer, 'reset_parameters'):
|
| 40 |
layer.reset_parameters()
|
train_custom.py
CHANGED
|
@@ -1,19 +1,50 @@
|
|
| 1 |
import torch
|
| 2 |
-
from dataset import MiniBPETokenizr, ChatDataset, train,SimpleTokenizr
|
| 3 |
from model import MiniGPT
|
| 4 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
#
|
| 7 |
-
|
| 8 |
-
texts = [json.loads(line)["text"] for line in f if line.strip()]
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
model.reset_params()
|
| 16 |
-
#model.load_state_dict(torch.load(ch_path))
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from dataset import MiniBPETokenizr, ChatDataset, train, SimpleTokenizr # SimpleTokenizr might be unused now
|
| 3 |
from model import MiniGPT
|
| 4 |
import json
|
| 5 |
+
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, normalizers
|
| 6 |
+
from tokenizers.trainers import BpeTrainer
|
| 7 |
+
from tokenizers.normalizers import Lowercase, NFD, StripAccents
|
| 8 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 9 |
|
| 10 |
+
# For debugging purposes, turn on anomaly detection for gradients
|
| 11 |
+
torch.autograd.set_detect_anomaly(True)
|
|
|
|
| 12 |
|
| 13 |
+
# Load training data
|
| 14 |
+
# NOTE: For underfitting on "10 examples", ensure this file *only* contains those 10 examples,
|
| 15 |
+
# and they are long enough (as you confirmed).
|
| 16 |
+
with open("./data/overfit_data.jsonl", "r", encoding="utf-8") as f:
|
| 17 |
+
texts = [(json.loads(line)["input"] + ' ' + json.loads(line)["output"]) for line in f if line.strip()]
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
def main():
|
| 20 |
+
# 🧠 Initialize HuggingFace BPE tokenizer
|
| 21 |
+
tokenizer = Tokenizer(models.BPE(unk_token="<UNK>"))
|
| 22 |
+
tokenizer.normalizer = normalizers.Sequence([Lowercase(), NFD(), StripAccents()])
|
| 23 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 24 |
+
|
| 25 |
+
# 🛠️ BPE Training
|
| 26 |
+
trainer = BpeTrainer(
|
| 27 |
+
vocab_size=28517,
|
| 28 |
+
special_tokens=["<PAD>", "<UNK>", "<END>", "^User:", "MiniGPT:"]
|
| 29 |
+
)
|
| 30 |
+
tokenizer.train_from_iterator(texts, trainer)
|
| 31 |
+
|
| 32 |
+
# 💾 Save tokenizer
|
| 33 |
+
tokenizer.save("./trained-mini-gpt/tokenizer.json")
|
| 34 |
+
hf_tokenizer = Tokenizer.from_file("./trained-mini-gpt/tokenizer.json")
|
| 35 |
+
|
| 36 |
+
# 🧾 Dataset & Model Init
|
| 37 |
+
dataset = ChatDataset(
|
| 38 |
+
data="./data/overfit_data.jsonl", # Ensure this path points to your 10-example dataset for testing
|
| 39 |
+
tokenizer=hf_tokenizer
|
| 40 |
+
)
|
| 41 |
+
model = MiniGPT(vocab_size=hf_tokenizer.get_vocab_size())
|
| 42 |
+
model.reset_params()
|
| 43 |
+
|
| 44 |
+
# 🚂 Train
|
| 45 |
+
# 🎯 CHANGE 2: Pass an increased learning rate (e.g., 1e-4) to the train function.
|
| 46 |
+
# Set epochs to a high number for clear overfitting.
|
| 47 |
+
train(model, dataset, hf_tokenizer, epochs=200, filepathh="./data/merged_data.jsonl", learning_rate=1e-4)
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|