import gradio as gr import torch import time from typing import Dict, Tuple from transformers import GPT2LMHeadModel, GPT2Tokenizer # ============================================================================= # LLM Decoding Strategy Analyzer # ============================================================================= # This application demonstrates 5 different text generation decoding strategies # using GPT-2, allowing users to compare outputs side-by-side. # # Research Foundation: # - Holtzman et al. (2019) "The Curious Case of Neural Text Degeneration" # https://arxiv.org/abs/1904.09751 # - Meister et al. (2020) "If beam search is the answer, what was the question?" # https://arxiv.org/abs/2010.02650 # - Basu et al. (2020) "Mirostat: A Neural Text Decoding Algorithm" # https://arxiv.org/abs/2007.14966 # ============================================================================= # ----- Model Loading ----- print("Loading GPT-2 model...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token model = GPT2LMHeadModel.from_pretrained("gpt2") model = model.to(device) model.eval() print(f"✅ Model loaded on {device}") # ----- Strategy Information ----- STRATEGY_INFO = { "greedy": { "name": "Greedy Decoding", "description": "Always selects the highest probability token. Deterministic but often repetitive.", "params": "do_sample=False" }, "beam": { "name": "Beam Search", "description": "Explores multiple hypotheses simultaneously. Deterministic, better than greedy but still conservative.", "params": "num_beams=5, no_repeat_ngram_size=2" }, "top_k": { "name": "Top-K Sampling", "description": "Randomly samples from the K most likely tokens. Adds variety but K is fixed regardless of distribution.", "params": "top_k=50, temperature=1.0" }, "top_p": { "name": "Top-P (Nucleus) Sampling", "description": "Samples from the smallest set of tokens whose cumulative probability exceeds P. Adapts to distribution shape.", "params": "top_p=0.95, temperature=1.0" }, "temperature": { "name": "Temperature + Top-P", "description": "Scales logits before sampling. Lower temperature = more focused. Combined with top-p for quality.", "params": "temperature=0.7, top_p=0.95" } } # ----- Generation Functions ----- def generate_with_strategy(prompt: str, strategy: str, max_new_tokens: int = 100) -> Tuple[str, float]: """Generate text using a specified decoding strategy.""" inputs = tokenizer(prompt, return_tensors="pt").to(device) input_length = inputs["input_ids"].shape[1] generation_configs = { "greedy": { "do_sample": False, "num_beams": 1, }, "beam": { "do_sample": False, "num_beams": 5, "early_stopping": True, "no_repeat_ngram_size": 2, }, "top_k": { "do_sample": True, "top_k": 50, "top_p": 1.0, "temperature": 1.0, }, "top_p": { "do_sample": True, "top_k": 0, "top_p": 0.95, "temperature": 1.0, }, "temperature": { "do_sample": True, "top_k": 0, "top_p": 0.95, "temperature": 0.7, }, } config = generation_configs[strategy] if device.type == "cuda": torch.cuda.synchronize() start_time = time.perf_counter() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, **config ) if device.type == "cuda": torch.cuda.synchronize() end_time = time.perf_counter() generated_tokens = outputs[0][input_length:] generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) generation_time = end_time - start_time return generated_text, generation_time def generate_all_strategies(prompt: str, max_new_tokens: int = 100) -> Dict[str, Dict]: """Generate text using all 5 strategies and return results.""" strategies = ["greedy", "beam", "top_k", "top_p", "temperature"] results = {} for strategy in strategies: text, gen_time = generate_with_strategy(prompt, strategy, max_new_tokens) tokens_generated = len(tokenizer.encode(text)) results[strategy] = { "text": text, "time": gen_time, "tokens": tokens_generated, "tokens_per_second": tokens_generated / gen_time if gen_time > 0 else 0 } return results def run_all_strategies(prompt: str, max_tokens: int) -> tuple: """Runs all 5 decoding strategies and returns formatted outputs for Gradio.""" if not prompt.strip(): empty_msg = "⚠️ Please enter a prompt." return (empty_msg,) * 6 try: results = generate_all_strategies(prompt, max_new_tokens=int(max_tokens)) outputs = [] summary_lines = ["| Strategy | Time | Tokens | Speed |", "|---|---|---|---|"] for strategy in ["greedy", "beam", "top_k", "top_p", "temperature"]: data = results[strategy] info = STRATEGY_INFO[strategy] output_text = f"**{info['name']}**\n\n" output_text += f"Parameters: `{info['params']}`\n\n" output_text += f"⏱️ {data['time']:.2f}s | 📝 {data['tokens']} tokens | ⚡ {data['tokens_per_second']:.1f} tok/s\n\n" output_text += "---\n\n" output_text += f"{data['text']}" outputs.append(output_text) summary_lines.append( f"| {info['name']} | {data['time']:.2f}s | {data['tokens']} | {data['tokens_per_second']:.1f} tok/s |" ) summary = "\n".join(summary_lines) outputs.append(summary) return tuple(outputs) except Exception as e: error_msg = f"❌ Error: {str(e)}" return (error_msg,) * 6 # ----- Gradio Interface ----- demo = gr.Blocks(theme=gr.themes.Soft()) with demo: gr.Markdown(""" # 🔬 LLM Decoding Strategy Analyzer Compare 5 text generation decoding strategies side-by-side using GPT-2. **Research Foundation:** - Holtzman et al. (2019) - [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) - Meister et al. (2020) - [If beam search is the answer, what was the question?](https://arxiv.org/abs/2010.02650) """) with gr.Row(): with gr.Column(scale=3): prompt_input = gr.Textbox( label="Enter your prompt", placeholder="In a distant galaxy, a lone astronaut discovered", lines=2 ) with gr.Column(scale=1): max_tokens_slider = gr.Slider( minimum=20, maximum=200, value=100, step=10, label="Max New Tokens" ) generate_btn = gr.Button("🚀 Generate All", variant="primary") gr.Examples( examples=[ ["In a distant galaxy, a lone astronaut discovered"], ["The secret to happiness is"], ["In the year 2050, artificial intelligence"], ["She opened the ancient book and read the first line:"], ["The most important scientific discovery of the century was"], ], inputs=prompt_input, label="Example Prompts" ) gr.Markdown("---") gr.Markdown("## 📊 Generation Results") gr.Markdown("### Deterministic Methods") with gr.Row(): greedy_output = gr.Markdown(label="Greedy Decoding") beam_output = gr.Markdown(label="Beam Search") gr.Markdown("### Stochastic Sampling Methods") with gr.Row(): topk_output = gr.Markdown(label="Top-K Sampling") topp_output = gr.Markdown(label="Top-P (Nucleus) Sampling") temp_output = gr.Markdown(label="Temperature + Top-P") gr.Markdown("### ⏱️ Performance Summary") summary_output = gr.Markdown() with gr.Accordion("📚 Strategy Explanations", open=False): gr.Markdown(""" | Strategy | How It Works | Pros | Cons | |----------|--------------|------|------| | **Greedy** | Always picks highest probability token | Fast, deterministic | Repetitive, boring | | **Beam Search** | Tracks top-k hypotheses simultaneously | More coherent than greedy | Still conservative, slow | | **Top-K Sampling** | Samples from K most likely tokens | Adds creativity | Fixed K ignores distribution shape | | **Top-P (Nucleus)** | Samples from smallest set with cumulative prob ≥ p | Adapts to context | Slightly slower | | **Temperature + Top-P** | Scales logits then applies top-p | Best quality/creativity balance | Requires tuning | **Key Insight:** Deterministic methods (greedy, beam) maximize probability but produce dull text. Sampling methods introduce controlled randomness for more human-like output. """) generate_btn.click( fn=run_all_strategies, inputs=[prompt_input, max_tokens_slider], outputs=[greedy_output, beam_output, topk_output, topp_output, temp_output, summary_output] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)