File size: 4,582 Bytes
90bef38
5b9e396
90bef38
 
 
 
b038974
 
e1ee436
90bef38
 
e1ee436
90bef38
 
 
5b9e396
b038974
90bef38
 
b038974
 
 
 
 
 
90bef38
b038974
 
 
 
 
 
 
90bef38
 
 
b038974
 
 
 
 
 
 
 
90bef38
1e8cc2c
90bef38
b038974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90bef38
 
 
b038974
 
 
 
 
 
 
90bef38
 
 
 
b038974
90bef38
b038974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b9e396
90bef38
b038974
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import streamlit as st
from transformers import pipeline
from PIL import Image
import io
from gtts import gTTS
import time
import os
import traceback

# Set page title
st.set_page_config(page_title="Kids Story Generator")

# Title and introduction
st.title("Kids Story Generator")
st.write("Upload a picture and let's create a magical story!")

# Initialize models with better error handling
@st.cache_resource
def load_models():
    try:
        image_to_text = pipeline("image-to-text", model="microsoft/git-base-coco")
        story_generator = pipeline("text-generation", model="gpt2")
        return image_to_text, story_generator, None
    except Exception as e:
        return None, None, str(e)

# Load models with status indicator
with st.spinner("Loading models..."):
    image_to_text, story_generator, error = load_models()
    if error:
        st.error(f"Failed to load models: {error}")
    else:
        st.success("Models loaded successfully!")

# Function to generate caption from image
def generate_caption(image):
    try:
        result = image_to_text(image)
        if result and len(result) > 0:
            caption = result[0]['generated_text']
            return caption, None
        return "An interesting image", "No caption generated"
    except Exception as e:
        return "An interesting image", str(e)

# Function to generate story from caption (less than 100 words)
def generate_story(caption):
    try:
        prompt = f"Once upon a time, {caption} "
        
        # Debug output
        st.write(f"Prompt: {prompt}")
        
        # Generate with increased timeout and temperature
        result = story_generator(
            prompt, 
            max_length=100, 
            do_sample=True, 
            temperature=0.9,
            top_p=0.95
        )
        
        # Debug output
        st.write(f"Generation result: {result}")
        
        if result and len(result) > 0:
            story = result[0]['generated_text']
            
            # Ensure story doesn't exceed 100 words
            words = story.split()
            if len(words) > 100:
                words = words[:100]
                story = " ".join(words)
                # Add period to the end if needed
                if not story.endswith(('.', '!', '?')):
                    story += '.'
            
            return story, None
        return "Story generation failed.", "No story generated"
    except Exception as e:
        st.error(f"Error in story generation: {str(e)}")
        st.error(traceback.format_exc())
        return "Once upon a time... (Story generation failed)", str(e)

# Function to convert text to speech
def text_to_speech(text):
    try:
        tts = gTTS(text=text, lang='en', slow=False)
        audio_file = "story_audio.mp3"
        tts.save(audio_file)
        return audio_file, None
    except Exception as e:
        return None, str(e)

# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None and image_to_text is not None and story_generator is not None:
    # Display the uploaded image
    try:
        image = Image.open(uploaded_file)
        st.image(image, caption='Uploaded Image', use_container_width=True)
        
        # Generate button
        if st.button("Generate Story"):
            with st.spinner("Generating your story..."):
                # Generate caption
                caption, caption_error = generate_caption(image)
                if caption_error:
                    st.warning(f"Caption generation issue: {caption_error}")
                st.write("Image caption:", caption)
                
                # Generate story
                story, story_error = generate_story(caption)
                if story_error:
                    st.warning(f"Story generation issue: {story_error}")
                word_count = len(story.split())
                st.write(f"### Your Story ({word_count} words)")
                st.write(story)
                
                # Generate audio
                audio_file, audio_error = text_to_speech(story)
                if audio_error:
                    st.warning(f"Audio generation issue: {audio_error}")
                else:
                    # Display audio
                    st.write("### Listen to your story")
                    st.audio(audio_file)
    except Exception as e:
        st.error(f"Error processing image: {str(e)}")
        st.error(traceback.format_exc())

st.markdown("---")
st.write("Created for ISOM5240_Assignment1")