Spaces:
Sleeping
Sleeping
| import logging | |
| import json | |
| from typing import Optional, List, Any, Union, Tuple, Dict | |
| from services.story_generator import generate_story | |
| from services.pdf_text_extractor import extract_text_from_pdf | |
| from services.streaming_chapter_processor import process_story_into_chapters_streaming | |
| from services.audio_generator import generate_audio, generate_melody_from_story | |
| from services.mesh_service import get_mesh_base64, transform_base64_to_glb_file | |
| import gradio as gr | |
| from config import constants | |
| from util.mistral_api_client import MistralAPI | |
| logger = logging.getLogger(__name__) | |
| def process_story_generation( | |
| story_type: str, | |
| tone: str, | |
| kid_interests: str, | |
| subject: str, | |
| kid_age: Union[int, float] = constants.DEFAULT_KID_AGE, | |
| kid_language: str = constants.DEFAULT_LANGUAGE, | |
| reading_time: int = constants.DEFAULT_READING_TIME, | |
| pdf_file: Optional[Any] = None, | |
| model_selector: str = constants.DEFAULT_MODEL, | |
| ) -> Tuple[str, str, Any]: | |
| """Process the story generation request from the UI. | |
| Args: | |
| story_type: Type of story to generate | |
| tone: Tone of the story | |
| kid_age: Age of the target child | |
| kid_language: Language the child speaks | |
| kid_interests: Child's interests | |
| subject: Subject of the story | |
| reading_time: Approximate reading time in minutes | |
| pdf_file: Optional PDF file upload | |
| model_selector: Selected AI model | |
| Returns: | |
| str: Generated story or error message | |
| """ | |
| try: | |
| logger.info( | |
| f"Generating story with type: {story_type}, tone: {tone}, subject: {subject}" | |
| ) | |
| # Process PDF if provided | |
| pdf_content = "" | |
| summarized_pdf = "" # Initialize with empty string by default | |
| if pdf_file: | |
| logger.info("Extracting text from PDF") | |
| pdf_content = extract_text_from_pdf(pdf_file) | |
| # summarize the PDF content for better prompting using mistral | |
| if pdf_content and not pdf_content.startswith("Error:"): | |
| mistral_api = MistralAPI() | |
| summarized_pdf = mistral_api.send_request( | |
| f"Summarize the following Text content into a single-sentence children's story without any explanations, tags, or formattingβjust plain text in one line.: {pdf_content}" | |
| )["choices"][0]["message"]["content"] | |
| logger.info(f"summarized_pdf: {summarized_pdf}") | |
| else: | |
| logger.error(f"PDF extraction error: {pdf_content}") | |
| # Generate story | |
| story_response = generate_story( | |
| story_type=story_type, | |
| tone=tone, | |
| kid_age=kid_age, | |
| kid_language=kid_language, | |
| kid_interests=kid_interests, | |
| subject=subject, | |
| reading_time=reading_time, | |
| pdf_content=summarized_pdf, | |
| model_name=model_selector, | |
| ) | |
| if story_response.startswith("Error:"): | |
| logger.error(f"Story generation error: {story_response}") | |
| return "", story_response, gr.update(interactive=False) | |
| try: | |
| # Parse JSON response | |
| story_data = json.loads(story_response) | |
| title = story_data.get("title", "Untitled Story") | |
| story = story_data.get("story", "") | |
| logger.info("Story generated successfully") | |
| return (title, story, gr.update(interactive=True, visible=True)) | |
| except json.JSONDecodeError: | |
| logger.error("Failed to parse story JSON response") | |
| return ( | |
| "", | |
| f"Error: Failed to parse story response: {story_response}", | |
| gr.update(interactive=False), | |
| ) | |
| except Exception as e: | |
| error_msg = f"Unexpected error during story generation: {str(e)}" | |
| logger.error(error_msg, exc_info=True) | |
| return "", f"Error: {error_msg}", gr.update(interactive=False) | |
| def process_chapters( | |
| story_content: str, story_title: str, progress=gr.Progress() | |
| ) -> dict: | |
| """ | |
| Process the generated story into chapters with image prompts. | |
| Args: | |
| story_content: The full story text to process | |
| story_title: The title of the story | |
| progress: Optional Gradio progress indicator | |
| Returns: | |
| dict: Dictionary containing title and chapters data | |
| """ | |
| if not story_content or story_content.startswith("Error:"): | |
| return "Error: Please generate a valid story first." | |
| logger.info("Processing story into chapters with streaming image generation") | |
| try: | |
| # Store for the current chapters data | |
| current_data = {"title": story_title, "chapters": []} | |
| # Callback function to update the UI with each new chapter image | |
| def update_callback(chapters_json): | |
| nonlocal current_data | |
| try: | |
| chapters_data = json.loads(chapters_json) | |
| if "error" in chapters_data: | |
| return f"Error: {chapters_data['error']}" | |
| chapters = chapters_data.get("chapters", []) | |
| current_data = {"title": story_title, "chapters": chapters} | |
| # Count completed images | |
| total_chapters = len(chapters) | |
| completed_images = sum( | |
| 1 for chapter in chapters if chapter.get("image_b64", "") | |
| ) | |
| # Update progress | |
| if total_chapters > 0: | |
| # First 50% is chapter creation, second 50% is image generation | |
| chapter_progress = 0.5 # Chapters are already created at this point | |
| image_progress = 0.5 * (completed_images / total_chapters) | |
| progress( | |
| (chapter_progress + image_progress), | |
| f"Generated {completed_images}/{total_chapters} chapter images", | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in update callback: {e}") | |
| return f"Error in update: {str(e)}" | |
| return current_data | |
| # Start the streaming process | |
| progress(0.05, "Splitting story into chapters...") | |
| process_story_into_chapters_streaming( | |
| story_content, story_title, update_callback=update_callback | |
| ) | |
| # Return the final data structure | |
| return current_data | |
| except Exception as e: | |
| logger.error(f"Failed to process chapters: {e}", exc_info=True) | |
| return f"Error processing chapters: {str(e)}" | |
| # Add chapter processing functionality | |
| def handle_chapter_processing(story_content, story_title, progress=gr.Progress()): | |
| """Handle chapter processing and update state""" | |
| if not story_content or story_content.startswith("Error:"): | |
| return {"error": "Please generate a valid story first."} | |
| gr.Info( | |
| message="Processing story into chapters... <br> Go to the Chapters tab to see updates.", | |
| title="Processing", | |
| ) | |
| # Process chapters and return the data structure | |
| logger.info("Starting chapter processing...") | |
| progress(0.01, "Starting chapter processing...") | |
| try: | |
| # Store for the current chapters data | |
| current_data = {"title": story_title, "chapters": [], "processing": True} | |
| # Callback function to update the UI with each new chapter image | |
| def update_callback(chapters_json): | |
| nonlocal current_data | |
| try: | |
| chapters_data = json.loads(chapters_json) | |
| # Handle progress updates | |
| if "progress" in chapters_data: | |
| prog_data = chapters_data["progress"] | |
| prog_value = prog_data.get("completed", 0) / prog_data.get( | |
| "total", 1 | |
| ) | |
| prog_message = prog_data.get("message", "Processing chapters...") | |
| progress(prog_value, prog_message) | |
| # Handle error cases | |
| if "error" in chapters_data: | |
| current_data = { | |
| "title": story_title, | |
| "error": chapters_data["error"], | |
| } | |
| return current_data | |
| # Update chapters if present | |
| if "chapters" in chapters_data: | |
| chapters = chapters_data.get("chapters", []) | |
| current_data = { | |
| "title": story_title, | |
| "chapters": chapters, | |
| "processing": True, | |
| } | |
| # Check if processing is complete | |
| if ( | |
| "progress" in chapters_data | |
| and chapters_data["progress"].get("stage") == "complete" | |
| ): | |
| current_data["processing"] = False | |
| except Exception as e: | |
| logger.error(f"Error in update callback: {e}") | |
| current_data = { | |
| "title": story_title, | |
| "error": f"Error in update: {str(e)}", | |
| } | |
| return current_data | |
| # Start the streaming process | |
| process_story_into_chapters_streaming( | |
| story_content, story_title, update_callback=update_callback | |
| ) | |
| # Return the final data structure | |
| return current_data | |
| except Exception as e: | |
| logger.error(f"Failed to process chapters: {e}", exc_info=True) | |
| return {"title": story_title, "error": f"Error processing chapters: {str(e)}"} | |
| def generate_audio_with_status(text): | |
| """ | |
| Generate audio from text with status updates for better user experience. | |
| Args: | |
| text (str): Text to convert to audio | |
| Returns: | |
| tuple: (audio_file_path, status_message) | |
| """ | |
| try: | |
| if not text or not text.strip(): | |
| return None, gr.HTML( | |
| "<p class='text'>β οΈ Please provide text to generate audio</p>", | |
| visible=True, | |
| ) | |
| # clean text to avoid issues with special characters also delete "<" | |
| text = text.replace("\n", " ").replace("\r", " ").strip().replace('"', "") | |
| logger.info(f"Generating audio for text: {text[:50]}...") | |
| audio_file_path = generate_audio( | |
| f"[S1] {text}", | |
| ) | |
| logger.info("Audio generation completed successfully") | |
| return audio_file_path, gr.HTML( | |
| "<p class='text'>β Audio generated successfully!</p>", visible=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in audio generation controller: {e}") | |
| error_msg = "<p class='text'>β Audio generation failed</p>" | |
| return None, gr.HTML(error_msg, visible=True) | |
| def generate_melody_from_story_with_status(story_text): | |
| """ | |
| Generate a melody based on story text with status updates for better user experience. | |
| Args: | |
| story_text (str): The story text to generate a melody for. | |
| Returns: | |
| tuple: (audio_file_path, status_message) | |
| """ | |
| try: | |
| if not story_text or not story_text.strip(): | |
| return None, gr.HTML( | |
| "<p class='text'>β οΈ Please provide a story to generate melody</p>", | |
| visible=True, | |
| ) | |
| # Clean text to avoid issues with special characters | |
| story_text = ( | |
| story_text.replace("\n", " ").replace("\r", " ").strip().replace('"', "") | |
| ) | |
| logger.info(f"Generating melody for story: {story_text[:50]}...") | |
| # Show processing status | |
| processing_status = "β³ Analyzing story and generating melody..." | |
| # Generate melody from story text | |
| audio_file_path = generate_melody_from_story(story_text) | |
| logger.info("Melody generation completed successfully") | |
| return audio_file_path, gr.HTML( | |
| "<p class='text'>β Story melody generated successfully!</p>", visible=True | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in melody generation controller: {e}") | |
| error_msg = "<p class='text'>β Melody generation failed</p>" | |
| return None, gr.HTML(error_msg, visible=True) | |
| def generate_3d_model(story_text): | |
| model_response = get_mesh_base64( | |
| text=story_text, apply_texture=False, output_format="glb" | |
| ) | |
| # Check if response contains an error | |
| if "error" in model_response: | |
| return None, f"Error: {model_response['error']}" | |
| # Check if the expected data structure exists | |
| if ( | |
| "model_data" not in model_response | |
| or "mesh_base64" not in model_response["model_data"] | |
| ): | |
| return ( | |
| None, | |
| "Error: Received unexpected response format from 3D model API", | |
| ) | |
| try: | |
| glb_file_path = transform_base64_to_glb_file( | |
| model_response["model_data"]["mesh_base64"] | |
| ) | |
| return (glb_file_path, "generate") | |
| except Exception as e: | |
| return None, f"Error processing model data: {str(e)}" | |
| def clear_fields() -> List[str]: | |
| """ | |
| Clear the subject and story text fields. | |
| Returns: | |
| List[str]: Empty strings for the fields to clear | |
| """ | |
| return ["", "", "", gr.update(interactive=False)] | |