import os import random from collections import Counter import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader from tqdm import tqdm import glob MODEL_FILE = "AgGPT21.pt" DATA_FOLDER = "training_corpora/" SEED = 42 random.seed(SEED) torch.manual_seed(SEED) SEQ_LEN = 64 STRIDE = 64 EMBED_SIZE = 128 HIDDEN_SIZE = 128 NUM_LAYERS = 1 DROPOUT = 0.2 BATCH_SIZE = 8 EPOCHS = 6 LR = 2e-3 WEIGHT_DECAY = 1e-4 CLIP_NORM = 1.0 GENERATE_LENGTH = 200 DATA_PERCENT = 0.1 MAX_TOKENS = 1_000_000 MAX_VOCAB = 30000 TEMPERATURE = 0.9 TOP_K = 50 TOP_P = 0.9 if torch.backends.mps.is_available(): DEVICE = torch.device("mps") elif torch.cuda.is_available(): DEVICE = torch.device("cuda") else: DEVICE = torch.device("cpu") def build_vocab_and_ids(folder_path, percent=1.0, max_tokens=None, max_vocab=None): """Build vocabulary and token IDs from all text files in a folder.""" all_words = [] # Get all .txt files in the folder txt_files = glob.glob(os.path.join(folder_path, "*.txt")) if not txt_files: raise FileNotFoundError(f"No .txt files found in {folder_path}") print(f"Found {len(txt_files)} training files") # Limit number of files to process based on percent if percent < 1.0: num_files_to_use = max(1, int(len(txt_files) * percent)) txt_files = txt_files[:num_files_to_use] print(f"Using {percent*100}% of files: {num_files_to_use}/{len(glob.glob(os.path.join(folder_path, '*.txt')))} files") # Read and combine selected files for file_path in sorted(txt_files): print(f"Reading {os.path.basename(file_path)}...") with open(file_path, "r", encoding="utf-8") as f: text = f.read().lower() # Split by whitespace and filter out empty strings words = [w for w in text.split() if w] all_words.extend(words) print(f"Total words loaded: {len(all_words):,}") if max_tokens is not None: all_words = all_words[:max_tokens] print(f"Truncated to max_tokens: {len(all_words):,} words") counts = Counter(all_words) if max_vocab is not None: keep = max(1, max_vocab - 1) common = [w for w, _ in counts.most_common(keep) if w != ""] vocab = [""] + common else: vocab = [""] + [w for w in counts if w != ""] stoi = {w: i for i, w in enumerate(vocab)} itos = {i: w for w, i in stoi.items()} ids = [stoi.get(w, 0) for w in all_words] print(f"Vocabulary size: {len(vocab):,}") return vocab, stoi, itos, ids class WordDataset(Dataset): def __init__(self, ids, seq_len, stride=None): self.ids = ids self.seq_len = seq_len self.stride = stride or seq_len self.n = max(0, (len(self.ids) - self.seq_len - 1) // self.stride + 1) def __len__(self): return self.n def __getitem__(self, idx): start = idx * self.stride x = torch.tensor(self.ids[start:start + self.seq_len], dtype=torch.long) y = torch.tensor(self.ids[start + 1:start + self.seq_len + 1], dtype=torch.long) return x, y class WordRNN(nn.Module): def __init__(self, vocab_size, embed_size=EMBED_SIZE, hidden_size=HIDDEN_SIZE, num_layers=NUM_LAYERS, dropout=DROPOUT): super().__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.drop = nn.Dropout(dropout) self.gru = nn.GRU(embed_size, hidden_size, num_layers=num_layers, batch_first=True) self.proj = None if hidden_size != embed_size: self.proj = nn.Linear(hidden_size, embed_size, bias=False) out_size = embed_size if self.proj else hidden_size self.fc = nn.Linear(out_size, vocab_size, bias=False) self.fc.weight = self.embed.weight def forward(self, x, hidden=None): e = self.drop(self.embed(x)) out, h = self.gru(e, hidden) out = self.drop(out) if self.proj is not None: out = self.proj(out) logits = self.fc(out) return logits, h def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def evaluate(model, dataloader, device, use_amp): model.eval() criterion = nn.CrossEntropyLoss(ignore_index=0) total_loss = 0.0 with torch.no_grad(): for x, y in dataloader: x = x.to(device) y = y.to(device) with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): logits, _ = model(x) loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) total_loss += loss.item() return total_loss / max(1, len(dataloader)) def train(model, train_loader, val_loader, epochs, lr, device, weight_decay, clip_norm, stoi, itos): model.to(device) opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) criterion = nn.CrossEntropyLoss(ignore_index=0) use_amp = device.type in {"mps", "cuda"} best_val = float("inf") patience = 2 epochs_no_improve = 0 print(f"Train batches per epoch: {len(train_loader)} | Val batches: {len(val_loader)}") epoch_bar = tqdm(range(epochs), desc="Epochs") for epoch in epoch_bar: model.train() total_loss = 0.0 batch_bar = tqdm(train_loader, desc=f"Train {epoch+1}/{epochs}", leave=False) for x, y in batch_bar: x = x.to(device) y = y.to(device) opt.zero_grad() with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): logits, _ = model(x) loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) opt.step() total_loss += loss.item() batch_bar.close() train_loss = total_loss / max(1, len(train_loader)) val_loss = evaluate(model, val_loader, device, use_amp) epoch_bar.set_postfix(train=f"{train_loss:.4f}", val=f"{val_loss:.4f}") if val_loss < best_val - 1e-4: best_val = val_loss epochs_no_improve = 0 torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) else: epochs_no_improve += 1 if epochs_no_improve >= patience: print("Early stopping.") break ckpt = torch.load(MODEL_FILE, map_location=device) model.load_state_dict(ckpt["model_state"]) return model def _sample_next_id(probs_1d, top_k=None, top_p=None, temperature=1.0, forbid_ids=None): probs = probs_1d.clone() if forbid_ids: for i in forbid_ids: if 0 <= i < probs.numel(): probs[i] = 0 if temperature != 1.0: logits = torch.log(probs + 1e-9) / temperature probs = torch.softmax(logits, dim=-1) if probs.sum() <= 0: probs = torch.ones_like(probs) if forbid_ids: for i in forbid_ids: if 0 <= i < probs.numel(): probs[i] = 0 probs = probs / probs.sum() if top_k is not None and top_k > 0: k = min(top_k, probs.size(-1)) values, indices = torch.topk(probs, k) values = values / values.sum() idx = indices[torch.multinomial(values, 1)] return idx.item() if top_p is not None and 0 < top_p < 1.0: sorted_probs, sorted_indices = torch.sort(probs, descending=True) cumulative = torch.cumsum(sorted_probs, dim=-1) keep_mask = cumulative <= top_p keep = int(keep_mask.nonzero()[-1].item()) + 1 if keep_mask.any() else 1 sorted_probs = sorted_probs[:keep] sorted_indices = sorted_indices[:keep] sorted_probs = sorted_probs / sorted_probs.sum() idx_pos = torch.multinomial(sorted_probs, 1) return sorted_indices[idx_pos].item() probs = probs / probs.sum() return torch.multinomial(probs, 1).item() def generate_text(model, stoi, itos, prompt, length=GENERATE_LENGTH, seq_len=SEQ_LEN, device=DEVICE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P): model.to(device) model.eval() words = prompt.lower().split() ids = [stoi.get(w, 0) for w in words] context = ids[-seq_len:] if len(ids) >= seq_len else [0] * (seq_len - len(ids)) + ids input_ids = torch.tensor(context, dtype=torch.long).unsqueeze(0).to(device) hidden = None generated = words.copy() use_amp = device.type in {"mps", "cuda"} with torch.no_grad(): gen_bar = tqdm(range(length), desc="Generating", leave=False) for _ in gen_bar: with torch.autocast(device_type=device.type, dtype=torch.float16, enabled=use_amp): logits, hidden = model(input_ids, hidden) probs = torch.softmax(logits[:, -1, :], dim=-1).squeeze(0) next_id = _sample_next_id(probs, top_k=top_k, top_p=top_p, temperature=temperature, forbid_ids=[0]) next_word = itos.get(next_id, "") generated.append(next_word) input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device) return " ".join(generated) if __name__ == "__main__": if os.path.exists(MODEL_FILE): ckpt = torch.load(MODEL_FILE, map_location=DEVICE) stoi = ckpt["stoi"] itos = ckpt["itos"] model = WordRNN(len(stoi)) model.load_state_dict(ckpt["model_state"]) print(f"Loaded model {MODEL_FILE} | device={DEVICE} | params={count_parameters(model):,}") else: if not os.path.exists(DATA_FOLDER): raise FileNotFoundError(f"Training folder not found: {DATA_FOLDER}") vocab, stoi, itos, ids = build_vocab_and_ids(DATA_FOLDER, percent=DATA_PERCENT, max_tokens=MAX_TOKENS, max_vocab=MAX_VOCAB) print(f"Vocab size: {len(vocab):,} | Tokens used: {len(ids):,} | device={DEVICE}") val_tokens = max(SEQ_LEN * 5, int(0.05 * len(ids))) train_ids = ids[:-val_tokens] val_ids = ids[-val_tokens:] train_dataset = WordDataset(train_ids, SEQ_LEN, stride=STRIDE) val_dataset = WordDataset(val_ids, SEQ_LEN, stride=STRIDE) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) model = WordRNN(len(vocab)) print(f"Model params: {count_parameters(model):,}") model = train(model, train_loader, val_loader, EPOCHS, LR, DEVICE, WEIGHT_DECAY, CLIP_NORM, stoi, itos) torch.save({"model_state": model.state_dict(), "stoi": stoi, "itos": itos}, MODEL_FILE) print(f"Saved {MODEL_FILE}") print("\n=== AgGPT-21 Demo ===") prompts = ["hello world", "how are you", "once upon a time", "tell me about"] for p in prompts: print(f"\nPrompt: {p}") print(f"Generated: {generate_text(model, stoi, itos, p)}") print("\nTraining complete! Use chat.py for interactive conversation.")