AgGPT21 / AgGPT21.py
AGofficial's picture
Upload 5 files
69f4cd6 verified
import os
import random
from collections import Counter
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import glob
MODEL_FILE = "AgGPT21.pt"
DATA_FOLDER = "training_corpora/"
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
SEQ_LEN = 64
STRIDE = 64
EMBED_SIZE = 128
HIDDEN_SIZE = 128
NUM_LAYERS = 1
DROPOUT = 0.2
BATCH_SIZE = 8
EPOCHS = 6
LR = 2e-3
WEIGHT_DECAY = 1e-4
CLIP_NORM = 1.0
GENERATE_LENGTH = 200
DATA_PERCENT = 0.1
MAX_TOKENS = 1_000_000
MAX_VOCAB = 30000
TEMPERATURE = 0.9
TOP_K = 50
TOP_P = 0.9
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
elif torch.cuda.is_available():
DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
def build_vocab_and_ids(folder_path, percent=1.0, max_tokens=None, max_vocab=None):
"""Build vocabulary and token IDs from all text files in a folder."""
all_words = []
# Get all .txt files in the folder
txt_files = glob.glob(os.path.join(folder_path, "*.txt"))
if not txt_files:
raise FileNotFoundError(f"No .txt files found in {folder_path}")
print(f"Found {len(txt_files)} training files")
# Limit number of files to process based on percent
if percent < 1.0:
num_files_to_use = max(1, int(len(txt_files) * percent))
txt_files = txt_files[:num_files_to_use]
print(f"Using {percent*100}% of files: {num_files_to_use}/{len(glob.glob(os.path.join(folder_path, '*.txt')))} files")
# Read and combine selected files
for file_path in sorted(txt_files):
print(f"Reading {os.path.basename(file_path)}...")
with open(file_path, "r", encoding="utf-8") as f:
text = f.read().lower()
# Split by whitespace and filter out empty strings
words = [w for w in text.split() if w]
all_words.extend(words)
print(f"Total words loaded: {len(all_words):,}")
if max_tokens is not None:
all_words = all_words[:max_tokens]
print(f"Truncated to max_tokens: {len(all_words):,} words")
counts = Counter(all_words)
if max_vocab is not None:
keep = max(1, max_vocab - 1)
common = [w for w, _ in counts.most_common(keep) if w != "<unk>"]
vocab = ["<unk>"] + common
else:
vocab = ["<unk>"] + [w for w in counts if w != "<unk>"]
stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}
ids = [stoi.get(w, 0) for w in all_words]
print(f"Vocabulary size: {len(vocab):,}")
return vocab, stoi, itos, ids
class WordDataset(Dataset):
def __init__(self, ids, seq_len, stride=None):
self.ids = ids
self.seq_len = seq_len
self.stride = stride or seq_len
self.n = max(0, (len(self.ids) - self.seq_len - 1) // self.stride + 1)
def __len__(self):
return self.n
def __getitem__(self, idx):
start = idx * self.stride
x = torch.tensor(self.ids[start:start + self.seq_len], dtype=torch.long)
y = torch.tensor(self.ids[start + 1:start + self.seq_len + 1], dtype=torch.long)
return x, y
class WordRNN(nn.Module):
def __init__(self, vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.drop = nn.Dropout(dropout)
self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, batch_first=True)
self.proj = None
if hidden_size != embed_size:
self.proj = nn.Linear(hidden_size, embed_size, bias=False)
out_size = embed_size if self.proj else hidden_size
self.fc = nn.Linear(out_size, vocab_size, bias=False)
self.fc.weight = self.embed.weight
def forward(self, x, hidden=None):
e = self.drop(self.embed(x))
out, h = self.gru(e, hidden)
out = self.drop(out)
if self.proj is not None:
out = self.proj(out)
logits = self.fc(out)
return logits, h
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def evaluate(model, dataloader, device, use_amp):
model.eval()
criterion = nn.CrossEntropyLoss(ignore_index=0)
total_loss = 0.0
with torch.no_grad():
for x, y in dataloader:
x = x.to(device)
y = y.to(device)
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
total_loss += loss.item()
return total_loss / max(1, len(dataloader))
def train(model, train_loader, val_loader, epochs, lr, device, weight_decay, clip_norm, stoi, itos):
model.to(device)
opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss(ignore_index=0)
use_amp = device.type in {"mps", "cuda"}
best_val = float("inf")
patience = 2
epochs_no_improve = 0
print(f"Train batches per epoch: {len(train_loader)} | Val batches: {len(val_loader)}")
epoch_bar = tqdm(range(epochs), desc="Epochs")
for epoch in epoch_bar:
model.train()
total_loss = 0.0
batch_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", leave=False)
for x, y in batch_bar:
x = x.to(device)
y = y.to(device)
opt.zero_grad()
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
logits, _ = model(x)
loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
opt.step()
total_loss += loss.item()
batch_bar.close()
train_loss = total_loss / max(1, len(train_loader))
val_loss = evaluate(model, val_loader, device, use_amp)
epoch_bar.set_postfix(train=f"{train_loss:.4f}", val=f"{val_loss:.4f}")
if val_loss < best_val - 1e-4:
best_val = val_loss
epochs_no_improve = 0
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE)
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print("Early stopping.")
break
ckpt = torch.load(MODEL_FILE, map_location=device)
model.load_state_dict(ckpt["model_state"])
return model
def _sample_next_id(probs_1d, top_k=None, top_p=None, temperature=1.0, forbid_ids=None):
probs = probs_1d.clone()
if forbid_ids:
for i in forbid_ids:
if 0 <= i < probs.numel():
probs[i] = 0
if temperature != 1.0:
logits = torch.log(probs + 1e-9) / temperature
probs = torch.softmax(logits, dim=-1)
if probs.sum() <= 0:
probs = torch.ones_like(probs)
if forbid_ids:
for i in forbid_ids:
if 0 <= i < probs.numel():
probs[i] = 0
probs = probs / probs.sum()
if top_k is not None and top_k > 0:
k = min(top_k, probs.size(-1))
values, indices = torch.topk(probs, k)
values = values / values.sum()
idx = indices[torch.multinomial(values, 1)]
return idx.item()
if top_p is not None and 0 < top_p < 1.0:
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
keep_mask = cumulative <= top_p
keep = int(keep_mask.nonzero()[-1].item()) + 1 if keep_mask.any() else 1
sorted_probs = sorted_probs[:keep]
sorted_indices = sorted_indices[:keep]
sorted_probs = sorted_probs / sorted_probs.sum()
idx_pos = torch.multinomial(sorted_probs, 1)
return sorted_indices[idx_pos].item()
probs = probs / probs.sum()
return torch.multinomial(probs, 1).item()
def generate_text(model, stoi, itos, prompt, length=GENERATE_LENGTH, seq_len=SEQ_LEN, device=DEVICE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P):
model.to(device)
model.eval()
words = prompt.lower().split()
ids = [stoi.get(w, 0) for w in words]
context = ids[-seq_len:] if len(ids) >= seq_len else [0] * (seq_len - len(ids)) + ids
input_ids = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device)
hidden = None
generated = words.copy()
use_amp = device.type in {"mps", "cuda"}
with torch.no_grad():
gen_bar = tqdm(range(length), desc="Generating", leave=False)
for _ in gen_bar:
with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp):
logits, hidden = model(input_ids, hidden)
probs = torch.softmax(logits[:, -1, :], dim=-1).squeeze(0)
next_id = _sample_next_id(probs, top_k=top_k, top_p=top_p, temperature=temperature, forbid_ids=[0])
next_word = itos.get(next_id, "<unk>")
generated.append(next_word)
input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device)
return " ".join(generated)
if __name__ == "__main__":
if os.path.exists(MODEL_FILE):
ckpt = torch.load(MODEL_FILE, map_location=DEVICE)
stoi = ckpt["stoi"]
itos = ckpt["itos"]
model = WordRNN(len(stoi))
model.load_state_dict(ckpt["model_state"])
print(f"Loaded model {MODEL_FILE} | device={DEVICE} | params={count_parameters(model):,}")
else:
if not os.path.exists(DATA_FOLDER):
raise FileNotFoundError(f"Training folder not found: {DATA_FOLDER}")
vocab, stoi, itos, ids = build_vocab_and_ids(DATA_FOLDER, percent=DATA_PERCENT, max_tokens=MAX_TOKENS, max_vocab=MAX_VOCAB)
print(f"Vocab size: {len(vocab):,} | Tokens used: {len(ids):,} | device={DEVICE}")
val_tokens = max(SEQ_LEN * 5, int(0.05 * len(ids)))
train_ids = ids[:-val_tokens]
val_ids = ids[-val_tokens:]
train_dataset = WordDataset(train_ids, SEQ_LEN, stride=STRIDE)
val_dataset = WordDataset(val_ids, SEQ_LEN, stride=STRIDE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
model = WordRNN(len(vocab))
print(f"Model params: {count_parameters(model):,}")
model = train(model, train_loader, val_loader, EPOCHS, LR, DEVICE, WEIGHT_DECAY, CLIP_NORM, stoi, itos)
torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE)
print(f"Saved {MODEL_FILE}")
print("\n=== AgGPT-21 Demo ===")
prompts = ["hello world", "how are you", "once upon a time", "tell me about"]
for p in prompts:
print(f"\nPrompt: {p}")
print(f"Generated: {generate_text(model, stoi, itos, p)}")
print("\nTraining complete! Use chat.py for interactive conversation.")