#!/usr/bin/env python3 """ STT WebSocket Service with Gradio + FastAPI Integration ZeroGPU compatible service with WebSocket endpoints for VoiceCal Following unmute.sh WebRTC pattern for HuggingFace Spaces """ import os import logging # Configure logging first logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # CRITICAL: ZeroGPU authentication must happen BEFORE importing spaces try: from huggingface_hub import login, HfFolder # Get and validate HF_TOKEN hf_token = os.getenv('HF_TOKEN') logger.info(f"🔍 DEBUG: HF_TOKEN environment variable: {'SET' if hf_token else 'NOT SET'}") if hf_token: # Try multiple authentication methods for ZeroGPU compatibility # Method 1: Standard huggingface_hub login login(token=hf_token, write_permission=True) # Method 2: Explicitly set in HfFolder (legacy method) HfFolder.save_token(hf_token) # Method 3: Ensure environment variable is explicitly set os.environ['HF_TOKEN'] = hf_token os.environ['HUGGINGFACE_HUB_TOKEN'] = hf_token # Alternative env var # Method 4: ZeroGPU-specific environment variables (from community reports) os.environ['ZEROGPU_V2'] = 'true' os.environ['ZERO_GPU_PATCH_TORCH_DEVICE'] = '1' logger.info(f"🔑 DEBUG: HuggingFace authentication successful (early init) - Token length: {len(hf_token)}") logger.info(f"🔑 DEBUG: Token prefix: {hf_token[:10]}...{hf_token[-4:]}") else: logger.warning(f"⚠️ DEBUG: No HF_TOKEN found - using anonymous access (limited quota)") except Exception as e: logger.error(f"❌ DEBUG: HuggingFace login failed (early init): {e}") import traceback logger.error(f"❌ DEBUG: Full traceback: {traceback.format_exc()}") # Now import everything else import asyncio import json import uuid import base64 import tempfile from datetime import datetime from typing import Optional, Dict, Any import torch from transformers import WhisperProcessor, WhisperForConditionalGeneration import torchaudio import soundfile as sf import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware import gradio as gr import spaces # Version info __version__ = "1.0.4" __service__ = "STT WebSocket Service" # Global variables for model (required for ZeroGPU decorator) model = None processor = None model_size = "base" device = "cuda" if torch.cuda.is_available() else "cpu" @spaces.GPU(duration=30) def transcribe_audio_zerogpu( audio_path: str, language: str = "en", model_size_param: str = "base" ) -> tuple[str, str, Dict[str, Any]]: """Transcribe audio file using Whisper with ZeroGPU""" global model, processor, model_size, device try: start_time = datetime.now() # Load model if not already loaded if model is None: logger.info(f"Loading Whisper {model_size_param} model...") model_name = f"openai/whisper-{model_size_param}" processor = WhisperProcessor.from_pretrained(model_name) model = WhisperForConditionalGeneration.from_pretrained(model_name) if device == "cuda": model = model.to(device) logger.info(f"✅ Model loaded on {device}") # Load and preprocess audio (following unmute.sh pattern) audio_input, sample_rate = torchaudio.load(audio_path) # Convert to 16kHz mono (Whisper requirement) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(sample_rate, 16000) audio_input = resampler(audio_input) if audio_input.shape[0] > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) audio_array = audio_input.squeeze().numpy() # Process with Whisper inputs = processor( audio_array, sampling_rate=16000, return_tensors="pt" ) if device == "cuda": inputs = {k: v.to(device) for k, v in inputs.items()} # Generate transcription with torch.no_grad(): predicted_ids = model.generate(**inputs) transcription = processor.batch_decode( predicted_ids, skip_special_tokens=True )[0] # Calculate timing end_time = datetime.now() processing_time = (end_time - start_time).total_seconds() timing_info = { "processing_time": processing_time, "start_time": start_time.isoformat(), "end_time": end_time.isoformat(), "model_size": model_size_param, "device": device } logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'") return transcription.strip(), "success", timing_info except Exception as e: logger.error(f"Transcription error: {str(e)}") return "", "error", {"error": str(e)} # Global WebSocket connection tracker active_connections: Dict[str, WebSocket] = {} # Simple Gradio interface for HF Spaces compliance def get_service_info(): """Simple function for Gradio interface""" return f""" # 🎤 STT WebSocket Service v{__version__} **WebSocket Endpoint:** `/ws/stt` **Model:** Whisper {model_size} **Device:** {device} **ZeroGPU:** {'✅ Available' if torch.cuda.is_available() else '❌ Not Available'} **Status:** Ready for WebSocket connections Connect your WebRTC client to: `wss://your-space.hf.space/ws/stt` """ def gradio_transcribe_wrapper(audio_file, language="en", model_size_param="base"): """Gradio wrapper for transcription function""" try: # DEBUG: Log all incoming requests logger.info(f"🎤 DEBUG: Gradio transcription request received") logger.info(f"🎤 DEBUG: Audio file: {audio_file}") logger.info(f"🎤 DEBUG: Language: {language}") logger.info(f"🎤 DEBUG: Model size: {model_size_param}") if audio_file is None: logger.warning("🎤 DEBUG: No audio file provided to Gradio wrapper") return "❌ No audio file provided", "{}", "Please upload an audio file" # DEBUG: Check file details if isinstance(audio_file, str) and os.path.exists(audio_file): file_size = os.path.getsize(audio_file) logger.info(f"🎤 DEBUG: Audio file size: {file_size} bytes") logger.info(f"🎤 DEBUG: Audio file path: {audio_file}") else: logger.warning(f"🎤 DEBUG: Invalid audio file: {type(audio_file)}") # Use the ZeroGPU transcription function logger.info(f"🎤 DEBUG: Calling transcribe_audio_zerogpu...") transcription, status, timing = transcribe_audio_zerogpu( audio_file, language, model_size_param ) logger.info(f"🎤 DEBUG: Transcription result: '{transcription[:100]}...'") logger.info(f"🎤 DEBUG: Status: {status}") if status == "success": return f"✅ {transcription}", json.dumps(timing, indent=2), f"Status: {status}" else: return f"❌ Transcription failed", json.dumps(timing, indent=2), f"Status: {status}" except Exception as e: error_msg = f"Error in gradio_transcribe_wrapper: {str(e)}" logger.error(f"🎤 DEBUG: {error_msg}") return f"❌ Error: {str(e)}", "{}", "Error occurred during transcription" def gradio_transcribe_memory(audio_base64: str, language: str = "en", model_size_param: str = "base"): """🚀 NEW: In-memory transcription function - TRUE UNMUTE.SH METHODOLOGY""" try: import base64 import io # DEBUG: Log in-memory request logger.info(f"🚀 MEMORY: In-memory transcription request received") logger.info(f"🚀 MEMORY: Audio base64 length: {len(audio_base64)} chars") logger.info(f"🚀 MEMORY: Language: {language}") logger.info(f"🚀 MEMORY: Model size: {model_size_param}") if not audio_base64 or audio_base64 == "": logger.warning("🚀 MEMORY: No audio data provided") return "❌ No audio data provided" # Decode base64 to binary audio data try: audio_binary = base64.b64decode(audio_base64) logger.info(f"🚀 MEMORY: Decoded audio size: {len(audio_binary)} bytes") except Exception as decode_error: logger.error(f"🚀 MEMORY: Base64 decode error: {decode_error}") return f"❌ Invalid base64 audio data: {decode_error}" # Save to temporary file for Whisper processing (still needed for torchaudio.load) with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: tmp_file.write(audio_binary) temp_path = tmp_file.name try: # Use the same ZeroGPU transcription function logger.info(f"🚀 MEMORY: Processing with ZeroGPU...") transcription, status, timing = transcribe_audio_zerogpu( temp_path, language, model_size_param ) logger.info(f"🚀 MEMORY: Transcription result: '{transcription[:100] if transcription else 'None'}...'") logger.info(f"🚀 MEMORY: Status: {status}") if status == "success": return transcription # Return only transcription for voiceCal-ai compatibility else: return "❌ Transcription failed" finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) logger.info(f"🚀 MEMORY: Cleaned up temp file") except Exception as e: error_msg = f"Error in gradio_transcribe_memory: {str(e)}" logger.error(f"🚀 MEMORY: {error_msg}") return f"❌ Error: {str(e)}" # Create Gradio interface with transcription functionality with gr.Blocks(title="🎤 STT WebSocket Service v1.0.0") as demo: gr.Markdown(""" # 🎤 STT WebSocket Service v1.0.0 **WebSocket-enabled Speech-to-Text service with ZeroGPU acceleration** ## 📋 Service Information - **Service:** STT WebSocket Service v1.0.0 - **ZeroGPU:** Enabled with H200 acceleration - **WebSocket Endpoint:** `wss://your-space.hf.space/ws/stt` - **HTTP API:** Available for direct file uploads ## 🎤 Upload Audio for Transcription """) with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Upload Audio File", type="filepath", sources=["upload", "microphone"] ) language_input = gr.Dropdown( choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], value="en", label="Language (English by default)" ) model_input = gr.Dropdown( choices=["tiny", "base", "small", "medium", "large-v2"], value="base", label="Model Size" ) transcribe_btn = gr.Button("🎤 Transcribe Audio", variant="primary") with gr.Column(): transcription_output = gr.Textbox( label="Transcription Result", lines=4, placeholder="Transcription will appear here..." ) timing_output = gr.Code( label="Timing Information", language="json" ) status_output = gr.Textbox( label="Status", lines=1 ) transcribe_btn.click( fn=gradio_transcribe_wrapper, inputs=[audio_input, language_input, model_input], outputs=[transcription_output, timing_output, status_output] ) # 🚀 NEW: In-Memory Transcription Interface (TRUE UNMUTE.SH) gr.Markdown(""" --- ## 🚀 In-Memory Transcription (WebRTC Streaming) **For real-time audio processing without file uploads** """) with gr.Row(): with gr.Column(): audio_base64_input = gr.Textbox( label="Audio Base64 Data", placeholder="Base64 encoded audio data will be sent here from WebRTC...", lines=3 ) memory_language_input = gr.Dropdown( choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], value="en", label="Language (English by default)" ) memory_model_input = gr.Dropdown( choices=["tiny", "base", "small", "medium", "large-v2"], value="base", label="Model Size" ) memory_transcribe_btn = gr.Button("🚀 Transcribe In-Memory", variant="secondary") with gr.Column(): memory_transcription_output = gr.Textbox( label="In-Memory Transcription Result", lines=4, placeholder="Real-time transcription will appear here..." ) memory_timing_output = gr.Code( label="Performance Metrics", language="json" ) memory_status_output = gr.Textbox( label="Processing Status", lines=1 ) memory_transcribe_btn.click( fn=gradio_transcribe_memory, inputs=[audio_base64_input, memory_language_input, memory_model_input], outputs=[memory_transcription_output, memory_timing_output, memory_status_output] ) # Create FastAPI app for WebSocket endpoints fastapi_app = FastAPI( title="STT WebSocket Service", version=__version__ ) # Add CORS middleware for WebRTC fastapi_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @fastapi_app.get("/api/health") async def health_check(): """Health check endpoint""" return { "service": __service__, "version": __version__, "status": "healthy", "model_loaded": model is not None, "active_connections": len(active_connections), "device": device, "timestamp": datetime.now().isoformat() } async def connect_websocket(websocket: WebSocket) -> str: """Accept WebSocket connection and return client ID""" client_id = str(uuid.uuid4()) await websocket.accept() active_connections[client_id] = websocket # Send connection confirmation await websocket.send_text(json.dumps({ "type": "stt_connection_confirmed", "client_id": client_id, "service": __service__, "version": __version__, "model": f"whisper-{model_size}", "device": device, "message": "STT WebSocket connected and ready" })) logger.info(f"Client {client_id} connected") return client_id async def disconnect_websocket(client_id: str): """Clean up WebSocket connection""" if client_id in active_connections: del active_connections[client_id] logger.info(f"Client {client_id} disconnected") async def process_audio_message(client_id: str, message: Dict[str, Any]): """Process incoming audio data from WebSocket""" try: websocket = active_connections[client_id] # Extract audio data (base64 encoded) audio_data_b64 = message.get("audio_data") if not audio_data_b64: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "No audio data provided" })) return # Decode base64 audio audio_bytes = base64.b64decode(audio_data_b64) # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: tmp_file.write(audio_bytes) temp_path = tmp_file.name try: # Transcribe audio using global ZeroGPU function transcription, status, timing = transcribe_audio_zerogpu( temp_path, message.get("language", "en"), message.get("model_size", model_size) ) # Send result back if status == "success" and transcription: await websocket.send_text(json.dumps({ "type": "stt_transcription_complete", "client_id": client_id, "transcription": transcription, "timing": timing, "status": "success" })) else: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "Transcription failed or empty result", "timing": timing })) finally: # Clean up temp file if os.path.exists(temp_path): os.unlink(temp_path) except Exception as e: logger.error(f"Error processing audio for {client_id}: {str(e)}") if client_id in active_connections: websocket = active_connections[client_id] await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": f"Processing error: {str(e)}" })) # WebSocket endpoint will be added to Gradio's FastAPI app in main block # For HuggingFace Spaces - we need to launch the Gradio demo # and add WebSocket routes to its internal FastAPI app if __name__ == "__main__": logger.info(f"🎤 DEBUG: Starting {__service__} v{__version__} with Gradio+WebSocket integration") logger.info(f"🎤 DEBUG: Device: {device}") logger.info(f"🎤 DEBUG: Model size: {model_size}") logger.info(f"🎤 DEBUG: Default language: English (en)") logger.info(f"🎤 DEBUG: Service ready for connections") # Create FastAPI app for WebSocket endpoints fastapi_app = FastAPI(title="STT WebSocket API") # Add CORS middleware fastapi_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @fastapi_app.websocket("/ws/stt") async def websocket_stt_endpoint(websocket: WebSocket): """Main STT WebSocket endpoint""" client_id = None try: # Accept connection client_id = await connect_websocket(websocket) # Handle messages while True: try: # Receive message data = await websocket.receive_text() message = json.loads(data) # Process based on message type message_type = message.get("type", "unknown") if message_type == "stt_audio_chunk": await process_audio_message(client_id, message) elif message_type == "ping": # Respond to ping await websocket.send_text(json.dumps({ "type": "pong", "client_id": client_id, "timestamp": datetime.now().isoformat() })) else: logger.warning(f"Unknown message type from {client_id}: {message_type}") except WebSocketDisconnect: break except json.JSONDecodeError: await websocket.send_text(json.dumps({ "type": "stt_transcription_error", "client_id": client_id, "error": "Invalid JSON message format" })) except Exception as e: logger.error(f"Error handling message from {client_id}: {str(e)}") break except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected normally") except Exception as e: logger.error(f"WebSocket error for {client_id}: {str(e)}") finally: if client_id: await disconnect_websocket(client_id) # Add health check route to FastAPI app @fastapi_app.get("/api/health") async def health_check(): """Health check endpoint""" return { "service": __service__, "version": __version__, "status": "healthy", "model_loaded": model is not None, "active_connections": len(active_connections), "device": device, "timestamp": datetime.now().isoformat() } # Add HTTP transcription endpoint for Streamlit integration from fastapi import File, Form, UploadFile @fastapi_app.post("/api/transcribe") async def http_transcribe_endpoint( file: UploadFile = File(...), language: str = Form("en"), model_size_param: str = Form("base") ): """HTTP transcription endpoint for Streamlit WebRTC integration""" try: # DEBUG: Log incoming HTTP request logger.info(f"🌐 DEBUG: HTTP transcribe request received") logger.info(f"🌐 DEBUG: File name: {file.filename}") logger.info(f"🌐 DEBUG: Content type: {file.content_type}") logger.info(f"🌐 DEBUG: Language: {language}") logger.info(f"🌐 DEBUG: Model size: {model_size_param}") # Save uploaded file with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file: content = await file.read() tmp_file.write(content) temp_path = tmp_file.name # DEBUG: Log file details file_size = len(content) logger.info(f"🌐 DEBUG: Uploaded file size: {file_size} bytes") logger.info(f"🌐 DEBUG: Temp file path: {temp_path}") try: # Transcribe using ZeroGPU function logger.info(f"🌐 DEBUG: Starting HTTP transcription...") transcription, status, timing = transcribe_audio_zerogpu( temp_path, language, model_size_param ) logger.info(f"🌐 DEBUG: HTTP transcription result: '{transcription[:100] if transcription else 'None'}...'") logger.info(f"🌐 DEBUG: HTTP status: {status}") if status == "success": return { "status": "success", "transcription": transcription, "timing": timing, "timestamp": datetime.now().isoformat() } else: return { "status": "error", "message": "Transcription failed", "timing": timing, "timestamp": datetime.now().isoformat() } finally: # Clean up if os.path.exists(temp_path): os.unlink(temp_path) logger.info(f"🌐 DEBUG: Cleaned up temp file: {temp_path}") except Exception as e: error_msg = f"HTTP transcription error: {e}" logger.error(f"🌐 DEBUG: {error_msg}") return { "status": "error", "message": f"HTTP transcription failed: {str(e)}", "timestamp": datetime.now().isoformat() } # For HF Spaces, launch Gradio interface directly # This avoids FastAPI mounting conflicts causing port issues demo.launch(show_api=True, show_error=True) else: # For programmatic use, create the mounted app fastapi_app = FastAPI(title="STT WebSocket API") fastapi_app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app = gr.mount_gradio_app(fastapi_app, demo, path="/")