# import part import streamlit as st from transformers import pipeline from PIL import Image # Set global caching options for Transformers from transformers import set_caching_enabled set_caching_enabled(True) # function part with caching for better performance @st.cache_resource def load_image_captioning_model(): return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base") @st.cache_resource def load_text_generator(): return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") @st.cache_resource def load_tts_model(): return pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1") # img2text - Using the original model with more constraints def img2text(image): # Load the model (cached) image_to_text = load_image_captioning_model() # Strongly limit output length for speed text = image_to_text(image, max_new_tokens=15)[0]["generated_text"] return text # text2story - Much more constrained for speed def text2story(text): # Load the model (cached) generator = load_text_generator() # Very brief prompt to minimize work prompt = f"Short story about {text}: Once upon a time, " # Very constrained parameters for maximum speed story_result = generator( prompt, max_new_tokens=60, # Much shorter output num_return_sequences=1, temperature=0.7, top_k=10, # Lower value = faster top_p=0.9, # Lower value = faster do_sample=True ) # Extract and clean text story_text = story_result[0]['generated_text'] story_text = story_text.replace(prompt, "Once upon a time, ") # Find a natural ending point last_period = story_text.rfind('.') if last_period > 30: # Ensure we have at least some content story_text = story_text[:last_period + 1] return story_text # text2audio - Minimal text for faster processing def text2audio(story_text): try: # Load the model (cached) synthesizer = load_tts_model() # Aggressively limit text length to speed up TTS max_chars = 200 # Much shorter than before if len(story_text) > max_chars: last_period = story_text[:max_chars].rfind('.') if last_period > 0: story_text = story_text[:last_period + 1] else: story_text = story_text[:max_chars] # Generate speech speech = synthesizer(story_text) return speech except Exception as e: st.error(f"Error generating audio: {str(e)}") return None # Streamlined main UI st.set_page_config(page_title="Image to Story", page_icon="📚") st.header("Image to Audio Story") # Add info about processing time st.info("Note: Processing may take some time as the models are loading. Please be patient.") # Cache the file uploader state if "uploaded_file" not in st.session_state: st.session_state["uploaded_file"] = None uploaded_file = st.file_uploader("Select an Image...", key="file_uploader") # Process the image if uploaded if uploaded_file is not None: st.session_state["uploaded_file"] = uploaded_file # Display the uploaded image st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) # Convert to PIL image image = Image.open(uploaded_file) # Optional processing toggle to let user decide if st.button("Generate Story and Audio"): col1, col2 = st.columns(2) # Stage 1: Image to Text with minimal output with col1: with st.spinner('Captioning image...'): caption = img2text(image) st.write(f"**Caption:** {caption}") # Stage 2: Text to Story with minimal length with col2: with st.spinner('Creating story...'): story = text2story(caption) st.write(f"**Story:** {story}") # Stage 3: Audio with minimal text with st.spinner('Generating audio...'): speech_output = text2audio(story) # Display audio immediately if speech_output is not None: try: if 'audio' in speech_output and 'sampling_rate' in speech_output: st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate']) elif 'audio_array' in speech_output and 'sampling_rate' in speech_output: st.audio(speech_output['audio_array'], sample_rate=speech_output['sampling_rate']) elif 'waveform' in speech_output and 'sample_rate' in speech_output: st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate']) else: # Try any array-like data for key, value in speech_output.items(): if hasattr(value, '__len__') and len(value) > 1000: sample_rate = speech_output.get('sampling_rate', speech_output.get('sample_rate', 24000)) st.audio(value, sample_rate=sample_rate) break else: st.error("Could not find audio data in the output") except Exception as e: st.error(f"Error playing audio: {str(e)}") else: st.error("Audio generation failed")