File size: 5,391 Bytes
cd245d5
90bef38
8d5fabf
cd245d5
76abf5e
118cd25
cd245d5
 
 
b2cad31
cd245d5
 
8d5fabf
cd245d5
 
f006a50
6f17888
cd245d5
 
f006a50
d996989
 
 
 
 
 
 
 
 
 
 
6f17888
f006a50
 
 
 
 
b9c5fcd
 
 
 
f006a50
cd245d5
8d5fabf
e5f2129
6706f05
cd245d5
7df9b81
6706f05
e5f2129
7df9b81
76abf5e
 
 
 
3fd88eb
 
 
76abf5e
7df9b81
e5f2129
 
7df9b81
6706f05
 
 
 
76abf5e
 
6706f05
7df9b81
6706f05
 
e5f2129
6706f05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76abf5e
7df9b81
 
 
3fd88eb
8d5fabf
cd245d5
 
 
 
 
 
 
 
 
 
 
8d5fabf
cd245d5
f006a50
 
 
4e37056
 
f006a50
a084b90
cd245d5
f006a50
 
8d5fabf
f006a50
 
 
 
 
 
 
 
 
 
 
 
e5f2129
f006a50
 
 
76abf5e
 
 
7df9b81
76abf5e
f006a50
76abf5e
f006a50
 
e5f2129
f006a50
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# import part
import streamlit as st
from transformers import pipeline
import os
import tempfile

# function part
# img2text
def img2text(image_path):
    image_to_text = pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
    text = image_to_text(image_path)[0]["generated_text"]
    return text

# text2story
def text2story(text):
    # Using a smaller text generation model
    generator = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    
    # Create a prompt for the story generation
    prompt = f"Write a fun children's story based on this: {text}. Once upon a time, "
    
    # Generate the story
    story_result = generator(
        prompt,
        max_length=150,
        num_return_sequences=1,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        do_sample=True
    )
   
    # Extract the generated text
    story_text = story_result[0]['generated_text']
    story_text = story_text.replace(prompt, "Once upon a time, ")
    
    # Make sure the story is at least 100 words
    words = story_text.split()
    if len(words) > 100:
        # Simply truncate to 100 words
        story_text = " ".join(words[:100])
    
    return story_text

# text2audio - REVISED to handle audio format correctly
# text2audio - REVISED with proper audio field handling
def text2audio(story_text):
    try:
        # Use the facebook TTS model
        synthesizer = pipeline("text-to-speech", model="facebook/mms-tts-eng")
        
        # Limit text length to avoid timeouts
        max_chars = 500
        if len(story_text) > max_chars:
            last_period = story_text[:max_chars].rfind('.')
            if last_period > 0:
                story_text = story_text[:last_period + 1]
            else:
                story_text = story_text[:max_chars]
        
        # Generate speech
        speech = synthesizer(story_text)
        
        # DEBUG: Print the keys in the speech output to understand its structure
        st.write(f"Speech output keys: {list(speech.keys())}")
        
        # Create a temporary WAV file
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
        temp_filename = temp_file.name
        temp_file.close()
        
        # Write the audio data to the temporary file
        # The key is likely 'audio' or 'raw' rather than 'bytes'
        with open(temp_filename, 'wb') as f:
            # Try to write using the correct key from the output
            if 'audio' in speech and isinstance(speech['audio'], (bytes, bytearray)):
                f.write(speech['audio'])
            elif 'raw' in speech and isinstance(speech['raw'], (bytes, bytearray)):
                f.write(speech['raw'])
            elif 'wav' in speech and isinstance(speech['wav'], (bytes, bytearray)):
                f.write(speech['wav'])
            elif 'audio' in speech and hasattr(speech['audio'], 'tobytes'):
                # It might be a numpy array
                f.write(speech['audio'].tobytes())
            else:
                # Try the first value that looks like audio data
                for key, value in speech.items():
                    if isinstance(value, (bytes, bytearray)) or (
                            hasattr(value, 'tobytes') and len(value) > 1000):
                        if hasattr(value, 'tobytes'):
                            f.write(value.tobytes())
                        else:
                            f.write(value)
                        st.write(f"Used key: {key} for audio data")
                        break
                else:
                    raise ValueError(f"No suitable audio data found in keys: {list(speech.keys())}")
        
        return temp_filename
        
    except Exception as e:
        st.error(f"Error generating audio: {str(e)}")
        return None

# Function to save temporary image file
def save_uploaded_image(uploaded_file):
    if not os.path.exists("temp"):
        os.makedirs("temp")
    
    image_path = os.path.join("temp", uploaded_file.name)
    
    with open(image_path, "wb") as f:
        f.write(uploaded_file.getvalue())
    
    return image_path

# main part
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to Audio Story")
uploaded_file = st.file_uploader("Select an Image...")

if uploaded_file is not None:
    # Display the uploaded image
    st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
    
    # Save the image temporarily
    image_path = save_uploaded_image(uploaded_file)
    
    # Stage 1: Image to Text
    st.text('Processing img2text...')
    caption = img2text(image_path)
    st.write(caption)
    
    # Stage 2: Text to Story
    st.text('Generating a story...')
    story = text2story(caption)
    st.write(story)
    
    # Stage 3: Story to Audio data
    st.text('Generating audio data...')
    audio_file = text2audio(story)
    
    # Play button
    if st.button("Play Audio"):
        if audio_file and os.path.exists(audio_file):
            # Play the audio file
            st.audio(audio_file)
        else:
            st.error("Audio generation failed. Please try again.")
    
    # Clean up the temporary files
    try:
        os.remove(image_path)
        # Don't delete audio file immediately as it might still be playing
    except:
        pass