testing_testing / app.py
mssaidat's picture
Update app.py
47c889a verified
raw
history blame
1.65 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Load tokenizer & model
#tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto", # Automatically use GPU if available
torch_dtype="auto"
)
# Create generation pipeline
story_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
# Function to generate stories
def generate_story(prompt, max_length=300, temperature=0.8):
outputs = story_generator(
prompt,
max_length=max_length,
temperature=temperature,
do_sample=True,
top_p=0.95,
top_k=50
)
return outputs[0]["generated_text"]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 📖 Interactive Story Generator (open-gpt-oss-20b)")
gr.Markdown("Type a prompt and let the AI continue your story with a powerful 20B model.")
prompt = gr.Textbox(
label="Your Story Prompt",
placeholder="e.g., In the far future, humanity discovered a hidden planet...",
lines=3
)
max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length")
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity")
generate_btn = gr.Button("✨ Generate Story")
output = gr.Textbox(label="Generated Story", lines=20)
generate_btn.click(
fn=generate_story,
inputs=[prompt, max_length, temperature],
outputs=output
)
demo.launch()