Assignment1 / app.py
CR7CAD's picture
Update app.py
1fb1e8e verified
raw
history blame
4.6 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import io
from gtts import gTTS
import time
import os
import traceback
# Set page title
st.set_page_config(page_title="Image to Audio Story Generator")
# Title and introduction
st.title("Image to Audio Story Generator")
st.write("Upload a picture and let's create a magical story!")
# Initialize models with better error handling
@st.cache_resource
def load_models():
try:
image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
story_generator = pipeline("text-generation", model="gpt2")
return image_to_text, story_generator, None
except Exception as e:
return None, None, str(e)
# Load models with status indicator
with st.spinner("Loading models..."):
image_to_text, story_generator, error = load_models()
if error:
st.error(f"Failed to load models: {error}")
else:
st.success("Models loaded successfully!")
# Function to generate caption from image
def generate_caption(image):
try:
result = image_to_text(image)
if result and len(result) > 0:
caption = result[0]['generated_text']
return caption, None
return "An interesting image", "No caption generated"
except Exception as e:
return "An interesting image", str(e)
# Function to generate story from caption (less than 100 words)
def generate_story(caption):
try:
prompt = f"Once upon a time, {caption} "
# Debug output
st.write(f"Prompt: {prompt}")
# Generate with increased timeout and temperature
result = story_generator(
prompt,
max_length=100,
do_sample=True,
temperature=0.9,
top_p=0.95
)
# Debug output
st.write(f"Generation result: {result}")
if result and len(result) > 0:
story = result[0]['generated_text']
# Ensure story doesn't exceed 100 words
words = story.split()
if len(words) > 100:
words = words[:100]
story = " ".join(words)
# Add period to the end if needed
if not story.endswith(('.', '!', '?')):
story += '.'
return story, None
return "Story generation failed.", "No story generated"
except Exception as e:
st.error(f"Error in story generation: {str(e)}")
st.error(traceback.format_exc())
return "Once upon a time... (Story generation failed)", str(e)
# Function to convert text to speech
def text_to_speech(text):
try:
tts = gTTS(text=text, lang='en', slow=False)
audio_file = "story_audio.mp3"
tts.save(audio_file)
return audio_file, None
except Exception as e:
return None, str(e)
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None and image_to_text is not None and story_generator is not None:
# Display the uploaded image
try:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_container_width=True)
# Generate button
if st.button("Generate Story"):
with st.spinner("Generating your story..."):
# Generate caption
caption, caption_error = generate_caption(image)
if caption_error:
st.warning(f"Caption generation issue: {caption_error}")
st.write("Image caption:", caption)
# Generate story
story, story_error = generate_story(caption)
if story_error:
st.warning(f"Story generation issue: {story_error}")
word_count = len(story.split())
st.write(f"### Your Story ({word_count} words)")
st.write(story)
# Generate audio
audio_file, audio_error = text_to_speech(story)
if audio_error:
st.warning(f"Audio generation issue: {audio_error}")
else:
# Display audio
st.write("### Listen to your story")
st.audio(audio_file)
except Exception as e:
st.error(f"Error processing image: {str(e)}")
st.error(traceback.format_exc())
st.markdown("---")
st.write("Created for ISOM5240 Assignment 1")