Spaces:
Sleeping
Sleeping
File size: 2,328 Bytes
fed242e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# app.py
import gradio as gr
from transformers import pipeline, set_seed
import torch
# Load generator (will download GPT-2 the first time)
device = 0 if torch.cuda.is_available() else -1
generator = pipeline("text-generation", model="gpt2", device=device)
set_seed(42)
def generate_blog(title, keywords, max_length, temperature, num_return_sequences):
# Build a short prompt for the model
prompt = f"Blog Title: {title}\nKeywords: {keywords}\n\nWrite a clear, friendly blog post about the above:"
# Generate text
outputs = generator(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=num_return_sequences,
)
# Extract generated text and remove the prompt prefix
posts = []
for out in outputs:
text = out["generated_text"]
# remove prompt part if present
if text.startswith(prompt):
text = text[len(prompt):].strip()
posts.append(text.strip())
# If only one sequence requested, return string; otherwise return joined with separators
if len(posts) == 1:
return posts[0]
return "\n\n-----\n\n".join(posts)
# Build a minimal Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Simple GPT-2 Blog Generator")
with gr.Row():
title_input = gr.Textbox(label="Blog Title", placeholder="Enter your blog title...", lines=1)
keywords_input = gr.Textbox(label="Keywords / Short brief", placeholder="e.g., sustainable travel, packing tips", lines=1)
with gr.Row():
max_len = gr.Slider(label="Max tokens (approx.)", minimum=50, maximum=800, step=10, value=250)
temp = gr.Slider(label="Temperature (creativity)", minimum=0.1, maximum=1.5, step=0.1, value=0.8)
with gr.Row():
nseq = gr.Slider(label="Number of outputs", minimum=1, maximum=3, step=1, value=1)
generate_btn = gr.Button("Generate Blog Post")
output = gr.Textbox(label="Generated Blog Post", lines=20)
generate_btn.click(
fn=generate_blog,
inputs=[title_input, keywords_input, max_len, temp, nseq],
outputs=output,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", share=False)
|