Story_generator / app.py
AdarshCodes98's picture
Upload 2 files
17a864c verified
import gradio as gr
from transformers import pipeline
import os
# Set Hugging Face endpoint to a mirror to avoid connection issues (WinError 10054/ISP blocks)
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
# 1. Initialize the Generator
# We use distilgpt2 because it is small (~350MB) and runs fast on a standard CPU.
print("Loading model... this may take a minute on the first run.")
generator = pipeline("text-generation", model="MBZUAI/LaMini-GPT-124M")
def generate_story(prompt, length, temp, top_p):
if not prompt.strip():
return "Please enter a prompt to start the story!"
# 1. The Magic Trick: We format the prompt behind the scenes
# We tell it it's an instruction, and we forcefully start the response for it.
formatted_prompt = f"### Instruction:\n{prompt}\n\n### Response:\nOnce upon a time,"
# 2. Generate the text
output = generator(
formatted_prompt,
max_new_tokens=length,
temperature=temp,
top_p=top_p,
do_sample=True,
truncation=True,
# Use the tokenizer's end-of-sentence token for safe padding
pad_token_id=generator.tokenizer.eos_token_id
)
generated_text = output[0]['generated_text']
# 3. Clean up the output so the user doesn't see the ugly "### Instruction:" part
# We split the text at "### Response:\n" and only keep the story part.
if "### Response:\n" in generated_text:
story_only = generated_text.split("### Response:\n")[1].strip()
else:
story_only = generated_text # Fallback just in case
return story_only
# 2. Build the Gradio Interface
# This defines the layout, sliders, and text boxes.
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ✍️ Lightweight AI Story Generator")
gr.Markdown("Enter a starting sentence and let DistilGPT2 finish the story.")
with gr.Row():
with gr.Column():
# Inputs
input_text = gr.Textbox(
label="Story Prompt",
placeholder="Once upon a time, a small robot found a glowing seed...",
lines=4
)
with gr.Accordion("Advanced Settings (Parameters)", open=False):
slider_len = gr.Slider(10, 200, value=80, step=10, label="Max New Tokens")
slider_temp = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature (Creativity)")
slider_top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)")
btn = gr.Button("Generate Story", variant="primary")
with gr.Column():
# Output
output_text = gr.Textbox(label="Generated Story", lines=12)
# Link the button to the function
btn.click(
fn=generate_story,
inputs=[input_text, slider_len, slider_temp, slider_top_p],
outputs=output_text
)
# 3. Launch the Application
if __name__ == "__main__":
demo.launch()