import os import torch import torch.nn.functional as F import tiktoken from model import ChatGCLM, MAX_SEQ_LEN MODEL_PATH = None for f in os.listdir("."): if f.startswith("ChatGCLM_") and f.endswith(".pt"): MODEL_PATH = f break if MODEL_PATH is None: print("Error: No model checkpoint found!") print("Please train the model first with: python3 train_chatgclm.py") exit(1) TOKENIZER_NAME = "gpt2" EOS_ID = 2 def load_model(device): tok = tiktoken.get_encoding(TOKENIZER_NAME) vocab_size = tok.n_vocab model = ChatGCLM(vocab_size).to(device) if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: print(f"Loading model from: {MODEL_PATH}") model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) model.eval() return model, tok else: print(f"Error: Could not load model from {MODEL_PATH}") return None, None @torch.no_grad() def generate(model, prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50): model.eval() input_ids = tokenizer.encode(prompt) x = torch.tensor([input_ids], dtype=torch.long, device=device) print(f"\n{'='*70}") print(f"PROMPT: {prompt}") print(f"{'='*70}") print("GENERATED TEXT:") print(prompt, end="", flush=True) generated_tokens = [] for _ in range(max_new_tokens): ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x logits = model(ctx) next_token_logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) idx = next_token.item() if idx == EOS_ID: break x = torch.cat((x, next_token), dim=1) generated_tokens.append(idx) token_text = tokenizer.decode([idx]) print(token_text, end="", flush=True) print(f"\n{'='*70}\n") return tokenizer.decode(generated_tokens) if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using device: {device}") model, tokenizer = load_model(device) if model is None: exit(1) test_prompts = [ "Once upon a time", "The future of AI is", "In a world where", ] print("\n" + "="*70) print("ChatGCLM Text Generation Demo") print("="*70) for prompt in test_prompts: generate(model, prompt, tokenizer, device, max_new_tokens=150, temperature=0.8, top_k=50) print("\n" + "="*70) print("Interactive Mode - Enter your own prompts!") print("="*70) while True: user_prompt = input("\nEnter prompt (or 'exit' to quit): ") if user_prompt.lower() == 'exit': break if user_prompt.strip(): generate(model, user_prompt, tokenizer, device, max_new_tokens=200, temperature=0.8, top_k=50)