i_like_purple / chat.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M Chatbot Interface
============================
Interactive terminal chatbot using a trained GPT-300M model.
Usage:
python chat.py --checkpoint ./checkpoints/best_model.pt
# Or with custom generation parameters:
python chat.py --checkpoint ./checkpoints/best_model.pt \
--temperature 0.8 --top_k 40 --max_tokens 256
"""
import argparse
import sys
import time
from typing import List, Dict, Optional
import torch
from config import GPT300MConfig
from model import GPT300M
from tokenizer import BPETokenizer
class ChatBot:
"""
Interactive chatbot powered by GPT-300M.
Maintains conversation history, handles tokenization/detokenization,
and performs autoregressive generation with KV-caching.
"""
def __init__(
self,
model: GPT300M,
tokenizer: BPETokenizer,
config: GPT300MConfig,
device: str = "auto",
):
self.config = config
self.tokenizer = tokenizer
# Device
if device == "auto":
if torch.cuda.is_available():
self.device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
else:
self.device = device
self.model = model.to(self.device)
self.model.eval()
# Conversation state
self.history: List[Dict[str, str]] = []
self.system_prompt = config.system_prompt
def set_system_prompt(self, prompt: str):
"""Set the system prompt for the conversation."""
self.system_prompt = prompt
def reset(self):
"""Clear conversation history."""
self.history = []
print("\n✦ Conversation reset.\n")
def chat(
self,
user_message: str,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
stream: bool = True,
) -> str:
"""
Send a message and get a response.
Args:
user_message: The user's input
temperature: Override sampling temperature
top_k: Override top-k
top_p: Override top-p
max_new_tokens: Override max generation length
stream: Whether to stream tokens to stdout
Returns:
The assistant's response text
"""
temp = temperature or self.config.temperature
k = top_k or self.config.top_k
p = top_p or self.config.top_p
max_tokens = max_new_tokens or self.config.max_new_tokens
# Build conversation messages
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.extend(self.history)
messages.append({"role": "user", "content": user_message})
# Tokenize
input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
# Check sequence length
if input_tensor.size(1) > self.config.max_seq_len - max_tokens:
# Truncate history if needed
while (
len(self.history) > 0
and input_tensor.size(1) > self.config.max_seq_len - max_tokens
):
self.history.pop(0)
messages = []
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
messages.extend(self.history)
messages.append({"role": "user", "content": user_message})
input_ids = self.tokenizer.encode_chat(messages, add_generation_prompt=True)
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=self.device)
# Generate
t0 = time.time()
if stream:
response_text = self._generate_streaming(
input_tensor, max_tokens, temp, k, p
)
else:
with torch.no_grad():
output_ids = self.model.generate(
input_tensor,
max_new_tokens=max_tokens,
temperature=temp,
top_k=k,
top_p=p,
repetition_penalty=self.config.repetition_penalty,
eos_token_id=self.tokenizer.special_tokens.get("<|end|>"),
)
# Decode only the new tokens
new_ids = output_ids[0, input_tensor.size(1):].tolist()
response_text = self.tokenizer.decode(new_ids, skip_special=True)
dt = time.time() - t0
n_tokens = len(self.tokenizer.encode(response_text))
# Update history
self.history.append({"role": "user", "content": user_message})
self.history.append({"role": "assistant", "content": response_text.strip()})
if stream:
print(f"\n [{n_tokens} tokens, {dt:.1f}s, {n_tokens/dt:.1f} tok/s]")
return response_text.strip()
@torch.no_grad()
def _generate_streaming(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
) -> str:
"""Generate tokens one at a time, printing as we go."""
import torch.nn.functional as F
model = self.model
model.eval()
eos_id = self.tokenizer.special_tokens.get("<|end|>")
end_id = self.tokenizer.special_tokens.get("<eos>")
# Initial forward pass
logits, _, kv_caches = model(input_ids, use_cache=True)
generated_ids = []
buffer = b""
for step in range(max_new_tokens):
next_logits = logits[:, -1, :]
# Repetition penalty
if self.config.repetition_penalty != 1.0:
for tid in set(generated_ids):
if next_logits[0, tid] > 0:
next_logits[0, tid] /= self.config.repetition_penalty
else:
next_logits[0, tid] *= self.config.repetition_penalty
# Temperature + sampling
if temperature > 0:
next_logits = next_logits / temperature
if top_k > 0:
topk_vals, _ = torch.topk(next_logits, min(top_k, next_logits.size(-1)))
next_logits[next_logits < topk_vals[:, -1:]] = float("-inf")
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_logits.argmax(dim=-1, keepdim=True)
token_id = next_token.item()
# Check for stop tokens
if token_id in (eos_id, end_id):
break
generated_ids.append(token_id)
# Decode and print the new token
token_bytes = self.tokenizer.vocab.get(token_id, b"")
buffer += token_bytes
try:
text = buffer.decode("utf-8")
sys.stdout.write(text)
sys.stdout.flush()
buffer = b""
except UnicodeDecodeError:
pass # Wait for more bytes
# Forward with KV-cache
position_offset = input_ids.size(1) + step
logits, _, kv_caches = model(
next_token,
kv_caches=kv_caches,
use_cache=True,
position_offset=position_offset,
)
# Flush remaining buffer
if buffer:
text = buffer.decode("utf-8", errors="replace")
sys.stdout.write(text)
sys.stdout.flush()
return self.tokenizer.decode(generated_ids, skip_special=True)
def interactive_chat(chatbot: ChatBot):
"""Run an interactive chat session in the terminal."""
print("=" * 60)
print(" GPT-300M Chatbot")
print(" Type 'quit' to exit, 'reset' to clear history")
print(" Type 'system: <prompt>' to set system prompt")
print("=" * 60)
print()
while True:
try:
user_input = input("You: ").strip()
except (KeyboardInterrupt, EOFError):
print("\n\nGoodbye!")
break
if not user_input:
continue
if user_input.lower() == "quit":
print("Goodbye!")
break
if user_input.lower() == "reset":
chatbot.reset()
continue
if user_input.lower().startswith("system:"):
prompt = user_input[7:].strip()
chatbot.set_system_prompt(prompt)
print(f"✦ System prompt set: {prompt}\n")
continue
print("\nAssistant: ", end="", flush=True)
chatbot.chat(user_input, stream=True)
print()
def load_model(checkpoint_path: str, device: str = "auto"):
"""Load a trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Reconstruct config
config = GPT300MConfig(**checkpoint["config"])
# Load model
model = GPT300M(config)
model.load_state_dict(checkpoint["model_state_dict"])
# Load tokenizer
tokenizer_path = os.path.join(
os.path.dirname(checkpoint_path), "tokenizer.json"
)
if os.path.exists(tokenizer_path):
tokenizer = BPETokenizer.load(tokenizer_path)
else:
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
print("Warning: Tokenizer not found, using untrained tokenizer")
return model, tokenizer, config
# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
import os
parser = argparse.ArgumentParser(description="GPT-300M Chatbot")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to model checkpoint")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=512)
parser.add_argument("--device", type=str, default="auto")
args = parser.parse_args()
if args.checkpoint and os.path.exists(args.checkpoint):
model, tokenizer, config = load_model(args.checkpoint, args.device)
else:
print("No checkpoint provided. Initializing random model for demo...")
from config import gpt_tiny
config = gpt_tiny()
model = GPT300M(config)
tokenizer = BPETokenizer(vocab_size=config.vocab_size)
# Quick train on minimal data
tokenizer.train("Hello! How are you? I am fine. " * 100)
config.temperature = args.temperature
config.top_k = args.top_k
config.top_p = args.top_p
config.max_new_tokens = args.max_tokens
chatbot = ChatBot(model, tokenizer, config, device=args.device)
interactive_chat(chatbot)