import asyncio import json import time import logging from typing import Optional import torch import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse, HTMLResponse import uvicorn # Version tracking VERSION = "1.3.0" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global Moshi model variables mimi = None moshi = None lm_gen = None device = None async def load_moshi_models(): """Load Moshi STT models on startup""" global mimi, moshi, lm_gen, device try: logger.info("Loading Moshi models...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") try: from huggingface_hub import hf_hub_download from moshi.models import loaders, LMGen # Load Mimi (audio codec) logger.info("Loading Mimi audio codec...") mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME) mimi = loaders.get_mimi(mimi_weight, device=device) mimi.set_num_codebooks(8) # Limited to 8 for Moshi # Load Moshi (language model) logger.info("Loading Moshi language model...") moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME) moshi = loaders.get_moshi_lm(moshi_weight, device=device) lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) logger.info("✅ Moshi models loaded successfully") return True except Exception as model_error: logger.error(f"Failed to load Moshi models: {model_error}") # Set mock mode mimi = "mock" moshi = "mock" lm_gen = "mock" return False except Exception as e: logger.error(f"Error in load_moshi_models: {e}") mimi = "mock" moshi = "mock" lm_gen = "mock" return False def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str: """Transcribe audio using Moshi models""" try: if mimi == "mock": duration = len(audio_data) / sample_rate return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz" # Ensure 24kHz audio for Moshi if sample_rate != 24000: import librosa audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) # Convert to torch tensor wav = torch.from_numpy(audio_data).unsqueeze(0).unsqueeze(0).to(device) # Process with Mimi codec in streaming mode with torch.no_grad(), mimi.streaming(batch_size=1): all_codes = [] frame_size = mimi.frame_size for offset in range(0, wav.shape[-1], frame_size): frame = wav[:, :, offset: offset + frame_size] if frame.shape[-1] == 0: break # Pad last frame if needed if frame.shape[-1] < frame_size: padding = frame_size - frame.shape[-1] frame = torch.nn.functional.pad(frame, (0, padding)) codes = mimi.encode(frame) all_codes.append(codes) # Concatenate all codes if all_codes: audio_tokens = torch.cat(all_codes, dim=-1) # Generate text with language model with torch.no_grad(): # Simple text generation from audio tokens # This is a simplified approach - Moshi has more complex generation text_output = lm_gen.generate_text_from_audio(audio_tokens) return text_output if text_output else "Transcription completed" return "No audio tokens generated" except Exception as e: logger.error(f"Moshi transcription error: {e}") return f"Error: {str(e)}" # FastAPI app app = FastAPI( title="STT GPU Service Python v4 - Moshi", description="Real-time WebSocket STT streaming with Moshi PyTorch implementation", version=VERSION ) @app.on_event("startup") async def startup_event(): """Load Moshi models on startup""" await load_moshi_models() @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "message": "Moshi STT WebSocket Service - Real-time streaming ready", "space_name": "stt-gpu-service-python-v4", "mimi_loaded": mimi is not None and mimi != "mock", "moshi_loaded": moshi is not None and moshi != "mock", "device": str(device) if device else "unknown", "expected_sample_rate": "24000Hz" } @app.get("/", response_class=HTMLResponse) async def get_index(): """Simple HTML interface for testing""" html_content = f""" STT GPU Service Python v4 - Moshi

🎙️ STT GPU Service Python v4 - Moshi

Real-time WebSocket speech transcription with Moshi PyTorch implementation

🔗 Moshi WebSocket Streaming Test

Status: Disconnected

Expected: 24kHz audio chunks (80ms = ~1920 samples)

Moshi transcription output will appear here...

v{VERSION} (SHA: {COMMIT_SHA}) - Moshi STT Implementation
""" return HTMLResponse(content=html_content) @app.websocket("/ws/stream") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time Moshi STT streaming""" await websocket.accept() logger.info("Moshi WebSocket connection established") try: # Send initial connection confirmation await websocket.send_json({ "type": "connection", "status": "connected", "message": "Moshi STT WebSocket ready for audio chunks", "chunk_size_ms": 80, "expected_sample_rate": 24000, "expected_chunk_samples": 1920, # 80ms at 24kHz "model": "Moshi PyTorch implementation" }) while True: # Receive audio data data = await websocket.receive_json() if data.get("type") == "audio_chunk": try: # Process 80ms audio chunk with Moshi # In real implementation: # 1. Decode base64 audio data to numpy array # 2. Process with Mimi codec (24kHz) # 3. Generate text with Moshi LM # 4. Return transcription # For now, mock processing transcription = f"Moshi STT transcription for 24kHz chunk at {data.get('timestamp', 'unknown')}" # Send transcription result await websocket.send_json({ "type": "transcription", "text": transcription, "timestamp": time.time(), "chunk_id": data.get("timestamp"), "confidence": 0.95, "model": "moshi" }) except Exception as e: await websocket.send_json({ "type": "error", "message": f"Moshi processing error: {str(e)}", "timestamp": time.time() }) elif data.get("type") == "ping": # Respond to ping await websocket.send_json({ "type": "pong", "timestamp": time.time(), "model": "moshi" }) except WebSocketDisconnect: logger.info("Moshi WebSocket connection closed") except Exception as e: logger.error(f"Moshi WebSocket error: {e}") await websocket.close(code=1011, reason=f"Moshi server error: {str(e)}") @app.post("/api/transcribe") async def api_transcribe(audio_file: Optional[str] = None): """REST API endpoint for testing Moshi STT""" if not audio_file: raise HTTPException(status_code=400, detail="No audio data provided") # Mock transcription result = { "transcription": f"Moshi STT API transcription for: {audio_file[:50]}...", "timestamp": time.time(), "version": VERSION, "method": "REST", "model": "moshi", "expected_sample_rate": "24kHz" } return result if __name__ == "__main__": # Run the server uvicorn.run( "app:app", host="0.0.0.0", port=7860, log_level="info", access_log=True )