Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import time | |
| from typing import Dict, Tuple | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| # ============================================================================= | |
| # LLM Decoding Strategy Analyzer | |
| # ============================================================================= | |
| # This application demonstrates 5 different text generation decoding strategies | |
| # using GPT-2, allowing users to compare outputs side-by-side. | |
| # | |
| # Research Foundation: | |
| # - Holtzman et al. (2019) "The Curious Case of Neural Text Degeneration" | |
| # https://arxiv.org/abs/1904.09751 | |
| # - Meister et al. (2020) "If beam search is the answer, what was the question?" | |
| # https://arxiv.org/abs/2010.02650 | |
| # - Basu et al. (2020) "Mirostat: A Neural Text Decoding Algorithm" | |
| # https://arxiv.org/abs/2007.14966 | |
| # ============================================================================= | |
| # ----- Model Loading ----- | |
| print("Loading GPT-2 model...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| model = model.to(device) | |
| model.eval() | |
| print(f"β Model loaded on {device}") | |
| # ----- Strategy Information ----- | |
| STRATEGY_INFO = { | |
| "greedy": { | |
| "name": "Greedy Decoding", | |
| "description": "Always selects the highest probability token. Deterministic but often repetitive.", | |
| "params": "do_sample=False" | |
| }, | |
| "beam": { | |
| "name": "Beam Search", | |
| "description": "Explores multiple hypotheses simultaneously. Deterministic, better than greedy but still conservative.", | |
| "params": "num_beams=5, no_repeat_ngram_size=2" | |
| }, | |
| "top_k": { | |
| "name": "Top-K Sampling", | |
| "description": "Randomly samples from the K most likely tokens. Adds variety but K is fixed regardless of distribution.", | |
| "params": "top_k=50, temperature=1.0" | |
| }, | |
| "top_p": { | |
| "name": "Top-P (Nucleus) Sampling", | |
| "description": "Samples from the smallest set of tokens whose cumulative probability exceeds P. Adapts to distribution shape.", | |
| "params": "top_p=0.95, temperature=1.0" | |
| }, | |
| "temperature": { | |
| "name": "Temperature + Top-P", | |
| "description": "Scales logits before sampling. Lower temperature = more focused. Combined with top-p for quality.", | |
| "params": "temperature=0.7, top_p=0.95" | |
| } | |
| } | |
| # ----- Generation Functions ----- | |
| def generate_with_strategy(prompt: str, strategy: str, max_new_tokens: int = 100) -> Tuple[str, float]: | |
| """Generate text using a specified decoding strategy.""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_length = inputs["input_ids"].shape[1] | |
| generation_configs = { | |
| "greedy": { | |
| "do_sample": False, | |
| "num_beams": 1, | |
| }, | |
| "beam": { | |
| "do_sample": False, | |
| "num_beams": 5, | |
| "early_stopping": True, | |
| "no_repeat_ngram_size": 2, | |
| }, | |
| "top_k": { | |
| "do_sample": True, | |
| "top_k": 50, | |
| "top_p": 1.0, | |
| "temperature": 1.0, | |
| }, | |
| "top_p": { | |
| "do_sample": True, | |
| "top_k": 0, | |
| "top_p": 0.95, | |
| "temperature": 1.0, | |
| }, | |
| "temperature": { | |
| "do_sample": True, | |
| "top_k": 0, | |
| "top_p": 0.95, | |
| "temperature": 0.7, | |
| }, | |
| } | |
| config = generation_configs[strategy] | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| start_time = time.perf_counter() | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tokenizer.eos_token_id, | |
| **config | |
| ) | |
| if device.type == "cuda": | |
| torch.cuda.synchronize() | |
| end_time = time.perf_counter() | |
| generated_tokens = outputs[0][input_length:] | |
| generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| generation_time = end_time - start_time | |
| return generated_text, generation_time | |
| def generate_all_strategies(prompt: str, max_new_tokens: int = 100) -> Dict[str, Dict]: | |
| """Generate text using all 5 strategies and return results.""" | |
| strategies = ["greedy", "beam", "top_k", "top_p", "temperature"] | |
| results = {} | |
| for strategy in strategies: | |
| text, gen_time = generate_with_strategy(prompt, strategy, max_new_tokens) | |
| tokens_generated = len(tokenizer.encode(text)) | |
| results[strategy] = { | |
| "text": text, | |
| "time": gen_time, | |
| "tokens": tokens_generated, | |
| "tokens_per_second": tokens_generated / gen_time if gen_time > 0 else 0 | |
| } | |
| return results | |
| def run_all_strategies(prompt: str, max_tokens: int) -> tuple: | |
| """Runs all 5 decoding strategies and returns formatted outputs for Gradio.""" | |
| if not prompt.strip(): | |
| empty_msg = "β οΈ Please enter a prompt." | |
| return (empty_msg,) * 6 | |
| try: | |
| results = generate_all_strategies(prompt, max_new_tokens=int(max_tokens)) | |
| outputs = [] | |
| summary_lines = ["| Strategy | Time | Tokens | Speed |", "|---|---|---|---|"] | |
| for strategy in ["greedy", "beam", "top_k", "top_p", "temperature"]: | |
| data = results[strategy] | |
| info = STRATEGY_INFO[strategy] | |
| output_text = f"**{info['name']}**\n\n" | |
| output_text += f"Parameters: `{info['params']}`\n\n" | |
| output_text += f"β±οΈ {data['time']:.2f}s | π {data['tokens']} tokens | β‘ {data['tokens_per_second']:.1f} tok/s\n\n" | |
| output_text += "---\n\n" | |
| output_text += f"{data['text']}" | |
| outputs.append(output_text) | |
| summary_lines.append( | |
| f"| {info['name']} | {data['time']:.2f}s | {data['tokens']} | {data['tokens_per_second']:.1f} tok/s |" | |
| ) | |
| summary = "\n".join(summary_lines) | |
| outputs.append(summary) | |
| return tuple(outputs) | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| return (error_msg,) * 6 | |
| # ----- Gradio Interface ----- | |
| demo = gr.Blocks(theme=gr.themes.Soft()) | |
| with demo: | |
| gr.Markdown(""" | |
| # π¬ LLM Decoding Strategy Analyzer | |
| Compare 5 text generation decoding strategies side-by-side using GPT-2. | |
| **Research Foundation:** | |
| - Holtzman et al. (2019) - [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) | |
| - Meister et al. (2020) - [If beam search is the answer, what was the question?](https://arxiv.org/abs/2010.02650) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt_input = gr.Textbox( | |
| label="Enter your prompt", | |
| placeholder="In a distant galaxy, a lone astronaut discovered", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| max_tokens_slider = gr.Slider( | |
| minimum=20, | |
| maximum=200, | |
| value=100, | |
| step=10, | |
| label="Max New Tokens" | |
| ) | |
| generate_btn = gr.Button("π Generate All", variant="primary") | |
| gr.Examples( | |
| examples=[ | |
| ["In a distant galaxy, a lone astronaut discovered"], | |
| ["The secret to happiness is"], | |
| ["In the year 2050, artificial intelligence"], | |
| ["She opened the ancient book and read the first line:"], | |
| ["The most important scientific discovery of the century was"], | |
| ], | |
| inputs=prompt_input, | |
| label="Example Prompts" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("## π Generation Results") | |
| gr.Markdown("### Deterministic Methods") | |
| with gr.Row(): | |
| greedy_output = gr.Markdown(label="Greedy Decoding") | |
| beam_output = gr.Markdown(label="Beam Search") | |
| gr.Markdown("### Stochastic Sampling Methods") | |
| with gr.Row(): | |
| topk_output = gr.Markdown(label="Top-K Sampling") | |
| topp_output = gr.Markdown(label="Top-P (Nucleus) Sampling") | |
| temp_output = gr.Markdown(label="Temperature + Top-P") | |
| gr.Markdown("### β±οΈ Performance Summary") | |
| summary_output = gr.Markdown() | |
| with gr.Accordion("π Strategy Explanations", open=False): | |
| gr.Markdown(""" | |
| | Strategy | How It Works | Pros | Cons | | |
| |----------|--------------|------|------| | |
| | **Greedy** | Always picks highest probability token | Fast, deterministic | Repetitive, boring | | |
| | **Beam Search** | Tracks top-k hypotheses simultaneously | More coherent than greedy | Still conservative, slow | | |
| | **Top-K Sampling** | Samples from K most likely tokens | Adds creativity | Fixed K ignores distribution shape | | |
| | **Top-P (Nucleus)** | Samples from smallest set with cumulative prob β₯ p | Adapts to context | Slightly slower | | |
| | **Temperature + Top-P** | Scales logits then applies top-p | Best quality/creativity balance | Requires tuning | | |
| **Key Insight:** Deterministic methods (greedy, beam) maximize probability but produce dull text. | |
| Sampling methods introduce controlled randomness for more human-like output. | |
| """) | |
| generate_btn.click( | |
| fn=run_all_strategies, | |
| inputs=[prompt_input, max_tokens_slider], | |
| outputs=[greedy_output, beam_output, topk_output, topp_output, temp_output, summary_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |