Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| STT WebSocket Service with Gradio + FastAPI Integration | |
| ZeroGPU compatible service with WebSocket endpoints for VoiceCal | |
| Following unmute.sh WebRTC pattern for HuggingFace Spaces | |
| """ | |
| import os | |
| import logging | |
| # Configure logging first | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # CRITICAL: ZeroGPU authentication must happen BEFORE importing spaces | |
| try: | |
| from huggingface_hub import login, HfFolder | |
| # Get and validate HF_TOKEN | |
| hf_token = os.getenv('HF_TOKEN') | |
| logger.info(f"π DEBUG: HF_TOKEN environment variable: {'SET' if hf_token else 'NOT SET'}") | |
| if hf_token: | |
| # Try multiple authentication methods for ZeroGPU compatibility | |
| # Method 1: Standard huggingface_hub login | |
| login(token=hf_token, write_permission=True) | |
| # Method 2: Explicitly set in HfFolder (legacy method) | |
| HfFolder.save_token(hf_token) | |
| # Method 3: Ensure environment variable is explicitly set | |
| os.environ['HF_TOKEN'] = hf_token | |
| os.environ['HUGGINGFACE_HUB_TOKEN'] = hf_token # Alternative env var | |
| # Method 4: ZeroGPU-specific environment variables (from community reports) | |
| os.environ['ZEROGPU_V2'] = 'true' | |
| os.environ['ZERO_GPU_PATCH_TORCH_DEVICE'] = '1' | |
| logger.info(f"π DEBUG: HuggingFace authentication successful (early init) - Token length: {len(hf_token)}") | |
| logger.info(f"π DEBUG: Token prefix: {hf_token[:10]}...{hf_token[-4:]}") | |
| else: | |
| logger.warning(f"β οΈ DEBUG: No HF_TOKEN found - using anonymous access (limited quota)") | |
| except Exception as e: | |
| logger.error(f"β DEBUG: HuggingFace login failed (early init): {e}") | |
| import traceback | |
| logger.error(f"β DEBUG: Full traceback: {traceback.format_exc()}") | |
| # Now import everything else | |
| import asyncio | |
| import json | |
| import uuid | |
| import base64 | |
| import tempfile | |
| from datetime import datetime | |
| from typing import Optional, Dict, Any | |
| import torch | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import torchaudio | |
| import soundfile as sf | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import gradio as gr | |
| import spaces | |
| # Version info | |
| __version__ = "1.0.4" | |
| __service__ = "STT WebSocket Service" | |
| # Global variables for model (required for ZeroGPU decorator) | |
| model = None | |
| processor = None | |
| model_size = "base" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def transcribe_audio_zerogpu( | |
| audio_path: str, | |
| language: str = "en", | |
| model_size_param: str = "base" | |
| ) -> tuple[str, str, Dict[str, Any]]: | |
| """Transcribe audio file using Whisper with ZeroGPU""" | |
| global model, processor, model_size, device | |
| try: | |
| start_time = datetime.now() | |
| # Load model if not already loaded | |
| if model is None: | |
| logger.info(f"Loading Whisper {model_size_param} model...") | |
| model_name = f"openai/whisper-{model_size_param}" | |
| processor = WhisperProcessor.from_pretrained(model_name) | |
| model = WhisperForConditionalGeneration.from_pretrained(model_name) | |
| if device == "cuda": | |
| model = model.to(device) | |
| logger.info(f"β Model loaded on {device}") | |
| # Load and preprocess audio (following unmute.sh pattern) | |
| audio_input, sample_rate = torchaudio.load(audio_path) | |
| # Convert to 16kHz mono (Whisper requirement) | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| audio_input = resampler(audio_input) | |
| if audio_input.shape[0] > 1: | |
| audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
| audio_array = audio_input.squeeze().numpy() | |
| # Process with Whisper | |
| inputs = processor( | |
| audio_array, | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| if device == "cuda": | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate transcription | |
| with torch.no_grad(): | |
| predicted_ids = model.generate(**inputs) | |
| transcription = processor.batch_decode( | |
| predicted_ids, | |
| skip_special_tokens=True | |
| )[0] | |
| # Calculate timing | |
| end_time = datetime.now() | |
| processing_time = (end_time - start_time).total_seconds() | |
| timing_info = { | |
| "processing_time": processing_time, | |
| "start_time": start_time.isoformat(), | |
| "end_time": end_time.isoformat(), | |
| "model_size": model_size_param, | |
| "device": device | |
| } | |
| logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'") | |
| return transcription.strip(), "success", timing_info | |
| except Exception as e: | |
| logger.error(f"Transcription error: {str(e)}") | |
| return "", "error", {"error": str(e)} | |
| # Global WebSocket connection tracker | |
| active_connections: Dict[str, WebSocket] = {} | |
| # Simple Gradio interface for HF Spaces compliance | |
| def get_service_info(): | |
| """Simple function for Gradio interface""" | |
| return f""" | |
| # π€ STT WebSocket Service v{__version__} | |
| **WebSocket Endpoint:** `/ws/stt` | |
| **Model:** Whisper {model_size} | |
| **Device:** {device} | |
| **ZeroGPU:** {'β Available' if torch.cuda.is_available() else 'β Not Available'} | |
| **Status:** Ready for WebSocket connections | |
| Connect your WebRTC client to: `wss://your-space.hf.space/ws/stt` | |
| """ | |
| def gradio_transcribe_wrapper(audio_file, language="en", model_size_param="base"): | |
| """Gradio wrapper for transcription function""" | |
| try: | |
| # DEBUG: Log all incoming requests | |
| logger.info(f"π€ DEBUG: Gradio transcription request received") | |
| logger.info(f"π€ DEBUG: Audio file: {audio_file}") | |
| logger.info(f"π€ DEBUG: Language: {language}") | |
| logger.info(f"π€ DEBUG: Model size: {model_size_param}") | |
| if audio_file is None: | |
| logger.warning("π€ DEBUG: No audio file provided to Gradio wrapper") | |
| return "β No audio file provided", "{}", "Please upload an audio file" | |
| # DEBUG: Check file details | |
| if isinstance(audio_file, str) and os.path.exists(audio_file): | |
| file_size = os.path.getsize(audio_file) | |
| logger.info(f"π€ DEBUG: Audio file size: {file_size} bytes") | |
| logger.info(f"π€ DEBUG: Audio file path: {audio_file}") | |
| else: | |
| logger.warning(f"π€ DEBUG: Invalid audio file: {type(audio_file)}") | |
| # Use the ZeroGPU transcription function | |
| logger.info(f"π€ DEBUG: Calling transcribe_audio_zerogpu...") | |
| transcription, status, timing = transcribe_audio_zerogpu( | |
| audio_file, language, model_size_param | |
| ) | |
| logger.info(f"π€ DEBUG: Transcription result: '{transcription[:100]}...'") | |
| logger.info(f"π€ DEBUG: Status: {status}") | |
| if status == "success": | |
| return f"β {transcription}", json.dumps(timing, indent=2), f"Status: {status}" | |
| else: | |
| return f"β Transcription failed", json.dumps(timing, indent=2), f"Status: {status}" | |
| except Exception as e: | |
| error_msg = f"Error in gradio_transcribe_wrapper: {str(e)}" | |
| logger.error(f"π€ DEBUG: {error_msg}") | |
| return f"β Error: {str(e)}", "{}", "Error occurred during transcription" | |
| def gradio_transcribe_memory(audio_base64: str, language: str = "en", model_size_param: str = "base"): | |
| """π NEW: In-memory transcription function - TRUE UNMUTE.SH METHODOLOGY""" | |
| try: | |
| import base64 | |
| import io | |
| # DEBUG: Log in-memory request | |
| logger.info(f"π MEMORY: In-memory transcription request received") | |
| logger.info(f"π MEMORY: Audio base64 length: {len(audio_base64)} chars") | |
| logger.info(f"π MEMORY: Language: {language}") | |
| logger.info(f"π MEMORY: Model size: {model_size_param}") | |
| if not audio_base64 or audio_base64 == "": | |
| logger.warning("π MEMORY: No audio data provided") | |
| return "β No audio data provided" | |
| # Decode base64 to binary audio data | |
| try: | |
| audio_binary = base64.b64decode(audio_base64) | |
| logger.info(f"π MEMORY: Decoded audio size: {len(audio_binary)} bytes") | |
| except Exception as decode_error: | |
| logger.error(f"π MEMORY: Base64 decode error: {decode_error}") | |
| return f"β Invalid base64 audio data: {decode_error}" | |
| # Save to temporary file for Whisper processing (still needed for torchaudio.load) | |
| with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: | |
| tmp_file.write(audio_binary) | |
| temp_path = tmp_file.name | |
| try: | |
| # Use the same ZeroGPU transcription function | |
| logger.info(f"π MEMORY: Processing with ZeroGPU...") | |
| transcription, status, timing = transcribe_audio_zerogpu( | |
| temp_path, language, model_size_param | |
| ) | |
| logger.info(f"π MEMORY: Transcription result: '{transcription[:100] if transcription else 'None'}...'") | |
| logger.info(f"π MEMORY: Status: {status}") | |
| if status == "success": | |
| return transcription # Return only transcription for voiceCal-ai compatibility | |
| else: | |
| return "β Transcription failed" | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| logger.info(f"π MEMORY: Cleaned up temp file") | |
| except Exception as e: | |
| error_msg = f"Error in gradio_transcribe_memory: {str(e)}" | |
| logger.error(f"π MEMORY: {error_msg}") | |
| return f"β Error: {str(e)}" | |
| # Create Gradio interface with transcription functionality | |
| with gr.Blocks(title="π€ STT WebSocket Service v1.0.0") as demo: | |
| gr.Markdown(""" | |
| # π€ STT WebSocket Service v1.0.0 | |
| **WebSocket-enabled Speech-to-Text service with ZeroGPU acceleration** | |
| ## π Service Information | |
| - **Service:** STT WebSocket Service v1.0.0 | |
| - **ZeroGPU:** Enabled with H200 acceleration | |
| - **WebSocket Endpoint:** `wss://your-space.hf.space/ws/stt` | |
| - **HTTP API:** Available for direct file uploads | |
| ## π€ Upload Audio for Transcription | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| language_input = gr.Dropdown( | |
| choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="en", | |
| label="Language (English by default)" | |
| ) | |
| model_input = gr.Dropdown( | |
| choices=["tiny", "base", "small", "medium", "large-v2"], | |
| value="base", | |
| label="Model Size" | |
| ) | |
| transcribe_btn = gr.Button("π€ Transcribe Audio", variant="primary") | |
| with gr.Column(): | |
| transcription_output = gr.Textbox( | |
| label="Transcription Result", | |
| lines=4, | |
| placeholder="Transcription will appear here..." | |
| ) | |
| timing_output = gr.Code( | |
| label="Timing Information", | |
| language="json" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=1 | |
| ) | |
| transcribe_btn.click( | |
| fn=gradio_transcribe_wrapper, | |
| inputs=[audio_input, language_input, model_input], | |
| outputs=[transcription_output, timing_output, status_output] | |
| ) | |
| # π NEW: In-Memory Transcription Interface (TRUE UNMUTE.SH) | |
| gr.Markdown(""" | |
| --- | |
| ## π In-Memory Transcription (WebRTC Streaming) | |
| **For real-time audio processing without file uploads** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_base64_input = gr.Textbox( | |
| label="Audio Base64 Data", | |
| placeholder="Base64 encoded audio data will be sent here from WebRTC...", | |
| lines=3 | |
| ) | |
| memory_language_input = gr.Dropdown( | |
| choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="en", | |
| label="Language (English by default)" | |
| ) | |
| memory_model_input = gr.Dropdown( | |
| choices=["tiny", "base", "small", "medium", "large-v2"], | |
| value="base", | |
| label="Model Size" | |
| ) | |
| memory_transcribe_btn = gr.Button("π Transcribe In-Memory", variant="secondary") | |
| with gr.Column(): | |
| memory_transcription_output = gr.Textbox( | |
| label="In-Memory Transcription Result", | |
| lines=4, | |
| placeholder="Real-time transcription will appear here..." | |
| ) | |
| memory_timing_output = gr.Code( | |
| label="Performance Metrics", | |
| language="json" | |
| ) | |
| memory_status_output = gr.Textbox( | |
| label="Processing Status", | |
| lines=1 | |
| ) | |
| memory_transcribe_btn.click( | |
| fn=gradio_transcribe_memory, | |
| inputs=[audio_base64_input, memory_language_input, memory_model_input], | |
| outputs=[memory_transcription_output, memory_timing_output, memory_status_output] | |
| ) | |
| # Create FastAPI app for WebSocket endpoints | |
| fastapi_app = FastAPI( | |
| title="STT WebSocket Service", | |
| version=__version__ | |
| ) | |
| # Add CORS middleware for WebRTC | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "active_connections": len(active_connections), | |
| "device": device, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def connect_websocket(websocket: WebSocket) -> str: | |
| """Accept WebSocket connection and return client ID""" | |
| client_id = str(uuid.uuid4()) | |
| await websocket.accept() | |
| active_connections[client_id] = websocket | |
| # Send connection confirmation | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_connection_confirmed", | |
| "client_id": client_id, | |
| "service": __service__, | |
| "version": __version__, | |
| "model": f"whisper-{model_size}", | |
| "device": device, | |
| "message": "STT WebSocket connected and ready" | |
| })) | |
| logger.info(f"Client {client_id} connected") | |
| return client_id | |
| async def disconnect_websocket(client_id: str): | |
| """Clean up WebSocket connection""" | |
| if client_id in active_connections: | |
| del active_connections[client_id] | |
| logger.info(f"Client {client_id} disconnected") | |
| async def process_audio_message(client_id: str, message: Dict[str, Any]): | |
| """Process incoming audio data from WebSocket""" | |
| try: | |
| websocket = active_connections[client_id] | |
| # Extract audio data (base64 encoded) | |
| audio_data_b64 = message.get("audio_data") | |
| if not audio_data_b64: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "No audio data provided" | |
| })) | |
| return | |
| # Decode base64 audio | |
| audio_bytes = base64.b64decode(audio_data_b64) | |
| # Save to temporary file | |
| with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| temp_path = tmp_file.name | |
| try: | |
| # Transcribe audio using global ZeroGPU function | |
| transcription, status, timing = transcribe_audio_zerogpu( | |
| temp_path, | |
| message.get("language", "en"), | |
| message.get("model_size", model_size) | |
| ) | |
| # Send result back | |
| if status == "success" and transcription: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_complete", | |
| "client_id": client_id, | |
| "transcription": transcription, | |
| "timing": timing, | |
| "status": "success" | |
| })) | |
| else: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "Transcription failed or empty result", | |
| "timing": timing | |
| })) | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| logger.error(f"Error processing audio for {client_id}: {str(e)}") | |
| if client_id in active_connections: | |
| websocket = active_connections[client_id] | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": f"Processing error: {str(e)}" | |
| })) | |
| # WebSocket endpoint will be added to Gradio's FastAPI app in main block | |
| # For HuggingFace Spaces - we need to launch the Gradio demo | |
| # and add WebSocket routes to its internal FastAPI app | |
| if __name__ == "__main__": | |
| logger.info(f"π€ DEBUG: Starting {__service__} v{__version__} with Gradio+WebSocket integration") | |
| logger.info(f"π€ DEBUG: Device: {device}") | |
| logger.info(f"π€ DEBUG: Model size: {model_size}") | |
| logger.info(f"π€ DEBUG: Default language: English (en)") | |
| logger.info(f"π€ DEBUG: Service ready for connections") | |
| # Create FastAPI app for WebSocket endpoints | |
| fastapi_app = FastAPI(title="STT WebSocket API") | |
| # Add CORS middleware | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def websocket_stt_endpoint(websocket: WebSocket): | |
| """Main STT WebSocket endpoint""" | |
| client_id = None | |
| try: | |
| # Accept connection | |
| client_id = await connect_websocket(websocket) | |
| # Handle messages | |
| while True: | |
| try: | |
| # Receive message | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| # Process based on message type | |
| message_type = message.get("type", "unknown") | |
| if message_type == "stt_audio_chunk": | |
| await process_audio_message(client_id, message) | |
| elif message_type == "ping": | |
| # Respond to ping | |
| await websocket.send_text(json.dumps({ | |
| "type": "pong", | |
| "client_id": client_id, | |
| "timestamp": datetime.now().isoformat() | |
| })) | |
| else: | |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") | |
| except WebSocketDisconnect: | |
| break | |
| except json.JSONDecodeError: | |
| await websocket.send_text(json.dumps({ | |
| "type": "stt_transcription_error", | |
| "client_id": client_id, | |
| "error": "Invalid JSON message format" | |
| })) | |
| except Exception as e: | |
| logger.error(f"Error handling message from {client_id}: {str(e)}") | |
| break | |
| except WebSocketDisconnect: | |
| logger.info(f"Client {client_id} disconnected normally") | |
| except Exception as e: | |
| logger.error(f"WebSocket error for {client_id}: {str(e)}") | |
| finally: | |
| if client_id: | |
| await disconnect_websocket(client_id) | |
| # Add health check route to FastAPI app | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "active_connections": len(active_connections), | |
| "device": device, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Add HTTP transcription endpoint for Streamlit integration | |
| from fastapi import File, Form, UploadFile | |
| async def http_transcribe_endpoint( | |
| file: UploadFile = File(...), | |
| language: str = Form("en"), | |
| model_size_param: str = Form("base") | |
| ): | |
| """HTTP transcription endpoint for Streamlit WebRTC integration""" | |
| try: | |
| # DEBUG: Log incoming HTTP request | |
| logger.info(f"π DEBUG: HTTP transcribe request received") | |
| logger.info(f"π DEBUG: File name: {file.filename}") | |
| logger.info(f"π DEBUG: Content type: {file.content_type}") | |
| logger.info(f"π DEBUG: Language: {language}") | |
| logger.info(f"π DEBUG: Model size: {model_size_param}") | |
| # Save uploaded file | |
| with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: | |
| content = await file.read() | |
| tmp_file.write(content) | |
| temp_path = tmp_file.name | |
| # DEBUG: Log file details | |
| file_size = len(content) | |
| logger.info(f"π DEBUG: Uploaded file size: {file_size} bytes") | |
| logger.info(f"π DEBUG: Temp file path: {temp_path}") | |
| try: | |
| # Transcribe using ZeroGPU function | |
| logger.info(f"π DEBUG: Starting HTTP transcription...") | |
| transcription, status, timing = transcribe_audio_zerogpu( | |
| temp_path, language, model_size_param | |
| ) | |
| logger.info(f"π DEBUG: HTTP transcription result: '{transcription[:100] if transcription else 'None'}...'") | |
| logger.info(f"π DEBUG: HTTP status: {status}") | |
| if status == "success": | |
| return { | |
| "status": "success", | |
| "transcription": transcription, | |
| "timing": timing, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| else: | |
| return { | |
| "status": "error", | |
| "message": "Transcription failed", | |
| "timing": timing, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| finally: | |
| # Clean up | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| logger.info(f"π DEBUG: Cleaned up temp file: {temp_path}") | |
| except Exception as e: | |
| error_msg = f"HTTP transcription error: {e}" | |
| logger.error(f"π DEBUG: {error_msg}") | |
| return { | |
| "status": "error", | |
| "message": f"HTTP transcription failed: {str(e)}", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # For HF Spaces, launch Gradio interface directly | |
| # This avoids FastAPI mounting conflicts causing port issues | |
| demo.launch(show_api=True, show_error=True) | |
| else: | |
| # For programmatic use, create the mounted app | |
| fastapi_app = FastAPI(title="STT WebSocket API") | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app = gr.mount_gradio_app(fastapi_app, demo, path="/") |