import streamlit as st from PIL import Image import tempfile import numpy as np from transformers import pipeline, set_seed import soundfile as sf # --- 模型初始化(缓存优化)--- @st.cache_resource def load_models(): caption_pipeline = pipeline( "image-to-text", model="Salesforce/blip-image-captioning-base", device="cuda" if torch.cuda.is_available() else "cpu" ) story_pipeline = pipeline( "text-generation", model="pranavpsv/gpt2-genre-story-generator", device="cuda" if torch.cuda.is_available() else "cpu" ) tts_pipeline = pipeline( "text-to-speech", model="speechbrain/tts-tacotron2-ljspeech", device="cuda" if torch.cuda.is_available() else "cpu" ) return caption_pipeline, story_pipeline, tts_pipeline # --- Stage 1: Image → Caption --- def generate_caption(image, pipeline): caption = pipeline(image)[0]['generated_text'] return caption # --- Stage 2: Caption(keyword) → Story (严格限制字数) --- def generate_story(caption, pipeline): prompt = f"Generate a children's story in 50-100 words about: {caption}" story = pipeline( prompt, max_length=150, # Token数量(约对应100词) min_length=80, # 约对应50词 do_sample=True, temperature=0.7, top_k=50, num_return_sequences=1 )[0]['generated_text'] # 移除重复提示并截断 story = story.replace(prompt, "").strip().split(".")[:5] # 取前5个句子 return ".".join(story[:5]) + "." # 确保以句号结尾 # --- Stage 3: Story → Audio (兼容Spaces) --- def generate_audio(story_text, pipeline): speech = pipeline(story_text) audio_array = speech["audio"].squeeze().numpy() sample_rate = speech["sampling_rate"] with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: sf.write(f.name, audio_array, sample_rate) return f.name # --- Streamlit UI --- def main(): st.title("📖 AI Storyteller for Kids") caption_pipeline, story_pipeline, tts_pipeline = load_models() uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"]) if uploaded_image: image = Image.open(uploaded_image) st.image(image, use_column_width=True) with st.spinner("🔍 Analyzing the image..."): caption = generate_caption(image, caption_pipeline) st.success(f"📝 Caption: {caption}") with st.spinner("✨ Creating a magical story..."): story = generate_story(caption, story_pipeline) st.subheader("📚 Your Story") st.write(story) st.info(f"Word count: {len(story.split())}") # 显示字数 with st.spinner("🔊 Generating audio..."): audio_path = generate_audio(story, tts_pipeline) st.audio(audio_path, format="audio/wav") if __name__ == "__main__": import torch # 延迟导入以避免Spaces预加载问题 main()