Spaces:
Sleeping
Sleeping
File size: 3,047 Bytes
07e9fc6 da08414 47c889a da08414 ec2c283 07e9fc6 51194d5 ec2c283 51194d5 896a6d3 ec2c283 896a6d3 51194d5 896a6d3 51194d5 ec2c283 cbeffa4 ec2c283 cbeffa4 ec2c283 cbeffa4 ec2c283 51194d5 ec2c283 896a6d3 ec2c283 896a6d3 ec2c283 896a6d3 ec2c283 896a6d3 51194d5 ec2c283 51194d5 f1c133b ec2c283 51194d5 f1c133b 51194d5 896a6d3 51194d5 ec2c283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# --- Load tokenizer & model ---
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto", # use GPU if available
torch_dtype="auto" # pick the best dtype
)
# Ensure pad token is set for safe generation
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(model.config, "pad_token_id", None) is None:
model.config.pad_token_id = tokenizer.pad_token_id
# --- Pipeline ---
generator = pipeline(
task="text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=False
)
# --- Prompt builder (renamed to avoid shadowing) ---
def build_prompt(user_text: str) -> str:
user_text = (user_text or "").strip()
messages = [
{"role": "system", "content": "You are a helpful storyteller that writes engaging prose."},
{"role": "user", "content": user_text}
]
# Use chat template if available, else fallback
if hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return (
"System: You are a helpful storyteller that writes engaging prose.\n"
f"User: {user_text}\n"
"Assistant:"
)
# --- Generation function ---
def generate_story(prompt, max_tokens=300, temperature=0.8):
try:
prompt_str = build_prompt(prompt) # <<< FIX: no name shadowing
outputs = generator(
prompt_str,
max_new_tokens=int(max_tokens),
temperature=float(temperature),
do_sample=True,
top_p=0.95,
top_k=50,
repetition_penalty=1.05,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# pipelines return a list of dicts with "generated_text"
return outputs[0].get("generated_text", "")
except Exception as e:
return f"Error during generation: {type(e).__name__}: {e}"
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# 📖 Interactive Story Generator (TinyLlama/TinyLlama-1.1B-Chat-v1.0)")
gr.Markdown("Type a prompt and let the AI continue your story with a compact chat model.")
prompt = gr.Textbox(
label="My Story Prompt",
placeholder="e.g., In the far future, humanity discovered a hidden planet...",
lines=3
)
max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length in new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity")
generate_btn = gr.Button("✨ Generate Story")
output = gr.Textbox(label="Generated Story", lines=20)
generate_btn.click(
fn=generate_story,
inputs=[prompt, max_length, temperature],
outputs=output
)
demo.launch()
|