File size: 10,594 Bytes
8fe6281
90bef38
8d5fabf
ab8ead3
8fe6281
 
 
fc13d66
15c1038
fc13d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862568a
ce9aea5
862568a
fc13d66
862568a
fc13d66
 
 
 
 
 
 
 
 
 
 
 
ce9aea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc13d66
 
 
ce9aea5
 
 
 
 
 
fc13d66
 
 
118cd25
15c1038
 
 
 
 
 
 
 
fc13d66
 
15c1038
 
 
 
 
 
 
 
 
fc13d66
8d5fabf
5518670
 
 
 
 
fc13d66
5f21a2d
fc13d66
5518670
 
 
 
 
5f21a2d
9d38390
fc13d66
5f21a2d
5518670
5f21a2d
 
 
 
 
5518670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4bc18
 
 
 
 
 
 
 
 
5518670
 
 
 
7c4bc18
 
5518670
 
 
7c4bc18
5518670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c4bc18
9d38390
5f21a2d
 
15c1038
 
fc13d66
 
 
 
 
 
 
 
 
 
15c1038
 
fc13d66
15c1038
 
 
 
 
fc13d66
 
 
 
 
 
 
 
 
 
15c1038
 
 
fc13d66
ab8ead3
ad4186a
 
f006a50
ad4186a
ab8ead3
f006a50
15c1038
 
 
fc13d66
 
 
15c1038
fc13d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fe6281
fc13d66
 
 
8fe6281
fc13d66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import torch
import os
import tempfile
import time
import numpy as np

# Use Streamlit's caching mechanisms to optimize model loading
@st.cache_resource
def load_image_to_text_pipeline():
    """Load and cache the image-to-text model"""
    return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")

@st.cache_resource
def load_text_generation_pipeline():
    """Load and cache the text generation model"""
    return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")

@st.cache_resource
def load_tts_pipeline():
    """Load and cache the text-to-speech pipeline as fallback"""
    try:
        return pipeline("text-to-speech", model="facebook/mms-tts-eng")
    except:
        # Return None if loading fails
        return None

# Initialize all models at app startup
with st.spinner("Loading models (this may take a moment the first time)..."):
    # Load all models at startup and cache them
    img2text_model = load_image_to_text_pipeline()
    story_generator_model = load_text_generation_pipeline()
    tts_fallback_model = load_tts_pipeline()

# For TTS, try multiple options in order of preference
try:
    # Try importing gTTS
    from gtts import gTTS
    has_gtts = True
except ImportError:
    has_gtts = False
    if tts_fallback_model is None:
        st.warning("No text-to-speech capability available. Audio generation will be disabled.")

# Cache the text-to-audio conversion
@st.cache_data
def text2audio(story_text):
    """Convert text to audio with caching to avoid regenerating the same audio"""
    if has_gtts:
        # Use gTTS
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
        temp_filename = temp_file.name
        temp_file.close()
        
        # Use gTTS to convert text to speech
        tts = gTTS(text=story_text, lang='en', slow=False)
        tts.save(temp_filename)
        
        # Read the audio file
        with open(temp_filename, 'rb') as audio_file:
            audio_bytes = audio_file.read()
        
        # Clean up the temporary file
        os.unlink(temp_filename)
        
        return audio_bytes, 'audio/mp3'
    elif tts_fallback_model is not None:
        # Use transformers TTS
        speech = tts_fallback_model(story_text)
        
        # Return the audio data
        if 'audio' in speech:
            return speech['audio'], speech.get('sampling_rate', 16000)
        elif 'audio_array' in speech:
            return speech['audio_array'], speech.get('sampling_rate', 16000)
    
    # If we got here, no TTS method worked
    raise Exception("No text-to-speech capability available")

# Convert PIL Image to bytes for hashing in cache
def get_image_bytes(pil_img):
    """Convert PIL image to bytes for hashing"""
    import io
    buf = io.BytesIO()
    pil_img.save(buf, format='JPEG')
    return buf.getvalue()

# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes):
    """Convert image to text with caching - using bytes for caching compatibility"""
    # Convert bytes back to PIL image for processing
    import io
    from PIL import Image
    pil_img = Image.open(io.BytesIO(image_bytes))
    
    # Process with the model
    result = img2text_model(pil_img)
    return result[0]["generated_text"]

# Helper function to count words
def count_words(text):
    return len(text.split())

# Improved text-to-story function without "Once upon a time" constraint
@st.cache_data
def text2story(text):
    """Generate a story from text with caching"""
    # Ask for a story without specifying how to start
    prompt = f"""Write a children's story based on this: {text}. 
    The story should have a clear beginning, middle, and end.
    Make the story approximately 150-200 words long with descriptive language.
    """
    
    # Generate a longer text to ensure we get a complete story
    story_result = story_generator_model(
        prompt,
        max_length=500,
        num_return_sequences=1,
        temperature=0.7,
        do_sample=True
    )
   
    full_text = story_result[0]['generated_text']
    
    # Try to extract just the story part (after the prompt)
    # Look for paragraph breaks or clear story beginnings
    potential_starts = [
        "\n\n",
        "\n",
        ". ",
        "! ",
        "? "
    ]
    
    # Find where the prompt ends and the actual story begins
    story_text = full_text
    
    # First remove the exact prompt if it appears verbatim
    if prompt in story_text:
        story_text = story_text.replace(prompt, "")
    else:
        # Look for paragraph breaks or sentence endings that might indicate
        # where the prompt instructions end and the story begins
        for start_marker in potential_starts:
            if start_marker in story_text:
                parts = story_text.split(start_marker, 1)
                if len(parts[0]) < len(story_text) * 0.5:  # If the first part is reasonably short
                    story_text = parts[1]
                    break
    
    # Clean up any leading/trailing whitespace
    story_text = story_text.strip()
    
    # Find natural ending points (end of sentences)
    periods = [i for i, char in enumerate(story_text) if char == '.']
    question_marks = [i for i, char in enumerate(story_text) if char == '?']
    exclamation_marks = [i for i, char in enumerate(story_text) if char == '!']
    
    # Combine all ending punctuation and sort
    all_endings = sorted(periods + question_marks + exclamation_marks)
    
    # Target approximately 100 words
    target_word_count = 100
    min_acceptable_words = 80
    
    # If we have any sentence endings
    if all_endings:
        # Find the sentence ending that gets us closest to 100 words
        closest_ending = None
        closest_word_diff = float('inf')
        
        for ending_idx in all_endings:
            candidate_text = story_text[:ending_idx+1]
            candidate_word_count = count_words(candidate_text)
            
            # Only consider endings that give us at least min_acceptable_words
            if candidate_word_count >= min_acceptable_words:
                word_diff = abs(candidate_word_count - target_word_count)
                
                if word_diff < closest_word_diff:
                    closest_ending = ending_idx
                    closest_word_diff = word_diff
        
        # If we found a suitable ending, use it
        if closest_ending is not None:
            return story_text[:closest_ending+1]
    
    # If we couldn't find a good ending near 100 words, but we have some sentence endings,
    # use the last one that results in a story with at least min_acceptable_words words
    if all_endings:
        for ending_idx in reversed(all_endings):
            candidate_text = story_text[:ending_idx+1]
            if count_words(candidate_text) >= min_acceptable_words:
                return candidate_text
    
    # If no good ending is found, return as is
    return story_text

# Function to reset progress when a new file is uploaded
def reset_progress():
    st.session_state.progress = {
        'caption_generated': False,
        'story_generated': False,
        'audio_generated': False,
        'caption': '',
        'story': '',
        'audio_data': None,
        'audio_format': None
    }

# Basic Streamlit interface
st.title("Image to Audio Story")

# Add processing status indicator
status_container = st.empty()

# Initialize session state for tracking progress
if 'progress' not in st.session_state:
    st.session_state.progress = {
        'caption_generated': False,
        'story_generated': False,
        'audio_generated': False,
        'caption': '',
        'story': '',
        'audio_data': None,
        'audio_format': None
    }

# File uploader
uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress)

# Process the image if uploaded
if uploaded_file is not None:
    # Display image
    st.image(uploaded_file, caption="Uploaded Image")
    
    # Convert to PIL Image
    image = Image.open(uploaded_file)
    
    # Convert image to bytes for caching compatibility
    image_bytes = get_image_bytes(image)
    
    # Image to Text (if not already done)
    if not st.session_state.progress['caption_generated']:
        status_container.info("Generating caption...")
        st.session_state.progress['caption'] = img2text(image_bytes)
        st.session_state.progress['caption_generated'] = True
    
    st.write(f"Caption: {st.session_state.progress['caption']}")
    
    # Text to Story (if not already done)
    if not st.session_state.progress['story_generated']:
        status_container.info("Creating story...")
        st.session_state.progress['story'] = text2story(st.session_state.progress['caption'])
        st.session_state.progress['story_generated'] = True
    
    # Display word count for transparency
    word_count = count_words(st.session_state.progress['story'])
    st.write(f"Story ({word_count} words):")
    st.write(st.session_state.progress['story'])
    
    # Pre-generate audio in background (if not already done)
    if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None):
        status_container.info("Pre-generating audio in background...")
        try:
            st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story'])
            st.session_state.progress['audio_generated'] = True
            status_container.success("Ready to play audio!")
        except Exception as e:
            status_container.error(f"Error pre-generating audio: {e}")
    
    # Button to play audio
    if st.button("Play the audio"):
        if st.session_state.progress['audio_generated']:
            # Display the audio player
            if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'):
                st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format'])
            else:
                st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format'])
        else:
            # Handle case where audio generation failed or is not available
            st.error("Unable to play audio. Audio generation was not successful.")
else:
    status_container.info("Upload an image to begin")