| | |
| |
|
| | """ |
| | Streamlit application for Image-to-Story demo with history sidebar. |
| | Allows demo/upload image, generates a caption, a trimmed story, |
| | and plays back as MP3 via gTTS. Keeps history of all runs. |
| | """ |
| | import streamlit as st |
| | from PIL import Image |
| | import warnings |
| | from modules import ( |
| | load_captioner, load_story_gen, |
| | generate_caption, generate_story_simple, |
| | generate_audio |
| | ) |
| | import io |
| |
|
| | warnings.filterwarnings("ignore", category=DeprecationWarning) |
| |
|
| | |
| | def reset_state(): |
| | for key in ["caption", "story", "audio_bytes", "audio_mime", "selected_index"]: |
| | if key in st.session_state: |
| | del st.session_state[key] |
| |
|
| | def main(): |
| | st.set_page_config(layout="wide") |
| | st.title("🎨 Magic Picture Story Time!") |
| | st.write("Pick or upload a picture, and watch it turn into a fun story with voice! Ready for a magical tale?") |
| |
|
| | |
| | st.sidebar.header("History") |
| | if "history" not in st.session_state: |
| | st.session_state.history = [] |
| | if "selected_index" not in st.session_state: |
| | st.session_state.selected_index = None |
| |
|
| | |
| | for idx, entry in enumerate(st.session_state.history): |
| | with st.sidebar.container(): |
| | st.sidebar.image(entry["image_bytes"], width=100) |
| | if st.sidebar.button(f"View #{idx+1}", key=f"view_{idx}"): |
| | st.session_state.selected_index = idx |
| |
|
| | |
| | if st.sidebar.button("Clear History"): |
| | st.session_state.history = [] |
| | st.session_state.selected_index = None |
| |
|
| | |
| | source = st.radio("Image source:", |
| | ("Upload my own image", "Use demo image"), |
| | on_change=reset_state) |
| |
|
| | |
| | if "models_loaded" not in st.session_state: |
| | with st.spinner("Loading models…"): |
| | st.session_state.captioner = load_captioner() |
| | st.session_state.story_gen = load_story_gen() |
| | st.session_state.models_loaded = True |
| |
|
| | captioner = st.session_state.captioner |
| | story_gen = st.session_state.story_gen |
| |
|
| | |
| | sel = st.session_state.selected_index |
| | if sel is not None: |
| | entry = st.session_state.history[sel] |
| | img = Image.open(io.BytesIO(entry["image_bytes"])).convert("RGB") |
| | st.image(img, use_container_width=True) |
| | st.markdown(f"**Caption:** {entry['caption']}") |
| | st.markdown(f"**Story:** {entry['story']}") |
| | if st.button("🔊 Play Story Audio"): |
| | st.audio(data=entry["audio_bytes"], format=entry["audio_mime"]) |
| | return |
| |
|
| | |
| | if source == "Use demo image": |
| | img = Image.open("test_kids_playing.jpg").convert("RGB") |
| | |
| | buf = io.BytesIO() |
| | img.save(buf, format="JPEG") |
| | img_bytes = buf.getvalue() |
| | else: |
| | uploaded = st.file_uploader("Upload an image", |
| | type=["png", "jpg", "jpeg"]) |
| | if not uploaded: |
| | return |
| | img = Image.open(uploaded).convert("RGB") |
| | img_bytes = uploaded.getvalue() |
| |
|
| | st.image(img, use_container_width=True) |
| |
|
| | |
| | if "caption" not in st.session_state: |
| | with st.spinner("Captioning image…"): |
| | st.session_state.caption = generate_caption(captioner, img) |
| | st.markdown(f"**Caption:** {st.session_state.caption}") |
| |
|
| | |
| | if "story" not in st.session_state: |
| | with st.spinner("Creating story…"): |
| | st.session_state.story = generate_story_simple( |
| | story_gen, st.session_state.caption, 50, 100 |
| | ) |
| | st.markdown(f"**Story:** {st.session_state.story}") |
| |
|
| | |
| | if "audio_bytes" not in st.session_state: |
| | with st.spinner("Generating audio…"): |
| | audio_bytes, mime = generate_audio(st.session_state.story) |
| | st.session_state.audio_bytes = audio_bytes |
| | st.session_state.audio_mime = mime |
| |
|
| | if st.button("🔊 Play Story Audio"): |
| | st.audio(data=st.session_state.audio_bytes, |
| | format=st.session_state.audio_mime) |
| |
|
| | |
| | if not st.session_state.history or st.session_state.history[-1]["image_bytes"] != img_bytes: |
| | st.session_state.history.append({ |
| | "image_bytes": img_bytes, |
| | "caption": st.session_state.caption, |
| | "story": st.session_state.story, |
| | "audio_bytes": st.session_state.audio_bytes, |
| | "audio_mime": st.session_state.audio_mime |
| | }) |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|