Krish-05 commited on
Commit
aa6e621
·
verified ·
1 Parent(s): 9f4d48c

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +138 -200
streamlit_app.py CHANGED
@@ -1,13 +1,14 @@
1
  import logging
2
  import logging.handlers
3
- import queue
4
  import threading
5
  import time
6
  import urllib.request
7
  import os
8
- from collections import deque
9
  from pathlib import Path
10
  from typing import List
 
 
 
11
 
12
  import av
13
  import numpy as np
@@ -21,11 +22,19 @@ HERE = Path(__file__).parent
21
 
22
  logger = logging.getLogger(__name__)
23
 
 
 
 
 
 
 
 
24
 
25
- # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501
 
26
  def download_file(url, download_to: Path, expected_size=None):
27
- # Don't download the file twice.
28
- # (If possible, verify the download using the file length.)
29
  if download_to.exists():
30
  if expected_size:
31
  if download_to.stat().st_size == expected_size:
@@ -37,7 +46,6 @@ def download_file(url, download_to: Path, expected_size=None):
37
 
38
  download_to.parent.mkdir(parents=True, exist_ok=True)
39
 
40
- # These are handles to two visual elements to animate.
41
  weights_warning, progress_bar = None, None
42
  try:
43
  weights_warning = st.warning("Downloading %s..." % url)
@@ -54,13 +62,11 @@ def download_file(url, download_to: Path, expected_size=None):
54
  counter += len(data)
55
  output_file.write(data)
56
 
57
- # We perform animation by overwriting the elements.
58
  weights_warning.warning(
59
  "Downloading %s... (%6.2f/%6.2f MB)"
60
  % (url, counter / MEGABYTES, length / MEGABYTES)
61
  )
62
  progress_bar.progress(min(counter / length, 1.0))
63
- # Finally, we remove these visual elements by calling .empty().
64
  finally:
65
  if weights_warning is not None:
66
  weights_warning.empty()
@@ -68,234 +74,166 @@ def download_file(url, download_to: Path, expected_size=None):
68
  progress_bar.empty()
69
 
70
 
71
- # This code is based on https://github.com/whitphx/streamlit-webrtc/blob/c1fe3c783c9e8042ce0c95d789e833233fd82e74/sample_utils/turn.py
72
- @st.cache_data # type: ignore
73
  def get_ice_servers():
74
- """Use Twilio's TURN server because Streamlit Community Cloud has changed
75
- its infrastructure and WebRTC connection cannot be established without TURN server now. # noqa: E501
76
- We considered Open Relay Project (https://www.metered.ca/tools/openrelay/) too,
77
- but it is not stable and hardly works as some people reported like https://github.com/aiortc/aiortc/issues/832#issuecomment-1482420656 # noqa: E501
78
- See https://github.com/whitphx/streamlit-webrtc/issues/1213
79
- """
80
-
81
- # Ref: https://www.twilio.com/docs/stun-turn/api
82
  try:
83
  account_sid = os.environ["TWILIO_ACCOUNT_SID"]
84
  auth_token = os.environ["TWILIO_AUTH_TOKEN"]
85
  except KeyError:
86
  logger.warning(
87
- "Twilio credentials are not set. Fallback to a free STUN server from Google." # noqa: E501
88
  )
89
  return [{"urls": ["stun:stun.l.google.com:19302"]}]
90
 
91
  client = Client(account_sid, auth_token)
92
-
93
  token = client.tokens.create()
94
-
95
  return token.ice_servers
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  def main():
100
- st.header("Real Time Speech-to-Text")
101
  st.markdown(
102
  """
103
- This demo app is using [DeepSpeech](https://github.com/mozilla/DeepSpeech),
104
- an open speech-to-text engine.
105
-
106
- A pre-trained model released with
107
- [v0.9.3](https://github.com/mozilla/DeepSpeech/releases/tag/v0.9.3),
108
- trained on American English is being served.
109
- """
110
  )
111
 
112
- # https://github.com/mozilla/DeepSpeech/releases/tag/v0.9.3
113
- MODEL_URL = "https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.pbmm" # noqa
114
- LANG_MODEL_URL = "https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.scorer" # noqa
115
- MODEL_LOCAL_PATH = HERE / "models/deepspeech-0.9.3-models.pbmm"
116
- LANG_MODEL_LOCAL_PATH = HERE / "models/deepspeech-0.9.3-models.scorer"
117
-
118
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=188915987)
119
- download_file(LANG_MODEL_URL, LANG_MODEL_LOCAL_PATH, expected_size=953363776)
120
-
121
- lm_alpha = 0.931289039105002
122
- lm_beta = 1.1834137581510284
123
- beam = 100
124
-
125
- sound_only_page = "Sound only (sendonly)"
126
- with_video_page = "With video (sendrecv)"
127
- app_mode = st.selectbox("Choose the app mode", [sound_only_page, with_video_page])
128
-
129
- if app_mode == sound_only_page:
130
- app_sst(
131
- str(MODEL_LOCAL_PATH), str(LANG_MODEL_LOCAL_PATH), lm_alpha, lm_beta, beam
132
- )
133
- elif app_mode == with_video_page:
134
- app_sst_with_video(
135
- str(MODEL_LOCAL_PATH), str(LANG_MODEL_LOCAL_PATH), lm_alpha, lm_beta, beam
136
- )
137
-
138
-
139
- def app_sst(model_path: str, lm_path: str, lm_alpha: float, lm_beta: float, beam: int):
140
  webrtc_ctx = webrtc_streamer(
141
- key="speech-to-text",
142
  mode=WebRtcMode.SENDONLY,
143
- audio_receiver_size=1024,
144
  rtc_configuration={"iceServers": get_ice_servers()},
145
  media_stream_constraints={"video": False, "audio": True},
 
146
  )
147
 
148
- status_indicator = st.empty()
149
-
150
- if not webrtc_ctx.state.playing:
151
- return
152
-
153
- status_indicator.write("Loading...")
154
- text_output = st.empty()
155
- stream = None
156
-
157
- while True:
158
- if webrtc_ctx.audio_receiver:
159
- if stream is None:
160
- from deepspeech import Model
161
-
162
- model = Model(model_path)
163
- model.enableExternalScorer(lm_path)
164
- model.setScorerAlphaBeta(lm_alpha, lm_beta)
165
- model.setBeamWidth(beam)
166
-
167
- stream = model.createStream()
168
-
169
- status_indicator.write("Model loaded.")
170
-
171
- sound_chunk = pydub.AudioSegment.empty()
172
- try:
173
- audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
174
- except queue.Empty:
175
- time.sleep(0.1)
176
- status_indicator.write("No frame arrived.")
177
- continue
178
-
179
- status_indicator.write("Running. Say something!")
180
-
181
- for audio_frame in audio_frames:
182
- sound = pydub.AudioSegment(
183
- data=audio_frame.to_ndarray().tobytes(),
184
- sample_width=audio_frame.format.bytes,
185
- frame_rate=audio_frame.sample_rate,
186
- channels=len(audio_frame.layout.channels),
187
- )
188
- sound_chunk += sound
189
-
190
- if len(sound_chunk) > 0:
191
- sound_chunk = sound_chunk.set_channels(1).set_frame_rate(
192
- model.sampleRate()
193
- )
194
- buffer = np.array(sound_chunk.get_array_of_samples())
195
- stream.feedAudioContent(buffer)
196
- text = stream.intermediateDecode()
197
- text_output.markdown(f"**Text:** {text}")
198
- else:
199
- status_indicator.write("AudioReciver is not set. Abort.")
200
- break
201
-
202
-
203
- def app_sst_with_video(
204
- model_path: str, lm_path: str, lm_alpha: float, lm_beta: float, beam: int
205
- ):
206
- frames_deque_lock = threading.Lock()
207
- frames_deque: deque = deque([])
208
-
209
- async def queued_audio_frames_callback(
210
- frames: List[av.AudioFrame],
211
- ) -> av.AudioFrame:
212
- with frames_deque_lock:
213
- frames_deque.extend(frames)
214
-
215
- # Return empty frames to be silent.
216
- new_frames = []
217
- for frame in frames:
218
- input_array = frame.to_ndarray()
219
- new_frame = av.AudioFrame.from_ndarray(
220
- np.zeros(input_array.shape, dtype=input_array.dtype),
221
- layout=frame.layout.name,
222
- )
223
- new_frame.sample_rate = frame.sample_rate
224
- new_frames.append(new_frame)
225
 
226
- return new_frames
 
 
 
227
 
228
- webrtc_ctx = webrtc_streamer(
229
- key="speech-to-text-w-video",
230
- mode=WebRtcMode.SENDRECV,
231
- queued_audio_frames_callback=queued_audio_frames_callback,
232
- rtc_configuration={"iceServers": get_ice_servers()},
233
- media_stream_constraints={"video": True, "audio": True},
234
- )
235
 
236
- status_indicator = st.empty()
 
237
 
238
- if not webrtc_ctx.state.playing:
239
- return
 
 
 
 
 
 
 
 
240
 
241
- status_indicator.write("Loading...")
242
- text_output = st.empty()
243
- stream = None
244
 
245
- while True:
246
  if webrtc_ctx.state.playing:
247
- if stream is None:
248
- from deepspeech import Model
249
-
250
- model = Model(model_path)
251
- model.enableExternalScorer(lm_path)
252
- model.setScorerAlphaBeta(lm_alpha, lm_beta)
253
- model.setBeamWidth(beam)
254
-
255
- stream = model.createStream()
256
-
257
- status_indicator.write("Model loaded.")
258
-
259
- sound_chunk = pydub.AudioSegment.empty()
260
-
261
- audio_frames = []
262
- with frames_deque_lock:
263
- while len(frames_deque) > 0:
264
- frame = frames_deque.popleft()
265
- audio_frames.append(frame)
266
-
267
- if len(audio_frames) == 0:
268
- time.sleep(0.1)
269
- status_indicator.write("No frame arrived.")
270
- continue
271
-
272
- status_indicator.write("Running. Say something!")
273
-
274
- for audio_frame in audio_frames:
275
- sound = pydub.AudioSegment(
276
- data=audio_frame.to_ndarray().tobytes(),
277
- sample_width=audio_frame.format.bytes,
278
- frame_rate=audio_frame.sample_rate,
279
- channels=len(audio_frame.layout.channels),
280
- )
281
- sound_chunk += sound
282
-
283
- if len(sound_chunk) > 0:
284
- sound_chunk = sound_chunk.set_channels(1).set_frame_rate(
285
- model.sampleRate()
286
- )
287
- buffer = np.array(sound_chunk.get_array_of_samples())
288
- stream.feedAudioContent(buffer)
289
- text = stream.intermediateDecode()
290
- text_output.markdown(f"**Text:** {text}")
291
  else:
292
- status_indicator.write("Stopped.")
293
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
 
296
  if __name__ == "__main__":
297
- import os
298
-
299
  DEBUG = os.environ.get("DEBUG", "false").lower() not in ["false", "no", "0"]
300
 
301
  logging.basicConfig(
@@ -307,7 +245,7 @@ if __name__ == "__main__":
307
  logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)
308
 
309
  st_webrtc_logger = logging.getLogger("streamlit_webrtc")
310
- st_webrtc_logger.setLevel(logging.DEBUG)
311
 
312
  fsevents_logger = logging.getLogger("fsevents")
313
  fsevents_logger.setLevel(logging.WARNING)
 
1
  import logging
2
  import logging.handlers
 
3
  import threading
4
  import time
5
  import urllib.request
6
  import os
 
7
  from pathlib import Path
8
  from typing import List
9
+ import io
10
+ import soundfile as sf
11
+ import requests
12
 
13
  import av
14
  import numpy as np
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
+ # --- Session State Initialization ---
26
+ if 'is_recording' not in st.session_state:
27
+ st.session_state.is_recording = False
28
+ if 'transcribed_text' not in st.session_state:
29
+ st.session_state.transcribed_text = ""
30
+ if 'audio_processor_instance' not in st.session_state:
31
+ st.session_state.audio_processor_instance = None
32
 
33
+
34
+ # --- Utility Functions (from original code, kept for completeness) ---
35
  def download_file(url, download_to: Path, expected_size=None):
36
+ # This function is retained but might not be strictly necessary for this new workflow
37
+ # as Whisper model is loaded by FastAPI server.
38
  if download_to.exists():
39
  if expected_size:
40
  if download_to.stat().st_size == expected_size:
 
46
 
47
  download_to.parent.mkdir(parents=True, exist_ok=True)
48
 
 
49
  weights_warning, progress_bar = None, None
50
  try:
51
  weights_warning = st.warning("Downloading %s..." % url)
 
62
  counter += len(data)
63
  output_file.write(data)
64
 
 
65
  weights_warning.warning(
66
  "Downloading %s... (%6.2f/%6.2f MB)"
67
  % (url, counter / MEGABYTES, length / MEGABYTES)
68
  )
69
  progress_bar.progress(min(counter / length, 1.0))
 
70
  finally:
71
  if weights_warning is not None:
72
  weights_warning.empty()
 
74
  progress_bar.empty()
75
 
76
 
77
+ @st.cache_data
 
78
  def get_ice_servers():
79
+ """Fetches ICE servers for WebRTC connection."""
 
 
 
 
 
 
 
80
  try:
81
  account_sid = os.environ["TWILIO_ACCOUNT_SID"]
82
  auth_token = os.environ["TWILIO_AUTH_TOKEN"]
83
  except KeyError:
84
  logger.warning(
85
+ "Twilio credentials are not set. Fallback to a free STUN server from Google."
86
  )
87
  return [{"urls": ["stun:stun.l.google.com:19302"]}]
88
 
89
  client = Client(account_sid, auth_token)
 
90
  token = client.tokens.create()
 
91
  return token.ice_servers
92
 
93
 
94
+ # --- Custom Audio Processor for streamlit-webrtc ---
95
+ class AudioBufferProcessor(AudioProcessorBase):
96
+ def __init__(self) -> None:
97
+ self._audio_buffer = pydub.AudioSegment.empty()
98
+ self._lock = threading.Lock()
99
+
100
+ def recv(self, frame: av.AudioFrame) -> None:
101
+ if st.session_state.is_recording:
102
+ sound = pydub.AudioSegment(
103
+ data=frame.to_ndarray().tobytes(),
104
+ sample_width=frame.format.bytes,
105
+ frame_rate=frame.sample_rate,
106
+ channels=len(frame.layout.channels),
107
+ )
108
+ sound = sound.set_channels(1).set_frame_rate(16000)
109
+ with self._lock:
110
+ self._audio_buffer += sound
111
+
112
+ def get_and_clear_buffered_audio(self) -> pydub.AudioSegment:
113
+ with self._lock:
114
+ recorded_audio = self._audio_buffer
115
+ self._audio_buffer = pydub.AudioSegment.empty()
116
+ return recorded_audio
117
+
118
 
119
  def main():
120
+ st.header("Whisper Speech-to-Text with Recording")
121
  st.markdown(
122
  """
123
+ Click "Start Recording" to begin capturing audio from your microphone.
124
+ Click "Stop Recording" to end the capture, save the audio,
125
+ and send it to the Whisper model for transcription.
126
+ """
 
 
 
127
  )
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  webrtc_ctx = webrtc_streamer(
130
+ key="audio_recorder",
131
  mode=WebRtcMode.SENDONLY,
132
+ audio_processor_factory=AudioBufferProcessor,
133
  rtc_configuration={"iceServers": get_ice_servers()},
134
  media_stream_constraints={"video": False, "audio": True},
135
+ async_processing=True
136
  )
137
 
138
+ if webrtc_ctx.audio_processor and st.session_state.audio_processor_instance is None:
139
+ st.session_state.audio_processor_instance = webrtc_ctx.audio_processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
+ if webrtc_ctx.state.playing:
142
+ st.success("Microphone connected. Ready to record.")
143
+ else:
144
+ st.warning("Waiting for microphone connection... Please allow microphone access.")
145
 
 
 
 
 
 
 
 
146
 
147
+ # --- Recording Controls ---
148
+ col1, col2 = st.columns(2)
149
 
150
+ with col1:
151
+ start_button = st.button(
152
+ "Start Recording",
153
+ disabled=st.session_state.is_recording or not webrtc_ctx.state.playing
154
+ )
155
+ with col2:
156
+ stop_button = st.button(
157
+ "Stop Recording",
158
+ disabled=not st.session_state.is_recording
159
+ )
160
 
161
+ # Placeholder for the animated text area
162
+ transcription_text_area = st.text_area("Transcription Result", value="", height=150, disabled=True)
 
163
 
164
+ if start_button:
165
  if webrtc_ctx.state.playing:
166
+ st.session_state.is_recording = True
167
+ st.session_state.transcribed_text = ""
168
+ # Clear text area immediately
169
+ transcription_text_area.empty()
170
+ st.info("Recording... Click 'Stop Recording' to transcribe.")
171
+ logger.info("Recording started.")
172
+ st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  else:
174
+ st.error("Cannot start recording: Microphone not connected. Please allow microphone access.")
175
+
176
+ if stop_button:
177
+ if st.session_state.is_recording:
178
+ st.session_state.is_recording = False
179
+ st.info("Processing recording... Please wait.")
180
+ logger.info("Recording stopped. Processing audio...")
181
+
182
+ if st.session_state.audio_processor_instance:
183
+ recorded_audio = st.session_state.audio_processor_instance.get_and_clear_buffered_audio()
184
+
185
+ if len(recorded_audio) > 0:
186
+ wav_file_buffer = io.BytesIO()
187
+ audio_array = np.array(recorded_audio.get_array_of_samples())
188
+ audio_array = audio_array.astype(np.float32)
189
+ sf.write(wav_file_buffer, audio_array, recorded_audio.frame_rate, format='WAV', subtype='PCM_16')
190
+ wav_file_buffer.seek(0)
191
+
192
+ WHISPER_API_URL = "http://localhost:1990/transcribe_audio/"
193
+ try:
194
+ files = {'audio_file': ('recorded_audio.wav', wav_file_buffer, 'audio/wav')}
195
+ response = requests.post(WHISPER_API_URL, files=files, timeout=120)
196
+ response.raise_for_status()
197
+ transcription_data = response.json()
198
+ full_transcribed_text = transcription_data.get("transcription", "No transcription found.")
199
+ st.session_state.transcribed_text = full_transcribed_text
200
+
201
+ # --- Character-by-character display logic ---
202
+ animated_text = ""
203
+ # Re-display the placeholder to clear previous content
204
+ transcription_text_area.empty()
205
+ for char in full_transcribed_text:
206
+ animated_text += char
207
+ transcription_text_area.text_area("Transcription Result", value=animated_text, height=150, disabled=True)
208
+ time.sleep(0.02) # Adjust speed as desired (e.g., 0.05 for slower)
209
+ # Ensure the final text is displayed
210
+ transcription_text_area.text_area("Transcription Result", value=full_transcribed_text, height=150, disabled=True)
211
+ # --- End character-by-character display logic ---
212
+
213
+ st.success("Transcription complete!")
214
+ logger.info(f"Transcription received: '{full_transcribed_text[:100]}...'")
215
+ except requests.exceptions.ConnectionError as e:
216
+ st.error(f"Could not connect to Whisper API at {WHISPER_API_URL}. Is the FastAPI server running on port 1990?")
217
+ logger.error(f"Connection Error: {e}", exc_info=True)
218
+ except requests.exceptions.Timeout:
219
+ st.error("Whisper API request timed out. The model might be busy or the audio too long. Try a shorter recording.")
220
+ logger.error("Request Timeout.", exc_info=True)
221
+ except requests.exceptions.RequestException as e:
222
+ st.error(f"Error during API request: {e}. Response: {e.response.text if e.response else 'No response'}")
223
+ logger.error(f"API Request Error: {e}", exc_info=True)
224
+ except Exception as e:
225
+ st.error(f"An unexpected error occurred during transcription: {e}")
226
+ logger.error(f"Unexpected Transcription Error: {e}", exc_info=True)
227
+
228
+ else:
229
+ st.warning("No audio recorded. Please ensure your microphone is active and you spoke.")
230
+ logger.warning("No audio recorded after stopping.")
231
+ else:
232
+ st.error("Audio processor instance not found. Please refresh the app and allow microphone access.")
233
+ st.rerun()
234
 
235
 
236
  if __name__ == "__main__":
 
 
237
  DEBUG = os.environ.get("DEBUG", "false").lower() not in ["false", "no", "0"]
238
 
239
  logging.basicConfig(
 
245
  logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)
246
 
247
  st_webrtc_logger = logging.getLogger("streamlit_webrtc")
248
+ st_webrtc_logger.setLevel(logging.DEBUG if DEBUG else logging.INFO)
249
 
250
  fsevents_logger = logging.getLogger("fsevents")
251
  fsevents_logger.setLevel(logging.WARNING)