Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import io | |
| from gtts import gTTS | |
| import time | |
| import os | |
| import traceback | |
| # Set page title | |
| st.set_page_config(page_title="Image to Audio Story Generator") | |
| # Title and introduction | |
| st.title("Image to Audio Story Generator") | |
| st.write("Upload a picture and let's create a magical story!") | |
| # Initialize models with better error handling | |
| def load_models(): | |
| try: | |
| image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco") | |
| story_generator = pipeline("text-generation", model="gpt2") | |
| return image_to_text, story_generator, None | |
| except Exception as e: | |
| return None, None, str(e) | |
| # Load models with status indicator | |
| with st.spinner("Loading models..."): | |
| image_to_text, story_generator, error = load_models() | |
| if error: | |
| st.error(f"Failed to load models: {error}") | |
| else: | |
| st.success("Models loaded successfully!") | |
| # Function to generate caption from image | |
| def generate_caption(image): | |
| try: | |
| result = image_to_text(image) | |
| if result and len(result) > 0: | |
| caption = result[0]['generated_text'] | |
| return caption, None | |
| return "An interesting image", "No caption generated" | |
| except Exception as e: | |
| return "An interesting image", str(e) | |
| # Function to generate story from caption (less than 100 words) | |
| def generate_story(caption): | |
| try: | |
| prompt = f"Once upon a time, {caption} " | |
| # Debug output | |
| st.write(f"Prompt: {prompt}") | |
| # Generate with increased timeout and temperature | |
| result = story_generator( | |
| prompt, | |
| max_length=100, | |
| do_sample=True, | |
| temperature=0.9, | |
| top_p=0.95 | |
| ) | |
| # Debug output | |
| st.write(f"Generation result: {result}") | |
| if result and len(result) > 0: | |
| story = result[0]['generated_text'] | |
| # Ensure story doesn't exceed 100 words | |
| words = story.split() | |
| if len(words) > 100: | |
| words = words[:100] | |
| story = " ".join(words) | |
| # Add period to the end if needed | |
| if not story.endswith(('.', '!', '?')): | |
| story += '.' | |
| return story, None | |
| return "Story generation failed.", "No story generated" | |
| except Exception as e: | |
| st.error(f"Error in story generation: {str(e)}") | |
| st.error(traceback.format_exc()) | |
| return "Once upon a time... (Story generation failed)", str(e) | |
| # Function to convert text to speech | |
| def text_to_speech(text): | |
| try: | |
| tts = gTTS(text=text, lang='en', slow=False) | |
| audio_file = "story_audio.mp3" | |
| tts.save(audio_file) | |
| return audio_file, None | |
| except Exception as e: | |
| return None, str(e) | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None and image_to_text is not None and story_generator is not None: | |
| # Display the uploaded image | |
| try: | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption='Uploaded Image', use_container_width=True) | |
| # Generate button | |
| if st.button("Generate Story"): | |
| with st.spinner("Generating your story..."): | |
| # Generate caption | |
| caption, caption_error = generate_caption(image) | |
| if caption_error: | |
| st.warning(f"Caption generation issue: {caption_error}") | |
| st.write("Image caption:", caption) | |
| # Generate story | |
| story, story_error = generate_story(caption) | |
| if story_error: | |
| st.warning(f"Story generation issue: {story_error}") | |
| word_count = len(story.split()) | |
| st.write(f"### Your Story ({word_count} words)") | |
| st.write(story) | |
| # Generate audio | |
| audio_file, audio_error = text_to_speech(story) | |
| if audio_error: | |
| st.warning(f"Audio generation issue: {audio_error}") | |
| else: | |
| # Display audio | |
| st.write("### Listen to your story") | |
| st.audio(audio_file) | |
| except Exception as e: | |
| st.error(f"Error processing image: {str(e)}") | |
| st.error(traceback.format_exc()) | |
| st.markdown("---") | |
| st.write("Created for ISOM5240 Assignment 1") |