hskwon7 commited on
Commit
ca604ad
·
verified ·
1 Parent(s): edd0e17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -23,6 +23,15 @@ from modules import (
23
  warnings.filterwarnings("ignore", category=DeprecationWarning)
24
 
25
 
 
 
 
 
 
 
 
 
 
26
  def main():
27
  st.title("🖼️ → 📖 Image-to-Story App for Kids")
28
  st.write(
@@ -30,6 +39,13 @@ def main():
30
  "with audio playback powered by Hugging Face pipelines!"
31
  )
32
 
 
 
 
 
 
 
 
33
  # Load pipelines (cached) with a friendly spinner
34
  if "models_loaded" not in st.session_state:
35
  with st.spinner("Loading AI models, please wait…"):
@@ -42,14 +58,8 @@ def main():
42
  story_gen = st.session_state.story_gen
43
  tts_pipe = st.session_state.tts_pipe
44
 
45
- # Choose image source
46
- source = st.radio(
47
- "Choose image source:",
48
- ("Upload my own image", "Use demo image")
49
- )
50
-
51
  if source == "Use demo image":
52
- # Load the bundled demo image
53
  img = Image.open("test_kids_playing.jpg")
54
  else:
55
  uploaded = st.file_uploader(
@@ -64,15 +74,14 @@ def main():
64
  # Display image
65
  st.image(img, use_container_width=True)
66
 
67
- # Generate caption once
68
- if "caption" not in st.session_state or source != st.session_state.get("last_source"):
69
  with st.spinner("Generating image caption…"):
70
  st.session_state.caption = generate_caption(captioner, img)
71
- st.session_state.last_source = source
72
  st.markdown(f"**Caption:** {st.session_state.caption}")
73
 
74
- # Generate story once
75
- if "story" not in st.session_state or source != st.session_state.get("last_source"):
76
  with st.spinner("Creating story for kids…"):
77
  st.session_state.story = generate_story(
78
  story_gen,
@@ -80,16 +89,14 @@ def main():
80
  min_words=50,
81
  max_words=100
82
  )
83
- st.session_state.last_source = source
84
  st.markdown(f"**Story:** {st.session_state.story}")
85
 
86
- # Generate TTS audio once
87
- if "audio_data" not in st.session_state or source != st.session_state.get("last_source"):
88
  with st.spinner("Synthesizing speech…"):
89
  audio_array, sr = generate_audio(tts_pipe, st.session_state.story)
90
  st.session_state.audio_data = audio_array
91
  st.session_state.audio_sr = sr
92
- st.session_state.last_source = source
93
 
94
  # Audio playback button
95
  if st.button("🔊 Play Story Audio"):
 
23
  warnings.filterwarnings("ignore", category=DeprecationWarning)
24
 
25
 
26
+ def reset_state():
27
+ """
28
+ Clear generated caption, story, and audio when image source changes.
29
+ """
30
+ for key in ["caption", "story", "audio_data", "audio_sr"]:
31
+ if key in st.session_state:
32
+ del st.session_state[key]
33
+
34
+
35
  def main():
36
  st.title("🖼️ → 📖 Image-to-Story App for Kids")
37
  st.write(
 
39
  "with audio playback powered by Hugging Face pipelines!"
40
  )
41
 
42
+ # Choose image source with callback to reset state
43
+ source = st.radio(
44
+ "Choose image source:",
45
+ ("Upload my own image", "Use demo image"),
46
+ on_change=reset_state
47
+ )
48
+
49
  # Load pipelines (cached) with a friendly spinner
50
  if "models_loaded" not in st.session_state:
51
  with st.spinner("Loading AI models, please wait…"):
 
58
  story_gen = st.session_state.story_gen
59
  tts_pipe = st.session_state.tts_pipe
60
 
61
+ # Get image object
 
 
 
 
 
62
  if source == "Use demo image":
 
63
  img = Image.open("test_kids_playing.jpg")
64
  else:
65
  uploaded = st.file_uploader(
 
74
  # Display image
75
  st.image(img, use_container_width=True)
76
 
77
+ # Generate caption
78
+ if "caption" not in st.session_state:
79
  with st.spinner("Generating image caption…"):
80
  st.session_state.caption = generate_caption(captioner, img)
 
81
  st.markdown(f"**Caption:** {st.session_state.caption}")
82
 
83
+ # Generate story
84
+ if "story" not in st.session_state:
85
  with st.spinner("Creating story for kids…"):
86
  st.session_state.story = generate_story(
87
  story_gen,
 
89
  min_words=50,
90
  max_words=100
91
  )
 
92
  st.markdown(f"**Story:** {st.session_state.story}")
93
 
94
+ # Generate audio
95
+ if "audio_data" not in st.session_state:
96
  with st.spinner("Synthesizing speech…"):
97
  audio_array, sr = generate_audio(tts_pipe, st.session_state.story)
98
  st.session_state.audio_data = audio_array
99
  st.session_state.audio_sr = sr
 
100
 
101
  # Audio playback button
102
  if st.button("🔊 Play Story Audio"):