import io import wave import streamlit as st from transformers import pipeline from PIL import Image import numpy as np import time import threading # ——— 1) MODEL LOADING (cached) ———————————————— @st.cache_resource def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"): return pipeline("image-to-text", model=model_name, device="cpu") @st.cache_resource def get_story_pipe(model_name="google/flan-t5-base"): return pipeline("text2text-generation", model=model_name, device="cpu") @st.cache_resource def get_tts_pipe(model_name="facebook/mms-tts-eng"): return pipeline("text-to-speech", model=model_name, device="cpu") # ——— 2) TRANSFORM FUNCTIONS ———————————————— def part1_image_to_text(pil_img, captioner): results = captioner(pil_img) return results[0].get("generated_text", "") if results else "" def part2_text_to_story( caption: str, story_pipe, target_words: int = 100, max_length: int = 100, min_length: int = 80, do_sample: bool = True, top_k: int = 100, top_p: float= 0.9, temperature: float= 0.7, repetition_penalty: float = 1.1, no_repeat_ngram_size: int = 4 ) -> str: prompt = ( f"Write a vivid, imaginative short story of about {target_words} words " f"describing this scene: {caption}" ) out = story_pipe( prompt, max_length=max_length, min_length=min_length, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, early_stopping=False ) raw = out[0].get("generated_text", "").strip() if not raw: return "" # strip echo of prompt if raw.lower().startswith(prompt.lower()): story = raw[len(prompt):].strip() else: story = raw # cut at last full stop idx = story.rfind(".") if idx != -1: story = story[:idx+1] return story def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes: out = tts_pipe(text) if isinstance(out, list): out = out[0] audio_array = out["audio"] # np.ndarray (channels, samples) rate = out["sampling_rate"] # int data = audio_array.T if audio_array.ndim == 2 else audio_array pcm = (data * 32767).astype(np.int16) buffer = io.BytesIO() wf = wave.open(buffer, "wb") channels = 1 if data.ndim == 1 else data.shape[1] wf.setnchannels(channels) wf.setsampwidth(2) wf.setframerate(rate) wf.writeframes(pcm.tobytes()) wf.close() buffer.seek(0) return buffer.read() # ——— 3) STREAMLIT UI ———————————————————————————— # Set page config as the first Streamlit command st.set_page_config( page_title="Picture to Story Magic", page_icon="✨", layout="centered" ) # Custom CSS for kid-friendly styling with improved readability st.markdown(""" """, unsafe_allow_html=True) # Main title st.markdown("
Picture to Story Magic! ✨
", unsafe_allow_html=True) # Image upload section with st.container(): st.markdown("
1️⃣ Pick a Fun Picture! 🖼️
", unsafe_allow_html=True) uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"]) if not uploaded: st.info("Upload a picture, and let's make a story! 🎉") st.stop() # Show image with st.spinner("Looking at your picture..."): pil_img = Image.open(uploaded) st.image(pil_img, use_container_width=True) # Caption section with st.container(): st.markdown("
2️⃣ What's in the Picture? 🧐
", unsafe_allow_html=True) captioner = get_image_captioner() progress_bar = st.progress(0) result = [None] def run_caption(): result[0] = part1_image_to_text(pil_img, captioner) with st.spinner("Figuring out what's in your picture..."): thread = threading.Thread(target=run_caption) thread.start() for i in range(100): progress_bar.progress(i + 1) time.sleep(0.05) # Adjust for ~5 seconds total thread.join() progress_bar.empty() caption = result[0] st.markdown(f"
Picture Description:
{caption}
", unsafe_allow_html=True) # Story and audio section with st.container(): st.markdown("
3️⃣ Your Story and Audio! 🎵
", unsafe_allow_html=True) # Story story_pipe = get_story_pipe() progress_bar = st.progress(0) result = [None] def run_story(): result[0] = part2_text_to_story(caption, story_pipe) with st.spinner("Writing a super cool story..."): thread = threading.Thread(target=run_story) thread.start() for i in range(100): progress_bar.progress(i + 1) time.sleep(0.07) # Adjust for ~7 seconds total thread.join() progress_bar.empty() story = result[0] st.markdown(f"
Your Cool Story! 📚
{story}
", unsafe_allow_html=True) # TTS tts_pipe = get_tts_pipe() progress_bar = st.progress(0) result = [None] def run_tts(): result[0] = part3_text_to_speech_bytes(story, tts_pipe) with st.spinner("Turning your story into sound..."): thread = threading.Thread(target=run_tts) thread.start() for i in range(100): progress_bar.progress(i + 1) time.sleep(0.10) # Adjust for ~10 seconds total thread.join() progress_bar.empty() audio_bytes = result[0] st.audio(audio_bytes, format="audio/wav")