File size: 2,328 Bytes
fed242e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# app.py
import gradio as gr
from transformers import pipeline, set_seed
import torch

# Load generator (will download GPT-2 the first time)
device = 0 if torch.cuda.is_available() else -1
generator = pipeline("text-generation", model="gpt2", device=device)
set_seed(42)

def generate_blog(title, keywords, max_length, temperature, num_return_sequences):
    # Build a short prompt for the model
    prompt = f"Blog Title: {title}\nKeywords: {keywords}\n\nWrite a clear, friendly blog post about the above:"
    # Generate text
    outputs = generator(
        prompt,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        num_return_sequences=num_return_sequences,
    )

    # Extract generated text and remove the prompt prefix
    posts = []
    for out in outputs:
        text = out["generated_text"]
        # remove prompt part if present
        if text.startswith(prompt):
            text = text[len(prompt):].strip()
        posts.append(text.strip())

    # If only one sequence requested, return string; otherwise return joined with separators
    if len(posts) == 1:
        return posts[0]
    return "\n\n-----\n\n".join(posts)

# Build a minimal Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Simple GPT-2 Blog Generator")
    with gr.Row():
        title_input = gr.Textbox(label="Blog Title", placeholder="Enter your blog title...", lines=1)
        keywords_input = gr.Textbox(label="Keywords / Short brief", placeholder="e.g., sustainable travel, packing tips", lines=1)
    with gr.Row():
        max_len = gr.Slider(label="Max tokens (approx.)", minimum=50, maximum=800, step=10, value=250)
        temp = gr.Slider(label="Temperature (creativity)", minimum=0.1, maximum=1.5, step=0.1, value=0.8)
    with gr.Row():
        nseq = gr.Slider(label="Number of outputs", minimum=1, maximum=3, step=1, value=1)
    generate_btn = gr.Button("Generate Blog Post")
    output = gr.Textbox(label="Generated Blog Post", lines=20)

    generate_btn.click(
        fn=generate_blog,
        inputs=[title_input, keywords_input, max_len, temp, nseq],
        outputs=output,
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", share=False)