Spaces:
Sleeping
Sleeping
File size: 10,594 Bytes
8fe6281 90bef38 8d5fabf ab8ead3 8fe6281 fc13d66 15c1038 fc13d66 862568a ce9aea5 862568a fc13d66 862568a fc13d66 ce9aea5 fc13d66 ce9aea5 fc13d66 118cd25 15c1038 fc13d66 15c1038 fc13d66 8d5fabf 5518670 fc13d66 5f21a2d fc13d66 5518670 5f21a2d 9d38390 fc13d66 5f21a2d 5518670 5f21a2d 5518670 7c4bc18 5518670 7c4bc18 5518670 7c4bc18 5518670 7c4bc18 9d38390 5f21a2d 15c1038 fc13d66 15c1038 fc13d66 15c1038 fc13d66 15c1038 fc13d66 ab8ead3 ad4186a f006a50 ad4186a ab8ead3 f006a50 15c1038 fc13d66 15c1038 fc13d66 8fe6281 fc13d66 8fe6281 fc13d66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 |
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import torch
import os
import tempfile
import time
import numpy as np
# Use Streamlit's caching mechanisms to optimize model loading
@st.cache_resource
def load_image_to_text_pipeline():
"""Load and cache the image-to-text model"""
return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
@st.cache_resource
def load_text_generation_pipeline():
"""Load and cache the text generation model"""
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
@st.cache_resource
def load_tts_pipeline():
"""Load and cache the text-to-speech pipeline as fallback"""
try:
return pipeline("text-to-speech", model="facebook/mms-tts-eng")
except:
# Return None if loading fails
return None
# Initialize all models at app startup
with st.spinner("Loading models (this may take a moment the first time)..."):
# Load all models at startup and cache them
img2text_model = load_image_to_text_pipeline()
story_generator_model = load_text_generation_pipeline()
tts_fallback_model = load_tts_pipeline()
# For TTS, try multiple options in order of preference
try:
# Try importing gTTS
from gtts import gTTS
has_gtts = True
except ImportError:
has_gtts = False
if tts_fallback_model is None:
st.warning("No text-to-speech capability available. Audio generation will be disabled.")
# Cache the text-to-audio conversion
@st.cache_data
def text2audio(story_text):
"""Convert text to audio with caching to avoid regenerating the same audio"""
if has_gtts:
# Use gTTS
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
temp_filename = temp_file.name
temp_file.close()
# Use gTTS to convert text to speech
tts = gTTS(text=story_text, lang='en', slow=False)
tts.save(temp_filename)
# Read the audio file
with open(temp_filename, 'rb') as audio_file:
audio_bytes = audio_file.read()
# Clean up the temporary file
os.unlink(temp_filename)
return audio_bytes, 'audio/mp3'
elif tts_fallback_model is not None:
# Use transformers TTS
speech = tts_fallback_model(story_text)
# Return the audio data
if 'audio' in speech:
return speech['audio'], speech.get('sampling_rate', 16000)
elif 'audio_array' in speech:
return speech['audio_array'], speech.get('sampling_rate', 16000)
# If we got here, no TTS method worked
raise Exception("No text-to-speech capability available")
# Convert PIL Image to bytes for hashing in cache
def get_image_bytes(pil_img):
"""Convert PIL image to bytes for hashing"""
import io
buf = io.BytesIO()
pil_img.save(buf, format='JPEG')
return buf.getvalue()
# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes):
"""Convert image to text with caching - using bytes for caching compatibility"""
# Convert bytes back to PIL image for processing
import io
from PIL import Image
pil_img = Image.open(io.BytesIO(image_bytes))
# Process with the model
result = img2text_model(pil_img)
return result[0]["generated_text"]
# Helper function to count words
def count_words(text):
return len(text.split())
# Improved text-to-story function without "Once upon a time" constraint
@st.cache_data
def text2story(text):
"""Generate a story from text with caching"""
# Ask for a story without specifying how to start
prompt = f"""Write a children's story based on this: {text}.
The story should have a clear beginning, middle, and end.
Make the story approximately 150-200 words long with descriptive language.
"""
# Generate a longer text to ensure we get a complete story
story_result = story_generator_model(
prompt,
max_length=500,
num_return_sequences=1,
temperature=0.7,
do_sample=True
)
full_text = story_result[0]['generated_text']
# Try to extract just the story part (after the prompt)
# Look for paragraph breaks or clear story beginnings
potential_starts = [
"\n\n",
"\n",
". ",
"! ",
"? "
]
# Find where the prompt ends and the actual story begins
story_text = full_text
# First remove the exact prompt if it appears verbatim
if prompt in story_text:
story_text = story_text.replace(prompt, "")
else:
# Look for paragraph breaks or sentence endings that might indicate
# where the prompt instructions end and the story begins
for start_marker in potential_starts:
if start_marker in story_text:
parts = story_text.split(start_marker, 1)
if len(parts[0]) < len(story_text) * 0.5: # If the first part is reasonably short
story_text = parts[1]
break
# Clean up any leading/trailing whitespace
story_text = story_text.strip()
# Find natural ending points (end of sentences)
periods = [i for i, char in enumerate(story_text) if char == '.']
question_marks = [i for i, char in enumerate(story_text) if char == '?']
exclamation_marks = [i for i, char in enumerate(story_text) if char == '!']
# Combine all ending punctuation and sort
all_endings = sorted(periods + question_marks + exclamation_marks)
# Target approximately 100 words
target_word_count = 100
min_acceptable_words = 80
# If we have any sentence endings
if all_endings:
# Find the sentence ending that gets us closest to 100 words
closest_ending = None
closest_word_diff = float('inf')
for ending_idx in all_endings:
candidate_text = story_text[:ending_idx+1]
candidate_word_count = count_words(candidate_text)
# Only consider endings that give us at least min_acceptable_words
if candidate_word_count >= min_acceptable_words:
word_diff = abs(candidate_word_count - target_word_count)
if word_diff < closest_word_diff:
closest_ending = ending_idx
closest_word_diff = word_diff
# If we found a suitable ending, use it
if closest_ending is not None:
return story_text[:closest_ending+1]
# If we couldn't find a good ending near 100 words, but we have some sentence endings,
# use the last one that results in a story with at least min_acceptable_words words
if all_endings:
for ending_idx in reversed(all_endings):
candidate_text = story_text[:ending_idx+1]
if count_words(candidate_text) >= min_acceptable_words:
return candidate_text
# If no good ending is found, return as is
return story_text
# Function to reset progress when a new file is uploaded
def reset_progress():
st.session_state.progress = {
'caption_generated': False,
'story_generated': False,
'audio_generated': False,
'caption': '',
'story': '',
'audio_data': None,
'audio_format': None
}
# Basic Streamlit interface
st.title("Image to Audio Story")
# Add processing status indicator
status_container = st.empty()
# Initialize session state for tracking progress
if 'progress' not in st.session_state:
st.session_state.progress = {
'caption_generated': False,
'story_generated': False,
'audio_generated': False,
'caption': '',
'story': '',
'audio_data': None,
'audio_format': None
}
# File uploader
uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress)
# Process the image if uploaded
if uploaded_file is not None:
# Display image
st.image(uploaded_file, caption="Uploaded Image")
# Convert to PIL Image
image = Image.open(uploaded_file)
# Convert image to bytes for caching compatibility
image_bytes = get_image_bytes(image)
# Image to Text (if not already done)
if not st.session_state.progress['caption_generated']:
status_container.info("Generating caption...")
st.session_state.progress['caption'] = img2text(image_bytes)
st.session_state.progress['caption_generated'] = True
st.write(f"Caption: {st.session_state.progress['caption']}")
# Text to Story (if not already done)
if not st.session_state.progress['story_generated']:
status_container.info("Creating story...")
st.session_state.progress['story'] = text2story(st.session_state.progress['caption'])
st.session_state.progress['story_generated'] = True
# Display word count for transparency
word_count = count_words(st.session_state.progress['story'])
st.write(f"Story ({word_count} words):")
st.write(st.session_state.progress['story'])
# Pre-generate audio in background (if not already done)
if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None):
status_container.info("Pre-generating audio in background...")
try:
st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story'])
st.session_state.progress['audio_generated'] = True
status_container.success("Ready to play audio!")
except Exception as e:
status_container.error(f"Error pre-generating audio: {e}")
# Button to play audio
if st.button("Play the audio"):
if st.session_state.progress['audio_generated']:
# Display the audio player
if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'):
st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format'])
else:
st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format'])
else:
# Handle case where audio generation failed or is not available
st.error("Unable to play audio. Audio generation was not successful.")
else:
status_container.info("Upload an image to begin") |