File size: 3,280 Bytes
d55766c
5e5ea3c
d55766c
 
 
a7bd32c
 
d55766c
5e5ea3c
 
 
 
d55766c
 
 
 
5e5ea3c
d55766c
 
 
 
5e5ea3c
d55766c
 
5e5ea3c
d55766c
 
5e5ea3c
 
 
d55766c
 
 
 
a7bd32c
d55766c
 
 
a7bd32c
 
 
 
 
 
 
 
 
 
d55766c
a7bd32c
d55766c
 
5e5ea3c
d55766c
 
5e5ea3c
a7bd32c
5e5ea3c
d55766c
 
a7bd32c
d55766c
a7bd32c
d55766c
 
a7bd32c
d55766c
 
5e5ea3c
d55766c
 
a7bd32c
 
 
d55766c
5e5ea3c
 
a7bd32c
 
5e5ea3c
d55766c
 
5e5ea3c
a7bd32c
5e5ea3c
a7bd32c
d55766c
 
a7bd32c
5e5ea3c
d55766c
 
5e5ea3c
d55766c
a7bd32c
d55766c
a7bd32c
 
5e5ea3c
a7bd32c
5e5ea3c
 
d55766c
 
a7bd32c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import streamlit as st
from transformers import pipeline
from PIL import Image
import tempfile
import torch
from TTS.api import TTS  # Coqui TTS
import os

# ======================
# Stage 1: Image Captioning
# ======================
@st.cache_resource
def load_image_captioner():
    return pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

def generate_caption(_pipeline, image):
    try:
        result = _pipeline(image, max_new_tokens=50)
        return result[0]['generated_text']
    except Exception as e:
        st.error(f"Caption generation failed: {str(e)}")
        return None

# ======================
# Stage 2: Story Generation
# ======================
@st.cache_resource
def load_story_generator():
    return pipeline(
        "text-generation",
        model="pranavpsv/gpt2-genre-story-generator",  # 可以替换为更强模型
        device="cuda" if torch.cuda.is_available() else "cpu"
    )

def generate_story(_pipeline, caption):
    prompt = f"""You are a children's storyteller. Based on the following image description: "{caption}", write a short children's story (80 words max). 
The story should:
- Use simple and friendly language
- Be related to the content of the image
- Include a magical or fun twist
- End happily

Story:"""

    try:
        story = _pipeline(prompt, max_length=200, temperature=0.7)[0]['generated_text']
        return story.replace(prompt, "").strip()
    except Exception as e:
        st.error(f"Story generation failed: {str(e)}")
        return None

# ======================
# Stage 3: Text-to-Speech using Coqui TTS
# ======================
@st.cache_resource
def load_tts():
    return TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=torch.cuda.is_available())

def text_to_speech(tts_model, story_text):
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            tts_model.tts_to_file(text=story_text, file_path=f.name)
            return f.name
    except Exception as e:
        st.error(f"Audio generation failed: {str(e)}")
        return None

# ======================
# Main Streamlit App
# ======================
def main():
    st.set_page_config(page_title="Magic Story Generator", layout="wide")
    st.title("🧚 Magic Story Generator")

    uploaded_image = st.file_uploader("Upload a photo", type=["jpg", "jpeg", "png"])
    if not uploaded_image:
        return

    image = Image.open(uploaded_image)
    st.image(image, use_container_width=True)

    with st.spinner("Processing your magical story..."):
        caption_pipe = load_image_captioner()
        story_pipe = load_story_generator()
        tts_model = load_tts()

        caption = generate_caption(caption_pipe, image)
        if caption:
            st.success(f"Image description: {caption}")
            story = generate_story(story_pipe, caption)

            if story:
                st.subheader("Your Magical Story")
                st.markdown(story)

                audio_path = text_to_speech(tts_model, story)
                if audio_path:
                    st.audio(audio_path, format="audio/wav")

if __name__ == "__main__":
    main()