File size: 3,059 Bytes
33de3ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import streamlit as st
from PIL import Image
import tempfile
import numpy as np
from transformers import pipeline, set_seed
import soundfile as sf

# --- 模型初始化(缓存优化)---
@st.cache_resource
def load_models():
    caption_pipeline = pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    story_pipeline = pipeline(
        "text-generation",
        model="pranavpsv/gpt2-genre-story-generator",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    tts_pipeline = pipeline(
        "text-to-speech",
        model="speechbrain/tts-tacotron2-ljspeech",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )
    return caption_pipeline, story_pipeline, tts_pipeline

# --- Stage 1: Image → Caption ---
def generate_caption(image, pipeline):
    caption = pipeline(image)[0]['generated_text']
    return caption

# --- Stage 2: Caption(keyword) → Story (严格限制字数) ---
def generate_story(caption, pipeline):
    prompt = f"Generate a children's story in 50-100 words about: {caption}"
    story = pipeline(
        prompt,
        max_length=150,  # Token数量(约对应100词)
        min_length=80,   # 约对应50词
        do_sample=True,
        temperature=0.7,
        top_k=50,
        num_return_sequences=1
    )[0]['generated_text']
    # 移除重复提示并截断
    story = story.replace(prompt, "").strip().split(".")[:5]  # 取前5个句子
    return ".".join(story[:5]) + "."  # 确保以句号结尾

# --- Stage 3: Story → Audio (兼容Spaces) ---
def generate_audio(story_text, pipeline):
    speech = pipeline(story_text)
    audio_array = speech["audio"].squeeze().numpy()
    sample_rate = speech["sampling_rate"]
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
        sf.write(f.name, audio_array, sample_rate)
        return f.name

# --- Streamlit UI ---
def main():
    st.title("📖 AI Storyteller for Kids")
    caption_pipeline, story_pipeline, tts_pipeline = load_models()
    
    uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"])
    if uploaded_image:
        image = Image.open(uploaded_image)
        st.image(image, use_column_width=True)
        
        with st.spinner("🔍 Analyzing the image..."):
            caption = generate_caption(image, caption_pipeline)
            st.success(f"📝 Caption: {caption}")
        
        with st.spinner("✨ Creating a magical story..."):
            story = generate_story(caption, story_pipeline)
            st.subheader("📚 Your Story")
            st.write(story)
            st.info(f"Word count: {len(story.split())}")  # 显示字数
        
        with st.spinner("🔊 Generating audio..."):
            audio_path = generate_audio(story, tts_pipeline)
            st.audio(audio_path, format="audio/wav")

if __name__ == "__main__":
    import torch  # 延迟导入以避免Spaces预加载问题
    main()