Assignment1 / app.py
shingguy1's picture
Update app.py
4816af6 verified
import io
import wave
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
import time
import threading
# β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Cache the model loading to avoid reloading on every rerun, improving performance
@st.cache_resource
def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
# Load the image-to-text model for generating captions from images
return pipeline("image-to-text", model=model_name, device="cpu")
@st.cache_resource
def get_story_pipe(model_name="google/flan-t5-base"):
# Load the text-to-text model for generating stories from captions
return pipeline("text2text-generation", model=model_name, device="cpu")
@st.cache_resource
def get_tts_pipe(model_name="facebook/mms-tts-eng"):
# Load the text-to-speech model for converting stories to audio
return pipeline("text-to-speech", model=model_name, device="cpu")
# β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def part1_image_to_text(pil_img, captioner):
# Generate a caption for the input image using the captioner model
results = captioner(pil_img)
# Extract the generated caption, return empty string if no result
return results[0].get("generated_text", "") if results else ""
def part2_text_to_story(
caption: str,
story_pipe,
target_words: int = 100,
max_length: int = 100,
min_length: int = 80,
do_sample: bool = True,
top_k: int = 100,
top_p: float= 0.9,
temperature: float= 0.7,
repetition_penalty: float = 1.1,
no_repeat_ngram_size: int = 4
) -> str:
# Create a prompt instructing the model to write a story based on the caption
prompt = (
f"Write a vivid, imaginative short story of about {target_words} words "
f"describing this scene: {caption}"
)
# Generate the story using the text-to-text model with specified parameters
out = story_pipe(
prompt,
max_length=max_length, # Maximum length of generated text
min_length=min_length, # Minimum length to ensure sufficient content
do_sample=do_sample, # Enable sampling for creative output
top_k=top_k, # Consider top-k tokens for sampling
top_p=top_p, # Use nucleus sampling for diversity
temperature=temperature, # Control randomness of output
repetition_penalty=repetition_penalty, # Penalize repeated phrases
no_repeat_ngram_size=no_repeat_ngram_size, # Prevent repeating n-grams
early_stopping=False # Continue until max_length is reached
)
# Extract the generated text and clean it
raw = out[0].get("generated_text", "").strip()
if not raw:
return ""
# Remove the prompt if it appears in the output
if raw.lower().startswith(prompt.lower()):
story = raw[len(prompt):].strip()
else:
story = raw
# Truncate at the last full stop for a natural ending
idx = story.rfind(".")
if idx != -1:
story = story[:idx+1]
return story
def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
# Convert the input text to audio using the text-to-speech model
out = tts_pipe(text)
if isinstance(out, list):
out = out[0]
# Extract audio data (numpy array) and sampling rate
audio_array = out["audio"] # np.ndarray (channels, samples)
rate = out["sampling_rate"] # int
# Transpose audio array if it has multiple channels
data = audio_array.T if audio_array.ndim == 2 else audio_array
# Convert audio to 16-bit PCM format for WAV compatibility
pcm = (data * 32767).astype(np.int16)
# Create a WAV file in memory
buffer = io.BytesIO()
wf = wave.open(buffer, "wb")
channels = 1 if data.ndim == 1 else data.shape[1] # Set mono or stereo
wf.setnchannels(channels)
wf.setsampwidth(2) # 2 bytes for 16-bit audio
wf.setframerate(rate) # Set sampling rate
wf.writeframes(pcm.tobytes()) # Write audio data
wf.close()
buffer.seek(0) # Reset buffer to start for reading
return buffer.read() # Return WAV bytes
# β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Configure the Streamlit page for a kid-friendly, centered layout
st.set_page_config(
page_title="Picture to Story Magic",
page_icon="✨",
layout="centered"
)
# Apply custom CSS for a colorful, engaging, and readable interface
st.markdown("""
<style>
.main {
background-color: #e6f3ff; /* Light blue background for main area */
padding: 20px;
border-radius: 15px;
}
.stButton>button {
background-color: #ffcccb; /* Pink button background */
button-color: #000000;
border-radius: 10px;
border: 2px solid #ff9999; /* Red border */
font-size: 18px;
font-weight: bold;
padding: 10px 20px;
transition: all 0.3s; /* Smooth hover effect */
}
.stButton>button:hover {
background-color: #ff9999; /* Darker pink on hover */
color: #ffffff;
transform: scale(1.05); /* Slight zoom on hover */
}
.stFileUploader {
background-color: #ffb300; /* Orange uploader background */
border: 2px dashed #ff8c00; /* Dashed orange border */
border-radius: 10px;
padding: 10px;
}
.stFileUploader div[role="button"] {
background-color: #f0f0f0; /* Light gray button */
border-radius: 10px;
padding: 10px;
}
.stFileUploader div[role="button"] > div {
color: #000000 !important; /* Black text for readability */
font-size: 16px;
}
.stFileUploader button {
background-color: #ffca28 !important; /* Yellow button */
color: #000000 !important;
border-radius: 8px !important;
border: 2px solid #ffb300 !important; /* Orange border */
padding: 5px 15px !important;
font-weight: bold !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow */
}
.stFileUploader button:hover {
background-color: #ff8c00 !important; /* Orange on hover */
color: #000000 !important;
}
.stImage {
border: 3px solid #81c784; /* Green border for images */
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Soft shadow */
}
.section-header {
background-color: #b3e5fc; /* Light blue header background */
padding: 10px;
border-radius: 10px;
text-align: center;
font-size: 24px;
font-weight: bold;
color: #000000;
margin-bottom: 10px;
}
.caption-box, .story-box {
background-color: #f0f4c3; /* Light yellow for text boxes */
padding: 15px;
border-radius: 10px;
border: 2px solid #d4e157; /* Green-yellow border */
margin-bottom: 20px;
color: #000000;
}
.caption-box b, .story-box b {
color: #000000; /* Black for bold text */
}
.stProgress > div > div {
background-color: #81c784; /* Green progress bar */
}
</style>
""", unsafe_allow_html=True)
# Display the main title with a fun, magical theme
st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
# Image upload section
with st.container():
# Prompt user to upload an image
st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
if not uploaded:
# Stop execution if no image is uploaded, with a friendly message
st.info("Upload a picture, and let's make a story! πŸŽ‰")
st.stop()
# Display the uploaded image
with st.spinner("Looking at your picture..."):
pil_img = Image.open(uploaded)
st.image(pil_img, use_container_width=True) # Show image scaled to container
# Caption generation section
with st.container():
st.markdown("<div class='section-header'>2️⃣ What's in the Picture? 🧐</div>", unsafe_allow_html=True)
captioner = get_image_captioner() # Load captioning model
progress_bar = st.progress(0) # Initialize progress bar
result = [None] # Store caption result
def run_caption():
# Run captioning in a separate thread to avoid blocking UI
result[0] = part1_image_to_text(pil_img, captioner)
with st.spinner("Figuring out what's in your picture..."):
thread = threading.Thread(target=run_caption)
thread.start()
# Simulate progress for ~5 seconds
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.05)
thread.join() # Wait for captioning to complete
progress_bar.empty() # Clear progress bar
caption = result[0]
# Display the generated caption in a styled box
st.markdown(f"<div class='caption-box'><b>Picture Description:</b><br>{caption}</div>", unsafe_allow_html=True)
# Story and audio generation section
with st.container():
st.markdown("<div class='section-header'>3️⃣ Your Story and Audio! 🎡</div>", unsafe_allow_html=True)
# Story generation
story_pipe = get_story_pipe() # Load story model
progress_bar = st.progress(0)
result = [None] # Store story result
def run_story():
# Generate story in a separate thread
result[0] = part2_text_to_story(caption, story_pipe)
with st.spinner("Writing a super cool story..."):
thread = threading.Thread(target=run_story)
thread.start()
# Simulate progress for ~7 seconds
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.07)
thread.join()
progress_bar.empty()
story = result[0]
# Display the generated story in a styled box
st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
# Text-to-speech conversion
tts_pipe = get_tts_pipe() # Load TTS model
progress_bar = st.progress(0)
result = [None] # Store audio result
def run_tts():
# Generate audio in a separate thread
result[0] = part3_text_to_speech_bytes(story, tts_pipe)
with st.spinner("Turning your story into sound..."):
thread = threading.Thread(target=run_tts)
thread.start()
# Simulate progress for ~10 seconds
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.10)
thread.join()
progress_bar.empty()
audio_bytes = result[0]
# Play the generated audio in the UI
st.audio(audio_bytes, format="audio/wav")