shingguy1 commited on
Commit
6543b0a
Β·
verified Β·
1 Parent(s): 2e503f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -159
app.py CHANGED
@@ -6,188 +6,214 @@ from transformers import pipeline
6
  from PIL import Image
7
  import numpy as np
8
 
9
- # --- 1) MODEL LOADING (cached) -----------------------------------------------
10
- # Cache resources to avoid reloading models on every interaction
11
  @st.cache_resource
12
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
13
- """Initialize image-to-text model for generating image captions"""
14
- return pipeline("image-to-text", model=model_name, device="cpu")
15
 
16
  @st.cache_resource
17
  def get_story_pipe(model_name="google/flan-t5-base"):
18
- """Initialize text generation pipeline for story creation"""
19
- return pipeline("text2text-generation", model=model_name, device="cpu")
20
 
21
  @st.cache_resource
22
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
23
- """Initialize text-to-speech pipeline for audio generation"""
24
- return pipeline("text-to-speech", model=model_name, device="cpu")
25
 
26
- # --- 2) TRANSFORM FUNCTIONS --------------------------------------------------
27
  def part1_image_to_text(pil_img, captioner):
28
- """Generate text caption from PIL image using captioning model"""
29
- results = captioner(pil_img)
30
- return results[0].get("generated_text", "") if results else ""
31
 
32
  def part2_text_to_story(
33
- caption: str,
34
- story_pipe,
35
- target_words: int = 100,
36
- max_length: int = 100,
37
- min_length: int = 80,
38
- do_sample: bool = True,
39
- top_k: int = 100,
40
- top_p: float= 0.9,
41
- temperature: float= 0.7,
42
- repetition_penalty: float = 1.1,
43
- no_repeat_ngram_size: int = 4
44
  ) -> str:
45
- """
46
- Generate story from text caption using controlled text generation
47
- Args:
48
- caption: Input image description
49
- story_pipe: Initialized text generation pipeline
50
- Generation parameters control output quality/diversity
51
- """
52
- # Craft prompt with dynamic word count target
53
- prompt = (
54
- f"Write a vivid, imaginative short story of about {target_words} words "
55
- f"describing this scene: {caption}"
56
- )
57
-
58
- # Generate raw story text with specified generation parameters
59
- out = story_pipe(
60
- prompt,
61
- max_length=max_length,
62
- min_length=min_length,
63
- do_sample=do_sample,
64
- top_k=top_k,
65
- top_p=top_p,
66
- temperature=temperature,
67
- repetition_penalty=repetition_penalty,
68
- no_repeat_ngram_size=no_repeat_ngram_size,
69
- early_stopping=False
70
- )
71
-
72
- # Post-process generated text
73
- raw = out[0].get("generated_text", "").strip()
74
- if not raw:
75
- return ""
76
-
77
- # Remove prompt echo if present
78
- if raw.lower().startswith(prompt.lower()):
79
- story = raw[len(prompt):].strip()
80
- else:
81
- story = raw
82
-
83
- # Ensure story ends with proper punctuation
84
- idx = story.rfind(".")
85
- return story[:idx+1] if idx != -1 else story
86
 
87
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
88
- """Convert generated story text to WAV audio bytes"""
89
- out = tts_pipe(text)
90
- if isinstance(out, list):
91
- out = out[0]
92
-
93
- # Process audio output
94
- audio_array = out["audio"] # Shape: (channels, samples)
95
- rate = out["sampling_rate"] # Typically 16000 for TTS models
96
-
97
- # Convert to PCM format
98
- data = audio_array.T if audio_array.ndim == 2 else audio_array
99
- pcm = (data * 32767).astype(np.int16) # Convert to 16-bit PCM
100
-
101
- # Create in-memory WAV file
102
- buffer = io.BytesIO()
103
- with wave.open(buffer, "wb") as wf:
104
- wf.setnchannels(1 if data.ndim == 1 else data.shape[1])
105
- wf.setsampwidth(2) # 16-bit samples
106
- wf.setframerate(rate)
107
- wf.writeframes(pcm.tobytes())
108
- buffer.seek(0)
109
- return buffer.read()
110
-
111
- # --- 3) STREAMLIT UI --------------------------------------------------------
112
- # Configure page settings (must be first Streamlit command)
113
  st.set_page_config(
114
- page_title="Picture to Story Magic",
115
- page_icon="✨",
116
- layout="centered"
117
  )
118
 
119
- # Custom CSS for enhanced accessibility and visual appeal
120
  st.markdown("""
121
- <style>
122
- /* Main container styling */
123
- .main {
124
- background-color: #e6f3ff; /* Soft blue background */
125
- padding: 20px;
126
- border-radius: 15px;
127
- }
128
- /* ... (other CSS rules maintained as original) ... */
129
- </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  """, unsafe_allow_html=True)
131
 
132
- # --- APP INTERFACE COMPONENTS ------------------------------------------------
133
- # Main header with emoji
134
- st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>",
135
- unsafe_allow_html=True)
136
 
137
  # Image upload section
138
  with st.container():
139
- st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>",
140
- unsafe_allow_html=True)
141
- uploaded = st.file_uploader("Choose a picture to start the magic! 😊",
142
- type=["jpg","jpeg","png"])
143
-
144
- # Early exit if no image uploaded
145
- if not uploaded:
146
- st.info("Upload a picture, and let's make a story! πŸŽ‰")
147
- st.stop()
148
-
149
- # Display uploaded image
150
- with st.spinner("Looking at your picture..."):
151
- pil_img = Image.open(uploaded)
152
- st.image(pil_img, use_column_width=True)
153
-
154
- # Caption generation section
155
  with st.container():
156
- captioner = get_image_captioner()
157
- with st.spinner("Figuring out what's in your picture..."):
158
- caption = part1_image_to_text(pil_img, captioner)
159
- st.markdown(
160
- f"<div class='caption-box'>"
161
- f"<b>What's in the Picture? 🧐</b><br>{caption}"
162
- f"</div>",
163
- unsafe_allow_html=True
164
- )
165
-
166
- # Story generation and audio section
167
  with st.container():
168
- st.markdown("<div class='section-header'>2️⃣ Make a Story and Hear It! 🎡</div>",
169
- unsafe_allow_html=True)
170
-
171
- if st.button("Create My Story! πŸŽ‰"):
172
- # Generate story text
173
- story_pipe = get_story_pipe()
174
- with st.spinner("Writing a super cool story..."):
175
- story = part2_text_to_story(caption, story_pipe)
176
-
177
- # Display generated story
178
- st.markdown(
179
- f"<div class='story-box'>"
180
- f"<b>Your Cool Story! πŸ“š</b><br>{story}"
181
- f"</div>",
182
- unsafe_allow_html=True
183
- )
184
-
185
- # Generate and display audio
186
- tts_pipe = get_tts_pipe()
187
- with st.spinner("Turning your story into sound..."):
188
- audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
189
- st.audio(audio_bytes, format="audio/wav")
190
-
191
- # Success indicators
192
- st.success("Yay! Your story is ready! 🎈")
193
- st.balloons()
 
6
  from PIL import Image
7
  import numpy as np
8
 
9
+ # β€”β€”β€” 1) MODEL LOADING (cached) β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
 
10
  @st.cache_resource
11
  def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"):
12
+ return pipeline("image-to-text", model=model_name, device="cpu")
 
13
 
14
  @st.cache_resource
15
  def get_story_pipe(model_name="google/flan-t5-base"):
16
+ return pipeline("text2text-generation", model=model_name, device="cpu")
 
17
 
18
  @st.cache_resource
19
  def get_tts_pipe(model_name="facebook/mms-tts-eng"):
20
+ return pipeline("text-to-speech", model=model_name, device="cpu")
 
21
 
22
+ # β€”β€”β€” 2) TRANSFORM FUNCTIONS β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
23
  def part1_image_to_text(pil_img, captioner):
24
+ results = captioner(pil_img)
25
+ return results[0].get("generated_text", "") if results else ""
 
26
 
27
  def part2_text_to_story(
28
+ caption: str,
29
+ story_pipe,
30
+ target_words: int = 100,
31
+ max_length: int = 100,
32
+ min_length: int = 80,
33
+ do_sample: bool = True,
34
+ top_k: int = 100,
35
+ top_p: float= 0.9,
36
+ temperature: float= 0.7,
37
+ repetition_penalty: float = 1.1,
38
+ no_repeat_ngram_size: int = 4
39
  ) -> str:
40
+ prompt = (
41
+ f"Write a vivid, imaginative short story of about {target_words} words "
42
+ f"describing this scene: {caption}"
43
+ )
44
+ out = story_pipe(
45
+ prompt,
46
+ max_length=max_length,
47
+ min_length=min_length,
48
+ do_sample=do_sample,
49
+ top_k=top_k,
50
+ top_p=top_p,
51
+ temperature=temperature,
52
+ repetition_penalty=repetition_penalty,
53
+ no_repeat_ngram_size=no_repeat_ngram_size,
54
+ early_stopping=False
55
+ )
56
+ raw = out[0].get("generated_text", "").strip()
57
+ if not raw:
58
+ return ""
59
+ # strip echo of prompt
60
+ if raw.lower().startswith(prompt.lower()):
61
+ story = raw[len(prompt):].strip()
62
+ else:
63
+ story = raw
64
+ # cut at last full stop
65
+ idx = story.rfind(".")
66
+ if idx != -1:
67
+ story = story[:idx+1]
68
+ return story
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes:
71
+ out = tts_pipe(text)
72
+ if isinstance(out, list):
73
+ out = out[0]
74
+ audio_array = out["audio"] # np.ndarray (channels, samples)
75
+ rate = out["sampling_rate"] # int
76
+ data = audio_array.T if audio_array.ndim == 2 else audio_array
77
+ pcm = (data * 32767).astype(np.int16)
78
+
79
+ buffer = io.BytesIO()
80
+ wf = wave.open(buffer, "wb")
81
+ channels = 1 if data.ndim == 1 else data.shape[1]
82
+ wf.setnchannels(channels)
83
+ wf.setsampwidth(2)
84
+ wf.setframerate(rate)
85
+ wf.writeframes(pcm.tobytes())
86
+ wf.close()
87
+ buffer.seek(0)
88
+ return buffer.read()
89
+
90
+ # β€”β€”β€” 3) STREAMLIT UI β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
91
+ # Set page config as the first Streamlit command
 
 
 
 
92
  st.set_page_config(
93
+ page_title="Picture to Story Magic",
94
+ page_icon="✨",
95
+ layout="centered"
96
  )
97
 
98
+ # Custom CSS for kid-friendly styling with improved readability
99
  st.markdown("""
100
+ <style>
101
+ .main {
102
+ background-color: #e6f3ff;
103
+ padding: 20px;
104
+ border-radius: 15px;
105
+ }
106
+ .stButton>button {
107
+ background-color: #ffcccb;
108
+ color: #000000; /* Black text */
109
+ border-radius: 10px;
110
+ border: 2px solid #ff9999;
111
+ font-size: 18px;
112
+ font-weight: bold;
113
+ padding: 10px 20px;
114
+ transition: all 0.3s;
115
+ }
116
+ .stButton>button:hover {
117
+ background-color: #ff9999;
118
+ color: #ffffff; /* White text on hover for contrast */
119
+ transform: scale(1.05);
120
+ }
121
+ .stFileUploader {
122
+ background-color: #ffb300; /* Darker yellow for better contrast with white label text */
123
+ border: 2px dashed #ff8c00; /* Darker orange border to match */
124
+ border-radius: 10px;
125
+ padding: 10px;
126
+ }
127
+ /* Style for the file uploader's inner text */
128
+ .stFileUploader div[role="button"] {
129
+ background-color: #f0f0f0; /* Very light gray background for contrast with black text */
130
+ border-radius: 10px;
131
+ padding: 10px;
132
+ }
133
+ .stFileUploader div[role="button"] > div {
134
+ color: #000000 !important; /* Black text */
135
+ font-size: 16px;
136
+ }
137
+ /* Style for the "Browse files" button inside the file uploader */
138
+ .stFileUploader button {
139
+ background-color: #ffca28 !important; /* Yellow button background */
140
+ color: #000000 !important; /* Black text */
141
+ border-radius: 8px !important;
142
+ border: 2px solid #ffb300 !important; /* Match the container background */
143
+ padding: 5px 15px !important;
144
+ font-weight: bold !important;
145
+ box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow to make button stand out */
146
+ }
147
+ .stFileUploader button:hover {
148
+ background-color: #ff8c00 !important; /* Slightly darker yellow on hover */
149
+ color: #000000 !important; /* Keep black text */
150
+ }
151
+ .stImage {
152
+ border: 3px solid #81c784;
153
+ border-radius: 10px;
154
+ box-shadow: 0 4px 8px rgba(0,0,0,0.1);
155
+ }
156
+ .section-header {
157
+ background-color: #b3e5fc;
158
+ padding: 10px;
159
+ border-radius: 10px;
160
+ text-align: center;
161
+ font-size: 24px;
162
+ font-weight: bold;
163
+ color: #000000; /* Black text */
164
+ margin-bottom: 10px;
165
+ }
166
+ .caption-box, .story-box {
167
+ background-color: #f0f4c3;
168
+ padding: 15px;
169
+ border-radius: 10px;
170
+ border: 2px solid #d4e157;
171
+ margin-bottom: 20px;
172
+ color: #000000; /* Black text */
173
+ }
174
+ .caption-box b, .story-box b {
175
+ color: #000000; /* Black text for bold headers */
176
+ }
177
+ </style>
178
  """, unsafe_allow_html=True)
179
 
180
+ # Main title
181
+ st.markdown("<div class='section-header'>Picture to Story Magic! ✨</div>", unsafe_allow_html=True)
 
 
182
 
183
  # Image upload section
184
  with st.container():
185
+ st.markdown("<div class='section-header'>1️⃣ Pick a Fun Picture! πŸ–ΌοΈ</div>", unsafe_allow_html=True)
186
+ uploaded = st.file_uploader("Choose a picture to start the magic! 😊", type=["jpg","jpeg","png"])
187
+ if not uploaded:
188
+ st.info("Upload a picture, and let's make a story! πŸŽ‰")
189
+ st.stop()
190
+
191
+ # Show image
192
+ with st.spinner("Looking at your picture..."):
193
+ pil_img = Image.open(uploaded)
194
+ st.image(pil_img, use_container_width=True)
195
+
196
+ # Caption section
 
 
 
 
197
  with st.container():
198
+ captioner = get_image_captioner()
199
+ with st.spinner("Figuring out what's in your picture..."):
200
+ caption = part1_image_to_text(pil_img, captioner)
201
+ st.markdown(f"<div class='caption-box'><b>What's in the Picture? 🧐</b><br>{caption}</div>", unsafe_allow_html=True)
202
+
203
+ # Story and audio section
 
 
 
 
 
204
  with st.container():
205
+ st.markdown("<div class='section-header'>2️⃣ Make a Story and Hear It! 🎡</div>", unsafe_allow_html=True)
206
+ if st.button("Create My Story! πŸŽ‰"):
207
+ # Story
208
+ story_pipe = get_story_pipe()
209
+ with st.spinner("Writing a super cool story..."):
210
+ story = part2_text_to_story(caption, story_pipe)
211
+ st.markdown(f"<div class='story-box'><b>Your Cool Story! πŸ“š</b><br>{story}</div>", unsafe_allow_html=True)
212
+
213
+ # TTS
214
+ tts_pipe = get_tts_pipe()
215
+ with st.spinner("Turning your story into sound..."):
216
+ audio_bytes = part3_text_to_speech_bytes(story, tts_pipe)
217
+ st.audio(audio_bytes, format="audio/wav")
218
+ st.success("Yay! Your story is ready! 🎈")
219
+ st.balloons() # Fun animation