hskwon7 commited on
Commit
2cbbfe6
·
verified ·
1 Parent(s): 07e7687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -63
app.py CHANGED
@@ -1,81 +1,89 @@
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
- from transformers import pipeline
3
  from PIL import Image
4
- import numpy as np
5
  import warnings
 
 
 
 
 
 
 
 
6
 
 
7
  warnings.filterwarnings("ignore", category=DeprecationWarning)
8
 
9
- st.title("🖼️ → 📖 Image-to-Story App (with Bark TTS)")
10
- st.write("Upload an image, get a caption, spin it into a story, and play it aloud!")
11
 
12
- # 1) Define your cached loaders
13
- @st.cache_resource
14
- def load_captioner():
15
- return pipeline("image-to-text", model="unography/blip-large-long-cap")
 
 
16
 
17
- @st.cache_resource
18
- def load_story_gen():
19
- return pipeline("text-generation", model="gpt2", tokenizer="gpt2")
 
 
 
 
20
 
21
- @st.cache_resource
22
- def load_bark_tts():
23
- # Zero-config Bark text-to-speech
24
- return pipeline("text-to-speech", model="suno/bark")
25
 
26
- # 2) Load all models under one spinner
27
- if "models_loaded" not in st.session_state:
28
- with st.spinner("Getting everything ready…"):
29
- st.session_state.captioner = load_captioner()
30
- st.session_state.story_gen = load_story_gen()
31
- st.session_state.tts_pipe = load_bark_tts()
32
- st.session_state.models_loaded = True
33
 
34
- captioner = st.session_state.captioner
35
- story_gen = st.session_state.story_gen
36
- tts_pipe = st.session_state.tts_pipe
37
 
38
- # 3) Upload image
39
- uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
40
- if not uploaded:
41
- st.stop()
 
42
 
43
- img = Image.open(uploaded)
44
- st.image(img, use_container_width=True)
 
 
 
 
 
 
 
 
45
 
46
- # 4) Generate caption once
47
- if "caption" not in st.session_state:
48
- with st.spinner("Describing your image…"):
49
- caps = captioner(img)
50
- first = caps[0]
51
- st.session_state.caption = first.get("generated_text", "") if isinstance(first, dict) else str(first)
52
- st.write("**Caption:**", st.session_state.caption)
53
 
54
- # 5) Generate story once
55
- if "story" not in st.session_state:
56
- with st.spinner("Weaving a story…"):
57
- out = story_gen(
58
- st.session_state.caption,
59
- max_length=200,
60
- do_sample=True,
61
- top_p=0.9,
62
- num_return_sequences=1
63
  )
64
- st.session_state.story = out[0]["generated_text"]
65
- st.write("**Story:**", st.session_state.story)
66
 
67
- # 6) Generate Bark audio once
68
- if "audio_data" not in st.session_state:
69
- with st.spinner("Converting to speech…"):
70
- speech = tts_pipe(st.session_state.story)
71
- # Bark returns a dict: {"audio": np.ndarray, "sampling_rate": int}
72
- st.session_state.audio_data = speech["audio"]
73
- st.session_state.audio_sr = speech["sampling_rate"]
74
 
75
- # 7) Play button
76
- if st.button("🔊 Play Story Audio"):
77
- st.audio(
78
- data=st.session_state.audio_data,
79
- format="audio/wav",
80
- sample_rate=st.session_state.audio_sr
81
- )
 
1
+ # app.py
2
+ """
3
+ app.py
4
+
5
+ Streamlit application for Image-to-Story demo.
6
+ Allows users to upload an image, generates a caption, creates a child-friendly story,
7
+ and plays it back as audio.
8
+ """
9
  import streamlit as st
 
10
  from PIL import Image
 
11
  import warnings
12
+ from modules import (
13
+ load_captioner,
14
+ load_story_gen,
15
+ load_tts,
16
+ generate_caption,
17
+ generate_story,
18
+ generate_audio
19
+ )
20
 
21
+ # Suppress deprecation warnings for cleaner UI
22
  warnings.filterwarnings("ignore", category=DeprecationWarning)
23
 
 
 
24
 
25
+ def main():
26
+ st.title("🖼️ → 📖 Image-to-Story App for Kids")
27
+ st.write(
28
+ "Upload an image and get an engaging story suitable for 3–10 year-olds, "
29
+ "with audio playback powered by Hugging Face pipelines!"
30
+ )
31
 
32
+ # Load pipelines (cached) with a friendly spinner
33
+ if "models_loaded" not in st.session_state:
34
+ with st.spinner("Loading AI models, please wait…"):
35
+ st.session_state.captioner = load_captioner()
36
+ st.session_state.story_gen = load_story_gen()
37
+ st.session_state.tts_pipe = load_tts()
38
+ st.session_state.models_loaded = True
39
 
40
+ captioner = st.session_state.captioner
41
+ story_gen = st.session_state.story_gen
42
+ tts_pipe = st.session_state.tts_pipe
 
43
 
44
+ # File uploader for images
45
+ uploaded = st.file_uploader(
46
+ "Upload an image (PNG, JPG, JPEG)", type=["png", "jpg", "jpeg"]
47
+ )
48
+ if not uploaded:
49
+ return
 
50
 
51
+ # Display uploaded image
52
+ img = Image.open(uploaded)
53
+ st.image(img, use_container_width=True)
54
 
55
+ # Generate caption once
56
+ if "caption" not in st.session_state:
57
+ with st.spinner("Generating image caption…"):
58
+ st.session_state.caption = generate_caption(captioner, img)
59
+ st.markdown(f"**Caption:** {st.session_state.caption}")
60
 
61
+ # Generate story once
62
+ if "story" not in st.session_state:
63
+ with st.spinner("Creating story for kids…"):
64
+ st.session_state.story = generate_story(
65
+ story_gen,
66
+ st.session_state.caption,
67
+ min_words=50,
68
+ max_words=100
69
+ )
70
+ st.markdown(f"**Story:** {st.session_state.story}")
71
 
72
+ # Generate TTS audio once
73
+ if "audio_data" not in st.session_state:
74
+ with st.spinner("Synthesizing speech…"):
75
+ audio_array, sr = generate_audio(tts_pipe, st.session_state.story)
76
+ st.session_state.audio_data = audio_array
77
+ st.session_state.audio_sr = sr
 
78
 
79
+ # Audio playback button
80
+ if st.button("🔊 Play Story Audio"):
81
+ st.audio(
82
+ data=st.session_state.audio_data,
83
+ format="audio/wav",
84
+ sample_rate=st.session_state.audio_sr
 
 
 
85
  )
 
 
86
 
 
 
 
 
 
 
 
87
 
88
+ if __name__ == "__main__":
89
+ main()