Spaces:
Running
Running
| """ | |
| GPT-2 Text Generation — Autoregressive decoding with sampling controls | |
| Courses: 100 ch4, 200 ch4 | |
| """ | |
| import torch | |
| import gradio as gr | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| # Load model once | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| model = GPT2LMHeadModel.from_pretrained("gpt2").eval() | |
| device = torch.device("cpu") | |
| model.to(device) | |
| def generate( | |
| prompt: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| show_token_probs: bool, | |
| ): | |
| if not prompt.strip(): | |
| return "", "" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| input_len = inputs["input_ids"].shape[1] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=max(temperature, 0.01), # avoid div by zero | |
| top_k=top_k if top_k > 0 else None, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| full_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generated_part = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) | |
| # Format output: bold the prompt | |
| display = f"**{prompt}**{generated_part}" | |
| # Token probabilities for first few generated tokens | |
| probs_text = "" | |
| if show_token_probs: | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[0, -1, :] # last position | |
| probs = torch.softmax(logits / max(temperature, 0.01), dim=0) | |
| top10 = torch.topk(probs, 10) | |
| probs_text = "**Next-token probabilities (first position):**\n\n" | |
| probs_text += "| Token | Probability |\n|---|---|\n" | |
| for prob, idx in zip(top10.values, top10.indices): | |
| token = tokenizer.decode([idx]) | |
| bar_len = int(float(prob) * 40) | |
| bar = "█" * bar_len | |
| probs_text += f"| `{token}` | {bar} {float(prob):.2%} |\n" | |
| return display, probs_text | |
| with gr.Blocks(title="GPT-2 Text Generation") as demo: | |
| gr.Markdown( | |
| "# GPT-2 Text Generation\n" | |
| "Enter a prompt and experiment with decoding parameters to see how they affect output.\n" | |
| "*Courses: 100 Deep Learning ch4, 200 Transformer ch4*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Once upon a time, in a land far away...", | |
| lines=3, | |
| ) | |
| max_tokens = gr.Slider(10, 200, value=80, step=10, label="Max New Tokens") | |
| temperature = gr.Slider(0.1, 2.0, value=0.8, step=0.1, label="Temperature") | |
| top_k = gr.Slider(0, 100, value=50, step=5, label="Top-k (0 = no limit)") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p (nucleus)") | |
| rep_penalty = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty") | |
| show_probs = gr.Checkbox(value=True, label="Show token probabilities") | |
| btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| output_md = gr.Markdown(label="Generated Text") | |
| probs_md = gr.Markdown(label="Token Probabilities") | |
| btn.click( | |
| generate, | |
| [prompt_input, max_tokens, temperature, top_k, top_p, rep_penalty, show_probs], | |
| [output_md, probs_md], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["The meaning of life is", 80, 0.8, 50, 0.9, 1.2, True], | |
| ["In the year 2050, artificial intelligence", 100, 0.7, 40, 0.95, 1.1, True], | |
| ["def fibonacci(n):\n ", 60, 0.5, 30, 0.9, 1.3, True], | |
| ["Once upon a time, a robot learned to", 100, 1.0, 0, 0.9, 1.0, True], | |
| ], | |
| inputs=[prompt_input, max_tokens, temperature, top_k, top_p, rep_penalty, show_probs], | |
| ) | |
| with gr.Accordion("Parameter Guide", open=False): | |
| gr.Markdown(""" | |
| **Temperature**: Controls randomness. Low (0.1-0.3) = focused/repetitive. High (1.0-2.0) = creative/chaotic. | |
| **Top-k**: Only consider the top-k most likely tokens. Lower = more focused. 0 = no limit. | |
| **Top-p (nucleus sampling)**: Only consider tokens whose cumulative probability exceeds p. Lower = more focused. | |
| **Repetition Penalty**: Penalizes tokens that already appeared. >1.0 reduces repetition. | |
| Try these experiments: | |
| 1. Set temperature=0.1 → very deterministic, same output each time | |
| 2. Set temperature=2.0 → chaotic, often incoherent | |
| 3. Set top_k=5 → very restricted vocabulary | |
| 4. Compare top_p=0.5 vs top_p=1.0 | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |