maldons77's picture
Upload 4 files
2d0a45d verified
import gradio as gr
from transformers import pipeline, set_seed
# Lazy-load the pipeline so Spaces can warm it up on first run
_generator = None
def get_generator(model_name: str):
global _generator
if _generator is None or getattr(_generator, "model_name", None) != model_name:
_generator = pipeline(
"text-generation",
model=model_name,
# device_map="auto" # Commented to avoid GPU requirement in CPU Spaces
)
_generator.model_name = model_name
return _generator
def generate_text(prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences):
if not prompt or not prompt.strip():
return "Please enter a non-empty prompt."
if seed is not None and seed != "":
try:
set_seed(int(seed))
except Exception:
pass
generator = get_generator(model_name)
outputs = generator(
prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
num_return_sequences=int(num_return_sequences),
pad_token_id=generator.tokenizer.eos_token_id,
)
return "\n\n---\n\n".join(o["generated_text"] for o in outputs)
with gr.Blocks(title="Text Generation Demo") as demo:
gr.Markdown(
"""
# Text Generation Demo
Minimal, education-focused demo using 🤗 Transformers.
- **Models**: pick from lightweight, CPU-friendly models (default: `gpt2`).
- **Use case**: learning and experimentation in NLP (no harmful or restricted use).
"""
)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
placeholder="Artificial intelligence is transforming the world because...",
lines=4,
)
with gr.Row():
model_name = gr.Dropdown(
label="Model",
choices=[
"gpt2",
"distilgpt2",
"gpt2-medium",
],
value="gpt2",
)
max_new_tokens = gr.Slider(16, 256, value=80, step=1, label="Max new tokens")
with gr.Row():
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus)")
with gr.Row():
seed = gr.Textbox(label="Seed (optional)", placeholder="e.g., 42")
num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="# of completions")
generate_btn = gr.Button("Generate")
output = gr.Textbox(label="Output", lines=12)
generate_btn.click(
generate_text,
inputs=[prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences],
outputs=output,
)
if __name__ == "__main__":
demo.launch()