import io import wave import streamlit as st from transformers import pipeline from PIL import Image import numpy as np # ——— 1) MODEL LOADING (cached) ———————————————— @st.cache_resource def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"): return pipeline("image-to-text", model=model_name, device="cpu") @st.cache_resource def get_story_pipe(model_name="google/flan-t5-base"): return pipeline("text2text-generation", model=model_name, device="cpu") @st.cache_resource def get_tts_pipe(model_name="facebook/mms-tts-eng"): return pipeline("text-to-speech", model=model_name, device="cpu") # ——— 2) TRANSFORM FUNCTIONS ———————————————— def part1_image_to_text(pil_img, captioner): results = captioner(pil_img) return results[0].get("generated_text", "") if results else "" def part2_text_to_story( caption: str, story_pipe, target_words: int = 100, max_length: int = 100, min_length: int = 80, do_sample: bool = True, top_k: int = 100, top_p: float= 0.9, temperature: float= 0.7, repetition_penalty: float = 1.1, no_repeat_ngram_size: int = 4 ) -> str: prompt = ( f"Write a vivid, imaginative short story of about {target_words} words " f"describing this scene: {caption}" ) out = story_pipe( prompt, max_length=max_length, min_length=min_length, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, early_stopping=False ) raw = out[0].get("generated_text", "").strip() if not raw: return "" # strip echo of prompt if raw.lower().startswith(prompt.lower()): story = raw[len(prompt):].strip() else: story = raw # cut at last full stop idx = story.rfind(".") if idx != -1: story = story[:idx+1] return story def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes: out = tts_pipe(text) if isinstance(out, list): out = out[0] audio_array = out["audio"] # np.ndarray (channels, samples) rate = out["sampling_rate"] # int data = audio_array.T if audio_array.ndim == 2 else audio_array pcm = (data * 32767).astype(np.int16) buffer = io.BytesIO() wf = wave.open(buffer, "wb") channels = 1 if data.ndim == 1 else data.shape[1] wf.setnchannels(channels) wf.setsampwidth(2) wf.setframerate(rate) wf.writeframes(pcm.tobytes()) wf.close() buffer.seek(0) return buffer.read() # ——— 3) STREAMLIT UI ———————————————————————————— # Set page config as the first Streamlit command st.set_page_config( page_title="Picture to Story Magic", page_icon="✨", layout="centered" ) # Custom CSS for kid-friendly styling with improved readability st.markdown(""" """, unsafe_allow_html=True) # Main title st.markdown("
Picture to Story Magic! ✨
", unsafe_allow_html=True) # Image upload section with st.container(): st.markdown("
1️⃣ Pick a Fun Picture! 🖼️
", unsafe_allow_html=True) uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"]) if not uploaded: st.info("Upload a picture, and let's make a story! 🎉") st.stop() # Show image with st.spinner("Looking at your picture..."): pil_img = Image.open(uploaded) st.image(pil_img, use_container_width=True) # Caption section with st.container(): captioner = get_image_captioner() with st.spinner("Figuring out what's in your picture..."): caption = part1_image_to_text(pil_img, captioner) st.markdown(f"
What's in the Picture? 🧐
{caption}
", unsafe_allow_html=True) # Story and audio section with st.container(): st.markdown("
2️⃣ Make a Story and Hear It! 🎵
", unsafe_allow_html=True) if st.button("Create My Story! 🎉"): # Story story_pipe = get_story_pipe() with st.spinner("Writing a super cool story..."): story = part2_text_to_story(caption, story_pipe) st.markdown(f"
Your Cool Story! 📚
{story}
", unsafe_allow_html=True) # TTS tts_pipe = get_tts_pipe() with st.spinner("Turning your story into sound..."): audio_bytes = part3_text_to_speech_bytes(story, tts_pipe) st.audio(audio_bytes, format="audio/wav")