|
|
import os |
|
|
import random |
|
|
from collections import Counter |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
import glob |
|
|
|
|
|
MODEL_FILE = "AgGPT21.pt" |
|
|
DATA_FOLDER = "training_corpora/" |
|
|
|
|
|
SEED = 42 |
|
|
random.seed(SEED) |
|
|
torch.manual_seed(SEED) |
|
|
|
|
|
SEQ_LEN = 64 |
|
|
STRIDE = 64 |
|
|
EMBED_SIZE = 128 |
|
|
HIDDEN_SIZE = 128 |
|
|
NUM_LAYERS = 1 |
|
|
DROPOUT = 0.2 |
|
|
|
|
|
BATCH_SIZE = 8 |
|
|
EPOCHS = 6 |
|
|
LR = 2e-3 |
|
|
WEIGHT_DECAY = 1e-4 |
|
|
CLIP_NORM = 1.0 |
|
|
|
|
|
GENERATE_LENGTH = 200 |
|
|
DATA_PERCENT = 0.1 |
|
|
MAX_TOKENS = 1_000_000 |
|
|
MAX_VOCAB = 30000 |
|
|
|
|
|
TEMPERATURE = 0.9 |
|
|
TOP_K = 50 |
|
|
TOP_P = 0.9 |
|
|
|
|
|
if torch.backends.mps.is_available(): |
|
|
DEVICE = torch.device("mps") |
|
|
elif torch.cuda.is_available(): |
|
|
DEVICE = torch.device("cuda") |
|
|
else: |
|
|
DEVICE = torch.device("cpu") |
|
|
|
|
|
def build_vocab_and_ids(folder_path, percent=1.0, max_tokens=None, max_vocab=None): |
|
|
"""Build vocabulary and token IDs from all text files in a folder.""" |
|
|
all_words = [] |
|
|
|
|
|
|
|
|
txt_files = glob.glob(os.path.join(folder_path, "*.txt")) |
|
|
if not txt_files: |
|
|
raise FileNotFoundError(f"No .txt files found in {folder_path}") |
|
|
|
|
|
print(f"Found {len(txt_files)} training files") |
|
|
|
|
|
|
|
|
if percent < 1.0: |
|
|
num_files_to_use = max(1, int(len(txt_files) * percent)) |
|
|
txt_files = txt_files[:num_files_to_use] |
|
|
print(f"Using {percent*100}% of files: {num_files_to_use}/{len(glob.glob(os.path.join(folder_path, '*.txt')))} files") |
|
|
|
|
|
|
|
|
for file_path in sorted(txt_files): |
|
|
print(f"Reading {os.path.basename(file_path)}...") |
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
text = f.read().lower() |
|
|
|
|
|
words = [w for w in text.split() if w] |
|
|
all_words.extend(words) |
|
|
|
|
|
print(f"Total words loaded: {len(all_words):,}") |
|
|
|
|
|
if max_tokens is not None: |
|
|
all_words = all_words[:max_tokens] |
|
|
print(f"Truncated to max_tokens: {len(all_words):,} words") |
|
|
|
|
|
counts = Counter(all_words) |
|
|
if max_vocab is not None: |
|
|
keep = max(1, max_vocab - 1) |
|
|
common = [w for w, _ in counts.most_common(keep) if w != "<unk>"] |
|
|
vocab = ["<unk>"] + common |
|
|
else: |
|
|
vocab = ["<unk>"] + [w for w in counts if w != "<unk>"] |
|
|
|
|
|
stoi = {w: i for i, w in enumerate(vocab)} |
|
|
itos = {i: w for w, i in stoi.items()} |
|
|
ids = [stoi.get(w, 0) for w in all_words] |
|
|
|
|
|
print(f"Vocabulary size: {len(vocab):,}") |
|
|
return vocab, stoi, itos, ids |
|
|
|
|
|
class WordDataset(Dataset): |
|
|
def __init__(self, ids, seq_len, stride=None): |
|
|
self.ids = ids |
|
|
self.seq_len = seq_len |
|
|
self.stride = stride or seq_len |
|
|
self.n = max(0, (len(self.ids) - self.seq_len - 1) // self.stride + 1) |
|
|
def __len__(self): |
|
|
return self.n |
|
|
def __getitem__(self, idx): |
|
|
start = idx * self.stride |
|
|
x = torch.tensor(self.ids[start:start + self.seq_len], dtype=torch.long) |
|
|
y = torch.tensor(self.ids[start + 1:start + self.seq_len + 1], dtype=torch.long) |
|
|
return x, y |
|
|
|
|
|
class WordRNN(nn.Module): |
|
|
def __init__(self, vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT): |
|
|
super().__init__() |
|
|
self.embed = nn.Embedding(vocab_size, embed_size) |
|
|
self.drop = nn.Dropout(dropout) |
|
|
self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, batch_first=True) |
|
|
self.proj = None |
|
|
if hidden_size != embed_size: |
|
|
self.proj = nn.Linear(hidden_size, embed_size, bias=False) |
|
|
out_size = embed_size if self.proj else hidden_size |
|
|
self.fc = nn.Linear(out_size, vocab_size, bias=False) |
|
|
self.fc.weight = self.embed.weight |
|
|
def forward(self, x, hidden=None): |
|
|
e = self.drop(self.embed(x)) |
|
|
out, h = self.gru(e, hidden) |
|
|
out = self.drop(out) |
|
|
if self.proj is not None: |
|
|
out = self.proj(out) |
|
|
logits = self.fc(out) |
|
|
return logits, h |
|
|
|
|
|
def count_parameters(model): |
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
def evaluate(model, dataloader, device, use_amp): |
|
|
model.eval() |
|
|
criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
|
total_loss = 0.0 |
|
|
with torch.no_grad(): |
|
|
for x, y in dataloader: |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
|
|
logits, _ = model(x) |
|
|
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) |
|
|
total_loss += loss.item() |
|
|
return total_loss / max(1, len(dataloader)) |
|
|
|
|
|
def train(model, train_loader, val_loader, epochs, lr, device, weight_decay, clip_norm, stoi, itos): |
|
|
model.to(device) |
|
|
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) |
|
|
criterion = nn.CrossEntropyLoss(ignore_index=0) |
|
|
use_amp = device.type in {"mps", "cuda"} |
|
|
best_val = float("inf") |
|
|
patience = 2 |
|
|
epochs_no_improve = 0 |
|
|
print(f"Train batches per epoch: {len(train_loader)} | Val batches: {len(val_loader)}") |
|
|
epoch_bar = tqdm(range(epochs), desc="Epochs") |
|
|
for epoch in epoch_bar: |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
batch_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", leave=False) |
|
|
for x, y in batch_bar: |
|
|
x = x.to(device) |
|
|
y = y.to(device) |
|
|
opt.zero_grad() |
|
|
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
|
|
logits, _ = model(x) |
|
|
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) |
|
|
opt.step() |
|
|
total_loss += loss.item() |
|
|
batch_bar.close() |
|
|
train_loss = total_loss / max(1, len(train_loader)) |
|
|
val_loss = evaluate(model, val_loader, device, use_amp) |
|
|
epoch_bar.set_postfix(train=f"{train_loss:.4f}", val=f"{val_loss:.4f}") |
|
|
if val_loss < best_val - 1e-4: |
|
|
best_val = val_loss |
|
|
epochs_no_improve = 0 |
|
|
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) |
|
|
else: |
|
|
epochs_no_improve += 1 |
|
|
if epochs_no_improve >= patience: |
|
|
print("Early stopping.") |
|
|
break |
|
|
ckpt = torch.load(MODEL_FILE, map_location=device) |
|
|
model.load_state_dict(ckpt["model_state"]) |
|
|
return model |
|
|
|
|
|
def _sample_next_id(probs_1d, top_k=None, top_p=None, temperature=1.0, forbid_ids=None): |
|
|
probs = probs_1d.clone() |
|
|
if forbid_ids: |
|
|
for i in forbid_ids: |
|
|
if 0 <= i < probs.numel(): |
|
|
probs[i] = 0 |
|
|
if temperature != 1.0: |
|
|
logits = torch.log(probs + 1e-9) / temperature |
|
|
probs = torch.softmax(logits, dim=-1) |
|
|
if probs.sum() <= 0: |
|
|
probs = torch.ones_like(probs) |
|
|
if forbid_ids: |
|
|
for i in forbid_ids: |
|
|
if 0 <= i < probs.numel(): |
|
|
probs[i] = 0 |
|
|
probs = probs / probs.sum() |
|
|
if top_k is not None and top_k > 0: |
|
|
k = min(top_k, probs.size(-1)) |
|
|
values, indices = torch.topk(probs, k) |
|
|
values = values / values.sum() |
|
|
idx = indices[torch.multinomial(values, 1)] |
|
|
return idx.item() |
|
|
if top_p is not None and 0 < top_p < 1.0: |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
cumulative = torch.cumsum(sorted_probs, dim=-1) |
|
|
keep_mask = cumulative <= top_p |
|
|
keep = int(keep_mask.nonzero()[-1].item()) + 1 if keep_mask.any() else 1 |
|
|
sorted_probs = sorted_probs[:keep] |
|
|
sorted_indices = sorted_indices[:keep] |
|
|
sorted_probs = sorted_probs / sorted_probs.sum() |
|
|
idx_pos = torch.multinomial(sorted_probs, 1) |
|
|
return sorted_indices[idx_pos].item() |
|
|
probs = probs / probs.sum() |
|
|
return torch.multinomial(probs, 1).item() |
|
|
|
|
|
def generate_text(model, stoi, itos, prompt, length=GENERATE_LENGTH, seq_len=SEQ_LEN, device=DEVICE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P): |
|
|
model.to(device) |
|
|
model.eval() |
|
|
words = prompt.lower().split() |
|
|
ids = [stoi.get(w, 0) for w in words] |
|
|
context = ids[-seq_len:] if len(ids) >= seq_len else [0] * (seq_len - len(ids)) + ids |
|
|
input_ids = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device) |
|
|
hidden = None |
|
|
generated = words.copy() |
|
|
use_amp = device.type in {"mps", "cuda"} |
|
|
with torch.no_grad(): |
|
|
gen_bar = tqdm(range(length), desc="Generating", leave=False) |
|
|
for _ in gen_bar: |
|
|
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): |
|
|
logits, hidden = model(input_ids, hidden) |
|
|
probs = torch.softmax(logits[:, -1, :], dim=-1).squeeze(0) |
|
|
next_id = _sample_next_id(probs, top_k=top_k, top_p=top_p, temperature=temperature, forbid_ids=[0]) |
|
|
next_word = itos.get(next_id, "<unk>") |
|
|
generated.append(next_word) |
|
|
input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device) |
|
|
return " ".join(generated) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
if os.path.exists(MODEL_FILE): |
|
|
ckpt = torch.load(MODEL_FILE, map_location=DEVICE) |
|
|
stoi = ckpt["stoi"] |
|
|
itos = ckpt["itos"] |
|
|
model = WordRNN(len(stoi)) |
|
|
model.load_state_dict(ckpt["model_state"]) |
|
|
print(f"Loaded model {MODEL_FILE} | device={DEVICE} | params={count_parameters(model):,}") |
|
|
else: |
|
|
if not os.path.exists(DATA_FOLDER): |
|
|
raise FileNotFoundError(f"Training folder not found: {DATA_FOLDER}") |
|
|
vocab, stoi, itos, ids = build_vocab_and_ids(DATA_FOLDER, percent=DATA_PERCENT, max_tokens=MAX_TOKENS, max_vocab=MAX_VOCAB) |
|
|
print(f"Vocab size: {len(vocab):,} | Tokens used: {len(ids):,} | device={DEVICE}") |
|
|
val_tokens = max(SEQ_LEN * 5, int(0.05 * len(ids))) |
|
|
train_ids = ids[:-val_tokens] |
|
|
val_ids = ids[-val_tokens:] |
|
|
train_dataset = WordDataset(train_ids, SEQ_LEN, stride=STRIDE) |
|
|
val_dataset = WordDataset(val_ids, SEQ_LEN, stride=STRIDE) |
|
|
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) |
|
|
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) |
|
|
model = WordRNN(len(vocab)) |
|
|
print(f"Model params: {count_parameters(model):,}") |
|
|
model = train(model, train_loader, val_loader, EPOCHS, LR, DEVICE, WEIGHT_DECAY, CLIP_NORM, stoi, itos) |
|
|
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) |
|
|
print(f"Saved {MODEL_FILE}") |
|
|
|
|
|
print("\n=== AgGPT-21 Demo ===") |
|
|
prompts = ["hello world", "how are you", "once upon a time", "tell me about"] |
|
|
for p in prompts: |
|
|
print(f"\nPrompt: {p}") |
|
|
print(f"Generated: {generate_text(model, stoi, itos, p)}") |
|
|
print("\nTraining complete! Use chat.py for interactive conversation.") |