Spaces:
Sleeping
Sleeping
| # import part | |
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| # Set global caching options for Transformers | |
| from transformers import set_caching_enabled | |
| set_caching_enabled(True) | |
| # function part with caching for better performance | |
| def load_image_captioning_model(): | |
| return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base") | |
| def load_text_generator(): | |
| return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
| def load_tts_model(): | |
| return pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1") | |
| # img2text - Using the original model with more constraints | |
| def img2text(image): | |
| # Load the model (cached) | |
| image_to_text = load_image_captioning_model() | |
| # Strongly limit output length for speed | |
| text = image_to_text(image, max_new_tokens=15)[0]["generated_text"] | |
| return text | |
| # text2story - Much more constrained for speed | |
| def text2story(text): | |
| # Load the model (cached) | |
| generator = load_text_generator() | |
| # Very brief prompt to minimize work | |
| prompt = f"Short story about {text}: Once upon a time, " | |
| # Very constrained parameters for maximum speed | |
| story_result = generator( | |
| prompt, | |
| max_new_tokens=60, # Much shorter output | |
| num_return_sequences=1, | |
| temperature=0.7, | |
| top_k=10, # Lower value = faster | |
| top_p=0.9, # Lower value = faster | |
| do_sample=True | |
| ) | |
| # Extract and clean text | |
| story_text = story_result[0]['generated_text'] | |
| story_text = story_text.replace(prompt, "Once upon a time, ") | |
| # Find a natural ending point | |
| last_period = story_text.rfind('.') | |
| if last_period > 30: # Ensure we have at least some content | |
| story_text = story_text[:last_period + 1] | |
| return story_text | |
| # text2audio - Minimal text for faster processing | |
| def text2audio(story_text): | |
| try: | |
| # Load the model (cached) | |
| synthesizer = load_tts_model() | |
| # Aggressively limit text length to speed up TTS | |
| max_chars = 200 # Much shorter than before | |
| if len(story_text) > max_chars: | |
| last_period = story_text[:max_chars].rfind('.') | |
| if last_period > 0: | |
| story_text = story_text[:last_period + 1] | |
| else: | |
| story_text = story_text[:max_chars] | |
| # Generate speech | |
| speech = synthesizer(story_text) | |
| return speech | |
| except Exception as e: | |
| st.error(f"Error generating audio: {str(e)}") | |
| return None | |
| # Streamlined main UI | |
| st.set_page_config(page_title="Image to Story", page_icon="π") | |
| st.header("Image to Audio Story") | |
| # Add info about processing time | |
| st.info("Note: Processing may take some time as the models are loading. Please be patient.") | |
| # Cache the file uploader state | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state["uploaded_file"] = None | |
| uploaded_file = st.file_uploader("Select an Image...", key="file_uploader") | |
| # Process the image if uploaded | |
| if uploaded_file is not None: | |
| st.session_state["uploaded_file"] = uploaded_file | |
| # Display the uploaded image | |
| st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) | |
| # Convert to PIL image | |
| image = Image.open(uploaded_file) | |
| # Optional processing toggle to let user decide | |
| if st.button("Generate Story and Audio"): | |
| col1, col2 = st.columns(2) | |
| # Stage 1: Image to Text with minimal output | |
| with col1: | |
| with st.spinner('Captioning image...'): | |
| caption = img2text(image) | |
| st.write(f"**Caption:** {caption}") | |
| # Stage 2: Text to Story with minimal length | |
| with col2: | |
| with st.spinner('Creating story...'): | |
| story = text2story(caption) | |
| st.write(f"**Story:** {story}") | |
| # Stage 3: Audio with minimal text | |
| with st.spinner('Generating audio...'): | |
| speech_output = text2audio(story) | |
| # Display audio immediately | |
| if speech_output is not None: | |
| try: | |
| if 'audio' in speech_output and 'sampling_rate' in speech_output: | |
| st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate']) | |
| elif 'audio_array' in speech_output and 'sampling_rate' in speech_output: | |
| st.audio(speech_output['audio_array'], sample_rate=speech_output['sampling_rate']) | |
| elif 'waveform' in speech_output and 'sample_rate' in speech_output: | |
| st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate']) | |
| else: | |
| # Try any array-like data | |
| for key, value in speech_output.items(): | |
| if hasattr(value, '__len__') and len(value) > 1000: | |
| sample_rate = speech_output.get('sampling_rate', speech_output.get('sample_rate', 24000)) | |
| st.audio(value, sample_rate=sample_rate) | |
| break | |
| else: | |
| st.error("Could not find audio data in the output") | |
| except Exception as e: | |
| st.error(f"Error playing audio: {str(e)}") | |
| else: | |
| st.error("Audio generation failed") |