Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from PIL import Image | |
| import tempfile | |
| import numpy as np | |
| from transformers import pipeline, set_seed | |
| import soundfile as sf | |
| # --- 模型初始化(缓存优化)--- | |
| def load_models(): | |
| caption_pipeline = pipeline( | |
| "image-to-text", | |
| model="Salesforce/blip-image-captioning-base", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| story_pipeline = pipeline( | |
| "text-generation", | |
| model="pranavpsv/gpt2-genre-story-generator", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| tts_pipeline = pipeline( | |
| "text-to-speech", | |
| model="speechbrain/tts-tacotron2-ljspeech", | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| ) | |
| return caption_pipeline, story_pipeline, tts_pipeline | |
| # --- Stage 1: Image → Caption --- | |
| def generate_caption(image, pipeline): | |
| caption = pipeline(image)[0]['generated_text'] | |
| return caption | |
| # --- Stage 2: Caption(keyword) → Story (严格限制字数) --- | |
| def generate_story(caption, pipeline): | |
| prompt = f"Generate a children's story in 50-100 words about: {caption}" | |
| story = pipeline( | |
| prompt, | |
| max_length=150, # Token数量(约对应100词) | |
| min_length=80, # 约对应50词 | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| num_return_sequences=1 | |
| )[0]['generated_text'] | |
| # 移除重复提示并截断 | |
| story = story.replace(prompt, "").strip().split(".")[:5] # 取前5个句子 | |
| return ".".join(story[:5]) + "." # 确保以句号结尾 | |
| # --- Stage 3: Story → Audio (兼容Spaces) --- | |
| def generate_audio(story_text, pipeline): | |
| speech = pipeline(story_text) | |
| audio_array = speech["audio"].squeeze().numpy() | |
| sample_rate = speech["sampling_rate"] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| sf.write(f.name, audio_array, sample_rate) | |
| return f.name | |
| # --- Streamlit UI --- | |
| def main(): | |
| st.title("📖 AI Storyteller for Kids") | |
| caption_pipeline, story_pipeline, tts_pipeline = load_models() | |
| uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"]) | |
| if uploaded_image: | |
| image = Image.open(uploaded_image) | |
| st.image(image, use_column_width=True) | |
| with st.spinner("🔍 Analyzing the image..."): | |
| caption = generate_caption(image, caption_pipeline) | |
| st.success(f"📝 Caption: {caption}") | |
| with st.spinner("✨ Creating a magical story..."): | |
| story = generate_story(caption, story_pipeline) | |
| st.subheader("📚 Your Story") | |
| st.write(story) | |
| st.info(f"Word count: {len(story.split())}") # 显示字数 | |
| with st.spinner("🔊 Generating audio..."): | |
| audio_path = generate_audio(story, tts_pipeline) | |
| st.audio(audio_path, format="audio/wav") | |
| if __name__ == "__main__": | |
| import torch # 延迟导入以避免Spaces预加载问题 | |
| main() |