import io import wave import streamlit as st from transformers import pipeline from PIL import Image import numpy as np import time import threading # ——— 1) MODEL LOADING (cached) ———————————————— # Cache the model loading to avoid reloading on every rerun, improving performance @st.cache_resource def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"): # Load the image-to-text model for generating captions from images return pipeline("image-to-text", model=model_name, device="cpu") @st.cache_resource def get_story_pipe(model_name="google/flan-t5-base"): # Load the text-to-text model for generating stories from captions return pipeline("text2text-generation", model=model_name, device="cpu") @st.cache_resource def get_tts_pipe(model_name="facebook/mms-tts-eng"): # Load the text-to-speech model for converting stories to audio return pipeline("text-to-speech", model=model_name, device="cpu") # ——— 2) TRANSFORM FUNCTIONS ———————————————— def part1_image_to_text(pil_img, captioner): # Generate a caption for the input image using the captioner model results = captioner(pil_img) # Extract the generated caption, return empty string if no result return results[0].get("generated_text", "") if results else "" def part2_text_to_story( caption: str, story_pipe, target_words: int = 100, max_length: int = 100, min_length: int = 80, do_sample: bool = True, top_k: int = 100, top_p: float= 0.9, temperature: float= 0.7, repetition_penalty: float = 1.1, no_repeat_ngram_size: int = 4 ) -> str: # Create a prompt instructing the model to write a story based on the caption prompt = ( f"Write a vivid, imaginative short story of about {target_words} words " f"describing this scene: {caption}" ) # Generate the story using the text-to-text model with specified parameters out = story_pipe( prompt, max_length=max_length, # Maximum length of generated text min_length=min_length, # Minimum length to ensure sufficient content do_sample=do_sample, # Enable sampling for creative output top_k=top_k, # Consider top-k tokens for sampling top_p=top_p, # Use nucleus sampling for diversity temperature=temperature, # Control randomness of output repetition_penalty=repetition_penalty, # Penalize repeated phrases no_repeat_ngram_size=no_repeat_ngram_size, # Prevent repeating n-grams early_stopping=False # Continue until max_length is reached ) # Extract the generated text and clean it raw = out[0].get("generated_text", "").strip() if not raw: return "" # Remove the prompt if it appears in the output if raw.lower().startswith(prompt.lower()): story = raw[len(prompt):].strip() else: story = raw # Truncate at the last full stop for a natural ending idx = story.rfind(".") if idx != -1: story = story[:idx+1] return story def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes: # Convert the input text to audio using the text-to-speech model out = tts_pipe(text) if isinstance(out, list): out = out[0] # Extract audio data (numpy array) and sampling rate audio_array = out["audio"] # np.ndarray (channels, samples) rate = out["sampling_rate"] # int # Transpose audio array if it has multiple channels data = audio_array.T if audio_array.ndim == 2 else audio_array # Convert audio to 16-bit PCM format for WAV compatibility pcm = (data * 32767).astype(np.int16) # Create a WAV file in memory buffer = io.BytesIO() wf = wave.open(buffer, "wb") channels = 1 if data.ndim == 1 else data.shape[1] # Set mono or stereo wf.setnchannels(channels) wf.setsampwidth(2) # 2 bytes for 16-bit audio wf.setframerate(rate) # Set sampling rate wf.writeframes(pcm.tobytes()) # Write audio data wf.close() buffer.seek(0) # Reset buffer to start for reading return buffer.read() # Return WAV bytes # ——— 3) STREAMLIT UI ———————————————————————————— # Configure the Streamlit page for a kid-friendly, centered layout st.set_page_config( page_title="Picture to Story Magic", page_icon="✨", layout="centered" ) # Apply custom CSS for a colorful, engaging, and readable interface st.markdown(""" """, unsafe_allow_html=True) # Display the main title with a fun, magical theme st.markdown("
Picture to Story Magic! ✨
", unsafe_allow_html=True) # Image upload section with st.container(): # Prompt user to upload an image st.markdown("
1️⃣ Pick a Fun Picture! 🖼️
", unsafe_allow_html=True) uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"]) if not uploaded: # Stop execution if no image is uploaded, with a friendly message st.info("Upload a picture, and let's make a story! 🎉") st.stop() # Display the uploaded image with st.spinner("Looking at your picture..."): pil_img = Image.open(uploaded) st.image(pil_img, use_container_width=True) # Show image scaled to container # Caption generation section with st.container(): st.markdown("
2️⃣ What's in the Picture? 🧐
", unsafe_allow_html=True) captioner = get_image_captioner() # Load captioning model progress_bar = st.progress(0) # Initialize progress bar result = [None] # Store caption result def run_caption(): # Run captioning in a separate thread to avoid blocking UI result[0] = part1_image_to_text(pil_img, captioner) with st.spinner("Figuring out what's in your picture..."): thread = threading.Thread(target=run_caption) thread.start() # Simulate progress for ~5 seconds for i in range(100): progress_bar.progress(i + 1) time.sleep(0.05) thread.join() # Wait for captioning to complete progress_bar.empty() # Clear progress bar caption = result[0] # Display the generated caption in a styled box st.markdown(f"
Picture Description:
{caption}
", unsafe_allow_html=True) # Story and audio generation section with st.container(): st.markdown("
3️⃣ Your Story and Audio! 🎵
", unsafe_allow_html=True) # Story generation story_pipe = get_story_pipe() # Load story model progress_bar = st.progress(0) result = [None] # Store story result def run_story(): # Generate story in a separate thread result[0] = part2_text_to_story(caption, story_pipe) with st.spinner("Writing a super cool story..."): thread = threading.Thread(target=run_story) thread.start() # Simulate progress for ~7 seconds for i in range(100): progress_bar.progress(i + 1) time.sleep(0.07) thread.join() progress_bar.empty() story = result[0] # Display the generated story in a styled box st.markdown(f"
Your Cool Story! 📚
{story}
", unsafe_allow_html=True) # Text-to-speech conversion tts_pipe = get_tts_pipe() # Load TTS model progress_bar = st.progress(0) result = [None] # Store audio result def run_tts(): # Generate audio in a separate thread result[0] = part3_text_to_speech_bytes(story, tts_pipe) with st.spinner("Turning your story into sound..."): thread = threading.Thread(target=run_tts) thread.start() # Simulate progress for ~10 seconds for i in range(100): progress_bar.progress(i + 1) time.sleep(0.10) thread.join() progress_bar.empty() audio_bytes = result[0] # Play the generated audio in the UI st.audio(audio_bytes, format="audio/wav")