# Imports import streamlit as st from transformers import pipeline from PIL import Image import torch import os import tempfile import time import numpy as np # Use Streamlit's caching mechanisms to optimize model loading @st.cache_resource def load_image_to_text_pipeline(): """Load and cache the image-to-text model""" return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base") @st.cache_resource def load_text_generation_pipeline(): """Load and cache the text generation model""" return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") @st.cache_resource def load_tts_pipeline(): """Load and cache the text-to-speech pipeline as fallback""" try: return pipeline("text-to-speech", model="facebook/mms-tts-eng") except: # Return None if loading fails return None # Initialize all models at app startup with st.spinner("Loading models (this may take a moment the first time)..."): # Load all models at startup and cache them img2text_model = load_image_to_text_pipeline() story_generator_model = load_text_generation_pipeline() tts_fallback_model = load_tts_pipeline() # For TTS, try multiple options in order of preference try: # Try importing gTTS from gtts import gTTS has_gtts = True except ImportError: has_gtts = False if tts_fallback_model is None: st.warning("No text-to-speech capability available. Audio generation will be disabled.") # Cache the text-to-audio conversion @st.cache_data def text2audio(story_text): """Convert text to audio with caching to avoid regenerating the same audio""" if has_gtts: # Use gTTS temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') temp_filename = temp_file.name temp_file.close() # Use gTTS to convert text to speech tts = gTTS(text=story_text, lang='en', slow=False) tts.save(temp_filename) # Read the audio file with open(temp_filename, 'rb') as audio_file: audio_bytes = audio_file.read() # Clean up the temporary file os.unlink(temp_filename) return audio_bytes, 'audio/mp3' elif tts_fallback_model is not None: # Use transformers TTS speech = tts_fallback_model(story_text) # Return the audio data if 'audio' in speech: return speech['audio'], speech.get('sampling_rate', 16000) elif 'audio_array' in speech: return speech['audio_array'], speech.get('sampling_rate', 16000) # If we got here, no TTS method worked raise Exception("No text-to-speech capability available") # Convert PIL Image to bytes for hashing in cache def get_image_bytes(pil_img): """Convert PIL image to bytes for hashing""" import io buf = io.BytesIO() pil_img.save(buf, format='JPEG') return buf.getvalue() # Simple image-to-text function using cached model @st.cache_data def img2text(image_bytes): """Convert image to text with caching - using bytes for caching compatibility""" # Convert bytes back to PIL image for processing import io from PIL import Image pil_img = Image.open(io.BytesIO(image_bytes)) # Process with the model result = img2text_model(pil_img) return result[0]["generated_text"] # Helper function to count words def count_words(text): return len(text.split()) # Improved text-to-story function without "Once upon a time" constraint @st.cache_data def text2story(text): """Generate a story from text with caching""" # Ask for a story without specifying how to start prompt = f"""Write a children's story based on this: {text}. The story should have a clear beginning, middle, and end. Make the story approximately 150-200 words long with descriptive language. """ # Generate a longer text to ensure we get a complete story story_result = story_generator_model( prompt, max_length=500, num_return_sequences=1, temperature=0.7, do_sample=True ) full_text = story_result[0]['generated_text'] # Try to extract just the story part (after the prompt) # Look for paragraph breaks or clear story beginnings potential_starts = [ "\n\n", "\n", ". ", "! ", "? " ] # Find where the prompt ends and the actual story begins story_text = full_text # First remove the exact prompt if it appears verbatim if prompt in story_text: story_text = story_text.replace(prompt, "") else: # Look for paragraph breaks or sentence endings that might indicate # where the prompt instructions end and the story begins for start_marker in potential_starts: if start_marker in story_text: parts = story_text.split(start_marker, 1) if len(parts[0]) < len(story_text) * 0.5: # If the first part is reasonably short story_text = parts[1] break # Clean up any leading/trailing whitespace story_text = story_text.strip() # Find natural ending points (end of sentences) periods = [i for i, char in enumerate(story_text) if char == '.'] question_marks = [i for i, char in enumerate(story_text) if char == '?'] exclamation_marks = [i for i, char in enumerate(story_text) if char == '!'] # Combine all ending punctuation and sort all_endings = sorted(periods + question_marks + exclamation_marks) # Target approximately 100 words target_word_count = 100 min_acceptable_words = 80 # If we have any sentence endings if all_endings: # Find the sentence ending that gets us closest to 100 words closest_ending = None closest_word_diff = float('inf') for ending_idx in all_endings: candidate_text = story_text[:ending_idx+1] candidate_word_count = count_words(candidate_text) # Only consider endings that give us at least min_acceptable_words if candidate_word_count >= min_acceptable_words: word_diff = abs(candidate_word_count - target_word_count) if word_diff < closest_word_diff: closest_ending = ending_idx closest_word_diff = word_diff # If we found a suitable ending, use it if closest_ending is not None: return story_text[:closest_ending+1] # If we couldn't find a good ending near 100 words, but we have some sentence endings, # use the last one that results in a story with at least min_acceptable_words words if all_endings: for ending_idx in reversed(all_endings): candidate_text = story_text[:ending_idx+1] if count_words(candidate_text) >= min_acceptable_words: return candidate_text # If no good ending is found, return as is return story_text # Function to reset progress when a new file is uploaded def reset_progress(): st.session_state.progress = { 'caption_generated': False, 'story_generated': False, 'audio_generated': False, 'caption': '', 'story': '', 'audio_data': None, 'audio_format': None } # Basic Streamlit interface st.title("Image to Audio Story") # Add processing status indicator status_container = st.empty() # Initialize session state for tracking progress if 'progress' not in st.session_state: st.session_state.progress = { 'caption_generated': False, 'story_generated': False, 'audio_generated': False, 'caption': '', 'story': '', 'audio_data': None, 'audio_format': None } # File uploader uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress) # Process the image if uploaded if uploaded_file is not None: # Display image st.image(uploaded_file, caption="Uploaded Image") # Convert to PIL Image image = Image.open(uploaded_file) # Convert image to bytes for caching compatibility image_bytes = get_image_bytes(image) # Image to Text (if not already done) if not st.session_state.progress['caption_generated']: status_container.info("Generating caption...") st.session_state.progress['caption'] = img2text(image_bytes) st.session_state.progress['caption_generated'] = True st.write(f"Caption: {st.session_state.progress['caption']}") # Text to Story (if not already done) if not st.session_state.progress['story_generated']: status_container.info("Creating story...") st.session_state.progress['story'] = text2story(st.session_state.progress['caption']) st.session_state.progress['story_generated'] = True # Display word count for transparency word_count = count_words(st.session_state.progress['story']) st.write(f"Story ({word_count} words):") st.write(st.session_state.progress['story']) # Pre-generate audio in background (if not already done) if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None): status_container.info("Pre-generating audio in background...") try: st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story']) st.session_state.progress['audio_generated'] = True status_container.success("Ready to play audio!") except Exception as e: status_container.error(f"Error pre-generating audio: {e}") # Button to play audio if st.button("Play the audio"): if st.session_state.progress['audio_generated']: # Display the audio player if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'): st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format']) else: st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format']) else: # Handle case where audio generation failed or is not available st.error("Unable to play audio. Audio generation was not successful.") else: status_container.info("Upload an image to begin")