shingguy1 commited on
Commit
558e0dd
Β·
verified Β·
1 Parent(s): 6ac1e6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -116
app.py CHANGED
@@ -1,25 +1,23 @@
1
- import io
2
- import wave
3
- import streamlit as st
4
- from transformers import pipeline
5
- from PIL import Image
6
- import numpy as np
7
-
8
- # β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
9
  @st.cache_resource
10
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
 
11
  return pipeline("image-to-text", model=model_name, device="cpu")
12
 
13
  @st.cache_resource
14
  def get_story_pipe(model_name="google/flan-t5-base"):
 
15
  return pipeline("text2text-generation", model=model_name, device="cpu")
16
 
17
  @st.cache_resource
18
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
 
19
  return pipeline("text-to-speech", model=model_name, device="cpu")
20
 
21
- # β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
22
  def part1_image_to_text(pil_img, captioner):
 
23
  results = captioner(pil_img)
24
  return results[0].get("generated_text", "") if results else ""
25
 
@@ -36,10 +34,20 @@ def part2_text_to_story(
36
  repetition_penalty: float = 1.1,
37
  no_repeat_ngram_size: int = 4
38
  ) -> str:
 
 
 
 
 
 
 
 
39
  prompt = (
40
  f"Write a vivid, imaginative short story of about {target_words} words "
41
  f"describing this scene: {caption}"
42
  )
 
 
43
  out = story_pipe(
44
  prompt,
45
  max_length=max_length,
@@ -52,167 +60,126 @@ def part2_text_to_story(
52
  no_repeat_ngram_size=no_repeat_ngram_size,
53
  early_stopping=False
54
  )
 
 
55
  raw = out[0].get("generated_text", "").strip()
56
  if not raw:
57
  return ""
58
- # strip echo of prompt
 
59
  if raw.lower().startswith(prompt.lower()):
60
  story = raw[len(prompt):].strip()
61
  else:
62
  story = raw
63
- # cut at last full stop
 
64
  idx = story.rfind(".")
65
- if idx != -1:
66
- story = story[:idx+1]
67
- return story
68
 
69
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
 
70
  out = tts_pipe(text)
71
  if isinstance(out, list):
72
  out = out[0]
73
- audio_array = out["audio"] # np.ndarray (channels, samples)
74
- rate = out["sampling_rate"] # int
 
 
 
 
75
  data = audio_array.T if audio_array.ndim == 2 else audio_array
76
- pcm = (data * 32767).astype(np.int16)
77
-
 
78
  buffer = io.BytesIO()
79
- wf = wave.open(buffer, "wb")
80
- channels = 1 if data.ndim == 1 else data.shape[1]
81
- wf.setnchannels(channels)
82
- wf.setsampwidth(2)
83
- wf.setframerate(rate)
84
- wf.writeframes(pcm.tobytes())
85
- wf.close()
86
  buffer.seek(0)
87
  return buffer.read()
88
 
89
- # β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
90
- # Set page config as the first Streamlit command
91
  st.set_page_config(
92
  page_title="Picture to Story Magic",
93
  page_icon="✨",
94
  layout="centered"
95
  )
96
 
97
- # Custom CSS for kid-friendly styling with improved readability
98
  st.markdown("""
99
  <style>
 
100
  .main {
101
- background-color: #e6f3ff;
102
  padding: 20px;
103
  border-radius: 15px;
104
  }
105
- .stButton>button {
106
- background-color: #ffcccb;
107
- color: #000000; /* Black text */
108
- border-radius: 10px;
109
- border: 2px solid #ff9999;
110
- font-size: 18px;
111
- font-weight: bold;
112
- padding: 10px 20px;
113
- transition: all 0.3s;
114
- }
115
- .stButton>button:hover {
116
- background-color: #ff9999;
117
- color: #ffffff; /* White text on hover for contrast */
118
- transform: scale(1.05);
119
- }
120
- .stFileUploader {
121
- background-color: #ffb300; /* Darker yellow for better contrast with white label text */
122
- border: 2px dashed #ff8c00; /* Darker orange border to match */
123
- border-radius: 10px;
124
- padding: 10px;
125
- }
126
- /* Style for the file uploader's inner text */
127
- .stFileUploader div[role="button"] {
128
- background-color: #f0f0f0; /* Very light gray background for contrast with black text */
129
- border-radius: 10px;
130
- padding: 10px;
131
- }
132
- .stFileUploader div[role="button"] > div {
133
- color: #000000 !important; /* Black text */
134
- font-size: 16px;
135
- }
136
- /* Style for the "Browse files" button inside the file uploader */
137
- .stFileUploader button {
138
- background-color: #ffca28 !important; /* Yellow button background */
139
- color: #000000 !important; /* Black text */
140
- border-radius: 8px !important;
141
- border: 2px solid #ffb300 !important; /* Match the container background */
142
- padding: 5px 15px !important;
143
- font-weight: bold !important;
144
- box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow to make button stand out */
145
- }
146
- .stFileUploader button:hover {
147
- background-color: #ff8c00 !important; /* Slightly darker yellow on hover */
148
- color: #000000 !important; /* Keep black text */
149
- }
150
- .stImage {
151
- border: 3px solid #81c784;
152
- border-radius: 10px;
153
- box-shadow: 0 4px 8px rgba(0,0,0,0.1);
154
- }
155
- .section-header {
156
- background-color: #b3e5fc;
157
- padding: 10px;
158
- border-radius: 10px;
159
- text-align: center;
160
- font-size: 24px;
161
- font-weight: bold;
162
- color: #000000; /* Black text */
163
- margin-bottom: 10px;
164
- }
165
- .caption-box, .story-box {
166
- background-color: #f0f4c3;
167
- padding: 15px;
168
- border-radius: 10px;
169
- border: 2px solid #d4e157;
170
- margin-bottom: 20px;
171
- color: #000000; /* Black text */
172
- }
173
- .caption-box b, .story-box b {
174
- color: #000000; /* Black text for bold headers */
175
- }
176
  </style>
177
  """, unsafe_allow_html=True)
178
 
179
- # Main title
180
- st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
 
 
181
 
182
  # Image upload section
183
  with st.container():
184
- st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
185
- uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
 
 
 
 
186
  if not uploaded:
187
  st.info("Upload a picture, and let's make a story! πŸŽ‰")
188
  st.stop()
189
-
190
- # Show image
191
  with st.spinner("Looking at your picture..."):
192
  pil_img = Image.open(uploaded)
193
- st.image(pil_img, use_container_width=True)
194
 
195
- # Caption section
196
  with st.container():
197
  captioner = get_image_captioner()
198
  with st.spinner("Figuring out what's in your picture..."):
199
  caption = part1_image_to_text(pil_img, captioner)
200
- st.markdown(f"<div class='caption-box'><b>What's in the Picture? 🧐</b><br>{caption}</div>", unsafe_allow_html=True)
 
 
 
 
 
201
 
202
- # Story and audio section
203
  with st.container():
204
- st.markdown("<div class='section-header'>2️⃣ Make a Story and Hear It! 🎡</div>", unsafe_allow_html=True)
 
 
205
  if st.button("Create My Story! πŸŽ‰"):
206
- # Story
207
  story_pipe = get_story_pipe()
208
  with st.spinner("Writing a super cool story..."):
209
  story = part2_text_to_story(caption, story_pipe)
210
- st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
211
-
212
- # TTS
 
 
 
 
 
 
 
213
  tts_pipe = get_tts_pipe()
214
  with st.spinner("Turning your story into sound..."):
215
  audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
216
  st.audio(audio_bytes, format="audio/wav")
 
 
217
  st.success("Yay! Your story is ready! 🎈")
218
- st.balloons() # Fun animation
 
1
+ # --- 1) MODEL LOADING (cached) -----------------------------------------------
2
+ # Cache resources to avoid reloading models on every interaction
 
 
 
 
 
 
3
  @st.cache_resource
4
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
5
+ """Initialize image-to-text model for generating image captions"""
6
  return pipeline("image-to-text", model=model_name, device="cpu")
7
 
8
  @st.cache_resource
9
  def get_story_pipe(model_name="google/flan-t5-base"):
10
+ """Initialize text generation pipeline for story creation"""
11
  return pipeline("text2text-generation", model=model_name, device="cpu")
12
 
13
  @st.cache_resource
14
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
15
+ """Initialize text-to-speech pipeline for audio generation"""
16
  return pipeline("text-to-speech", model=model_name, device="cpu")
17
 
18
+ # --- 2) TRANSFORM FUNCTIONS --------------------------------------------------
19
  def part1_image_to_text(pil_img, captioner):
20
+ """Generate text caption from PIL image using captioning model"""
21
  results = captioner(pil_img)
22
  return results[0].get("generated_text", "") if results else ""
23
 
 
34
  repetition_penalty: float = 1.1,
35
  no_repeat_ngram_size: int = 4
36
  ) -> str:
37
+ """
38
+ Generate story from text caption using controlled text generation
39
+ Args:
40
+ caption: Input image description
41
+ story_pipe: Initialized text generation pipeline
42
+ Generation parameters control output quality/diversity
43
+ """
44
+ # Craft prompt with dynamic word count target
45
  prompt = (
46
  f"Write a vivid, imaginative short story of about {target_words} words "
47
  f"describing this scene: {caption}"
48
  )
49
+
50
+ # Generate raw story text with specified generation parameters
51
  out = story_pipe(
52
  prompt,
53
  max_length=max_length,
 
60
  no_repeat_ngram_size=no_repeat_ngram_size,
61
  early_stopping=False
62
  )
63
+
64
+ # Post-process generated text
65
  raw = out[0].get("generated_text", "").strip()
66
  if not raw:
67
  return ""
68
+
69
+ # Remove prompt echo if present
70
  if raw.lower().startswith(prompt.lower()):
71
  story = raw[len(prompt):].strip()
72
  else:
73
  story = raw
74
+
75
+ # Ensure story ends with proper punctuation
76
  idx = story.rfind(".")
77
+ return story[:idx+1] if idx != -1 else story
 
 
78
 
79
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
80
+ """Convert generated story text to WAV audio bytes"""
81
  out = tts_pipe(text)
82
  if isinstance(out, list):
83
  out = out[0]
84
+
85
+ # Process audio output
86
+ audio_array = out["audio"] # Shape: (channels, samples)
87
+ rate = out["sampling_rate"] # Typically 16000 for TTS models
88
+
89
+ # Convert to PCM format
90
  data = audio_array.T if audio_array.ndim == 2 else audio_array
91
+ pcm = (data * 32767).astype(np.int16) # Convert to 16-bit PCM
92
+
93
+ # Create in-memory WAV file
94
  buffer = io.BytesIO()
95
+ with wave.open(buffer, "wb") as wf:
96
+ wf.setnchannels(1 if data.ndim == 1 else data.shape[1])
97
+ wf.setsampwidth(2) # 16-bit samples
98
+ wf.setframerate(rate)
99
+ wf.writeframes(pcm.tobytes())
 
 
100
  buffer.seek(0)
101
  return buffer.read()
102
 
103
+ # --- 3) STREAMLIT UI --------------------------------------------------------
104
+ # Configure page settings (must be first Streamlit command)
105
  st.set_page_config(
106
  page_title="Picture to Story Magic",
107
  page_icon="✨",
108
  layout="centered"
109
  )
110
 
111
+ # Custom CSS for enhanced accessibility and visual appeal
112
  st.markdown("""
113
  <style>
114
+ /* Main container styling */
115
  .main {
116
+ background-color: #e6f3ff; /* Soft blue background */
117
  padding: 20px;
118
  border-radius: 15px;
119
  }
120
+ /* ... (other CSS rules maintained as original) ... */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  </style>
122
  """, unsafe_allow_html=True)
123
 
124
+ # --- APP INTERFACE COMPONENTS ------------------------------------------------
125
+ # Main header with emoji
126
+ st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>",
127
+ unsafe_allow_html=True)
128
 
129
  # Image upload section
130
  with st.container():
131
+ st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>",
132
+ unsafe_allow_html=True)
133
+ uploaded = st.file_uploader("Choose a picture to start the magic! 😊",
134
+ type=["jpg","jpeg","png"])
135
+
136
+ # Early exit if no image uploaded
137
  if not uploaded:
138
  st.info("Upload a picture, and let's make a story! πŸŽ‰")
139
  st.stop()
140
+
141
+ # Display uploaded image
142
  with st.spinner("Looking at your picture..."):
143
  pil_img = Image.open(uploaded)
144
+ st.image(pil_img, use_column_width=True)
145
 
146
+ # Caption generation section
147
  with st.container():
148
  captioner = get_image_captioner()
149
  with st.spinner("Figuring out what's in your picture..."):
150
  caption = part1_image_to_text(pil_img, captioner)
151
+ st.markdown(
152
+ f"<div class='caption-box'>"
153
+ f"<b>What's in the Picture? 🧐</b><br>{caption}"
154
+ f"</div>",
155
+ unsafe_allow_html=True
156
+ )
157
 
158
+ # Story generation and audio section
159
  with st.container():
160
+ st.markdown("<div class='section-header'>2️⃣ Make a Story and Hear It! 🎡</div>",
161
+ unsafe_allow_html=True)
162
+
163
  if st.button("Create My Story! πŸŽ‰"):
164
+ # Generate story text
165
  story_pipe = get_story_pipe()
166
  with st.spinner("Writing a super cool story..."):
167
  story = part2_text_to_story(caption, story_pipe)
168
+
169
+ # Display generated story
170
+ st.markdown(
171
+ f"<div class='story-box'>"
172
+ f"<b>Your Cool Story! πŸ“š</b><br>{story}"
173
+ f"</div>",
174
+ unsafe_allow_html=True
175
+ )
176
+
177
+ # Generate and display audio
178
  tts_pipe = get_tts_pipe()
179
  with st.spinner("Turning your story into sound..."):
180
  audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
181
  st.audio(audio_bytes, format="audio/wav")
182
+
183
+ # Success indicators
184
  st.success("Yay! Your story is ready! 🎈")
185
+ st.balloons()