stt-gpu-service / app.py
Peter Michael Gits
CRITICAL FIX: Return single value from gradio_transcribe_memory endpoint
489f3e7
#!/usr/bin/env python3
"""
STT WebSocket Service with Gradio + FastAPI Integration
ZeroGPU compatible service with WebSocket endpoints for VoiceCal
Following unmute.sh WebRTC pattern for HuggingFace Spaces
"""
import os
import logging
# Configure logging first
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# CRITICAL: ZeroGPU authentication must happen BEFORE importing spaces
try:
from huggingface_hub import login, HfFolder
# Get and validate HF_TOKEN
hf_token = os.getenv('HF_TOKEN')
logger.info(f"πŸ” DEBUG: HF_TOKEN environment variable: {'SET' if hf_token else 'NOT SET'}")
if hf_token:
# Try multiple authentication methods for ZeroGPU compatibility
# Method 1: Standard huggingface_hub login
login(token=hf_token, write_permission=True)
# Method 2: Explicitly set in HfFolder (legacy method)
HfFolder.save_token(hf_token)
# Method 3: Ensure environment variable is explicitly set
os.environ['HF_TOKEN'] = hf_token
os.environ['HUGGINGFACE_HUB_TOKEN'] = hf_token # Alternative env var
# Method 4: ZeroGPU-specific environment variables (from community reports)
os.environ['ZEROGPU_V2'] = 'true'
os.environ['ZERO_GPU_PATCH_TORCH_DEVICE'] = '1'
logger.info(f"πŸ”‘ DEBUG: HuggingFace authentication successful (early init) - Token length: {len(hf_token)}")
logger.info(f"πŸ”‘ DEBUG: Token prefix: {hf_token[:10]}...{hf_token[-4:]}")
else:
logger.warning(f"⚠️ DEBUG: No HF_TOKEN found - using anonymous access (limited quota)")
except Exception as e:
logger.error(f"❌ DEBUG: HuggingFace login failed (early init): {e}")
import traceback
logger.error(f"❌ DEBUG: Full traceback: {traceback.format_exc()}")
# Now import everything else
import asyncio
import json
import uuid
import base64
import tempfile
from datetime import datetime
from typing import Optional, Dict, Any
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torchaudio
import soundfile as sf
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
import spaces
# Version info
__version__ = "1.0.4"
__service__ = "STT WebSocket Service"
# Global variables for model (required for ZeroGPU decorator)
model = None
processor = None
model_size = "base"
device = "cuda" if torch.cuda.is_available() else "cpu"
@spaces.GPU(duration=30)
def transcribe_audio_zerogpu(
audio_path: str,
language: str = "en",
model_size_param: str = "base"
) -> tuple[str, str, Dict[str, Any]]:
"""Transcribe audio file using Whisper with ZeroGPU"""
global model, processor, model_size, device
try:
start_time = datetime.now()
# Load model if not already loaded
if model is None:
logger.info(f"Loading Whisper {model_size_param} model...")
model_name = f"openai/whisper-{model_size_param}"
processor = WhisperProcessor.from_pretrained(model_name)
model = WhisperForConditionalGeneration.from_pretrained(model_name)
if device == "cuda":
model = model.to(device)
logger.info(f"βœ… Model loaded on {device}")
# Load and preprocess audio (following unmute.sh pattern)
audio_input, sample_rate = torchaudio.load(audio_path)
# Convert to 16kHz mono (Whisper requirement)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
audio_input = resampler(audio_input)
if audio_input.shape[0] > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
audio_array = audio_input.squeeze().numpy()
# Process with Whisper
inputs = processor(
audio_array,
sampling_rate=16000,
return_tensors="pt"
)
if device == "cuda":
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate transcription
with torch.no_grad():
predicted_ids = model.generate(**inputs)
transcription = processor.batch_decode(
predicted_ids,
skip_special_tokens=True
)[0]
# Calculate timing
end_time = datetime.now()
processing_time = (end_time - start_time).total_seconds()
timing_info = {
"processing_time": processing_time,
"start_time": start_time.isoformat(),
"end_time": end_time.isoformat(),
"model_size": model_size_param,
"device": device
}
logger.info(f"Transcription completed in {processing_time:.2f}s: '{transcription[:50]}...'")
return transcription.strip(), "success", timing_info
except Exception as e:
logger.error(f"Transcription error: {str(e)}")
return "", "error", {"error": str(e)}
# Global WebSocket connection tracker
active_connections: Dict[str, WebSocket] = {}
# Simple Gradio interface for HF Spaces compliance
def get_service_info():
"""Simple function for Gradio interface"""
return f"""
# 🎀 STT WebSocket Service v{__version__}
**WebSocket Endpoint:** `/ws/stt`
**Model:** Whisper {model_size}
**Device:** {device}
**ZeroGPU:** {'βœ… Available' if torch.cuda.is_available() else '❌ Not Available'}
**Status:** Ready for WebSocket connections
Connect your WebRTC client to: `wss://your-space.hf.space/ws/stt`
"""
def gradio_transcribe_wrapper(audio_file, language="en", model_size_param="base"):
"""Gradio wrapper for transcription function"""
try:
# DEBUG: Log all incoming requests
logger.info(f"🎀 DEBUG: Gradio transcription request received")
logger.info(f"🎀 DEBUG: Audio file: {audio_file}")
logger.info(f"🎀 DEBUG: Language: {language}")
logger.info(f"🎀 DEBUG: Model size: {model_size_param}")
if audio_file is None:
logger.warning("🎀 DEBUG: No audio file provided to Gradio wrapper")
return "❌ No audio file provided", "{}", "Please upload an audio file"
# DEBUG: Check file details
if isinstance(audio_file, str) and os.path.exists(audio_file):
file_size = os.path.getsize(audio_file)
logger.info(f"🎀 DEBUG: Audio file size: {file_size} bytes")
logger.info(f"🎀 DEBUG: Audio file path: {audio_file}")
else:
logger.warning(f"🎀 DEBUG: Invalid audio file: {type(audio_file)}")
# Use the ZeroGPU transcription function
logger.info(f"🎀 DEBUG: Calling transcribe_audio_zerogpu...")
transcription, status, timing = transcribe_audio_zerogpu(
audio_file, language, model_size_param
)
logger.info(f"🎀 DEBUG: Transcription result: '{transcription[:100]}...'")
logger.info(f"🎀 DEBUG: Status: {status}")
if status == "success":
return f"βœ… {transcription}", json.dumps(timing, indent=2), f"Status: {status}"
else:
return f"❌ Transcription failed", json.dumps(timing, indent=2), f"Status: {status}"
except Exception as e:
error_msg = f"Error in gradio_transcribe_wrapper: {str(e)}"
logger.error(f"🎀 DEBUG: {error_msg}")
return f"❌ Error: {str(e)}", "{}", "Error occurred during transcription"
def gradio_transcribe_memory(audio_base64: str, language: str = "en", model_size_param: str = "base"):
"""πŸš€ NEW: In-memory transcription function - TRUE UNMUTE.SH METHODOLOGY"""
try:
import base64
import io
# DEBUG: Log in-memory request
logger.info(f"πŸš€ MEMORY: In-memory transcription request received")
logger.info(f"πŸš€ MEMORY: Audio base64 length: {len(audio_base64)} chars")
logger.info(f"πŸš€ MEMORY: Language: {language}")
logger.info(f"πŸš€ MEMORY: Model size: {model_size_param}")
if not audio_base64 or audio_base64 == "":
logger.warning("πŸš€ MEMORY: No audio data provided")
return "❌ No audio data provided"
# Decode base64 to binary audio data
try:
audio_binary = base64.b64decode(audio_base64)
logger.info(f"πŸš€ MEMORY: Decoded audio size: {len(audio_binary)} bytes")
except Exception as decode_error:
logger.error(f"πŸš€ MEMORY: Base64 decode error: {decode_error}")
return f"❌ Invalid base64 audio data: {decode_error}"
# Save to temporary file for Whisper processing (still needed for torchaudio.load)
with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file:
tmp_file.write(audio_binary)
temp_path = tmp_file.name
try:
# Use the same ZeroGPU transcription function
logger.info(f"πŸš€ MEMORY: Processing with ZeroGPU...")
transcription, status, timing = transcribe_audio_zerogpu(
temp_path, language, model_size_param
)
logger.info(f"πŸš€ MEMORY: Transcription result: '{transcription[:100] if transcription else 'None'}...'")
logger.info(f"πŸš€ MEMORY: Status: {status}")
if status == "success":
return transcription # Return only transcription for voiceCal-ai compatibility
else:
return "❌ Transcription failed"
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
logger.info(f"πŸš€ MEMORY: Cleaned up temp file")
except Exception as e:
error_msg = f"Error in gradio_transcribe_memory: {str(e)}"
logger.error(f"πŸš€ MEMORY: {error_msg}")
return f"❌ Error: {str(e)}"
# Create Gradio interface with transcription functionality
with gr.Blocks(title="🎀 STT WebSocket Service v1.0.0") as demo:
gr.Markdown("""
# 🎀 STT WebSocket Service v1.0.0
**WebSocket-enabled Speech-to-Text service with ZeroGPU acceleration**
## πŸ“‹ Service Information
- **Service:** STT WebSocket Service v1.0.0
- **ZeroGPU:** Enabled with H200 acceleration
- **WebSocket Endpoint:** `wss://your-space.hf.space/ws/stt`
- **HTTP API:** Available for direct file uploads
## 🎀 Upload Audio for Transcription
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Audio File",
type="filepath",
sources=["upload", "microphone"]
)
language_input = gr.Dropdown(
choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"],
value="en",
label="Language (English by default)"
)
model_input = gr.Dropdown(
choices=["tiny", "base", "small", "medium", "large-v2"],
value="base",
label="Model Size"
)
transcribe_btn = gr.Button("🎀 Transcribe Audio", variant="primary")
with gr.Column():
transcription_output = gr.Textbox(
label="Transcription Result",
lines=4,
placeholder="Transcription will appear here..."
)
timing_output = gr.Code(
label="Timing Information",
language="json"
)
status_output = gr.Textbox(
label="Status",
lines=1
)
transcribe_btn.click(
fn=gradio_transcribe_wrapper,
inputs=[audio_input, language_input, model_input],
outputs=[transcription_output, timing_output, status_output]
)
# πŸš€ NEW: In-Memory Transcription Interface (TRUE UNMUTE.SH)
gr.Markdown("""
---
## πŸš€ In-Memory Transcription (WebRTC Streaming)
**For real-time audio processing without file uploads**
""")
with gr.Row():
with gr.Column():
audio_base64_input = gr.Textbox(
label="Audio Base64 Data",
placeholder="Base64 encoded audio data will be sent here from WebRTC...",
lines=3
)
memory_language_input = gr.Dropdown(
choices=["en", "auto", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"],
value="en",
label="Language (English by default)"
)
memory_model_input = gr.Dropdown(
choices=["tiny", "base", "small", "medium", "large-v2"],
value="base",
label="Model Size"
)
memory_transcribe_btn = gr.Button("πŸš€ Transcribe In-Memory", variant="secondary")
with gr.Column():
memory_transcription_output = gr.Textbox(
label="In-Memory Transcription Result",
lines=4,
placeholder="Real-time transcription will appear here..."
)
memory_timing_output = gr.Code(
label="Performance Metrics",
language="json"
)
memory_status_output = gr.Textbox(
label="Processing Status",
lines=1
)
memory_transcribe_btn.click(
fn=gradio_transcribe_memory,
inputs=[audio_base64_input, memory_language_input, memory_model_input],
outputs=[memory_transcription_output, memory_timing_output, memory_status_output]
)
# Create FastAPI app for WebSocket endpoints
fastapi_app = FastAPI(
title="STT WebSocket Service",
version=__version__
)
# Add CORS middleware for WebRTC
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@fastapi_app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {
"service": __service__,
"version": __version__,
"status": "healthy",
"model_loaded": model is not None,
"active_connections": len(active_connections),
"device": device,
"timestamp": datetime.now().isoformat()
}
async def connect_websocket(websocket: WebSocket) -> str:
"""Accept WebSocket connection and return client ID"""
client_id = str(uuid.uuid4())
await websocket.accept()
active_connections[client_id] = websocket
# Send connection confirmation
await websocket.send_text(json.dumps({
"type": "stt_connection_confirmed",
"client_id": client_id,
"service": __service__,
"version": __version__,
"model": f"whisper-{model_size}",
"device": device,
"message": "STT WebSocket connected and ready"
}))
logger.info(f"Client {client_id} connected")
return client_id
async def disconnect_websocket(client_id: str):
"""Clean up WebSocket connection"""
if client_id in active_connections:
del active_connections[client_id]
logger.info(f"Client {client_id} disconnected")
async def process_audio_message(client_id: str, message: Dict[str, Any]):
"""Process incoming audio data from WebSocket"""
try:
websocket = active_connections[client_id]
# Extract audio data (base64 encoded)
audio_data_b64 = message.get("audio_data")
if not audio_data_b64:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "No audio data provided"
}))
return
# Decode base64 audio
audio_bytes = base64.b64decode(audio_data_b64)
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file:
tmp_file.write(audio_bytes)
temp_path = tmp_file.name
try:
# Transcribe audio using global ZeroGPU function
transcription, status, timing = transcribe_audio_zerogpu(
temp_path,
message.get("language", "en"),
message.get("model_size", model_size)
)
# Send result back
if status == "success" and transcription:
await websocket.send_text(json.dumps({
"type": "stt_transcription_complete",
"client_id": client_id,
"transcription": transcription,
"timing": timing,
"status": "success"
}))
else:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "Transcription failed or empty result",
"timing": timing
}))
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.unlink(temp_path)
except Exception as e:
logger.error(f"Error processing audio for {client_id}: {str(e)}")
if client_id in active_connections:
websocket = active_connections[client_id]
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": f"Processing error: {str(e)}"
}))
# WebSocket endpoint will be added to Gradio's FastAPI app in main block
# For HuggingFace Spaces - we need to launch the Gradio demo
# and add WebSocket routes to its internal FastAPI app
if __name__ == "__main__":
logger.info(f"🎀 DEBUG: Starting {__service__} v{__version__} with Gradio+WebSocket integration")
logger.info(f"🎀 DEBUG: Device: {device}")
logger.info(f"🎀 DEBUG: Model size: {model_size}")
logger.info(f"🎀 DEBUG: Default language: English (en)")
logger.info(f"🎀 DEBUG: Service ready for connections")
# Create FastAPI app for WebSocket endpoints
fastapi_app = FastAPI(title="STT WebSocket API")
# Add CORS middleware
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@fastapi_app.websocket("/ws/stt")
async def websocket_stt_endpoint(websocket: WebSocket):
"""Main STT WebSocket endpoint"""
client_id = None
try:
# Accept connection
client_id = await connect_websocket(websocket)
# Handle messages
while True:
try:
# Receive message
data = await websocket.receive_text()
message = json.loads(data)
# Process based on message type
message_type = message.get("type", "unknown")
if message_type == "stt_audio_chunk":
await process_audio_message(client_id, message)
elif message_type == "ping":
# Respond to ping
await websocket.send_text(json.dumps({
"type": "pong",
"client_id": client_id,
"timestamp": datetime.now().isoformat()
}))
else:
logger.warning(f"Unknown message type from {client_id}: {message_type}")
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_text(json.dumps({
"type": "stt_transcription_error",
"client_id": client_id,
"error": "Invalid JSON message format"
}))
except Exception as e:
logger.error(f"Error handling message from {client_id}: {str(e)}")
break
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected normally")
except Exception as e:
logger.error(f"WebSocket error for {client_id}: {str(e)}")
finally:
if client_id:
await disconnect_websocket(client_id)
# Add health check route to FastAPI app
@fastapi_app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {
"service": __service__,
"version": __version__,
"status": "healthy",
"model_loaded": model is not None,
"active_connections": len(active_connections),
"device": device,
"timestamp": datetime.now().isoformat()
}
# Add HTTP transcription endpoint for Streamlit integration
from fastapi import File, Form, UploadFile
@fastapi_app.post("/api/transcribe")
async def http_transcribe_endpoint(
file: UploadFile = File(...),
language: str = Form("en"),
model_size_param: str = Form("base")
):
"""HTTP transcription endpoint for Streamlit WebRTC integration"""
try:
# DEBUG: Log incoming HTTP request
logger.info(f"🌐 DEBUG: HTTP transcribe request received")
logger.info(f"🌐 DEBUG: File name: {file.filename}")
logger.info(f"🌐 DEBUG: Content type: {file.content_type}")
logger.info(f"🌐 DEBUG: Language: {language}")
logger.info(f"🌐 DEBUG: Model size: {model_size_param}")
# Save uploaded file
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as tmp_file:
content = await file.read()
tmp_file.write(content)
temp_path = tmp_file.name
# DEBUG: Log file details
file_size = len(content)
logger.info(f"🌐 DEBUG: Uploaded file size: {file_size} bytes")
logger.info(f"🌐 DEBUG: Temp file path: {temp_path}")
try:
# Transcribe using ZeroGPU function
logger.info(f"🌐 DEBUG: Starting HTTP transcription...")
transcription, status, timing = transcribe_audio_zerogpu(
temp_path, language, model_size_param
)
logger.info(f"🌐 DEBUG: HTTP transcription result: '{transcription[:100] if transcription else 'None'}...'")
logger.info(f"🌐 DEBUG: HTTP status: {status}")
if status == "success":
return {
"status": "success",
"transcription": transcription,
"timing": timing,
"timestamp": datetime.now().isoformat()
}
else:
return {
"status": "error",
"message": "Transcription failed",
"timing": timing,
"timestamp": datetime.now().isoformat()
}
finally:
# Clean up
if os.path.exists(temp_path):
os.unlink(temp_path)
logger.info(f"🌐 DEBUG: Cleaned up temp file: {temp_path}")
except Exception as e:
error_msg = f"HTTP transcription error: {e}"
logger.error(f"🌐 DEBUG: {error_msg}")
return {
"status": "error",
"message": f"HTTP transcription failed: {str(e)}",
"timestamp": datetime.now().isoformat()
}
# For HF Spaces, launch Gradio interface directly
# This avoids FastAPI mounting conflicts causing port issues
demo.launch(show_api=True, show_error=True)
else:
# For programmatic use, create the mounted app
fastapi_app = FastAPI(title="STT WebSocket API")
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app = gr.mount_gradio_app(fastapi_app, demo, path="/")