import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # --- Load tokenizer & model --- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", # use GPU if available torch_dtype="auto" # pick the best dtype ) # Ensure pad token is set for safe generation if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if getattr(model.config, "pad_token_id", None) is None: model.config.pad_token_id = tokenizer.pad_token_id # --- Pipeline --- generator = pipeline( task="text-generation", model=model, tokenizer=tokenizer, return_full_text=False ) # --- Prompt builder (renamed to avoid shadowing) --- def build_prompt(user_text: str) -> str: user_text = (user_text or "").strip() messages = [ {"role": "system", "content": "You are a helpful storyteller that writes engaging prose."}, {"role": "user", "content": user_text} ] # Use chat template if available, else fallback if hasattr(tokenizer, "apply_chat_template"): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return ( "System: You are a helpful storyteller that writes engaging prose.\n" f"User: {user_text}\n" "Assistant:" ) # --- Generation function --- def generate_story(prompt, max_tokens=300, temperature=0.8): try: prompt_str = build_prompt(prompt) # <<< FIX: no name shadowing outputs = generator( prompt_str, max_new_tokens=int(max_tokens), temperature=float(temperature), do_sample=True, top_p=0.95, top_k=50, repetition_penalty=1.05, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # pipelines return a list of dicts with "generated_text" return outputs[0].get("generated_text", "") except Exception as e: return f"Error during generation: {type(e).__name__}: {e}" # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# 📖 Interactive Story Generator (TinyLlama/TinyLlama-1.1B-Chat-v1.0)") gr.Markdown("Type a prompt and let the AI continue your story with a compact chat model.") prompt = gr.Textbox( label="My Story Prompt", placeholder="e.g., In the far future, humanity discovered a hidden planet...", lines=3 ) max_length = gr.Slider(50, 1000, value=300, step=50, label="Story Length in new tokens") temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Creativity") generate_btn = gr.Button("✨ Generate Story") output = gr.Textbox(label="Generated Story", lines=20) generate_btn.click( fn=generate_story, inputs=[prompt, max_length, temperature], outputs=output ) demo.launch()