Assignment1 / app.py
CR7CAD's picture
Update app.py
15c1038 verified
raw
history blame
10.6 kB
# Imports
import streamlit as st
from transformers import pipeline
from PIL import Image
import torch
import os
import tempfile
import time
import numpy as np
# Use Streamlit's caching mechanisms to optimize model loading
@st.cache_resource
def load_image_to_text_pipeline():
"""Load and cache the image-to-text model"""
return pipeline("image-to-text", model="sooh-j/blip-image-captioning-base")
@st.cache_resource
def load_text_generation_pipeline():
"""Load and cache the text generation model"""
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
@st.cache_resource
def load_tts_pipeline():
"""Load and cache the text-to-speech pipeline as fallback"""
try:
return pipeline("text-to-speech", model="facebook/mms-tts-eng")
except:
# Return None if loading fails
return None
# Initialize all models at app startup
with st.spinner("Loading models (this may take a moment the first time)..."):
# Load all models at startup and cache them
img2text_model = load_image_to_text_pipeline()
story_generator_model = load_text_generation_pipeline()
tts_fallback_model = load_tts_pipeline()
# For TTS, try multiple options in order of preference
try:
# Try importing gTTS
from gtts import gTTS
has_gtts = True
except ImportError:
has_gtts = False
if tts_fallback_model is None:
st.warning("No text-to-speech capability available. Audio generation will be disabled.")
# Cache the text-to-audio conversion
@st.cache_data
def text2audio(story_text):
"""Convert text to audio with caching to avoid regenerating the same audio"""
if has_gtts:
# Use gTTS
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
temp_filename = temp_file.name
temp_file.close()
# Use gTTS to convert text to speech
tts = gTTS(text=story_text, lang='en', slow=False)
tts.save(temp_filename)
# Read the audio file
with open(temp_filename, 'rb') as audio_file:
audio_bytes = audio_file.read()
# Clean up the temporary file
os.unlink(temp_filename)
return audio_bytes, 'audio/mp3'
elif tts_fallback_model is not None:
# Use transformers TTS
speech = tts_fallback_model(story_text)
# Return the audio data
if 'audio' in speech:
return speech['audio'], speech.get('sampling_rate', 16000)
elif 'audio_array' in speech:
return speech['audio_array'], speech.get('sampling_rate', 16000)
# If we got here, no TTS method worked
raise Exception("No text-to-speech capability available")
# Convert PIL Image to bytes for hashing in cache
def get_image_bytes(pil_img):
"""Convert PIL image to bytes for hashing"""
import io
buf = io.BytesIO()
pil_img.save(buf, format='JPEG')
return buf.getvalue()
# Simple image-to-text function using cached model
@st.cache_data
def img2text(image_bytes):
"""Convert image to text with caching - using bytes for caching compatibility"""
# Convert bytes back to PIL image for processing
import io
from PIL import Image
pil_img = Image.open(io.BytesIO(image_bytes))
# Process with the model
result = img2text_model(pil_img)
return result[0]["generated_text"]
# Helper function to count words
def count_words(text):
return len(text.split())
# Improved text-to-story function without "Once upon a time" constraint
@st.cache_data
def text2story(text):
"""Generate a story from text with caching"""
# Ask for a story without specifying how to start
prompt = f"""Write a children's story based on this: {text}.
The story should have a clear beginning, middle, and end.
Make the story approximately 150-200 words long with descriptive language.
"""
# Generate a longer text to ensure we get a complete story
story_result = story_generator_model(
prompt,
max_length=500,
num_return_sequences=1,
temperature=0.7,
do_sample=True
)
full_text = story_result[0]['generated_text']
# Try to extract just the story part (after the prompt)
# Look for paragraph breaks or clear story beginnings
potential_starts = [
"\n\n",
"\n",
". ",
"! ",
"? "
]
# Find where the prompt ends and the actual story begins
story_text = full_text
# First remove the exact prompt if it appears verbatim
if prompt in story_text:
story_text = story_text.replace(prompt, "")
else:
# Look for paragraph breaks or sentence endings that might indicate
# where the prompt instructions end and the story begins
for start_marker in potential_starts:
if start_marker in story_text:
parts = story_text.split(start_marker, 1)
if len(parts[0]) < len(story_text) * 0.5: # If the first part is reasonably short
story_text = parts[1]
break
# Clean up any leading/trailing whitespace
story_text = story_text.strip()
# Find natural ending points (end of sentences)
periods = [i for i, char in enumerate(story_text) if char == '.']
question_marks = [i for i, char in enumerate(story_text) if char == '?']
exclamation_marks = [i for i, char in enumerate(story_text) if char == '!']
# Combine all ending punctuation and sort
all_endings = sorted(periods + question_marks + exclamation_marks)
# Target approximately 100 words
target_word_count = 100
min_acceptable_words = 80
# If we have any sentence endings
if all_endings:
# Find the sentence ending that gets us closest to 100 words
closest_ending = None
closest_word_diff = float('inf')
for ending_idx in all_endings:
candidate_text = story_text[:ending_idx+1]
candidate_word_count = count_words(candidate_text)
# Only consider endings that give us at least min_acceptable_words
if candidate_word_count >= min_acceptable_words:
word_diff = abs(candidate_word_count - target_word_count)
if word_diff < closest_word_diff:
closest_ending = ending_idx
closest_word_diff = word_diff
# If we found a suitable ending, use it
if closest_ending is not None:
return story_text[:closest_ending+1]
# If we couldn't find a good ending near 100 words, but we have some sentence endings,
# use the last one that results in a story with at least min_acceptable_words words
if all_endings:
for ending_idx in reversed(all_endings):
candidate_text = story_text[:ending_idx+1]
if count_words(candidate_text) >= min_acceptable_words:
return candidate_text
# If no good ending is found, return as is
return story_text
# Function to reset progress when a new file is uploaded
def reset_progress():
st.session_state.progress = {
'caption_generated': False,
'story_generated': False,
'audio_generated': False,
'caption': '',
'story': '',
'audio_data': None,
'audio_format': None
}
# Basic Streamlit interface
st.title("Image to Audio Story")
# Add processing status indicator
status_container = st.empty()
# Initialize session state for tracking progress
if 'progress' not in st.session_state:
st.session_state.progress = {
'caption_generated': False,
'story_generated': False,
'audio_generated': False,
'caption': '',
'story': '',
'audio_data': None,
'audio_format': None
}
# File uploader
uploaded_file = st.file_uploader("Upload an image", on_change=reset_progress)
# Process the image if uploaded
if uploaded_file is not None:
# Display image
st.image(uploaded_file, caption="Uploaded Image")
# Convert to PIL Image
image = Image.open(uploaded_file)
# Convert image to bytes for caching compatibility
image_bytes = get_image_bytes(image)
# Image to Text (if not already done)
if not st.session_state.progress['caption_generated']:
status_container.info("Generating caption...")
st.session_state.progress['caption'] = img2text(image_bytes)
st.session_state.progress['caption_generated'] = True
st.write(f"Caption: {st.session_state.progress['caption']}")
# Text to Story (if not already done)
if not st.session_state.progress['story_generated']:
status_container.info("Creating story...")
st.session_state.progress['story'] = text2story(st.session_state.progress['caption'])
st.session_state.progress['story_generated'] = True
# Display word count for transparency
word_count = count_words(st.session_state.progress['story'])
st.write(f"Story ({word_count} words):")
st.write(st.session_state.progress['story'])
# Pre-generate audio in background (if not already done)
if not st.session_state.progress['audio_generated'] and (has_gtts or tts_fallback_model is not None):
status_container.info("Pre-generating audio in background...")
try:
st.session_state.progress['audio_data'], st.session_state.progress['audio_format'] = text2audio(st.session_state.progress['story'])
st.session_state.progress['audio_generated'] = True
status_container.success("Ready to play audio!")
except Exception as e:
status_container.error(f"Error pre-generating audio: {e}")
# Button to play audio
if st.button("Play the audio"):
if st.session_state.progress['audio_generated']:
# Display the audio player
if isinstance(st.session_state.progress['audio_format'], str) and st.session_state.progress['audio_format'].startswith('audio/'):
st.audio(st.session_state.progress['audio_data'], format=st.session_state.progress['audio_format'])
else:
st.audio(st.session_state.progress['audio_data'], sample_rate=st.session_state.progress['audio_format'])
else:
# Handle case where audio generation failed or is not available
st.error("Unable to play audio. Audio generation was not successful.")
else:
status_container.info("Upload an image to begin")