| | """ |
| | GPT-300M Chatbot Interface |
| | ============================ |
| | Interactive terminal chatbot using a trained GPT-300M model. |
| | |
| | Usage: |
| | python chat.py --checkpoint ./checkpoints/best_model.pt |
| | |
| | # Or with custom generation parameters: |
| | python chat.py --checkpoint ./checkpoints/best_model.pt \ |
| | --temperature 0.8 --top_k 40 --max_tokens 256 |
| | """ |
| |
|
| | import argparse |
| | import sys |
| | import time |
| | from typing import List, Dict, Optional |
| |
|
| | import torch |
| |
|
| | from config import GPT300MConfig |
| | from model import GPT300M |
| | from tokenizer import BPETokenizer |
| |
|
| |
|
| | class ChatBot: |
| | """ |
| | Interactive chatbot powered by GPT-300M. |
| | |
| | Maintains conversation history, handles tokenization/detokenization, |
| | and performs autoregressive generation with KV-caching. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | model: GPT300M, |
| | tokenizer: BPETokenizer, |
| | config: GPT300MConfig, |
| | device: str = "auto", |
| | ): |
| | self.config = config |
| | self.tokenizer = tokenizer |
| |
|
| | |
| | if device == "auto": |
| | if torch.cuda.is_available(): |
| | self.device = "cuda" |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | self.device = "mps" |
| | else: |
| | self.device = "cpu" |
| | else: |
| | self.device = device |
| |
|
| | self.model = model.to(self.device) |
| | self.model.eval() |
| |
|
| | |
| | self.history: List[Dict[str, str]] = [] |
| | self.system_prompt = config.system_prompt |
| |
|
| | def set_system_prompt(self, prompt: str): |
| | """Set the system prompt for the conversation.""" |
| | self.system_prompt = prompt |
| |
|
| | def reset(self): |
| | """Clear conversation history.""" |
| | self.history = [] |
| | print("\n⦠Conversation reset.\n") |
| |
|
| | def chat( |
| | self, |
| | user_message: str, |
| | temperature: Optional[float] = None, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | max_new_tokens: Optional[int] = None, |
| | stream: bool = True, |
| | ) -> str: |
| | """ |
| | Send a message and get a response. |
| | |
| | Args: |
| | user_message: The user's input |
| | temperature: Override sampling temperature |
| | top_k: Override top-k |
| | top_p: Override top-p |
| | max_new_tokens: Override max generation length |
| | stream: Whether to stream tokens to stdout |
| | |
| | Returns: |
| | The assistant's response text |
| | """ |
| | temp = temperature or self.config.temperature |
| | k = top_k or self.config.top_k |
| | p = top_p or self.config.top_p |
| | max_tokens = max_new_tokens or self.config.max_new_tokens |
| |
|
| | |
| | messages = [] |
| | if self.system_prompt: |
| | messages.append({"role": "system", "content": self.system_prompt}) |
| | messages.extend(self.history) |
| | messages.append({"role": "user", "content": user_message}) |
| |
|
| | |
| | input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True) |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
| |
|
| | |
| | if input_tensor.size(1) > self.config.max_seq_len - max_tokens: |
| | |
| | while ( |
| | len(self.history) > 0 |
| | and input_tensor.size(1) > self.config.max_seq_len - max_tokens |
| | ): |
| | self.history.pop(0) |
| | messages = [] |
| | if self.system_prompt: |
| | messages.append({"role": "system", "content": self.system_prompt}) |
| | messages.extend(self.history) |
| | messages.append({"role": "user", "content": user_message}) |
| | input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True) |
| | input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device) |
| |
|
| | |
| | t0 = time.time() |
| |
|
| | if stream: |
| | response_text = self._generate_streaming( |
| | input_tensor, max_tokens, temp, k, p |
| | ) |
| | else: |
| | with torch.no_grad(): |
| | output_ids = self.model.generate( |
| | input_tensor, |
| | max_new_tokens=max_tokens, |
| | temperature=temp, |
| | top_k=k, |
| | top_p=p, |
| | repetition_penalty=self.config.repetition_penalty, |
| | eos_token_id=self.tokenizer.special_tokens.get("<|end|>"), |
| | ) |
| | |
| | new_ids = output_ids[0, input_tensor.size(1):].tolist() |
| | response_text = self.tokenizer.decode(new_ids, skip_special=True) |
| |
|
| | dt = time.time() - t0 |
| | n_tokens = len(self.tokenizer.encode(response_text)) |
| |
|
| | |
| | self.history.append({"role": "user", "content": user_message}) |
| | self.history.append({"role": "assistant", "content": response_text.strip()}) |
| |
|
| | if stream: |
| | print(f"\n [{n_tokens} tokens, {dt:.1f}s, {n_tokens/dt:.1f} tok/s]") |
| |
|
| | return response_text.strip() |
| |
|
| | @torch.no_grad() |
| | def _generate_streaming( |
| | self, |
| | input_ids: torch.Tensor, |
| | max_new_tokens: int, |
| | temperature: float, |
| | top_k: int, |
| | top_p: float, |
| | ) -> str: |
| | """Generate tokens one at a time, printing as we go.""" |
| | import torch.nn.functional as F |
| |
|
| | model = self.model |
| | model.eval() |
| |
|
| | eos_id = self.tokenizer.special_tokens.get("<|end|>") |
| | end_id = self.tokenizer.special_tokens.get("<eos>") |
| |
|
| | |
| | logits, _, kv_caches = model(input_ids, use_cache=True) |
| |
|
| | generated_ids = [] |
| | buffer = b"" |
| |
|
| | for step in range(max_new_tokens): |
| | next_logits = logits[:, -1, :] |
| |
|
| | |
| | if self.config.repetition_penalty != 1.0: |
| | for tid in set(generated_ids): |
| | if next_logits[0, tid] > 0: |
| | next_logits[0, tid] /= self.config.repetition_penalty |
| | else: |
| | next_logits[0, tid] *= self.config.repetition_penalty |
| |
|
| | |
| | if temperature > 0: |
| | next_logits = next_logits / temperature |
| | if top_k > 0: |
| | topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1))) |
| | next_logits[next_logits < topk_vals[:, -1:]] = float("-inf") |
| | probs = F.softmax(next_logits, dim=-1) |
| | next_token = torch.multinomial(probs, num_samples=1) |
| | else: |
| | next_token = next_logits.argmax(dim=-1, keepdim=True) |
| |
|
| | token_id = next_token.item() |
| |
|
| | |
| | if token_id in (eos_id, end_id): |
| | break |
| |
|
| | generated_ids.append(token_id) |
| |
|
| | |
| | token_bytes = self.tokenizer.vocab.get(token_id, b"") |
| | buffer += token_bytes |
| | try: |
| | text = buffer.decode("utf-8") |
| | sys.stdout.write(text) |
| | sys.stdout.flush() |
| | buffer = b"" |
| | except UnicodeDecodeError: |
| | pass |
| |
|
| | |
| | position_offset = input_ids.size(1) + step |
| | logits, _, kv_caches = model( |
| | next_token, |
| | kv_caches=kv_caches, |
| | use_cache=True, |
| | position_offset=position_offset, |
| | ) |
| |
|
| | |
| | if buffer: |
| | text = buffer.decode("utf-8", errors="replace") |
| | sys.stdout.write(text) |
| | sys.stdout.flush() |
| |
|
| | return self.tokenizer.decode(generated_ids, skip_special=True) |
| |
|
| |
|
| | def interactive_chat(chatbot: ChatBot): |
| | """Run an interactive chat session in the terminal.""" |
| | print("=" * 60) |
| | print(" GPT-300M Chatbot") |
| | print(" Type 'quit' to exit, 'reset' to clear history") |
| | print(" Type 'system: <prompt>' to set system prompt") |
| | print("=" * 60) |
| | print() |
| |
|
| | while True: |
| | try: |
| | user_input = input("You: ").strip() |
| | except (KeyboardInterrupt, EOFError): |
| | print("\n\nGoodbye!") |
| | break |
| |
|
| | if not user_input: |
| | continue |
| |
|
| | if user_input.lower() == "quit": |
| | print("Goodbye!") |
| | break |
| |
|
| | if user_input.lower() == "reset": |
| | chatbot.reset() |
| | continue |
| |
|
| | if user_input.lower().startswith("system:"): |
| | prompt = user_input[7:].strip() |
| | chatbot.set_system_prompt(prompt) |
| | print(f"β¦ System prompt set: {prompt}\n") |
| | continue |
| |
|
| | print("\nAssistant: ", end="", flush=True) |
| | chatbot.chat(user_input, stream=True) |
| | print() |
| |
|
| |
|
| | def load_model(checkpoint_path: str, device: str = "auto"): |
| | """Load a trained model from checkpoint.""" |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| |
|
| | |
| | config = GPT300MConfig(**checkpoint["config"]) |
| |
|
| | |
| | model = GPT300M(config) |
| | model.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | |
| | tokenizer_path = os.path.join( |
| | os.path.dirname(checkpoint_path), "tokenizer.json" |
| | ) |
| | if os.path.exists(tokenizer_path): |
| | tokenizer = BPETokenizer.load(tokenizer_path) |
| | else: |
| | tokenizer = BPETokenizer(vocab_size=config.vocab_size) |
| | print("Warning: Tokenizer not found, using untrained tokenizer") |
| |
|
| | return model, tokenizer, config |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | import os |
| |
|
| | parser = argparse.ArgumentParser(description="GPT-300M Chatbot") |
| | parser.add_argument("--checkpoint", type=str, default=None, |
| | help="Path to model checkpoint") |
| | parser.add_argument("--temperature", type=float, default=0.7) |
| | parser.add_argument("--top_k", type=int, default=50) |
| | parser.add_argument("--top_p", type=float, default=0.9) |
| | parser.add_argument("--max_tokens", type=int, default=512) |
| | parser.add_argument("--device", type=str, default="auto") |
| | args = parser.parse_args() |
| |
|
| | if args.checkpoint and os.path.exists(args.checkpoint): |
| | model, tokenizer, config = load_model(args.checkpoint, args.device) |
| | else: |
| | print("No checkpoint provided. Initializing random model for demo...") |
| | from config import gpt_tiny |
| | config = gpt_tiny() |
| | model = GPT300M(config) |
| | tokenizer = BPETokenizer(vocab_size=config.vocab_size) |
| | |
| | tokenizer.train("Hello! How are you? I am fine. " * 100) |
| |
|
| | config.temperature = args.temperature |
| | config.top_k = args.top_k |
| | config.top_p = args.top_p |
| | config.max_new_tokens = args.max_tokens |
| |
|
| | chatbot = ChatBot(model, tokenizer, config, device=args.device) |
| | interactive_chat(chatbot) |
| |
|