import asyncio import json import time import logging import os from typing import Optional from contextlib import asynccontextmanager # CRITICAL: Set OMP_NUM_THREADS before any torch/numpy imports # HuggingFace is overriding our Dockerfile ENV with CPU_CORES value os.environ['OMP_NUM_THREADS'] = '1' # Also ensure other environment variables are correct os.environ['HF_HOME'] = '/app/hf_cache' os.environ['HUGGINGFACE_HUB_CACHE'] = '/app/hf_cache' os.environ['TRANSFORMERS_CACHE'] = '/app/hf_cache' import torch import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse, HTMLResponse import uvicorn # Version tracking VERSION = "2.0.3" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Create cache directory if it doesn't exist os.makedirs('/app/hf_cache', exist_ok=True) # Global Moshi model variables mimi = None moshi = None lm_gen = None device = None async def load_moshi_models(): """Load Moshi STT models on startup""" global mimi, moshi, lm_gen, device try: logger.info("Loading Moshi models...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") logger.info(f"Cache directory: {os.environ.get('HF_HOME', 'default')}") # Clear GPU memory and set memory management if device == "cuda": torch.cuda.empty_cache() # Enable memory efficient attention torch.backends.cuda.enable_flash_sdp(False) logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") try: from huggingface_hub import hf_hub_download from moshi.models import loaders, LMGen # Load Mimi (audio codec) - using full Moshi model logger.info("Loading Mimi audio codec...") mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME, cache_dir='/app/hf_cache') mimi = loaders.get_mimi(mimi_weight, device=device) mimi.set_num_codebooks(8) # Limited to 8 for Moshi logger.info("โ Mimi loaded successfully") # Clear cache after Mimi loading if device == "cuda": torch.cuda.empty_cache() logger.info(f"GPU memory after Mimi: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") # Load Moshi (full language model) logger.info("Loading Moshi language model...") moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME, cache_dir='/app/hf_cache') # Try loading with memory-efficient settings try: moshi = loaders.get_moshi_lm(moshi_weight, device=device) lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) logger.info("โ Moshi loaded successfully on GPU") except RuntimeError as cuda_error: if "CUDA out of memory" in str(cuda_error): logger.warning(f"Moshi CUDA out of memory, trying CPU fallback: {cuda_error}") # Move Mimi to CPU as well for consistency mimi = loaders.get_mimi(mimi_weight, device="cpu") mimi.set_num_codebooks(8) device = "cpu" moshi = loaders.get_moshi_lm(moshi_weight, device="cpu") lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) logger.info("โ Moshi loaded successfully on CPU (fallback)") logger.info("โ Mimi also moved to CPU for device consistency") else: raise logger.info("๐ All Moshi models loaded successfully!") return True except ImportError as import_error: logger.error(f"Moshi import failed: {import_error}") mimi = "mock" moshi = "mock" lm_gen = "mock" return False except Exception as model_error: logger.error(f"Failed to load Moshi models: {model_error}") # Set mock mode mimi = "mock" moshi = "mock" lm_gen = "mock" return False except Exception as e: logger.error(f"Error in load_moshi_models: {e}") mimi = "mock" moshi = "mock" lm_gen = "mock" return False def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str: """Transcribe audio using Moshi models""" try: logger.info(f"๐๏ธ Starting transcription - Audio length: {len(audio_data)} samples at {sample_rate}Hz") if mimi == "mock": duration = len(audio_data) / sample_rate return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz" # Ensure 24kHz audio for Moshi if sample_rate != 24000: import librosa logger.info(f"๐ Resampling from {sample_rate}Hz to 24000Hz") audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) # Determine actual device of the models (might have fallen back to CPU) model_device = next(mimi.parameters()).device if hasattr(mimi, 'parameters') else device logger.info(f"Using device for transcription: {model_device}") # Convert to torch tensor and put on same device as models # Copy array to avoid PyTorch writable tensor warning wav = torch.from_numpy(audio_data.copy()).unsqueeze(0).unsqueeze(0).to(model_device) logger.info(f"๐ Tensor shape: {wav.shape}, device: {wav.device}") # Process with Mimi codec in streaming mode logger.info("๐ง Starting Mimi audio encoding...") with torch.no_grad(), mimi.streaming(batch_size=1): all_codes = [] frame_size = mimi.frame_size logger.info(f"๐ Frame size: {frame_size}") for offset in range(0, wav.shape[-1], frame_size): frame = wav[:, :, offset: offset + frame_size] if frame.shape[-1] == 0: break # Pad last frame if needed if frame.shape[-1] < frame_size: padding = frame_size - frame.shape[-1] frame = torch.nn.functional.pad(frame, (0, padding)) codes = mimi.encode(frame) all_codes.append(codes) logger.info(f"๐ต Encoded {len(all_codes)} audio frames") # Concatenate all codes if all_codes: audio_tokens = torch.cat(all_codes, dim=-1) logger.info(f"๐ Audio tokens shape: {audio_tokens.shape}") # Generate text with Moshi language model logger.info("๐ง Starting Moshi text generation...") with torch.no_grad(): try: # Use the actual language model for generation if lm_gen and lm_gen != "mock": logger.info(f"๐ง LMGen type: {type(lm_gen)}") logger.info(f"๐ง LMGen methods: {[m for m in dir(lm_gen) if not m.startswith('_')]}") # Try simpler approach - maybe streaming context is the issue try: # First try without streaming context logger.info("๐งช Trying step() without streaming context...") code_step = audio_tokens[:, :, 0:1] # Just first timestep [B, 8, 1] tokens_out = lm_gen.step(code_step) logger.info(f"๐ Direct step result: {type(tokens_out)}, value: {tokens_out}") if tokens_out is None: # Try with streaming context logger.info("๐งช Trying with streaming context...") with lm_gen.streaming(1): tokens_out = lm_gen.step(code_step) logger.info(f"๐ Streaming step result: {type(tokens_out)}, value: {tokens_out}") if tokens_out is None: # Maybe we need to call a different method or check state logger.error("๐จ Both approaches returned None - checking LMGen state") logger.info(f"๐ง LMGen attributes: {vars(lm_gen) if hasattr(lm_gen, '__dict__') else 'No __dict__'}") text_output = "Moshiko: LMGen step() returns None - API issue" else: logger.info(f"โ Got tokens! Shape: {tokens_out.shape if hasattr(tokens_out, 'shape') else 'No shape'}") text_output = f"Moshiko CPU: Successfully generated tokens with shape {tokens_out.shape if hasattr(tokens_out, 'shape') else 'unknown'}" except Exception as step_error: logger.error(f"๐จ LMGen step error: {step_error}") text_output = f"Moshiko: LMGen step error: {str(step_error)}" else: text_output = "Moshiko fallback: LM generator not available" logger.warning("โ ๏ธ LM generator not available, using fallback") return text_output except Exception as gen_error: logger.error(f"โ Text generation failed: {gen_error}") return f"Moshiko encoding successful but text generation failed: {str(gen_error)}" logger.warning("โ ๏ธ No audio tokens were generated") return "No audio tokens generated" except Exception as e: logger.error(f"Moshi transcription error: {e}") return f"Error: {str(e)}" # Use lifespan instead of deprecated on_event @asynccontextmanager async def lifespan(app: FastAPI): # Startup await load_moshi_models() yield # Shutdown (if needed) # FastAPI app with lifespan app = FastAPI( title="STT GPU Service Python v4 - Full Moshi Model", description="Real-time WebSocket STT streaming with full Moshi PyTorch implementation (L4 GPU with 30GB VRAM)", version=VERSION, lifespan=lifespan ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "message": "Moshi STT WebSocket Service - Full model on L4 GPU", "space_name": "stt-gpu-service-python-v4", "mimi_loaded": mimi is not None and mimi != "mock", "moshi_loaded": moshi is not None and moshi != "mock", "device": str(device) if device else "unknown", "expected_sample_rate": "24000Hz", "cache_dir": "/app/hf_cache", "cache_status": "writable" } @app.get("/", response_class=HTMLResponse) async def get_index(): """Simple HTML interface for testing""" html_content = f"""
Real-time WebSocket speech transcription with Moshi PyTorch implementation
๐ฏ Almost there! Moshi models should now load properly with writable cache directory.
๐ Latest: Fixed cache permissions - HF models can now download properly.
Status: Disconnected
Expected: 24kHz audio chunks (80ms = ~1920 samples)
Moshi transcription output will appear here...