Peter Michael Gits Claude commited on
Commit
ee64ed2
·
1 Parent(s): 1111d6d

feat: Implement hybrid Gradio+FastAPI WebSocket service for ZeroGPU compatibility

Browse files

- Use gr.mount_gradio_app() for proper WebSocket routing (fixes 404 WebSocket errors)
- Create minimal Gradio interface for HF Spaces compliance with ZeroGPU
- Mount FastAPI WebSocket endpoints at /ws/stt using official mounting approach
- Maintain ZeroGPU compatibility with @spaces.GPU decorators on global functions
- Add CORS middleware for WebRTC connectivity
- Implement WebSocket connection tracking and message handling
- Remove complex lifecycle management (let Gradio handle queue management)
- Based on research of HF Spaces best practices and known WebSocket fixes

Architecture: HF Spaces (Gradio SDK) → gr.mount_gradio_app() → FastAPI WebSocket

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +138 -166
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- Standalone WebSocket-only STT Service
4
- Simplified service without Gradio, MCP, or web interfaces
5
  Following unmute.sh WebRTC pattern for HuggingFace Spaces
6
  """
7
 
@@ -14,7 +14,6 @@ import os
14
  import logging
15
  from datetime import datetime
16
  from typing import Optional, Dict, Any
17
- from contextlib import asynccontextmanager
18
  import torch
19
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
20
  import torchaudio
@@ -22,8 +21,8 @@ import soundfile as sf
22
  import numpy as np
23
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
24
  from fastapi.middleware.cors import CORSMiddleware
 
25
  import spaces
26
- import uvicorn
27
 
28
  # Configure logging
29
  logging.basicConfig(level=logging.INFO)
@@ -115,142 +114,44 @@ def transcribe_audio_zerogpu(
115
  logger.error(f"Transcription error: {str(e)}")
116
  return "", "error", {"error": str(e)}
117
 
118
- class STTWebSocketService:
119
- """Standalone STT service with WebSocket-only interface"""
120
-
121
- def __init__(self):
122
- self.active_connections: Dict[str, WebSocket] = {}
123
-
124
- logger.info(f"🎤 {__service__} v{__version__} initializing...")
125
- logger.info(f"Device: {device}")
126
- logger.info(f"Model: whisper-{model_size}")
127
-
128
- async def load_model(self):
129
- """Load Whisper model with ZeroGPU compatibility - delegates to global function"""
130
- global model
131
- if model is None:
132
- # Trigger model loading by calling the ZeroGPU function with a dummy path
133
- # The actual loading will happen on first real transcription
134
- logger.info("Model will be loaded on first transcription request...")
135
- else:
136
- logger.info("✅ Model already loaded")
137
-
138
- async def connect_websocket(self, websocket: WebSocket) -> str:
139
- """Accept WebSocket connection and return client ID"""
140
- client_id = str(uuid.uuid4())
141
- await websocket.accept()
142
- self.active_connections[client_id] = websocket
143
-
144
- # Send connection confirmation
145
- await websocket.send_text(json.dumps({
146
- "type": "stt_connection_confirmed",
147
- "client_id": client_id,
148
- "service": __service__,
149
- "version": __version__,
150
- "model": f"whisper-{model_size}",
151
- "device": device,
152
- "message": "STT WebSocket connected and ready"
153
- }))
154
-
155
- logger.info(f"Client {client_id} connected")
156
- return client_id
157
-
158
- async def disconnect_websocket(self, client_id: str):
159
- """Clean up WebSocket connection"""
160
- if client_id in self.active_connections:
161
- del self.active_connections[client_id]
162
- logger.info(f"Client {client_id} disconnected")
163
-
164
- async def process_audio_message(self, client_id: str, message: Dict[str, Any]):
165
- """Process incoming audio data from WebSocket"""
166
- try:
167
- websocket = self.active_connections[client_id]
168
-
169
- # Extract audio data (base64 encoded)
170
- audio_data_b64 = message.get("audio_data")
171
- if not audio_data_b64:
172
- await websocket.send_text(json.dumps({
173
- "type": "stt_transcription_error",
174
- "client_id": client_id,
175
- "error": "No audio data provided"
176
- }))
177
- return
178
-
179
- # Decode base64 audio
180
- audio_bytes = base64.b64decode(audio_data_b64)
181
-
182
- # Save to temporary file
183
- with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file:
184
- tmp_file.write(audio_bytes)
185
- temp_path = tmp_file.name
186
-
187
- try:
188
- # Transcribe audio using global ZeroGPU function
189
- transcription, status, timing = transcribe_audio_zerogpu(
190
- temp_path,
191
- message.get("language", "auto"),
192
- message.get("model_size", model_size)
193
- )
194
-
195
- # Send result back
196
- if status == "success" and transcription:
197
- await websocket.send_text(json.dumps({
198
- "type": "stt_transcription_complete",
199
- "client_id": client_id,
200
- "transcription": transcription,
201
- "timing": timing,
202
- "status": "success"
203
- }))
204
- else:
205
- await websocket.send_text(json.dumps({
206
- "type": "stt_transcription_error",
207
- "client_id": client_id,
208
- "error": "Transcription failed or empty result",
209
- "timing": timing
210
- }))
211
-
212
- finally:
213
- # Clean up temp file
214
- if os.path.exists(temp_path):
215
- os.unlink(temp_path)
216
-
217
- except Exception as e:
218
- logger.error(f"Error processing audio for {client_id}: {str(e)}")
219
- if client_id in self.active_connections:
220
- websocket = self.active_connections[client_id]
221
- await websocket.send_text(json.dumps({
222
- "type": "stt_transcription_error",
223
- "client_id": client_id,
224
- "error": f"Processing error: {str(e)}"
225
- }))
226
 
227
- # Initialize service
228
- stt_service = STTWebSocketService()
 
 
 
229
 
230
- @asynccontextmanager
231
- async def lifespan(app: FastAPI):
232
- """Lifespan event handler for FastAPI app startup/shutdown"""
233
- # Startup
234
- logger.info(f"🚀 {__service__} v{__version__} starting...")
235
- logger.info("Pre-loading Whisper model for optimal performance...")
236
- await stt_service.load_model()
237
- logger.info("✅ Service ready for WebSocket connections")
238
-
239
- yield
240
-
241
- # Shutdown
242
- logger.info("🛑 STT WebSocket Service shutting down...")
243
 
244
- # Create FastAPI app with lifespan
245
- app = FastAPI(
246
- title="STT WebSocket Service",
247
- description="Standalone WebSocket-only Speech-to-Text service",
248
- version=__version__,
249
- lifespan=lifespan
 
 
 
250
  )
251
 
252
- # Add CORS middleware
253
- app.add_middleware(
 
 
 
 
 
 
254
  CORSMiddleware,
255
  allow_origins=["*"],
256
  allow_credentials=True,
@@ -258,42 +159,116 @@ app.add_middleware(
258
  allow_headers=["*"],
259
  )
260
 
261
- @app.get("/")
262
- async def root():
263
- """Health check endpoint"""
264
- return {
265
- "service": __service__,
266
- "version": __version__,
267
- "status": "ready",
268
- "endpoints": {
269
- "websocket": "/ws/stt",
270
- "health": "/health"
271
- },
272
- "model": f"whisper-{model_size}",
273
- "device": device
274
- }
275
-
276
- @app.get("/health")
277
  async def health_check():
278
- """Detailed health check"""
279
  return {
280
  "service": __service__,
281
  "version": __version__,
282
  "status": "healthy",
283
  "model_loaded": model is not None,
284
- "active_connections": len(stt_service.active_connections),
285
  "device": device,
286
  "timestamp": datetime.now().isoformat()
287
  }
288
 
289
- @app.websocket("/ws/stt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  async def websocket_stt_endpoint(websocket: WebSocket):
291
  """Main STT WebSocket endpoint"""
292
  client_id = None
293
 
294
  try:
295
  # Accept connection
296
- client_id = await stt_service.connect_websocket(websocket)
297
 
298
  # Handle messages
299
  while True:
@@ -306,7 +281,7 @@ async def websocket_stt_endpoint(websocket: WebSocket):
306
  message_type = message.get("type", "unknown")
307
 
308
  if message_type == "stt_audio_chunk":
309
- await stt_service.process_audio_message(client_id, message)
310
  elif message_type == "ping":
311
  # Respond to ping
312
  await websocket.send_text(json.dumps({
@@ -335,15 +310,12 @@ async def websocket_stt_endpoint(websocket: WebSocket):
335
  logger.error(f"WebSocket error for {client_id}: {str(e)}")
336
  finally:
337
  if client_id:
338
- await stt_service.disconnect_websocket(client_id)
339
 
 
 
 
 
340
  if __name__ == "__main__":
341
- port = int(os.environ.get("PORT", 7860))
342
- logger.info(f"🎤 Starting {__service__} v{__version__} on port {port}")
343
-
344
- uvicorn.run(
345
- app,
346
- host="0.0.0.0",
347
- port=port,
348
- log_level="info"
349
- )
 
1
  #!/usr/bin/env python3
2
  """
3
+ STT WebSocket Service with Gradio + FastAPI Integration
4
+ ZeroGPU compatible service with WebSocket endpoints for VoiceCal
5
  Following unmute.sh WebRTC pattern for HuggingFace Spaces
6
  """
7
 
 
14
  import logging
15
  from datetime import datetime
16
  from typing import Optional, Dict, Any
 
17
  import torch
18
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
19
  import torchaudio
 
21
  import numpy as np
22
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
23
  from fastapi.middleware.cors import CORSMiddleware
24
+ import gradio as gr
25
  import spaces
 
26
 
27
  # Configure logging
28
  logging.basicConfig(level=logging.INFO)
 
114
  logger.error(f"Transcription error: {str(e)}")
115
  return "", "error", {"error": str(e)}
116
 
117
+ # Global WebSocket connection tracker
118
+ active_connections: Dict[str, WebSocket] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
+ # Simple Gradio interface for HF Spaces compliance
121
+ def get_service_info():
122
+ """Simple function for Gradio interface"""
123
+ return f"""
124
+ # 🎤 STT WebSocket Service v{__version__}
125
 
126
+ **WebSocket Endpoint:** `/ws/stt`
127
+ **Model:** Whisper {model_size}
128
+ **Device:** {device}
129
+ **ZeroGPU:** {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'}
130
+
131
+ **Status:** Ready for WebSocket connections
132
+
133
+ Connect your WebRTC client to: `wss://your-space.hf.space/ws/stt`
134
+ """
 
 
 
 
135
 
136
+ # Create minimal Gradio interface for HF Spaces
137
+ demo = gr.Interface(
138
+ fn=get_service_info,
139
+ inputs=None,
140
+ outputs=gr.Markdown(),
141
+ title="🎤 STT WebSocket Service v1.0.0",
142
+ description="WebSocket-enabled Speech-to-Text service with ZeroGPU acceleration",
143
+ examples=None,
144
+ live=False
145
  )
146
 
147
+ # Create FastAPI app for WebSocket endpoints
148
+ fastapi_app = FastAPI(
149
+ title="STT WebSocket Service",
150
+ version=__version__
151
+ )
152
+
153
+ # Add CORS middleware for WebRTC
154
+ fastapi_app.add_middleware(
155
  CORSMiddleware,
156
  allow_origins=["*"],
157
  allow_credentials=True,
 
159
  allow_headers=["*"],
160
  )
161
 
162
+ @fastapi_app.get("/api/health")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  async def health_check():
164
+ """Health check endpoint"""
165
  return {
166
  "service": __service__,
167
  "version": __version__,
168
  "status": "healthy",
169
  "model_loaded": model is not None,
170
+ "active_connections": len(active_connections),
171
  "device": device,
172
  "timestamp": datetime.now().isoformat()
173
  }
174
 
175
+ async def connect_websocket(websocket: WebSocket) -> str:
176
+ """Accept WebSocket connection and return client ID"""
177
+ client_id = str(uuid.uuid4())
178
+ await websocket.accept()
179
+ active_connections[client_id] = websocket
180
+
181
+ # Send connection confirmation
182
+ await websocket.send_text(json.dumps({
183
+ "type": "stt_connection_confirmed",
184
+ "client_id": client_id,
185
+ "service": __service__,
186
+ "version": __version__,
187
+ "model": f"whisper-{model_size}",
188
+ "device": device,
189
+ "message": "STT WebSocket connected and ready"
190
+ }))
191
+
192
+ logger.info(f"Client {client_id} connected")
193
+ return client_id
194
+
195
+ async def disconnect_websocket(client_id: str):
196
+ """Clean up WebSocket connection"""
197
+ if client_id in active_connections:
198
+ del active_connections[client_id]
199
+ logger.info(f"Client {client_id} disconnected")
200
+
201
+ async def process_audio_message(client_id: str, message: Dict[str, Any]):
202
+ """Process incoming audio data from WebSocket"""
203
+ try:
204
+ websocket = active_connections[client_id]
205
+
206
+ # Extract audio data (base64 encoded)
207
+ audio_data_b64 = message.get("audio_data")
208
+ if not audio_data_b64:
209
+ await websocket.send_text(json.dumps({
210
+ "type": "stt_transcription_error",
211
+ "client_id": client_id,
212
+ "error": "No audio data provided"
213
+ }))
214
+ return
215
+
216
+ # Decode base64 audio
217
+ audio_bytes = base64.b64decode(audio_data_b64)
218
+
219
+ # Save to temporary file
220
+ with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file:
221
+ tmp_file.write(audio_bytes)
222
+ temp_path = tmp_file.name
223
+
224
+ try:
225
+ # Transcribe audio using global ZeroGPU function
226
+ transcription, status, timing = transcribe_audio_zerogpu(
227
+ temp_path,
228
+ message.get("language", "auto"),
229
+ message.get("model_size", model_size)
230
+ )
231
+
232
+ # Send result back
233
+ if status == "success" and transcription:
234
+ await websocket.send_text(json.dumps({
235
+ "type": "stt_transcription_complete",
236
+ "client_id": client_id,
237
+ "transcription": transcription,
238
+ "timing": timing,
239
+ "status": "success"
240
+ }))
241
+ else:
242
+ await websocket.send_text(json.dumps({
243
+ "type": "stt_transcription_error",
244
+ "client_id": client_id,
245
+ "error": "Transcription failed or empty result",
246
+ "timing": timing
247
+ }))
248
+
249
+ finally:
250
+ # Clean up temp file
251
+ if os.path.exists(temp_path):
252
+ os.unlink(temp_path)
253
+
254
+ except Exception as e:
255
+ logger.error(f"Error processing audio for {client_id}: {str(e)}")
256
+ if client_id in active_connections:
257
+ websocket = active_connections[client_id]
258
+ await websocket.send_text(json.dumps({
259
+ "type": "stt_transcription_error",
260
+ "client_id": client_id,
261
+ "error": f"Processing error: {str(e)}"
262
+ }))
263
+
264
+ @fastapi_app.websocket("/ws/stt")
265
  async def websocket_stt_endpoint(websocket: WebSocket):
266
  """Main STT WebSocket endpoint"""
267
  client_id = None
268
 
269
  try:
270
  # Accept connection
271
+ client_id = await connect_websocket(websocket)
272
 
273
  # Handle messages
274
  while True:
 
281
  message_type = message.get("type", "unknown")
282
 
283
  if message_type == "stt_audio_chunk":
284
+ await process_audio_message(client_id, message)
285
  elif message_type == "ping":
286
  # Respond to ping
287
  await websocket.send_text(json.dumps({
 
310
  logger.error(f"WebSocket error for {client_id}: {str(e)}")
311
  finally:
312
  if client_id:
313
+ await disconnect_websocket(client_id)
314
 
315
+ # CRITICAL: Use gr.mount_gradio_app() for proper WebSocket routing
316
+ app = gr.mount_gradio_app(fastapi_app, demo, path="/")
317
+
318
+ # For HuggingFace Spaces - this becomes the main app
319
  if __name__ == "__main__":
320
+ logger.info(f"🎤 Starting {__service__} v{__version__} with Gradio+WebSocket integration")
321
+ demo.launch(server_port=7860, server_name="0.0.0.0")
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,10 @@
1
- # Minimal requirements for WebSocket-only STT service
2
  torch>=2.1.0
3
  torchaudio>=2.1.0
4
  transformers>=4.35.0
5
  accelerate>=0.24.0
6
  spaces>=0.19.0
 
7
  numpy>=1.21.0
8
  soundfile>=0.12.0
9
  fastapi>=0.104.0
 
1
+ # Requirements for Gradio+WebSocket STT service with ZeroGPU
2
  torch>=2.1.0
3
  torchaudio>=2.1.0
4
  transformers>=4.35.0
5
  accelerate>=0.24.0
6
  spaces>=0.19.0
7
+ gradio>=5.42.0
8
  numpy>=1.21.0
9
  soundfile>=0.12.0
10
  fastapi>=0.104.0