shingguy1 commited on
Commit
4816af6
Β·
verified Β·
1 Parent(s): f97d7d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -59
app.py CHANGED
@@ -8,21 +8,27 @@ import time
8
  import threading
9
 
10
  # β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
 
11
  @st.cache_resource
12
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
 
13
  return pipeline("image-to-text", model=model_name, device="cpu")
14
 
15
  @st.cache_resource
16
  def get_story_pipe(model_name="google/flan-t5-base"):
 
17
  return pipeline("text2text-generation", model=model_name, device="cpu")
18
 
19
  @st.cache_resource
20
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
 
21
  return pipeline("text-to-speech", model=model_name, device="cpu")
22
 
23
  # β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
24
  def part1_image_to_text(pil_img, captioner):
 
25
  results = captioner(pil_img)
 
26
  return results[0].get("generated_text", "") if results else ""
27
 
28
  def part2_text_to_story(
@@ -38,122 +44,130 @@ def part2_text_to_story(
38
  repetition_penalty: float = 1.1,
39
  no_repeat_ngram_size: int = 4
40
  ) -> str:
 
41
  prompt = (
42
  f"Write a vivid, imaginative short story of about {target_words} words "
43
  f"describing this scene: {caption}"
44
  )
 
45
  out = story_pipe(
46
  prompt,
47
- max_length=max_length,
48
- min_length=min_length,
49
- do_sample=do_sample,
50
- top_k=top_k,
51
- top_p=top_p,
52
- temperature=temperature,
53
- repetition_penalty=repetition_penalty,
54
- no_repeat_ngram_size=no_repeat_ngram_size,
55
- early_stopping=False
56
  )
 
57
  raw = out[0].get("generated_text", "").strip()
58
  if not raw:
59
  return ""
60
- # strip echo of prompt
61
  if raw.lower().startswith(prompt.lower()):
62
  story = raw[len(prompt):].strip()
63
  else:
64
  story = raw
65
- # cut at last full stop
66
  idx = story.rfind(".")
67
  if idx != -1:
68
  story = story[:idx+1]
69
  return story
70
 
71
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
 
72
  out = tts_pipe(text)
73
  if isinstance(out, list):
74
  out = out[0]
 
75
  audio_array = out["audio"] # np.ndarray (channels, samples)
76
  rate = out["sampling_rate"] # int
 
77
  data = audio_array.T if audio_array.ndim == 2 else audio_array
 
78
  pcm = (data * 32767).astype(np.int16)
79
 
 
80
  buffer = io.BytesIO()
81
  wf = wave.open(buffer, "wb")
82
- channels = 1 if data.ndim == 1 else data.shape[1]
83
  wf.setnchannels(channels)
84
- wf.setsampwidth(2)
85
- wf.setframerate(rate)
86
- wf.writeframes(pcm.tobytes())
87
  wf.close()
88
- buffer.seek(0)
89
- return buffer.read()
90
 
91
  # β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
92
- # Set page config as the first Streamlit command
93
  st.set_page_config(
94
  page_title="Picture to Story Magic",
95
  page_icon="✨",
96
  layout="centered"
97
  )
98
 
99
- # Custom CSS for kid-friendly styling with improved readability
100
  st.markdown("""
101
  <style>
102
  .main {
103
- background-color: #e6f3ff;
104
  padding: 20px;
105
  border-radius: 15px;
106
  }
107
  .stButton>button {
108
- background-color: #ffcccb;
109
  button-color: #000000;
110
  border-radius: 10px;
111
- border: 2px solid #ff9999;
112
  font-size: 18px;
113
  font-weight: bold;
114
  padding: 10px 20px;
115
- transition: all 0.3s;
116
  }
117
  .stButton>button:hover {
118
- background-color: #ff9999;
119
  color: #ffffff;
120
- transform: scale(1.05);
121
  }
122
  .stFileUploader {
123
- background-color: #ffb300;
124
- border: 2px dashed #ff8c00;
125
  border-radius: 10px;
126
  padding: 10px;
127
  }
128
  .stFileUploader div[role="button"] {
129
- background-color: #f0f0f0;
130
  border-radius: 10px;
131
  padding: 10px;
132
  }
133
  .stFileUploader div[role="button"] > div {
134
- color: #000000 !important;
135
  font-size: 16px;
136
  }
137
  .stFileUploader button {
138
- background-color: #ffca28 !important;
139
  color: #000000 !important;
140
  border-radius: 8px !important;
141
- border: 2px solid #ffb300 !important;
142
  padding: 5px 15px !important;
143
  font-weight: bold !important;
144
- box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important;
145
  }
146
  .stFileUploader button:hover {
147
- background-color: #ff8c00 !important;
148
  color: #000000 !important;
149
  }
150
  .stImage {
151
- border: 3px solid #81c784;
152
  border-radius: 10px;
153
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
154
  }
155
  .section-header {
156
- background-color: #b3e5fc;
157
  padding: 10px;
158
  border-radius: 10px;
159
  text-align: center;
@@ -163,90 +177,101 @@ st.markdown("""
163
  margin-bottom: 10px;
164
  }
165
  .caption-box, .story-box {
166
- background-color: #f0f4c3;
167
  padding: 15px;
168
  border-radius: 10px;
169
- border: 2px solid #d4e157;
170
  margin-bottom: 20px;
171
  color: #000000;
172
  }
173
  .caption-box b, .story-box b {
174
- color: #000000;
175
  }
176
  .stProgress > div > div {
177
- background-color: #81c784;
178
  }
179
  </style>
180
  """, unsafe_allow_html=True)
181
 
182
- # Main title
183
  st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
184
 
185
  # Image upload section
186
  with st.container():
 
187
  st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
188
  uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
189
  if not uploaded:
 
190
  st.info("Upload a picture, and let's make a story! πŸŽ‰")
191
  st.stop()
192
 
193
- # Show image
194
  with st.spinner("Looking at your picture..."):
195
  pil_img = Image.open(uploaded)
196
- st.image(pil_img, use_container_width=True)
197
 
198
- # Caption section
199
  with st.container():
200
  st.markdown("<div class='section-header'>2️⃣ What's in the Picture? 🧐</div>", unsafe_allow_html=True)
201
- captioner = get_image_captioner()
202
- progress_bar = st.progress(0)
203
- result = [None]
204
  def run_caption():
 
205
  result[0] = part1_image_to_text(pil_img, captioner)
206
  with st.spinner("Figuring out what's in your picture..."):
207
  thread = threading.Thread(target=run_caption)
208
  thread.start()
 
209
  for i in range(100):
210
  progress_bar.progress(i + 1)
211
- time.sleep(0.05) # Adjust for ~5 seconds total
212
- thread.join()
213
- progress_bar.empty()
214
  caption = result[0]
 
215
  st.markdown(f"<div class='caption-box'><b>Picture Description:</b><br>{caption}</div>", unsafe_allow_html=True)
216
 
217
- # Story and audio section
218
  with st.container():
219
  st.markdown("<div class='section-header'>3️⃣ Your Story and Audio! 🎡</div>", unsafe_allow_html=True)
220
- # Story
221
- story_pipe = get_story_pipe()
222
  progress_bar = st.progress(0)
223
- result = [None]
224
  def run_story():
 
225
  result[0] = part2_text_to_story(caption, story_pipe)
226
  with st.spinner("Writing a super cool story..."):
227
  thread = threading.Thread(target=run_story)
228
  thread.start()
 
229
  for i in range(100):
230
  progress_bar.progress(i + 1)
231
- time.sleep(0.07) # Adjust for ~7 seconds total
232
  thread.join()
233
  progress_bar.empty()
234
  story = result[0]
 
235
  st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
236
 
237
- # TTS
238
- tts_pipe = get_tts_pipe()
239
  progress_bar = st.progress(0)
240
- result = [None]
241
  def run_tts():
 
242
  result[0] = part3_text_to_speech_bytes(story, tts_pipe)
243
  with st.spinner("Turning your story into sound..."):
244
  thread = threading.Thread(target=run_tts)
245
  thread.start()
 
246
  for i in range(100):
247
  progress_bar.progress(i + 1)
248
- time.sleep(0.10) # Adjust for ~10 seconds total
249
  thread.join()
250
  progress_bar.empty()
251
  audio_bytes = result[0]
 
252
  st.audio(audio_bytes, format="audio/wav")
 
8
  import threading
9
 
10
  # β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
11
+ # Cache the model loading to avoid reloading on every rerun, improving performance
12
  @st.cache_resource
13
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
14
+ # Load the image-to-text model for generating captions from images
15
  return pipeline("image-to-text", model=model_name, device="cpu")
16
 
17
  @st.cache_resource
18
  def get_story_pipe(model_name="google/flan-t5-base"):
19
+ # Load the text-to-text model for generating stories from captions
20
  return pipeline("text2text-generation", model=model_name, device="cpu")
21
 
22
  @st.cache_resource
23
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
24
+ # Load the text-to-speech model for converting stories to audio
25
  return pipeline("text-to-speech", model=model_name, device="cpu")
26
 
27
  # β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
28
  def part1_image_to_text(pil_img, captioner):
29
+ # Generate a caption for the input image using the captioner model
30
  results = captioner(pil_img)
31
+ # Extract the generated caption, return empty string if no result
32
  return results[0].get("generated_text", "") if results else ""
33
 
34
  def part2_text_to_story(
 
44
  repetition_penalty: float = 1.1,
45
  no_repeat_ngram_size: int = 4
46
  ) -> str:
47
+ # Create a prompt instructing the model to write a story based on the caption
48
  prompt = (
49
  f"Write a vivid, imaginative short story of about {target_words} words "
50
  f"describing this scene: {caption}"
51
  )
52
+ # Generate the story using the text-to-text model with specified parameters
53
  out = story_pipe(
54
  prompt,
55
+ max_length=max_length, # Maximum length of generated text
56
+ min_length=min_length, # Minimum length to ensure sufficient content
57
+ do_sample=do_sample, # Enable sampling for creative output
58
+ top_k=top_k, # Consider top-k tokens for sampling
59
+ top_p=top_p, # Use nucleus sampling for diversity
60
+ temperature=temperature, # Control randomness of output
61
+ repetition_penalty=repetition_penalty, # Penalize repeated phrases
62
+ no_repeat_ngram_size=no_repeat_ngram_size, # Prevent repeating n-grams
63
+ early_stopping=False # Continue until max_length is reached
64
  )
65
+ # Extract the generated text and clean it
66
  raw = out[0].get("generated_text", "").strip()
67
  if not raw:
68
  return ""
69
+ # Remove the prompt if it appears in the output
70
  if raw.lower().startswith(prompt.lower()):
71
  story = raw[len(prompt):].strip()
72
  else:
73
  story = raw
74
+ # Truncate at the last full stop for a natural ending
75
  idx = story.rfind(".")
76
  if idx != -1:
77
  story = story[:idx+1]
78
  return story
79
 
80
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
81
+ # Convert the input text to audio using the text-to-speech model
82
  out = tts_pipe(text)
83
  if isinstance(out, list):
84
  out = out[0]
85
+ # Extract audio data (numpy array) and sampling rate
86
  audio_array = out["audio"] # np.ndarray (channels, samples)
87
  rate = out["sampling_rate"] # int
88
+ # Transpose audio array if it has multiple channels
89
  data = audio_array.T if audio_array.ndim == 2 else audio_array
90
+ # Convert audio to 16-bit PCM format for WAV compatibility
91
  pcm = (data * 32767).astype(np.int16)
92
 
93
+ # Create a WAV file in memory
94
  buffer = io.BytesIO()
95
  wf = wave.open(buffer, "wb")
96
+ channels = 1 if data.ndim == 1 else data.shape[1] # Set mono or stereo
97
  wf.setnchannels(channels)
98
+ wf.setsampwidth(2) # 2 bytes for 16-bit audio
99
+ wf.setframerate(rate) # Set sampling rate
100
+ wf.writeframes(pcm.tobytes()) # Write audio data
101
  wf.close()
102
+ buffer.seek(0) # Reset buffer to start for reading
103
+ return buffer.read() # Return WAV bytes
104
 
105
  # β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
106
+ # Configure the Streamlit page for a kid-friendly, centered layout
107
  st.set_page_config(
108
  page_title="Picture to Story Magic",
109
  page_icon="✨",
110
  layout="centered"
111
  )
112
 
113
+ # Apply custom CSS for a colorful, engaging, and readable interface
114
  st.markdown("""
115
  <style>
116
  .main {
117
+ background-color: #e6f3ff; /* Light blue background for main area */
118
  padding: 20px;
119
  border-radius: 15px;
120
  }
121
  .stButton>button {
122
+ background-color: #ffcccb; /* Pink button background */
123
  button-color: #000000;
124
  border-radius: 10px;
125
+ border: 2px solid #ff9999; /* Red border */
126
  font-size: 18px;
127
  font-weight: bold;
128
  padding: 10px 20px;
129
+ transition: all 0.3s; /* Smooth hover effect */
130
  }
131
  .stButton>button:hover {
132
+ background-color: #ff9999; /* Darker pink on hover */
133
  color: #ffffff;
134
+ transform: scale(1.05); /* Slight zoom on hover */
135
  }
136
  .stFileUploader {
137
+ background-color: #ffb300; /* Orange uploader background */
138
+ border: 2px dashed #ff8c00; /* Dashed orange border */
139
  border-radius: 10px;
140
  padding: 10px;
141
  }
142
  .stFileUploader div[role="button"] {
143
+ background-color: #f0f0f0; /* Light gray button */
144
  border-radius: 10px;
145
  padding: 10px;
146
  }
147
  .stFileUploader div[role="button"] > div {
148
+ color: #000000 !important; /* Black text for readability */
149
  font-size: 16px;
150
  }
151
  .stFileUploader button {
152
+ background-color: #ffca28 !important; /* Yellow button */
153
  color: #000000 !important;
154
  border-radius: 8px !important;
155
+ border: 2px solid #ffb300 !important; /* Orange border */
156
  padding: 5px 15px !important;
157
  font-weight: bold !important;
158
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow */
159
  }
160
  .stFileUploader button:hover {
161
+ background-color: #ff8c00 !important; /* Orange on hover */
162
  color: #000000 !important;
163
  }
164
  .stImage {
165
+ border: 3px solid #81c784; /* Green border for images */
166
  border-radius: 10px;
167
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Soft shadow */
168
  }
169
  .section-header {
170
+ background-color: #b3e5fc; /* Light blue header background */
171
  padding: 10px;
172
  border-radius: 10px;
173
  text-align: center;
 
177
  margin-bottom: 10px;
178
  }
179
  .caption-box, .story-box {
180
+ background-color: #f0f4c3; /* Light yellow for text boxes */
181
  padding: 15px;
182
  border-radius: 10px;
183
+ border: 2px solid #d4e157; /* Green-yellow border */
184
  margin-bottom: 20px;
185
  color: #000000;
186
  }
187
  .caption-box b, .story-box b {
188
+ color: #000000; /* Black for bold text */
189
  }
190
  .stProgress > div > div {
191
+ background-color: #81c784; /* Green progress bar */
192
  }
193
  </style>
194
  """, unsafe_allow_html=True)
195
 
196
+ # Display the main title with a fun, magical theme
197
  st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
198
 
199
  # Image upload section
200
  with st.container():
201
+ # Prompt user to upload an image
202
  st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
203
  uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
204
  if not uploaded:
205
+ # Stop execution if no image is uploaded, with a friendly message
206
  st.info("Upload a picture, and let's make a story! πŸŽ‰")
207
  st.stop()
208
 
209
+ # Display the uploaded image
210
  with st.spinner("Looking at your picture..."):
211
  pil_img = Image.open(uploaded)
212
+ st.image(pil_img, use_container_width=True) # Show image scaled to container
213
 
214
+ # Caption generation section
215
  with st.container():
216
  st.markdown("<div class='section-header'>2️⃣ What's in the Picture? 🧐</div>", unsafe_allow_html=True)
217
+ captioner = get_image_captioner() # Load captioning model
218
+ progress_bar = st.progress(0) # Initialize progress bar
219
+ result = [None] # Store caption result
220
  def run_caption():
221
+ # Run captioning in a separate thread to avoid blocking UI
222
  result[0] = part1_image_to_text(pil_img, captioner)
223
  with st.spinner("Figuring out what's in your picture..."):
224
  thread = threading.Thread(target=run_caption)
225
  thread.start()
226
+ # Simulate progress for ~5 seconds
227
  for i in range(100):
228
  progress_bar.progress(i + 1)
229
+ time.sleep(0.05)
230
+ thread.join() # Wait for captioning to complete
231
+ progress_bar.empty() # Clear progress bar
232
  caption = result[0]
233
+ # Display the generated caption in a styled box
234
  st.markdown(f"<div class='caption-box'><b>Picture Description:</b><br>{caption}</div>", unsafe_allow_html=True)
235
 
236
+ # Story and audio generation section
237
  with st.container():
238
  st.markdown("<div class='section-header'>3️⃣ Your Story and Audio! 🎡</div>", unsafe_allow_html=True)
239
+ # Story generation
240
+ story_pipe = get_story_pipe() # Load story model
241
  progress_bar = st.progress(0)
242
+ result = [None] # Store story result
243
  def run_story():
244
+ # Generate story in a separate thread
245
  result[0] = part2_text_to_story(caption, story_pipe)
246
  with st.spinner("Writing a super cool story..."):
247
  thread = threading.Thread(target=run_story)
248
  thread.start()
249
+ # Simulate progress for ~7 seconds
250
  for i in range(100):
251
  progress_bar.progress(i + 1)
252
+ time.sleep(0.07)
253
  thread.join()
254
  progress_bar.empty()
255
  story = result[0]
256
+ # Display the generated story in a styled box
257
  st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
258
 
259
+ # Text-to-speech conversion
260
+ tts_pipe = get_tts_pipe() # Load TTS model
261
  progress_bar = st.progress(0)
262
+ result = [None] # Store audio result
263
  def run_tts():
264
+ # Generate audio in a separate thread
265
  result[0] = part3_text_to_speech_bytes(story, tts_pipe)
266
  with st.spinner("Turning your story into sound..."):
267
  thread = threading.Thread(target=run_tts)
268
  thread.start()
269
+ # Simulate progress for ~10 seconds
270
  for i in range(100):
271
  progress_bar.progress(i + 1)
272
+ time.sleep(0.10)
273
  thread.join()
274
  progress_bar.empty()
275
  audio_bytes = result[0]
276
+ # Play the generated audio in the UI
277
  st.audio(audio_bytes, format="audio/wav")