""" VicAI Text Generation Interactive text generation and sampling utilities. """ import argparse import sys import torch from model import VicAIModel, VicAIConfig, create_vicai_5b from tokenizer import ByteLevelBPETokenizer, BPETokenizer from utils import get_logger def generate_interactive( model, tokenizer, device, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, repetition_penalty: float = 1.1, ): """Interactive text generation loop.""" print("\n" + "=" * 60) print("VicAI Interactive Generation") print("=" * 60) print("Commands:") print(" /quit - Exit the program") print(" /config - Show current generation settings") print(" /temp X - Set temperature (0.1 - 2.0)") print(" /topk X - Set top-k (1 - 100)") print(" /topp X - Set top-p (0.0 - 1.0)") print(" /reppen X - Set repetition penalty (1.0 - 2.0)") print(" /maxlen X - Set max new tokens") print("=" * 60 + "\n") # Current settings settings = { 'temperature': temperature, 'top_k': top_k, 'top_p': top_p, 'repetition_penalty': repetition_penalty, 'max_new_tokens': max_new_tokens, } while True: try: # Get prompt prompt = input("\nPrompt: ").strip() # Handle commands if prompt == '/quit': print("Goodbye!") break if prompt == '/config': print("\nCurrent settings:") for key, value in settings.items(): print(f" {key}: {value}") continue if prompt.startswith('/temp '): try: settings['temperature'] = float(prompt.split()[1]) print(f"Temperature set to {settings['temperature']}") except (ValueError, IndexError): print("Invalid temperature value") continue if prompt.startswith('/topk '): try: settings['top_k'] = int(prompt.split()[1]) print(f"Top-k set to {settings['top_k']}") except (ValueError, IndexError): print("Invalid top-k value") continue if prompt.startswith('/topp '): try: settings['top_p'] = float(prompt.split()[1]) print(f"Top-p set to {settings['top_p']}") except (ValueError, IndexError): print("Invalid top-p value") continue if prompt.startswith('/reppen '): try: settings['repetition_penalty'] = float(prompt.split()[1]) print(f"Repetition penalty set to {settings['repetition_penalty']}") except (ValueError, IndexError): print("Invalid repetition penalty value") continue if prompt.startswith('/maxlen '): try: settings['max_new_tokens'] = int(prompt.split()[1]) print(f"Max new tokens set to {settings['max_new_tokens']}") except (ValueError, IndexError): print("Invalid max new tokens value") continue if not prompt: continue # Encode prompt input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) # Generate print("\nGenerating...") with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=settings['max_new_tokens'], temperature=settings['temperature'], top_k=settings['top_k'], top_p=settings['top_p'], repetition_penalty=settings['repetition_penalty'], eos_token_id=tokenizer.eos_token_id, ) # Decode and print generated_text = tokenizer.decode(output_ids[0].tolist()) # Remove the original prompt from output prompt_text = tokenizer.decode(input_ids[0].tolist()) if generated_text.startswith(prompt_text): generated_text = generated_text[len(prompt_text):].strip() print("\n" + "-" * 60) print("Generated:") print("-" * 60) print(generated_text) print("-" * 60) # Print token info num_tokens = output_ids.shape[1] - input_ids.shape[1] print(f"\nTokens generated: {num_tokens}") except KeyboardInterrupt: print("\n\nInterrupted by user. Type /quit to exit.") except Exception as e: print(f"\nError: {e}") def generate_batch( model, tokenizer, prompts: list, device, max_new_tokens: int = 256, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, ): """Generate completions for multiple prompts.""" results = [] for prompt in prompts: input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output_ids[0].tolist()) prompt_text = tokenizer.decode(input_ids[0].tolist()) if generated_text.startswith(prompt_text): generated_text = generated_text[len(prompt_text):].strip() results.append({ 'prompt': prompt, 'completion': generated_text, }) return results def benchmark_generation( model, tokenizer, device, num_runs: int = 10, max_new_tokens: int = 128, prompt: str = "The future of artificial intelligence is", ): """Benchmark generation speed.""" import time print(f"\nBenchmarking generation ({num_runs} runs)...") input_ids = torch.tensor([tokenizer.encode(prompt)], device=device) # Warmup with torch.no_grad(): _ = model.generate(input_ids, max_new_tokens=10) torch.cuda.synchronize() # Benchmark times = [] tokens_generated = [] for i in range(num_runs): start = time.time() with torch.no_grad(): output = model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=1.0, ) torch.cuda.synchronize() elapsed = time.time() - start num_tokens = output.shape[1] - input_ids.shape[1] times.append(elapsed) tokens_generated.append(num_tokens) print(f" Run {i+1}: {num_tokens} tokens in {elapsed:.2f}s ({num_tokens/elapsed:.1f} tok/s)") avg_time = sum(times) / len(times) avg_tokens = sum(tokens_generated) / len(tokens_generated) avg_speed = avg_tokens / avg_time print(f"\nAverage: {avg_tokens:.1f} tokens in {avg_time:.2f}s ({avg_speed:.1f} tok/s)") def main(): parser = argparse.ArgumentParser(description='Generate text with VicAI') parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint') parser.add_argument('--tokenizer', type=str, default='tokenizer.pkl', help='Path to tokenizer') parser.add_argument('--prompt', type=str, default=None, help='Single prompt to generate from') parser.add_argument('--interactive', action='store_true', help='Interactive mode') parser.add_argument('--max-new-tokens', type=int, default=256, help='Maximum tokens to generate') parser.add_argument('--temperature', type=float, default=0.8, help='Sampling temperature') parser.add_argument('--top-k', type=int, default=50, help='Top-k sampling') parser.add_argument('--top-p', type=float, default=0.9, help='Top-p (nucleus) sampling') parser.add_argument('--repetition-penalty', type=float, default=1.1, help='Repetition penalty') parser.add_argument('--benchmark', action='store_true', help='Run generation benchmark') parser.add_argument('--device', type=str, default='cuda', help='Device to use') args = parser.parse_args() # Setup device device = torch.device(args.device if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Load tokenizer print(f"Loading tokenizer from {args.tokenizer}...") # Use ByteLevelBPETokenizer by default (our trained tokenizer) tokenizer = ByteLevelBPETokenizer() tokenizer.load(args.tokenizer) print(f"Tokenizer loaded: {len(tokenizer)} tokens") # Load model print(f"Loading model from {args.checkpoint}...") checkpoint = torch.load(args.checkpoint, map_location=device) # Create model (assuming 5B config) model = create_vicai_5b(vocab_size=len(tokenizer)) # Load weights state_dict = checkpoint.get('model', checkpoint) model.load_state_dict(state_dict) model = model.to(device) model.eval() print(f"Model loaded: ~{model.get_num_params() / 1e9:.2f}B parameters") # Run benchmark if requested if args.benchmark: benchmark_generation(model, tokenizer, device) return # Interactive mode if args.interactive or args.prompt is None: generate_interactive( model, tokenizer, device, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, ) else: # Single prompt generation print(f"\nPrompt: {args.prompt}") print("-" * 60) input_ids = torch.tensor([tokenizer.encode(args.prompt)], device=device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, eos_token_id=tokenizer.eos_token_id, ) generated_text = tokenizer.decode(output_ids[0].tolist()) prompt_text = tokenizer.decode(input_ids[0].tolist()) if generated_text.startswith(prompt_text): generated_text = generated_text[len(prompt_text):].strip() print(generated_text) print("-" * 60) num_tokens = output_ids.shape[1] - input_ids.shape[1] print(f"\nGenerated {num_tokens} tokens") if __name__ == '__main__': main()