write-with-gpt2 / app.py
Tralalabs's picture
Upload app.py
bf42056 verified
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Load model and tokenizer once at startup
print("Loading GPT-2 Small (124M)...")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
print("Model loaded on CPU!")
def generate_text(
prompt: str,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
do_sample: bool,
):
if not prompt.strip():
return "⚠️ Please enter a prompt!"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
with torch.no_grad():
output = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else 1.0,
top_k=top_k if do_sample else 50,
top_p=top_p if do_sample else 1.0,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
pad_token_id=tokenizer.eos_token_id,
)
# Decode only the newly generated tokens
generated_ids = output[0][input_ids.shape[-1]:]
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
return prompt + generated_text
# Gradio UI
with gr.Blocks(title="Write with GPT-2") as demo:
gr.Markdown(
"""
# ✍️ Write with GPT-2
**Model:** [openai-community/gpt2](https://huggingface.co/openai-community/gpt2)  | 
**Hardware:** CPU only  | 
**Space by:** [Tralalabs](https://huggingface.co/Tralalabs)
> Classic GPT-2 Small running fully on CPU. Enter a prompt and let it continue!
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Once upon a time...",
lines=4,
)
with gr.Accordion("βš™οΈ Generation Settings", open=False):
max_new_tokens = gr.Slider(
minimum=10, maximum=300, value=100, step=10,
label="Max New Tokens"
)
do_sample = gr.Checkbox(
value=True, label="Sampling (uncheck for greedy decoding)"
)
temperature = gr.Slider(
minimum=0.1, maximum=2.0, value=0.9, step=0.05,
label="Temperature"
)
top_k = gr.Slider(
minimum=0, maximum=100, value=50, step=1,
label="Top-K"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.01,
label="Top-P (nucleus sampling)"
)
repetition_penalty = gr.Slider(
minimum=1.0, maximum=2.0, value=1.1, step=0.05,
label="Repetition Penalty"
)
generate_btn = gr.Button("πŸš€ Generate", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Generated Text",
lines=12,
interactive=False,
)
gr.Examples(
examples=[
["The future of artificial intelligence is", 120, 0.9, 50, 0.95, 1.1, True],
["Once upon a time in a land far away,", 150, 1.0, 40, 0.92, 1.2, True],
["Scientists recently discovered that", 100, 0.8, 50, 0.95, 1.1, True],
["Dear Claude, I wanted to tell you that", 100, 0.95, 60, 0.9, 1.0, True],
],
inputs=[prompt_input, max_new_tokens, temperature, top_k, top_p, repetition_penalty, do_sample],
outputs=output_text,
fn=generate_text,
cache_examples=False,
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_new_tokens, temperature, top_k, top_p, repetition_penalty, do_sample],
outputs=output_text,
)
gr.Markdown(
"""
---
**Tips:**
- Lower temperature (0.5–0.7) β†’ more focused, coherent text
- Higher temperature (1.0–1.5) β†’ more creative, unpredictable text
- Greedy decoding (uncheck Sampling) β†’ always picks the most likely next token
- Repetition Penalty > 1.0 helps avoid loops
"""
)
if __name__ == "__main__":
demo.launch()