#!/usr/bin/env python3 """ Interactive script to test your trained ismAIl model. Load a checkpoint and generate text with custom prompts. """ import torch import json from pathlib import Path import sys from model import ismail, ModelArgs from generation import ( generate_text_simple, generate_text_with_sampling, text_to_token_ids, token_ids_to_text, get_tokenizer, load_checkpoint ) def interactive_generation(model, tokenizer, args): """Interactive mode: continuously prompt for text and generate responses.""" print("\n" + "="*60) print("🎤 INTERACTIVE GENERATION MODE") print("="*60) print("Commands:") print(" - Type your prompt and press Enter to generate") print(" - Type 'quit' or 'exit' to stop") print(" - Type 'params' to change generation parameters") print("="*60 + "\n") # Default generation parameters temperature = 0.8 top_k = 50 max_tokens = 50 use_sampling = True while True: try: prompt = input("\n💬 Prompt: ").strip() if prompt.lower() in ['quit', 'exit', 'q']: print("👋 Goodbye!") break if prompt.lower() == 'params': print("\n⚙️ Current parameters:") print(f" Temperature: {temperature}") print(f" Top-k: {top_k}") print(f" Max tokens: {max_tokens}") print(f" Use sampling: {use_sampling}") try: temp_input = input(f" New temperature (current: {temperature}): ").strip() if temp_input: temperature = float(temp_input) topk_input = input(f" New top-k (current: {top_k}): ").strip() if topk_input: top_k = int(topk_input) tokens_input = input(f" New max tokens (current: {max_tokens}): ").strip() if tokens_input: max_tokens = int(tokens_input) sampling_input = input(f" Use sampling? (y/n, current: {'y' if use_sampling else 'n'}): ").strip() if sampling_input: use_sampling = sampling_input.lower() in ['y', 'yes', 't', 'true'] print("✅ Parameters updated!") except ValueError as e: print(f"❌ Invalid input: {e}") continue if not prompt: print("⚠️ Empty prompt, try again") continue # Tokenize token_ids = text_to_token_ids(prompt, tokenizer) print(f"📝 Input tokens: {token_ids.shape[1]}") # Generate print("🤖 Generating...", end='', flush=True) if use_sampling: generated_ids = generate_text_with_sampling( model=model, idx=token_ids, max_new_tokens=max_tokens, context_size=args.max_seq_len, temperature=temperature, top_k=top_k ) else: generated_ids = generate_text_simple( model=model, idx=token_ids, max_new_tokens=max_tokens, context_size=args.max_seq_len ) # Decode generated_text = token_ids_to_text(generated_ids, tokenizer) print(f"\r🤖 Generated ({generated_ids.shape[1]} tokens):") print(f"\n{generated_text}\n") except KeyboardInterrupt: print("\n\n👋 Interrupted. Goodbye!") break except Exception as e: print(f"\n❌ Error: {e}") import traceback traceback.print_exc() def batch_generation(model, tokenizer, args, prompts): """Generate text for a list of prompts.""" print("\n" + "="*60) print("📋 BATCH GENERATION MODE") print("="*60 + "\n") for i, prompt in enumerate(prompts, 1): print(f"\n--- Prompt {i}/{len(prompts)} ---") print(f"Input: {prompt}") token_ids = text_to_token_ids(prompt, tokenizer) # Generate with sampling generated_ids = generate_text_with_sampling( model=model, idx=token_ids, max_new_tokens=50, context_size=args.max_seq_len, temperature=0.8, top_k=50 ) generated_text = token_ids_to_text(generated_ids, tokenizer) print(f"Output: {generated_text}\n") def main(): # Parse command line arguments if len(sys.argv) < 2: print("Usage: python test_model.py [--interactive] [--prompts \"prompt1\" \"prompt2\" ...]") print("\nExample:") print(" python test_model.py checkpoints/step_55000_expert_2.pt --interactive") print(" python test_model.py checkpoints/step_55000_expert_2.pt --prompts \"Merhaba\" \"Yapay zeka\"") sys.exit(1) checkpoint_path = sys.argv[1] interactive_mode = '--interactive' in sys.argv or '-i' in sys.argv # Extract prompts from command line custom_prompts = [] if '--prompts' in sys.argv: idx = sys.argv.index('--prompts') custom_prompts = [arg for arg in sys.argv[idx+1:] if not arg.startswith('--')] print("="*60) print("🧠 ismAIl Model Testing Script") print("="*60) # Load config config_path = Path("config.json") if config_path.exists(): with open(config_path) as f: config = json.load(f) print(f"✅ Loaded config from {config_path}") args = ModelArgs(**config["model"]) else: print("❌ config.json not found!") sys.exit(1) # Initialize tokenizer tokenizer_name = getattr(args, "tokenizer_name", "gpt2") use_turkish = tokenizer_name.lower() == "turkish" tokenizer = get_tokenizer( use_turkish=use_turkish, tokenizer_name="gpt2" if use_turkish else tokenizer_name ) # Update vocab size if using Turkish tokenizer if use_turkish: from data import TurkishTokenizerWrapper if isinstance(tokenizer, TurkishTokenizerWrapper): if args.vocab_size != tokenizer.n_vocab: print(f"⚠️ Updating vocab_size: {args.vocab_size:,} -> {tokenizer.n_vocab:,}") args.vocab_size = tokenizer.n_vocab # Initialize model print("\n🚀 Initializing model...") model = ismail(args) # Load checkpoint checkpoint_file = Path(checkpoint_path) if checkpoint_file.exists(): load_checkpoint(model, checkpoint_file) else: print(f"❌ Checkpoint not found: {checkpoint_path}") sys.exit(1) model.eval() # Run appropriate mode if interactive_mode: interactive_generation(model, tokenizer, args) elif custom_prompts: batch_generation(model, tokenizer, args, custom_prompts) else: # Default: use some Turkish prompts default_prompts = [ "Merhaba, ben", "Yapay zekanın geleceği", "Bir varmış bir yokmuş", "Türkiye'nin başkenti" ] batch_generation(model, tokenizer, args, default_prompts) if __name__ == "__main__": main()