Spaces:
Sleeping
Sleeping
| # Imports | |
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import torch | |
| import os | |
| import tempfile | |
| import time | |
| import numpy as np | |
| # Use Streamlit's caching mechanisms to optimize model loading | |
| def load_image_to_text_pipeline(): | |
| """Load and cache the image-to-text model""" | |
| return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base") | |
| def load_text_generation_pipeline(): | |
| """Load and cache the text generation model""" | |
| return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| def load_tts_pipeline(): | |
| """Load and cache the text-to-speech pipeline as fallback""" | |
| try: | |
| return pipeline("text-to-speech", model="facebook/mms-tts-eng") | |
| except: | |
| # Return None if loading fails | |
| return None | |
| # Initialize all models at app startup | |
| with st.spinner("Loading models (this may take a moment the first time)..."): | |
| # Load all models at startup and cache them | |
| img2text_model = load_image_to_text_pipeline() | |
| story_generator_model = load_text_generation_pipeline() | |
| tts_fallback_model = load_tts_pipeline() | |
| # For TTS, try multiple options in order of preference | |
| try: | |
| # Try importing gTTS | |
| from gtts import gTTS | |
| has_gtts = True | |
| except ImportError: | |
| has_gtts = False | |
| if tts_fallback_model is None: | |
| st.warning("No text-to-speech capability available. Audio generation will be disabled.") | |
| # Cache the text-to-audio conversion | |
| def text2audio(story_text): | |
| """Convert text to audio with caching to avoid regenerating the same audio""" | |
| if has_gtts: | |
| # Use gTTS | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') | |
| temp_filename = temp_file.name | |
| temp_file.close() | |
| # Use gTTS to convert text to speech | |
| tts = gTTS(text=story_text, lang='en', slow=False) | |
| tts.save(temp_filename) | |
| # Read the audio file | |
| with open(temp_filename, 'rb') as audio_file: | |
| audio_bytes = audio_file.read() | |
| # Clean up the temporary file | |
| os.unlink(temp_filename) | |
| return audio_bytes, 'audio/mp3' | |
| elif tts_fallback_model is not None: | |
| # Use transformers TTS | |
| speech = tts_fallback_model(story_text) | |
| # Return the audio data | |
| if 'audio' in speech: | |
| return speech['audio'], speech.get('sampling_rate', 16000) | |
| elif 'audio_array' in speech: | |
| return speech['audio_array'], speech.get('sampling_rate', 16000) | |
| # If we got here, no TTS method worked | |
| raise Exception("No text-to-speech capability available") | |
| # Convert PIL Image to bytes for hashing in cache | |
| def get_image_bytes(pil_img): | |
| """Convert PIL image to bytes for hashing""" | |
| import io | |
| buf = io.BytesIO() | |
| pil_img.save(buf, format='JPEG') | |
| return buf.getvalue() | |
| # Simple image-to-text function using cached model | |
| def img2text(image_bytes): | |
| """Convert image to text with caching - using bytes for caching compatibility""" | |
| # Convert bytes back to PIL image for processing | |
| import io | |
| from PIL import Image | |
| pil_img = Image.open(io.BytesIO(image_bytes)) | |
| # Process with the model | |
| result = img2text_model(pil_img) | |
| return result[0]["generated_text"] | |
| # Helper function to count words | |
| def count_words(text): | |
| return len(text.split()) | |
| # Improved text-to-story function without "Once upon a time" constraint | |
| def text2story(text): | |
| """Generate a story from text with caching""" | |
| # Ask for a story without specifying how to start | |
| prompt = f"""Write a children's story based on this: {text}. | |
| The story should have a clear beginning, middle, and end. | |
| Make the story approximately 150-200 words long with descriptive language. | |
| """ | |
| # Generate a longer text to ensure we get a complete story | |
| story_result = story_generator_model( | |
| prompt, | |
| max_length=500, | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| full_text = story_result[0]['generated_text'] | |
| # Try to extract just the story part (after the prompt) | |
| # Look for paragraph breaks or clear story beginnings | |
| potential_starts = [ | |
| "\n\n", | |
| "\n", | |
| ". ", | |
| "! ", | |
| "? " | |
| ] | |
| # Find where the prompt ends and the actual story begins | |
| story_text = full_text | |
| # First remove the exact prompt if it appears verbatim | |
| if prompt in story_text: | |
| story_text = story_text.replace(prompt, "") | |
| else: | |
| # Look for paragraph breaks or sentence endings that might indicate | |
| # where the prompt instructions end and the story begins | |
| for start_marker in potential_starts: | |
| if start_marker in story_text: | |
| parts = story_text.split(start_marker, 1) | |
| if len(parts[0]) < len(story_text) * 0.5: # If the first part is reasonably short | |
| story_text = parts[1] | |
| break | |
| # Clean up any leading/trailing whitespace | |
| story_text = story_text.strip() | |
| # Find natural ending points (end of sentences) | |
| periods = [i for i, char in enumerate(story_text) if char == '.'] | |
| question_marks = [i for i, char in enumerate(story_text) if char == '?'] | |
| exclamation_marks = [i for i, char in enumerate(story_text) if char == '!'] | |
| # Combine all ending punctuation and sort | |
| all_endings = sorted(periods + question_marks + exclamation_marks) | |
| # Target approximately 100 words | |
| target_word_count = 100 | |
| min_acceptable_words = 80 | |
| # If we have any sentence endings | |
| if all_endings: | |
| # Find the sentence ending that gets us closest to 100 words | |
| closest_ending = None | |
| closest_word_diff = float('inf') | |
| for ending_idx in all_endings: | |
| candidate_text = story_text[:ending_idx+1] | |
| candidate_word_count = count_words(candidate_text) | |
| # Only consider endings that give us at least min_acceptable_words | |
| if candidate_word_count >= min_acceptable_words: | |
| word_diff = abs(candidate_word_count - target_word_count) | |
| if word_diff < closest_word_diff: | |
| closest_ending = ending_idx | |
| closest_word_diff = word_diff | |
| # If we found a suitable ending, use it | |
| if closest_ending is not None: | |
| return story_text[:closest_ending+1] | |
| # If we couldn't find a good ending near 100 words, but we have some sentence endings, | |
| # use the last one that results in a story with at least min_acceptable_words words | |
| if all_endings: | |
| for ending_idx in reversed(all_endings): | |
| candidate_text = story_text[:ending_idx+1] | |
| if count_words(candidate_text) >= min_acceptable_words: | |
| return candidate_text | |
| # If no good ending is found, return as is | |
| return story_text | |
| # Function to reset progress when a new file is uploaded | |
| def reset_progress(): | |
| st.session_state.progress = { | |
| 'caption_generated': False, | |
| 'story_generated': False, | |
| 'audio_generated': False, | |
| 'caption': '', | |
| 'story': '', | |
| 'audio_data': None, | |
| 'audio_format': None | |
| } | |
| # Basic Streamlit interface | |
| st.title("Image to Audio Story") | |
| # Add processing status indicator | |
| status_container = st.empty() | |
| # Initialize session state for tracking progress | |
| if 'progress' not in st.session_state: | |
| st.session_state.progress = { | |
| 'caption_generated': False, | |
| 'story_generated': False, | |
| 'audio_generated': False, | |
| 'caption': '', | |
| 'story': '', | |
| 'audio_data': None, | |
| 'audio_format': None | |
| } | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress) | |
| # Process the image if uploaded | |
| if uploaded_file is not None: | |
| # Display image | |
| st.image(uploaded_file, caption="Uploaded Image") | |
| # Convert to PIL Image | |
| image = Image.open(uploaded_file) | |
| # Convert image to bytes for caching compatibility | |
| image_bytes = get_image_bytes(image) | |
| # Image to Text (if not already done) | |
| if not st.session_state.progress['caption_generated']: | |
| status_container.info("Generating caption...") | |
| st.session_state.progress['caption'] = img2text(image_bytes) | |
| st.session_state.progress['caption_generated'] = True | |
| st.write(f"Caption: {st.session_state.progress['caption']}") | |
| # Text to Story (if not already done) | |
| if not st.session_state.progress['story_generated']: | |
| status_container.info("Creating story...") | |
| st.session_state.progress['story'] = text2story(st.session_state.progress['caption']) | |
| st.session_state.progress['story_generated'] = True | |
| # Display word count for transparency | |
| word_count = count_words(st.session_state.progress['story']) | |
| st.write(f"Story ({word_count} words):") | |
| st.write(st.session_state.progress['story']) | |
| # Pre-generate audio in background (if not already done) | |
| if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None): | |
| status_container.info("Pre-generating audio in background...") | |
| try: | |
| st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story']) | |
| st.session_state.progress['audio_generated'] = True | |
| status_container.success("Ready to play audio!") | |
| except Exception as e: | |
| status_container.error(f"Error pre-generating audio: {e}") | |
| # Button to play audio | |
| if st.button("Play the audio"): | |
| if st.session_state.progress['audio_generated']: | |
| # Display the audio player | |
| if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'): | |
| st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format']) | |
| else: | |
| st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format']) | |
| else: | |
| # Handle case where audio generation failed or is not available | |
| st.error("Unable to play audio. Audio generation was not successful.") | |
| else: | |
| status_container.info("Upload an image to begin") |