SalexAI commited on
Commit
9e0fd09
·
verified ·
1 Parent(s): 9ab5c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -53
app.py CHANGED
@@ -1,81 +1,334 @@
1
  import os
2
- import requests
3
- from fastapi import FastAPI, Request
 
 
 
 
 
4
  from fastapi.middleware.cors import CORSMiddleware
5
- from fastapi.responses import JSONResponse
 
6
 
7
- app = FastAPI()
 
8
 
9
- # Allow ScratchX / PenguinMod
 
 
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
- allow_origins=["*"], # 🔒 restrict later if desired
13
  allow_credentials=True,
14
- allow_methods=["GET", "POST", "OPTIONS"],
15
  allow_headers=["*"],
16
  )
17
 
18
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
19
- OPENAI_REALTIME_URL = "https://api.openai.com/v1/realtime/sessions"
 
 
 
 
20
 
 
 
 
 
21
 
22
- def _mint_ephemeral(model: str, voice: str):
23
- """Helper to call OpenAI and mint ephemeral token."""
24
- if not OPENAI_API_KEY:
25
- return JSONResponse(
26
- status_code=500,
27
- content={"error": "OPENAI_API_KEY not set in environment"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
- headers = {
31
- "Authorization": f"Bearer {OPENAI_API_KEY}",
32
- "Content-Type": "application/json",
33
- "OpenAI-Beta": "realtime=v1",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  }
35
- body = {"model": model, "voice": voice}
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- r = requests.post(OPENAI_REALTIME_URL, headers=headers, json=body)
39
- r.raise_for_status()
40
- return r.json()
41
- except Exception as e:
42
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
 
44
 
45
- # --- Health endpoints ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @app.get("/health")
47
- @app.get("/health/")
48
- @app.get("/proxy/health")
49
- @app.get("/proxy/health/")
50
  def health():
51
  return {"status": "ok"}
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- @app.middleware("http")
55
- async def log_requests(request: Request, call_next):
56
- print(f"[DEBUG] Incoming: {request.method} {request.url.path}")
57
- response = await call_next(request)
58
- return response
 
59
 
 
 
 
 
 
60
 
61
- # --- Ephemeral endpoints ---
62
- @app.get("/ephemeral")
63
- @app.get("/ephemeral/")
64
- @app.get("/proxy/ephemeral")
65
- @app.get("/proxy/ephemeral/")
66
- def ephemeral_get(model: str = "gpt-4o-realtime-preview", voice: str = "verse"):
67
- return _mint_ephemeral(model, voice)
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- @app.post("/ephemeral")
71
- @app.post("/ephemeral/")
72
- @app.post("/proxy/ephemeral")
73
- @app.post("/proxy/ephemeral/")
74
- async def ephemeral_post(request: Request):
75
- try:
76
- data = await request.json()
77
- model = data.get("model", "gpt-4o-realtime-preview")
78
- voice = data.get("voice", "verse")
79
- except Exception:
80
- model, voice = "gpt-4o-realtime-preview", "verse"
81
- return _mint_ephemeral(model, voice)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import asyncio
4
+ import threading
5
+ import numpy as np
6
+ from scipy.signal import resample
7
+
8
+ from fastapi import FastAPI, WebSocket, Request, Response
9
  from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse
11
+ from fastapi.staticfiles import StaticFiles
12
 
13
+ # Realtime STT
14
+ from RealtimeSTT import AudioToTextRecorder
15
 
16
+ # ----------------------------
17
+ # App + CORS
18
+ # ----------------------------
19
+ app = FastAPI()
20
  app.add_middleware(
21
  CORSMiddleware,
22
+ allow_origins=["*"], # tighten if desired
23
  allow_credentials=True,
24
+ allow_methods=["*"],
25
  allow_headers=["*"],
26
  )
27
 
28
+ # ----------------------------
29
+ # Global recorder (singleton)
30
+ # ----------------------------
31
+ recorder = None
32
+ recorder_ready = threading.Event()
33
+ is_running = True
34
 
35
+ # Active websocket(s) to stream results to (basic single-client model)
36
+ # You can turn this into a set() if you want multi-client broadcast.
37
+ active_ws: WebSocket | None = None
38
+ main_loop = None # asyncio loop for scheduling cross-thread sends
39
 
40
+ # ----------------------------
41
+ # RealtimeSTT callbacks
42
+ # ----------------------------
43
+ async def _send_to_client(payload: dict):
44
+ global active_ws
45
+ if active_ws is None:
46
+ return
47
+ try:
48
+ await active_ws.send_text(json.dumps(payload))
49
+ except Exception:
50
+ # client probably disconnected
51
+ pass
52
+
53
+ def on_realtime_text(text: str):
54
+ """Called by recorder thread for stabilized realtime partials."""
55
+ global main_loop
56
+ if main_loop:
57
+ asyncio.run_coroutine_threadsafe(
58
+ _send_to_client({"type": "realtime", "text": text}),
59
+ main_loop
60
  )
61
 
62
+ # ----------------------------
63
+ # Recorder thread
64
+ # ----------------------------
65
+ def _recorder_thread():
66
+ global recorder, is_running, main_loop
67
+ cfg = {
68
+ "spinner": False,
69
+ "use_microphone": False, # we feed audio via .feed_audio
70
+ "model": "large-v2", # adjust if your hardware is limited
71
+ "language": "en",
72
+ "silero_sensitivity": 0.4,
73
+ "webrtc_sensitivity": 2,
74
+ "post_speech_silence_duration": 0.7,
75
+ "min_length_of_recording": 0,
76
+ "min_gap_between_recordings": 0,
77
+ "enable_realtime_transcription": True,
78
+ "realtime_processing_pause": 0,
79
+ "realtime_model_type": "tiny.en", # fast streaming model
80
+ "on_realtime_transcription_stabilized": on_realtime_text,
81
  }
 
82
 
83
+ recorder = AudioToTextRecorder(**cfg)
84
+ recorder_ready.set()
85
+
86
+ # Continuously poll for final sentences and forward them to the client
87
+ while is_running:
88
+ try:
89
+ full_sentence = recorder.text()
90
+ if full_sentence:
91
+ if main_loop:
92
+ asyncio.run_coroutine_threadsafe(
93
+ _send_to_client({"type": "fullSentence", "text": full_sentence}),
94
+ main_loop
95
+ )
96
+ except Exception:
97
+ continue
98
+
99
+ # Start recorder thread once on startup
100
+ @app.on_event("startup")
101
+ async def _startup():
102
+ global main_loop
103
+ main_loop = asyncio.get_running_loop()
104
+ t = threading.Thread(target=_recorder_thread, daemon=True)
105
+ t.start()
106
+ recorder_ready.wait(timeout=120)
107
+
108
+ @app.on_event("shutdown")
109
+ async def _shutdown():
110
+ global is_running, recorder
111
+ is_running = False
112
  try:
113
+ if recorder:
114
+ recorder.stop()
115
+ recorder.shutdown()
116
+ except Exception:
117
+ pass
118
+
119
+ # ----------------------------
120
+ # Audio helpers
121
+ # ----------------------------
122
+ def decode_and_resample(audio_bytes: bytes, orig_sr: int, target_sr: int = 16000) -> bytes:
123
+ """Resample PCM16LE buffer to target SR."""
124
+ try:
125
+ audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
126
+ if orig_sr == target_sr:
127
+ return audio_np.tobytes()
128
+ n_orig = len(audio_np)
129
+ if n_orig == 0:
130
+ return b""
131
+ n_target = int(n_orig * target_sr / orig_sr)
132
+ resampled = resample(audio_np, n_target).astype(np.int16)
133
+ return resampled.tobytes()
134
+ except Exception:
135
+ # If resample fails, just return original chunk
136
+ return audio_bytes
137
+
138
+ # ----------------------------
139
+ # WebSocket: /ws
140
+ # Frame format: [4-byte little-endian length][UTF-8 JSON metadata][PCM16 payload]
141
+ # metadata: {"sampleRate": 48000}
142
+ # ----------------------------
143
+ @app.websocket("/ws")
144
+ async def ws_endpoint(ws: WebSocket):
145
+ global active_ws
146
+ await ws.accept()
147
+ active_ws = ws
148
+
149
+ # Ensure recorder is ready
150
+ if not recorder_ready.is_set():
151
+ await ws.send_text(json.dumps({"type": "error", "error": "Recorder not ready"}))
152
+
153
+ try:
154
+ while True:
155
+ # Expect a single binary message per chunk
156
+ data = await ws.receive_bytes()
157
+
158
+ # Parse metadata length
159
+ if len(data) < 4:
160
+ continue
161
+ meta_len = int.from_bytes(data[:4], byteorder="little", signed=False)
162
+ if 4 + meta_len > len(data):
163
+ continue
164
 
165
+ # Parse metadata JSON
166
+ meta_json = data[4:4+meta_len].decode("utf-8", errors="ignore")
167
+ try:
168
+ meta = json.loads(meta_json)
169
+ sample_rate = int(meta.get("sampleRate", 48000))
170
+ except Exception:
171
+ sample_rate = 48000
172
 
173
+ # PCM16 payload
174
+ chunk = data[4+meta_len:]
175
+ if not chunk:
176
+ continue
177
+
178
+ # Convert to 16k mono PCM16
179
+ resampled = decode_and_resample(chunk, sample_rate, 16000)
180
+
181
+ # Feed into the recorder
182
+ try:
183
+ recorder.feed_audio(resampled)
184
+ except Exception:
185
+ # recorder not ready or an intermittent error; ignore this chunk
186
+ pass
187
+
188
+ except Exception:
189
+ # connection closed or error
190
+ pass
191
+ finally:
192
+ # mark inactive
193
+ if active_ws is ws:
194
+ active_ws = None
195
+
196
+ # ----------------------------
197
+ # Health
198
+ # ----------------------------
199
  @app.get("/health")
 
 
 
200
  def health():
201
  return {"status": "ok"}
202
 
203
+ # ----------------------------
204
+ # Frontend: index.html + client JS
205
+ # ----------------------------
206
+ INDEX_HTML = """<!doctype html>
207
+ <html>
208
+ <head>
209
+ <meta charset="utf-8" />
210
+ <title>Realtime STT (HF Space)</title>
211
+ <style>
212
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif; margin: 24px; }
213
+ .row { margin: 12px 0; }
214
+ #log { white-space: pre-wrap; background: #111; color: #0f0; padding: 12px; border-radius: 8px; height: 240px; overflow:auto; }
215
+ button { padding: 8px 12px; border-radius: 8px; border: 1px solid #888; background: #222; color: #fff; cursor: pointer; }
216
+ input[type=number] { width: 100px; }
217
+ label { display: inline-block; min-width: 130px; }
218
+ </style>
219
+ </head>
220
+ <body>
221
+ <h2>Realtime STT WebSocket Demo</h2>
222
+ <div class="row">
223
+ <label>Sample Rate</label>
224
+ <input id="sr" type="number" value="48000" />
225
+ </div>
226
+ <div class="row">
227
+ <button id="start">Start</button>
228
+ <button id="stop">Stop</button>
229
+ </div>
230
+ <div class="row">
231
+ <strong>Live output:</strong>
232
+ <div id="log"></div>
233
+ </div>
234
 
235
+ <script>
236
+ let ws = null;
237
+ let audioCtx = null;
238
+ let micStream = null;
239
+ let processor = null;
240
+ let source = null;
241
 
242
+ function log(s) {
243
+ const el = document.getElementById('log');
244
+ el.textContent += s + "\\n";
245
+ el.scrollTop = el.scrollHeight;
246
+ }
247
 
248
+ async function start() {
249
+ const targetSR = parseInt(document.getElementById('sr').value, 10) || 48000;
 
 
 
 
 
250
 
251
+ // Setup WS
252
+ const wsProto = location.protocol === 'https:' ? 'wss' : 'ws';
253
+ const wsURL = wsProto + '://' + location.host + '/ws';
254
+ ws = new WebSocket(wsURL);
255
+ ws.onopen = () => log('WS connected: ' + wsURL);
256
+ ws.onmessage = (ev) => {
257
+ try {
258
+ const msg = JSON.parse(ev.data);
259
+ if (msg.type === 'realtime') {
260
+ log('[partial] ' + msg.text);
261
+ } else if (msg.type === 'fullSentence') {
262
+ log('[final] ' + msg.text);
263
+ } else if (msg.type === 'error') {
264
+ log('[error] ' + msg.error);
265
+ } else {
266
+ log('[msg] ' + ev.data);
267
+ }
268
+ } catch (e) {
269
+ log('[raw] ' + ev.data);
270
+ }
271
+ };
272
+ ws.onerror = (e) => log('WS error');
273
+ ws.onclose = () => log('WS closed');
274
 
275
+ // Setup audio
276
+ audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: targetSR });
277
+ micStream = await navigator.mediaDevices.getUserMedia({ audio: true });
278
+ source = audioCtx.createMediaStreamSource(micStream);
279
+
280
+ // ScriptProcessor (deprecated but widely supported)
281
+ processor = audioCtx.createScriptProcessor(4096, 1, 1);
282
+ processor.onaudioprocess = (e) => {
283
+ // Float32 [-1,1] -> PCM16 little-endian
284
+ const input = e.inputBuffer.getChannelData(0);
285
+ const buf = new ArrayBuffer(input.length * 2);
286
+ const view = new DataView(buf);
287
+ for (let i = 0; i < input.length; i++) {
288
+ let s = Math.max(-1, Math.min(1, input[i]));
289
+ view.setInt16(i * 2, s < 0 ? s * 0x8000 : s * 0x7FFF, true);
290
+ }
291
+
292
+ // Build frame: [4-byte meta length][meta JSON][PCM16 payload]
293
+ const meta = JSON.stringify({ sampleRate: audioCtx.sampleRate });
294
+ const metaBytes = new TextEncoder().encode(meta);
295
+ const header = new Uint8Array(4 + metaBytes.length);
296
+ const dv = new DataView(header.buffer);
297
+ dv.setUint32(0, metaBytes.length, true);
298
+ header.set(metaBytes, 4);
299
+
300
+ // Concatenate header + payload
301
+ const payload = new Uint8Array(buf);
302
+ const frame = new Uint8Array(header.length + payload.length);
303
+ frame.set(header, 0);
304
+ frame.set(payload, header.length);
305
+
306
+ if (ws && ws.readyState === 1) {
307
+ ws.send(frame);
308
+ }
309
+ };
310
+
311
+ source.connect(processor);
312
+ processor.connect(audioCtx.destination);
313
+ }
314
+
315
+ function stop() {
316
+ try { if (processor) processor.disconnect(); } catch {}
317
+ try { if (source) source.disconnect(); } catch {}
318
+ try { if (micStream) micStream.getTracks().forEach(t => t.stop()); } catch {}
319
+ try { if (audioCtx) audioCtx.close(); } catch {}
320
+ try { if (ws) ws.close(); } catch {}
321
+ ws = null; micStream = null; source = null; processor = null; audioCtx = null;
322
+ log('stopped.');
323
+ }
324
+
325
+ document.getElementById('start').onclick = () => start().catch(e => log('start error: ' + e.message));
326
+ document.getElementById('stop').onclick = () => stop();
327
+ </script>
328
+ </body>
329
+ </html>
330
+ """
331
+
332
+ @app.get("/")
333
+ def index():
334
+ return HTMLResponse(INDEX_HTML)