Assignment1 / app.py
CR7CAD's picture
Update app.py
ab8ead3 verified
raw
history blame
5.45 kB
# import part
import streamlit as st
from transformers import pipeline
from PIL import Image
# Set global caching options for Transformers
from transformers import set_caching_enabled
set_caching_enabled(True)
# function part with caching for better performance
@st.cache_resource
def load_image_captioning_model():
return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
@st.cache_resource
def load_text_generator():
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
@st.cache_resource
def load_tts_model():
return pipeline("text-to-speech", model="HelpingAI/HelpingAI-TTS-v1")
# img2text - Using the original model with more constraints
def img2text(image):
# Load the model (cached)
image_to_text = load_image_captioning_model()
# Strongly limit output length for speed
text = image_to_text(image, max_new_tokens=15)[0]["generated_text"]
return text
# text2story - Much more constrained for speed
def text2story(text):
# Load the model (cached)
generator = load_text_generator()
# Very brief prompt to minimize work
prompt = f"Short story about {text}: Once upon a time, "
# Very constrained parameters for maximum speed
story_result = generator(
prompt,
max_new_tokens=60, # Much shorter output
num_return_sequences=1,
temperature=0.7,
top_k=10, # Lower value = faster
top_p=0.9, # Lower value = faster
do_sample=True
)
# Extract and clean text
story_text = story_result[0]['generated_text']
story_text = story_text.replace(prompt, "Once upon a time, ")
# Find a natural ending point
last_period = story_text.rfind('.')
if last_period > 30: # Ensure we have at least some content
story_text = story_text[:last_period + 1]
return story_text
# text2audio - Minimal text for faster processing
def text2audio(story_text):
try:
# Load the model (cached)
synthesizer = load_tts_model()
# Aggressively limit text length to speed up TTS
max_chars = 200 # Much shorter than before
if len(story_text) > max_chars:
last_period = story_text[:max_chars].rfind('.')
if last_period > 0:
story_text = story_text[:last_period + 1]
else:
story_text = story_text[:max_chars]
# Generate speech
speech = synthesizer(story_text)
return speech
except Exception as e:
st.error(f"Error generating audio: {str(e)}")
return None
# Streamlined main UI
st.set_page_config(page_title="Image to Story", page_icon="πŸ“š")
st.header("Image to Audio Story")
# Add info about processing time
st.info("Note: Processing may take some time as the models are loading. Please be patient.")
# Cache the file uploader state
if "uploaded_file" not in st.session_state:
st.session_state["uploaded_file"] = None
uploaded_file = st.file_uploader("Select an Image...", key="file_uploader")
# Process the image if uploaded
if uploaded_file is not None:
st.session_state["uploaded_file"] = uploaded_file
# Display the uploaded image
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
# Convert to PIL image
image = Image.open(uploaded_file)
# Optional processing toggle to let user decide
if st.button("Generate Story and Audio"):
col1, col2 = st.columns(2)
# Stage 1: Image to Text with minimal output
with col1:
with st.spinner('Captioning image...'):
caption = img2text(image)
st.write(f"**Caption:** {caption}")
# Stage 2: Text to Story with minimal length
with col2:
with st.spinner('Creating story...'):
story = text2story(caption)
st.write(f"**Story:** {story}")
# Stage 3: Audio with minimal text
with st.spinner('Generating audio...'):
speech_output = text2audio(story)
# Display audio immediately
if speech_output is not None:
try:
if 'audio' in speech_output and 'sampling_rate' in speech_output:
st.audio(speech_output['audio'], sample_rate=speech_output['sampling_rate'])
elif 'audio_array' in speech_output and 'sampling_rate' in speech_output:
st.audio(speech_output['audio_array'], sample_rate=speech_output['sampling_rate'])
elif 'waveform' in speech_output and 'sample_rate' in speech_output:
st.audio(speech_output['waveform'], sample_rate=speech_output['sample_rate'])
else:
# Try any array-like data
for key, value in speech_output.items():
if hasattr(value, '__len__') and len(value) > 1000:
sample_rate = speech_output.get('sampling_rate', speech_output.get('sample_rate', 24000))
st.audio(value, sample_rate=sample_rate)
break
else:
st.error("Could not find audio data in the output")
except Exception as e:
st.error(f"Error playing audio: {str(e)}")
else:
st.error("Audio generation failed")