File size: 2,847 Bytes
2d0a45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
from transformers import pipeline, set_seed

# Lazy-load the pipeline so Spaces can warm it up on first run
_generator = None
def get_generator(model_name: str):
    global _generator
    if _generator is None or getattr(_generator, "model_name", None) != model_name:
        _generator = pipeline(
            "text-generation",
            model=model_name,
            # device_map="auto"  # Commented to avoid GPU requirement in CPU Spaces
        )
        _generator.model_name = model_name
    return _generator

def generate_text(prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences):
    if not prompt or not prompt.strip():
        return "Please enter a non-empty prompt."
    if seed is not None and seed != "":
        try:
            set_seed(int(seed))
        except Exception:
            pass
    generator = get_generator(model_name)
    outputs = generator(
        prompt,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=True,
        num_return_sequences=int(num_return_sequences),
        pad_token_id=generator.tokenizer.eos_token_id,
    )
    return "\n\n---\n\n".join(o["generated_text"] for o in outputs)

with gr.Blocks(title="Text Generation Demo") as demo:
    gr.Markdown(
        """
        # Text Generation Demo
        Minimal, education-focused demo using 🤗 Transformers.

        - **Models**: pick from lightweight, CPU-friendly models (default: `gpt2`).
        - **Use case**: learning and experimentation in NLP (no harmful or restricted use).
        """
    )
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            placeholder="Artificial intelligence is transforming the world because...",
            lines=4,
        )
    with gr.Row():
        model_name = gr.Dropdown(
            label="Model",
            choices=[
                "gpt2",
                "distilgpt2",
                "gpt2-medium",
            ],
            value="gpt2",
        )
        max_new_tokens = gr.Slider(16, 256, value=80, step=1, label="Max new tokens")
    with gr.Row():
        temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
        top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p (nucleus)")
    with gr.Row():
        seed = gr.Textbox(label="Seed (optional)", placeholder="e.g., 42")
        num_return_sequences = gr.Slider(1, 3, value=1, step=1, label="# of completions")
    generate_btn = gr.Button("Generate")
    output = gr.Textbox(label="Output", lines=12)
    generate_btn.click(
        generate_text,
        inputs=[prompt, model_name, max_new_tokens, temperature, top_p, seed, num_return_sequences],
        outputs=output,
    )

if __name__ == "__main__":
    demo.launch()