Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, set_seed | |
| # Lazy-load the pipeline so Spaces can warm it up on first run | |
| _generator = None | |
| def get_generator(model_name: str): | |
| global _generator | |
| if _generator is None or getattr(_generator, "model_name", None) != model_name: | |
| _generator = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| # device_map="auto" # Commented to avoid GPU requirement in CPU Spaces | |
| ) | |
| _generator.model_name = model_name | |
| return _generator | |
| def generate_text(prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences): | |
| if not prompt or not prompt.strip(): | |
| return "Please enter a non-empty prompt." | |
| if seed is not None and seed != "": | |
| try: | |
| set_seed(int(seed)) | |
| except Exception: | |
| pass | |
| generator = get_generator(model_name) | |
| outputs = generator( | |
| prompt, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=True, | |
| num_return_sequences=int(num_return_sequences), | |
| pad_token_id=generator.tokenizer.eos_token_id, | |
| ) | |
| return "\n\n---\n\n".join(o["generated_text"] for o in outputs) | |
| with gr.Blocks(title="Text Generation Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # Text Generation Demo | |
| Minimal, education-focused demo using 🤗 Transformers. | |
| - **Models**: pick from lightweight, CPU-friendly models (default: `gpt2`). | |
| - **Use case**: learning and experimentation in NLP (no harmful or restricted use). | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Artificial intelligence is transforming the world because...", | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| label="Model", | |
| choices=[ | |
| "gpt2", | |
| "distilgpt2", | |
| "gpt2-medium", | |
| ], | |
| value="gpt2", | |
| ) | |
| max_new_tokens = gr.Slider(16, 256, value=80, step=1, label="Max new tokens") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus)") | |
| with gr.Row(): | |
| seed = gr.Textbox(label="Seed (optional)", placeholder="e.g., 42") | |
| num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="# of completions") | |
| generate_btn = gr.Button("Generate") | |
| output = gr.Textbox(label="Output", lines=12) | |
| generate_btn.click( | |
| generate_text, | |
| inputs=[prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences], | |
| outputs=output, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |