""" GPT-300M Chatbot Interface ============================ Interactive terminal chatbot using a trained GPT-300M model. Usage: python chat.py --checkpoint ./checkpoints/best_model.pt # Or with custom generation parameters: python chat.py --checkpoint ./checkpoints/best_model.pt \ --temperature 0.8 --top_k 40 --max_tokens 256 """ import argparse import sys import time from typing import List, Dict, Optional import torch from config import GPT300MConfig from model import GPT300M from tokenizer import BPETokenizer class ChatBot: """ Interactive chatbot powered by GPT-300M. Maintains conversation history, handles tokenization/detokenization, and performs autoregressive generation with KV-caching. """ def __init__( self, model: GPT300M, tokenizer: BPETokenizer, config: GPT300MConfig, device: str = "auto", ): self.config = config self.tokenizer = tokenizer # Device if device == "auto": if torch.cuda.is_available(): self.device = "cuda" elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" else: self.device = device self.model = model.to(self.device) self.model.eval() # Conversation state self.history: List[Dict[str, str]] = [] self.system_prompt = config.system_prompt def set_system_prompt(self, prompt: str): """Set the system prompt for the conversation.""" self.system_prompt = prompt def reset(self): """Clear conversation history.""" self.history = [] print("\n✦ Conversation reset.\n") def chat( self, user_message: str, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, max_new_tokens: Optional[int] = None, stream: bool = True, ) -> str: """ Send a message and get a response. Args: user_message: The user's input temperature: Override sampling temperature top_k: Override top-k top_p: Override top-p max_new_tokens: Override max generation length stream: Whether to stream tokens to stdout Returns: The assistant's response text """ temp = temperature or self.config.temperature k = top_k or self.config.top_k p = top_p or self.config.top_p max_tokens = max_new_tokens or self.config.max_new_tokens # Build conversation messages messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) messages.extend(self.history) messages.append({"role": "user", "content": user_message}) # Tokenize input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True) input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) # Check sequence length if input_tensor.size(1) > self.config.max_seq_len - max_tokens: # Truncate history if needed while ( len(self.history) > 0 and input_tensor.size(1) > self.config.max_seq_len - max_tokens ): self.history.pop(0) messages = [] if self.system_prompt: messages.append({"role": "system", "content": self.system_prompt}) messages.extend(self.history) messages.append({"role": "user", "content": user_message}) input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True) input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) # Generate t0 = time.time() if stream: response_text = self._generate_streaming( input_tensor, max_tokens, temp, k, p ) else: with torch.no_grad(): output_ids = self.model.generate( input_tensor, max_new_tokens=max_tokens, temperature=temp, top_k=k, top_p=p, repetition_penalty=self.config.repetition_penalty, eos_token_id=self.tokenizer.special_tokens.get("<|end|>"), ) # Decode only the new tokens new_ids = output_ids[0, input_tensor.size(1):].tolist() response_text = self.tokenizer.decode(new_ids, skip_special=True) dt = time.time() - t0 n_tokens = len(self.tokenizer.encode(response_text)) # Update history self.history.append({"role": "user", "content": user_message}) self.history.append({"role": "assistant", "content": response_text.strip()}) if stream: print(f"\n [{n_tokens} tokens, {dt:.1f}s, {n_tokens/dt:.1f} tok/s]") return response_text.strip() @torch.no_grad() def _generate_streaming( self, input_ids: torch.Tensor, max_new_tokens: int, temperature: float, top_k: int, top_p: float, ) -> str: """Generate tokens one at a time, printing as we go.""" import torch.nn.functional as F model = self.model model.eval() eos_id = self.tokenizer.special_tokens.get("<|end|>") end_id = self.tokenizer.special_tokens.get("") # Initial forward pass logits, _, kv_caches = model(input_ids, use_cache=True) generated_ids = [] buffer = b"" for step in range(max_new_tokens): next_logits = logits[:, -1, :] # Repetition penalty if self.config.repetition_penalty != 1.0: for tid in set(generated_ids): if next_logits[0, tid] > 0: next_logits[0, tid] /= self.config.repetition_penalty else: next_logits[0, tid] *= self.config.repetition_penalty # Temperature + sampling if temperature > 0: next_logits = next_logits / temperature if top_k > 0: topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) next_logits[next_logits < topk_vals[:, -1:]] = float("-inf") probs = F.softmax(next_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = next_logits.argmax(dim=-1, keepdim=True) token_id = next_token.item() # Check for stop tokens if token_id in (eos_id, end_id): break generated_ids.append(token_id) # Decode and print the new token token_bytes = self.tokenizer.vocab.get(token_id, b"") buffer += token_bytes try: text = buffer.decode("utf-8") sys.stdout.write(text) sys.stdout.flush() buffer = b"" except UnicodeDecodeError: pass # Wait for more bytes # Forward with KV-cache position_offset = input_ids.size(1) + step logits, _, kv_caches = model( next_token, kv_caches=kv_caches, use_cache=True, position_offset=position_offset, ) # Flush remaining buffer if buffer: text = buffer.decode("utf-8", errors="replace") sys.stdout.write(text) sys.stdout.flush() return self.tokenizer.decode(generated_ids, skip_special=True) def interactive_chat(chatbot: ChatBot): """Run an interactive chat session in the terminal.""" print("=" * 60) print(" GPT-300M Chatbot") print(" Type 'quit' to exit, 'reset' to clear history") print(" Type 'system: ' to set system prompt") print("=" * 60) print() while True: try: user_input = input("You: ").strip() except (KeyboardInterrupt, EOFError): print("\n\nGoodbye!") break if not user_input: continue if user_input.lower() == "quit": print("Goodbye!") break if user_input.lower() == "reset": chatbot.reset() continue if user_input.lower().startswith("system:"): prompt = user_input[7:].strip() chatbot.set_system_prompt(prompt) print(f"✦ System prompt set: {prompt}\n") continue print("\nAssistant: ", end="", flush=True) chatbot.chat(user_input, stream=True) print() def load_model(checkpoint_path: str, device: str = "auto"): """Load a trained model from checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location="cpu") # Reconstruct config config = GPT300MConfig(**checkpoint["config"]) # Load model model = GPT300M(config) model.load_state_dict(checkpoint["model_state_dict"]) # Load tokenizer tokenizer_path = os.path.join( os.path.dirname(checkpoint_path), "tokenizer.json" ) if os.path.exists(tokenizer_path): tokenizer = BPETokenizer.load(tokenizer_path) else: tokenizer = BPETokenizer(vocab_size=config.vocab_size) print("Warning: Tokenizer not found, using untrained tokenizer") return model, tokenizer, config # ═══════════════════════════════════════════════════════════════════════ # MAIN # ═══════════════════════════════════════════════════════════════════════ if __name__ == "__main__": import os parser = argparse.ArgumentParser(description="GPT-300M Chatbot") parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--top_k", type=int, default=50) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=512) parser.add_argument("--device", type=str, default="auto") args = parser.parse_args() if args.checkpoint and os.path.exists(args.checkpoint): model, tokenizer, config = load_model(args.checkpoint, args.device) else: print("No checkpoint provided. Initializing random model for demo...") from config import gpt_tiny config = gpt_tiny() model = GPT300M(config) tokenizer = BPETokenizer(vocab_size=config.vocab_size) # Quick train on minimal data tokenizer.train("Hello! How are you? I am fine. " * 100) config.temperature = args.temperature config.top_k = args.top_k config.top_p = args.top_p config.max_new_tokens = args.max_tokens chatbot = ChatBot(model, tokenizer, config, device=args.device) interactive_chat(chatbot)