File size: 4,814 Bytes
2cbbfe6
e36e817
2cbbfe6
53719e5
e36e817
53719e5
2cbbfe6
3718f37
 
af192ef
2cbbfe6
e36e817
 
 
2cbbfe6
53719e5
a0254c7
af192ef
3718f37
53719e5
ca604ad
53719e5
ca604ad
 
 
2cbbfe6
53719e5
89d0d20
 
3718f37
53719e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e36e817
 
 
ca604ad
e36e817
2cbbfe6
e36e817
2cbbfe6
 
 
3718f37
2cbbfe6
 
639df53
53719e5
 
 
 
 
 
 
 
 
 
 
 
 
6c53b18
e36e817
53719e5
 
 
 
6c53b18
e36e817
 
6c53b18
 
e36e817
53719e5
6c53b18
2cbbfe6
3718f37
53719e5
ca604ad
e36e817
2cbbfe6
 
3718f37
53719e5
ca604ad
e36e817
 
 
 
2cbbfe6
3718f37
53719e5
e36e817
 
 
 
53719e5
3718f37
2cbbfe6
e36e817
 
af192ef
53719e5
 
 
 
 
 
 
 
 
 
2cbbfe6
4862507
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py

"""
Streamlit application for Image-to-Story demo with history sidebar.
Allows demo/upload image, generates a caption, a trimmed story,
and plays back as MP3 via gTTS. Keeps history of all runs.
"""
import streamlit as st
from PIL import Image
import warnings
from modules import (
    load_captioner, load_story_gen,
    generate_caption, generate_story_simple,
    generate_audio
)
import io

warnings.filterwarnings("ignore", category=DeprecationWarning)

# Reset state when switching image source
def reset_state():
    for key in ["caption", "story", "audio_bytes", "audio_mime", "selected_index"]:
        if key in st.session_state:
            del st.session_state[key]

def main():
    st.set_page_config(layout="wide")
    st.title("🎨 Magic Picture Story Time!")
    st.write("Pick or upload a picture, and watch it turn into a fun story with voice! Ready for a magical tale?")

    # --- Sidebar: History ---
    st.sidebar.header("History")
    if "history" not in st.session_state:
        st.session_state.history = []           # list of dicts
    if "selected_index" not in st.session_state:
        st.session_state.selected_index = None

    # Render thumbnails & select buttons
    for idx, entry in enumerate(st.session_state.history):
        with st.sidebar.container():
            st.sidebar.image(entry["image_bytes"], width=100)
            if st.sidebar.button(f"View #{idx+1}", key=f"view_{idx}"):
                st.session_state.selected_index = idx

    # Sidebar clear-all button
    if st.sidebar.button("Clear History"):
        st.session_state.history = []
        st.session_state.selected_index = None

    # --- Main panel: image selection ---
    source = st.radio("Image source:",
                      ("Upload my own image", "Use demo image"),
                      on_change=reset_state)

    # Load pipelines once
    if "models_loaded" not in st.session_state:
        with st.spinner("Loading models…"):
            st.session_state.captioner = load_captioner()
            st.session_state.story_gen = load_story_gen()
        st.session_state.models_loaded = True

    captioner = st.session_state.captioner
    story_gen = st.session_state.story_gen

    # If user clicked a history entry, load it
    sel = st.session_state.selected_index
    if sel is not None:
        entry = st.session_state.history[sel]
        img = Image.open(io.BytesIO(entry["image_bytes"])).convert("RGB")
        st.image(img, use_container_width=True)
        st.markdown(f"**Caption:** {entry['caption']}")
        st.markdown(f"**Story:** {entry['story']}")
        if st.button("🔊 Play Story Audio"):
            st.audio(data=entry["audio_bytes"], format=entry["audio_mime"])
        return

    # Otherwise, handle a fresh upload/demo
    if source == "Use demo image":
        img = Image.open("test_kids_playing.jpg").convert("RGB")
        # grab raw bytes for history
        buf = io.BytesIO()
        img.save(buf, format="JPEG")
        img_bytes = buf.getvalue()
    else:
        uploaded = st.file_uploader("Upload an image",
                                    type=["png", "jpg", "jpeg"])
        if not uploaded:
            return
        img = Image.open(uploaded).convert("RGB")
        img_bytes = uploaded.getvalue()

    st.image(img, use_container_width=True)

    # Step 1: Caption
    if "caption" not in st.session_state:
        with st.spinner("Captioning image…"):
            st.session_state.caption = generate_caption(captioner, img)
    st.markdown(f"**Caption:** {st.session_state.caption}")

    # Step 2: Story
    if "story" not in st.session_state:
        with st.spinner("Creating story…"):
            st.session_state.story = generate_story_simple(
                story_gen, st.session_state.caption, 50, 100
            )
    st.markdown(f"**Story:** {st.session_state.story}")

    # Step 3: Audio
    if "audio_bytes" not in st.session_state:
        with st.spinner("Generating audio…"):
            audio_bytes, mime = generate_audio(st.session_state.story)
            st.session_state.audio_bytes = audio_bytes
            st.session_state.audio_mime  = mime

    if st.button("🔊 Play Story Audio"):
        st.audio(data=st.session_state.audio_bytes,
                 format=st.session_state.audio_mime)

    # Step 4: Append to history (only once per new run)
    if not st.session_state.history or st.session_state.history[-1]["image_bytes"] != img_bytes:
        st.session_state.history.append({
            "image_bytes": img_bytes,
            "caption":     st.session_state.caption,
            "story":       st.session_state.story,
            "audio_bytes": st.session_state.audio_bytes,
            "audio_mime":  st.session_state.audio_mime
        })

if __name__ == "__main__":
    main()