stt-gpu-service-python-v4 / app_moshi_stt.py
Peter Michael Gits
Fix Dockerfile directory permissions - create /app as root before switching users
26096f4
import asyncio
import json
import time
import logging
from typing import Optional
import torch
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
import uvicorn
# Version tracking
VERSION = "1.3.0"
COMMIT_SHA = "TBD"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global Moshi model variables
mimi = None
moshi = None
lm_gen = None
device = None
async def load_moshi_models():
"""Load Moshi STT models on startup"""
global mimi, moshi, lm_gen, device
try:
logger.info("Loading Moshi models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
try:
from huggingface_hub import hf_hub_download
from moshi.models import loaders, LMGen
# Load Mimi (audio codec)
logger.info("Loading Mimi audio codec...")
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(8) # Limited to 8 for Moshi
# Load Moshi (language model)
logger.info("Loading Moshi language model...")
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
moshi = loaders.get_moshi_lm(moshi_weight, device=device)
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
logger.info("✅ Moshi models loaded successfully")
return True
except Exception as model_error:
logger.error(f"Failed to load Moshi models: {model_error}")
# Set mock mode
mimi = "mock"
moshi = "mock"
lm_gen = "mock"
return False
except Exception as e:
logger.error(f"Error in load_moshi_models: {e}")
mimi = "mock"
moshi = "mock"
lm_gen = "mock"
return False
def transcribe_audio_moshi(audio_data: np.ndarray, sample_rate: int = 24000) -> str:
"""Transcribe audio using Moshi models"""
try:
if mimi == "mock":
duration = len(audio_data) / sample_rate
return f"Mock Moshi STT: {duration:.2f}s audio at {sample_rate}Hz"
# Ensure 24kHz audio for Moshi
if sample_rate != 24000:
import librosa
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000)
# Convert to torch tensor
wav = torch.from_numpy(audio_data).unsqueeze(0).unsqueeze(0).to(device)
# Process with Mimi codec in streaming mode
with torch.no_grad(), mimi.streaming(batch_size=1):
all_codes = []
frame_size = mimi.frame_size
for offset in range(0, wav.shape[-1], frame_size):
frame = wav[:, :, offset: offset + frame_size]
if frame.shape[-1] == 0:
break
# Pad last frame if needed
if frame.shape[-1] < frame_size:
padding = frame_size - frame.shape[-1]
frame = torch.nn.functional.pad(frame, (0, padding))
codes = mimi.encode(frame)
all_codes.append(codes)
# Concatenate all codes
if all_codes:
audio_tokens = torch.cat(all_codes, dim=-1)
# Generate text with language model
with torch.no_grad():
# Simple text generation from audio tokens
# This is a simplified approach - Moshi has more complex generation
text_output = lm_gen.generate_text_from_audio(audio_tokens)
return text_output if text_output else "Transcription completed"
return "No audio tokens generated"
except Exception as e:
logger.error(f"Moshi transcription error: {e}")
return f"Error: {str(e)}"
# FastAPI app
app = FastAPI(
title="STT GPU Service Python v4 - Moshi",
description="Real-time WebSocket STT streaming with Moshi PyTorch implementation",
version=VERSION
)
@app.on_event("startup")
async def startup_event():
"""Load Moshi models on startup"""
await load_moshi_models()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": time.time(),
"version": VERSION,
"commit_sha": COMMIT_SHA,
"message": "Moshi STT WebSocket Service - Real-time streaming ready",
"space_name": "stt-gpu-service-python-v4",
"mimi_loaded": mimi is not None and mimi != "mock",
"moshi_loaded": moshi is not None and moshi != "mock",
"device": str(device) if device else "unknown",
"expected_sample_rate": "24000Hz"
}
@app.get("/", response_class=HTMLResponse)
async def get_index():
"""Simple HTML interface for testing"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>STT GPU Service Python v4 - Moshi</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.container {{ max-width: 800px; margin: 0 auto; }}
.status {{ background: #f0f0f0; padding: 20px; border-radius: 8px; margin: 20px 0; }}
button {{ padding: 10px 20px; margin: 5px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }}
button:disabled {{ background: #ccc; }}
#output {{ background: #f8f9fa; padding: 15px; border-radius: 4px; margin-top: 20px; }}
.version {{ font-size: 0.8em; color: #666; margin-top: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>🎙️ STT GPU Service Python v4 - Moshi</h1>
<p>Real-time WebSocket speech transcription with Moshi PyTorch implementation</p>
<div class="status">
<h3>🔗 Moshi WebSocket Streaming Test</h3>
<button onclick="startWebSocket()">Connect WebSocket</button>
<button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button>
<p>Status: <span id="wsStatus">Disconnected</span></p>
<p><small>Expected: 24kHz audio chunks (80ms = ~1920 samples)</small></p>
</div>
<div id="output">
<p>Moshi transcription output will appear here...</p>
</div>
<div class="version">
v{VERSION} (SHA: {COMMIT_SHA}) - Moshi STT Implementation
</div>
</div>
<script>
let ws = null;
function startWebSocket() {{
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${{protocol}}//${{window.location.host}}/ws/stream`;
ws = new WebSocket(wsUrl);
ws.onopen = function(event) {{
document.getElementById('wsStatus').textContent = 'Connected to Moshi STT';
document.querySelector('button').disabled = true;
document.getElementById('stopBtn').disabled = false;
// Send test message
ws.send(JSON.stringify({{
type: 'audio_chunk',
data: 'test_moshi_audio_24khz',
timestamp: Date.now()
}}));
}};
ws.onmessage = function(event) {{
const data = JSON.parse(event.data);
document.getElementById('output').innerHTML += `<p>${{JSON.stringify(data, null, 2)}}</p>`;
}};
ws.onclose = function(event) {{
document.getElementById('wsStatus').textContent = 'Disconnected';
document.querySelector('button').disabled = false;
document.getElementById('stopBtn').disabled = true;
}};
ws.onerror = function(error) {{
document.getElementById('output').innerHTML += `<p style="color: red;">WebSocket Error: ${{error}}</p>`;
}};
}}
function stopWebSocket() {{
if (ws) {{
ws.close();
}}
}}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.websocket("/ws/stream")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time Moshi STT streaming"""
await websocket.accept()
logger.info("Moshi WebSocket connection established")
try:
# Send initial connection confirmation
await websocket.send_json({
"type": "connection",
"status": "connected",
"message": "Moshi STT WebSocket ready for audio chunks",
"chunk_size_ms": 80,
"expected_sample_rate": 24000,
"expected_chunk_samples": 1920, # 80ms at 24kHz
"model": "Moshi PyTorch implementation"
})
while True:
# Receive audio data
data = await websocket.receive_json()
if data.get("type") == "audio_chunk":
try:
# Process 80ms audio chunk with Moshi
# In real implementation:
# 1. Decode base64 audio data to numpy array
# 2. Process with Mimi codec (24kHz)
# 3. Generate text with Moshi LM
# 4. Return transcription
# For now, mock processing
transcription = f"Moshi STT transcription for 24kHz chunk at {data.get('timestamp', 'unknown')}"
# Send transcription result
await websocket.send_json({
"type": "transcription",
"text": transcription,
"timestamp": time.time(),
"chunk_id": data.get("timestamp"),
"confidence": 0.95,
"model": "moshi"
})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": f"Moshi processing error: {str(e)}",
"timestamp": time.time()
})
elif data.get("type") == "ping":
# Respond to ping
await websocket.send_json({
"type": "pong",
"timestamp": time.time(),
"model": "moshi"
})
except WebSocketDisconnect:
logger.info("Moshi WebSocket connection closed")
except Exception as e:
logger.error(f"Moshi WebSocket error: {e}")
await websocket.close(code=1011, reason=f"Moshi server error: {str(e)}")
@app.post("/api/transcribe")
async def api_transcribe(audio_file: Optional[str] = None):
"""REST API endpoint for testing Moshi STT"""
if not audio_file:
raise HTTPException(status_code=400, detail="No audio data provided")
# Mock transcription
result = {
"transcription": f"Moshi STT API transcription for: {audio_file[:50]}...",
"timestamp": time.time(),
"version": VERSION,
"method": "REST",
"model": "moshi",
"expected_sample_rate": "24kHz"
}
return result
if __name__ == "__main__":
# Run the server
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True
)