import io import wave import streamlit as st from transformers import pipeline from PIL import Image import numpy as np # ——— 1) MODEL LOADING (cached) ———————————————— @st.cache_resource def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"): return pipeline("image-to-text", model=model_name, device="cpu") @st.cache_resource def get_story_pipe(model_name="google/flan-t5-base"): return pipeline("text2text-generation", model=model_name, device="cpu") @st.cache_resource def get_tts_pipe(model_name="facebook/mms-tts-eng"): return pipeline("text-to-speech", model=model_name, device="cpu") # ——— 2) TRANSFORM FUNCTIONS ———————————————— def part1_image_to_text(pil_img, captioner): results = captioner(pil_img) return results[0].get("generated_text", "") if results else "" def part2_text_to_story( caption: str, story_pipe, target_words: int = 100, max_length: int = 100, min_length: int = 80, do_sample: bool = True, top_k: int = 100, top_p: float= 0.9, temperature: float= 0.7, repetition_penalty: float = 1.1, no_repeat_ngram_size: int = 4 ) -> str: prompt = ( f"Write a vivid, imaginative short story of about {target_words} words " f"describing this scene: {caption}" ) out = story_pipe( prompt, max_length=max_length, min_length=min_length, do_sample=do_sample, top_k=top_k, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, early_stopping=False ) raw = out[0].get("generated_text", "").strip() if not raw: return "" # strip echo of prompt if raw.lower().startswith(prompt.lower()): story = raw[len(prompt):].strip() else: story = raw # cut at last full stop idx = story.rfind(".") if idx != -1: story = story[:idx+1] return story def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes: out = tts_pipe(text) if isinstance(out, list): out = out[0] audio_array = out["audio"] # np.ndarray (channels, samples) rate = out["sampling_rate"] # int data = audio_array.T if audio_array.ndim == 2 else audio_array pcm = (data * 32767).astype(np.int16) buffer = io.BytesIO() wf = wave.open(buffer, "wb") channels = 1 if data.ndim == 1 else data.shape[1] wf.setnchannels(channels) wf.setsampwidth(2) wf.setframerate(rate) wf.writeframes(pcm.tobytes()) wf.close() buffer.seek(0) return buffer.read() # ——— 3) STREAMLIT UI ———————————————————————————— # Set page config as the first Streamlit command st.set_page_config( page_title="Picture to Story Magic", page_icon="✨", layout="centered" ) # Custom CSS for kid-friendly styling with improved readability st.markdown(""" """, unsafe_allow_html=True) # Main title st.markdown("