Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| import os | |
| from huggingface_hub import login | |
| # --- App Configuration --- | |
| TITLE = "✍️ AI Story Outliner" | |
| DESCRIPTION = """ | |
| Enter a prompt and get 10 unique story outlines from a powerful AI model. | |
| The app uses **Mistral-7B-v0.1**, a popular and capable open-source model, to generate creative outlines. | |
| **How it works:** | |
| 1. Enter your story idea. | |
| 2. The AI will generate 10 different story outlines. | |
| 3. Each outline has a dramatic beginning and is concise, like a song. | |
| """ | |
| # --- Example Prompts for Storytelling --- | |
| examples = [ | |
| ["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], | |
| ["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], | |
| ["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but for what the dragon said when it finally spoke."], | |
| ["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], | |
| ["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] | |
| ] | |
| # --- Model Initialization --- | |
| # This section loads the Mistral-7B model, which requires authentication. | |
| # It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces. | |
| generator = None | |
| model_error = None | |
| try: | |
| print("Initializing model... This may take a moment.") | |
| # Explicitly load the token from environment variables (for HF Spaces secrets). | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| print("✅ HF_TOKEN secret found. Logging in...") | |
| # Programmatically log in to Hugging Face. This is a more robust method. | |
| login(token=hf_token) | |
| print("✅ Login successful.") | |
| else: | |
| # If no token is found, raise an error to prevent the app from crashing later. | |
| raise ValueError("Hugging Face token not found. Please set the HF_TOKEN secret in your Space settings.") | |
| # Using 'mistralai/Mistral-7B-v0.1'. | |
| # After login(), the token argument is no longer needed here as the session is authenticated. | |
| generator = pipeline( | |
| "text-generation", | |
| model="mistralai/Mistral-7B-v0.1", | |
| torch_dtype=torch.bfloat16, # More performant data type | |
| device_map="auto" # Will use GPU if available, otherwise CPU | |
| ) | |
| print("✅ mistralai/Mistral-7B-v0.1 model loaded successfully!") | |
| except Exception as e: | |
| model_error = e | |
| print(f"--- 🚨 Error loading model ---") | |
| print(f"Error: {model_error}") | |
| # --- App Logic --- | |
| def generate_stories(prompt: str) -> list[str]: | |
| """ | |
| Generates 10 story outlines from the loaded model based on the user's prompt. | |
| """ | |
| print("--- Button clicked. Attempting to generate stories... ---") | |
| # If the model failed to load during startup, display that error. | |
| if model_error: | |
| error_message = f"**Model failed to load during startup.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(model_error)}`" | |
| print(f"Returning startup error: {error_message}") | |
| return [error_message] * 10 | |
| if not prompt: | |
| # Return a list of 10 empty strings to clear the outputs | |
| return [""] * 10 | |
| try: | |
| # This prompt format is optimized for Mistral instruct models. | |
| story_prompt = f"""[INST] Create a short story outline based on this idea: "{prompt}" | |
| The outline must have three parts: a dramatic hook, a concise ballad, and a satisfying finale. Use emojis for each section header. [/INST] | |
| ### 🎬 The Hook | |
| """ | |
| # Parameters for the pipeline to generate 10 diverse results. | |
| params = { | |
| "max_new_tokens": 250, | |
| "num_return_sequences": 10, | |
| "do_sample": True, | |
| "temperature": 0.8, | |
| "top_p": 0.95, | |
| "pad_token_id": generator.tokenizer.eos_token_id | |
| } | |
| print("Generating text with the model...") | |
| # Generate 10 different story variations | |
| outputs = generator(story_prompt, **params) | |
| print("✅ Text generation complete.") | |
| # Extract the generated text. | |
| stories = [] | |
| for out in outputs: | |
| # The model will generate the prompt plus the continuation. We extract just the new part. | |
| full_text = out['generated_text'] | |
| # Split by the instruction closing tag to get only the model's response | |
| generated_part = full_text.split("[/INST]")[-1].strip() | |
| stories.append(generated_part) | |
| # Ensure we return exactly 10 stories, padding if necessary. | |
| while len(stories) < 10: | |
| stories.append("Failed to generate a story for this slot.") | |
| return stories | |
| except Exception as e: | |
| # Catch any errors that happen DURING generation and display them in the UI. | |
| print(f"--- 🚨 Error during story generation ---") | |
| print(f"Error: {e}") | |
| runtime_error_message = f"**An error occurred during story generation.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(e)}`" | |
| return [runtime_error_message] * 10 | |
| # --- Gradio Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: | |
| gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_area = gr.TextArea( | |
| lines=5, | |
| label="Your Story Prompt 👇", | |
| placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'" | |
| ) | |
| generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("## 📖 Your 10 Story Outlines") | |
| # Create 10 markdown components to display the stories in two columns | |
| story_outputs = [] | |
| with gr.Row(): | |
| with gr.Column(): | |
| for i in range(5): | |
| md = gr.Markdown(label=f"Story Outline {i + 1}") | |
| story_outputs.append(md) | |
| with gr.Column(): | |
| for i in range(5, 10): | |
| md = gr.Markdown(label=f"Story Outline {i + 1}") | |
| story_outputs.append(md) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_area, | |
| label="Example Story Starters (Click to use)" | |
| ) | |
| generate_button.click( | |
| fn=generate_stories, | |
| inputs=input_area, | |
| outputs=story_outputs, | |
| api_name="generate" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |