hskwon7 commited on
Commit
639df53
·
verified ·
1 Parent(s): 1fd7c62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -26
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
- import io
5
- from gtts import gTTS
6
- import tempfile
7
 
8
- st.title("🖼️ → 📖 Image-to-Story Demo")
9
  st.write("Upload an image and watch as it’s captioned, turned into a short story, and even read aloud!")
10
 
 
11
  @st.cache_resource
12
  def load_captioner():
13
  return pipeline("image-to-text", model="unography/blip-large-long-cap")
@@ -16,49 +15,58 @@ def load_captioner():
16
  def load_story_gen():
17
  return pipeline("text-generation", model="gpt2", tokenizer="gpt2")
18
 
 
 
 
 
 
19
  captioner = load_captioner()
20
  story_gen = load_story_gen()
 
21
 
 
22
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
23
  if uploaded:
24
  img = Image.open(uploaded)
25
  st.image(img, use_column_width=True)
26
 
27
- # Caption
28
  if "caption" not in st.session_state:
29
  with st.spinner("Generating caption…"):
30
- caps = captioner(img)
31
- st.session_state.caption = caps[0] if isinstance(caps, list) else caps
 
32
  st.write("**Caption:**", st.session_state.caption)
33
 
34
- # Story
35
  if "story" not in st.session_state:
36
  with st.spinner("Spinning up a story…"):
37
  out = story_gen(
38
  st.session_state.caption,
39
  max_length=200,
40
- num_return_sequences=1,
41
  do_sample=True,
42
- top_p=0.9
 
43
  )
44
  st.session_state.story = out[0]["generated_text"]
45
  st.write("**Story:**", st.session_state.story)
46
 
47
- # Prepare audio bytes once
48
- if "audio_bytes" not in st.session_state:
49
- with st.spinner("Generating audio…"):
50
- tts = gTTS(text=st.session_state.story, lang="en")
51
- buf = io.BytesIO()
52
- tts.write_to_fp(buf)
53
- st.session_state.audio_bytes = buf.getvalue()
 
 
 
 
54
 
55
- # Play button
56
  if st.button("🔊 Play Story Audio"):
57
- # Write to a temp file
58
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
59
- tmp.write(st.session_state.audio_bytes)
60
- tmp.flush()
61
- tmp_path = tmp.name
62
- tmp.close()
63
- # Stream it
64
- st.audio(tmp_path, format="audio/mp3")
 
1
  import streamlit as st
2
  from transformers import pipeline
3
  from PIL import Image
4
+ import numpy as np
 
 
5
 
6
+ st.title("🖼️ → 📖 Image-to-Story Demo (with HF TTS)")
7
  st.write("Upload an image and watch as it’s captioned, turned into a short story, and even read aloud!")
8
 
9
+ # 1) load and cache pipelines
10
  @st.cache_resource
11
  def load_captioner():
12
  return pipeline("image-to-text", model="unography/blip-large-long-cap")
 
15
  def load_story_gen():
16
  return pipeline("text-generation", model="gpt2", tokenizer="gpt2")
17
 
18
+ @st.cache_resource
19
+ def load_tts():
20
+ # SpeechT5 text-to-speech
21
+ return pipeline("text-to-speech", model="microsoft/speecht5_tts")
22
+
23
  captioner = load_captioner()
24
  story_gen = load_story_gen()
25
+ tts = load_tts()
26
 
27
+ # 2) upload image
28
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
29
  if uploaded:
30
  img = Image.open(uploaded)
31
  st.image(img, use_column_width=True)
32
 
33
+ # 3) generate caption
34
  if "caption" not in st.session_state:
35
  with st.spinner("Generating caption…"):
36
+ cap = captioner(img)
37
+ # BLIP returns a list of strings
38
+ st.session_state.caption = cap[0] if isinstance(cap, list) else cap
39
  st.write("**Caption:**", st.session_state.caption)
40
 
41
+ # 4) generate story
42
  if "story" not in st.session_state:
43
  with st.spinner("Spinning up a story…"):
44
  out = story_gen(
45
  st.session_state.caption,
46
  max_length=200,
 
47
  do_sample=True,
48
+ top_p=0.9,
49
+ num_return_sequences=1
50
  )
51
  st.session_state.story = out[0]["generated_text"]
52
  st.write("**Story:**", st.session_state.story)
53
 
54
+ # 5) generate TTS once
55
+ if "tts_array" not in st.session_state:
56
+ with st.spinner("Generating speech…"):
57
+ # returns list of dicts with "array" and "sampling_rate"
58
+ speech = tts(st.session_state.story)
59
+ arr = speech[0]["array"] # NumPy float32 array
60
+ sr = speech[0]["sampling_rate"] # e.g. 48000
61
+ # Hugging Face outputs float32 in [-1,1]; convert to int16 for playback
62
+ int16 = (arr * 32767).astype(np.int16)
63
+ st.session_state.tts_array = int16
64
+ st.session_state.tts_sr = sr
65
 
66
+ # 6) play on button
67
  if st.button("🔊 Play Story Audio"):
68
+ st.audio(
69
+ data=st.session_state.tts_array,
70
+ format="audio/wav",
71
+ sample_rate=st.session_state.tts_sr
72
+ )