Assignment1 / app.py
shingguy1's picture
Update app.py
f97d7d0 verified
raw
history blame
7.64 kB
import io
import wave
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
import time
import threading
# β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
@st.cache_resource
def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
return pipeline("image-to-text", model=model_name, device="cpu")
@st.cache_resource
def get_story_pipe(model_name="google/flan-t5-base"):
return pipeline("text2text-generation", model=model_name, device="cpu")
@st.cache_resource
def get_tts_pipe(model_name="facebook/mms-tts-eng"):
return pipeline("text-to-speech", model=model_name, device="cpu")
# β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def part1_image_to_text(pil_img, captioner):
results = captioner(pil_img)
return results[0].get("generated_text", "") if results else ""
def part2_text_to_story(
caption: str,
story_pipe,
target_words: int = 100,
max_length: int = 100,
min_length: int = 80,
do_sample: bool = True,
top_k: int = 100,
top_p: float= 0.9,
temperature: float= 0.7,
repetition_penalty: float = 1.1,
no_repeat_ngram_size: int = 4
) -> str:
prompt = (
f"Write a vivid, imaginative short story of about {target_words} words "
f"describing this scene: {caption}"
)
out = story_pipe(
prompt,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=False
)
raw = out[0].get("generated_text", "").strip()
if not raw:
return ""
# strip echo of prompt
if raw.lower().startswith(prompt.lower()):
story = raw[len(prompt):].strip()
else:
story = raw
# cut at last full stop
idx = story.rfind(".")
if idx != -1:
story = story[:idx+1]
return story
def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
out = tts_pipe(text)
if isinstance(out, list):
out = out[0]
audio_array = out["audio"] # np.ndarray (channels, samples)
rate = out["sampling_rate"] # int
data = audio_array.T if audio_array.ndim == 2 else audio_array
pcm = (data * 32767).astype(np.int16)
buffer = io.BytesIO()
wf = wave.open(buffer, "wb")
channels = 1 if data.ndim == 1 else data.shape[1]
wf.setnchannels(channels)
wf.setsampwidth(2)
wf.setframerate(rate)
wf.writeframes(pcm.tobytes())
wf.close()
buffer.seek(0)
return buffer.read()
# β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Set page config as the first Streamlit command
st.set_page_config(
page_title="Picture to Story Magic",
page_icon="✨",
layout="centered"
)
# Custom CSS for kid-friendly styling with improved readability
st.markdown("""
<style>
.main {
background-color: #e6f3ff;
padding: 20px;
border-radius: 15px;
}
.stButton>button {
background-color: #ffcccb;
button-color: #000000;
border-radius: 10px;
border: 2px solid #ff9999;
font-size: 18px;
font-weight: bold;
padding: 10px 20px;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #ff9999;
color: #ffffff;
transform: scale(1.05);
}
.stFileUploader {
background-color: #ffb300;
border: 2px dashed #ff8c00;
border-radius: 10px;
padding: 10px;
}
.stFileUploader div[role="button"] {
background-color: #f0f0f0;
border-radius: 10px;
padding: 10px;
}
.stFileUploader div[role="button"] > div {
color: #000000 !important;
font-size: 16px;
}
.stFileUploader button {
background-color: #ffca28 !important;
color: #000000 !important;
border-radius: 8px !important;
border: 2px solid #ffb300 !important;
padding: 5px 15px !important;
font-weight: bold !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
}
.stFileUploader button:hover {
background-color: #ff8c00 !important;
color: #000000 !important;
}
.stImage {
border: 3px solid #81c784;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.section-header {
background-color: #b3e5fc;
padding: 10px;
border-radius: 10px;
text-align: center;
font-size: 24px;
font-weight: bold;
color: #000000;
margin-bottom: 10px;
}
.caption-box, .story-box {
background-color: #f0f4c3;
padding: 15px;
border-radius: 10px;
border: 2px solid #d4e157;
margin-bottom: 20px;
color: #000000;
}
.caption-box b, .story-box b {
color: #000000;
}
.stProgress > div > div {
background-color: #81c784;
}
</style>
""", unsafe_allow_html=True)
# Main title
st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
# Image upload section
with st.container():
st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
if not uploaded:
st.info("Upload a picture, and let's make a story! πŸŽ‰")
st.stop()
# Show image
with st.spinner("Looking at your picture..."):
pil_img = Image.open(uploaded)
st.image(pil_img, use_container_width=True)
# Caption section
with st.container():
st.markdown("<div class='section-header'>2️⃣ What's in the Picture? 🧐</div>", unsafe_allow_html=True)
captioner = get_image_captioner()
progress_bar = st.progress(0)
result = [None]
def run_caption():
result[0] = part1_image_to_text(pil_img, captioner)
with st.spinner("Figuring out what's in your picture..."):
thread = threading.Thread(target=run_caption)
thread.start()
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.05) # Adjust for ~5 seconds total
thread.join()
progress_bar.empty()
caption = result[0]
st.markdown(f"<div class='caption-box'><b>Picture Description:</b><br>{caption}</div>", unsafe_allow_html=True)
# Story and audio section
with st.container():
st.markdown("<div class='section-header'>3️⃣ Your Story and Audio! 🎡</div>", unsafe_allow_html=True)
# Story
story_pipe = get_story_pipe()
progress_bar = st.progress(0)
result = [None]
def run_story():
result[0] = part2_text_to_story(caption, story_pipe)
with st.spinner("Writing a super cool story..."):
thread = threading.Thread(target=run_story)
thread.start()
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.07) # Adjust for ~7 seconds total
thread.join()
progress_bar.empty()
story = result[0]
st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
# TTS
tts_pipe = get_tts_pipe()
progress_bar = st.progress(0)
result = [None]
def run_tts():
result[0] = part3_text_to_speech_bytes(story, tts_pipe)
with st.spinner("Turning your story into sound..."):
thread = threading.Thread(target=run_tts)
thread.start()
for i in range(100):
progress_bar.progress(i + 1)
time.sleep(0.10) # Adjust for ~10 seconds total
thread.join()
progress_bar.empty()
audio_bytes = result[0]
st.audio(audio_bytes, format="audio/wav")