File size: 5,452 Bytes
ab8ead3
90bef38
8d5fabf
ab8ead3
118cd25
ab8ead3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd245d5
8d5fabf
ab8ead3
5f21a2d
ab8ead3
 
5f21a2d
ab8ead3
 
5f21a2d
ab8ead3
5f21a2d
 
ab8ead3
5f21a2d
 
ab8ead3
 
5f21a2d
 
 
ab8ead3
5f21a2d
 
 
ab8ead3
 
 
 
5f21a2d
 
 
ab8ead3
cd245d5
7df9b81
ab8ead3
 
7df9b81
ab8ead3
 
76abf5e
 
3fd88eb
 
 
76abf5e
7df9b81
a79c9ac
 
 
7df9b81
 
 
3fd88eb
8d5fabf
ab8ead3
 
 
4e37056
ab8ead3
 
 
 
 
 
e77741a
ab8ead3
 
 
 
 
f006a50
ab8ead3
 
f006a50
ab8ead3
 
f006a50
ab8ead3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a79c9ac
 
 
 
 
 
 
 
 
ab8ead3
1ebc71c
 
ab8ead3
 
1ebc71c
 
ab8ead3
a79c9ac
 
7df9b81
ab8ead3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# import part
import streamlit as st
from transformers import pipeline
from PIL import Image

# Set global caching options for Transformers
from transformers import set_caching_enabled
set_caching_enabled(True)

# function part with caching for better performance
@st.cache_resource
def load_image_captioning_model():
    return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")

@st.cache_resource  
def load_text_generator():
    return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")

@st.cache_resource
def load_tts_model():
    return pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1")

# img2text - Using the original model with more constraints
def img2text(image):
    # Load the model (cached)
    image_to_text = load_image_captioning_model()
    
    # Strongly limit output length for speed
    text = image_to_text(image, max_new_tokens=15)[0]["generated_text"]
    return text

# text2story - Much more constrained for speed
def text2story(text):
    # Load the model (cached)
    generator = load_text_generator()
    
    # Very brief prompt to minimize work
    prompt = f"Short story about {text}: Once upon a time, "
    
    # Very constrained parameters for maximum speed
    story_result = generator(
        prompt,
        max_new_tokens=60,  # Much shorter output
        num_return_sequences=1,
        temperature=0.7,
        top_k=10,  # Lower value = faster
        top_p=0.9,  # Lower value = faster
        do_sample=True
    )
   
    # Extract and clean text
    story_text = story_result[0]['generated_text']
    story_text = story_text.replace(prompt, "Once upon a time, ")
    
    # Find a natural ending point
    last_period = story_text.rfind('.')
    if last_period > 30:  # Ensure we have at least some content
        story_text = story_text[:last_period + 1]
    
    return story_text

# text2audio - Minimal text for faster processing
def text2audio(story_text):
    try:
        # Load the model (cached)
        synthesizer = load_tts_model()
        
        # Aggressively limit text length to speed up TTS
        max_chars = 200  # Much shorter than before
        if len(story_text) > max_chars:
            last_period = story_text[:max_chars].rfind('.')
            if last_period > 0:
                story_text = story_text[:last_period + 1]
            else:
                story_text = story_text[:max_chars]
        
        # Generate speech
        speech = synthesizer(story_text)
        return speech
        
    except Exception as e:
        st.error(f"Error generating audio: {str(e)}")
        return None

# Streamlined main UI
st.set_page_config(page_title="Image to Story", page_icon="📚")
st.header("Image to Audio Story")

# Add info about processing time
st.info("Note: Processing may take some time as the models are loading. Please be patient.")

# Cache the file uploader state
if "uploaded_file" not in st.session_state:
    st.session_state["uploaded_file"] = None
    
uploaded_file = st.file_uploader("Select an Image...", key="file_uploader")

# Process the image if uploaded
if uploaded_file is not None:
    st.session_state["uploaded_file"] = uploaded_file
    
    # Display the uploaded image
    st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
    
    # Convert to PIL image
    image = Image.open(uploaded_file)
    
    # Optional processing toggle to let user decide
    if st.button("Generate Story and Audio"):
        col1, col2 = st.columns(2)
        
        # Stage 1: Image to Text with minimal output
        with col1:
            with st.spinner('Captioning image...'):
                caption = img2text(image)
            st.write(f"**Caption:** {caption}")
        
        # Stage 2: Text to Story with minimal length
        with col2:
            with st.spinner('Creating story...'):
                story = text2story(caption)
            st.write(f"**Story:** {story}")
        
        # Stage 3: Audio with minimal text
        with st.spinner('Generating audio...'):
            speech_output = text2audio(story)
        
        # Display audio immediately
        if speech_output is not None:
            try:
                if 'audio' in speech_output and 'sampling_rate' in speech_output:
                    st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate'])
                elif 'audio_array' in speech_output and 'sampling_rate' in speech_output:
                    st.audio(speech_output['audio_array'], sample_rate=speech_output['sampling_rate'])
                elif 'waveform' in speech_output and 'sample_rate' in speech_output:
                    st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate'])
                else:
                    # Try any array-like data
                    for key, value in speech_output.items():
                        if hasattr(value, '__len__') and len(value) > 1000:
                            sample_rate = speech_output.get('sampling_rate', speech_output.get('sample_rate', 24000))
                            st.audio(value, sample_rate=sample_rate)
                            break
                    else:
                        st.error("Could not find audio data in the output")
            except Exception as e:
                st.error(f"Error playing audio: {str(e)}")
        else:
            st.error("Audio generation failed")