File size: 2,637 Bytes
c4d25fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# app.py

"""
app.py

Streamlit application for Image-to-Story demo.
Allows demo/upload image, generates a caption, a trimmed story,
and plays back as MP3 via gTTS.
"""
import streamlit as st
from PIL import Image
import warnings
from modules import (
    load_captioner, load_story_gen,
    generate_caption, generate_story_simple,
    generate_audio
)

warnings.filterwarnings("ignore", category=DeprecationWarning)

def reset_state():
    for key in ["caption", "story", "audio_bytes", "audio_mime"]:
        if key in st.session_state:
            del st.session_state[key]

def main():
    st.title("🖼️ → 📖 Image-to-Story App for Kids")
    st.write("Upload or demo an image to get a 50–100 word story and audio!")

    source = st.radio("Image source:",
                      ("Upload my own image", "Use demo image"),
                      on_change=reset_state)

    # Load pipelines once
    if "models_loaded" not in st.session_state:
        with st.spinner("Loading models…"):
            st.session_state.captioner = load_captioner()
            st.session_state.story_gen = load_story_gen()
        st.session_state.models_loaded = True

    captioner = st.session_state.captioner
    story_gen = st.session_state.story_gen

    # Acquire image
    if source == "Use demo image":
        img = Image.open("test_kids_playing.jpg").convert("RGB")
    else:
        uploaded = st.file_uploader("Upload an image",
                                    type=["png", "jpg", "jpeg"])
        if not uploaded:
            return
        img = Image.open(uploaded).convert("RGB")

    st.image(img, use_container_width=True)

    # Caption
    if "caption" not in st.session_state:
        with st.spinner("Captioning image…"):
            st.session_state.caption = generate_caption(captioner, img)
    st.markdown(f"**Caption:** {st.session_state.caption}")

    # Story
    if "story" not in st.session_state:
        with st.spinner("Creating story…"):
            st.session_state.story = generate_story_simple(
                story_gen, st.session_state.caption, 50, 100
            )
    st.markdown(f"**Story:** {st.session_state.story}")

    # Audio
    if "audio_bytes" not in st.session_state:
        with st.spinner("Generating audio…"):
            audio_bytes, mime = generate_audio(st.session_state.story)
            st.session_state.audio_bytes = audio_bytes
            st.session_state.audio_mime = mime

    if st.button("🔊 Play Story Audio"):
        st.audio(data=st.session_state.audio_bytes,
                 format=st.session_state.audio_mime)

if __name__ == "__main__":
    main()