rjzevallos commited on
Commit
20547d7
·
1 Parent(s): 4ae7ed6

Fix: send 'FINISH' text over WebSocket on stop to match server

Browse files
Files changed (1) hide show
  1. app.py +77 -397
app.py CHANGED
@@ -1,415 +1,95 @@
1
- import asyncio
2
- import logging
3
- from fastapi import FastAPI, UploadFile, File, WebSocket
4
- from fastapi.responses import JSONResponse, StreamingResponse
5
  import gradio as gr
 
6
  import numpy as np
7
- import io
8
-
9
- import server_wrapper
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- app = FastAPI(title="SimulStreaming ASR")
15
-
16
-
17
- @app.on_event("startup")
18
- async def startup_event():
19
- logger.info("Starting up... initializing model.")
20
- loop = asyncio.get_event_loop()
21
- try:
22
- logger.info("Downloading Whisper model if not already present...")
23
- await loop.run_in_executor(None, _ensure_model_downloaded)
24
- logger.info("Model ready.")
25
-
26
- await loop.run_in_executor(None, server_wrapper.init_model)
27
- logger.info("Model initialized successfully.")
28
- except Exception as e:
29
- logger.error(f"Error during model initialization: {e}")
30
-
31
-
32
- def _ensure_model_downloaded():
33
- """Ensure the Whisper model is downloaded."""
34
- import os
35
- model_dir = os.path.expanduser('~/.cache/whisper')
36
- model_path = os.path.join(model_dir, 'large-v3.pt')
37
-
38
- if not os.path.exists(model_path):
39
- try:
40
- logger.info(f"Downloading Whisper large-v3 model to {model_path}...")
41
- import whisper
42
- whisper.load_model('large-v3')
43
- logger.info("Model downloaded successfully.")
44
- except Exception as e:
45
- logger.warning(f"Could not pre-download model: {e}")
46
  else:
47
- logger.info(f"Model already present at {model_path}")
48
-
49
-
50
- @app.post("/api/reset")
51
- async def api_reset():
52
- try:
53
- server_wrapper.reset()
54
- return JSONResponse({"status": "ok"})
55
- except Exception as e:
56
- return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
57
-
58
-
59
- @app.post("/api/chunk")
60
- async def api_chunk(file: UploadFile = File(...)):
61
- """Process a single audio chunk (streaming)."""
62
- try:
63
- raw = await file.read()
64
- out = await asyncio.get_event_loop().run_in_executor(None, server_wrapper.process_chunk_from_bytes, raw)
65
- return JSONResponse(out or {"text": ""})
66
- except Exception as e:
67
- logger.error(f"Error processing chunk: {e}")
68
- return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
69
-
70
 
71
- @app.post("/api/finish")
72
- async def api_finish():
73
- """Finish streaming and return final transcription."""
74
- try:
75
- out = await asyncio.get_event_loop().run_in_executor(None, server_wrapper.finish)
76
- return JSONResponse(out or {"text": ""})
77
- except Exception as e:
78
- logger.error(f"Error finishing: {e}")
79
- return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
80
 
81
 
82
- @app.websocket("/ws/audio")
83
- async def websocket_audio(websocket: WebSocket):
84
- """WebSocket endpoint for real-time audio streaming."""
85
- await websocket.accept()
86
- logger.info("WebSocket connection established")
87
-
88
- try:
89
- server_wrapper.reset()
90
-
91
- while True:
92
- # Accept either binary frames (audio) or text frames (control messages like FINISH)
93
- message = await websocket.receive()
94
- data = None
95
- is_text = False
96
- if 'bytes' in message and message['bytes'] is not None:
97
- data = message['bytes']
98
- elif 'text' in message and message['text'] is not None:
99
- data = message['text']
100
- is_text = True
101
 
102
- if is_text:
103
- # Control messages
104
- if data == "FINISH":
105
- result = await asyncio.get_event_loop().run_in_executor(None, server_wrapper.finish)
106
- await websocket.send_json({"type": "finish", **(result or {})})
107
- break
108
- elif data == "RESET":
109
- server_wrapper.reset()
110
- await websocket.send_json({"type": "reset", "status": "ok"})
111
- else:
112
- # Unknown text message - ignore or log
113
- logger.debug(f"Unknown WS text message: {data}")
114
- else:
115
- # Binary audio chunk (or binary control marker)
116
- try:
117
- # If client sent the 4-byte control marker 0xFF 0xFF 0xFF 0xFF, treat as FINISH
118
- if isinstance(data, (bytes, bytearray)) and data == b"\xFF\xFF\xFF\xFF":
119
- result = await asyncio.get_event_loop().run_in_executor(None, server_wrapper.finish)
120
- await websocket.send_json({"type": "finish", **(result or {})})
121
- break
122
 
123
- result = await asyncio.get_event_loop().run_in_executor(None, server_wrapper.process_chunk_from_bytes, data)
124
- if result and result.get("text"):
125
- await websocket.send_json({"type": "update", "text": result["text"]})
126
- except Exception as e:
127
- logger.error(f"Error processing audio chunk via websocket: {e}")
128
- await websocket.send_json({"type": "error", "message": str(e)})
129
-
130
- except Exception as e:
131
- logger.error(f"WebSocket error: {e}")
132
- try:
133
- await websocket.send_json({"type": "error", "message": str(e)})
134
- except:
135
- pass
136
- finally:
137
- await websocket.close()
138
- logger.info("WebSocket connection closed")
139
 
140
 
141
- def create_ui():
142
- with gr.Blocks(title="Streaming ASR", theme=gr.themes.Soft()) as demo:
143
- gr.Markdown("""
144
- # 🎙️ Streaming ASR — SimulWhisper
145
-
146
- Transcripción en tiempo real mientras hablas.
147
-
148
- **Instrucciones:**
149
- 1. Haz clic en **"🔴 Start Recording"**
150
- 2. Habla naturalmente - verás la transcripción EN TIEMPO REAL
151
- 3. Haz clic en **"⏹️ Stop Recording"** cuando termines
152
- """)
153
-
154
- with gr.Row():
155
- start_btn = gr.Button("🔴 Start Recording", size="lg", variant="primary", scale=1)
156
- stop_btn = gr.Button("⏹️ Stop Recording", size="lg", variant="stop", scale=1)
157
-
158
- with gr.Row():
159
- with gr.Column(scale=1):
160
- gr.Markdown("### Status")
161
- status = gr.Textbox(
162
- value="Ready",
163
- interactive=False,
164
- show_label=False,
165
- lines=2
166
- )
167
-
168
- with gr.Column(scale=2):
169
- gr.Markdown("### 📝 Transcripción en Vivo")
170
- transcript = gr.Textbox(
171
- show_label=False,
172
- lines=8,
173
- interactive=False,
174
- placeholder="La transcripción aparecerá aquí en tiempo real..."
175
- )
176
-
177
- # JavaScript para captura real-time con WebSocket
178
- html_js = """
179
- <script>
180
- let mediaRecorder;
181
- let audioCtx;
182
- let source;
183
- let processor;
184
- let recording = false;
185
- let ws = null;
186
- let chunkSize = 16000 * 0.5; // 0.5 seconds at 16kHz
187
- let startBtn = null;
188
- let stopBtn = null;
189
- let statusDiv = null;
190
- let transcriptDiv = null;
191
 
192
- function to16BitPCM(float32Array) {
193
- const l = float32Array.length;
194
- const buffer = new ArrayBuffer(l * 2);
195
- const view = new DataView(buffer);
196
- let offset = 0;
197
- for (let i = 0; i < l; i++) {
198
- let s = Math.max(-1, Math.min(1, float32Array[i]));
199
- view.setInt16(offset, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
200
- offset += 2;
201
- }
202
- return buffer;
203
- }
204
 
205
- function writeWAV(samples, sampleRate) {
206
- const buffer = new ArrayBuffer(44 + samples.byteLength);
207
- const view = new DataView(buffer);
208
- function writeString(view, offset, string) {
209
- for (let i = 0; i < string.length; i++) {
210
- view.setUint8(offset + i, string.charCodeAt(i));
211
- }
212
- }
213
- writeString(view, 0, 'RIFF');
214
- view.setUint32(4, 36 + samples.byteLength, true);
215
- writeString(view, 8, 'WAVE');
216
- writeString(view, 12, 'fmt ');
217
- view.setUint32(16, 16, true);
218
- view.setUint16(20, 1, true);
219
- view.setUint16(22, 1, true);
220
- view.setUint32(24, sampleRate, true);
221
- view.setUint32(28, sampleRate * 2, true);
222
- view.setUint16(32, 2, true);
223
- view.setUint16(34, 16, true);
224
- writeString(view, 36, 'data');
225
- view.setUint32(40, samples.byteLength, true);
226
- const bytes = new Uint8Array(buffer, 44);
227
- bytes.set(new Uint8Array(samples));
228
- return buffer;
229
- }
230
 
231
- async function resampleAudio(float32Array, fromSampleRate, toSampleRate) {
232
- if (fromSampleRate === toSampleRate) {
233
- return float32Array;
234
- }
235
- const length = Math.round(float32Array.length * toSampleRate / fromSampleRate);
236
- const offlineCtx = new OfflineAudioContext(1, length, toSampleRate);
237
- const buffer = offlineCtx.createBuffer(1, float32Array.length, fromSampleRate);
238
- buffer.copyToChannel(float32Array, 0, 0);
239
- const src = offlineCtx.createBufferSource();
240
- src.buffer = buffer;
241
- src.connect(offlineCtx.destination);
242
- src.start(0);
243
- const rendered = await offlineCtx.startRendering();
244
- return rendered.getChannelData(0);
245
- }
246
 
247
- async function sendChunk(float32Array, sampleRate) {
248
- if (!ws || ws.readyState !== WebSocket.OPEN) return;
249
-
250
- try {
251
- let resampled = await resampleAudio(float32Array, sampleRate, 16000);
252
- const pcm16 = to16BitPCM(resampled);
253
- const wav = writeWAV(pcm16, 16000);
254
- ws.send(wav);
255
- } catch (e) {
256
- console.error('Error sending chunk:', e);
257
- }
258
- }
259
-
260
- async function startRecording() {
261
- try {
262
- if (recording) return;
263
-
264
- console.log('Starting recording...');
265
-
266
- // Connect WebSocket
267
- const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
268
- ws = new WebSocket(protocol + '//' + window.location.host + '/ws/audio');
269
-
270
- ws.onopen = () => {
271
- console.log('WebSocket connected');
272
- updateStatus('🔴 Recording... listening');
273
- };
274
-
275
- ws.onmessage = (event) => {
276
- const data = JSON.parse(event.data);
277
- if (data.type === 'update' && data.text) {
278
- updateTranscript(data.text);
279
- } else if (data.type === 'finish') {
280
- console.log('Transcription finished:', data);
281
- updateStatus('✅ Done');
282
- }
283
- };
284
-
285
- ws.onerror = (error) => {
286
- console.error('WebSocket error:', error);
287
- updateStatus('❌ Connection error');
288
- };
289
-
290
- ws.onclose = () => {
291
- console.log('WebSocket closed');
292
- recording = false;
293
- };
294
-
295
- // Start audio capture
296
- recording = true;
297
- audioCtx = new (window.AudioContext || window.webkitAudioContext)();
298
- const stream = await navigator.mediaDevices.getUserMedia({
299
- audio: {
300
- echoCancellation: false,
301
- noiseSuppression: false,
302
- autoGainControl: false
303
- }
304
- });
305
- source = audioCtx.createMediaStreamSource(stream);
306
- processor = audioCtx.createScriptProcessor(4096, 1, 1);
307
-
308
- let buffer = [];
309
-
310
- processor.onaudioprocess = function(e) {
311
- const ch = e.inputBuffer.getChannelData(0);
312
- for (let i = 0; i < ch.length; i++) {
313
- buffer.push(ch[i]);
314
- }
315
-
316
- // Send chunk every 0.5 seconds
317
- if (buffer.length >= chunkSize) {
318
- const chunk = new Float32Array(buffer.slice(0, chunkSize));
319
- buffer = buffer.slice(chunkSize);
320
- sendChunk(chunk, audioCtx.sampleRate);
321
- }
322
- };
323
-
324
- source.connect(processor);
325
- processor.connect(audioCtx.destination);
326
-
327
- } catch (e) {
328
- console.error('Error starting recording:', e);
329
- updateStatus('❌ Error: ' + e.message);
330
- recording = false;
331
- }
332
- }
333
-
334
- function stopRecording() {
335
- if (!recording) return;
336
-
337
- recording = false;
338
- updateStatus('⏹️ Stopping...');
339
-
340
- if (source && source.mediaStream) {
341
- const tracks = source.mediaStream.getTracks();
342
- tracks.forEach(t => t.stop());
343
- }
344
- if (processor) processor.disconnect();
345
- if (source) source.disconnect();
346
-
347
- // Send finish signal (binary marker) so server recognizes it
348
- if (ws && ws.readyState === WebSocket.OPEN) {
349
- ws.send(new Uint8Array([0xFF, 0xFF, 0xFF, 0xFF]));
350
- setTimeout(() => {
351
- if (ws) ws.close();
352
- }, 500);
353
- }
354
- }
355
-
356
- function updateTranscript(text) {
357
- const textareas = document.querySelectorAll('textarea');
358
- if (textareas.length >= 2) {
359
- textareas[1].value = text;
360
- textareas[1].dispatchEvent(new Event('input', { bubbles: true }));
361
- }
362
- }
363
-
364
- function updateStatus(text) {
365
- const textareas = document.querySelectorAll('textarea');
366
- if (textareas.length >= 1) {
367
- textareas[0].value = text;
368
- textareas[0].dispatchEvent(new Event('input', { bubbles: true }));
369
- }
370
- }
371
-
372
- // Find and attach button listeners
373
- function attachButtons() {
374
- const buttons = document.querySelectorAll('button');
375
- console.log('Found ' + buttons.length + ' buttons');
376
-
377
- if (buttons.length >= 2) {
378
- startBtn = buttons[0];
379
- stopBtn = buttons[1];
380
-
381
- startBtn.addEventListener('click', startRecording);
382
- stopBtn.addEventListener('click', stopRecording);
383
-
384
- console.log('Buttons attached successfully');
385
- }
386
- }
387
-
388
- // Try to attach buttons when page loads
389
- document.addEventListener('DOMContentLoaded', () => {
390
- console.log('DOM loaded');
391
- setTimeout(attachButtons, 1000);
392
- });
393
-
394
- // Also try immediately
395
- setTimeout(attachButtons, 500);
396
- </script>
397
- """
398
-
399
- gr.HTML(html_js)
400
-
401
- return demo
402
-
403
-
404
- demo = create_ui()
405
-
406
- # Mount Gradio app on FastAPI
407
- app = gr.mount_gradio_app(app, demo, path="/")
408
 
 
 
 
409
 
410
  if __name__ == "__main__":
411
- import uvicorn
412
- uvicorn.run(app, host="0.0.0.0", port=7860)
413
-
414
-
415
 
 
 
1
+ import time
2
+
 
 
3
  import gradio as gr
4
+ import librosa
5
  import numpy as np
6
+ # import soundfile as sf
7
+ from transformers import pipeline
8
+
9
+ TARGET_SAMPLE_RATE = 16_000
10
+ AUDIO_SECONDS_THRESHOLD = 2
11
+ pipe = pipeline("audio-classification", model="MIT/ast-finetuned-audioset-10-10-0.4593")
12
+ prediction = [{"score": 1, "label": "recording..."}]
13
+
14
+
15
+ def normalize_waveform(waveform, datatype=np.float32): # source datatype: np.int16
16
+ waveform = waveform.astype(dtype=datatype)
17
+ waveform /= 32768.0
18
+ return waveform
19
+
20
+
21
+ def streaming_recording_fn(stream, new_chunk):
22
+ global prediction
23
+ sr, y = new_chunk
24
+ y = normalize_waveform(y)
25
+ y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SAMPLE_RATE)
26
+ if stream is not None:
27
+ if (stream.shape[-1] / TARGET_SAMPLE_RATE) >= AUDIO_SECONDS_THRESHOLD:
28
+ prediction = pipe(stream)
29
+ file_name = f'./audio/{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.wav'
30
+ # # sf.write(file_name, stream, TARGET_SAMPLE_RATE)
31
+ print(f"SAVE AUDIO: {file_name}")
32
+ print(f">>>>>>1\t{y.shape=}, {stream.shape=}\n\t{prediction[0]=}")
33
+ stream = None
34
+ else:
35
+ stream = np.concatenate([stream, y], axis=-1)
36
+ print(f">>>>>>2\t{y.shape=}, {stream.shape=}")
 
 
 
 
 
 
 
 
37
  else:
38
+ stream = y
39
+ print(f">>>>>>3\t{y.shape=}, {stream.shape=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ return stream, {i['label']: i['score'] for i in prediction}
 
 
 
 
 
 
 
 
42
 
43
 
44
+ def microphone_fn(waveform):
45
+ print('-' * 120)
46
+ print(f"{waveform=}")
47
+ sr, y = waveform
48
+ y = normalize_waveform(y)
49
+ y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SAMPLE_RATE)
50
+ result = pipe(y)
51
+ file_name = f'./audio/{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.wav'
52
+ # sf.write(file_name, y, TARGET_SAMPLE_RATE)
53
+ return {i['label']: i['score'] for i in result}
 
 
 
 
 
 
 
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def file_fn(waveform):
57
+ print('-' * 120)
58
+ print(f"{waveform=}")
59
+ sr, y = waveform
60
+ y = normalize_waveform(y)
61
+ y = librosa.resample(y, orig_sr=sr, target_sr=TARGET_SAMPLE_RATE)
62
+ result = pipe(y)
63
+ file_name = f'./audio/{time.strftime("%Y%m%d_%H%M%S", time.localtime())}.wav'
64
+ # sf.write(file_name, y, TARGET_SAMPLE_RATE)
65
+ return {i['label']: i['score'] for i in result}
 
 
 
 
 
 
66
 
67
 
68
+ streaming_demo = gr.Interface(
69
+ fn=streaming_recording_fn,
70
+ inputs=["state", gr.Audio(sources=["microphone"], streaming=True)],
71
+ outputs=["state", "label"],
72
+ live=True,
73
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ with gr.Blocks() as example:
78
+ inputs = [gr.Audio(sources=["upload"], type="numpy")]
79
+ output = gr.Label()
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ examples = [
82
+ ["audio/cantina.wav"],
83
+ ["audio/cat.mp3"]
84
+ ]
85
+ ex = gr.Examples(examples,
86
+ fn=file_fn, inputs=inputs, outputs=output,
87
+ run_on_click=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ with gr.Blocks() as demo:
90
+ gr.TabbedInterface([streaming_demo],
91
+ ["Streaming"])
92
 
93
  if __name__ == "__main__":
 
 
 
 
94
 
95
+ demo.launch(share=True)