Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import tempfile | |
| import time | |
| import re | |
| import asyncio | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| # Import the custom modules | |
| from llm import get_llm | |
| from prompt import story_request, generate_story, image_request, generate_image_prompt | |
| from flux import generate_image | |
| from docx import Document | |
| from docx.shared import Inches | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Create the FastAPI instance | |
| app = FastAPI( | |
| title="Bedtime Story Generator API", | |
| description="API to generate a bedtime story with images and save as a docx document.", | |
| version="1.0.0" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Pydantic model for validating the incoming story parameters | |
| # --------------------------------------------------------------------------- | |
| class StoryParams(BaseModel): | |
| Age: str | |
| Theme: str | |
| Pages: int | |
| Time: int | |
| Tone: str | |
| Setting: str | |
| Moral: str | |
| # --------------------------------------------------------------------------- | |
| # Helper functions (wrapped from your provided code) | |
| # --------------------------------------------------------------------------- | |
| def inference(llm_instance, story_params: dict) -> str: | |
| """ | |
| Generates the story text from the LLM based on user parameters. | |
| """ | |
| req = story_request( | |
| Age=story_params["Age"], | |
| Theme=story_params["Theme"], | |
| Pages=story_params["Pages"], | |
| Time=story_params["Time"], | |
| Tone=story_params["Tone"], | |
| Setting=story_params["Setting"], | |
| Moral=story_params["Moral"] | |
| ) | |
| prompt_text = generate_story(req) | |
| print("\nGenerating story. Please wait...\n") | |
| response = llm_instance.invoke(prompt_text) | |
| return response.content | |
| def parse_story_sections(story_text: str) -> list: | |
| """ | |
| Parses the LLM-generated story into sections using markers enclosed in '**'. | |
| """ | |
| pattern = r'\*\*(.*?)\*\*\s*' | |
| matches = list(re.finditer(pattern, story_text, flags=re.DOTALL)) | |
| sections = [] | |
| for i, match in enumerate(matches): | |
| marker = match.group(1).strip() | |
| start = match.end() | |
| end = matches[i+1].start() if (i+1) < len(matches) else len(story_text) | |
| content = story_text[start:end].strip() | |
| section_text = f"{marker}\n\n{content}" if content else marker | |
| sections.append(section_text) | |
| return sections | |
| def generate_images_for_sections(sections: list, style: str = "sketch") -> list: | |
| """ | |
| Generates an image for each story section. | |
| """ | |
| image_paths = [] | |
| for idx, section in enumerate(sections): | |
| print(f"Generating image for section {idx+1}...") | |
| img_req = image_request(style=style, bedtime_story_content=section) | |
| img_prompt = generate_image_prompt(img_req) | |
| image = generate_image(img_prompt) | |
| if image: | |
| temp_dir = tempfile.gettempdir() | |
| image_filename = os.path.join(temp_dir, f"section_{idx+1}_{uuid.uuid4().hex}.png") | |
| image.save(image_filename) | |
| image_paths.append(image_filename) | |
| print(f"Image for section {idx+1} saved as {image_filename}\n") | |
| else: | |
| print(f"Failed to generate image for section {idx+1}.\n") | |
| image_paths.append(None) | |
| time.sleep(1) # Optional pause between image generations | |
| return image_paths | |
| def save_story_to_docx(sections: list, image_paths: list, output_filename: str) -> None: | |
| """ | |
| Saves the story sections and images into a formatted Word document. | |
| """ | |
| document = Document() | |
| # If the first section is a title, use it as the document title. | |
| if sections and sections[0].startswith("Title:"): | |
| lines = sections[0].splitlines() | |
| title_line = lines[0].strip() # e.g., "Title: The Amazing Adventure" | |
| title_text = title_line.replace("Title:", "").strip() | |
| document.core_properties.title = title_text | |
| document.add_heading(title_text, level=1) | |
| sections = sections[1:] | |
| if image_paths: | |
| image_paths = image_paths[1:] | |
| # Process remaining sections. | |
| for idx, section in enumerate(sections): | |
| lines = section.splitlines() | |
| if not lines: | |
| continue | |
| first_line = lines[0].strip() | |
| if any(first_line.startswith(marker) for marker in ["Opening Hook:", "Page", "Ending", "The End"]): | |
| document.add_heading(first_line, level=2) | |
| remaining_text = "\n".join(lines[1:]).strip() | |
| if remaining_text: | |
| document.add_paragraph(remaining_text) | |
| else: | |
| document.add_paragraph(section) | |
| # Insert the corresponding image (if available). | |
| if idx < len(image_paths) and image_paths[idx]: | |
| try: | |
| document.add_picture(image_paths[idx], width=Inches(4)) | |
| except Exception as e: | |
| print(f"Error inserting image for section {idx+1}: {e}") | |
| document.save(output_filename) | |
| print(f"\n📖 Story saved to: {output_filename}") | |
| def generate_story_docx(story_params: dict) -> str: | |
| """ | |
| Complete pipeline: | |
| - Validates the API key | |
| - Generates the story text via the LLM | |
| - Parses the story into sections | |
| - Generates images for each section | |
| - Saves the complete story with images as a Word document | |
| Returns the filename of the saved document. | |
| """ | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| raise Exception("Error: OPENAI_API_KEY not found in environment variables.") | |
| llm_instance = get_llm(OPENAI_API_KEY) | |
| # Generate the story text from the LLM | |
| story_text = inference(llm_instance, story_params) | |
| print("\nStory generated successfully!\n") | |
| # Parse the story text into sections | |
| sections = parse_story_sections(story_text) | |
| # Generate images for each section | |
| image_paths = generate_images_for_sections(sections, style="sketch") | |
| # Create a unique filename for the docx file in a temporary directory | |
| output_filename = os.path.join(tempfile.gettempdir(), f"bedtime_story_{uuid.uuid4().hex}.docx") | |
| # Save the story and images to the Word document | |
| save_story_to_docx(sections, image_paths, output_filename=output_filename) | |
| return output_filename | |
| # --------------------------------------------------------------------------- | |
| # API Endpoints | |
| # --------------------------------------------------------------------------- | |
| async def root(): | |
| """ | |
| Returns a welcome message and a link to the API documentation. | |
| """ | |
| return { | |
| "message": "Welcome to the Bedtime Story Generator API!", | |
| "documentation": "/docs" | |
| } | |
| async def generate_story_endpoint(story_params: StoryParams): | |
| """ | |
| API endpoint that runs the complete story-generation pipeline. | |
| It accepts story parameters as JSON, processes the story and images, | |
| and returns a downloadable Word document. | |
| """ | |
| try: | |
| # Run the blocking story generation in a separate thread | |
| docx_file = await asyncio.to_thread(generate_story_docx, story_params.dict()) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return FileResponse( | |
| path=docx_file, | |
| media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", | |
| filename=os.path.basename(docx_file) | |
| ) | |
| async def health(): | |
| return {"status": "ok"} | |
| # --------------------------------------------------------------------------- | |
| # Run the server with: uvicorn main:app --reload | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | |