hskwon7 commited on
Commit
e36e817
·
verified ·
1 Parent(s): 3745c6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -66
app.py CHANGED
@@ -1,115 +1,82 @@
1
  # app.py
 
2
  """
3
  app.py
4
 
5
  Streamlit application for Image-to-Story demo.
6
- Allows users to upload an image or use a demo image, generates a caption, creates a child-friendly story,
7
- and plays it back as audio.
8
- Suitable for deployment on Hugging Face Spaces.
9
  """
10
  import streamlit as st
11
  from PIL import Image
12
  import warnings
13
  from modules import (
14
- load_captioner,
15
- load_story_gen,
16
- load_tts,
17
- generate_caption,
18
- generate_story,
19
- generate_audio,
20
- generate_story_simple
21
  )
22
 
23
- # Suppress deprecation warnings for cleaner UI
24
  warnings.filterwarnings("ignore", category=DeprecationWarning)
25
 
26
-
27
  def reset_state():
28
- """
29
- Clear generated caption, story, and audio when image source changes.
30
- """
31
- for key in ["caption", "story", "audio_data", "audio_sr"]:
32
  if key in st.session_state:
33
  del st.session_state[key]
34
 
35
-
36
  def main():
37
  st.title("🖼️ → 📖 Image-to-Story App for Kids")
38
- st.write(
39
- "Upload an image or use the demo image to get an engaging story suitable for 3–10 year-olds, "
40
- "with audio playback powered by Hugging Face pipelines!"
41
- )
42
 
43
- # Choose image source with callback to reset state
44
- source = st.radio(
45
- "Choose image source:",
46
- ("Upload my own image", "Use demo image"),
47
- on_change=reset_state
48
- )
49
 
50
- # Load pipelines (cached) with a friendly spinner
51
  if "models_loaded" not in st.session_state:
52
- with st.spinner("Getting things ready, please wait…"):
53
  st.session_state.captioner = load_captioner()
54
  st.session_state.story_gen = load_story_gen()
55
- st.session_state.tts_pipe = load_tts()
56
  st.session_state.models_loaded = True
57
 
58
  captioner = st.session_state.captioner
59
  story_gen = st.session_state.story_gen
60
- tts_pipe = st.session_state.tts_pipe
61
 
62
- # Get image object
63
  if source == "Use demo image":
64
- img = Image.open("test_kids_playing.jpg")
65
  else:
66
- uploaded = st.file_uploader(
67
- "Upload an image (PNG, JPG, JPEG)",
68
- type=["png", "jpg", "jpeg"],
69
- key="upload"
70
- )
71
  if not uploaded:
72
  return
73
- img = Image.open(uploaded)
74
 
75
- # Display image
76
  st.image(img, use_container_width=True)
77
 
78
- # Generate caption
79
  if "caption" not in st.session_state:
80
- with st.spinner("Generating image caption…"):
81
  st.session_state.caption = generate_caption(captioner, img)
82
  st.markdown(f"**Caption:** {st.session_state.caption}")
83
 
84
- # Generate story
85
  if "story" not in st.session_state:
86
- with st.spinner("Creating story for kids…"):
87
- st.session_state.story = generate_story_simple(story_gen, st.session_state.caption,
88
- min_words=50, max_words=100)
89
-
90
- # st.session_state.story = generate_story(
91
- # story_gen,
92
- # st.session_state.caption,
93
- # min_words=50,
94
- # max_words=100
95
- # )
96
  st.markdown(f"**Story:** {st.session_state.story}")
97
 
98
- # Generate audio
99
- if "audio_data" not in st.session_state:
100
- with st.spinner("Synthesizing speech…"):
101
- audio_array, sr = generate_audio(tts_pipe, st.session_state.story)
102
- st.session_state.audio_data = audio_array
103
- st.session_state.audio_sr = sr
104
 
105
- # Audio playback button
106
  if st.button("🔊 Play Story Audio"):
107
- st.audio(
108
- data=st.session_state.audio_data,
109
- format="audio/wav",
110
- sample_rate=st.session_state.audio_sr
111
- )
112
-
113
 
114
  if __name__ == "__main__":
115
  main()
 
1
  # app.py
2
+
3
  """
4
  app.py
5
 
6
  Streamlit application for Image-to-Story demo.
7
+ Allows demo/upload image, generates a caption, a trimmed story,
8
+ and plays back as MP3 via gTTS.
 
9
  """
10
  import streamlit as st
11
  from PIL import Image
12
  import warnings
13
  from modules import (
14
+ load_captioner, load_story_gen,
15
+ generate_caption, generate_story_simple,
16
+ generate_audio
 
 
 
 
17
  )
18
 
 
19
  warnings.filterwarnings("ignore", category=DeprecationWarning)
20
 
 
21
  def reset_state():
22
+ for key in ["caption", "story", "audio_bytes", "audio_mime"]:
 
 
 
23
  if key in st.session_state:
24
  del st.session_state[key]
25
 
 
26
  def main():
27
  st.title("🖼️ → 📖 Image-to-Story App for Kids")
28
+ st.write("Upload or demo an image to get a 50–100 word story and audio!")
 
 
 
29
 
30
+ source = st.radio("Image source:",
31
+ ("Upload my own image", "Use demo image"),
32
+ on_change=reset_state)
 
 
 
33
 
34
+ # Load pipelines once
35
  if "models_loaded" not in st.session_state:
36
+ with st.spinner("Loading models…"):
37
  st.session_state.captioner = load_captioner()
38
  st.session_state.story_gen = load_story_gen()
 
39
  st.session_state.models_loaded = True
40
 
41
  captioner = st.session_state.captioner
42
  story_gen = st.session_state.story_gen
 
43
 
44
+ # Acquire image
45
  if source == "Use demo image":
46
+ img = Image.open("test_kids_playing.jpg").convert("RGB")
47
  else:
48
+ uploaded = st.file_uploader("Upload an image",
49
+ type=["png", "jpg", "jpeg"])
 
 
 
50
  if not uploaded:
51
  return
52
+ img = Image.open(uploaded).convert("RGB")
53
 
 
54
  st.image(img, use_container_width=True)
55
 
56
+ # Caption
57
  if "caption" not in st.session_state:
58
+ with st.spinner("Captioning image…"):
59
  st.session_state.caption = generate_caption(captioner, img)
60
  st.markdown(f"**Caption:** {st.session_state.caption}")
61
 
62
+ # Story
63
  if "story" not in st.session_state:
64
+ with st.spinner("Creating story…"):
65
+ st.session_state.story = generate_story_simple(
66
+ story_gen, st.session_state.caption, 50, 100
67
+ )
 
 
 
 
 
 
68
  st.markdown(f"**Story:** {st.session_state.story}")
69
 
70
+ # Audio
71
+ if "audio_bytes" not in st.session_state:
72
+ with st.spinner("Generating audio…"):
73
+ audio_bytes, mime = generate_audio(st.session_state.story)
74
+ st.session_state.audio_bytes = audio_bytes
75
+ st.session_state.audio_mime = mime
76
 
 
77
  if st.button("🔊 Play Story Audio"):
78
+ st.audio(data=st.session_state.audio_bytes,
79
+ format=st.session_state.audio_mime)
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
  main()