Assignment1 / app.py
shingguy1's picture
Update app.py
e8cddda verified
raw
history blame
7.08 kB
import io
import wave
import streamlit as st
from transformers import pipeline
from PIL import Image
import numpy as np
# β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
@st.cache_resource
def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
return pipeline("image-to-text", model=model_name, device="cpu")
@st.cache_resource
def get_story_pipe(model_name="google/flan-t5-base"):
return pipeline("text2text-generation", model=model_name, device="cpu")
@st.cache_resource
def get_tts_pipe(model_name="facebook/mms-tts-eng"):
return pipeline("text-to-speech", model=model_name, device="cpu")
# β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def part1_image_to_text(pil_img, captioner):
results = captioner(pil_img)
return results[0].get("generated_text", "") if results else ""
def part2_text_to_story(
caption: str,
story_pipe,
target_words: int = 100,
max_length: int = 100,
min_length: int = 80,
do_sample: bool = True,
top_k: int = 100,
top_p: float= 0.9,
temperature: float= 0.7,
repetition_penalty: float = 1.1,
no_repeat_ngram_size: int = 4
) -> str:
prompt = (
f"Write a vivid, imaginative short story of about {target_words} words "
f"describing this scene: {caption}"
)
out = story_pipe(
prompt,
max_length=max_length,
min_length=min_length,
do_sample=do_sample,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
early_stopping=False
)
raw = out[0].get("generated_text", "").strip()
if not raw:
return ""
# strip echo of prompt
if raw.lower().startswith(prompt.lower()):
story = raw[len(prompt):].strip()
else:
story = raw
# cut at last full stop
idx = story.rfind(".")
if idx != -1:
story = story[:idx+1]
return story
def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
out = tts_pipe(text)
if isinstance(out, list):
out = out[0]
audio_array = out["audio"] # np.ndarray (channels, samples)
rate = out["sampling_rate"] # int
data = audio_array.T if audio_array.ndim == 2 else audio_array
pcm = (data * 32767).astype(np.int16)
buffer = io.BytesIO()
wf = wave.open(buffer, "wb")
channels = 1 if data.ndim == 1 else data.shape[1]
wf.setnchannels(channels)
wf.setsampwidth(2)
wf.setframerate(rate)
wf.writeframes(pcm.tobytes())
wf.close()
buffer.seek(0)
return buffer.read()
# β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Set page config as the first Streamlit command
st.set_page_config(
page_title="Picture to Story Magic",
page_icon="✨",
layout="centered"
)
# Custom CSS for kid-friendly styling with improved readability
st.markdown("""
<style>
.main {
background-color: #e6f3ff;
padding: 20px;
border-radius: 15px;
}
.stButton>button {
background-color: #ffcccb;
color: #000000; /* Black text */
border-radius: 10px;
border: 2px solid #ff9999;
font-size: 18px;
font-weight: bold;
padding: 10px 20px;
transition: all 0.3s;
}
.stButton>button:hover {
background-color: #ff9999;
color: #ffffff; /* White text on hover for contrast */
transform: scale(1.05);
}
.stFileUploader {
background-color: #ffb300; /* Darker yellow for better contrast with white label text */
border: 2px dashed #ff8c00; /* Darker orange border to match */
border-radius: 10px;
padding: 10px;
}
/* Style for the file uploader's inner text */
.stFileUploader div[role="button"] {
background-color: #f0f0f0; /* Very light gray background for contrast with black text */
border-radius: 10px;
padding: 10px;
}
.stFileUploader div[role="button"] > div {
color: #000000 !important; /* Black text */
font-size: 16px;
}
/* Style for the "Browse files" button inside the file uploader */
.stFileUploader button {
background-color: #ffca28 !important; /* Yellow button background */
color: #000000 !important; /* Black text */
border-radius: 8px !important;
border: 2px solid #ffb300 !important; /* Match the container background */
padding: 5px 15px !important;
font-weight: bold !important;
box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow to make button stand out */
}
.stFileUploader button:hover {
background-color: #ff8c00 !important; /* Slightly darker yellow on hover */
color: #000000 !important; /* Keep black text */
}
.stImage {
border: 3px solid #81c784;
border-radius: 10px;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.section-header {
background-color: #b3e5fc;
padding: 10px;
border-radius: 10px;
text-align: center;
font-size: 24px;
font-weight: bold;
color: #000000; /* Black text */
margin-bottom: 10px;
}
.caption-box, .story-box {
background-color: #f0f4c3;
padding: 15px;
border-radius: 10px;
border: 2px solid #d4e157;
margin-bottom: 20px;
color: #000000; /* Black text */
}
.caption-box b, .story-box b {
color: #000000; /* Black text for bold headers */
}
</style>
""", unsafe_allow_html=True)
# Main title
st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
# Image upload section
with st.container():
st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
if not uploaded:
st.info("Upload a picture, and let's make a story! πŸŽ‰")
st.stop()
# Show image
with st.spinner("Looking at your picture..."):
pil_img = Image.open(uploaded)
st.image(pil_img, use_container_width=True)
# Caption section
with st.container():
captioner = get_image_captioner()
with st.spinner("Figuring out what's in your picture..."):
caption = part1_image_to_text(pil_img, captioner)
st.markdown(f"<div class='caption-box'><b>What's in the Picture? 🧐</b><br>{caption}</div>", unsafe_allow_html=True)
# Story and audio section
with st.container():
st.markdown("<div class='section-header'>2️⃣ Make a Story and Hear It! 🎡</div>", unsafe_allow_html=True)
if st.button("Create My Story! πŸŽ‰"):
# Story
story_pipe = get_story_pipe()
with st.spinner("Writing a super cool story..."):
story = part2_text_to_story(caption, story_pipe)
st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
# TTS
tts_pipe = get_tts_pipe()
with st.spinner("Turning your story into sound..."):
audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
st.audio(audio_bytes, format="audio/wav")