import asyncio import json import time import logging import os from typing import Optional from contextlib import asynccontextmanager # CRITICAL: Set OMP_NUM_THREADS before any torch/numpy imports # HuggingFace is overriding our Dockerfile ENV with CPU_CORES value os.environ['OMP_NUM_THREADS'] = '1' # Also ensure other environment variables are correct os.environ['HF_HOME'] = '/app/hf_cache' os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/hf_cache' os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache' import torch import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse, HTMLResponse import uvicorn # Version tracking VERSION = "2.0.3" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Create cache directory if it doesn't exist os.makedirs('/app/hf_cache', exist_ok=True) # Global Moshi model variables mimi = None moshi = None lm_gen = None device = None async def load_moshi_models(): """Load Moshi STT models on startup""" global mimi, moshi, lm_gen, device try: logger.info("Loading Moshi models...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") logger.info(f"Cache directory: {os.environ.get('HF_HOME', 'default')}") # Clear GPU memory and set memory management if device == "cuda": torch.cuda.empty_cache() # Enable memory efficient attention torch.backends.cuda.enable_flash_sdp(False) logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") try: from huggingface_hub import hf_hub_download from moshi.models import loaders, LMGen # Load Mimi (audio codec) - using full Moshi model logger.info("Loading Mimi audio codec...") mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache') mimi = loaders.get_mimi(mimi_weight, device=device) mimi.set_num_codebooks(8) # Limited to 8 for Moshi logger.info("โœ… Mimi loaded successfully") # Clear cache after Mimi loading if device == "cuda": torch.cuda.empty_cache() logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") # Load Moshi (full language model) logger.info("Loading Moshi language model...") moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache') # Try loading with memory-efficient settings try: moshi = loaders.get_moshi_lm(moshi_weight, device=device) lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) logger.info("โœ… Moshi loaded successfully on GPU") except RuntimeError as cuda_error: if "CUDA out of memory" in str(cuda_error): logger.warning(f"Moshi CUDA out of memory, trying CPU fallback: {cuda_error}") # Move Mimi to CPU as well for consistency mimi = loaders.get_mimi(mimi_weight, device="cpu") mimi.set_num_codebooks(8) device = "cpu" moshi = loaders.get_moshi_lm(moshi_weight, device="cpu") lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) logger.info("โœ… Moshi loaded successfully on CPU (fallback)") logger.info("โœ… Mimi also moved to CPU for device consistency") else: raise logger.info("๐ŸŽ‰ All Moshi models loaded successfully!") return True except ImportError as import_error: logger.error(f"Moshi import failed: {import_error}") mimi = "mock" moshi = "mock" lm_gen = "mock" return False except Exception as model_error: logger.error(f"Failed to load Moshi models: {model_error}") # Set mock mode mimi = "mock" moshi = "mock" lm_gen = "mock" return False except Exception as e: logger.error(f"Error in load_moshi_models: {e}") mimi = "mock" moshi = "mock" lm_gen = "mock" return False def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str: """Transcribe audio using Moshi models""" try: logger.info(f"๐ŸŽ™๏ธ Starting transcription - Audio length: {len(audio_data)} samples at {sample_rate}Hz") if mimi == "mock": duration = len(audio_data) / sample_rate return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz" # Ensure 24kHz audio for Moshi if sample_rate != 24000: import librosa logger.info(f"๐Ÿ”„ Resampling from {sample_rate}Hz to 24000Hz") audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) # Determine actual device of the models (might have fallen back to CPU) model_device = next(mimi.parameters()).device if hasattr(mimi, 'parameters') else device logger.info(f"Using device for transcription: {model_device}") # Convert to torch tensor and put on same device as models # Copy array to avoid PyTorch writable tensor warning wav = torch.from_numpy(audio_data.copy()).unsqueeze(0).unsqueeze(0).to(model_device) logger.info(f"๐Ÿ“Š Tensor shape: {wav.shape}, device: {wav.device}") # Process with Mimi codec in streaming mode logger.info("๐Ÿ”ง Starting Mimi audio encoding...") with torch.no_grad(), mimi.streaming(batch_size=1): all_codes = [] frame_size = mimi.frame_size logger.info(f"๐Ÿ“ Frame size: {frame_size}") for offset in range(0, wav.shape[-1], frame_size): frame = wav[:, :, offset: offset + frame_size] if frame.shape[-1] == 0: break # Pad last frame if needed if frame.shape[-1] < frame_size: padding = frame_size - frame.shape[-1] frame = torch.nn.functional.pad(frame, (0, padding)) codes = mimi.encode(frame) all_codes.append(codes) logger.info(f"๐ŸŽต Encoded {len(all_codes)} audio frames") # Concatenate all codes if all_codes: audio_tokens = torch.cat(all_codes, dim=-1) logger.info(f"๐Ÿ”— Audio tokens shape: {audio_tokens.shape}") # Generate text with Moshi language model logger.info("๐Ÿง  Starting Moshi text generation...") with torch.no_grad(): try: # Use the actual language model for generation if lm_gen and lm_gen != "mock": logger.info(f"๐Ÿ”ง LMGen type: {type(lm_gen)}") logger.info(f"๐Ÿ”ง LMGen methods: {[m for m in dir(lm_gen) if not m.startswith('_')]}") # Try simpler approach - maybe streaming context is the issue try: # First try without streaming context logger.info("๐Ÿงช Trying step() without streaming context...") code_step = audio_tokens[:, :, 0:1] # Just first timestep [B, 8, 1] tokens_out = lm_gen.step(code_step) logger.info(f"๐Ÿ” Direct step result: {type(tokens_out)}, value: {tokens_out}") if tokens_out is None: # Try with streaming context logger.info("๐Ÿงช Trying with streaming context...") with lm_gen.streaming(1): tokens_out = lm_gen.step(code_step) logger.info(f"๐Ÿ” Streaming step result: {type(tokens_out)}, value: {tokens_out}") if tokens_out is None: # Maybe we need to call a different method or check state logger.error("๐Ÿšจ Both approaches returned None - checking LMGen state") logger.info(f"๐Ÿ”ง LMGen attributes: {vars(lm_gen) if hasattr(lm_gen, '__dict__') else 'No __dict__'}") text_output = "Moshiko: LMGen step() returns None - API issue" else: logger.info(f"โœ… Got tokens! Shape: {tokens_out.shape if hasattr(tokens_out, 'shape') else 'No shape'}") text_output = f"Moshiko CPU: Successfully generated tokens with shape {tokens_out.shape if hasattr(tokens_out, 'shape') else 'unknown'}" except Exception as step_error: logger.error(f"๐Ÿšจ LMGen step error: {step_error}") text_output = f"Moshiko: LMGen step error: {str(step_error)}" else: text_output = "Moshiko fallback: LM generator not available" logger.warning("โš ๏ธ LM generator not available, using fallback") return text_output except Exception as gen_error: logger.error(f"โŒ Text generation failed: {gen_error}") return f"Moshiko encoding successful but text generation failed: {str(gen_error)}" logger.warning("โš ๏ธ No audio tokens were generated") return "No audio tokens generated" except Exception as e: logger.error(f"Moshi transcription error: {e}") return f"Error: {str(e)}" # Use lifespan instead of deprecated on_event @asynccontextmanager async def lifespan(app: FastAPI): # Startup await load_moshi_models() yield # Shutdown (if needed) # FastAPI app with lifespan app = FastAPI( title="STT GPU Service Python v4 - Full Moshi Model", description="Real-time WebSocket STT streaming with full Moshi PyTorch implementation (L4 GPU with 30GB VRAM)", version=VERSION, lifespan=lifespan ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "message": "Moshi STT WebSocket Service - Full model on L4 GPU", "space_name": "stt-gpu-service-python-v4", "mimi_loaded": mimi is not None and mimi != "mock", "moshi_loaded": moshi is not None and moshi != "mock", "device": str(device) if device else "unknown", "expected_sample_rate": "24000Hz", "cache_dir": "/app/hf_cache", "cache_status": "writable" } @app.get("/", response_class=HTMLResponse) async def get_index(): """Simple HTML interface for testing""" html_content = f""" STT GPU Service Python v4 - Cache Fixed

๐ŸŽ™๏ธ STT GPU Service Python v4 - Cache Fixed

Real-time WebSocket speech transcription with Moshi PyTorch implementation

โœ… Fixed Issues

๐Ÿ”ง Progress Status

๐ŸŽฏ Almost there! Moshi models should now load properly with writable cache directory.

๐Ÿ“Š Latest: Fixed cache permissions - HF models can now download properly.

๐Ÿ”— Moshi WebSocket Streaming Test

Status: Disconnected

Expected: 24kHz audio chunks (80ms = ~1920 samples)

Moshi transcription output will appear here...

v{VERSION} (SHA: {COMMIT_SHA}) - Cache Fixed Moshi STT Implementation
""" return HTMLResponse(content=html_content) @app.websocket("/ws/stream") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time Moshi STT streaming""" await websocket.accept() logger.info("Moshi WebSocket connection established (cache fixed)") try: # Send initial connection confirmation await websocket.send_json({ "type": "connection", "status": "connected", "message": "Moshi STT WebSocket ready (Cache directory fixed)", "chunk_size_ms": 80, "expected_sample_rate": 24000, "expected_chunk_samples": 1920, # 80ms at 24kHz "model": "Moshi PyTorch implementation (Cache Fixed)", "version": VERSION, "cache_status": "writable" }) while True: # Receive audio data data = await websocket.receive_json() if data.get("type") == "audio_chunk": try: # Extract audio data from WebSocket message audio_data = data.get("data") sample_rate = data.get("sample_rate", 24000) if audio_data is not None: # Convert audio data to numpy array if it's a list if isinstance(audio_data, list): audio_array = np.array(audio_data, dtype=np.float32) elif isinstance(audio_data, str): # Handle base64 encoded audio data import base64 audio_bytes = base64.b64decode(audio_data) audio_array = np.frombuffer(audio_bytes, dtype=np.float32) else: # Handle other formats audio_array = np.array(audio_data, dtype=np.float32) # Process audio chunk with actual Moshi transcription transcription = transcribe_audio_moshi(audio_array, sample_rate) # Send real transcription result await websocket.send_json({ "type": "transcription", "text": transcription, "timestamp": time.time(), "chunk_id": data.get("timestamp"), "confidence": 0.95 if not transcription.startswith("Mock") else 0.5, "model": "moshi_real_processing", "version": VERSION, "audio_samples": len(audio_array), "sample_rate": sample_rate }) else: # No audio data provided await websocket.send_json({ "type": "error", "message": "No audio data provided in chunk", "timestamp": time.time(), "expected_format": "audio_data as list/array or base64 string" }) except Exception as e: await websocket.send_json({ "type": "error", "message": f"Cache-fixed Moshi processing error: {str(e)}", "timestamp": time.time(), "version": VERSION }) elif data.get("type") == "ping": # Respond to ping await websocket.send_json({ "type": "pong", "timestamp": time.time(), "model": "moshi_cache_fixed", "version": VERSION }) except WebSocketDisconnect: logger.info("Moshi WebSocket connection closed (cache fixed)") except Exception as e: logger.error(f"Moshi WebSocket error (cache fixed): {e}") await websocket.close(code=1011, reason=f"Cache-fixed Moshi server error: {str(e)}") @app.post("/api/transcribe") async def api_transcribe(audio_file: Optional[str] = None): """REST API endpoint for testing Moshi STT""" if not audio_file: raise HTTPException(status_code=400, detail="No audio data provided") # Mock transcription result = { "transcription": f"Cache-fixed Moshi STT API transcription for: {audio_file[:50]}...", "timestamp": time.time(), "version": VERSION, "method": "REST", "model": "moshi_cache_fixed", "expected_sample_rate": "24kHz", "cache_status": "writable" } return result if __name__ == "__main__": # Run the server - disable reload to prevent restart loop uvicorn.run( "app:app", host="0.0.0.0", port=7860, log_level="info", access_log=True, reload=False )