Assignment1 / app.py
CR7CAD's picture
Update app.py
b8a0fec verified
raw
history blame
4.61 kB
import streamlit as st
from transformers import pipeline
from PIL import Image
import io
from gtts import gTTS
import time
import os
import traceback
# Set page title
st.set_page_config(page_title="Story Generator for Kids")
# Title and introduction
st.title("Story Generator for Kids")
st.write("Upload a picture and let's create a magical story!")
# Initialize models with better error handling
@st.cache_resource
def load_models():
try:
image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
story_generator = pipeline("text-generation", model="gpt2")
return image_to_text, story_generator, None
except Exception as e:
return None, None, str(e)
# Load models with status indicator
with st.spinner("Loading models..."):
image_to_text, story_generator, error = load_models()
if error:
st.error(f"Failed to load models: {error}")
# Function to generate caption from image
def generate_caption(image):
try:
result = image_to_text(image)
if result and len(result) > 0:
caption = result[0]['generated_text']
return caption, None
return "An interesting image", "No caption generated"
except Exception as e:
return "An interesting image", str(e)
# Function to generate story from caption (less than 100 words)
def generate_story(caption):
try:
prompt = f"Once upon a time, {caption} "
# Generate with increased timeout and temperature
result = story_generator(
prompt,
max_length=100,
do_sample=True,
temperature=0.9,
top_p=0.95
)
if result and len(result) > 0:
story = result[0]['generated_text']
# Ensure story doesn't exceed 100 words
words = story.split()
if len(words) > 100:
words = words[:100]
story = " ".join(words)
# Add period to the end if needed
if not story.endswith(('.', '!', '?')):
story += '.'
return story, None
return "Story generation failed.", "No story generated"
except Exception as e:
return "Once upon a time... (Story generation failed)", str(e)
# Function to convert text to speech
def text_to_speech(text):
try:
tts = gTTS(text=text, lang='en', slow=False)
audio_file = "story_audio.mp3"
tts.save(audio_file)
return audio_file, None
except Exception as e:
return None, str(e)
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None and image_to_text is not None and story_generator is not None:
# Display the uploaded image
try:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_container_width=True)
# Generate button
if st.button("Generate Story"):
with st.spinner("Generating your story..."):
# Generate caption
caption, caption_error = generate_caption(image)
if caption_error:
st.warning(f"Caption generation issue: {caption_error}", icon="⚠️")
# Display the caption (without debug information)
st.write("Image caption:", caption)
# Generate story
story, story_error = generate_story(caption)
if story_error and not st.session_state.get("deployed", True):
st.warning(f"Story generation issue: {story_error}", icon="⚠️")
# Display the story (without debug information)
word_count = len(story.split())
st.write(f"### Your Story ({word_count} words)")
st.write(story)
# Generate audio
audio_file, audio_error = text_to_speech(story)
if audio_error and not st.session_state.get("deployed", True):
st.warning(f"Audio generation issue: {audio_error}", icon="⚠️")
elif audio_file:
# Display audio
st.write("### Listen to your story")
st.audio(audio_file)
except Exception as e:
if not st.session_state.get("deployed", True):
st.error(f"Error processing image: {str(e)}")
st.markdown("---")
st.write("Created for ISOM5240 Assignment 1")