import os import json import torch import torch.nn as nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace from pathlib import Path import argparse class LightweightGPT(nn.Module): def __init__(self, vocab_size, block_size, n_embd, n_head, n_layer): super().__init__() self.block_size = block_size self.token_embedding = nn.Embedding(vocab_size, n_embd) self.position_embedding = nn.Embedding(block_size, n_embd) self.blocks = nn.ModuleList([ nn.TransformerDecoderLayer( d_model=n_embd, nhead=n_head, dim_feedforward=4 * n_embd, dropout=0.1, activation='gelu', batch_first=True, norm_first=True ) for _ in range(n_layer) ]) self.ln_f = nn.LayerNorm(n_embd) self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) def forward(self, idx, targets=None): B, T = idx.shape device = idx.device causal_mask = torch.triu(torch.ones(T, T, device=device, dtype=torch.bool), diagonal=1) token_emb = self.token_embedding(idx) pos = torch.arange(0, T, dtype=torch.long, device=device) pos_emb = self.position_embedding(pos) x = token_emb + pos_emb for block in self.blocks: x = block(x, x, tgt_mask=causal_mask) x = self.ln_f(x) logits = self.lm_head(x) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 ) return logits, loss def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, stop_token=None): for _ in range(max_new_tokens): idx_cond = idx[:, -self.block_size:] logits, _ = self(idx_cond) logits = logits[:, -1, :] logits = logits / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) if stop_token is not None and idx_next.item() == stop_token: break idx = torch.cat((idx, idx_next), dim=1) return idx class ConversationDataset(Dataset): def __init__(self, tokens, block_size, end_token_id): self.end_token = end_token_id self.block_size = block_size self.segments = [] current_start = 0 for i, token in enumerate(tokens): if token == end_token_id: segment = tokens[current_start:i+1] if len(segment) < block_size + 1: padding = [end_token_id] * (block_size + 1 - len(segment)) segment.extend(padding) self.segments.append(segment) current_start = i + 1 print(f"Created {len(self.segments)} conversation segments.") def __len__(self): return len(self.segments) def __getitem__(self, idx): segment = self.segments[idx] start_pos = torch.randint(0, max(1, len(segment) - self.block_size), (1,)).item() chunk = segment[start_pos:start_pos + self.block_size + 1] x = torch.tensor(chunk[:-1], dtype=torch.long) y = torch.tensor(chunk[1:], dtype=torch.long) return x, y class AIBuilder: def __init__(self, model_name: str): self.model_name = model_name self.output_folder = model_name.replace(" ", "_").lower() self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") self.model_config = { "block_size": 128, "n_embd": 128, "n_head": 4, "n_layer": 4, "vocab_size": 8000, "batch_size": 8, "grad_accum": 4, "max_epochs": 3, } def _build_tokenizer(self, training_data: str): tokenizer = Tokenizer(BPE(unk_token="[UNK]")) tokenizer.pre_tokenizer = Whitespace() trainer = BpeTrainer( special_tokens=["[UNK]", "[PAD]", "user:", "ai:", "<|endoftext|>"], vocab_size=self.model_config["vocab_size"] ) tokenizer.train_from_iterator(self._get_text_iterator(training_data), trainer) return tokenizer def _get_text_iterator(self, text, chunk_size=1000): for i in range(0, len(text), chunk_size): yield text[i:i + chunk_size] def _prepare_dataloader(self, tokenizer, text): tokens = tokenizer.encode(text).ids end_token_id = tokenizer.token_to_id("<|endoftext|>") dataset = ConversationDataset(tokens, self.model_config["block_size"], end_token_id) def collate_fn(batch): xs, ys = zip(*batch) return torch.stack(xs), torch.stack(ys) return DataLoader(dataset, batch_size=self.model_config["batch_size"], shuffle=True, collate_fn=collate_fn) def train(self, training_data: str): os.makedirs(self.output_folder, exist_ok=True) print("Building and saving tokenizer...") tokenizer = self._build_tokenizer(training_data) tokenizer.save(os.path.join(self.output_folder, "tokenizer.json")) print("Saving configuration file...") self._save_config(tokenizer) # MOVED HERE print("Preparing data for training...") dataloader = self._prepare_dataloader(tokenizer, training_data) model = LightweightGPT( vocab_size=tokenizer.get_vocab_size(), block_size=self.model_config["block_size"], n_embd=self.model_config["n_embd"], n_head=self.model_config["n_head"], n_layer=self.model_config["n_layer"] ).to(self.device) optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) model_path = os.path.join(self.output_folder, "model.pt") print("\n--- Starting Model Training ---") model.train() best_loss = float('inf') for epoch in range(self.model_config["max_epochs"]): optimizer.zero_grad() for batch_idx, (x, y) in enumerate(dataloader): x, y = x.to(self.device), y.to(self.device) _, loss = model(x, y) loss = loss / self.model_config["grad_accum"] loss.backward() if (batch_idx + 1) % self.model_config["grad_accum"] == 0: optimizer.step() optimizer.zero_grad() current_loss = loss.detach().item() * self.model_config["grad_accum"] if batch_idx % 50 == 0: print(f"Epoch {epoch+1} | Batch {batch_idx} | Loss: {current_loss:.4f}") if current_loss < best_loss: best_loss = current_loss torch.save(model.state_dict(), model_path) print(f"šŸŽ‰ New best model saved with loss: {best_loss:.4f}") print(f"āœ… Training complete. Final best loss: {best_loss:.4f}") def _save_config(self, tokenizer): config = { "model_name": self.model_name, **self.model_config, "vocab_size": tokenizer.get_vocab_size(), "end_token_id": tokenizer.token_to_id("<|endoftext|>") } with open(os.path.join(self.output_folder, "config.json"), "w") as f: json.dump(config, f, indent=2) print(f"Configuration saved to {os.path.join(self.output_folder, 'config.json')}") class ChatInterface: def __init__(self, model_dir="aglm"): self.model_dir = Path(model_dir) self.device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" self.load_model() def load_model(self): with open(self.model_dir / "config.json", "r") as f: self.config = json.load(f) self.tokenizer = Tokenizer.from_file(str(self.model_dir / "tokenizer.json")) self.end_token_id = self.config.get("end_token_id") self.model = LightweightGPT( vocab_size=self.config["vocab_size"], block_size=self.config["block_size"], n_embd=self.config["n_embd"], n_head=self.config["n_head"], n_layer=self.config["n_layer"] ).to(self.device) self.model.load_state_dict(torch.load(self.model_dir / "model.pt", map_location=self.device)) self.model.eval() print("āœ… Model loaded successfully!") def chat(self): print("\n===== AI Assistant Ready =====") print("Type 'quit' or 'exit' to end the chat.\n") while True: user_input = input("user: ") if user_input.lower() in ["quit", "exit"]: break prompt = f"user: {user_input}\nai:" input_ids = self.tokenizer.encode(prompt).ids input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) with torch.no_grad(): output_ids = self.model.generate( input_tensor, max_new_tokens=150, temperature=0.7, top_k=40, stop_token=self.end_token_id ) response_ids = output_ids[0, len(input_ids):].tolist() response = self.tokenizer.decode(response_ids) response = response.replace("<|endoftext|>", "").strip() print(f"ai: {response}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train or chat with an AgLM model.") parser.add_argument('action', choices=['train', 'chat'], nargs='?', default='train', help="Choose 'train' (default) or 'chat'.") args = parser.parse_args() model_folder = "aglm" if args.action == 'train': print("--- Starting Setup for AgLM ---") builder = AIBuilder("AgLM") try: with open("train.txt", "r", encoding="utf-8") as f: data = f.read() builder.train(data) print("\nāœ… Training finished. You can now run with the 'chat' argument.") print(f"To chat, run: python {os.path.basename(__file__)} chat") except FileNotFoundError: print("\nERROR: train.txt not found. Please create train.txt with your conversational data to train the model.") elif args.action == 'chat': print("--- Starting Chat Interface for AgLM ---") if os.path.exists(model_folder) and os.path.exists(os.path.join(model_folder, "model.pt")): chat_bot = ChatInterface(model_dir=model_folder) chat_bot.chat() else: print(f"\nERROR: Model directory '{model_folder}' not found. Please run training first.")