Spaces:
Sleeping
Sleeping
Deploy Polyglot backend with quantized models
Browse files- .dockerignore +12 -0
- Dockerfile +47 -0
- README.md +40 -10
- app/__init__.py +1 -0
- app/auth.py +310 -0
- app/config/__init__.py +7 -0
- app/config/cors.py +295 -0
- app/main.py +345 -0
- app/main.py.bak +345 -0
- app/models/__init__.py +77 -0
- app/routers/__init__.py +1 -0
- app/routers/add_phase_endpoints.py +490 -0
- app/routers/learning.py +1020 -0
- app/routers/mobile.py +536 -0
- app/routers/sessions.py +200 -0
- app/routers/watch.py +152 -0
- app/services/__init__.py +1 -0
- app/services/learning_data_service.py +415 -0
- app/services/quantization_utils.py +124 -0
- app/services/session_manager.py +180 -0
- app/services/transcription_service.py +736 -0
- app/services/transcription_service.py.bak +726 -0
- app/services/transcription_service_onnx.py +682 -0
- app/services/transcription_service_onnx_optimized.py +251 -0
- app/services/translation_service.py +151 -0
- app/services/translation_service_onnx.py +268 -0
- app/services/tts_service.py +541 -0
- app/services/tts_service_onnx.py +587 -0
- app/services/websocket_manager.py +909 -0
- preload_models.py +23 -0
- requirements.txt +58 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.cache
|
| 2 |
+
nltk_data
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
*.pyd
|
| 7 |
+
.Python
|
| 8 |
+
*.so
|
| 9 |
+
*.egg
|
| 10 |
+
*.egg-info
|
| 11 |
+
dist
|
| 12 |
+
build
|
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
ffmpeg \
|
| 8 |
+
libsndfile1 \
|
| 9 |
+
sox \
|
| 10 |
+
espeak \
|
| 11 |
+
espeak-data \
|
| 12 |
+
libespeak1 \
|
| 13 |
+
libespeak-dev \
|
| 14 |
+
wget \
|
| 15 |
+
gnupg \
|
| 16 |
+
curl \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Copy requirements and install Python dependencies
|
| 20 |
+
COPY requirements.txt .
|
| 21 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 22 |
+
|
| 23 |
+
# Copy application code
|
| 24 |
+
COPY app ./app
|
| 25 |
+
COPY preload_models.py .
|
| 26 |
+
|
| 27 |
+
# Set environment variables for caching
|
| 28 |
+
ENV HF_HOME=/app/.cache
|
| 29 |
+
ENV TRANSFORMERS_CACHE=/app/.cache
|
| 30 |
+
ENV NLTK_DATA=/app/nltk_data
|
| 31 |
+
ENV PYTHONPATH=/app
|
| 32 |
+
ENV PORT=7860
|
| 33 |
+
|
| 34 |
+
# Create cache directories
|
| 35 |
+
RUN mkdir -p $HF_HOME && chmod -R 777 $HF_HOME
|
| 36 |
+
RUN mkdir -p $NLTK_DATA && chmod -R 777 $NLTK_DATA
|
| 37 |
+
|
| 38 |
+
# Download models using HF token from environment
|
| 39 |
+
# HuggingFace Spaces automatically provides HUGGING_FACE_HUB_TOKEN
|
| 40 |
+
ARG HUGGING_FACE_HUB_TOKEN
|
| 41 |
+
RUN python preload_models.py $HUGGING_FACE_HUB_TOKEN || echo "Model preload skipped - will download on first use"
|
| 42 |
+
|
| 43 |
+
# Expose port 7860 (HuggingFace Spaces standard)
|
| 44 |
+
EXPOSE 7860
|
| 45 |
+
|
| 46 |
+
# Run the application
|
| 47 |
+
CMD ["uvicorn", "app.main:socket_app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,40 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Polyglot Backend
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo: green
|
| 6 |
-
sdk: docker
|
| 7 |
-
pinned: false
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Polyglot Translation Backend
|
| 3 |
+
emoji: 🌍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Polyglot Translation Backend - Quantized Models
|
| 13 |
+
|
| 14 |
+
Real-time speech transcription and translation API with Socket.IO for WebSocket communication. This version uses INT8 quantized models for improved performance and reduced memory footprint.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- **Real-time Speech Recognition**: Support for English, Swahili, Kikuyu, Kamba, Kimeru, Luo, and Somali
|
| 19 |
+
- **Translation**: Multi-language translation using NLLB models
|
| 20 |
+
- **Text-to-Speech**: Generate speech in multiple languages
|
| 21 |
+
- **WebSocket Support**: Real-time communication via Socket.IO
|
| 22 |
+
- **Model Quantization**: INT8 dynamic quantization for faster inference
|
| 23 |
+
|
| 24 |
+
## API Endpoints
|
| 25 |
+
|
| 26 |
+
- `GET /health` - Health check endpoint
|
| 27 |
+
- `WebSocket /` - Socket.IO connection for real-time communication
|
| 28 |
+
|
| 29 |
+
## Environment
|
| 30 |
+
|
| 31 |
+
This Space requires a HuggingFace token for model access. The token is automatically provided by HuggingFace Spaces when configured as a secret.
|
| 32 |
+
|
| 33 |
+
## Technical Details
|
| 34 |
+
|
| 35 |
+
- **Framework**: FastAPI with Socket.IO
|
| 36 |
+
- **Models**:
|
| 37 |
+
- ASR: Whisper (English) and Wav2Vec2-BERT (African languages)
|
| 38 |
+
- Translation: NLLB-600M fine-tuned model
|
| 39 |
+
- TTS: VITS models for each language
|
| 40 |
+
- **Optimization**: INT8 dynamic quantization via PyTorch
|
app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Backend application package
|
app/auth.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Authentication module for HuggingFace token validation
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from fastapi import HTTPException, status, Request
|
| 7 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 8 |
+
from fastapi.security.utils import get_authorization_scheme_param
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def is_local_development() -> bool:
|
| 12 |
+
"""
|
| 13 |
+
Detect if the application is running in local development mode.
|
| 14 |
+
This checks multiple indicators to determine if auth should be disabled.
|
| 15 |
+
"""
|
| 16 |
+
# Method 1: Explicit disable auth flag
|
| 17 |
+
disable_auth = os.getenv('DISABLE_AUTH', '').lower()
|
| 18 |
+
if disable_auth in ['true', '1', 'yes']:
|
| 19 |
+
return True
|
| 20 |
+
|
| 21 |
+
# Method 2: Check ENVIRONMENT variable
|
| 22 |
+
environment = os.getenv('ENVIRONMENT', '').lower()
|
| 23 |
+
if environment in ['development', 'dev', 'local']:
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
# Method 3: Check DEBUG flag
|
| 27 |
+
debug = os.getenv('DEBUG', '').lower()
|
| 28 |
+
if debug in ['true', '1', 'yes']:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
# Method 4: Check if running on localhost/development ports
|
| 32 |
+
host = os.getenv('HOST', '')
|
| 33 |
+
port = os.getenv('PORT', '')
|
| 34 |
+
if host in ['localhost', '127.0.0.1', '0.0.0.0'] and port == '7860':
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
# Method 5: Check for presence of local development files
|
| 38 |
+
local_indicators = [
|
| 39 |
+
'.env.local',
|
| 40 |
+
'docker-compose.local.yml',
|
| 41 |
+
'Dockerfile.local'
|
| 42 |
+
]
|
| 43 |
+
for indicator in local_indicators:
|
| 44 |
+
if os.path.exists(indicator):
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
# Method 6: Check if we're in a Docker container with local development setup
|
| 48 |
+
if os.path.exists('/.dockerenv'):
|
| 49 |
+
# We're in Docker, check if it's local development
|
| 50 |
+
if os.getenv('ALLOW_ALL_ORIGINS', '').lower() == 'true':
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class HuggingFaceTokenAuth:
|
| 57 |
+
"""HuggingFace token authentication handler"""
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
self.bearer = HTTPBearer(auto_error=False)
|
| 61 |
+
self.is_local = is_local_development()
|
| 62 |
+
|
| 63 |
+
if self.is_local:
|
| 64 |
+
print("🔓 RUNNING IN LOCAL DEVELOPMENT MODE - AUTH DISABLED")
|
| 65 |
+
print(" Environment indicators:")
|
| 66 |
+
print(f" - DISABLE_AUTH: {os.getenv('DISABLE_AUTH', 'not set')}")
|
| 67 |
+
print(f" - ENVIRONMENT: {os.getenv('ENVIRONMENT', 'not set')}")
|
| 68 |
+
print(f" - DEBUG: {os.getenv('DEBUG', 'not set')}")
|
| 69 |
+
print(f" - HOST: {os.getenv('HOST', 'not set')}")
|
| 70 |
+
print(f" - PORT: {os.getenv('PORT', 'not set')}")
|
| 71 |
+
print(f" - ALLOW_ALL_ORIGINS: {os.getenv('ALLOW_ALL_ORIGINS', 'not set')}")
|
| 72 |
+
print(f" - Docker container: {os.path.exists('/.dockerenv')}")
|
| 73 |
+
print(f" - .env.local exists: {os.path.exists('.env.local')}")
|
| 74 |
+
else:
|
| 75 |
+
print("🔒 RUNNING IN PRODUCTION MODE - AUTH REQUIRED")
|
| 76 |
+
|
| 77 |
+
def verify_token(self, token: str) -> bool:
|
| 78 |
+
"""
|
| 79 |
+
Verify if the token is a valid HuggingFace token.
|
| 80 |
+
In local development mode, always returns True.
|
| 81 |
+
"""
|
| 82 |
+
# Skip token validation in local development
|
| 83 |
+
if self.is_local:
|
| 84 |
+
print("🔓 Local development mode: skipping token validation")
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
if not token:
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
if not isinstance(token, str):
|
| 92 |
+
print(f"❌ Token is not a string: {type(token)}")
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
# HuggingFace tokens start with 'hf_'
|
| 96 |
+
if not token.startswith('hf_'):
|
| 97 |
+
print(f"❌ Token does not start with 'hf_': {token[:10]}...")
|
| 98 |
+
return False
|
| 99 |
+
|
| 100 |
+
# Additional validation can be added here
|
| 101 |
+
# For example, you could make a request to HuggingFace API
|
| 102 |
+
# to validate the token, but that would add latency
|
| 103 |
+
|
| 104 |
+
return True
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"❌ Error in verify_token: {e}")
|
| 108 |
+
return False
|
| 109 |
+
|
| 110 |
+
def get_token_from_request(self, request: Request) -> Optional[str]:
|
| 111 |
+
"""Extract token from various sources in the request"""
|
| 112 |
+
|
| 113 |
+
# Method 1: Authorization header
|
| 114 |
+
authorization = request.headers.get("Authorization")
|
| 115 |
+
if authorization:
|
| 116 |
+
scheme, token = get_authorization_scheme_param(authorization)
|
| 117 |
+
if scheme.lower() == "bearer":
|
| 118 |
+
return token
|
| 119 |
+
|
| 120 |
+
# Method 2: Query parameter (for WebSocket initial handshake)
|
| 121 |
+
token = request.query_params.get("token")
|
| 122 |
+
if token:
|
| 123 |
+
return token
|
| 124 |
+
|
| 125 |
+
# Method 3: Custom header (alternative)
|
| 126 |
+
token = request.headers.get("X-HF-Token")
|
| 127 |
+
if token:
|
| 128 |
+
return token
|
| 129 |
+
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
async def authenticate_request(self, request: Request) -> bool:
|
| 133 |
+
"""Authenticate a request using HuggingFace token"""
|
| 134 |
+
token = self.get_token_from_request(request)
|
| 135 |
+
|
| 136 |
+
if not token:
|
| 137 |
+
return False
|
| 138 |
+
|
| 139 |
+
return self.verify_token(token)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# Global instance
|
| 143 |
+
hf_auth = HuggingFaceTokenAuth()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
async def require_hf_token(request: Request) -> str:
|
| 147 |
+
"""
|
| 148 |
+
FastAPI dependency that requires a valid HuggingFace token.
|
| 149 |
+
In local development mode, returns a dummy token.
|
| 150 |
+
Returns the token if valid, raises HTTPException if not.
|
| 151 |
+
"""
|
| 152 |
+
# Skip authentication in local development
|
| 153 |
+
if hf_auth.is_local:
|
| 154 |
+
print("🔓 Local development mode: bypassing HF token requirement")
|
| 155 |
+
return "local-development-bypass"
|
| 156 |
+
|
| 157 |
+
token = hf_auth.get_token_from_request(request)
|
| 158 |
+
|
| 159 |
+
if not token:
|
| 160 |
+
raise HTTPException(
|
| 161 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 162 |
+
detail="HuggingFace token required. Please provide a valid token in Authorization header.",
|
| 163 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if not hf_auth.verify_token(token):
|
| 167 |
+
raise HTTPException(
|
| 168 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 169 |
+
detail="Invalid HuggingFace token. Token must start with 'hf_'.",
|
| 170 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
return token
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
async def optional_hf_token(request: Request) -> Optional[str]:
|
| 177 |
+
"""
|
| 178 |
+
FastAPI dependency that optionally validates HuggingFace token.
|
| 179 |
+
In local development mode, returns a dummy token if no real token provided.
|
| 180 |
+
Returns the token if present and valid, None otherwise.
|
| 181 |
+
Useful for endpoints that work with or without authentication.
|
| 182 |
+
"""
|
| 183 |
+
# In local development, always return a token
|
| 184 |
+
if hf_auth.is_local:
|
| 185 |
+
token = hf_auth.get_token_from_request(request)
|
| 186 |
+
if token and hf_auth.verify_token(token):
|
| 187 |
+
return token
|
| 188 |
+
else:
|
| 189 |
+
print("🔓 Local development mode: providing dummy token for optional auth")
|
| 190 |
+
return "local-development-bypass"
|
| 191 |
+
|
| 192 |
+
token = hf_auth.get_token_from_request(request)
|
| 193 |
+
|
| 194 |
+
if not token:
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
if hf_auth.verify_token(token):
|
| 198 |
+
return token
|
| 199 |
+
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def authenticate_websocket_connect(environ: dict) -> bool:
|
| 204 |
+
"""
|
| 205 |
+
Authenticate WebSocket connection using token from various sources.
|
| 206 |
+
In local development mode, always returns True.
|
| 207 |
+
This is called during the Socket.IO connect event.
|
| 208 |
+
"""
|
| 209 |
+
# Skip authentication in local development
|
| 210 |
+
if hf_auth.is_local:
|
| 211 |
+
print("🔓 Local development mode: bypassing WebSocket authentication")
|
| 212 |
+
return True
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
print("=== WEBSOCKET ENVIRON AUTHENTICATION ===")
|
| 216 |
+
print(f"Environ type: {type(environ)}")
|
| 217 |
+
|
| 218 |
+
if not isinstance(environ, dict):
|
| 219 |
+
print(f"❌ Environ is not a dict: {type(environ)}")
|
| 220 |
+
return False
|
| 221 |
+
|
| 222 |
+
# Method 1: Check query parameters
|
| 223 |
+
query_string = environ.get('QUERY_STRING', '')
|
| 224 |
+
print(f"Query string: {query_string}")
|
| 225 |
+
if query_string:
|
| 226 |
+
from urllib.parse import parse_qs
|
| 227 |
+
query_params = parse_qs(query_string)
|
| 228 |
+
print(f"Parsed query params: {query_params}")
|
| 229 |
+
tokens = query_params.get('token', [])
|
| 230 |
+
if tokens:
|
| 231 |
+
token = tokens[0]
|
| 232 |
+
print(f"Found token in query: {token[:10]}...")
|
| 233 |
+
if hf_auth.verify_token(token):
|
| 234 |
+
print("✓ Token validated via query params")
|
| 235 |
+
return True
|
| 236 |
+
|
| 237 |
+
# Method 2: Check headers
|
| 238 |
+
auth_header = environ.get('HTTP_AUTHORIZATION', '')
|
| 239 |
+
print(f"Authorization header: {auth_header[:20] if auth_header else 'None'}...")
|
| 240 |
+
if auth_header:
|
| 241 |
+
if auth_header.startswith('Bearer '):
|
| 242 |
+
token = auth_header[7:] # Remove 'Bearer ' prefix
|
| 243 |
+
print(f"Found token in Authorization header: {token[:10]}...")
|
| 244 |
+
if hf_auth.verify_token(token):
|
| 245 |
+
print("✓ Token validated via Authorization header")
|
| 246 |
+
return True
|
| 247 |
+
|
| 248 |
+
# Method 3: Check custom header
|
| 249 |
+
hf_token_header = environ.get('HTTP_X_HF_TOKEN', '')
|
| 250 |
+
print(f"X-HF-Token header: {hf_token_header[:10] if hf_token_header else 'None'}...")
|
| 251 |
+
if hf_token_header:
|
| 252 |
+
if hf_auth.verify_token(hf_token_header):
|
| 253 |
+
print("✓ Token validated via X-HF-Token header")
|
| 254 |
+
return True
|
| 255 |
+
|
| 256 |
+
print("❌ No valid token found in environ")
|
| 257 |
+
print(f"Available environ keys: {list(environ.keys())}")
|
| 258 |
+
return False
|
| 259 |
+
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"❌ Error in authenticate_websocket_connect: {e}")
|
| 262 |
+
import traceback
|
| 263 |
+
traceback.print_exc()
|
| 264 |
+
return False
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def authenticate_websocket_auth_data(auth_data: dict) -> bool:
|
| 268 |
+
"""
|
| 269 |
+
Authenticate WebSocket connection using auth data from Socket.IO.
|
| 270 |
+
In local development mode, always returns True.
|
| 271 |
+
This is called when the client sends auth data in the connection.
|
| 272 |
+
"""
|
| 273 |
+
# Skip authentication in local development
|
| 274 |
+
if hf_auth.is_local:
|
| 275 |
+
print("🔓 Local development mode: bypassing WebSocket auth data validation")
|
| 276 |
+
return True
|
| 277 |
+
|
| 278 |
+
try:
|
| 279 |
+
print("=== WEBSOCKET AUTH DATA AUTHENTICATION ===")
|
| 280 |
+
print(f"Auth data received: {auth_data}")
|
| 281 |
+
print(f"Auth data type: {type(auth_data)}")
|
| 282 |
+
|
| 283 |
+
if not auth_data:
|
| 284 |
+
print("❌ No auth data provided")
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
if not isinstance(auth_data, dict):
|
| 288 |
+
print(f"❌ Auth data is not a dict: {type(auth_data)}")
|
| 289 |
+
return False
|
| 290 |
+
|
| 291 |
+
# Check for token in auth data
|
| 292 |
+
token = auth_data.get('token')
|
| 293 |
+
if token:
|
| 294 |
+
print(f"Found token in auth data: {token[:10]}...")
|
| 295 |
+
if hf_auth.verify_token(token):
|
| 296 |
+
print("✓ Token validated via auth data")
|
| 297 |
+
return True
|
| 298 |
+
else:
|
| 299 |
+
print("❌ Invalid token in auth data")
|
| 300 |
+
else:
|
| 301 |
+
print("❌ No token in auth data")
|
| 302 |
+
print(f"Available keys in auth data: {list(auth_data.keys())}")
|
| 303 |
+
|
| 304 |
+
return False
|
| 305 |
+
|
| 306 |
+
except Exception as e:
|
| 307 |
+
print(f"❌ Error in authenticate_websocket_auth_data: {e}")
|
| 308 |
+
import traceback
|
| 309 |
+
traceback.print_exc()
|
| 310 |
+
return False
|
app/config/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration package for Polyglot backend
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .cors import cors_config
|
| 6 |
+
|
| 7 |
+
__all__ = ["cors_config"]
|
app/config/cors.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CORS Configuration Module
|
| 3 |
+
|
| 4 |
+
Centralized CORS configuration supporting multiple deployment environments.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
from enum import Enum
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Environment(str, Enum):
|
| 14 |
+
"""Deployment environment types"""
|
| 15 |
+
LOCAL = "local"
|
| 16 |
+
DEVELOPMENT = "development"
|
| 17 |
+
STAGING = "staging"
|
| 18 |
+
PRODUCTION = "production"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CORSConfig:
|
| 22 |
+
"""CORS configuration manager"""
|
| 23 |
+
|
| 24 |
+
# Default origins for local development
|
| 25 |
+
DEFAULT_LOCAL_ORIGINS = [
|
| 26 |
+
"http://localhost:3000", # React/Next.js dev server
|
| 27 |
+
"http://localhost:3001", # Polyglot frontend (Vite)
|
| 28 |
+
"http://localhost:3002", # Lessons UI (Vite)
|
| 29 |
+
"http://localhost:3003", # Podium (Vite)
|
| 30 |
+
"http://localhost:3004", # Podium alternative port
|
| 31 |
+
"http://localhost:5173", # Vite dev server
|
| 32 |
+
"http://localhost:7860", # Backend self-reference
|
| 33 |
+
"http://localhost:8080", # Alternative dev server
|
| 34 |
+
"http://127.0.0.1:3000", # IPv4 localhost variant
|
| 35 |
+
"http://127.0.0.1:3001", # IPv4 localhost variant
|
| 36 |
+
"http://127.0.0.1:3002", # IPv4 localhost variant
|
| 37 |
+
"http://127.0.0.1:3003", # IPv4 localhost variant
|
| 38 |
+
"http://127.0.0.1:3004", # IPv4 localhost variant
|
| 39 |
+
"http://127.0.0.1:5173", # IPv4 localhost variant
|
| 40 |
+
"http://127.0.0.1:7860", # IPv4 localhost variant
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
# Default patterns for production deployments
|
| 44 |
+
DEFAULT_PRODUCTION_PATTERNS = [
|
| 45 |
+
r"^https://.*\.tafiti\.dev$", # Tafiti production/staging
|
| 46 |
+
r"^https://.*\.vercel\.app$", # Vercel deployments
|
| 47 |
+
r"^https://.*\.hf\.space$", # HuggingFace Spaces
|
| 48 |
+
r"^https://milimani\.tafiti-api\.org$", # Production API
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
# Mobile app protocols
|
| 52 |
+
MOBILE_PROTOCOLS = [
|
| 53 |
+
"capacitor://localhost", # Capacitor apps
|
| 54 |
+
"ionic://localhost", # Ionic apps
|
| 55 |
+
"http://localhost", # Mobile WebView
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
def __init__(self):
|
| 59 |
+
self.environment = self._get_environment()
|
| 60 |
+
self.allowed_origins = self._build_allowed_origins()
|
| 61 |
+
self.allow_all = self._should_allow_all()
|
| 62 |
+
self.origin_patterns = self._build_origin_patterns()
|
| 63 |
+
|
| 64 |
+
def _get_environment(self) -> Environment:
|
| 65 |
+
"""Get current deployment environment"""
|
| 66 |
+
env_str = os.getenv("ENVIRONMENT", "local").lower()
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
return Environment(env_str)
|
| 70 |
+
except ValueError:
|
| 71 |
+
print(f"⚠️ Unknown environment '{env_str}', defaulting to 'local'")
|
| 72 |
+
return Environment.LOCAL
|
| 73 |
+
|
| 74 |
+
def _should_allow_all(self) -> bool:
|
| 75 |
+
"""Check if CORS should allow all origins (insecure, dev only)"""
|
| 76 |
+
allow_all = os.getenv("CORS_ALLOW_ALL", "false").lower()
|
| 77 |
+
|
| 78 |
+
if allow_all == "true":
|
| 79 |
+
if self.environment == Environment.PRODUCTION:
|
| 80 |
+
print("❌ ERROR: CORS_ALLOW_ALL=true is not allowed in production")
|
| 81 |
+
return False
|
| 82 |
+
else:
|
| 83 |
+
print("⚠️ WARNING: CORS allowing all origins - INSECURE, use only for development")
|
| 84 |
+
return True
|
| 85 |
+
|
| 86 |
+
return False
|
| 87 |
+
|
| 88 |
+
def _build_allowed_origins(self) -> List[str]:
|
| 89 |
+
"""Build list of allowed origins from environment and defaults"""
|
| 90 |
+
origins = []
|
| 91 |
+
|
| 92 |
+
# Get custom origins from environment variable
|
| 93 |
+
custom_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "")
|
| 94 |
+
|
| 95 |
+
if custom_origins_str:
|
| 96 |
+
# Parse comma-separated origins
|
| 97 |
+
custom_origins = [
|
| 98 |
+
origin.strip()
|
| 99 |
+
for origin in custom_origins_str.split(",")
|
| 100 |
+
if origin.strip()
|
| 101 |
+
]
|
| 102 |
+
origins.extend(custom_origins)
|
| 103 |
+
print(f"✓ Loaded {len(custom_origins)} custom CORS origins from environment")
|
| 104 |
+
|
| 105 |
+
# Add defaults based on environment
|
| 106 |
+
if self.environment == Environment.LOCAL:
|
| 107 |
+
origins.extend(self.DEFAULT_LOCAL_ORIGINS)
|
| 108 |
+
print(f"✓ Added {len(self.DEFAULT_LOCAL_ORIGINS)} default local origins")
|
| 109 |
+
|
| 110 |
+
# Always include mobile protocols in non-production
|
| 111 |
+
if self.environment != Environment.PRODUCTION:
|
| 112 |
+
origins.extend(self.MOBILE_PROTOCOLS)
|
| 113 |
+
print(f"✓ Added {len(self.MOBILE_PROTOCOLS)} mobile protocol origins")
|
| 114 |
+
|
| 115 |
+
# Remove duplicates while preserving order
|
| 116 |
+
seen = set()
|
| 117 |
+
unique_origins = []
|
| 118 |
+
for origin in origins:
|
| 119 |
+
if origin not in seen:
|
| 120 |
+
seen.add(origin)
|
| 121 |
+
unique_origins.append(origin)
|
| 122 |
+
|
| 123 |
+
return unique_origins
|
| 124 |
+
|
| 125 |
+
def _build_origin_patterns(self) -> List[re.Pattern]:
|
| 126 |
+
"""Build regex patterns for origin matching"""
|
| 127 |
+
patterns = []
|
| 128 |
+
|
| 129 |
+
# Get custom patterns from environment
|
| 130 |
+
custom_patterns_str = os.getenv("CORS_ALLOWED_PATTERNS", "")
|
| 131 |
+
|
| 132 |
+
if custom_patterns_str:
|
| 133 |
+
custom_pattern_strs = [
|
| 134 |
+
p.strip()
|
| 135 |
+
for p in custom_patterns_str.split(",")
|
| 136 |
+
if p.strip()
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
for pattern_str in custom_pattern_strs:
|
| 140 |
+
try:
|
| 141 |
+
patterns.append(re.compile(pattern_str))
|
| 142 |
+
except re.error as e:
|
| 143 |
+
print(f"⚠️ Invalid regex pattern '{pattern_str}': {e}")
|
| 144 |
+
|
| 145 |
+
print(f"✓ Loaded {len(patterns)} custom CORS patterns from environment")
|
| 146 |
+
|
| 147 |
+
# Add default production patterns if in production/staging/development
|
| 148 |
+
if self.environment in [Environment.PRODUCTION, Environment.STAGING, Environment.DEVELOPMENT]:
|
| 149 |
+
for pattern_str in self.DEFAULT_PRODUCTION_PATTERNS:
|
| 150 |
+
patterns.append(re.compile(pattern_str))
|
| 151 |
+
|
| 152 |
+
print(f"✓ Added {len(self.DEFAULT_PRODUCTION_PATTERNS)} default production patterns")
|
| 153 |
+
|
| 154 |
+
# Add localhost pattern for development
|
| 155 |
+
if self.environment == Environment.LOCAL:
|
| 156 |
+
patterns.append(re.compile(r"^http://localhost:\d+$"))
|
| 157 |
+
patterns.append(re.compile(r"^http://127\.0\.0\.1:\d+$"))
|
| 158 |
+
print("✓ Added localhost wildcard patterns for development")
|
| 159 |
+
|
| 160 |
+
return patterns
|
| 161 |
+
|
| 162 |
+
def is_origin_allowed(self, origin: str) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Check if an origin is allowed based on explicit list or patterns
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
origin: Origin to check (e.g., "https://app.tafiti.dev")
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
True if origin is allowed, False otherwise
|
| 171 |
+
"""
|
| 172 |
+
# If allow_all is enabled (dev only)
|
| 173 |
+
if self.allow_all:
|
| 174 |
+
return True
|
| 175 |
+
|
| 176 |
+
# Check explicit origins list
|
| 177 |
+
if origin in self.allowed_origins:
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
# Check against patterns
|
| 181 |
+
for pattern in self.origin_patterns:
|
| 182 |
+
if pattern.match(origin):
|
| 183 |
+
return True
|
| 184 |
+
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def get_cors_middleware_config(self) -> dict:
|
| 188 |
+
"""Get configuration dict for FastAPI CORSMiddleware"""
|
| 189 |
+
if self.allow_all:
|
| 190 |
+
return {
|
| 191 |
+
"allow_origins": ["*"],
|
| 192 |
+
"allow_credentials": False, # Cannot use credentials with wildcard
|
| 193 |
+
"allow_methods": ["*"],
|
| 194 |
+
"allow_headers": ["*"],
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
# Build origin regex for pattern matching
|
| 198 |
+
if self.origin_patterns:
|
| 199 |
+
# Combine all patterns into a single regex
|
| 200 |
+
combined_pattern = "|".join(f"({p.pattern})" for p in self.origin_patterns)
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"allow_origins": self.allowed_origins,
|
| 204 |
+
"allow_origin_regex": combined_pattern,
|
| 205 |
+
"allow_credentials": True,
|
| 206 |
+
"allow_methods": ["*"],
|
| 207 |
+
"allow_headers": ["*"],
|
| 208 |
+
}
|
| 209 |
+
else:
|
| 210 |
+
return {
|
| 211 |
+
"allow_origins": self.allowed_origins,
|
| 212 |
+
"allow_credentials": True,
|
| 213 |
+
"allow_methods": ["*"],
|
| 214 |
+
"allow_headers": ["*"],
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
def get_socketio_cors_origins(self):
|
| 218 |
+
"""
|
| 219 |
+
Get CORS origins for Socket.IO
|
| 220 |
+
|
| 221 |
+
Socket.IO doesn't support regex patterns, so we need to provide explicit list.
|
| 222 |
+
For production, this means we need to enumerate common origins.
|
| 223 |
+
"""
|
| 224 |
+
if self.allow_all:
|
| 225 |
+
return "*"
|
| 226 |
+
|
| 227 |
+
# For Socket.IO, we can only provide explicit origins
|
| 228 |
+
# In production, we may need to enumerate common subdomains
|
| 229 |
+
socketio_origins = self.allowed_origins.copy()
|
| 230 |
+
|
| 231 |
+
# Add common production subdomains if using production patterns
|
| 232 |
+
if self.environment in [Environment.PRODUCTION, Environment.STAGING]:
|
| 233 |
+
# These should be added to CORS_ALLOWED_ORIGINS for Socket.IO support
|
| 234 |
+
production_origins = [
|
| 235 |
+
"https://app.tafiti.dev",
|
| 236 |
+
"https://www.tafiti.dev",
|
| 237 |
+
"https://polyglot.tafiti.dev",
|
| 238 |
+
"https://podium.tafiti.dev",
|
| 239 |
+
"https://milimani.tafiti-api.org",
|
| 240 |
+
"https://polyglot-ashy-beta.vercel.app",
|
| 241 |
+
"https://lessons-silk.vercel.app",
|
| 242 |
+
"https://lessons.tafiti.dev",
|
| 243 |
+
"https://podium-chi.vercel.app",
|
| 244 |
+
]
|
| 245 |
+
for origin in production_origins:
|
| 246 |
+
if origin not in socketio_origins:
|
| 247 |
+
socketio_origins.append(origin)
|
| 248 |
+
|
| 249 |
+
return socketio_origins
|
| 250 |
+
|
| 251 |
+
def print_config_summary(self):
|
| 252 |
+
"""Print CORS configuration summary for debugging"""
|
| 253 |
+
print("\n" + "="*70)
|
| 254 |
+
print("CORS CONFIGURATION SUMMARY")
|
| 255 |
+
print("="*70)
|
| 256 |
+
print(f"Environment: {self.environment.value}")
|
| 257 |
+
print(f"Allow All: {self.allow_all}")
|
| 258 |
+
print(f"\nExplicit Origins ({len(self.allowed_origins)}):")
|
| 259 |
+
for origin in self.allowed_origins:
|
| 260 |
+
print(f" • {origin}")
|
| 261 |
+
|
| 262 |
+
if self.origin_patterns:
|
| 263 |
+
print(f"\nOrigin Patterns ({len(self.origin_patterns)}):")
|
| 264 |
+
for pattern in self.origin_patterns:
|
| 265 |
+
print(f" • {pattern.pattern}")
|
| 266 |
+
|
| 267 |
+
print("\nExample Origins That Would Be Allowed:")
|
| 268 |
+
test_origins = [
|
| 269 |
+
"http://localhost:3001",
|
| 270 |
+
"http://localhost:3002",
|
| 271 |
+
"http://localhost:3003",
|
| 272 |
+
"http://localhost:3004",
|
| 273 |
+
"http://localhost:5173",
|
| 274 |
+
"https://app.tafiti.dev",
|
| 275 |
+
"https://polyglot.tafiti.dev",
|
| 276 |
+
"https://podium.tafiti.dev",
|
| 277 |
+
"https://polyglot.vercel.app",
|
| 278 |
+
"https://lessons-silk.vercel.app",
|
| 279 |
+
"https://podium-chi.vercel.app",
|
| 280 |
+
"https://polyglot-ashy-beta.vercel.app",
|
| 281 |
+
"https://mutisya-translator.hf.space",
|
| 282 |
+
"https://milimani.tafiti-api.org",
|
| 283 |
+
"capacitor://localhost",
|
| 284 |
+
"https://example.com",
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
for test_origin in test_origins:
|
| 288 |
+
allowed = "✓" if self.is_origin_allowed(test_origin) else "✗"
|
| 289 |
+
print(f" {allowed} {test_origin}")
|
| 290 |
+
|
| 291 |
+
print("="*70 + "\n")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Global CORS configuration instance
|
| 295 |
+
cors_config = CORSConfig()
|
app/main.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.staticfiles import StaticFiles
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
import logging
|
| 9 |
+
import socketio
|
| 10 |
+
import engineio
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from app.routers import sessions, mobile, watch, learning
|
| 14 |
+
from app.services.session_manager import SessionManager
|
| 15 |
+
from app.services.transcription_service import TranscriptionService
|
| 16 |
+
from app.services.translation_service import TranslationService
|
| 17 |
+
from app.services.tts_service import TTSService
|
| 18 |
+
from app.services.websocket_manager import WebSocketManager
|
| 19 |
+
from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
|
| 20 |
+
from app.config.cors import cors_config
|
| 21 |
+
|
| 22 |
+
class ChunkArrayTruncateFilter(logging.Filter):
|
| 23 |
+
"""Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
|
| 24 |
+
|
| 25 |
+
def filter(self, record):
|
| 26 |
+
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
| 27 |
+
# More aggressive approach to truncate audioData arrays
|
| 28 |
+
# Pattern to match: "audioData":[numbers,numbers,numbers,...]
|
| 29 |
+
audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
|
| 30 |
+
|
| 31 |
+
def truncate_audiodata(match):
|
| 32 |
+
array_content = match.group(1)
|
| 33 |
+
# Split by comma and get first 10 items
|
| 34 |
+
items = array_content.split(',')
|
| 35 |
+
if len(items) > 10:
|
| 36 |
+
truncated = ','.join(items[:10])
|
| 37 |
+
return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
|
| 38 |
+
return match.group(0)
|
| 39 |
+
|
| 40 |
+
record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
|
| 41 |
+
|
| 42 |
+
# Also handle any other large numeric arrays in brackets
|
| 43 |
+
# Pattern for arrays with more than 20 numbers
|
| 44 |
+
large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
|
| 45 |
+
|
| 46 |
+
def truncate_large_numeric_array(match):
|
| 47 |
+
prefix = match.group(1)
|
| 48 |
+
array_content = match.group(2)
|
| 49 |
+
suffix = match.group(3)
|
| 50 |
+
|
| 51 |
+
# Split by comma and get first 10 items
|
| 52 |
+
items = array_content.split(',')
|
| 53 |
+
if len(items) > 10:
|
| 54 |
+
truncated = ','.join(items[:10])
|
| 55 |
+
return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
|
| 56 |
+
return match.group(0)
|
| 57 |
+
|
| 58 |
+
record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
|
| 59 |
+
|
| 60 |
+
# Truncate other field types
|
| 61 |
+
for field_name in ['chunk', 'wavChunk', 'data']:
|
| 62 |
+
field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
|
| 63 |
+
def make_truncate_field(fname):
|
| 64 |
+
def truncate_field(match):
|
| 65 |
+
array_content = match.group(1)
|
| 66 |
+
items = array_content.split(',')
|
| 67 |
+
if len(items) > 10:
|
| 68 |
+
truncated = ','.join(items[:10])
|
| 69 |
+
return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
|
| 70 |
+
return match.group(0)
|
| 71 |
+
return truncate_field
|
| 72 |
+
|
| 73 |
+
record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
|
| 74 |
+
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@asynccontextmanager
|
| 79 |
+
async def lifespan(app: FastAPI):
|
| 80 |
+
# Initialize services
|
| 81 |
+
print("=== INITIALIZING BACKEND SERVICES ===")
|
| 82 |
+
try:
|
| 83 |
+
print("Initializing transcription service...")
|
| 84 |
+
await transcription_service.initialize()
|
| 85 |
+
print("✓ Transcription service initialized")
|
| 86 |
+
|
| 87 |
+
print("Initializing translation service...")
|
| 88 |
+
await translation_service.initialize()
|
| 89 |
+
print("✓ Translation service initialized")
|
| 90 |
+
|
| 91 |
+
print("Initializing TTS service...")
|
| 92 |
+
await tts_service.initialize()
|
| 93 |
+
print("✓ TTS service initialized")
|
| 94 |
+
|
| 95 |
+
print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===" )
|
| 96 |
+
|
| 97 |
+
# Start background loading of additional models after successful startup
|
| 98 |
+
print("=== STARTING BACKGROUND MODEL LOADING ===")
|
| 99 |
+
transcription_service.start_background_loading()
|
| 100 |
+
tts_service.start_background_loading()
|
| 101 |
+
print("=== BACKGROUND MODEL LOADING INITIATED ===")
|
| 102 |
+
|
| 103 |
+
# Print CORS configuration summary
|
| 104 |
+
cors_config.print_config_summary()
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
|
| 108 |
+
import traceback
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
yield
|
| 113 |
+
|
| 114 |
+
# Cleanup
|
| 115 |
+
print("=== CLEANING UP SERVICES ===")
|
| 116 |
+
await transcription_service.cleanup()
|
| 117 |
+
await translation_service.cleanup()
|
| 118 |
+
await tts_service.cleanup()
|
| 119 |
+
print("=== CLEANUP COMPLETE ===")
|
| 120 |
+
|
| 121 |
+
app = FastAPI(
|
| 122 |
+
title="Real-time Transcription & Translation API",
|
| 123 |
+
description="Backend API for real-time speech transcription and translation",
|
| 124 |
+
version="1.0.0",
|
| 125 |
+
lifespan=lifespan
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# CORS middleware with environment-based configuration
|
| 129 |
+
cors_middleware_config = cors_config.get_cors_middleware_config()
|
| 130 |
+
print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
|
| 131 |
+
|
| 132 |
+
app.add_middleware(
|
| 133 |
+
CORSMiddleware,
|
| 134 |
+
**cors_middleware_config
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Initialize services - using PyTorch models for better compatibility
|
| 138 |
+
session_manager = SessionManager()
|
| 139 |
+
transcription_service = TranscriptionService()
|
| 140 |
+
translation_service = TranslationService()
|
| 141 |
+
tts_service = TTSService()
|
| 142 |
+
websocket_manager = WebSocketManager(
|
| 143 |
+
session_manager=session_manager,
|
| 144 |
+
transcription_service=transcription_service,
|
| 145 |
+
translation_service=translation_service,
|
| 146 |
+
tts_service=tts_service
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Include routers
|
| 150 |
+
app.include_router(sessions.router, prefix="/api")
|
| 151 |
+
app.include_router(mobile.router, prefix="/api")
|
| 152 |
+
app.include_router(watch.router, prefix="/api")
|
| 153 |
+
app.include_router(learning.router)
|
| 154 |
+
|
| 155 |
+
# Set the session manager in the router
|
| 156 |
+
sessions.session_manager = session_manager
|
| 157 |
+
sessions.translation_service = translation_service
|
| 158 |
+
sessions.tts_service = tts_service
|
| 159 |
+
sessions.transcription_service = transcription_service
|
| 160 |
+
|
| 161 |
+
# Set the mobile router
|
| 162 |
+
mobile.translation_service = translation_service
|
| 163 |
+
mobile.tts_service = tts_service
|
| 164 |
+
mobile.transcription_service = transcription_service
|
| 165 |
+
|
| 166 |
+
# Set the watch router
|
| 167 |
+
watch.translation_service = translation_service
|
| 168 |
+
watch.tts_service = tts_service
|
| 169 |
+
watch.transcription_service = transcription_service
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Configure logging with custom filter to truncate chunk arrays
|
| 173 |
+
chunk_filter = ChunkArrayTruncateFilter()
|
| 174 |
+
|
| 175 |
+
sio_logger = logging.getLogger('socketio')
|
| 176 |
+
sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
|
| 177 |
+
sio_logger.addFilter(chunk_filter)
|
| 178 |
+
|
| 179 |
+
engineio_logger = logging.getLogger('engineio')
|
| 180 |
+
engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
|
| 181 |
+
engineio_logger.addFilter(chunk_filter)
|
| 182 |
+
|
| 183 |
+
# Also apply filter to the root logger to catch any other verbose logging
|
| 184 |
+
root_logger = logging.getLogger()
|
| 185 |
+
root_logger.addFilter(chunk_filter)
|
| 186 |
+
|
| 187 |
+
# Configure Engine.IO payload limits for large audio chunks
|
| 188 |
+
engineio.payload.Payload.max_decode_packets = 250
|
| 189 |
+
|
| 190 |
+
# Socket.IO setup with environment-based CORS
|
| 191 |
+
socketio_cors_origins = cors_config.get_socketio_cors_origins()
|
| 192 |
+
print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
|
| 193 |
+
|
| 194 |
+
sio = socketio.AsyncServer(
|
| 195 |
+
async_mode='asgi',
|
| 196 |
+
cors_allowed_origins=socketio_cors_origins,
|
| 197 |
+
cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
|
| 198 |
+
logger=True, # Re-enabled with custom filtering
|
| 199 |
+
engineio_logger=True, # Re-enabled with custom filtering
|
| 200 |
+
always_connect=False # This ensures connect event is called for authentication
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Set the socketio instance in websocket manager
|
| 204 |
+
websocket_manager.set_socketio(sio)
|
| 205 |
+
|
| 206 |
+
socket_app = socketio.ASGIApp(sio, app)
|
| 207 |
+
|
| 208 |
+
@app.get("/health")
|
| 209 |
+
async def health_check(token: str = Depends(optional_hf_token)):
|
| 210 |
+
"""Health check endpoint - optionally authenticated"""
|
| 211 |
+
from app.auth import hf_auth
|
| 212 |
+
|
| 213 |
+
auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
|
| 214 |
+
if not hf_auth.is_local and not token:
|
| 215 |
+
auth_status = "unauthenticated"
|
| 216 |
+
|
| 217 |
+
return {
|
| 218 |
+
"status": "healthy",
|
| 219 |
+
"message": "Translation service is running",
|
| 220 |
+
"auth_status": auth_status,
|
| 221 |
+
"local_development": hf_auth.is_local,
|
| 222 |
+
"auth_bypassed": hf_auth.is_local,
|
| 223 |
+
"token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
|
| 224 |
+
"environment": {
|
| 225 |
+
"ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
|
| 226 |
+
"DEBUG": os.getenv('DEBUG', 'not set'),
|
| 227 |
+
"DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
|
| 228 |
+
"HOST": os.getenv('HOST', 'not set'),
|
| 229 |
+
"PORT": os.getenv('PORT', 'not set')
|
| 230 |
+
},
|
| 231 |
+
"services": {
|
| 232 |
+
"transcription": transcription_service is not None,
|
| 233 |
+
"translation": translation_service is not None,
|
| 234 |
+
"tts": tts_service is not None,
|
| 235 |
+
"sessions": session_manager is not None
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@sio.event
|
| 240 |
+
async def connect(sid, environ=None, auth=None):
|
| 241 |
+
"""Handle Socket.IO connection with authentication"""
|
| 242 |
+
try:
|
| 243 |
+
print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
|
| 244 |
+
print(f"SID: {sid}")
|
| 245 |
+
print(f"Auth data: {auth}")
|
| 246 |
+
print(f"Environ type: {type(environ)}")
|
| 247 |
+
print(f"Environ data: {environ}")
|
| 248 |
+
|
| 249 |
+
# Ensure environ is a dict
|
| 250 |
+
if environ is None:
|
| 251 |
+
environ = {}
|
| 252 |
+
|
| 253 |
+
print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
|
| 254 |
+
print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
|
| 255 |
+
|
| 256 |
+
# Check authentication from multiple sources
|
| 257 |
+
authenticated = False
|
| 258 |
+
auth_method = None
|
| 259 |
+
|
| 260 |
+
# Method 1: Check auth data from client
|
| 261 |
+
if auth and authenticate_websocket_auth_data(auth):
|
| 262 |
+
authenticated = True
|
| 263 |
+
auth_method = "auth_data"
|
| 264 |
+
print("✓ Authenticated via auth data")
|
| 265 |
+
|
| 266 |
+
# Method 2: Check environment (headers, query params)
|
| 267 |
+
elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
|
| 268 |
+
authenticated = True
|
| 269 |
+
auth_method = "environ"
|
| 270 |
+
print("✓ Authenticated via headers/query")
|
| 271 |
+
|
| 272 |
+
# TEMPORARY: Allow connections for debugging (remove in production)
|
| 273 |
+
# This helps identify if the issue is authentication or something else
|
| 274 |
+
if not authenticated:
|
| 275 |
+
print("⚠️ Authentication failed, but allowing for debugging")
|
| 276 |
+
if isinstance(environ, dict):
|
| 277 |
+
print(f"Available environ keys: {list(environ.keys())}")
|
| 278 |
+
# Uncomment the next line to temporarily allow unauthenticated connections for debugging
|
| 279 |
+
authenticated = True
|
| 280 |
+
auth_method = "debug_bypass"
|
| 281 |
+
|
| 282 |
+
if not authenticated:
|
| 283 |
+
print("❌ Authentication failed - disconnecting")
|
| 284 |
+
await sio.disconnect(sid)
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
|
| 288 |
+
return True
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
print(f"❌ Error in connect handler: {e}")
|
| 292 |
+
import traceback
|
| 293 |
+
traceback.print_exc()
|
| 294 |
+
try:
|
| 295 |
+
await sio.disconnect(sid)
|
| 296 |
+
except:
|
| 297 |
+
pass
|
| 298 |
+
return False
|
| 299 |
+
|
| 300 |
+
@sio.event
|
| 301 |
+
async def disconnect(sid):
|
| 302 |
+
await websocket_manager.handle_disconnect(sid)
|
| 303 |
+
|
| 304 |
+
@sio.event
|
| 305 |
+
async def join_session(sid, data):
|
| 306 |
+
await websocket_manager.handle_join_session(sid, data)
|
| 307 |
+
|
| 308 |
+
@sio.event
|
| 309 |
+
async def join_hub(sid, data):
|
| 310 |
+
await websocket_manager.handle_join_hub(sid, data)
|
| 311 |
+
|
| 312 |
+
@sio.event
|
| 313 |
+
async def leave_session(sid, data):
|
| 314 |
+
await websocket_manager.handle_leave_session(sid, data)
|
| 315 |
+
|
| 316 |
+
@sio.event
|
| 317 |
+
async def audio_chunk(sid, data):
|
| 318 |
+
await websocket_manager.handle_audio_chunk(sid, data)
|
| 319 |
+
|
| 320 |
+
@sio.event
|
| 321 |
+
async def speaking_status(sid, data):
|
| 322 |
+
await websocket_manager.handle_speaking_status(sid, data)
|
| 323 |
+
|
| 324 |
+
@sio.event
|
| 325 |
+
async def test_echo(sid, data):
|
| 326 |
+
"""Test event to verify WebSocket communication"""
|
| 327 |
+
await sio.emit('test_echo_response', data, room=sid)
|
| 328 |
+
|
| 329 |
+
@sio.event
|
| 330 |
+
async def update_participant_language(sid, data):
|
| 331 |
+
"""Update participant's language (affects speech recognition)"""
|
| 332 |
+
await websocket_manager.handle_update_participant_language(sid, data)
|
| 333 |
+
|
| 334 |
+
@sio.event
|
| 335 |
+
async def update_session_languages(sid, data):
|
| 336 |
+
"""Update session's languages (affects translation targets)"""
|
| 337 |
+
await websocket_manager.handle_update_session_languages(sid, data)
|
| 338 |
+
|
| 339 |
+
# Serve static files (for frontend)
|
| 340 |
+
if os.path.exists("../frontend/dist"):
|
| 341 |
+
app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
import uvicorn
|
| 345 |
+
uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
|
app/main.py.bak
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os
|
| 3 |
+
import asyncio
|
| 4 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.staticfiles import StaticFiles
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
import logging
|
| 9 |
+
import socketio
|
| 10 |
+
import engineio
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
from app.routers import sessions, mobile, watch, learning
|
| 14 |
+
from app.services.session_manager import SessionManager
|
| 15 |
+
from app.services.transcription_service import TranscriptionService
|
| 16 |
+
from app.services.translation_service import TranslationService
|
| 17 |
+
from app.services.tts_service import TTSService
|
| 18 |
+
from app.services.websocket_manager import WebSocketManager
|
| 19 |
+
from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
|
| 20 |
+
from app.config.cors import cors_config
|
| 21 |
+
|
| 22 |
+
class ChunkArrayTruncateFilter(logging.Filter):
|
| 23 |
+
"""Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
|
| 24 |
+
|
| 25 |
+
def filter(self, record):
|
| 26 |
+
if hasattr(record, 'msg') and isinstance(record.msg, str):
|
| 27 |
+
# More aggressive approach to truncate audioData arrays
|
| 28 |
+
# Pattern to match: "audioData":[numbers,numbers,numbers,...]
|
| 29 |
+
audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
|
| 30 |
+
|
| 31 |
+
def truncate_audiodata(match):
|
| 32 |
+
array_content = match.group(1)
|
| 33 |
+
# Split by comma and get first 10 items
|
| 34 |
+
items = array_content.split(',')
|
| 35 |
+
if len(items) > 10:
|
| 36 |
+
truncated = ','.join(items[:10])
|
| 37 |
+
return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
|
| 38 |
+
return match.group(0)
|
| 39 |
+
|
| 40 |
+
record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
|
| 41 |
+
|
| 42 |
+
# Also handle any other large numeric arrays in brackets
|
| 43 |
+
# Pattern for arrays with more than 20 numbers
|
| 44 |
+
large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
|
| 45 |
+
|
| 46 |
+
def truncate_large_numeric_array(match):
|
| 47 |
+
prefix = match.group(1)
|
| 48 |
+
array_content = match.group(2)
|
| 49 |
+
suffix = match.group(3)
|
| 50 |
+
|
| 51 |
+
# Split by comma and get first 10 items
|
| 52 |
+
items = array_content.split(',')
|
| 53 |
+
if len(items) > 10:
|
| 54 |
+
truncated = ','.join(items[:10])
|
| 55 |
+
return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
|
| 56 |
+
return match.group(0)
|
| 57 |
+
|
| 58 |
+
record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
|
| 59 |
+
|
| 60 |
+
# Truncate other field types
|
| 61 |
+
for field_name in ['chunk', 'wavChunk', 'data']:
|
| 62 |
+
field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
|
| 63 |
+
def make_truncate_field(fname):
|
| 64 |
+
def truncate_field(match):
|
| 65 |
+
array_content = match.group(1)
|
| 66 |
+
items = array_content.split(',')
|
| 67 |
+
if len(items) > 10:
|
| 68 |
+
truncated = ','.join(items[:10])
|
| 69 |
+
return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
|
| 70 |
+
return match.group(0)
|
| 71 |
+
return truncate_field
|
| 72 |
+
|
| 73 |
+
record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
|
| 74 |
+
|
| 75 |
+
return True
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@asynccontextmanager
|
| 79 |
+
async def lifespan(app: FastAPI):
|
| 80 |
+
# Initialize services
|
| 81 |
+
print("=== INITIALIZING BACKEND SERVICES ===")
|
| 82 |
+
try:
|
| 83 |
+
print("Initializing transcription service...")
|
| 84 |
+
await transcription_service.initialize()
|
| 85 |
+
print("✓ Transcription service initialized")
|
| 86 |
+
|
| 87 |
+
print("Initializing translation service...")
|
| 88 |
+
await translation_service.initialize()
|
| 89 |
+
print("✓ Translation service initialized")
|
| 90 |
+
|
| 91 |
+
print("Initializing TTS service...")
|
| 92 |
+
await tts_service.initialize()
|
| 93 |
+
print("✓ TTS service initialized")
|
| 94 |
+
|
| 95 |
+
print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===")
|
| 96 |
+
|
| 97 |
+
# Start background loading of additional models after successful startup
|
| 98 |
+
print("=== STARTING BACKGROUND MODEL LOADING ===")
|
| 99 |
+
transcription_service.start_background_loading()
|
| 100 |
+
tts_service.start_background_loading()
|
| 101 |
+
print("=== BACKGROUND MODEL LOADING INITIATED ===")
|
| 102 |
+
|
| 103 |
+
# Print CORS configuration summary
|
| 104 |
+
cors_config.print_config_summary()
|
| 105 |
+
|
| 106 |
+
except Exception as e:
|
| 107 |
+
print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
|
| 108 |
+
import traceback
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
yield
|
| 113 |
+
|
| 114 |
+
# Cleanup
|
| 115 |
+
print("=== CLEANING UP SERVICES ===")
|
| 116 |
+
await transcription_service.cleanup()
|
| 117 |
+
await translation_service.cleanup()
|
| 118 |
+
await tts_service.cleanup()
|
| 119 |
+
print("=== CLEANUP COMPLETE ===")
|
| 120 |
+
|
| 121 |
+
app = FastAPI(
|
| 122 |
+
title="Real-time Transcription & Translation API",
|
| 123 |
+
description="Backend API for real-time speech transcription and translation",
|
| 124 |
+
version="1.0.0",
|
| 125 |
+
lifespan=lifespan
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# CORS middleware with environment-based configuration
|
| 129 |
+
cors_middleware_config = cors_config.get_cors_middleware_config()
|
| 130 |
+
print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
|
| 131 |
+
|
| 132 |
+
app.add_middleware(
|
| 133 |
+
CORSMiddleware,
|
| 134 |
+
**cors_middleware_config
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Initialize services - using PyTorch models for better compatibility
|
| 138 |
+
session_manager = SessionManager()
|
| 139 |
+
transcription_service = TranscriptionService()
|
| 140 |
+
translation_service = TranslationService()
|
| 141 |
+
tts_service = TTSService()
|
| 142 |
+
websocket_manager = WebSocketManager(
|
| 143 |
+
session_manager=session_manager,
|
| 144 |
+
transcription_service=transcription_service,
|
| 145 |
+
translation_service=translation_service,
|
| 146 |
+
tts_service=tts_service
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Include routers
|
| 150 |
+
app.include_router(sessions.router, prefix="/api")
|
| 151 |
+
app.include_router(mobile.router, prefix="/api")
|
| 152 |
+
app.include_router(watch.router, prefix="/api")
|
| 153 |
+
app.include_router(learning.router)
|
| 154 |
+
|
| 155 |
+
# Set the session manager in the router
|
| 156 |
+
sessions.session_manager = session_manager
|
| 157 |
+
sessions.translation_service = translation_service
|
| 158 |
+
sessions.tts_service = tts_service
|
| 159 |
+
sessions.transcription_service = transcription_service
|
| 160 |
+
|
| 161 |
+
# Set the mobile router
|
| 162 |
+
mobile.translation_service = translation_service
|
| 163 |
+
mobile.tts_service = tts_service
|
| 164 |
+
mobile.transcription_service = transcription_service
|
| 165 |
+
|
| 166 |
+
# Set the watch router
|
| 167 |
+
watch.translation_service = translation_service
|
| 168 |
+
watch.tts_service = tts_service
|
| 169 |
+
watch.transcription_service = transcription_service
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
# Configure logging with custom filter to truncate chunk arrays
|
| 173 |
+
chunk_filter = ChunkArrayTruncateFilter()
|
| 174 |
+
|
| 175 |
+
sio_logger = logging.getLogger('socketio')
|
| 176 |
+
sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
|
| 177 |
+
sio_logger.addFilter(chunk_filter)
|
| 178 |
+
|
| 179 |
+
engineio_logger = logging.getLogger('engineio')
|
| 180 |
+
engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
|
| 181 |
+
engineio_logger.addFilter(chunk_filter)
|
| 182 |
+
|
| 183 |
+
# Also apply filter to the root logger to catch any other verbose logging
|
| 184 |
+
root_logger = logging.getLogger()
|
| 185 |
+
root_logger.addFilter(chunk_filter)
|
| 186 |
+
|
| 187 |
+
# Configure Engine.IO payload limits for large audio chunks
|
| 188 |
+
engineio.payload.Payload.max_decode_packets = 250
|
| 189 |
+
|
| 190 |
+
# Socket.IO setup with environment-based CORS
|
| 191 |
+
socketio_cors_origins = cors_config.get_socketio_cors_origins()
|
| 192 |
+
print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
|
| 193 |
+
|
| 194 |
+
sio = socketio.AsyncServer(
|
| 195 |
+
async_mode='asgi',
|
| 196 |
+
cors_allowed_origins=socketio_cors_origins,
|
| 197 |
+
cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
|
| 198 |
+
logger=True, # Re-enabled with custom filtering
|
| 199 |
+
engineio_logger=True, # Re-enabled with custom filtering
|
| 200 |
+
always_connect=False # This ensures connect event is called for authentication
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Set the socketio instance in websocket manager
|
| 204 |
+
websocket_manager.set_socketio(sio)
|
| 205 |
+
|
| 206 |
+
socket_app = socketio.ASGIApp(sio, app)
|
| 207 |
+
|
| 208 |
+
@app.get("/health")
|
| 209 |
+
async def health_check(token: str = Depends(optional_hf_token)):
|
| 210 |
+
"""Health check endpoint - optionally authenticated"""
|
| 211 |
+
from app.auth import hf_auth
|
| 212 |
+
|
| 213 |
+
auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
|
| 214 |
+
if not hf_auth.is_local and not token:
|
| 215 |
+
auth_status = "unauthenticated"
|
| 216 |
+
|
| 217 |
+
return {
|
| 218 |
+
"status": "healthy",
|
| 219 |
+
"message": "Translation service is running",
|
| 220 |
+
"auth_status": auth_status,
|
| 221 |
+
"local_development": hf_auth.is_local,
|
| 222 |
+
"auth_bypassed": hf_auth.is_local,
|
| 223 |
+
"token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
|
| 224 |
+
"environment": {
|
| 225 |
+
"ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
|
| 226 |
+
"DEBUG": os.getenv('DEBUG', 'not set'),
|
| 227 |
+
"DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
|
| 228 |
+
"HOST": os.getenv('HOST', 'not set'),
|
| 229 |
+
"PORT": os.getenv('PORT', 'not set')
|
| 230 |
+
},
|
| 231 |
+
"services": {
|
| 232 |
+
"transcription": transcription_service is not None,
|
| 233 |
+
"translation": translation_service is not None,
|
| 234 |
+
"tts": tts_service is not None,
|
| 235 |
+
"sessions": session_manager is not None
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@sio.event
|
| 240 |
+
async def connect(sid, environ=None, auth=None):
|
| 241 |
+
"""Handle Socket.IO connection with authentication"""
|
| 242 |
+
try:
|
| 243 |
+
print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
|
| 244 |
+
print(f"SID: {sid}")
|
| 245 |
+
print(f"Auth data: {auth}")
|
| 246 |
+
print(f"Environ type: {type(environ)}")
|
| 247 |
+
print(f"Environ data: {environ}")
|
| 248 |
+
|
| 249 |
+
# Ensure environ is a dict
|
| 250 |
+
if environ is None:
|
| 251 |
+
environ = {}
|
| 252 |
+
|
| 253 |
+
print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
|
| 254 |
+
print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
|
| 255 |
+
|
| 256 |
+
# Check authentication from multiple sources
|
| 257 |
+
authenticated = False
|
| 258 |
+
auth_method = None
|
| 259 |
+
|
| 260 |
+
# Method 1: Check auth data from client
|
| 261 |
+
if auth and authenticate_websocket_auth_data(auth):
|
| 262 |
+
authenticated = True
|
| 263 |
+
auth_method = "auth_data"
|
| 264 |
+
print("✓ Authenticated via auth data")
|
| 265 |
+
|
| 266 |
+
# Method 2: Check environment (headers, query params)
|
| 267 |
+
elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
|
| 268 |
+
authenticated = True
|
| 269 |
+
auth_method = "environ"
|
| 270 |
+
print("✓ Authenticated via headers/query")
|
| 271 |
+
|
| 272 |
+
# TEMPORARY: Allow connections for debugging (remove in production)
|
| 273 |
+
# This helps identify if the issue is authentication or something else
|
| 274 |
+
if not authenticated:
|
| 275 |
+
print("⚠️ Authentication failed, but allowing for debugging")
|
| 276 |
+
if isinstance(environ, dict):
|
| 277 |
+
print(f"Available environ keys: {list(environ.keys())}")
|
| 278 |
+
# Uncomment the next line to temporarily allow unauthenticated connections for debugging
|
| 279 |
+
authenticated = True
|
| 280 |
+
auth_method = "debug_bypass"
|
| 281 |
+
|
| 282 |
+
if not authenticated:
|
| 283 |
+
print("❌ Authentication failed - disconnecting")
|
| 284 |
+
await sio.disconnect(sid)
|
| 285 |
+
return False
|
| 286 |
+
|
| 287 |
+
print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
|
| 288 |
+
return True
|
| 289 |
+
|
| 290 |
+
except Exception as e:
|
| 291 |
+
print(f"❌ Error in connect handler: {e}")
|
| 292 |
+
import traceback
|
| 293 |
+
traceback.print_exc()
|
| 294 |
+
try:
|
| 295 |
+
await sio.disconnect(sid)
|
| 296 |
+
except:
|
| 297 |
+
pass
|
| 298 |
+
return False
|
| 299 |
+
|
| 300 |
+
@sio.event
|
| 301 |
+
async def disconnect(sid):
|
| 302 |
+
await websocket_manager.handle_disconnect(sid)
|
| 303 |
+
|
| 304 |
+
@sio.event
|
| 305 |
+
async def join_session(sid, data):
|
| 306 |
+
await websocket_manager.handle_join_session(sid, data)
|
| 307 |
+
|
| 308 |
+
@sio.event
|
| 309 |
+
async def join_hub(sid, data):
|
| 310 |
+
await websocket_manager.handle_join_hub(sid, data)
|
| 311 |
+
|
| 312 |
+
@sio.event
|
| 313 |
+
async def leave_session(sid, data):
|
| 314 |
+
await websocket_manager.handle_leave_session(sid, data)
|
| 315 |
+
|
| 316 |
+
@sio.event
|
| 317 |
+
async def audio_chunk(sid, data):
|
| 318 |
+
await websocket_manager.handle_audio_chunk(sid, data)
|
| 319 |
+
|
| 320 |
+
@sio.event
|
| 321 |
+
async def speaking_status(sid, data):
|
| 322 |
+
await websocket_manager.handle_speaking_status(sid, data)
|
| 323 |
+
|
| 324 |
+
@sio.event
|
| 325 |
+
async def test_echo(sid, data):
|
| 326 |
+
"""Test event to verify WebSocket communication"""
|
| 327 |
+
await sio.emit('test_echo_response', data, room=sid)
|
| 328 |
+
|
| 329 |
+
@sio.event
|
| 330 |
+
async def update_participant_language(sid, data):
|
| 331 |
+
"""Update participant's language (affects speech recognition)"""
|
| 332 |
+
await websocket_manager.handle_update_participant_language(sid, data)
|
| 333 |
+
|
| 334 |
+
@sio.event
|
| 335 |
+
async def update_session_languages(sid, data):
|
| 336 |
+
"""Update session's languages (affects translation targets)"""
|
| 337 |
+
await websocket_manager.handle_update_session_languages(sid, data)
|
| 338 |
+
|
| 339 |
+
# Serve static files (for frontend)
|
| 340 |
+
if os.path.exists("../frontend/dist"):
|
| 341 |
+
app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
import uvicorn
|
| 345 |
+
uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
|
app/models/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import List, Dict, Optional
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
class LanguageCode(str, Enum):
|
| 6 |
+
ENGLISH = "eng"
|
| 7 |
+
SWAHILI = "swa"
|
| 8 |
+
KIKUYU = "kik"
|
| 9 |
+
KAMBA = "kam"
|
| 10 |
+
KIMERU = "mer"
|
| 11 |
+
LUO = "luo"
|
| 12 |
+
SOMALI = "som"
|
| 13 |
+
|
| 14 |
+
class Language(BaseModel):
|
| 15 |
+
code: LanguageCode
|
| 16 |
+
name: str
|
| 17 |
+
display_name: str
|
| 18 |
+
|
| 19 |
+
class ParticipantCreate(BaseModel):
|
| 20 |
+
name: str
|
| 21 |
+
language: LanguageCode
|
| 22 |
+
|
| 23 |
+
class Participant(BaseModel):
|
| 24 |
+
id: str
|
| 25 |
+
name: str
|
| 26 |
+
language: Language
|
| 27 |
+
is_organizer: bool = False
|
| 28 |
+
is_speaking: bool = False
|
| 29 |
+
is_connected: bool = False
|
| 30 |
+
|
| 31 |
+
class SessionCreate(BaseModel):
|
| 32 |
+
name: str
|
| 33 |
+
organizer_name: str
|
| 34 |
+
languages: List[LanguageCode]
|
| 35 |
+
enable_tts: bool = True # Enable TTS by default for backward compatibility
|
| 36 |
+
|
| 37 |
+
class Session(BaseModel):
|
| 38 |
+
id: str
|
| 39 |
+
name: str
|
| 40 |
+
organizer_name: str
|
| 41 |
+
participants: List[Participant] = []
|
| 42 |
+
languages: List[Language] = []
|
| 43 |
+
qr_code_url: Optional[str] = None
|
| 44 |
+
is_active: bool = True
|
| 45 |
+
enable_tts: bool = True # TTS enabled by default
|
| 46 |
+
|
| 47 |
+
class Message(BaseModel):
|
| 48 |
+
id: str
|
| 49 |
+
session_id: str
|
| 50 |
+
speaker_id: str
|
| 51 |
+
speaker_name: str
|
| 52 |
+
original_text: str
|
| 53 |
+
original_language: Language
|
| 54 |
+
translations: Dict[str, str] = {}
|
| 55 |
+
is_transcribing: bool = False
|
| 56 |
+
|
| 57 |
+
class TranscriptionUpdate(BaseModel):
|
| 58 |
+
message_id: str
|
| 59 |
+
text: str
|
| 60 |
+
is_complete: bool
|
| 61 |
+
confidence: Optional[float] = None
|
| 62 |
+
|
| 63 |
+
class TranslationUpdate(BaseModel):
|
| 64 |
+
message_id: str
|
| 65 |
+
target_language: LanguageCode
|
| 66 |
+
translated_text: str
|
| 67 |
+
|
| 68 |
+
class AudioChunk(BaseModel):
|
| 69 |
+
session_id: str
|
| 70 |
+
participant_id: str
|
| 71 |
+
audio_data: bytes
|
| 72 |
+
|
| 73 |
+
class WebSocketMessage(BaseModel):
|
| 74 |
+
type: str
|
| 75 |
+
data: Dict
|
| 76 |
+
session_id: str
|
| 77 |
+
participant_id: Optional[str] = None
|
app/routers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Routers package
|
app/routers/add_phase_endpoints.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Script to add remaining Phase 1-3 endpoints to learning.py
|
| 2 |
+
|
| 3 |
+
endpoints_code = """
|
| 4 |
+
|
| 5 |
+
@router.post("/vocabulary/add")
|
| 6 |
+
async def add_vocabulary_to_practice(
|
| 7 |
+
vocab_request: VocabularyAddRequest,
|
| 8 |
+
request: Request,
|
| 9 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 10 |
+
):
|
| 11 |
+
\"\"\"Add a vocabulary word to user's practice queue with FSRS initialization\"\"\"
|
| 12 |
+
try:
|
| 13 |
+
user_id = token if token else 'anonymous'
|
| 14 |
+
|
| 15 |
+
vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
|
| 16 |
+
if not vocab:
|
| 17 |
+
raise HTTPException(status_code=404, detail="Vocabulary not found")
|
| 18 |
+
|
| 19 |
+
fsrs_data = {
|
| 20 |
+
'difficulty': 0.3,
|
| 21 |
+
'stability': 2.5,
|
| 22 |
+
'retrievability': 1.0,
|
| 23 |
+
'review_count': 0,
|
| 24 |
+
'last_review': None,
|
| 25 |
+
'next_review': datetime.utcnow().isoformat() + 'Z',
|
| 26 |
+
'lapses': 0,
|
| 27 |
+
'state': 'new'
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
user_vocab = {
|
| 31 |
+
'vocabulary_id': vocab_request.vocab_id,
|
| 32 |
+
'swahili': vocab.get('swahili', ''),
|
| 33 |
+
'english': vocab.get('english', ''),
|
| 34 |
+
'part_of_speech': vocab.get('part_of_speech', 'unknown'),
|
| 35 |
+
'added_at': datetime.utcnow().isoformat() + 'Z',
|
| 36 |
+
'added_from': vocab_request.source_lesson_id,
|
| 37 |
+
'fsrs': fsrs_data,
|
| 38 |
+
'mastery_level': 0,
|
| 39 |
+
'times_reviewed': 0,
|
| 40 |
+
'times_correct': 0,
|
| 41 |
+
'accuracy': 0.0
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
success = learning_service.update_vocabulary_progress(
|
| 45 |
+
user_id, str(vocab_request.vocab_id), user_vocab
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
if success:
|
| 49 |
+
return {"success": True, "vocabulary": user_vocab}
|
| 50 |
+
else:
|
| 51 |
+
raise HTTPException(status_code=500, detail="Failed to add vocabulary")
|
| 52 |
+
except HTTPException:
|
| 53 |
+
raise
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"Error adding vocabulary: {e}")
|
| 56 |
+
raise HTTPException(status_code=500, detail="Failed to add vocabulary")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
|
| 60 |
+
\"\"\"Implement FSRS algorithm\"\"\"
|
| 61 |
+
from datetime import timedelta
|
| 62 |
+
|
| 63 |
+
difficulty = fsrs['difficulty']
|
| 64 |
+
stability = fsrs['stability']
|
| 65 |
+
|
| 66 |
+
if grade == 0:
|
| 67 |
+
new_difficulty = min(difficulty + 0.2, 1.0)
|
| 68 |
+
elif grade == 2:
|
| 69 |
+
new_difficulty = min(difficulty + 0.1, 1.0)
|
| 70 |
+
elif grade == 4:
|
| 71 |
+
new_difficulty = max(difficulty - 0.1, 0.0)
|
| 72 |
+
else:
|
| 73 |
+
new_difficulty = difficulty
|
| 74 |
+
|
| 75 |
+
if grade == 0:
|
| 76 |
+
new_stability = stability * 0.5
|
| 77 |
+
state = 'relearning'
|
| 78 |
+
interval_minutes = 10
|
| 79 |
+
elif grade == 2:
|
| 80 |
+
new_stability = stability * 1.2
|
| 81 |
+
state = 'review'
|
| 82 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 83 |
+
elif grade == 3:
|
| 84 |
+
new_stability = stability * 2.5
|
| 85 |
+
state = 'review'
|
| 86 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 87 |
+
else:
|
| 88 |
+
new_stability = stability * 4.0
|
| 89 |
+
state = 'review'
|
| 90 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 91 |
+
|
| 92 |
+
next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
'difficulty': new_difficulty,
|
| 96 |
+
'stability': new_stability,
|
| 97 |
+
'retrievability': 0.9 if grade >= 2 else 0.0,
|
| 98 |
+
'review_count': fsrs['review_count'] + 1,
|
| 99 |
+
'last_review': datetime.utcnow().isoformat() + 'Z',
|
| 100 |
+
'next_review': next_review.isoformat() + 'Z',
|
| 101 |
+
'lapses': fsrs['lapses'],
|
| 102 |
+
'state': state,
|
| 103 |
+
'interval_days': interval_minutes / (24 * 60)
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def calculate_mastery_level(vocab: Dict) -> int:
|
| 108 |
+
\"\"\"Calculate mastery level (0-5)\"\"\"
|
| 109 |
+
accuracy = vocab['accuracy']
|
| 110 |
+
reviews = vocab['times_reviewed']
|
| 111 |
+
stability = vocab['fsrs']['stability']
|
| 112 |
+
|
| 113 |
+
if reviews == 0:
|
| 114 |
+
return 0
|
| 115 |
+
elif reviews < 5 or accuracy < 70:
|
| 116 |
+
return 1
|
| 117 |
+
elif reviews < 10 or accuracy < 85:
|
| 118 |
+
return 2
|
| 119 |
+
elif reviews < 20 or accuracy < 95:
|
| 120 |
+
return 3
|
| 121 |
+
elif reviews >= 20 and accuracy >= 95 and stability >= 30:
|
| 122 |
+
return 4
|
| 123 |
+
elif reviews >= 40 and accuracy >= 98 and stability >= 90:
|
| 124 |
+
return 5
|
| 125 |
+
else:
|
| 126 |
+
return 3
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@router.post("/vocabulary/review")
|
| 130 |
+
async def record_vocabulary_review_fsrs(
|
| 131 |
+
review_request: VocabularyReviewRequest,
|
| 132 |
+
request: Request,
|
| 133 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 134 |
+
):
|
| 135 |
+
\"\"\"Record vocabulary review and update FSRS parameters\"\"\"
|
| 136 |
+
try:
|
| 137 |
+
user_id = token if token else 'anonymous'
|
| 138 |
+
progress = learning_service.get_user_progress(user_id)
|
| 139 |
+
|
| 140 |
+
if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
|
| 141 |
+
raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
|
| 142 |
+
|
| 143 |
+
vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
|
| 144 |
+
fsrs = vocab['fsrs']
|
| 145 |
+
|
| 146 |
+
grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
|
| 147 |
+
grade = grade_map.get(review_request.rating, 3)
|
| 148 |
+
|
| 149 |
+
new_fsrs = calculate_next_review_fsrs(fsrs, grade)
|
| 150 |
+
|
| 151 |
+
vocab['fsrs'] = new_fsrs
|
| 152 |
+
vocab['times_reviewed'] += 1
|
| 153 |
+
if grade >= 2:
|
| 154 |
+
vocab['times_correct'] += 1
|
| 155 |
+
else:
|
| 156 |
+
vocab['fsrs']['lapses'] += 1
|
| 157 |
+
|
| 158 |
+
vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
|
| 159 |
+
vocab['mastery_level'] = calculate_mastery_level(vocab)
|
| 160 |
+
vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 161 |
+
|
| 162 |
+
if 'vocabulary_reviewed' not in progress['overall_stats']:
|
| 163 |
+
progress['overall_stats']['vocabulary_reviewed'] = 0
|
| 164 |
+
progress['overall_stats']['vocabulary_reviewed'] += 1
|
| 165 |
+
|
| 166 |
+
learning_service.save_user_progress(user_id, progress)
|
| 167 |
+
|
| 168 |
+
return {
|
| 169 |
+
"success": True,
|
| 170 |
+
"vocabulary": vocab,
|
| 171 |
+
"next_review": new_fsrs['next_review'],
|
| 172 |
+
"interval_days": new_fsrs['interval_days']
|
| 173 |
+
}
|
| 174 |
+
except HTTPException:
|
| 175 |
+
raise
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.error(f"Error recording vocabulary review: {e}")
|
| 178 |
+
raise HTTPException(status_code=500, detail="Failed to record review")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
@router.get("/vocabulary/stats")
|
| 182 |
+
async def get_vocabulary_stats(
|
| 183 |
+
request: Request,
|
| 184 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 185 |
+
):
|
| 186 |
+
\"\"\"Get vocabulary mastery statistics\"\"\"
|
| 187 |
+
try:
|
| 188 |
+
user_id = token if token else 'anonymous'
|
| 189 |
+
progress = learning_service.get_user_progress(user_id)
|
| 190 |
+
|
| 191 |
+
if not progress:
|
| 192 |
+
return {
|
| 193 |
+
"total_words": 0,
|
| 194 |
+
"in_practice": 0,
|
| 195 |
+
"mastery_breakdown": {str(i): 0 for i in range(6)},
|
| 196 |
+
"average_accuracy": 0
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
vocab_progress = progress.get('vocabulary_progress', {})
|
| 200 |
+
mastery_breakdown = {str(i): 0 for i in range(6)}
|
| 201 |
+
total_accuracy = 0
|
| 202 |
+
total_with_reviews = 0
|
| 203 |
+
|
| 204 |
+
for vocab_data in vocab_progress.values():
|
| 205 |
+
level = vocab_data.get('mastery_level', 0)
|
| 206 |
+
mastery_breakdown[str(level)] += 1
|
| 207 |
+
|
| 208 |
+
if vocab_data.get('times_reviewed', 0) > 0:
|
| 209 |
+
total_accuracy += vocab_data.get('accuracy', 0)
|
| 210 |
+
total_with_reviews += 1
|
| 211 |
+
|
| 212 |
+
avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
|
| 213 |
+
|
| 214 |
+
return {
|
| 215 |
+
"total_words": len(vocab_progress),
|
| 216 |
+
"in_practice": len(vocab_progress),
|
| 217 |
+
"mastery_breakdown": mastery_breakdown,
|
| 218 |
+
"average_accuracy": round(avg_accuracy, 1),
|
| 219 |
+
"total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
|
| 220 |
+
}
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.error(f"Error getting vocabulary stats: {e}")
|
| 223 |
+
raise HTTPException(status_code=500, detail="Failed to get stats")
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
@router.get("/vocabulary/library")
|
| 227 |
+
async def get_vocabulary_library(
|
| 228 |
+
lesson_id: Optional[int] = None,
|
| 229 |
+
level: Optional[str] = None,
|
| 230 |
+
search: Optional[str] = None,
|
| 231 |
+
request: Request = None,
|
| 232 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 233 |
+
):
|
| 234 |
+
\"\"\"Browse all vocabulary with filters\"\"\"
|
| 235 |
+
try:
|
| 236 |
+
user_id = token if token else 'anonymous'
|
| 237 |
+
|
| 238 |
+
all_vocab = learning_service.get_all_vocabulary()
|
| 239 |
+
progress = learning_service.get_user_progress(user_id)
|
| 240 |
+
user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
|
| 241 |
+
|
| 242 |
+
filtered_vocab = all_vocab
|
| 243 |
+
|
| 244 |
+
if lesson_id:
|
| 245 |
+
filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
|
| 246 |
+
|
| 247 |
+
if level:
|
| 248 |
+
filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
|
| 249 |
+
|
| 250 |
+
if search:
|
| 251 |
+
search_lower = search.lower()
|
| 252 |
+
filtered_vocab = [v for v in filtered_vocab
|
| 253 |
+
if search_lower in v.get('swahili', '').lower()
|
| 254 |
+
or search_lower in v.get('english', '').lower()]
|
| 255 |
+
|
| 256 |
+
for vocab in filtered_vocab:
|
| 257 |
+
vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
|
| 258 |
+
if vocab_id in user_vocab:
|
| 259 |
+
vocab['status'] = 'practicing'
|
| 260 |
+
vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
|
| 261 |
+
vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
|
| 262 |
+
vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
|
| 263 |
+
else:
|
| 264 |
+
vocab['status'] = 'not_practicing'
|
| 265 |
+
vocab['mastery_level'] = 0
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"vocabulary": filtered_vocab,
|
| 269 |
+
"total": len(filtered_vocab),
|
| 270 |
+
"filters_applied": {
|
| 271 |
+
"lesson_id": lesson_id,
|
| 272 |
+
"level": level,
|
| 273 |
+
"search": search
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
except Exception as e:
|
| 277 |
+
logger.error(f"Error getting vocabulary library: {e}")
|
| 278 |
+
raise HTTPException(status_code=500, detail="Failed to get vocabulary")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# Reading Comprehension
|
| 282 |
+
|
| 283 |
+
class ComprehensionAnswer(BaseModel):
|
| 284 |
+
question_id: str
|
| 285 |
+
answer: str
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class ComprehensionSubmission(BaseModel):
|
| 289 |
+
lesson_id: int
|
| 290 |
+
passage_id: str
|
| 291 |
+
answers: List[ComprehensionAnswer]
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@router.post("/comprehension/submit")
|
| 295 |
+
async def submit_comprehension_answers(
|
| 296 |
+
submission: ComprehensionSubmission,
|
| 297 |
+
request: Request,
|
| 298 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 299 |
+
):
|
| 300 |
+
\"\"\"Submit reading comprehension answers and get scoring\"\"\"
|
| 301 |
+
try:
|
| 302 |
+
user_id = token if token else 'anonymous'
|
| 303 |
+
|
| 304 |
+
lesson = learning_service.get_lesson(submission.lesson_id)
|
| 305 |
+
if not lesson:
|
| 306 |
+
raise HTTPException(status_code=404, detail="Lesson not found")
|
| 307 |
+
|
| 308 |
+
passage = None
|
| 309 |
+
for p in lesson.get('reading_passages', []):
|
| 310 |
+
if p['passage_id'] == submission.passage_id:
|
| 311 |
+
passage = p
|
| 312 |
+
break
|
| 313 |
+
|
| 314 |
+
if not passage:
|
| 315 |
+
raise HTTPException(status_code=404, detail="Passage not found")
|
| 316 |
+
|
| 317 |
+
results = []
|
| 318 |
+
correct_count = 0
|
| 319 |
+
|
| 320 |
+
for submitted in submission.answers:
|
| 321 |
+
question_id = submitted.question_id
|
| 322 |
+
user_answer = submitted.answer.strip().lower()
|
| 323 |
+
|
| 324 |
+
question = None
|
| 325 |
+
for q in passage['comprehension_questions']:
|
| 326 |
+
if q['question_id'] == question_id:
|
| 327 |
+
question = q
|
| 328 |
+
break
|
| 329 |
+
|
| 330 |
+
if not question:
|
| 331 |
+
continue
|
| 332 |
+
|
| 333 |
+
correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
|
| 334 |
+
is_correct = user_answer in correct_answers
|
| 335 |
+
|
| 336 |
+
if is_correct:
|
| 337 |
+
correct_count += 1
|
| 338 |
+
|
| 339 |
+
results.append({
|
| 340 |
+
"question_id": question_id,
|
| 341 |
+
"correct": is_correct,
|
| 342 |
+
"user_answer": user_answer,
|
| 343 |
+
"correct_answer": question['correct_answers'][0] if correct_answers else None,
|
| 344 |
+
"explanation": question.get('explanation')
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
|
| 348 |
+
|
| 349 |
+
progress = learning_service.get_user_progress(user_id)
|
| 350 |
+
if not progress:
|
| 351 |
+
progress = learning_service.create_default_progress(user_id)
|
| 352 |
+
|
| 353 |
+
if 'comprehension_scores' not in progress:
|
| 354 |
+
progress['comprehension_scores'] = {}
|
| 355 |
+
|
| 356 |
+
progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
|
| 357 |
+
"score": score,
|
| 358 |
+
"completed_at": datetime.utcnow().isoformat() + 'Z',
|
| 359 |
+
"attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
learning_service.save_user_progress(user_id, progress)
|
| 363 |
+
|
| 364 |
+
return {
|
| 365 |
+
"results": results,
|
| 366 |
+
"score": round(score, 1),
|
| 367 |
+
"correct": correct_count,
|
| 368 |
+
"total": len(submission.answers)
|
| 369 |
+
}
|
| 370 |
+
except HTTPException:
|
| 371 |
+
raise
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"Error submitting comprehension: {e}")
|
| 374 |
+
raise HTTPException(status_code=500, detail="Failed to submit comprehension")
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# Task Scenarios
|
| 378 |
+
|
| 379 |
+
class ScenarioProgressUpdate(BaseModel):
|
| 380 |
+
turn_id: str
|
| 381 |
+
choice_id: str
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@router.get("/scenarios/{scenario_id}")
|
| 385 |
+
async def get_scenario(
|
| 386 |
+
scenario_id: str,
|
| 387 |
+
request: Request,
|
| 388 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 389 |
+
):
|
| 390 |
+
\"\"\"Get task scenario with branching dialogue\"\"\"
|
| 391 |
+
try:
|
| 392 |
+
user_id = token if token else 'anonymous'
|
| 393 |
+
|
| 394 |
+
scenario = learning_service.get_scenario(scenario_id)
|
| 395 |
+
if not scenario:
|
| 396 |
+
raise HTTPException(status_code=404, detail="Scenario not found")
|
| 397 |
+
|
| 398 |
+
progress = learning_service.get_user_progress(user_id)
|
| 399 |
+
scenario_progress = None
|
| 400 |
+
|
| 401 |
+
if progress and 'scenario_progress' in progress:
|
| 402 |
+
scenario_progress = progress['scenario_progress'].get(scenario_id)
|
| 403 |
+
|
| 404 |
+
return {
|
| 405 |
+
"scenario": scenario,
|
| 406 |
+
"user_progress": scenario_progress
|
| 407 |
+
}
|
| 408 |
+
except HTTPException:
|
| 409 |
+
raise
|
| 410 |
+
except Exception as e:
|
| 411 |
+
logger.error(f"Error getting scenario: {e}")
|
| 412 |
+
raise HTTPException(status_code=500, detail="Failed to get scenario")
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
@router.post("/scenarios/{scenario_id}/progress")
|
| 416 |
+
async def update_scenario_progress(
|
| 417 |
+
scenario_id: str,
|
| 418 |
+
progress_update: ScenarioProgressUpdate,
|
| 419 |
+
request: Request,
|
| 420 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 421 |
+
):
|
| 422 |
+
\"\"\"Update scenario progress with user choice\"\"\"
|
| 423 |
+
try:
|
| 424 |
+
user_id = token if token else 'anonymous'
|
| 425 |
+
|
| 426 |
+
scenario = learning_service.get_scenario(scenario_id)
|
| 427 |
+
if not scenario:
|
| 428 |
+
raise HTTPException(status_code=404, detail="Scenario not found")
|
| 429 |
+
|
| 430 |
+
progress = learning_service.get_user_progress(user_id)
|
| 431 |
+
if not progress:
|
| 432 |
+
progress = learning_service.create_default_progress(user_id)
|
| 433 |
+
|
| 434 |
+
if 'scenario_progress' not in progress:
|
| 435 |
+
progress['scenario_progress'] = {}
|
| 436 |
+
|
| 437 |
+
if scenario_id not in progress['scenario_progress']:
|
| 438 |
+
progress['scenario_progress'][scenario_id] = {
|
| 439 |
+
"started_at": datetime.utcnow().isoformat() + 'Z',
|
| 440 |
+
"turns": [],
|
| 441 |
+
"completed": False
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
progress['scenario_progress'][scenario_id]['turns'].append({
|
| 445 |
+
"turn_id": progress_update.turn_id,
|
| 446 |
+
"choice_id": progress_update.choice_id,
|
| 447 |
+
"timestamp": datetime.utcnow().isoformat() + 'Z'
|
| 448 |
+
})
|
| 449 |
+
|
| 450 |
+
turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
|
| 451 |
+
if turns_count >= scenario.get('required_turns', 6):
|
| 452 |
+
progress['scenario_progress'][scenario_id]['completed'] = True
|
| 453 |
+
progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 454 |
+
|
| 455 |
+
learning_service.save_user_progress(user_id, progress)
|
| 456 |
+
|
| 457 |
+
return {
|
| 458 |
+
"success": True,
|
| 459 |
+
"progress": progress['scenario_progress'][scenario_id]
|
| 460 |
+
}
|
| 461 |
+
except HTTPException:
|
| 462 |
+
raise
|
| 463 |
+
except Exception as e:
|
| 464 |
+
logger.error(f"Error updating scenario progress: {e}")
|
| 465 |
+
raise HTTPException(status_code=500, detail="Failed to update scenario")
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@router.get("/scenarios")
|
| 469 |
+
async def list_scenarios(
|
| 470 |
+
request: Request,
|
| 471 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 472 |
+
):
|
| 473 |
+
\"\"\"Get list of all available scenarios\"\"\"
|
| 474 |
+
try:
|
| 475 |
+
scenarios = learning_service.get_all_scenarios()
|
| 476 |
+
return {
|
| 477 |
+
"success": True,
|
| 478 |
+
"scenarios": scenarios,
|
| 479 |
+
"total": len(scenarios)
|
| 480 |
+
}
|
| 481 |
+
except Exception as e:
|
| 482 |
+
logger.error(f"Error listing scenarios: {e}")
|
| 483 |
+
raise HTTPException(status_code=500, detail="Failed to list scenarios")
|
| 484 |
+
"""
|
| 485 |
+
|
| 486 |
+
# Append to learning.py
|
| 487 |
+
with open('C:/repos/polyglot/backend/app/routers/learning.py', 'a', encoding='utf-8') as f:
|
| 488 |
+
f.write(endpoints_code)
|
| 489 |
+
|
| 490 |
+
print("Successfully added all remaining Phase 1-3 endpoints!")
|
app/routers/learning.py
ADDED
|
@@ -0,0 +1,1020 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learning API Router - REST endpoints for language learning functionality
|
| 3 |
+
|
| 4 |
+
Provides endpoints for:
|
| 5 |
+
- Fetching lesson catalog and individual lessons
|
| 6 |
+
- Managing user progress
|
| 7 |
+
- Recording lesson completion and scores
|
| 8 |
+
- Achievement tracking
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter, HTTPException, Depends, Request, File, UploadFile
|
| 12 |
+
from fastapi.responses import Response
|
| 13 |
+
from pydantic import BaseModel
|
| 14 |
+
from typing import List, Dict, Optional, Any
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
import logging
|
| 17 |
+
import io
|
| 18 |
+
|
| 19 |
+
from app.services.learning_data_service import LearningDataService
|
| 20 |
+
from app.auth import optional_hf_token
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
router = APIRouter(prefix="/api/learning", tags=["learning"])
|
| 25 |
+
|
| 26 |
+
# Initialize data service
|
| 27 |
+
learning_service = LearningDataService()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ==================== Request/Response Models ====================
|
| 31 |
+
|
| 32 |
+
class LessonProgressUpdate(BaseModel):
|
| 33 |
+
lesson_id: int
|
| 34 |
+
status: str # 'in_progress' or 'completed'
|
| 35 |
+
score: Optional[int] = None
|
| 36 |
+
pronunciation_score: Optional[float] = None
|
| 37 |
+
listening_score: Optional[float] = None
|
| 38 |
+
comprehension_score: Optional[float] = None
|
| 39 |
+
time_spent_seconds: Optional[int] = None
|
| 40 |
+
steps_completed: Optional[int] = None
|
| 41 |
+
steps_skipped: Optional[int] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class VocabularyReview(BaseModel):
|
| 45 |
+
vocabulary_id: int
|
| 46 |
+
swahili: str
|
| 47 |
+
is_correct: bool
|
| 48 |
+
mastery_level: Optional[int] = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AchievementCheck(BaseModel):
|
| 52 |
+
achievement_id: str
|
| 53 |
+
progress: int
|
| 54 |
+
target: int
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ==================== Lesson Endpoints ====================
|
| 58 |
+
|
| 59 |
+
@router.get("/lessons")
|
| 60 |
+
async def get_lessons(language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
|
| 61 |
+
"""
|
| 62 |
+
Get catalog of all available lessons for a specific language
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
language: Language code (swahili, kamba, maasai)
|
| 66 |
+
|
| 67 |
+
Returns the lesson index with metadata for all lessons
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
index = learning_service.get_lessons_index(language)
|
| 71 |
+
if not index:
|
| 72 |
+
raise HTTPException(status_code=404, detail=f"Lessons catalog not found for {language}")
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"success": True,
|
| 76 |
+
"lessons": index.get('lessons', []),
|
| 77 |
+
"learning_paths": index.get('learning_paths', {}),
|
| 78 |
+
"metadata": index.get('metadata', {})
|
| 79 |
+
}
|
| 80 |
+
except HTTPException:
|
| 81 |
+
raise
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"Error fetching lessons for {language}: {e}")
|
| 84 |
+
raise HTTPException(status_code=500, detail="Failed to fetch lessons")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@router.get("/lessons/{lesson_id}")
|
| 88 |
+
async def get_lesson(lesson_id: int, language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
|
| 89 |
+
"""
|
| 90 |
+
Get detailed lesson content including vocabulary, dialogue, and exercises
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
lesson_id: ID of the lesson to fetch
|
| 94 |
+
language: Language code (swahili, kamba, maasai)
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
lesson = learning_service.get_lesson(lesson_id, language)
|
| 98 |
+
if not lesson:
|
| 99 |
+
raise HTTPException(status_code=404, detail=f"Lesson {lesson_id} not found for {language}")
|
| 100 |
+
|
| 101 |
+
return {
|
| 102 |
+
"success": True,
|
| 103 |
+
"lesson": lesson
|
| 104 |
+
}
|
| 105 |
+
except HTTPException:
|
| 106 |
+
raise
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error fetching lesson {lesson_id} for {language}: {e}")
|
| 109 |
+
raise HTTPException(status_code=500, detail="Failed to fetch lesson")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# ==================== User Progress Endpoints ====================
|
| 113 |
+
|
| 114 |
+
@router.get("/progress")
|
| 115 |
+
async def get_user_progress(request: Request, token: Optional[str] = Depends(optional_hf_token)):
|
| 116 |
+
"""
|
| 117 |
+
Get user's learning progress
|
| 118 |
+
|
| 119 |
+
Returns overall stats, lesson progress, vocabulary progress, and achievements
|
| 120 |
+
"""
|
| 121 |
+
try:
|
| 122 |
+
# Use authenticated user ID or default for anonymous users
|
| 123 |
+
user_id = token if token else 'anonymous'
|
| 124 |
+
|
| 125 |
+
progress = learning_service.get_user_progress(user_id)
|
| 126 |
+
if not progress:
|
| 127 |
+
raise HTTPException(status_code=500, detail="Failed to load user progress")
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"success": True,
|
| 131 |
+
"progress": progress
|
| 132 |
+
}
|
| 133 |
+
except HTTPException:
|
| 134 |
+
raise
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error fetching user progress: {e}")
|
| 137 |
+
raise HTTPException(status_code=500, detail="Failed to fetch user progress")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@router.post("/progress/lesson")
|
| 141 |
+
async def update_lesson_progress(
|
| 142 |
+
progress_update: LessonProgressUpdate,
|
| 143 |
+
request: Request,
|
| 144 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 145 |
+
):
|
| 146 |
+
"""
|
| 147 |
+
Update progress for a specific lesson
|
| 148 |
+
|
| 149 |
+
Records completion status, scores, and time spent on a lesson
|
| 150 |
+
"""
|
| 151 |
+
try:
|
| 152 |
+
user_id = token if token else 'anonymous'
|
| 153 |
+
|
| 154 |
+
# Build progress update dict
|
| 155 |
+
update_data = {
|
| 156 |
+
'lesson_id': progress_update.lesson_id,
|
| 157 |
+
'status': progress_update.status
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
# Add optional fields if provided
|
| 161 |
+
if progress_update.score is not None:
|
| 162 |
+
update_data['latest_score'] = progress_update.score
|
| 163 |
+
|
| 164 |
+
# Track best score
|
| 165 |
+
user_progress = learning_service.get_user_progress(user_id)
|
| 166 |
+
if user_progress:
|
| 167 |
+
lesson_key = str(progress_update.lesson_id)
|
| 168 |
+
current_best = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('best_score', 0)
|
| 169 |
+
update_data['best_score'] = max(current_best, progress_update.score)
|
| 170 |
+
|
| 171 |
+
if progress_update.pronunciation_score is not None:
|
| 172 |
+
update_data['pronunciation_score'] = progress_update.pronunciation_score
|
| 173 |
+
|
| 174 |
+
if progress_update.listening_score is not None:
|
| 175 |
+
update_data['listening_score'] = progress_update.listening_score
|
| 176 |
+
|
| 177 |
+
if progress_update.comprehension_score is not None:
|
| 178 |
+
update_data['comprehension_score'] = progress_update.comprehension_score
|
| 179 |
+
|
| 180 |
+
if progress_update.time_spent_seconds is not None:
|
| 181 |
+
update_data['time_spent_seconds'] = progress_update.time_spent_seconds
|
| 182 |
+
|
| 183 |
+
if progress_update.steps_completed is not None:
|
| 184 |
+
update_data['steps_completed'] = progress_update.steps_completed
|
| 185 |
+
|
| 186 |
+
if progress_update.steps_skipped is not None:
|
| 187 |
+
update_data['steps_skipped'] = progress_update.steps_skipped
|
| 188 |
+
|
| 189 |
+
# Add completion timestamp if status is completed
|
| 190 |
+
if progress_update.status == 'completed':
|
| 191 |
+
update_data['completed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 192 |
+
|
| 193 |
+
# Increment attempts
|
| 194 |
+
user_progress = learning_service.get_user_progress(user_id)
|
| 195 |
+
if user_progress:
|
| 196 |
+
lesson_key = str(progress_update.lesson_id)
|
| 197 |
+
current_attempts = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('attempts', 0)
|
| 198 |
+
update_data['attempts'] = current_attempts + 1
|
| 199 |
+
|
| 200 |
+
# Save to file
|
| 201 |
+
success = learning_service.update_lesson_progress(
|
| 202 |
+
user_id,
|
| 203 |
+
progress_update.lesson_id,
|
| 204 |
+
update_data
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if not success:
|
| 208 |
+
raise HTTPException(status_code=500, detail="Failed to save progress")
|
| 209 |
+
|
| 210 |
+
return {
|
| 211 |
+
"success": True,
|
| 212 |
+
"message": "Lesson progress updated"
|
| 213 |
+
}
|
| 214 |
+
except HTTPException:
|
| 215 |
+
raise
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Error updating lesson progress: {e}")
|
| 218 |
+
raise HTTPException(status_code=500, detail="Failed to update lesson progress")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@router.post("/progress/vocabulary")
|
| 222 |
+
async def record_vocabulary_review(
|
| 223 |
+
review: VocabularyReview,
|
| 224 |
+
request: Request,
|
| 225 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Record a vocabulary review/practice session
|
| 229 |
+
|
| 230 |
+
Updates mastery level and review statistics for a vocabulary word
|
| 231 |
+
"""
|
| 232 |
+
try:
|
| 233 |
+
user_id = token if token else 'anonymous'
|
| 234 |
+
|
| 235 |
+
# Get current vocabulary progress
|
| 236 |
+
user_progress = learning_service.get_user_progress(user_id)
|
| 237 |
+
if not user_progress:
|
| 238 |
+
raise HTTPException(status_code=500, detail="Failed to load user progress")
|
| 239 |
+
|
| 240 |
+
vocab_key = str(review.vocabulary_id)
|
| 241 |
+
vocab_progress = user_progress.get('vocabulary_progress', {}).get(vocab_key, {
|
| 242 |
+
'vocabulary_id': review.vocabulary_id,
|
| 243 |
+
'swahili': review.swahili,
|
| 244 |
+
'mastery_level': 0,
|
| 245 |
+
'times_reviewed': 0,
|
| 246 |
+
'times_correct': 0,
|
| 247 |
+
'ease_factor': 2.5,
|
| 248 |
+
'interval_days': 0
|
| 249 |
+
})
|
| 250 |
+
|
| 251 |
+
# Update review counts
|
| 252 |
+
vocab_progress['times_reviewed'] = vocab_progress.get('times_reviewed', 0) + 1
|
| 253 |
+
if review.is_correct:
|
| 254 |
+
vocab_progress['times_correct'] = vocab_progress.get('times_correct', 0) + 1
|
| 255 |
+
|
| 256 |
+
# Update mastery level if provided
|
| 257 |
+
if review.mastery_level is not None:
|
| 258 |
+
vocab_progress['mastery_level'] = review.mastery_level
|
| 259 |
+
|
| 260 |
+
# Update timestamps
|
| 261 |
+
vocab_progress['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 262 |
+
|
| 263 |
+
# Calculate next review date using simple spaced repetition
|
| 264 |
+
# (simplified version - could use SuperMemo SM-2 algorithm)
|
| 265 |
+
interval_days = vocab_progress.get('interval_days', 0)
|
| 266 |
+
if review.is_correct:
|
| 267 |
+
interval_days = max(1, interval_days * 2) # Double the interval
|
| 268 |
+
else:
|
| 269 |
+
interval_days = 1 # Reset to 1 day if incorrect
|
| 270 |
+
|
| 271 |
+
vocab_progress['interval_days'] = interval_days
|
| 272 |
+
|
| 273 |
+
from datetime import timedelta
|
| 274 |
+
next_review = datetime.utcnow() + timedelta(days=interval_days)
|
| 275 |
+
vocab_progress['next_review_at'] = next_review.isoformat() + 'Z'
|
| 276 |
+
|
| 277 |
+
# Save to file
|
| 278 |
+
success = learning_service.update_vocabulary_progress(
|
| 279 |
+
user_id,
|
| 280 |
+
review.vocabulary_id,
|
| 281 |
+
vocab_progress
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if not success:
|
| 285 |
+
raise HTTPException(status_code=500, detail="Failed to save vocabulary progress")
|
| 286 |
+
|
| 287 |
+
return {
|
| 288 |
+
"success": True,
|
| 289 |
+
"message": "Vocabulary review recorded",
|
| 290 |
+
"next_review_at": vocab_progress['next_review_at']
|
| 291 |
+
}
|
| 292 |
+
except HTTPException:
|
| 293 |
+
raise
|
| 294 |
+
except Exception as e:
|
| 295 |
+
logger.error(f"Error recording vocabulary review: {e}")
|
| 296 |
+
raise HTTPException(status_code=500, detail="Failed to record vocabulary review")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ==================== Achievement Endpoints ====================
|
| 300 |
+
|
| 301 |
+
@router.get("/achievements")
|
| 302 |
+
async def get_achievements(request: Request, token: Optional[str] = Depends(optional_hf_token)):
|
| 303 |
+
"""
|
| 304 |
+
Get all available achievements and user's progress on them
|
| 305 |
+
"""
|
| 306 |
+
try:
|
| 307 |
+
# Get achievements configuration
|
| 308 |
+
achievements_config = learning_service.get_achievements()
|
| 309 |
+
if not achievements_config:
|
| 310 |
+
raise HTTPException(status_code=404, detail="Achievements not found")
|
| 311 |
+
|
| 312 |
+
# Get user progress
|
| 313 |
+
user_id = token if token else 'anonymous'
|
| 314 |
+
user_progress = learning_service.get_user_progress(user_id)
|
| 315 |
+
|
| 316 |
+
# Merge achievement definitions with user progress
|
| 317 |
+
user_achievements = user_progress.get('achievements', {}) if user_progress else {}
|
| 318 |
+
|
| 319 |
+
achievements_with_progress = []
|
| 320 |
+
for achievement in achievements_config.get('achievements', []):
|
| 321 |
+
achievement_id = achievement['achievement_id']
|
| 322 |
+
achievement_data = {
|
| 323 |
+
**achievement,
|
| 324 |
+
'unlocked': False,
|
| 325 |
+
'progress': 0
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Add user progress if available
|
| 329 |
+
if achievement_id in user_achievements:
|
| 330 |
+
achievement_data.update(user_achievements[achievement_id])
|
| 331 |
+
|
| 332 |
+
achievements_with_progress.append(achievement_data)
|
| 333 |
+
|
| 334 |
+
return {
|
| 335 |
+
"success": True,
|
| 336 |
+
"achievements": achievements_with_progress,
|
| 337 |
+
"tiers": achievements_config.get('tiers', {})
|
| 338 |
+
}
|
| 339 |
+
except HTTPException:
|
| 340 |
+
raise
|
| 341 |
+
except Exception as e:
|
| 342 |
+
logger.error(f"Error fetching achievements: {e}")
|
| 343 |
+
raise HTTPException(status_code=500, detail="Failed to fetch achievements")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
@router.post("/achievements/check")
|
| 347 |
+
async def check_achievement(
|
| 348 |
+
achievement: AchievementCheck,
|
| 349 |
+
request: Request,
|
| 350 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 351 |
+
):
|
| 352 |
+
"""
|
| 353 |
+
Check and potentially unlock an achievement
|
| 354 |
+
|
| 355 |
+
Updates achievement progress and unlocks if target is reached
|
| 356 |
+
"""
|
| 357 |
+
try:
|
| 358 |
+
user_id = token if token else 'anonymous'
|
| 359 |
+
|
| 360 |
+
success = learning_service.unlock_achievement(
|
| 361 |
+
user_id,
|
| 362 |
+
achievement.achievement_id,
|
| 363 |
+
achievement.progress,
|
| 364 |
+
achievement.target
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
if not success:
|
| 368 |
+
raise HTTPException(status_code=500, detail="Failed to update achievement")
|
| 369 |
+
|
| 370 |
+
is_unlocked = achievement.progress >= achievement.target
|
| 371 |
+
|
| 372 |
+
return {
|
| 373 |
+
"success": True,
|
| 374 |
+
"unlocked": is_unlocked,
|
| 375 |
+
"achievement_id": achievement.achievement_id
|
| 376 |
+
}
|
| 377 |
+
except HTTPException:
|
| 378 |
+
raise
|
| 379 |
+
except Exception as e:
|
| 380 |
+
logger.error(f"Error checking achievement: {e}")
|
| 381 |
+
raise HTTPException(status_code=500, detail="Failed to check achievement")
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
# ==================== Statistics Endpoints ====================
|
| 385 |
+
|
| 386 |
+
@router.get("/stats")
|
| 387 |
+
async def get_user_stats(request: Request, token: Optional[str] = Depends(optional_hf_token)):
|
| 388 |
+
"""
|
| 389 |
+
Get user's overall learning statistics
|
| 390 |
+
|
| 391 |
+
Returns aggregated stats like total XP, streak, lessons completed, etc.
|
| 392 |
+
"""
|
| 393 |
+
try:
|
| 394 |
+
user_id = token if token else 'anonymous'
|
| 395 |
+
progress = learning_service.get_user_progress(user_id)
|
| 396 |
+
|
| 397 |
+
if not progress:
|
| 398 |
+
raise HTTPException(status_code=500, detail="Failed to load user progress")
|
| 399 |
+
|
| 400 |
+
return {
|
| 401 |
+
"success": True,
|
| 402 |
+
"stats": progress.get('overall_stats', {}),
|
| 403 |
+
"daily_stats": progress.get('daily_stats', {})
|
| 404 |
+
}
|
| 405 |
+
except HTTPException:
|
| 406 |
+
raise
|
| 407 |
+
except Exception as e:
|
| 408 |
+
logger.error(f"Error fetching user stats: {e}")
|
| 409 |
+
raise HTTPException(status_code=500, detail="Failed to fetch user stats")
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
# ==================== TTS and ASR Endpoints ====================
|
| 413 |
+
|
| 414 |
+
class TTSRequest(BaseModel):
|
| 415 |
+
text: str
|
| 416 |
+
language: str
|
| 417 |
+
messageId: Optional[str] = None
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@router.post("/tts/generate")
|
| 421 |
+
async def generate_tts(
|
| 422 |
+
tts_request: TTSRequest,
|
| 423 |
+
request: Request
|
| 424 |
+
):
|
| 425 |
+
"""
|
| 426 |
+
Generate TTS audio for lesson text
|
| 427 |
+
"""
|
| 428 |
+
try:
|
| 429 |
+
from app.main import tts_service
|
| 430 |
+
|
| 431 |
+
# Generate TTS audio
|
| 432 |
+
audio_data = await tts_service.generate_speech(
|
| 433 |
+
text=tts_request.text,
|
| 434 |
+
language_code=tts_request.language
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
if not audio_data:
|
| 438 |
+
raise HTTPException(status_code=500, detail="Failed to generate TTS audio")
|
| 439 |
+
|
| 440 |
+
# Return audio as WAV file
|
| 441 |
+
return Response(
|
| 442 |
+
content=audio_data,
|
| 443 |
+
media_type="audio/wav",
|
| 444 |
+
headers={
|
| 445 |
+
"Content-Disposition": f"inline; filename=tts_{tts_request.messageId or 'audio'}.wav"
|
| 446 |
+
}
|
| 447 |
+
)
|
| 448 |
+
except Exception as e:
|
| 449 |
+
logger.error(f"Error generating TTS: {e}")
|
| 450 |
+
raise HTTPException(status_code=500, detail=f"Failed to generate TTS: {str(e)}")
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
@router.post("/transcribe")
|
| 454 |
+
async def transcribe_audio(
|
| 455 |
+
request: Request,
|
| 456 |
+
audio: UploadFile = File(...)
|
| 457 |
+
):
|
| 458 |
+
"""
|
| 459 |
+
Transcribe audio for pronunciation practice
|
| 460 |
+
"""
|
| 461 |
+
try:
|
| 462 |
+
from app.main import transcription_service
|
| 463 |
+
|
| 464 |
+
# Read audio file
|
| 465 |
+
audio_bytes = await audio.read()
|
| 466 |
+
|
| 467 |
+
# Get language from form data (default to Swahili)
|
| 468 |
+
form = await request.form()
|
| 469 |
+
language = form.get('language', 'swa')
|
| 470 |
+
|
| 471 |
+
# Transcribe
|
| 472 |
+
text = await transcription_service.transcribe_audio(
|
| 473 |
+
audio_data=audio_bytes,
|
| 474 |
+
language_code=language
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
return {
|
| 478 |
+
"success": True,
|
| 479 |
+
"text": text,
|
| 480 |
+
"language": language
|
| 481 |
+
}
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logger.error(f"Error transcribing audio: {e}")
|
| 484 |
+
raise HTTPException(status_code=500, detail=f"Failed to transcribe: {str(e)}")
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# ==================== Phase 1-3 Endpoints ====================
|
| 488 |
+
|
| 489 |
+
# Vocabulary Management
|
| 490 |
+
|
| 491 |
+
class VocabularyAddRequest(BaseModel):
|
| 492 |
+
vocab_id: int
|
| 493 |
+
source_lesson_id: Optional[int] = None
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
class VocabularyReviewRequest(BaseModel):
|
| 497 |
+
vocab_id: int
|
| 498 |
+
rating: str # 'again', 'hard', 'good', 'easy'
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@router.get("/vocabulary/due")
|
| 502 |
+
async def get_due_vocabulary(
|
| 503 |
+
request: Request,
|
| 504 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 505 |
+
):
|
| 506 |
+
"""Get vocabulary words due for FSRS review"""
|
| 507 |
+
try:
|
| 508 |
+
user_id = token if token else 'anonymous'
|
| 509 |
+
progress = learning_service.get_user_progress(user_id)
|
| 510 |
+
|
| 511 |
+
if not progress:
|
| 512 |
+
return {"due_words": [], "total_due": 0}
|
| 513 |
+
|
| 514 |
+
vocab_progress = progress.get('vocabulary_progress', {})
|
| 515 |
+
now = datetime.utcnow()
|
| 516 |
+
due_words = []
|
| 517 |
+
|
| 518 |
+
for vocab_id, vocab_data in vocab_progress.items():
|
| 519 |
+
next_review_str = vocab_data.get('fsrs', {}).get('next_review')
|
| 520 |
+
if not next_review_str:
|
| 521 |
+
continue
|
| 522 |
+
|
| 523 |
+
next_review = datetime.fromisoformat(next_review_str.rstrip('Z'))
|
| 524 |
+
|
| 525 |
+
if next_review <= now:
|
| 526 |
+
hours_overdue = (now - next_review).total_seconds() / 3600
|
| 527 |
+
vocab_data['priority'] = 1000 - hours_overdue
|
| 528 |
+
due_words.append(vocab_data)
|
| 529 |
+
|
| 530 |
+
due_words.sort(key=lambda x: x.get('priority', 0), reverse=True)
|
| 531 |
+
|
| 532 |
+
return {
|
| 533 |
+
"due_words": due_words,
|
| 534 |
+
"total_due": len(due_words),
|
| 535 |
+
"timestamp": now.isoformat() + 'Z'
|
| 536 |
+
}
|
| 537 |
+
except Exception as e:
|
| 538 |
+
logger.error(f"Error getting due vocabulary: {e}")
|
| 539 |
+
raise HTTPException(status_code=500, detail="Failed to get due vocabulary")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
@router.post("/vocabulary/add")
|
| 543 |
+
async def add_vocabulary_to_practice(
|
| 544 |
+
vocab_request: VocabularyAddRequest,
|
| 545 |
+
request: Request,
|
| 546 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 547 |
+
):
|
| 548 |
+
"""Add a vocabulary word to user's practice queue with FSRS initialization"""
|
| 549 |
+
try:
|
| 550 |
+
user_id = token if token else 'anonymous'
|
| 551 |
+
|
| 552 |
+
vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
|
| 553 |
+
if not vocab:
|
| 554 |
+
raise HTTPException(status_code=404, detail="Vocabulary not found")
|
| 555 |
+
|
| 556 |
+
fsrs_data = {
|
| 557 |
+
'difficulty': 0.3,
|
| 558 |
+
'stability': 2.5,
|
| 559 |
+
'retrievability': 1.0,
|
| 560 |
+
'review_count': 0,
|
| 561 |
+
'last_review': None,
|
| 562 |
+
'next_review': datetime.utcnow().isoformat() + 'Z',
|
| 563 |
+
'lapses': 0,
|
| 564 |
+
'state': 'new'
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
user_vocab = {
|
| 568 |
+
'vocabulary_id': vocab_request.vocab_id,
|
| 569 |
+
'swahili': vocab.get('swahili', ''),
|
| 570 |
+
'english': vocab.get('english', ''),
|
| 571 |
+
'part_of_speech': vocab.get('part_of_speech', 'unknown'),
|
| 572 |
+
'added_at': datetime.utcnow().isoformat() + 'Z',
|
| 573 |
+
'added_from': vocab_request.source_lesson_id,
|
| 574 |
+
'fsrs': fsrs_data,
|
| 575 |
+
'mastery_level': 0,
|
| 576 |
+
'times_reviewed': 0,
|
| 577 |
+
'times_correct': 0,
|
| 578 |
+
'accuracy': 0.0
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
success = learning_service.update_vocabulary_progress(
|
| 582 |
+
user_id, str(vocab_request.vocab_id), user_vocab
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if success:
|
| 586 |
+
return {"success": True, "vocabulary": user_vocab}
|
| 587 |
+
else:
|
| 588 |
+
raise HTTPException(status_code=500, detail="Failed to add vocabulary")
|
| 589 |
+
except HTTPException:
|
| 590 |
+
raise
|
| 591 |
+
except Exception as e:
|
| 592 |
+
logger.error(f"Error adding vocabulary: {e}")
|
| 593 |
+
raise HTTPException(status_code=500, detail="Failed to add vocabulary")
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
|
| 597 |
+
"""Implement FSRS algorithm"""
|
| 598 |
+
from datetime import timedelta
|
| 599 |
+
|
| 600 |
+
difficulty = fsrs['difficulty']
|
| 601 |
+
stability = fsrs['stability']
|
| 602 |
+
|
| 603 |
+
if grade == 0:
|
| 604 |
+
new_difficulty = min(difficulty + 0.2, 1.0)
|
| 605 |
+
elif grade == 2:
|
| 606 |
+
new_difficulty = min(difficulty + 0.1, 1.0)
|
| 607 |
+
elif grade == 4:
|
| 608 |
+
new_difficulty = max(difficulty - 0.1, 0.0)
|
| 609 |
+
else:
|
| 610 |
+
new_difficulty = difficulty
|
| 611 |
+
|
| 612 |
+
if grade == 0:
|
| 613 |
+
new_stability = stability * 0.5
|
| 614 |
+
state = 'relearning'
|
| 615 |
+
interval_minutes = 10
|
| 616 |
+
elif grade == 2:
|
| 617 |
+
new_stability = stability * 1.2
|
| 618 |
+
state = 'review'
|
| 619 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 620 |
+
elif grade == 3:
|
| 621 |
+
new_stability = stability * 2.5
|
| 622 |
+
state = 'review'
|
| 623 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 624 |
+
else:
|
| 625 |
+
new_stability = stability * 4.0
|
| 626 |
+
state = 'review'
|
| 627 |
+
interval_minutes = int(new_stability * 24 * 60)
|
| 628 |
+
|
| 629 |
+
next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
|
| 630 |
+
|
| 631 |
+
return {
|
| 632 |
+
'difficulty': new_difficulty,
|
| 633 |
+
'stability': new_stability,
|
| 634 |
+
'retrievability': 0.9 if grade >= 2 else 0.0,
|
| 635 |
+
'review_count': fsrs['review_count'] + 1,
|
| 636 |
+
'last_review': datetime.utcnow().isoformat() + 'Z',
|
| 637 |
+
'next_review': next_review.isoformat() + 'Z',
|
| 638 |
+
'lapses': fsrs['lapses'],
|
| 639 |
+
'state': state,
|
| 640 |
+
'interval_days': interval_minutes / (24 * 60)
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def calculate_mastery_level(vocab: Dict) -> int:
|
| 645 |
+
"""Calculate mastery level (0-5)"""
|
| 646 |
+
accuracy = vocab['accuracy']
|
| 647 |
+
reviews = vocab['times_reviewed']
|
| 648 |
+
stability = vocab['fsrs']['stability']
|
| 649 |
+
|
| 650 |
+
if reviews == 0:
|
| 651 |
+
return 0
|
| 652 |
+
elif reviews < 5 or accuracy < 70:
|
| 653 |
+
return 1
|
| 654 |
+
elif reviews < 10 or accuracy < 85:
|
| 655 |
+
return 2
|
| 656 |
+
elif reviews < 20 or accuracy < 95:
|
| 657 |
+
return 3
|
| 658 |
+
elif reviews >= 20 and accuracy >= 95 and stability >= 30:
|
| 659 |
+
return 4
|
| 660 |
+
elif reviews >= 40 and accuracy >= 98 and stability >= 90:
|
| 661 |
+
return 5
|
| 662 |
+
else:
|
| 663 |
+
return 3
|
| 664 |
+
|
| 665 |
+
|
| 666 |
+
@router.post("/vocabulary/review")
|
| 667 |
+
async def record_vocabulary_review_fsrs(
|
| 668 |
+
review_request: VocabularyReviewRequest,
|
| 669 |
+
request: Request,
|
| 670 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 671 |
+
):
|
| 672 |
+
"""Record vocabulary review and update FSRS parameters"""
|
| 673 |
+
try:
|
| 674 |
+
user_id = token if token else 'anonymous'
|
| 675 |
+
progress = learning_service.get_user_progress(user_id)
|
| 676 |
+
|
| 677 |
+
if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
|
| 678 |
+
raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
|
| 679 |
+
|
| 680 |
+
vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
|
| 681 |
+
fsrs = vocab['fsrs']
|
| 682 |
+
|
| 683 |
+
grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
|
| 684 |
+
grade = grade_map.get(review_request.rating, 3)
|
| 685 |
+
|
| 686 |
+
new_fsrs = calculate_next_review_fsrs(fsrs, grade)
|
| 687 |
+
|
| 688 |
+
vocab['fsrs'] = new_fsrs
|
| 689 |
+
vocab['times_reviewed'] += 1
|
| 690 |
+
if grade >= 2:
|
| 691 |
+
vocab['times_correct'] += 1
|
| 692 |
+
else:
|
| 693 |
+
vocab['fsrs']['lapses'] += 1
|
| 694 |
+
|
| 695 |
+
vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
|
| 696 |
+
vocab['mastery_level'] = calculate_mastery_level(vocab)
|
| 697 |
+
vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 698 |
+
|
| 699 |
+
if 'vocabulary_reviewed' not in progress['overall_stats']:
|
| 700 |
+
progress['overall_stats']['vocabulary_reviewed'] = 0
|
| 701 |
+
progress['overall_stats']['vocabulary_reviewed'] += 1
|
| 702 |
+
|
| 703 |
+
learning_service.save_user_progress(user_id, progress)
|
| 704 |
+
|
| 705 |
+
return {
|
| 706 |
+
"success": True,
|
| 707 |
+
"vocabulary": vocab,
|
| 708 |
+
"next_review": new_fsrs['next_review'],
|
| 709 |
+
"interval_days": new_fsrs['interval_days']
|
| 710 |
+
}
|
| 711 |
+
except HTTPException:
|
| 712 |
+
raise
|
| 713 |
+
except Exception as e:
|
| 714 |
+
logger.error(f"Error recording vocabulary review: {e}")
|
| 715 |
+
raise HTTPException(status_code=500, detail="Failed to record review")
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
@router.get("/vocabulary/stats")
|
| 719 |
+
async def get_vocabulary_stats(
|
| 720 |
+
request: Request,
|
| 721 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 722 |
+
):
|
| 723 |
+
"""Get vocabulary mastery statistics"""
|
| 724 |
+
try:
|
| 725 |
+
user_id = token if token else 'anonymous'
|
| 726 |
+
progress = learning_service.get_user_progress(user_id)
|
| 727 |
+
|
| 728 |
+
if not progress:
|
| 729 |
+
return {
|
| 730 |
+
"total_words": 0,
|
| 731 |
+
"in_practice": 0,
|
| 732 |
+
"mastery_breakdown": {str(i): 0 for i in range(6)},
|
| 733 |
+
"average_accuracy": 0
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
vocab_progress = progress.get('vocabulary_progress', {})
|
| 737 |
+
mastery_breakdown = {str(i): 0 for i in range(6)}
|
| 738 |
+
total_accuracy = 0
|
| 739 |
+
total_with_reviews = 0
|
| 740 |
+
|
| 741 |
+
for vocab_data in vocab_progress.values():
|
| 742 |
+
level = vocab_data.get('mastery_level', 0)
|
| 743 |
+
mastery_breakdown[str(level)] += 1
|
| 744 |
+
|
| 745 |
+
if vocab_data.get('times_reviewed', 0) > 0:
|
| 746 |
+
total_accuracy += vocab_data.get('accuracy', 0)
|
| 747 |
+
total_with_reviews += 1
|
| 748 |
+
|
| 749 |
+
avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
|
| 750 |
+
|
| 751 |
+
return {
|
| 752 |
+
"total_words": len(vocab_progress),
|
| 753 |
+
"in_practice": len(vocab_progress),
|
| 754 |
+
"mastery_breakdown": mastery_breakdown,
|
| 755 |
+
"average_accuracy": round(avg_accuracy, 1),
|
| 756 |
+
"total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
|
| 757 |
+
}
|
| 758 |
+
except Exception as e:
|
| 759 |
+
logger.error(f"Error getting vocabulary stats: {e}")
|
| 760 |
+
raise HTTPException(status_code=500, detail="Failed to get stats")
|
| 761 |
+
|
| 762 |
+
|
| 763 |
+
@router.get("/vocabulary/library")
|
| 764 |
+
async def get_vocabulary_library(
|
| 765 |
+
lesson_id: Optional[int] = None,
|
| 766 |
+
level: Optional[str] = None,
|
| 767 |
+
search: Optional[str] = None,
|
| 768 |
+
request: Request = None,
|
| 769 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 770 |
+
):
|
| 771 |
+
"""Browse all vocabulary with filters"""
|
| 772 |
+
try:
|
| 773 |
+
user_id = token if token else 'anonymous'
|
| 774 |
+
|
| 775 |
+
all_vocab = learning_service.get_all_vocabulary()
|
| 776 |
+
progress = learning_service.get_user_progress(user_id)
|
| 777 |
+
user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
|
| 778 |
+
|
| 779 |
+
filtered_vocab = all_vocab
|
| 780 |
+
|
| 781 |
+
if lesson_id:
|
| 782 |
+
filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
|
| 783 |
+
|
| 784 |
+
if level:
|
| 785 |
+
filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
|
| 786 |
+
|
| 787 |
+
if search:
|
| 788 |
+
search_lower = search.lower()
|
| 789 |
+
filtered_vocab = [v for v in filtered_vocab
|
| 790 |
+
if search_lower in v.get('swahili', '').lower()
|
| 791 |
+
or search_lower in v.get('english', '').lower()]
|
| 792 |
+
|
| 793 |
+
for vocab in filtered_vocab:
|
| 794 |
+
vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
|
| 795 |
+
if vocab_id in user_vocab:
|
| 796 |
+
vocab['status'] = 'practicing'
|
| 797 |
+
vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
|
| 798 |
+
vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
|
| 799 |
+
vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
|
| 800 |
+
else:
|
| 801 |
+
vocab['status'] = 'not_practicing'
|
| 802 |
+
vocab['mastery_level'] = 0
|
| 803 |
+
|
| 804 |
+
return {
|
| 805 |
+
"vocabulary": filtered_vocab,
|
| 806 |
+
"total": len(filtered_vocab),
|
| 807 |
+
"filters_applied": {
|
| 808 |
+
"lesson_id": lesson_id,
|
| 809 |
+
"level": level,
|
| 810 |
+
"search": search
|
| 811 |
+
}
|
| 812 |
+
}
|
| 813 |
+
except Exception as e:
|
| 814 |
+
logger.error(f"Error getting vocabulary library: {e}")
|
| 815 |
+
raise HTTPException(status_code=500, detail="Failed to get vocabulary")
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
# Reading Comprehension
|
| 819 |
+
|
| 820 |
+
class ComprehensionAnswer(BaseModel):
|
| 821 |
+
question_id: str
|
| 822 |
+
answer: str
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
class ComprehensionSubmission(BaseModel):
|
| 826 |
+
lesson_id: int
|
| 827 |
+
passage_id: str
|
| 828 |
+
answers: List[ComprehensionAnswer]
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
@router.post("/comprehension/submit")
|
| 832 |
+
async def submit_comprehension_answers(
|
| 833 |
+
submission: ComprehensionSubmission,
|
| 834 |
+
request: Request,
|
| 835 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 836 |
+
):
|
| 837 |
+
"""Submit reading comprehension answers and get scoring"""
|
| 838 |
+
try:
|
| 839 |
+
user_id = token if token else 'anonymous'
|
| 840 |
+
|
| 841 |
+
lesson = learning_service.get_lesson(submission.lesson_id)
|
| 842 |
+
if not lesson:
|
| 843 |
+
raise HTTPException(status_code=404, detail="Lesson not found")
|
| 844 |
+
|
| 845 |
+
passage = None
|
| 846 |
+
for p in lesson.get('reading_passages', []):
|
| 847 |
+
if p['passage_id'] == submission.passage_id:
|
| 848 |
+
passage = p
|
| 849 |
+
break
|
| 850 |
+
|
| 851 |
+
if not passage:
|
| 852 |
+
raise HTTPException(status_code=404, detail="Passage not found")
|
| 853 |
+
|
| 854 |
+
results = []
|
| 855 |
+
correct_count = 0
|
| 856 |
+
|
| 857 |
+
for submitted in submission.answers:
|
| 858 |
+
question_id = submitted.question_id
|
| 859 |
+
user_answer = submitted.answer.strip().lower()
|
| 860 |
+
|
| 861 |
+
question = None
|
| 862 |
+
for q in passage['comprehension_questions']:
|
| 863 |
+
if q['question_id'] == question_id:
|
| 864 |
+
question = q
|
| 865 |
+
break
|
| 866 |
+
|
| 867 |
+
if not question:
|
| 868 |
+
continue
|
| 869 |
+
|
| 870 |
+
correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
|
| 871 |
+
is_correct = user_answer in correct_answers
|
| 872 |
+
|
| 873 |
+
if is_correct:
|
| 874 |
+
correct_count += 1
|
| 875 |
+
|
| 876 |
+
results.append({
|
| 877 |
+
"question_id": question_id,
|
| 878 |
+
"correct": is_correct,
|
| 879 |
+
"user_answer": user_answer,
|
| 880 |
+
"correct_answer": question['correct_answers'][0] if correct_answers else None,
|
| 881 |
+
"explanation": question.get('explanation')
|
| 882 |
+
})
|
| 883 |
+
|
| 884 |
+
score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
|
| 885 |
+
|
| 886 |
+
progress = learning_service.get_user_progress(user_id)
|
| 887 |
+
if not progress:
|
| 888 |
+
progress = learning_service.create_default_progress(user_id)
|
| 889 |
+
|
| 890 |
+
if 'comprehension_scores' not in progress:
|
| 891 |
+
progress['comprehension_scores'] = {}
|
| 892 |
+
|
| 893 |
+
progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
|
| 894 |
+
"score": score,
|
| 895 |
+
"completed_at": datetime.utcnow().isoformat() + 'Z',
|
| 896 |
+
"attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
learning_service.save_user_progress(user_id, progress)
|
| 900 |
+
|
| 901 |
+
return {
|
| 902 |
+
"results": results,
|
| 903 |
+
"score": round(score, 1),
|
| 904 |
+
"correct": correct_count,
|
| 905 |
+
"total": len(submission.answers)
|
| 906 |
+
}
|
| 907 |
+
except HTTPException:
|
| 908 |
+
raise
|
| 909 |
+
except Exception as e:
|
| 910 |
+
logger.error(f"Error submitting comprehension: {e}")
|
| 911 |
+
raise HTTPException(status_code=500, detail="Failed to submit comprehension")
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
# Task Scenarios
|
| 915 |
+
|
| 916 |
+
class ScenarioProgressUpdate(BaseModel):
|
| 917 |
+
turn_id: str
|
| 918 |
+
choice_id: str
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
@router.get("/scenarios/{scenario_id}")
|
| 922 |
+
async def get_scenario(
|
| 923 |
+
scenario_id: str,
|
| 924 |
+
request: Request,
|
| 925 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 926 |
+
):
|
| 927 |
+
"""Get task scenario with branching dialogue"""
|
| 928 |
+
try:
|
| 929 |
+
user_id = token if token else 'anonymous'
|
| 930 |
+
|
| 931 |
+
scenario = learning_service.get_scenario(scenario_id)
|
| 932 |
+
if not scenario:
|
| 933 |
+
raise HTTPException(status_code=404, detail="Scenario not found")
|
| 934 |
+
|
| 935 |
+
progress = learning_service.get_user_progress(user_id)
|
| 936 |
+
scenario_progress = None
|
| 937 |
+
|
| 938 |
+
if progress and 'scenario_progress' in progress:
|
| 939 |
+
scenario_progress = progress['scenario_progress'].get(scenario_id)
|
| 940 |
+
|
| 941 |
+
return {
|
| 942 |
+
"scenario": scenario,
|
| 943 |
+
"user_progress": scenario_progress
|
| 944 |
+
}
|
| 945 |
+
except HTTPException:
|
| 946 |
+
raise
|
| 947 |
+
except Exception as e:
|
| 948 |
+
logger.error(f"Error getting scenario: {e}")
|
| 949 |
+
raise HTTPException(status_code=500, detail="Failed to get scenario")
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
@router.post("/scenarios/{scenario_id}/progress")
|
| 953 |
+
async def update_scenario_progress(
|
| 954 |
+
scenario_id: str,
|
| 955 |
+
progress_update: ScenarioProgressUpdate,
|
| 956 |
+
request: Request,
|
| 957 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 958 |
+
):
|
| 959 |
+
"""Update scenario progress with user choice"""
|
| 960 |
+
try:
|
| 961 |
+
user_id = token if token else 'anonymous'
|
| 962 |
+
|
| 963 |
+
scenario = learning_service.get_scenario(scenario_id)
|
| 964 |
+
if not scenario:
|
| 965 |
+
raise HTTPException(status_code=404, detail="Scenario not found")
|
| 966 |
+
|
| 967 |
+
progress = learning_service.get_user_progress(user_id)
|
| 968 |
+
if not progress:
|
| 969 |
+
progress = learning_service.create_default_progress(user_id)
|
| 970 |
+
|
| 971 |
+
if 'scenario_progress' not in progress:
|
| 972 |
+
progress['scenario_progress'] = {}
|
| 973 |
+
|
| 974 |
+
if scenario_id not in progress['scenario_progress']:
|
| 975 |
+
progress['scenario_progress'][scenario_id] = {
|
| 976 |
+
"started_at": datetime.utcnow().isoformat() + 'Z',
|
| 977 |
+
"turns": [],
|
| 978 |
+
"completed": False
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
progress['scenario_progress'][scenario_id]['turns'].append({
|
| 982 |
+
"turn_id": progress_update.turn_id,
|
| 983 |
+
"choice_id": progress_update.choice_id,
|
| 984 |
+
"timestamp": datetime.utcnow().isoformat() + 'Z'
|
| 985 |
+
})
|
| 986 |
+
|
| 987 |
+
turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
|
| 988 |
+
if turns_count >= scenario.get('required_turns', 6):
|
| 989 |
+
progress['scenario_progress'][scenario_id]['completed'] = True
|
| 990 |
+
progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 991 |
+
|
| 992 |
+
learning_service.save_user_progress(user_id, progress)
|
| 993 |
+
|
| 994 |
+
return {
|
| 995 |
+
"success": True,
|
| 996 |
+
"progress": progress['scenario_progress'][scenario_id]
|
| 997 |
+
}
|
| 998 |
+
except HTTPException:
|
| 999 |
+
raise
|
| 1000 |
+
except Exception as e:
|
| 1001 |
+
logger.error(f"Error updating scenario progress: {e}")
|
| 1002 |
+
raise HTTPException(status_code=500, detail="Failed to update scenario")
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
@router.get("/scenarios")
|
| 1006 |
+
async def list_scenarios(
|
| 1007 |
+
request: Request,
|
| 1008 |
+
token: Optional[str] = Depends(optional_hf_token)
|
| 1009 |
+
):
|
| 1010 |
+
"""Get list of all available scenarios"""
|
| 1011 |
+
try:
|
| 1012 |
+
scenarios = learning_service.get_all_scenarios()
|
| 1013 |
+
return {
|
| 1014 |
+
"success": True,
|
| 1015 |
+
"scenarios": scenarios,
|
| 1016 |
+
"total": len(scenarios)
|
| 1017 |
+
}
|
| 1018 |
+
except Exception as e:
|
| 1019 |
+
logger.error(f"Error listing scenarios: {e}")
|
| 1020 |
+
raise HTTPException(status_code=500, detail="Failed to list scenarios")
|
app/routers/mobile.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Query, Depends
|
| 2 |
+
from fastapi.responses import Response
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import base64
|
| 6 |
+
import json
|
| 7 |
+
import uuid
|
| 8 |
+
import datetime
|
| 9 |
+
from app.services.translation_service import TranslationService
|
| 10 |
+
from app.services.tts_service import TTSService
|
| 11 |
+
from app.services.transcription_service import TranscriptionService
|
| 12 |
+
from app.auth import require_hf_token
|
| 13 |
+
|
| 14 |
+
router = APIRouter()
|
| 15 |
+
|
| 16 |
+
# Service instances - these will be injected by main app
|
| 17 |
+
translation_service = None
|
| 18 |
+
tts_service = None
|
| 19 |
+
transcription_service = None
|
| 20 |
+
|
| 21 |
+
# Mobile-specific data models
|
| 22 |
+
class MobileSessionRequest(BaseModel):
|
| 23 |
+
user_name: str
|
| 24 |
+
default_source_lang: str = "eng"
|
| 25 |
+
default_target_lang: str = "swa"
|
| 26 |
+
|
| 27 |
+
class MobileSessionResponse(BaseModel):
|
| 28 |
+
session_id: str
|
| 29 |
+
participant_id: str
|
| 30 |
+
user_name: str
|
| 31 |
+
source_language: str
|
| 32 |
+
target_language: str
|
| 33 |
+
|
| 34 |
+
class MobileTranscribeRequest(BaseModel):
|
| 35 |
+
participant_id: str
|
| 36 |
+
source_language: str
|
| 37 |
+
target_language: str
|
| 38 |
+
is_final_chunk: bool = False
|
| 39 |
+
|
| 40 |
+
class MobileLanguageUpdateRequest(BaseModel):
|
| 41 |
+
participant_id: str
|
| 42 |
+
source_language: str
|
| 43 |
+
target_language: str
|
| 44 |
+
|
| 45 |
+
# In-memory session storage (in production, use Redis or database)
|
| 46 |
+
mobile_sessions = {}
|
| 47 |
+
|
| 48 |
+
@router.post("/mobile/session/create", response_model=MobileSessionResponse)
|
| 49 |
+
async def create_mobile_session(
|
| 50 |
+
user_name: str = Form(...),
|
| 51 |
+
default_source_lang: str = Form("eng"),
|
| 52 |
+
default_target_lang: str = Form("swa"),
|
| 53 |
+
token: str = Depends(require_hf_token)
|
| 54 |
+
):
|
| 55 |
+
"""Create a mobile-specific single-user session"""
|
| 56 |
+
try:
|
| 57 |
+
print(f"=== MOBILE SESSION CREATE REQUEST ===")
|
| 58 |
+
print(f"User name: {user_name}")
|
| 59 |
+
print(f"Source language: {default_source_lang}")
|
| 60 |
+
print(f"Target language: {default_target_lang}")
|
| 61 |
+
|
| 62 |
+
# Validate inputs
|
| 63 |
+
if not user_name or user_name.strip() == "":
|
| 64 |
+
raise HTTPException(status_code=400, detail="User name is required")
|
| 65 |
+
|
| 66 |
+
# Validate language codes
|
| 67 |
+
valid_languages = ["eng", "swa", "kik", "kam", "mer", "luo", "som"]
|
| 68 |
+
if default_source_lang not in valid_languages:
|
| 69 |
+
print(f"Invalid source language: {default_source_lang}, defaulting to 'eng'")
|
| 70 |
+
default_source_lang = "eng"
|
| 71 |
+
if default_target_lang not in valid_languages:
|
| 72 |
+
print(f"Invalid target language: {default_target_lang}, defaulting to 'swa'")
|
| 73 |
+
default_target_lang = "swa"
|
| 74 |
+
|
| 75 |
+
session_id = f"mobile-{uuid.uuid4().hex[:8]}"
|
| 76 |
+
participant_id = f"user-{uuid.uuid4().hex[:8]}"
|
| 77 |
+
|
| 78 |
+
# Store session data
|
| 79 |
+
mobile_sessions[session_id] = {
|
| 80 |
+
"session_id": session_id,
|
| 81 |
+
"participant_id": participant_id,
|
| 82 |
+
"user_name": user_name.strip(),
|
| 83 |
+
"source_language": default_source_lang,
|
| 84 |
+
"target_language": default_target_lang,
|
| 85 |
+
"created_at": datetime.datetime.now().isoformat()
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
print(f"Created session: {session_id} for user: {user_name}")
|
| 89 |
+
print(f"Total sessions: {len(mobile_sessions)}")
|
| 90 |
+
|
| 91 |
+
response = MobileSessionResponse(
|
| 92 |
+
session_id=session_id,
|
| 93 |
+
participant_id=participant_id,
|
| 94 |
+
user_name=user_name.strip(),
|
| 95 |
+
source_language=default_source_lang,
|
| 96 |
+
target_language=default_target_lang
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
print(f"Returning response: {response}")
|
| 100 |
+
return response
|
| 101 |
+
|
| 102 |
+
except HTTPException:
|
| 103 |
+
raise
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"ERROR creating mobile session: {e}")
|
| 106 |
+
import traceback
|
| 107 |
+
traceback.print_exc()
|
| 108 |
+
raise HTTPException(status_code=500, detail=f"Failed to create mobile session: {str(e)}")
|
| 109 |
+
|
| 110 |
+
@router.get("/mobile/session/{session_id}")
|
| 111 |
+
async def get_mobile_session(session_id: str, token: str = Depends(require_hf_token)):
|
| 112 |
+
"""Get mobile session details"""
|
| 113 |
+
if session_id not in mobile_sessions:
|
| 114 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 115 |
+
|
| 116 |
+
return mobile_sessions[session_id]
|
| 117 |
+
|
| 118 |
+
@router.put("/mobile/session/{session_id}/languages")
|
| 119 |
+
async def update_session_languages(
|
| 120 |
+
session_id: str,
|
| 121 |
+
participant_id: str = Form(...),
|
| 122 |
+
source_language: str = Form(...),
|
| 123 |
+
target_language: str = Form(...)
|
| 124 |
+
):
|
| 125 |
+
"""Update the default languages for a mobile session"""
|
| 126 |
+
try:
|
| 127 |
+
if session_id not in mobile_sessions:
|
| 128 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 129 |
+
|
| 130 |
+
session = mobile_sessions[session_id]
|
| 131 |
+
if session["participant_id"] != participant_id:
|
| 132 |
+
raise HTTPException(status_code=403, detail="Invalid participant")
|
| 133 |
+
|
| 134 |
+
# Update session languages
|
| 135 |
+
session["source_language"] = source_language
|
| 136 |
+
session["target_language"] = target_language
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
"success": True,
|
| 140 |
+
"session_id": session_id,
|
| 141 |
+
"source_language": source_language,
|
| 142 |
+
"target_language": target_language
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"Failed to update languages: {str(e)}")
|
| 147 |
+
|
| 148 |
+
@router.post("/mobile/session/{session_id}/transcribe-realtime")
|
| 149 |
+
async def transcribe_realtime(
|
| 150 |
+
session_id: str,
|
| 151 |
+
audio: UploadFile = File(...),
|
| 152 |
+
participant_id: str = Form(...),
|
| 153 |
+
source_language: str = Form(...),
|
| 154 |
+
target_language: str = Form(...),
|
| 155 |
+
is_final_chunk: bool = Form(False),
|
| 156 |
+
chunk_sequence: int = Form(0)
|
| 157 |
+
):
|
| 158 |
+
"""Real-time transcription endpoint for mobile with streaming support"""
|
| 159 |
+
try:
|
| 160 |
+
if session_id not in mobile_sessions:
|
| 161 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 162 |
+
|
| 163 |
+
session = mobile_sessions[session_id]
|
| 164 |
+
if session["participant_id"] != participant_id:
|
| 165 |
+
raise HTTPException(status_code=403, detail="Invalid participant")
|
| 166 |
+
|
| 167 |
+
# Read audio file
|
| 168 |
+
audio_data = await audio.read()
|
| 169 |
+
|
| 170 |
+
# Generate unique message ID for this chunk sequence
|
| 171 |
+
message_id = f"msg-{participant_id}-{chunk_sequence}"
|
| 172 |
+
|
| 173 |
+
# Initialize response data
|
| 174 |
+
response_data = {
|
| 175 |
+
"success": True,
|
| 176 |
+
"message_id": message_id,
|
| 177 |
+
"chunk_sequence": chunk_sequence,
|
| 178 |
+
"original_text": "",
|
| 179 |
+
"original_language": source_language,
|
| 180 |
+
"is_final_chunk": is_final_chunk,
|
| 181 |
+
"is_interim": not is_final_chunk,
|
| 182 |
+
"session_id": session_id,
|
| 183 |
+
"translated_text": None,
|
| 184 |
+
"target_language": target_language,
|
| 185 |
+
"has_audio": False,
|
| 186 |
+
"audio_base64": None,
|
| 187 |
+
"audio_format": None
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
# Process transcription
|
| 191 |
+
if transcription_service:
|
| 192 |
+
try:
|
| 193 |
+
# Use streaming transcription if available
|
| 194 |
+
if hasattr(transcription_service, 'process_realtime_chunk'):
|
| 195 |
+
transcription_result = await transcription_service.process_realtime_chunk(
|
| 196 |
+
audio_data, source_language, participant_id, is_final_chunk
|
| 197 |
+
)
|
| 198 |
+
else:
|
| 199 |
+
# Fallback to regular transcription
|
| 200 |
+
transcription_result = await transcription_service.transcribe_audio(
|
| 201 |
+
audio_data, source_language
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
response_data["original_text"] = transcription_result or ""
|
| 205 |
+
|
| 206 |
+
# Only process translation and TTS for final chunks with actual text
|
| 207 |
+
if is_final_chunk and transcription_result and transcription_result.strip():
|
| 208 |
+
if translation_service:
|
| 209 |
+
try:
|
| 210 |
+
translated_text = await translation_service.translate_text(
|
| 211 |
+
transcription_result, source_language, target_language
|
| 212 |
+
)
|
| 213 |
+
response_data["translated_text"] = translated_text
|
| 214 |
+
|
| 215 |
+
# Generate TTS audio in target language
|
| 216 |
+
if tts_service and translated_text:
|
| 217 |
+
try:
|
| 218 |
+
tts_audio = await tts_service.generate_speech(
|
| 219 |
+
translated_text, target_language, output_format="wav"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if tts_audio:
|
| 223 |
+
response_data.update({
|
| 224 |
+
"has_audio": True,
|
| 225 |
+
"audio_base64": base64.b64encode(tts_audio).decode('utf-8'),
|
| 226 |
+
"audio_format": "wav"
|
| 227 |
+
})
|
| 228 |
+
except Exception as tts_error:
|
| 229 |
+
print(f"TTS generation failed: {tts_error}")
|
| 230 |
+
# Continue without TTS
|
| 231 |
+
|
| 232 |
+
except Exception as translation_error:
|
| 233 |
+
print(f"Translation failed: {translation_error}")
|
| 234 |
+
# Continue without translation
|
| 235 |
+
|
| 236 |
+
except Exception as transcription_error:
|
| 237 |
+
print(f"Transcription failed: {transcription_error}")
|
| 238 |
+
response_data["original_text"] = ""
|
| 239 |
+
|
| 240 |
+
return response_data
|
| 241 |
+
|
| 242 |
+
else:
|
| 243 |
+
raise HTTPException(status_code=500, detail="Transcription service not available")
|
| 244 |
+
|
| 245 |
+
except Exception as e:
|
| 246 |
+
raise HTTPException(status_code=500, detail=f"Real-time transcription failed: {str(e)}")
|
| 247 |
+
|
| 248 |
+
@router.post("/mobile/session/{session_id}/stream-audio")
|
| 249 |
+
async def stream_audio_chunk(
|
| 250 |
+
session_id: str,
|
| 251 |
+
participant_id: str = Form(...),
|
| 252 |
+
audio_chunk: UploadFile = File(...),
|
| 253 |
+
source_language: str = Form(...),
|
| 254 |
+
target_language: str = Form(...),
|
| 255 |
+
chunk_index: int = Form(0),
|
| 256 |
+
is_speaking: bool = Form(True),
|
| 257 |
+
force_complete: bool = Form(False)
|
| 258 |
+
):
|
| 259 |
+
"""Stream audio chunks for continuous processing"""
|
| 260 |
+
try:
|
| 261 |
+
if session_id not in mobile_sessions:
|
| 262 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 263 |
+
|
| 264 |
+
session = mobile_sessions[session_id]
|
| 265 |
+
if session["participant_id"] != participant_id:
|
| 266 |
+
raise HTTPException(status_code=403, detail="Invalid participant")
|
| 267 |
+
|
| 268 |
+
audio_data = await audio_chunk.read()
|
| 269 |
+
|
| 270 |
+
# Use streaming approach similar to WebSocket
|
| 271 |
+
interim_text = ""
|
| 272 |
+
if transcription_service:
|
| 273 |
+
try:
|
| 274 |
+
if hasattr(transcription_service, 'process_audio_chunk'):
|
| 275 |
+
result = await transcription_service.process_audio_chunk(
|
| 276 |
+
audio_data,
|
| 277 |
+
source_language,
|
| 278 |
+
participant_id,
|
| 279 |
+
has_voice_activity=is_speaking,
|
| 280 |
+
progress_callback=None, # No callback for HTTP
|
| 281 |
+
sentence_callback=None # No callback for HTTP
|
| 282 |
+
)
|
| 283 |
+
interim_text = result or ""
|
| 284 |
+
else:
|
| 285 |
+
# Fallback to regular transcription for interim results
|
| 286 |
+
interim_text = await transcription_service.transcribe_audio(
|
| 287 |
+
audio_data, source_language
|
| 288 |
+
) or ""
|
| 289 |
+
except Exception as e:
|
| 290 |
+
print(f"Streaming transcription error: {e}")
|
| 291 |
+
interim_text = ""
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"success": True,
|
| 295 |
+
"chunk_index": chunk_index,
|
| 296 |
+
"session_id": session_id,
|
| 297 |
+
"interim_text": interim_text,
|
| 298 |
+
"is_speaking": is_speaking,
|
| 299 |
+
"force_complete": force_complete
|
| 300 |
+
}
|
| 301 |
+
else:
|
| 302 |
+
raise HTTPException(status_code=500, detail="Transcription service not available")
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
raise HTTPException(status_code=500, detail=f"Audio streaming failed: {str(e)}")
|
| 306 |
+
|
| 307 |
+
@router.get("/mobile/session/{session_id}/realtime-status")
|
| 308 |
+
async def get_realtime_status(session_id: str, participant_id: str = Query(...)):
|
| 309 |
+
"""Get current real-time processing status"""
|
| 310 |
+
try:
|
| 311 |
+
if session_id not in mobile_sessions:
|
| 312 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 313 |
+
|
| 314 |
+
session = mobile_sessions[session_id]
|
| 315 |
+
if session["participant_id"] != participant_id:
|
| 316 |
+
raise HTTPException(status_code=403, detail="Invalid participant")
|
| 317 |
+
|
| 318 |
+
# Check if transcription service has any pending messages
|
| 319 |
+
pending_messages = []
|
| 320 |
+
if transcription_service:
|
| 321 |
+
try:
|
| 322 |
+
if hasattr(transcription_service, 'get_participant_status'):
|
| 323 |
+
pending_messages = transcription_service.get_participant_status(participant_id)
|
| 324 |
+
else:
|
| 325 |
+
pending_messages = []
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"Error getting participant status: {e}")
|
| 328 |
+
pending_messages = []
|
| 329 |
+
|
| 330 |
+
return {
|
| 331 |
+
"session_id": session_id,
|
| 332 |
+
"participant_id": participant_id,
|
| 333 |
+
"is_active": True,
|
| 334 |
+
"pending_messages": pending_messages,
|
| 335 |
+
"current_languages": {
|
| 336 |
+
"source": session["source_language"],
|
| 337 |
+
"target": session["target_language"]
|
| 338 |
+
},
|
| 339 |
+
"service_status": {
|
| 340 |
+
"transcription": transcription_service is not None,
|
| 341 |
+
"translation": translation_service is not None,
|
| 342 |
+
"tts": tts_service is not None
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
except Exception as e:
|
| 347 |
+
raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
|
| 348 |
+
|
| 349 |
+
@router.post("/mobile/session/{session_id}/transcribe-with-languages")
|
| 350 |
+
async def transcribe_with_languages_legacy(
|
| 351 |
+
session_id: str,
|
| 352 |
+
audio: UploadFile = File(...),
|
| 353 |
+
participant_id: str = Form(...),
|
| 354 |
+
source_language: str = Form(...),
|
| 355 |
+
target_language: str = Form(...),
|
| 356 |
+
is_final_chunk: bool = Form(False)
|
| 357 |
+
):
|
| 358 |
+
"""Legacy endpoint - transcribe audio with specific source/target languages for mobile"""
|
| 359 |
+
try:
|
| 360 |
+
if session_id not in mobile_sessions:
|
| 361 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 362 |
+
|
| 363 |
+
session = mobile_sessions[session_id]
|
| 364 |
+
if session["participant_id"] != participant_id:
|
| 365 |
+
raise HTTPException(status_code=403, detail="Invalid participant")
|
| 366 |
+
|
| 367 |
+
# Read audio file
|
| 368 |
+
audio_data = await audio.read()
|
| 369 |
+
|
| 370 |
+
# Generate unique message ID
|
| 371 |
+
message_id = f"msg-{uuid.uuid4().hex[:8]}"
|
| 372 |
+
|
| 373 |
+
# Initialize response
|
| 374 |
+
response_data = {
|
| 375 |
+
"success": True,
|
| 376 |
+
"message_id": message_id,
|
| 377 |
+
"original_text": "",
|
| 378 |
+
"original_language": source_language,
|
| 379 |
+
"translated_text": None,
|
| 380 |
+
"target_language": target_language,
|
| 381 |
+
"has_audio": False,
|
| 382 |
+
"is_final_chunk": is_final_chunk,
|
| 383 |
+
"audio_base64": None
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
# Process transcription in source language
|
| 387 |
+
if transcription_service:
|
| 388 |
+
try:
|
| 389 |
+
transcription_result = await transcription_service.transcribe_audio(
|
| 390 |
+
audio_data, source_language
|
| 391 |
+
)
|
| 392 |
+
response_data["original_text"] = transcription_result or ""
|
| 393 |
+
|
| 394 |
+
# Process translation to target language
|
| 395 |
+
if translation_service and transcription_result and transcription_result.strip():
|
| 396 |
+
try:
|
| 397 |
+
translated_text = await translation_service.translate_text(
|
| 398 |
+
transcription_result, source_language, target_language
|
| 399 |
+
)
|
| 400 |
+
response_data["translated_text"] = translated_text
|
| 401 |
+
|
| 402 |
+
# Generate TTS audio in target language
|
| 403 |
+
if tts_service and translated_text:
|
| 404 |
+
try:
|
| 405 |
+
tts_audio = await tts_service.generate_speech(
|
| 406 |
+
translated_text, target_language, output_format="wav"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if tts_audio:
|
| 410 |
+
response_data["has_audio"] = True
|
| 411 |
+
response_data["audio_base64"] = base64.b64encode(tts_audio).decode('utf-8')
|
| 412 |
+
except Exception as tts_error:
|
| 413 |
+
print(f"TTS generation failed: {tts_error}")
|
| 414 |
+
|
| 415 |
+
except Exception as translation_error:
|
| 416 |
+
print(f"Translation failed: {translation_error}")
|
| 417 |
+
|
| 418 |
+
except Exception as transcription_error:
|
| 419 |
+
print(f"Transcription failed: {transcription_error}")
|
| 420 |
+
|
| 421 |
+
return response_data
|
| 422 |
+
else:
|
| 423 |
+
raise HTTPException(status_code=500, detail="Transcription service not available")
|
| 424 |
+
|
| 425 |
+
except Exception as e:
|
| 426 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
| 427 |
+
|
| 428 |
+
@router.post("/mobile/translate")
|
| 429 |
+
async def translate_text_mobile(
|
| 430 |
+
text: str = Form(...),
|
| 431 |
+
source_lang: str = Form(...),
|
| 432 |
+
target_lang: str = Form(...)
|
| 433 |
+
):
|
| 434 |
+
"""Mobile-friendly text translation endpoint"""
|
| 435 |
+
try:
|
| 436 |
+
if not translation_service:
|
| 437 |
+
raise HTTPException(status_code=500, detail="Translation service not initialized")
|
| 438 |
+
|
| 439 |
+
# Map common language codes to internal format
|
| 440 |
+
lang_mapping = {
|
| 441 |
+
"english": "eng", "en": "eng",
|
| 442 |
+
"swahili": "swa", "sw": "swa",
|
| 443 |
+
"kikuyu": "kik", "ki": "kik",
|
| 444 |
+
"kamba": "kam", "kam": "kam",
|
| 445 |
+
"kimeru": "mer", "mer": "mer",
|
| 446 |
+
"luo": "luo", "luo": "luo",
|
| 447 |
+
"somali": "som", "so": "som"
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
source_code = lang_mapping.get(source_lang.lower(), source_lang.lower())
|
| 451 |
+
target_code = lang_mapping.get(target_lang.lower(), target_lang.lower())
|
| 452 |
+
|
| 453 |
+
translated_text = await translation_service.translate_text(text, source_code, target_code)
|
| 454 |
+
|
| 455 |
+
return {
|
| 456 |
+
"success": True,
|
| 457 |
+
"original_text": text,
|
| 458 |
+
"translated_text": translated_text or text,
|
| 459 |
+
"source_language": source_code,
|
| 460 |
+
"target_language": target_code
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
except Exception as e:
|
| 464 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
|
| 465 |
+
|
| 466 |
+
@router.get("/mobile/languages")
|
| 467 |
+
async def get_supported_languages():
|
| 468 |
+
"""Get list of supported languages for mobile app"""
|
| 469 |
+
return {
|
| 470 |
+
"supported_languages": [
|
| 471 |
+
{"code": "eng", "name": "English", "display_name": "English (eng)"},
|
| 472 |
+
{"code": "swa", "name": "Swahili", "display_name": "Swahili (swa)"},
|
| 473 |
+
{"code": "kik", "name": "Kikuyu", "display_name": "Kikuyu (kik)"},
|
| 474 |
+
{"code": "kam", "name": "Kamba", "display_name": "Kamba (kam)"},
|
| 475 |
+
{"code": "mer", "name": "Kimeru", "display_name": "Kimeru (mer)"},
|
| 476 |
+
{"code": "luo", "name": "Luo", "display_name": "Luo (luo)"},
|
| 477 |
+
{"code": "som", "name": "Somali", "display_name": "Somali (som)"}
|
| 478 |
+
]
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
@router.get("/mobile/test")
|
| 482 |
+
async def test_mobile_endpoints():
|
| 483 |
+
"""Test endpoint for mobile app connectivity"""
|
| 484 |
+
return {
|
| 485 |
+
"status": "Mobile API is working",
|
| 486 |
+
"endpoints": [
|
| 487 |
+
"/mobile/session/create",
|
| 488 |
+
"/mobile/session/{session_id}",
|
| 489 |
+
"/mobile/session/{session_id}/languages",
|
| 490 |
+
"/mobile/session/{session_id}/transcribe-realtime",
|
| 491 |
+
"/mobile/session/{session_id}/stream-audio",
|
| 492 |
+
"/mobile/session/{session_id}/realtime-status",
|
| 493 |
+
"/mobile/session/{session_id}/transcribe-with-languages",
|
| 494 |
+
"/mobile/translate",
|
| 495 |
+
"/mobile/languages",
|
| 496 |
+
"/mobile/test"
|
| 497 |
+
],
|
| 498 |
+
"timestamp": datetime.datetime.now().isoformat(),
|
| 499 |
+
"services_available": {
|
| 500 |
+
"transcription": transcription_service is not None,
|
| 501 |
+
"translation": translation_service is not None,
|
| 502 |
+
"tts": tts_service is not None
|
| 503 |
+
},
|
| 504 |
+
"active_sessions": len(mobile_sessions),
|
| 505 |
+
"session_list": list(mobile_sessions.keys())
|
| 506 |
+
}
|
| 507 |
+
|
| 508 |
+
@router.post("/mobile/test-session")
|
| 509 |
+
async def test_session_creation(
|
| 510 |
+
test_user: str = Form("TestUser"),
|
| 511 |
+
test_source: str = Form("eng"),
|
| 512 |
+
test_target: str = Form("swa")
|
| 513 |
+
):
|
| 514 |
+
"""Test session creation with debug info"""
|
| 515 |
+
try:
|
| 516 |
+
print(f"=== TEST SESSION CREATE ===")
|
| 517 |
+
print(f"Received: user={test_user}, source={test_source}, target={test_target}")
|
| 518 |
+
|
| 519 |
+
session_id = f"test-{uuid.uuid4().hex[:8]}"
|
| 520 |
+
|
| 521 |
+
return {
|
| 522 |
+
"success": True,
|
| 523 |
+
"test_session_id": session_id,
|
| 524 |
+
"received_params": {
|
| 525 |
+
"user": test_user,
|
| 526 |
+
"source": test_source,
|
| 527 |
+
"target": test_target
|
| 528 |
+
},
|
| 529 |
+
"form_processing": "OK"
|
| 530 |
+
}
|
| 531 |
+
except Exception as e:
|
| 532 |
+
print(f"Test session error: {e}")
|
| 533 |
+
return {
|
| 534 |
+
"success": False,
|
| 535 |
+
"error": str(e)
|
| 536 |
+
}
|
app/routers/sessions.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, Response, Depends
|
| 2 |
+
from typing import List
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
import qrcode
|
| 5 |
+
import io
|
| 6 |
+
from app.models import Session, SessionCreate
|
| 7 |
+
from app.auth import require_hf_token, optional_hf_token
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
# This will be set by the main app
|
| 12 |
+
session_manager = None
|
| 13 |
+
|
| 14 |
+
# Initialize services (these will be injected by main app)
|
| 15 |
+
transcription_service = None
|
| 16 |
+
translation_service = None
|
| 17 |
+
tts_service = None
|
| 18 |
+
|
| 19 |
+
class TextTranslationRequest(BaseModel):
|
| 20 |
+
text: str
|
| 21 |
+
source_language: str
|
| 22 |
+
target_language: str
|
| 23 |
+
|
| 24 |
+
class TextTranslationResponse(BaseModel):
|
| 25 |
+
original_text: str
|
| 26 |
+
translated_text: str
|
| 27 |
+
source_language: str
|
| 28 |
+
target_language: str
|
| 29 |
+
|
| 30 |
+
@router.post("/sessions", response_model=Session)
|
| 31 |
+
async def create_session(session_data: SessionCreate, token: str = Depends(require_hf_token)):
|
| 32 |
+
"""Create a new transcription session"""
|
| 33 |
+
try:
|
| 34 |
+
session = await session_manager.create_session(session_data)
|
| 35 |
+
return session
|
| 36 |
+
except Exception as e:
|
| 37 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 38 |
+
|
| 39 |
+
@router.get("/sessions", response_model=List[Session])
|
| 40 |
+
async def get_all_sessions(token: str = Depends(require_hf_token)):
|
| 41 |
+
"""Get all active sessions"""
|
| 42 |
+
try:
|
| 43 |
+
sessions = await session_manager.get_all_sessions()
|
| 44 |
+
return sessions
|
| 45 |
+
except Exception as e:
|
| 46 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 47 |
+
|
| 48 |
+
@router.get("/sessions/{session_id}", response_model=Session)
|
| 49 |
+
async def get_session(session_id: str, token: str = Depends(require_hf_token)):
|
| 50 |
+
"""Get specific session by ID or short code"""
|
| 51 |
+
session = await session_manager.get_session(session_id)
|
| 52 |
+
if not session:
|
| 53 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 54 |
+
return session
|
| 55 |
+
|
| 56 |
+
@router.get("/sessions/{session_id}/short-code")
|
| 57 |
+
async def get_session_short_code(session_id: str, token: str = Depends(require_hf_token)):
|
| 58 |
+
"""Get short code for a session"""
|
| 59 |
+
session = await session_manager.get_session(session_id)
|
| 60 |
+
if not session:
|
| 61 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 62 |
+
|
| 63 |
+
short_code = session_manager.get_short_code(session.id)
|
| 64 |
+
return {"session_id": session.id, "short_code": short_code}
|
| 65 |
+
|
| 66 |
+
@router.delete("/sessions/{session_id}")
|
| 67 |
+
async def delete_session(session_id: str, token: str = Depends(require_hf_token)):
|
| 68 |
+
"""Delete a session"""
|
| 69 |
+
success = await session_manager.delete_session(session_id)
|
| 70 |
+
if not success:
|
| 71 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 72 |
+
return {"message": "Session deleted successfully"}
|
| 73 |
+
|
| 74 |
+
@router.post("/sessions/{session_id}/languages/{language_code}")
|
| 75 |
+
async def add_language_to_session(session_id: str, language_code: str, token: str = Depends(require_hf_token)):
|
| 76 |
+
"""Add a language to a session"""
|
| 77 |
+
from app.models import LanguageCode
|
| 78 |
+
|
| 79 |
+
# Convert string to LanguageCode enum
|
| 80 |
+
try:
|
| 81 |
+
lang_code_enum = LanguageCode(language_code)
|
| 82 |
+
except ValueError:
|
| 83 |
+
raise HTTPException(status_code=400, detail=f"Invalid language code: {language_code}")
|
| 84 |
+
|
| 85 |
+
success = await session_manager.add_language_to_session(session_id, lang_code_enum)
|
| 86 |
+
if success:
|
| 87 |
+
session = await session_manager.get_session(session_id)
|
| 88 |
+
return {"message": f"Language {language_code} added to session", "session": session}
|
| 89 |
+
else:
|
| 90 |
+
# Check if session exists
|
| 91 |
+
session = await session_manager.get_session(session_id)
|
| 92 |
+
if not session:
|
| 93 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 94 |
+
return {"message": f"Language {language_code} already exists in session", "session": session}
|
| 95 |
+
|
| 96 |
+
@router.post("/translate", response_model=TextTranslationResponse)
|
| 97 |
+
async def translate_text(request: TextTranslationRequest, token: str = Depends(require_hf_token)):
|
| 98 |
+
"""Translate text from source language to target language"""
|
| 99 |
+
try:
|
| 100 |
+
# Map language codes to proper names
|
| 101 |
+
lang_map = {
|
| 102 |
+
'eng': 'English',
|
| 103 |
+
'swa': 'Swahili',
|
| 104 |
+
'kik': 'Kikuyu',
|
| 105 |
+
'kam': 'Kamba',
|
| 106 |
+
'mer': 'Kimeru',
|
| 107 |
+
'luo': 'Luo',
|
| 108 |
+
'som': 'Somali'
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
source_lang_name = lang_map.get(request.source_language.lower(), request.source_language)
|
| 112 |
+
target_lang_name = lang_map.get(request.target_language.lower(), request.target_language)
|
| 113 |
+
|
| 114 |
+
# Perform translation
|
| 115 |
+
translated_text = await translation_service.translate_text(
|
| 116 |
+
text=request.text,
|
| 117 |
+
source_lang=source_lang_name,
|
| 118 |
+
target_lang=target_lang_name
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return TextTranslationResponse(
|
| 122 |
+
original_text=request.text,
|
| 123 |
+
translated_text=translated_text,
|
| 124 |
+
source_language=request.source_language,
|
| 125 |
+
target_language=request.target_language
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
|
| 130 |
+
|
| 131 |
+
@router.get("/test")
|
| 132 |
+
async def test_endpoint(token: str = Depends(optional_hf_token)):
|
| 133 |
+
"""Test endpoint to verify API is working"""
|
| 134 |
+
auth_status = "authenticated" if token else "public"
|
| 135 |
+
return {
|
| 136 |
+
"status": "API is working",
|
| 137 |
+
"sessions_count": len(session_manager.sessions),
|
| 138 |
+
"auth_status": auth_status
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
@router.get("/test/translation")
|
| 142 |
+
async def test_translation(token: str = Depends(require_hf_token)):
|
| 143 |
+
"""Test translation service directly"""
|
| 144 |
+
try:
|
| 145 |
+
# Test English to Swahili translation
|
| 146 |
+
result = await translation_service.translate_text("Hello, how are you?", "English", "Swahili")
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"status": "Translation test completed",
|
| 150 |
+
"original": "Hello, how are you?",
|
| 151 |
+
"translated": result,
|
| 152 |
+
"source_lang": "English",
|
| 153 |
+
"target_lang": "Swahili"
|
| 154 |
+
}
|
| 155 |
+
except Exception as e:
|
| 156 |
+
return {"status": "Translation test failed", "error": str(e)}
|
| 157 |
+
|
| 158 |
+
@router.get("/test/tts")
|
| 159 |
+
async def test_tts(token: str = Depends(require_hf_token)):
|
| 160 |
+
"""Test TTS service directly"""
|
| 161 |
+
try:
|
| 162 |
+
# Test TTS generation
|
| 163 |
+
audio_data = await tts_service.generate_speech("Hello world", "eng")
|
| 164 |
+
|
| 165 |
+
return {
|
| 166 |
+
"status": "TTS test completed",
|
| 167 |
+
"text": "Hello world",
|
| 168 |
+
"language": "eng",
|
| 169 |
+
"audio_generated": audio_data is not None,
|
| 170 |
+
"audio_size": len(audio_data) if audio_data else 0
|
| 171 |
+
}
|
| 172 |
+
except Exception as e:
|
| 173 |
+
return {"status": "TTS test failed", "error": str(e)}
|
| 174 |
+
|
| 175 |
+
@router.get("/sessions/{session_id}/qr-code")
|
| 176 |
+
async def get_session_qr_code(session_id: str, token: str = Depends(require_hf_token)):
|
| 177 |
+
"""Generate QR code for session"""
|
| 178 |
+
if session_manager is None:
|
| 179 |
+
raise HTTPException(status_code=500, detail="Session manager not initialized")
|
| 180 |
+
|
| 181 |
+
session = await session_manager.get_session(session_id)
|
| 182 |
+
|
| 183 |
+
if not session:
|
| 184 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
| 185 |
+
|
| 186 |
+
# Generate QR code with session join URL - use your HF space URL
|
| 187 |
+
join_url = f"https://mutisya-realtime-translator-5-27-25-v2.hf.space/?join={session_id}"
|
| 188 |
+
|
| 189 |
+
qr = qrcode.QRCode(version=1, box_size=10, border=5)
|
| 190 |
+
qr.add_data(join_url)
|
| 191 |
+
qr.make(fit=True)
|
| 192 |
+
|
| 193 |
+
img = qr.make_image(fill_color="black", back_color="white")
|
| 194 |
+
|
| 195 |
+
# Convert to bytes
|
| 196 |
+
img_buffer = io.BytesIO()
|
| 197 |
+
img.save(img_buffer, format='PNG')
|
| 198 |
+
img_buffer.seek(0)
|
| 199 |
+
|
| 200 |
+
return Response(content=img_buffer.getvalue(), media_type="image/png")
|
app/routers/watch.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
|
| 2 |
+
from fastapi.responses import Response
|
| 3 |
+
from typing import Optional
|
| 4 |
+
import io
|
| 5 |
+
import base64
|
| 6 |
+
from app.services.transcription_service import TranscriptionService
|
| 7 |
+
from app.services.translation_service import TranslationService
|
| 8 |
+
from app.services.tts_service import TTSService
|
| 9 |
+
from app.models import LanguageCode
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
from app.auth import require_hf_token
|
| 12 |
+
|
| 13 |
+
router = APIRouter()
|
| 14 |
+
|
| 15 |
+
class WatchTranslationRequest(BaseModel):
|
| 16 |
+
source_language: str
|
| 17 |
+
target_language: str
|
| 18 |
+
audio_base64: str
|
| 19 |
+
|
| 20 |
+
class WatchTranslationResponse(BaseModel):
|
| 21 |
+
original_text: str
|
| 22 |
+
original_language: str
|
| 23 |
+
translated_text: str
|
| 24 |
+
target_language: str
|
| 25 |
+
translated_audio_base64: str
|
| 26 |
+
success: bool
|
| 27 |
+
error: Optional[str] = None
|
| 28 |
+
|
| 29 |
+
# Initialize services (these will be injected by main app)
|
| 30 |
+
transcription_service = None
|
| 31 |
+
translation_service = None
|
| 32 |
+
tts_service = None
|
| 33 |
+
|
| 34 |
+
@router.post("/watch/translate", response_model=WatchTranslationResponse)
|
| 35 |
+
async def watch_translate_audio(request: WatchTranslationRequest, token: str = Depends(require_hf_token)):
|
| 36 |
+
"""
|
| 37 |
+
Process audio for watch app translation
|
| 38 |
+
- Transcribe audio using source language model
|
| 39 |
+
- Translate text to target language
|
| 40 |
+
- Generate TTS audio for target language
|
| 41 |
+
- Return all data to watch app
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
# Validate languages
|
| 45 |
+
source_lang = request.source_language.lower()
|
| 46 |
+
target_lang = request.target_language.lower()
|
| 47 |
+
|
| 48 |
+
if source_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
|
| 49 |
+
raise HTTPException(status_code=400, detail=f"Unsupported source language: {source_lang}")
|
| 50 |
+
|
| 51 |
+
if target_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
|
| 52 |
+
raise HTTPException(status_code=400, detail=f"Unsupported target language: {target_lang}")
|
| 53 |
+
|
| 54 |
+
# Decode base64 audio
|
| 55 |
+
try:
|
| 56 |
+
audio_data = base64.b64decode(request.audio_base64)
|
| 57 |
+
print(f"Decoded audio data: {len(audio_data)} bytes")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
raise HTTPException(status_code=400, detail=f"Invalid base64 audio data: {str(e)}")
|
| 60 |
+
|
| 61 |
+
# Step 1: Transcribe audio
|
| 62 |
+
print(f"Transcribing audio with {source_lang} model...")
|
| 63 |
+
transcribed_text = await transcription_service.transcribe_audio(audio_data, source_lang)
|
| 64 |
+
|
| 65 |
+
if not transcribed_text or transcribed_text.strip() == "":
|
| 66 |
+
return WatchTranslationResponse(
|
| 67 |
+
original_text="",
|
| 68 |
+
original_language=source_lang,
|
| 69 |
+
translated_text="No speech detected",
|
| 70 |
+
target_language=target_lang,
|
| 71 |
+
translated_audio_base64="",
|
| 72 |
+
success=False,
|
| 73 |
+
error="No speech detected in audio"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
print(f"Transcribed text: {transcribed_text}")
|
| 77 |
+
|
| 78 |
+
# Step 2: Translate text (skip if source and target are the same)
|
| 79 |
+
if source_lang == target_lang:
|
| 80 |
+
translated_text = transcribed_text
|
| 81 |
+
else:
|
| 82 |
+
print(f"Translating from {source_lang} to {target_lang}...")
|
| 83 |
+
|
| 84 |
+
# Convert language codes to full names for translation service
|
| 85 |
+
lang_name_map = {
|
| 86 |
+
'eng': 'English',
|
| 87 |
+
'swa': 'Swahili',
|
| 88 |
+
'kik': 'Kikuyu',
|
| 89 |
+
'kam': 'Kamba',
|
| 90 |
+
'mer': 'Kimeru',
|
| 91 |
+
'luo': 'Luo',
|
| 92 |
+
'som': 'Somali'
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
source_lang_name = lang_name_map.get(source_lang, 'English')
|
| 96 |
+
target_lang_name = lang_name_map.get(target_lang, 'Swahili')
|
| 97 |
+
|
| 98 |
+
translated_text = await translation_service.translate_text(
|
| 99 |
+
transcribed_text,
|
| 100 |
+
source_lang_name,
|
| 101 |
+
target_lang_name
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
print(f"Translated text: {translated_text}")
|
| 105 |
+
|
| 106 |
+
# Step 3: Generate TTS audio for translated text (Android-compatible WAV format)
|
| 107 |
+
print(f"Generating TTS audio for {target_lang} in WAV format for Android...")
|
| 108 |
+
tts_audio_data = await tts_service.generate_speech(translated_text, target_lang, output_format="wav")
|
| 109 |
+
|
| 110 |
+
# Encode TTS audio as base64
|
| 111 |
+
tts_audio_base64 = ""
|
| 112 |
+
if tts_audio_data:
|
| 113 |
+
tts_audio_base64 = base64.b64encode(tts_audio_data).decode('utf-8')
|
| 114 |
+
print(f"TTS audio generated: {len(tts_audio_data)} bytes, base64: {len(tts_audio_base64)} chars")
|
| 115 |
+
else:
|
| 116 |
+
print("TTS audio generation failed - no data returned")
|
| 117 |
+
|
| 118 |
+
return WatchTranslationResponse(
|
| 119 |
+
original_text=transcribed_text,
|
| 120 |
+
original_language=source_lang,
|
| 121 |
+
translated_text=translated_text,
|
| 122 |
+
target_language=target_lang,
|
| 123 |
+
translated_audio_base64=tts_audio_base64,
|
| 124 |
+
success=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
except Exception as e:
|
| 128 |
+
print(f"Error in watch translation: {str(e)}")
|
| 129 |
+
import traceback
|
| 130 |
+
traceback.print_exc()
|
| 131 |
+
|
| 132 |
+
return WatchTranslationResponse(
|
| 133 |
+
original_text="",
|
| 134 |
+
original_language=request.source_language,
|
| 135 |
+
translated_text="",
|
| 136 |
+
target_language=request.target_language,
|
| 137 |
+
translated_audio_base64="",
|
| 138 |
+
success=False,
|
| 139 |
+
error=str(e)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
@router.get("/watch/test")
|
| 143 |
+
async def test_watch_endpoint(token: str = Depends(require_hf_token)):
|
| 144 |
+
"""Test endpoint for watch app connectivity"""
|
| 145 |
+
return {
|
| 146 |
+
"status": "Watch API is working",
|
| 147 |
+
"services": {
|
| 148 |
+
"transcription": transcription_service is not None,
|
| 149 |
+
"translation": translation_service is not None,
|
| 150 |
+
"tts": tts_service is not None
|
| 151 |
+
}
|
| 152 |
+
}
|
app/services/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Services package
|
app/services/learning_data_service.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learning Data Service - File-based data access for language learning prototype
|
| 3 |
+
|
| 4 |
+
This service provides access to lesson data, user progress, and achievements
|
| 5 |
+
using JSON files stored in the backend/data/learning directory.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, List, Optional, Any
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LearningDataService:
|
| 19 |
+
"""Service for managing language learning data using JSON files"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
# Get the data directory relative to this file
|
| 23 |
+
self.data_dir = Path(__file__).parent.parent.parent / "data" / "learning"
|
| 24 |
+
self.lessons_dir = self.data_dir / "lessons"
|
| 25 |
+
self.users_dir = self.data_dir / "users"
|
| 26 |
+
|
| 27 |
+
# Ensure directories exist
|
| 28 |
+
self.users_dir.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
logger.info(f"Learning data directory: {self.data_dir}")
|
| 31 |
+
logger.info(f"Lessons directory: {self.lessons_dir}")
|
| 32 |
+
logger.info(f"Users directory: {self.users_dir}")
|
| 33 |
+
|
| 34 |
+
# ==================== Lesson Data ====================
|
| 35 |
+
|
| 36 |
+
def get_lessons_index(self, language: str = 'swahili') -> Optional[Dict]:
|
| 37 |
+
"""Load the lessons index/catalog for a specific language"""
|
| 38 |
+
try:
|
| 39 |
+
# Map language codes to folder names
|
| 40 |
+
language_map = {
|
| 41 |
+
'swahili': 'swahili',
|
| 42 |
+
'swa': 'swahili',
|
| 43 |
+
'kamba': 'kamba',
|
| 44 |
+
'kam': 'kamba',
|
| 45 |
+
'maasai': 'maasai',
|
| 46 |
+
'mas': 'maasai'
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
language_folder = language_map.get(language.lower(), 'swahili')
|
| 50 |
+
index_path = self.lessons_dir / language_folder / "index.json"
|
| 51 |
+
|
| 52 |
+
logger.info(f"Loading lessons index for language '{language}' -> folder '{language_folder}' at {index_path}")
|
| 53 |
+
|
| 54 |
+
if not index_path.exists():
|
| 55 |
+
logger.warning(f"Lessons index not found at {index_path}")
|
| 56 |
+
logger.info(f"Lessons dir contents: {list(self.lessons_dir.iterdir())}")
|
| 57 |
+
return None
|
| 58 |
+
|
| 59 |
+
with open(index_path, 'r', encoding='utf-8') as f:
|
| 60 |
+
data = json.load(f)
|
| 61 |
+
logger.info(f"Successfully loaded {len(data.get('lessons', []))} lessons for {language}")
|
| 62 |
+
return data
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Error loading lessons index for {language}: {e}")
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def get_lesson(self, lesson_id: int, language: str = 'swahili') -> Optional[Dict]:
|
| 68 |
+
"""Load a specific lesson by ID for a specific language"""
|
| 69 |
+
try:
|
| 70 |
+
# Map language codes to folder names
|
| 71 |
+
language_map = {
|
| 72 |
+
'swahili': 'swahili',
|
| 73 |
+
'swa': 'swahili',
|
| 74 |
+
'kamba': 'kamba',
|
| 75 |
+
'kam': 'kamba',
|
| 76 |
+
'maasai': 'maasai',
|
| 77 |
+
'mas': 'maasai'
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
language_folder = language_map.get(language.lower(), 'swahili')
|
| 81 |
+
|
| 82 |
+
# First get the index to find the lesson file
|
| 83 |
+
index = self.get_lessons_index(language)
|
| 84 |
+
if not index:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
# Find the lesson in the index
|
| 88 |
+
lesson_meta = None
|
| 89 |
+
for lesson in index.get('lessons', []):
|
| 90 |
+
if lesson['lesson_id'] == lesson_id:
|
| 91 |
+
lesson_meta = lesson
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
if not lesson_meta:
|
| 95 |
+
logger.warning(f"Lesson {lesson_id} not found in index for {language}")
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# Load the lesson file
|
| 99 |
+
lesson_path = self.lessons_dir / language_folder / lesson_meta['file']
|
| 100 |
+
if not lesson_path.exists():
|
| 101 |
+
logger.warning(f"Lesson file not found: {lesson_path}")
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
with open(lesson_path, 'r', encoding='utf-8') as f:
|
| 105 |
+
return json.load(f)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
logger.error(f"Error loading lesson {lesson_id} for {language}: {e}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
def get_available_lessons(self) -> List[Dict]:
|
| 111 |
+
"""Get list of available lessons (not planned)"""
|
| 112 |
+
try:
|
| 113 |
+
index = self.get_lessons_index()
|
| 114 |
+
if not index:
|
| 115 |
+
return []
|
| 116 |
+
|
| 117 |
+
available = [
|
| 118 |
+
lesson for lesson in index.get('lessons', [])
|
| 119 |
+
if lesson.get('status') == 'available'
|
| 120 |
+
]
|
| 121 |
+
return available
|
| 122 |
+
except Exception as e:
|
| 123 |
+
logger.error(f"Error getting available lessons: {e}")
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
# ==================== Achievements ====================
|
| 127 |
+
|
| 128 |
+
def get_achievements(self) -> Optional[Dict]:
|
| 129 |
+
"""Load achievements configuration"""
|
| 130 |
+
try:
|
| 131 |
+
achievements_path = self.data_dir / "achievements.json"
|
| 132 |
+
if not achievements_path.exists():
|
| 133 |
+
logger.warning(f"Achievements file not found at {achievements_path}")
|
| 134 |
+
return None
|
| 135 |
+
|
| 136 |
+
with open(achievements_path, 'r', encoding='utf-8') as f:
|
| 137 |
+
return json.load(f)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.error(f"Error loading achievements: {e}")
|
| 140 |
+
return None
|
| 141 |
+
|
| 142 |
+
# ==================== User Progress ====================
|
| 143 |
+
|
| 144 |
+
def get_user_progress(self, user_id: str) -> Optional[Dict]:
|
| 145 |
+
"""Load user progress data"""
|
| 146 |
+
try:
|
| 147 |
+
user_file = self.users_dir / f"user-{user_id}.json"
|
| 148 |
+
if not user_file.exists():
|
| 149 |
+
# Return default progress structure for new users
|
| 150 |
+
return self._create_default_user_progress(user_id)
|
| 151 |
+
|
| 152 |
+
with open(user_file, 'r', encoding='utf-8') as f:
|
| 153 |
+
return json.load(f)
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"Error loading user progress for {user_id}: {e}")
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
def save_user_progress(self, user_id: str, progress_data: Dict) -> bool:
|
| 159 |
+
"""Save user progress data"""
|
| 160 |
+
try:
|
| 161 |
+
user_file = self.users_dir / f"user-{user_id}.json"
|
| 162 |
+
|
| 163 |
+
# Update last_active timestamp
|
| 164 |
+
if 'profile' in progress_data:
|
| 165 |
+
progress_data['profile']['last_active'] = datetime.utcnow().isoformat() + 'Z'
|
| 166 |
+
|
| 167 |
+
with open(user_file, 'w', encoding='utf-8') as f:
|
| 168 |
+
json.dump(progress_data, f, indent=2, ensure_ascii=False)
|
| 169 |
+
|
| 170 |
+
logger.info(f"Saved progress for user {user_id}")
|
| 171 |
+
return True
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.error(f"Error saving user progress for {user_id}: {e}")
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
def update_lesson_progress(
|
| 177 |
+
self,
|
| 178 |
+
user_id: str,
|
| 179 |
+
lesson_id: int,
|
| 180 |
+
progress_update: Dict
|
| 181 |
+
) -> bool:
|
| 182 |
+
"""Update progress for a specific lesson"""
|
| 183 |
+
try:
|
| 184 |
+
user_progress = self.get_user_progress(user_id)
|
| 185 |
+
if not user_progress:
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
# Initialize lesson_progress if it doesn't exist
|
| 189 |
+
if 'lesson_progress' not in user_progress:
|
| 190 |
+
user_progress['lesson_progress'] = {}
|
| 191 |
+
|
| 192 |
+
lesson_key = str(lesson_id)
|
| 193 |
+
|
| 194 |
+
# Update or create lesson progress
|
| 195 |
+
if lesson_key in user_progress['lesson_progress']:
|
| 196 |
+
user_progress['lesson_progress'][lesson_key].update(progress_update)
|
| 197 |
+
else:
|
| 198 |
+
user_progress['lesson_progress'][lesson_key] = progress_update
|
| 199 |
+
|
| 200 |
+
return self.save_user_progress(user_id, user_progress)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"Error updating lesson progress: {e}")
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
def update_vocabulary_progress(
|
| 206 |
+
self,
|
| 207 |
+
user_id: str,
|
| 208 |
+
vocab_id: int,
|
| 209 |
+
vocab_update: Dict
|
| 210 |
+
) -> bool:
|
| 211 |
+
"""Update progress for a specific vocabulary word"""
|
| 212 |
+
try:
|
| 213 |
+
user_progress = self.get_user_progress(user_id)
|
| 214 |
+
if not user_progress:
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
# Initialize vocabulary_progress if it doesn't exist
|
| 218 |
+
if 'vocabulary_progress' not in user_progress:
|
| 219 |
+
user_progress['vocabulary_progress'] = {}
|
| 220 |
+
|
| 221 |
+
vocab_key = str(vocab_id)
|
| 222 |
+
|
| 223 |
+
# Update or create vocabulary progress
|
| 224 |
+
if vocab_key in user_progress['vocabulary_progress']:
|
| 225 |
+
user_progress['vocabulary_progress'][vocab_key].update(vocab_update)
|
| 226 |
+
else:
|
| 227 |
+
user_progress['vocabulary_progress'][vocab_key] = vocab_update
|
| 228 |
+
|
| 229 |
+
return self.save_user_progress(user_id, user_progress)
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error updating vocabulary progress: {e}")
|
| 232 |
+
return False
|
| 233 |
+
|
| 234 |
+
def unlock_achievement(
|
| 235 |
+
self,
|
| 236 |
+
user_id: str,
|
| 237 |
+
achievement_id: str,
|
| 238 |
+
progress: int,
|
| 239 |
+
target: int
|
| 240 |
+
) -> bool:
|
| 241 |
+
"""Unlock or update progress on an achievement"""
|
| 242 |
+
try:
|
| 243 |
+
user_progress = self.get_user_progress(user_id)
|
| 244 |
+
if not user_progress:
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
# Initialize achievements if it doesn't exist
|
| 248 |
+
if 'achievements' not in user_progress:
|
| 249 |
+
user_progress['achievements'] = {}
|
| 250 |
+
|
| 251 |
+
# Update achievement
|
| 252 |
+
achievement_data = {
|
| 253 |
+
'achievement_id': achievement_id,
|
| 254 |
+
'unlocked': progress >= target,
|
| 255 |
+
'progress': progress,
|
| 256 |
+
'target': target
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
# Add unlock timestamp if newly unlocked
|
| 260 |
+
if achievement_data['unlocked'] and achievement_id not in user_progress['achievements']:
|
| 261 |
+
achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 262 |
+
elif achievement_data['unlocked'] and achievement_id in user_progress['achievements']:
|
| 263 |
+
# Preserve original unlock time
|
| 264 |
+
if 'unlocked_at' in user_progress['achievements'][achievement_id]:
|
| 265 |
+
achievement_data['unlocked_at'] = user_progress['achievements'][achievement_id]['unlocked_at']
|
| 266 |
+
else:
|
| 267 |
+
achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
|
| 268 |
+
|
| 269 |
+
user_progress['achievements'][achievement_id] = achievement_data
|
| 270 |
+
|
| 271 |
+
return self.save_user_progress(user_id, user_progress)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.error(f"Error unlocking achievement: {e}")
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
# ==================== Helper Methods ====================
|
| 277 |
+
|
| 278 |
+
def _create_default_user_progress(self, user_id: str) -> Dict:
|
| 279 |
+
"""Create default progress structure for a new user"""
|
| 280 |
+
return {
|
| 281 |
+
'user_id': user_id,
|
| 282 |
+
'profile': {
|
| 283 |
+
'user_id': user_id,
|
| 284 |
+
'learning_language': 'swa',
|
| 285 |
+
'native_language': 'eng',
|
| 286 |
+
'created_at': datetime.utcnow().isoformat() + 'Z',
|
| 287 |
+
'last_active': datetime.utcnow().isoformat() + 'Z'
|
| 288 |
+
},
|
| 289 |
+
'overall_stats': {
|
| 290 |
+
'level': 'beginner',
|
| 291 |
+
'total_xp': 0,
|
| 292 |
+
'next_level_xp': 1000,
|
| 293 |
+
'current_streak': 0,
|
| 294 |
+
'longest_streak': 0,
|
| 295 |
+
'lessons_completed': 0,
|
| 296 |
+
'vocabulary_learned': 0,
|
| 297 |
+
'vocabulary_mastered': 0,
|
| 298 |
+
'total_practice_time_seconds': 0,
|
| 299 |
+
'pronunciation_avg_score': 0.0,
|
| 300 |
+
'listening_avg_score': 0.0,
|
| 301 |
+
'reading_avg_score': 0.0
|
| 302 |
+
},
|
| 303 |
+
'daily_stats': {},
|
| 304 |
+
'lesson_progress': {},
|
| 305 |
+
'vocabulary_progress': {},
|
| 306 |
+
'achievements': {},
|
| 307 |
+
'session_history': []
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
def create_default_progress(self, user_id: str) -> Dict:
|
| 311 |
+
"""Public method to create default progress structure"""
|
| 312 |
+
progress = self._create_default_user_progress(user_id)
|
| 313 |
+
# Add Phase 1-3 specific fields
|
| 314 |
+
progress['overall_stats']['vocabulary_reviewed'] = 0
|
| 315 |
+
progress['comprehension_scores'] = {}
|
| 316 |
+
progress['scenario_progress'] = {}
|
| 317 |
+
return progress
|
| 318 |
+
|
| 319 |
+
# ==================== Phase 1-3 Methods ====================
|
| 320 |
+
|
| 321 |
+
def get_vocabulary(self, vocab_id: int) -> Optional[Dict]:
|
| 322 |
+
"""Get a single vocabulary word by ID from any lesson"""
|
| 323 |
+
try:
|
| 324 |
+
lessons_index = self.get_lessons_index()
|
| 325 |
+
if not lessons_index:
|
| 326 |
+
return None
|
| 327 |
+
|
| 328 |
+
# Search through all lessons
|
| 329 |
+
for lesson_meta in lessons_index.get('lessons', []):
|
| 330 |
+
lesson = self.get_lesson(lesson_meta['lesson_id'])
|
| 331 |
+
if lesson and 'vocabulary' in lesson:
|
| 332 |
+
for vocab in lesson['vocabulary']:
|
| 333 |
+
# Support both 'id' and 'vocabulary_id' fields
|
| 334 |
+
vocab_item_id = vocab.get('vocabulary_id') or vocab.get('id')
|
| 335 |
+
if vocab_item_id == vocab_id:
|
| 336 |
+
# Add lesson context
|
| 337 |
+
vocab['lesson_id'] = lesson['lesson_id']
|
| 338 |
+
vocab['lesson_title'] = lesson.get('title', '')
|
| 339 |
+
return vocab
|
| 340 |
+
|
| 341 |
+
logger.warning(f"Vocabulary {vocab_id} not found in any lesson")
|
| 342 |
+
return None
|
| 343 |
+
except Exception as e:
|
| 344 |
+
logger.error(f"Error getting vocabulary {vocab_id}: {e}")
|
| 345 |
+
return None
|
| 346 |
+
|
| 347 |
+
def get_all_vocabulary(self) -> List[Dict]:
|
| 348 |
+
"""Get all vocabulary words from all lessons"""
|
| 349 |
+
try:
|
| 350 |
+
all_vocab = []
|
| 351 |
+
lessons_index = self.get_lessons_index()
|
| 352 |
+
if not lessons_index:
|
| 353 |
+
return all_vocab
|
| 354 |
+
|
| 355 |
+
for lesson_meta in lessons_index.get('lessons', []):
|
| 356 |
+
lesson = self.get_lesson(lesson_meta['lesson_id'])
|
| 357 |
+
if lesson and 'vocabulary' in lesson:
|
| 358 |
+
for vocab in lesson['vocabulary']:
|
| 359 |
+
# Add lesson context
|
| 360 |
+
vocab_copy = vocab.copy()
|
| 361 |
+
vocab_copy['lesson_id'] = lesson['lesson_id']
|
| 362 |
+
vocab_copy['lesson_title'] = lesson.get('title', '')
|
| 363 |
+
vocab_copy['lesson_level'] = lesson.get('difficulty_level', 1)
|
| 364 |
+
all_vocab.append(vocab_copy)
|
| 365 |
+
|
| 366 |
+
return all_vocab
|
| 367 |
+
except Exception as e:
|
| 368 |
+
logger.error(f"Error getting all vocabulary: {e}")
|
| 369 |
+
return []
|
| 370 |
+
|
| 371 |
+
def get_scenario(self, scenario_id: str) -> Optional[Dict]:
|
| 372 |
+
"""Load a task scenario by ID"""
|
| 373 |
+
try:
|
| 374 |
+
scenarios_dir = self.data_dir / "scenarios"
|
| 375 |
+
scenario_path = scenarios_dir / f"{scenario_id}.json"
|
| 376 |
+
|
| 377 |
+
if not scenario_path.exists():
|
| 378 |
+
logger.warning(f"Scenario file not found: {scenario_path}")
|
| 379 |
+
return None
|
| 380 |
+
|
| 381 |
+
with open(scenario_path, 'r', encoding='utf-8') as f:
|
| 382 |
+
return json.load(f)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
logger.error(f"Error loading scenario {scenario_id}: {e}")
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
def get_all_scenarios(self) -> List[Dict]:
|
| 388 |
+
"""Get list of all available scenarios"""
|
| 389 |
+
try:
|
| 390 |
+
scenarios_dir = self.data_dir / "scenarios"
|
| 391 |
+
if not scenarios_dir.exists():
|
| 392 |
+
return []
|
| 393 |
+
|
| 394 |
+
scenarios = []
|
| 395 |
+
for scenario_file in scenarios_dir.glob("*.json"):
|
| 396 |
+
try:
|
| 397 |
+
with open(scenario_file, 'r', encoding='utf-8') as f:
|
| 398 |
+
scenario_data = json.load(f)
|
| 399 |
+
# Add just metadata, not full dialogue tree
|
| 400 |
+
scenarios.append({
|
| 401 |
+
'scenario_id': scenario_data.get('scenario_id'),
|
| 402 |
+
'title': scenario_data.get('title'),
|
| 403 |
+
'title_en': scenario_data.get('title_en'),
|
| 404 |
+
'level': scenario_data.get('level'),
|
| 405 |
+
'estimated_duration_minutes': scenario_data.get('estimated_duration_minutes'),
|
| 406 |
+
'learning_goals': scenario_data.get('learning_goals', [])
|
| 407 |
+
})
|
| 408 |
+
except Exception as e:
|
| 409 |
+
logger.error(f"Error loading scenario {scenario_file}: {e}")
|
| 410 |
+
continue
|
| 411 |
+
|
| 412 |
+
return scenarios
|
| 413 |
+
except Exception as e:
|
| 414 |
+
logger.error(f"Error getting all scenarios: {e}")
|
| 415 |
+
return []
|
app/services/quantization_utils.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dynamic INT8 Quantization utilities for ASR models.
|
| 3 |
+
|
| 4 |
+
This module provides utilities to apply PyTorch dynamic quantization to
|
| 5 |
+
Hugging Face transformer models, specifically optimized for ASR models like
|
| 6 |
+
Whisper and Wav2Vec2-BERT.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.quantization import quantize_dynamic
|
| 11 |
+
from transformers import PreTrainedModel
|
| 12 |
+
import time
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def apply_dynamic_int8_quantization(model: PreTrainedModel, model_type: str = "auto") -> PreTrainedModel:
|
| 16 |
+
"""
|
| 17 |
+
Apply dynamic INT8 quantization to a Hugging Face model.
|
| 18 |
+
|
| 19 |
+
Dynamic quantization converts model weights to INT8 and activations to INT8 on-the-fly
|
| 20 |
+
during inference, reducing model size and improving inference speed with minimal
|
| 21 |
+
accuracy loss.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
model: The Hugging Face model to quantize
|
| 25 |
+
model_type: Type of model ("whisper", "wav2vec2-bert", or "auto")
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Quantized model
|
| 29 |
+
|
| 30 |
+
References:
|
| 31 |
+
- PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html
|
| 32 |
+
- Dynamic Quantization for NLP: https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html
|
| 33 |
+
"""
|
| 34 |
+
print(f"\n{'='*60}")
|
| 35 |
+
print(f"Applying Dynamic INT8 Quantization to {model_type} model")
|
| 36 |
+
print(f"{'='*60}")
|
| 37 |
+
|
| 38 |
+
# Get model size before quantization
|
| 39 |
+
param_size = 0
|
| 40 |
+
for param in model.parameters():
|
| 41 |
+
param_size += param.nelement() * param.element_size()
|
| 42 |
+
buffer_size = 0
|
| 43 |
+
for buffer in model.buffers():
|
| 44 |
+
buffer_size += buffer.nelement() * buffer.element_size()
|
| 45 |
+
size_before_mb = (param_size + buffer_size) / 1024**2
|
| 46 |
+
|
| 47 |
+
print(f"Model size before quantization: {size_before_mb:.2f} MB")
|
| 48 |
+
|
| 49 |
+
# Start quantization timer
|
| 50 |
+
start_time = time.time()
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
# Dynamic quantization targets:
|
| 54 |
+
# - torch.nn.Linear: Most common layer type in transformers
|
| 55 |
+
# - torch.nn.LSTM/GRU/RNN: For sequential models (if present)
|
| 56 |
+
#
|
| 57 |
+
# Note: We use qint8 (quantized int8) which converts weights to INT8
|
| 58 |
+
# and performs INT8 arithmetic for linear layers during inference
|
| 59 |
+
quantized_model = quantize_dynamic(
|
| 60 |
+
model,
|
| 61 |
+
{torch.nn.Linear}, # Quantize all Linear layers
|
| 62 |
+
dtype=torch.qint8 # Use 8-bit integer quantization
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Get model size after quantization
|
| 66 |
+
param_size_q = 0
|
| 67 |
+
for param in quantized_model.parameters():
|
| 68 |
+
param_size_q += param.nelement() * param.element_size()
|
| 69 |
+
buffer_size_q = 0
|
| 70 |
+
for buffer in quantized_model.buffers():
|
| 71 |
+
buffer_size_q += buffer.nelement() * buffer.element_size()
|
| 72 |
+
size_after_mb = (param_size_q + buffer_size_q) / 1024**2
|
| 73 |
+
|
| 74 |
+
quantization_time = time.time() - start_time
|
| 75 |
+
size_reduction = ((size_before_mb - size_after_mb) / size_before_mb) * 100
|
| 76 |
+
|
| 77 |
+
print(f"✓ Quantization successful!")
|
| 78 |
+
print(f" - Model size after quantization: {size_after_mb:.2f} MB")
|
| 79 |
+
print(f" - Size reduction: {size_reduction:.1f}%")
|
| 80 |
+
print(f" - Quantization time: {quantization_time:.2f}s")
|
| 81 |
+
print(f"{'='*60}\n")
|
| 82 |
+
|
| 83 |
+
return quantized_model
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"✗ Quantization failed: {e}")
|
| 87 |
+
print(f" Returning original unquantized model")
|
| 88 |
+
print(f"{'='*60}\n")
|
| 89 |
+
return model
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def get_quantization_stats(model: PreTrainedModel) -> dict:
|
| 93 |
+
"""
|
| 94 |
+
Get statistics about a model's quantization status.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
model: The model to analyze
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Dictionary with quantization statistics
|
| 101 |
+
"""
|
| 102 |
+
stats = {
|
| 103 |
+
"is_quantized": False,
|
| 104 |
+
"quantized_layers": 0,
|
| 105 |
+
"total_layers": 0,
|
| 106 |
+
"size_mb": 0.0
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Count quantized vs regular layers
|
| 110 |
+
for name, module in model.named_modules():
|
| 111 |
+
if isinstance(module, (torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU)):
|
| 112 |
+
stats["total_layers"] += 1
|
| 113 |
+
|
| 114 |
+
# Check if layer is quantized (will have _packed_params attribute)
|
| 115 |
+
if hasattr(module, '_packed_params'):
|
| 116 |
+
stats["quantized_layers"] += 1
|
| 117 |
+
stats["is_quantized"] = True
|
| 118 |
+
|
| 119 |
+
# Calculate model size
|
| 120 |
+
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
|
| 121 |
+
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
|
| 122 |
+
stats["size_mb"] = (param_size + buffer_size) / 1024**2
|
| 123 |
+
|
| 124 |
+
return stats
|
app/services/session_manager.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import random
|
| 3 |
+
import string
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
from app.models import Session, SessionCreate, Participant, Language, LanguageCode
|
| 6 |
+
|
| 7 |
+
def generate_short_code(length: int = 8) -> str:
|
| 8 |
+
"""Generate a random short code using uppercase letters and digits"""
|
| 9 |
+
# Use only uppercase letters and digits to avoid confusion (no lowercase to avoid O/0, I/1 confusion)
|
| 10 |
+
alphabet = string.ascii_uppercase + string.digits
|
| 11 |
+
# Remove confusing characters
|
| 12 |
+
alphabet = alphabet.replace('O', '').replace('0', '').replace('I', '').replace('1', '')
|
| 13 |
+
return ''.join(random.choice(alphabet) for _ in range(length))
|
| 14 |
+
|
| 15 |
+
# Language mappings
|
| 16 |
+
LANGUAGE_MAP = {
|
| 17 |
+
LanguageCode.ENGLISH: Language(code=LanguageCode.ENGLISH, name="English", display_name="English (eng)"),
|
| 18 |
+
LanguageCode.SWAHILI: Language(code=LanguageCode.SWAHILI, name="Swahili", display_name="Swahili (swa)"),
|
| 19 |
+
LanguageCode.KIKUYU: Language(code=LanguageCode.KIKUYU, name="Kikuyu", display_name="Kikuyu (kik)"),
|
| 20 |
+
LanguageCode.KAMBA: Language(code=LanguageCode.KAMBA, name="Kamba", display_name="Kamba (kam)"),
|
| 21 |
+
LanguageCode.KIMERU: Language(code=LanguageCode.KIMERU, name="Kimeru", display_name="Kimeru (mer)"),
|
| 22 |
+
LanguageCode.LUO: Language(code=LanguageCode.LUO, name="Luo", display_name="Luo (luo)"),
|
| 23 |
+
LanguageCode.SOMALI: Language(code=LanguageCode.SOMALI, name="Somali", display_name="Somali (som)"),
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
class SessionManager:
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.sessions: Dict[str, Session] = {}
|
| 29 |
+
self.participant_sessions: Dict[str, str] = {} # participant_id -> session_id
|
| 30 |
+
self.short_code_to_id: Dict[str, str] = {} # short_code -> session_id
|
| 31 |
+
self.id_to_short_code: Dict[str, str] = {} # session_id -> short_code
|
| 32 |
+
|
| 33 |
+
async def create_session(self, session_data: SessionCreate) -> Session:
|
| 34 |
+
session_id = str(uuid.uuid4())
|
| 35 |
+
|
| 36 |
+
# Generate unique short code
|
| 37 |
+
short_code = generate_short_code(8)
|
| 38 |
+
while short_code in self.short_code_to_id:
|
| 39 |
+
# Extremely unlikely collision, but regenerate if needed
|
| 40 |
+
short_code = generate_short_code(8)
|
| 41 |
+
|
| 42 |
+
# Convert language codes to Language objects
|
| 43 |
+
languages = [LANGUAGE_MAP[lang_code] for lang_code in session_data.languages]
|
| 44 |
+
|
| 45 |
+
session = Session(
|
| 46 |
+
id=session_id,
|
| 47 |
+
name=session_data.name,
|
| 48 |
+
organizer_name=session_data.organizer_name,
|
| 49 |
+
languages=languages,
|
| 50 |
+
participants=[],
|
| 51 |
+
is_active=True,
|
| 52 |
+
enable_tts=session_data.enable_tts
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.sessions[session_id] = session
|
| 56 |
+
self.short_code_to_id[short_code] = session_id
|
| 57 |
+
self.id_to_short_code[session_id] = short_code
|
| 58 |
+
return session
|
| 59 |
+
|
| 60 |
+
async def get_session(self, session_id_or_code: str) -> Optional[Session]:
|
| 61 |
+
"""Get session by full UUID or short code"""
|
| 62 |
+
# Try as full UUID first
|
| 63 |
+
session = self.sessions.get(session_id_or_code)
|
| 64 |
+
if session:
|
| 65 |
+
return session
|
| 66 |
+
|
| 67 |
+
# Try as short code
|
| 68 |
+
session_id = self.short_code_to_id.get(session_id_or_code.upper())
|
| 69 |
+
if session_id:
|
| 70 |
+
return self.sessions.get(session_id)
|
| 71 |
+
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
def get_short_code(self, session_id: str) -> str:
|
| 75 |
+
"""Get short code for a session ID"""
|
| 76 |
+
return self.id_to_short_code.get(session_id, session_id)
|
| 77 |
+
|
| 78 |
+
async def get_all_sessions(self) -> List[Session]:
|
| 79 |
+
return list(self.sessions.values())
|
| 80 |
+
|
| 81 |
+
async def add_participant(self, session_id: str, participant_name: str, language_code: LanguageCode) -> Optional[Participant]:
|
| 82 |
+
session = await self.get_session(session_id)
|
| 83 |
+
if not session:
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
participant_id = str(uuid.uuid4())
|
| 87 |
+
language = LANGUAGE_MAP[language_code]
|
| 88 |
+
|
| 89 |
+
# Check if the participant's language is already in the session languages
|
| 90 |
+
language_exists = any(lang.code == language_code for lang in session.languages)
|
| 91 |
+
if not language_exists:
|
| 92 |
+
print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
|
| 93 |
+
session.languages.append(language)
|
| 94 |
+
|
| 95 |
+
participant = Participant(
|
| 96 |
+
id=participant_id,
|
| 97 |
+
name=participant_name,
|
| 98 |
+
language=language,
|
| 99 |
+
is_organizer=len(session.participants) == 0, # First participant is organizer
|
| 100 |
+
is_speaking=False,
|
| 101 |
+
is_connected=True
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
session.participants.append(participant)
|
| 105 |
+
self.participant_sessions[participant_id] = session_id
|
| 106 |
+
|
| 107 |
+
print(f"Participant {participant_name} added to session. Session now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
|
| 108 |
+
|
| 109 |
+
return participant
|
| 110 |
+
|
| 111 |
+
async def remove_participant(self, participant_id: str) -> bool:
|
| 112 |
+
session_id = self.participant_sessions.get(participant_id)
|
| 113 |
+
if not session_id:
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
session = await self.get_session(session_id)
|
| 117 |
+
if not session:
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
# Remove participant from session
|
| 121 |
+
session.participants = [p for p in session.participants if p.id != participant_id]
|
| 122 |
+
del self.participant_sessions[participant_id]
|
| 123 |
+
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
async def update_participant_speaking_status(self, participant_id: str, is_speaking: bool) -> bool:
|
| 127 |
+
session_id = self.participant_sessions.get(participant_id)
|
| 128 |
+
if not session_id:
|
| 129 |
+
return False
|
| 130 |
+
|
| 131 |
+
session = await self.get_session(session_id)
|
| 132 |
+
if not session:
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
for participant in session.participants:
|
| 136 |
+
if participant.id == participant_id:
|
| 137 |
+
participant.is_speaking = is_speaking
|
| 138 |
+
return True
|
| 139 |
+
|
| 140 |
+
return False
|
| 141 |
+
|
| 142 |
+
async def get_participant_session_id(self, participant_id: str) -> Optional[str]:
|
| 143 |
+
return self.participant_sessions.get(participant_id)
|
| 144 |
+
|
| 145 |
+
async def add_language_to_session(self, session_id: str, language_code: LanguageCode) -> bool:
|
| 146 |
+
"""Add a language to the session if it doesn't already exist"""
|
| 147 |
+
session = await self.get_session(session_id)
|
| 148 |
+
if not session:
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
language = LANGUAGE_MAP[language_code]
|
| 152 |
+
|
| 153 |
+
# Check if the language is already in the session languages
|
| 154 |
+
language_exists = any(lang.code == language_code for lang in session.languages)
|
| 155 |
+
if not language_exists:
|
| 156 |
+
print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
|
| 157 |
+
session.languages.append(language)
|
| 158 |
+
print(f"Session {session_id} now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
|
| 159 |
+
return True
|
| 160 |
+
else:
|
| 161 |
+
print(f"Language {language.name} ({language_code.value}) already exists in session {session_id}")
|
| 162 |
+
return False
|
| 163 |
+
|
| 164 |
+
async def delete_session(self, session_id: str) -> bool:
|
| 165 |
+
if session_id in self.sessions:
|
| 166 |
+
# Remove all participants from tracking
|
| 167 |
+
session = self.sessions[session_id]
|
| 168 |
+
for participant in session.participants:
|
| 169 |
+
if participant.id in self.participant_sessions:
|
| 170 |
+
del self.participant_sessions[participant.id]
|
| 171 |
+
|
| 172 |
+
# Remove short code mapping
|
| 173 |
+
short_code = self.id_to_short_code.get(session_id)
|
| 174 |
+
if short_code:
|
| 175 |
+
del self.short_code_to_id[short_code]
|
| 176 |
+
del self.id_to_short_code[session_id]
|
| 177 |
+
|
| 178 |
+
del self.sessions[session_id]
|
| 179 |
+
return True
|
| 180 |
+
return False
|
app/services/transcription_service.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
from typing import Dict, Optional, Callable
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
import torch
|
| 9 |
+
from app.models import LanguageCode
|
| 10 |
+
import os
|
| 11 |
+
from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
|
| 12 |
+
|
| 13 |
+
# Silero VAD imports
|
| 14 |
+
try:
|
| 15 |
+
import silero_vad
|
| 16 |
+
SILERO_VAD_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
SILERO_VAD_AVAILABLE = False
|
| 19 |
+
print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
|
| 20 |
+
|
| 21 |
+
class TranscriptionService:
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self.asr_pipelines: Dict[str, any] = {}
|
| 24 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 25 |
+
|
| 26 |
+
# Model configurations - using original mutisya models with updated config
|
| 27 |
+
self.asr_config = {
|
| 28 |
+
"eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
|
| 29 |
+
"swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 30 |
+
"kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 31 |
+
"kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 32 |
+
"mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 33 |
+
"luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 34 |
+
"som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
self.preload_languages = ["eng"]
|
| 38 |
+
self.background_loading_task = None
|
| 39 |
+
self.models_loading_status = {}
|
| 40 |
+
|
| 41 |
+
# Enhanced audio buffering for VAD-based sentence detection
|
| 42 |
+
self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
|
| 43 |
+
self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
|
| 44 |
+
self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
|
| 45 |
+
self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
|
| 46 |
+
|
| 47 |
+
# VAD parameters - made more lenient for better detection
|
| 48 |
+
self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
|
| 49 |
+
self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
|
| 50 |
+
|
| 51 |
+
# Silero VAD initialization
|
| 52 |
+
self.vad_model = None
|
| 53 |
+
self.vad_sample_rate = 16000
|
| 54 |
+
self.vad_available = SILERO_VAD_AVAILABLE
|
| 55 |
+
|
| 56 |
+
# Quantization configuration
|
| 57 |
+
# Set ENABLE_INT8_QUANTIZATION=true in environment to enable quantization
|
| 58 |
+
self.enable_quantization = os.getenv('ENABLE_INT8_QUANTIZATION', 'true').lower() == 'true'
|
| 59 |
+
print(f"INT8 Quantization: {'ENABLED' if self.enable_quantization else 'DISABLED'}")
|
| 60 |
+
|
| 61 |
+
async def initialize(self):
|
| 62 |
+
"""Initialize ASR models for preloaded languages and Silero VAD"""
|
| 63 |
+
# Initialize Silero VAD model
|
| 64 |
+
if self.vad_available:
|
| 65 |
+
try:
|
| 66 |
+
print("Loading Silero VAD model...")
|
| 67 |
+
self.vad_model = silero_vad.load_silero_vad(onnx=False)
|
| 68 |
+
print("✓ Silero VAD model loaded successfully")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Failed to load Silero VAD model: {e}")
|
| 71 |
+
print("Falling back to RMS-based VAD")
|
| 72 |
+
self.vad_available = False
|
| 73 |
+
|
| 74 |
+
# Initialize ASR models
|
| 75 |
+
for lang_code in self.preload_languages:
|
| 76 |
+
if lang_code in self.asr_config:
|
| 77 |
+
try:
|
| 78 |
+
model_config = self.asr_config[lang_code]
|
| 79 |
+
pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
|
| 80 |
+
self.asr_pipelines[lang_code] = pipeline_obj
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Failed to load ASR model for {lang_code}: {e}")
|
| 83 |
+
|
| 84 |
+
def _load_and_quantize_pipeline(self, lang_code: str, model_config: dict):
|
| 85 |
+
"""Load ASR pipeline and optionally apply INT8 quantization"""
|
| 86 |
+
# Build pipeline parameters
|
| 87 |
+
pipeline_params = {
|
| 88 |
+
"task": "automatic-speech-recognition",
|
| 89 |
+
"model": model_config["model_repo"],
|
| 90 |
+
"device": self.device,
|
| 91 |
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Add trust_remote_code if specified
|
| 95 |
+
if model_config.get("trust_remote_code", False):
|
| 96 |
+
pipeline_params["trust_remote_code"] = True
|
| 97 |
+
|
| 98 |
+
print(f"Loading ASR model for {lang_code}: {model_config['model_repo']}")
|
| 99 |
+
pipeline_obj = pipeline(**pipeline_params)
|
| 100 |
+
|
| 101 |
+
# Apply quantization if enabled
|
| 102 |
+
if self.enable_quantization:
|
| 103 |
+
try:
|
| 104 |
+
# Get the underlying model from the pipeline
|
| 105 |
+
model = pipeline_obj.model
|
| 106 |
+
model_type = model_config.get("model_type", "auto")
|
| 107 |
+
|
| 108 |
+
# Apply dynamic INT8 quantization
|
| 109 |
+
quantized_model = apply_dynamic_int8_quantization(model, model_type)
|
| 110 |
+
|
| 111 |
+
# Replace the model in the pipeline
|
| 112 |
+
pipeline_obj.model = quantized_model
|
| 113 |
+
|
| 114 |
+
# Print quantization stats
|
| 115 |
+
stats = get_quantization_stats(quantized_model)
|
| 116 |
+
print(f"✓ {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Warning: Could not quantize {lang_code} model: {e}")
|
| 120 |
+
print(f"Continuing with unquantized model")
|
| 121 |
+
|
| 122 |
+
return pipeline_obj
|
| 123 |
+
|
| 124 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 125 |
+
"""Load ASR model for language if not already loaded"""
|
| 126 |
+
if language_code not in self.asr_pipelines and language_code in self.asr_config:
|
| 127 |
+
try:
|
| 128 |
+
model_config = self.asr_config[language_code]
|
| 129 |
+
pipeline_obj = self._load_and_quantize_pipeline(language_code, model_config)
|
| 130 |
+
self.asr_pipelines[language_code] = pipeline_obj
|
| 131 |
+
except Exception as e:
|
| 132 |
+
print(f"Failed to load ASR model for {language_code}: {e}")
|
| 133 |
+
raise
|
| 134 |
+
|
| 135 |
+
async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
|
| 136 |
+
has_voice_activity: bool = True,
|
| 137 |
+
progress_callback: Optional[Callable] = None,
|
| 138 |
+
sentence_callback: Optional[Callable] = None,
|
| 139 |
+
debug_callback: Optional[Callable] = None) -> str:
|
| 140 |
+
"""Process audio chunk with VAD-based sentence detection"""
|
| 141 |
+
try:
|
| 142 |
+
# Initialize buffers if needed
|
| 143 |
+
if participant_id not in self.candidate_audio_buffers:
|
| 144 |
+
# Store as numpy array, not bytes, to avoid multiple WAV header issues
|
| 145 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 146 |
+
self.candidate_text_cache[participant_id] = ""
|
| 147 |
+
self.silence_counters[participant_id] = 0
|
| 148 |
+
self.sentence_finalized[participant_id] = False
|
| 149 |
+
|
| 150 |
+
# Convert current chunk to numpy array for processing
|
| 151 |
+
current_chunk_array = self._bytes_to_audio_array(audio_data)
|
| 152 |
+
if len(current_chunk_array) == 0:
|
| 153 |
+
print(f"WARNING: Received empty audio chunk for participant {participant_id}")
|
| 154 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 155 |
+
|
| 156 |
+
print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
|
| 157 |
+
f"duration: {len(current_chunk_array)/16000:.3f}s, "
|
| 158 |
+
f"first 4 bytes: {audio_data[:4]}")
|
| 159 |
+
|
| 160 |
+
# DO NOT normalize individual chunks - this causes audio distortion
|
| 161 |
+
# We'll normalize the entire accumulated audio buffer before transcription
|
| 162 |
+
current_chunk_array = current_chunk_array.astype(np.float32)
|
| 163 |
+
|
| 164 |
+
# Get existing accumulated audio array (now stored as numpy array)
|
| 165 |
+
existing_array = self.candidate_audio_buffers[participant_id]
|
| 166 |
+
if len(existing_array) > 0:
|
| 167 |
+
# Concatenate with existing audio (like stream = np.concatenate([stream, y]))
|
| 168 |
+
combined_array = np.concatenate([existing_array, current_chunk_array])
|
| 169 |
+
else:
|
| 170 |
+
combined_array = current_chunk_array
|
| 171 |
+
|
| 172 |
+
# Store as numpy array to avoid WAV header accumulation issues
|
| 173 |
+
self.candidate_audio_buffers[participant_id] = combined_array
|
| 174 |
+
|
| 175 |
+
# For debug callback, convert to bytes (this adds ONE WAV header)
|
| 176 |
+
combined_bytes = self._audio_array_to_bytes(combined_array)
|
| 177 |
+
|
| 178 |
+
# Update silence counter based on voice activity
|
| 179 |
+
if not has_voice_activity:
|
| 180 |
+
self.silence_counters[participant_id] += 1
|
| 181 |
+
else:
|
| 182 |
+
self.silence_counters[participant_id] = 0
|
| 183 |
+
|
| 184 |
+
# Check if we should finalize sentence due to prolonged silence
|
| 185 |
+
should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
|
| 186 |
+
len(combined_array) > 0 and
|
| 187 |
+
not self.sentence_finalized[participant_id])
|
| 188 |
+
|
| 189 |
+
if should_finalize:
|
| 190 |
+
return await self._finalize_candidate_sentence(
|
| 191 |
+
language_code, participant_id, sentence_callback
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# Always run transcription on the accumulated audio
|
| 195 |
+
audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
|
| 196 |
+
|
| 197 |
+
# Minimum duration check - ignore very short audio bursts
|
| 198 |
+
MIN_CHUNK_DURATION = 0.3 # 300ms minimum
|
| 199 |
+
if audio_duration_sec < MIN_CHUNK_DURATION:
|
| 200 |
+
print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
|
| 201 |
+
if progress_callback:
|
| 202 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 203 |
+
await progress_callback(cached_text, False)
|
| 204 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 205 |
+
|
| 206 |
+
# Force finalization if buffer gets too long (prevent infinite accumulation)
|
| 207 |
+
if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
|
| 208 |
+
return await self._finalize_candidate_sentence(
|
| 209 |
+
language_code, participant_id, sentence_callback
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Run voice activity detection on the accumulated audio before transcription
|
| 213 |
+
has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
|
| 214 |
+
|
| 215 |
+
if not has_voice_in_buffer:
|
| 216 |
+
# Still send progress update with cached text to maintain UI state
|
| 217 |
+
if progress_callback:
|
| 218 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 219 |
+
await progress_callback(cached_text, False)
|
| 220 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 221 |
+
|
| 222 |
+
# Run transcription
|
| 223 |
+
await self.ensure_model_loaded(language_code)
|
| 224 |
+
|
| 225 |
+
# Double-check voice activity before running expensive ASR
|
| 226 |
+
has_voice_for_asr = self.has_voice_activity(combined_bytes)
|
| 227 |
+
if not has_voice_for_asr:
|
| 228 |
+
print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
|
| 229 |
+
# Return cached text and send progress update
|
| 230 |
+
if progress_callback:
|
| 231 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 232 |
+
await progress_callback(cached_text, False)
|
| 233 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 234 |
+
|
| 235 |
+
if language_code not in self.asr_pipelines:
|
| 236 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 237 |
+
|
| 238 |
+
print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
|
| 239 |
+
pipeline_obj = self.asr_pipelines[language_code]
|
| 240 |
+
|
| 241 |
+
# Normalize the ENTIRE accumulated audio buffer before transcription
|
| 242 |
+
# This prevents audio distortion from per-chunk normalization
|
| 243 |
+
normalized_array = combined_array.astype(np.float32)
|
| 244 |
+
max_val = np.max(np.abs(normalized_array))
|
| 245 |
+
if max_val > 0:
|
| 246 |
+
normalized_array = normalized_array / max_val
|
| 247 |
+
|
| 248 |
+
# Track transcription latency
|
| 249 |
+
transcription_start_time = time.time()
|
| 250 |
+
|
| 251 |
+
# For wav2vec2 models, request word timestamps
|
| 252 |
+
model_type = self.asr_config[language_code].get("model_type", "whisper")
|
| 253 |
+
if model_type in ["wav2vec2-bert", "wav2vec2"]:
|
| 254 |
+
result = pipeline_obj(
|
| 255 |
+
{"sampling_rate": 16000, "raw": normalized_array},
|
| 256 |
+
return_timestamps="word"
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
# Whisper model - add anti-hallucination parameters
|
| 260 |
+
# Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
|
| 261 |
+
result = pipeline_obj(
|
| 262 |
+
{"sampling_rate": 16000, "raw": normalized_array},
|
| 263 |
+
return_timestamps=True,
|
| 264 |
+
chunk_length_s=30, # Process in 30s chunks
|
| 265 |
+
stride_length_s=5 # 5s stride for context
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
transcription_latency_ms = (time.time() - transcription_start_time) * 1000
|
| 269 |
+
|
| 270 |
+
candidate_text = result.get("text", "").strip()
|
| 271 |
+
word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
|
| 272 |
+
|
| 273 |
+
# Send debug information if callback provided (for wav2vec2 models only)
|
| 274 |
+
if debug_callback and word_timestamps is not None:
|
| 275 |
+
debug_info = {
|
| 276 |
+
"text": candidate_text,
|
| 277 |
+
"timestamps": word_timestamps,
|
| 278 |
+
"audio_data": combined_bytes,
|
| 279 |
+
"audio_duration": audio_duration_sec,
|
| 280 |
+
"model_type": model_type,
|
| 281 |
+
"transcription_latency_ms": transcription_latency_ms
|
| 282 |
+
}
|
| 283 |
+
await debug_callback(debug_info)
|
| 284 |
+
|
| 285 |
+
# Filter out common ASR artifacts and very short responses
|
| 286 |
+
artifacts = [
|
| 287 |
+
"thank you", "thanks", "bye", ".", ",", "?", "!",
|
| 288 |
+
"um", "uh", "ah", "hmm", "mm", "mhm",
|
| 289 |
+
"you", "the", "a", "an", "and", "but", "or",
|
| 290 |
+
"music", "laughter", "applause", "[music]", "[laughter]",
|
| 291 |
+
# Common Whisper hallucinations:
|
| 292 |
+
"subscribe", "subtitles", "amara", "www", "http",
|
| 293 |
+
"please subscribe", "like and subscribe",
|
| 294 |
+
"thank you for watching", "don't forget to subscribe",
|
| 295 |
+
"[blank_audio]", "[noise]", "[silence]",
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
# Check if the result is likely an artifact
|
| 299 |
+
is_artifact = (
|
| 300 |
+
len(candidate_text) < 3 or # Very short
|
| 301 |
+
candidate_text.lower() in artifacts or # Common artifacts
|
| 302 |
+
len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
if is_artifact:
|
| 306 |
+
# Keep the previous cached text instead of updating with artifact
|
| 307 |
+
candidate_text = self.candidate_text_cache.get(participant_id, "")
|
| 308 |
+
|
| 309 |
+
# Cache the current candidate text
|
| 310 |
+
self.candidate_text_cache[participant_id] = candidate_text
|
| 311 |
+
|
| 312 |
+
# Force completion if we have a reasonable amount of text and some silence
|
| 313 |
+
word_count = len(candidate_text.split()) if candidate_text else 0
|
| 314 |
+
if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
|
| 315 |
+
not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
|
| 316 |
+
return await self._finalize_candidate_sentence(
|
| 317 |
+
language_code, participant_id, sentence_callback
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Always send progress update
|
| 321 |
+
if progress_callback:
|
| 322 |
+
await progress_callback(candidate_text, False)
|
| 323 |
+
|
| 324 |
+
return candidate_text
|
| 325 |
+
|
| 326 |
+
except Exception as e:
|
| 327 |
+
print(f"TranscriptionService: Error processing audio chunk: {e}")
|
| 328 |
+
import traceback
|
| 329 |
+
traceback.print_exc()
|
| 330 |
+
# Even on error, try to send cached text
|
| 331 |
+
if progress_callback:
|
| 332 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 333 |
+
await progress_callback(cached_text, False)
|
| 334 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 335 |
+
|
| 336 |
+
async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
|
| 337 |
+
sentence_callback: Optional[Callable] = None) -> str:
|
| 338 |
+
"""Finalize the current candidate sentence and clear buffers"""
|
| 339 |
+
try:
|
| 340 |
+
# Check if sentence was already finalized
|
| 341 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 342 |
+
print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
|
| 343 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 344 |
+
|
| 345 |
+
final_text = self.candidate_text_cache.get(participant_id, "")
|
| 346 |
+
final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
|
| 347 |
+
|
| 348 |
+
# Convert audio array to bytes for VAD check and callback
|
| 349 |
+
final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
|
| 350 |
+
|
| 351 |
+
if final_text and len(final_text.strip()) > 0:
|
| 352 |
+
# Run VAD check on the final accumulated buffer before sending for translation
|
| 353 |
+
if len(final_audio_bytes) > 0:
|
| 354 |
+
has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
|
| 355 |
+
if not has_voice_in_final:
|
| 356 |
+
print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
|
| 357 |
+
# Clear buffers without sending to translation
|
| 358 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 359 |
+
self.candidate_text_cache[participant_id] = ""
|
| 360 |
+
self.silence_counters[participant_id] = 0
|
| 361 |
+
self.sentence_finalized[participant_id] = False
|
| 362 |
+
return ""
|
| 363 |
+
|
| 364 |
+
# Mark as finalized BEFORE calling the callback to prevent race conditions
|
| 365 |
+
self.sentence_finalized[participant_id] = True
|
| 366 |
+
|
| 367 |
+
# Send to sentence callback for translation
|
| 368 |
+
if sentence_callback and len(final_audio_bytes) > 0:
|
| 369 |
+
print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
|
| 370 |
+
await sentence_callback(final_text, final_audio_bytes)
|
| 371 |
+
|
| 372 |
+
# Clear buffers for next sentence
|
| 373 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 374 |
+
self.candidate_text_cache[participant_id] = ""
|
| 375 |
+
self.silence_counters[participant_id] = 0
|
| 376 |
+
self.sentence_finalized[participant_id] = False # Reset for next sentence
|
| 377 |
+
|
| 378 |
+
return final_text
|
| 379 |
+
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"Error finalizing sentence: {e}")
|
| 382 |
+
import traceback
|
| 383 |
+
traceback.print_exc()
|
| 384 |
+
# Reset finalized flag on error
|
| 385 |
+
self.sentence_finalized[participant_id] = False
|
| 386 |
+
return ""
|
| 387 |
+
|
| 388 |
+
def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
|
| 389 |
+
"""Voice Activity Detection using Silero VAD (with RMS fallback)"""
|
| 390 |
+
try:
|
| 391 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 392 |
+
if len(audio_array) == 0:
|
| 393 |
+
print("VAD: No audio array, returning False")
|
| 394 |
+
return False
|
| 395 |
+
|
| 396 |
+
# Normalize audio to float32 range [-1, 1]
|
| 397 |
+
audio_array = audio_array.astype(np.float32)
|
| 398 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 399 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 400 |
+
|
| 401 |
+
# Use Silero VAD if available
|
| 402 |
+
if self.vad_available and self.vad_model is not None:
|
| 403 |
+
try:
|
| 404 |
+
# Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
|
| 405 |
+
# Process audio in chunks and average the probabilities
|
| 406 |
+
frame_size = 512 # 32ms at 16kHz
|
| 407 |
+
num_samples = len(audio_array)
|
| 408 |
+
|
| 409 |
+
# If audio is too short, pad it
|
| 410 |
+
if num_samples < frame_size:
|
| 411 |
+
audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
|
| 412 |
+
num_samples = frame_size
|
| 413 |
+
|
| 414 |
+
# Process in frames and collect probabilities
|
| 415 |
+
speech_probs = []
|
| 416 |
+
for i in range(0, num_samples, frame_size):
|
| 417 |
+
frame = audio_array[i:i + frame_size]
|
| 418 |
+
if len(frame) < frame_size:
|
| 419 |
+
# Pad last frame if needed
|
| 420 |
+
frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
|
| 421 |
+
|
| 422 |
+
# Convert to torch tensor
|
| 423 |
+
frame_tensor = torch.from_numpy(frame).float()
|
| 424 |
+
|
| 425 |
+
# Get speech probability from Silero VAD
|
| 426 |
+
with torch.no_grad():
|
| 427 |
+
prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
|
| 428 |
+
speech_probs.append(prob)
|
| 429 |
+
|
| 430 |
+
# Average probability across all frames
|
| 431 |
+
speech_prob = np.mean(speech_probs)
|
| 432 |
+
has_voice = speech_prob > threshold
|
| 433 |
+
|
| 434 |
+
print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
|
| 435 |
+
|
| 436 |
+
return has_voice
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
|
| 440 |
+
# Fall through to RMS-based VAD below
|
| 441 |
+
|
| 442 |
+
# Fallback: RMS-based VAD (original implementation)
|
| 443 |
+
rms_threshold = 0.002
|
| 444 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 445 |
+
peak = np.max(np.abs(audio_array))
|
| 446 |
+
audio_std = np.std(audio_array)
|
| 447 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 448 |
+
|
| 449 |
+
has_voice_rms = rms > rms_threshold
|
| 450 |
+
has_voice_peak = peak > rms_threshold * 3
|
| 451 |
+
has_voice_variation = audio_std > rms_threshold * 0.8
|
| 452 |
+
has_voice_zcr = zero_crossing_rate > 0.008
|
| 453 |
+
|
| 454 |
+
has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
|
| 455 |
+
|
| 456 |
+
print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
|
| 457 |
+
|
| 458 |
+
return has_voice
|
| 459 |
+
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(f"Error in VAD: {e}")
|
| 462 |
+
return True # Default to assuming voice activity on error
|
| 463 |
+
|
| 464 |
+
def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
|
| 465 |
+
"""Stricter VAD check specifically for pre-transcription filtering"""
|
| 466 |
+
try:
|
| 467 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 468 |
+
if len(audio_array) == 0:
|
| 469 |
+
return False
|
| 470 |
+
|
| 471 |
+
# Normalize audio
|
| 472 |
+
audio_array = audio_array.astype(np.float32)
|
| 473 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 474 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 475 |
+
|
| 476 |
+
# Calculate features with higher thresholds for meaningful speech
|
| 477 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 478 |
+
peak = np.max(np.abs(audio_array))
|
| 479 |
+
audio_std = np.std(audio_array)
|
| 480 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 481 |
+
|
| 482 |
+
# Higher thresholds for meaningful speech detection
|
| 483 |
+
has_meaningful_voice = (
|
| 484 |
+
rms > threshold and
|
| 485 |
+
peak > threshold * 2 and
|
| 486 |
+
audio_std > threshold * 0.5 and
|
| 487 |
+
zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
return has_meaningful_voice
|
| 491 |
+
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print(f"Error in meaningful VAD: {e}")
|
| 494 |
+
return False # Default to no meaningful voice on error
|
| 495 |
+
|
| 496 |
+
async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
|
| 497 |
+
"""Force complete any pending sentence for a participant"""
|
| 498 |
+
try:
|
| 499 |
+
# Check if sentence was already finalized
|
| 500 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 501 |
+
print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
|
| 502 |
+
return ""
|
| 503 |
+
|
| 504 |
+
if participant_id in self.candidate_text_cache:
|
| 505 |
+
cached_text = self.candidate_text_cache[participant_id]
|
| 506 |
+
|
| 507 |
+
if cached_text and len(cached_text.strip()) > 0:
|
| 508 |
+
result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
|
| 509 |
+
return result
|
| 510 |
+
|
| 511 |
+
return ""
|
| 512 |
+
|
| 513 |
+
except Exception as e:
|
| 514 |
+
print(f"Error in force_complete_sentence: {e}")
|
| 515 |
+
import traceback
|
| 516 |
+
traceback.print_exc()
|
| 517 |
+
return ""
|
| 518 |
+
|
| 519 |
+
async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
|
| 520 |
+
"""Transcribe audio data to text"""
|
| 521 |
+
try:
|
| 522 |
+
# Check for voice activity before running ASR
|
| 523 |
+
has_voice = self.has_voice_activity(audio_data)
|
| 524 |
+
if not has_voice:
|
| 525 |
+
print(f"ASR: No voice activity detected in audio data, skipping transcription")
|
| 526 |
+
return ""
|
| 527 |
+
|
| 528 |
+
await self.ensure_model_loaded(language_code)
|
| 529 |
+
|
| 530 |
+
if language_code not in self.asr_pipelines:
|
| 531 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 532 |
+
|
| 533 |
+
# Convert audio bytes to numpy array
|
| 534 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 535 |
+
|
| 536 |
+
print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
|
| 537 |
+
# Transcribe
|
| 538 |
+
pipeline_obj = self.asr_pipelines[language_code]
|
| 539 |
+
result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
|
| 540 |
+
|
| 541 |
+
text = result.get("text", "")
|
| 542 |
+
|
| 543 |
+
if callback:
|
| 544 |
+
await callback(text)
|
| 545 |
+
|
| 546 |
+
return text
|
| 547 |
+
|
| 548 |
+
except Exception as e:
|
| 549 |
+
print(f"TranscriptionService: Transcription error: {e}")
|
| 550 |
+
import traceback
|
| 551 |
+
traceback.print_exc()
|
| 552 |
+
return ""
|
| 553 |
+
|
| 554 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 555 |
+
"""Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
|
| 556 |
+
try:
|
| 557 |
+
# Detect format by checking magic bytes
|
| 558 |
+
is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
|
| 559 |
+
is_wav = audio_data[:4] == b'RIFF'
|
| 560 |
+
|
| 561 |
+
import sys
|
| 562 |
+
print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
|
| 563 |
+
sys.stdout.flush()
|
| 564 |
+
|
| 565 |
+
# Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
|
| 566 |
+
# This is the most common case for microphone input
|
| 567 |
+
if not is_wav and not is_webm and len(audio_data) > 0:
|
| 568 |
+
try:
|
| 569 |
+
# Assume 16-bit PCM at 48kHz (browser's native rate)
|
| 570 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
| 571 |
+
|
| 572 |
+
# Check if this looks like valid audio data (not NaN, reasonable range)
|
| 573 |
+
if len(audio_array) > 0 and not np.isnan(audio_array).any():
|
| 574 |
+
print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
|
| 575 |
+
|
| 576 |
+
# Convert to float32 and normalize
|
| 577 |
+
audio_float = audio_array.astype(np.float32) / 32768.0
|
| 578 |
+
|
| 579 |
+
# Resample from 48kHz to 16kHz
|
| 580 |
+
import librosa
|
| 581 |
+
audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
|
| 582 |
+
print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
|
| 583 |
+
|
| 584 |
+
return audio_array
|
| 585 |
+
except Exception as pcm_error:
|
| 586 |
+
print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
|
| 587 |
+
# Fall through to other methods
|
| 588 |
+
|
| 589 |
+
if is_webm:
|
| 590 |
+
# Decode WebM/Opus using pydub (requires ffmpeg)
|
| 591 |
+
try:
|
| 592 |
+
from pydub import AudioSegment
|
| 593 |
+
audio_io = io.BytesIO(audio_data)
|
| 594 |
+
audio_segment = AudioSegment.from_file(audio_io, format="webm")
|
| 595 |
+
|
| 596 |
+
# Convert to mono 16kHz
|
| 597 |
+
audio_segment = audio_segment.set_channels(1)
|
| 598 |
+
audio_segment = audio_segment.set_frame_rate(16000)
|
| 599 |
+
|
| 600 |
+
# Convert to numpy array
|
| 601 |
+
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
|
| 602 |
+
# Normalize to float32 [-1, 1]
|
| 603 |
+
audio_array = samples.astype(np.float32) / 32768.0
|
| 604 |
+
return audio_array
|
| 605 |
+
except Exception as webm_error:
|
| 606 |
+
print(f"TranscriptionService: WebM decoding error: {webm_error}")
|
| 607 |
+
# Fall through to other methods
|
| 608 |
+
|
| 609 |
+
if is_wav:
|
| 610 |
+
# Decode WAV format (first chunk from frontend includes WAV header with sample rate)
|
| 611 |
+
try:
|
| 612 |
+
audio_io = io.BytesIO(audio_data)
|
| 613 |
+
with wave.open(audio_io, 'rb') as wav_file:
|
| 614 |
+
sample_rate = wav_file.getframerate()
|
| 615 |
+
channels = wav_file.getnchannels()
|
| 616 |
+
sample_width = wav_file.getsampwidth()
|
| 617 |
+
|
| 618 |
+
print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
|
| 619 |
+
|
| 620 |
+
frames = wav_file.readframes(-1)
|
| 621 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 622 |
+
|
| 623 |
+
# Resample if needed
|
| 624 |
+
if sample_rate != 16000:
|
| 625 |
+
print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
|
| 626 |
+
import librosa
|
| 627 |
+
# Convert to float first
|
| 628 |
+
audio_float = audio_array.astype(np.float32) / 32768.0
|
| 629 |
+
# Resample
|
| 630 |
+
audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
|
| 631 |
+
print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
|
| 632 |
+
else:
|
| 633 |
+
# Convert to float32 and normalize
|
| 634 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 635 |
+
|
| 636 |
+
print(f"Returning audio array: {len(audio_array)} samples", flush=True)
|
| 637 |
+
return audio_array
|
| 638 |
+
except Exception as wav_error:
|
| 639 |
+
print(f"TranscriptionService: WAV decoding error: {wav_error}")
|
| 640 |
+
import traceback
|
| 641 |
+
traceback.print_exc()
|
| 642 |
+
|
| 643 |
+
# Fallback: assume raw float32 audio data
|
| 644 |
+
try:
|
| 645 |
+
audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 646 |
+
return audio_array
|
| 647 |
+
except Exception:
|
| 648 |
+
pass
|
| 649 |
+
|
| 650 |
+
# Last resort: return empty array
|
| 651 |
+
return np.array([], dtype=np.float32)
|
| 652 |
+
|
| 653 |
+
except Exception as e:
|
| 654 |
+
print(f"TranscriptionService: Audio conversion error: {e}")
|
| 655 |
+
return np.array([], dtype=np.float32)
|
| 656 |
+
|
| 657 |
+
def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
|
| 658 |
+
"""Convert numpy audio array back to WAV bytes for storage"""
|
| 659 |
+
try:
|
| 660 |
+
# Ensure float32 format
|
| 661 |
+
if audio_array.dtype != np.float32:
|
| 662 |
+
audio_array = audio_array.astype(np.float32)
|
| 663 |
+
|
| 664 |
+
# Convert to 16-bit PCM for WAV storage
|
| 665 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
| 666 |
+
|
| 667 |
+
# Create WAV bytes
|
| 668 |
+
wav_buffer = io.BytesIO()
|
| 669 |
+
with wave.open(wav_buffer, 'wb') as wav_file:
|
| 670 |
+
wav_file.setnchannels(1) # Mono
|
| 671 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 672 |
+
wav_file.setframerate(16000) # 16kHz
|
| 673 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 674 |
+
|
| 675 |
+
return wav_buffer.getvalue()
|
| 676 |
+
|
| 677 |
+
except Exception as e:
|
| 678 |
+
print(f"Error converting audio array to bytes: {e}")
|
| 679 |
+
return b''
|
| 680 |
+
|
| 681 |
+
def clear_participant_buffers(self, participant_id: str):
|
| 682 |
+
"""Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
|
| 683 |
+
if participant_id in self.candidate_audio_buffers:
|
| 684 |
+
del self.candidate_audio_buffers[participant_id]
|
| 685 |
+
if participant_id in self.candidate_text_cache:
|
| 686 |
+
del self.candidate_text_cache[participant_id]
|
| 687 |
+
if participant_id in self.silence_counters:
|
| 688 |
+
del self.silence_counters[participant_id]
|
| 689 |
+
if participant_id in self.sentence_finalized:
|
| 690 |
+
del self.sentence_finalized[participant_id]
|
| 691 |
+
|
| 692 |
+
async def load_remaining_models_in_background(self):
|
| 693 |
+
"""Load all remaining ASR models in the background after startup"""
|
| 694 |
+
try:
|
| 695 |
+
print("ASR: Starting background loading of additional language models...")
|
| 696 |
+
for lang_code in self.asr_config.keys():
|
| 697 |
+
if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
|
| 698 |
+
try:
|
| 699 |
+
print(f"ASR: Background loading model for {lang_code}...")
|
| 700 |
+
self.models_loading_status[lang_code] = "loading"
|
| 701 |
+
|
| 702 |
+
model_config = self.asr_config[lang_code]
|
| 703 |
+
# Use quantization helper for background loading too
|
| 704 |
+
pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
|
| 705 |
+
self.asr_pipelines[lang_code] = pipeline_obj
|
| 706 |
+
self.models_loading_status[lang_code] = "loaded"
|
| 707 |
+
print(f"ASR: Successfully loaded model for {lang_code} in background")
|
| 708 |
+
|
| 709 |
+
# Add a small delay between loading models to prevent overwhelming the system
|
| 710 |
+
await asyncio.sleep(2)
|
| 711 |
+
except Exception as e:
|
| 712 |
+
print(f"ASR: Failed to load model for {lang_code} in background: {e}")
|
| 713 |
+
self.models_loading_status[lang_code] = "failed"
|
| 714 |
+
|
| 715 |
+
print("ASR: Background loading of all language models complete")
|
| 716 |
+
print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
|
| 717 |
+
except Exception as e:
|
| 718 |
+
print(f"ASR: Error in background model loading: {e}")
|
| 719 |
+
|
| 720 |
+
def start_background_loading(self):
|
| 721 |
+
"""Start background loading of models as a non-blocking task"""
|
| 722 |
+
if self.background_loading_task is None:
|
| 723 |
+
self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
|
| 724 |
+
print("ASR: Background model loading task started")
|
| 725 |
+
|
| 726 |
+
async def cleanup(self):
|
| 727 |
+
"""Cleanup resources"""
|
| 728 |
+
# Cancel background loading if still running
|
| 729 |
+
if self.background_loading_task and not self.background_loading_task.done():
|
| 730 |
+
self.background_loading_task.cancel()
|
| 731 |
+
try:
|
| 732 |
+
await self.background_loading_task
|
| 733 |
+
except asyncio.CancelledError:
|
| 734 |
+
pass
|
| 735 |
+
|
| 736 |
+
self.asr_pipelines.clear()
|
app/services/transcription_service.py.bak
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
from typing import Dict, Optional, Callable
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
import torch
|
| 9 |
+
from app.models import LanguageCode
|
| 10 |
+
from app.services.performance_mixin import track_performance
|
| 11 |
+
|
| 12 |
+
# Silero VAD imports
|
| 13 |
+
try:
|
| 14 |
+
import silero_vad
|
| 15 |
+
SILERO_VAD_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
SILERO_VAD_AVAILABLE = False
|
| 18 |
+
print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
|
| 19 |
+
|
| 20 |
+
class TranscriptionService:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.asr_pipelines: Dict[str, any] = {}
|
| 23 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 24 |
+
|
| 25 |
+
# Model configurations - using original mutisya models with updated config
|
| 26 |
+
self.asr_config = {
|
| 27 |
+
"eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
|
| 28 |
+
"swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 29 |
+
"kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 30 |
+
"kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 31 |
+
"mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 32 |
+
"luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
|
| 33 |
+
"som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
self.preload_languages = ["eng"]
|
| 37 |
+
self.background_loading_task = None
|
| 38 |
+
self.models_loading_status = {}
|
| 39 |
+
|
| 40 |
+
# Enhanced audio buffering for VAD-based sentence detection
|
| 41 |
+
self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
|
| 42 |
+
self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
|
| 43 |
+
self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
|
| 44 |
+
self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
|
| 45 |
+
|
| 46 |
+
# VAD parameters - made more lenient for better detection
|
| 47 |
+
self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
|
| 48 |
+
self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
|
| 49 |
+
|
| 50 |
+
# Silero VAD initialization
|
| 51 |
+
self.vad_model = None
|
| 52 |
+
self.vad_sample_rate = 16000
|
| 53 |
+
self.vad_available = SILERO_VAD_AVAILABLE
|
| 54 |
+
|
| 55 |
+
async def initialize(self):
|
| 56 |
+
"""Initialize ASR models for preloaded languages and Silero VAD"""
|
| 57 |
+
# Initialize Silero VAD model
|
| 58 |
+
if self.vad_available:
|
| 59 |
+
try:
|
| 60 |
+
print("Loading Silero VAD model...")
|
| 61 |
+
self.vad_model = silero_vad.load_silero_vad(onnx=False)
|
| 62 |
+
print("✓ Silero VAD model loaded successfully")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Failed to load Silero VAD model: {e}")
|
| 65 |
+
print("Falling back to RMS-based VAD")
|
| 66 |
+
self.vad_available = False
|
| 67 |
+
|
| 68 |
+
# Initialize ASR models
|
| 69 |
+
for lang_code in self.preload_languages:
|
| 70 |
+
if lang_code in self.asr_config:
|
| 71 |
+
try:
|
| 72 |
+
model_config = self.asr_config[lang_code]
|
| 73 |
+
# Build pipeline parameters
|
| 74 |
+
pipeline_params = {
|
| 75 |
+
"task": "automatic-speech-recognition",
|
| 76 |
+
"model": model_config["model_repo"],
|
| 77 |
+
"device": self.device,
|
| 78 |
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
# Add trust_remote_code if specified
|
| 82 |
+
if model_config.get("trust_remote_code", False):
|
| 83 |
+
pipeline_params["trust_remote_code"] = True
|
| 84 |
+
|
| 85 |
+
pipeline_obj = pipeline(**pipeline_params)
|
| 86 |
+
self.asr_pipelines[lang_code] = pipeline_obj
|
| 87 |
+
except Exception as e:
|
| 88 |
+
print(f"Failed to load ASR model for {lang_code}: {e}")
|
| 89 |
+
|
| 90 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 91 |
+
"""Load ASR model for language if not already loaded"""
|
| 92 |
+
if language_code not in self.asr_pipelines and language_code in self.asr_config:
|
| 93 |
+
try:
|
| 94 |
+
model_config = self.asr_config[language_code]
|
| 95 |
+
# Build pipeline parameters
|
| 96 |
+
pipeline_params = {
|
| 97 |
+
"task": "automatic-speech-recognition",
|
| 98 |
+
"model": model_config["model_repo"],
|
| 99 |
+
"device": self.device,
|
| 100 |
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
# Add trust_remote_code if specified
|
| 104 |
+
if model_config.get("trust_remote_code", False):
|
| 105 |
+
pipeline_params["trust_remote_code"] = True
|
| 106 |
+
|
| 107 |
+
pipeline_obj = pipeline(**pipeline_params)
|
| 108 |
+
self.asr_pipelines[language_code] = pipeline_obj
|
| 109 |
+
except Exception as e:
|
| 110 |
+
print(f"Failed to load ASR model for {language_code}: {e}")
|
| 111 |
+
raise
|
| 112 |
+
|
| 113 |
+
async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
|
| 114 |
+
has_voice_activity: bool = True,
|
| 115 |
+
progress_callback: Optional[Callable] = None,
|
| 116 |
+
sentence_callback: Optional[Callable] = None,
|
| 117 |
+
debug_callback: Optional[Callable] = None) -> str:
|
| 118 |
+
"""Process audio chunk with VAD-based sentence detection"""
|
| 119 |
+
try:
|
| 120 |
+
# Initialize buffers if needed
|
| 121 |
+
if participant_id not in self.candidate_audio_buffers:
|
| 122 |
+
# Store as numpy array, not bytes, to avoid multiple WAV header issues
|
| 123 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 124 |
+
self.candidate_text_cache[participant_id] = ""
|
| 125 |
+
self.silence_counters[participant_id] = 0
|
| 126 |
+
self.sentence_finalized[participant_id] = False
|
| 127 |
+
|
| 128 |
+
# Convert current chunk to numpy array for processing
|
| 129 |
+
current_chunk_array = self._bytes_to_audio_array(audio_data)
|
| 130 |
+
if len(current_chunk_array) == 0:
|
| 131 |
+
print(f"WARNING: Received empty audio chunk for participant {participant_id}")
|
| 132 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 133 |
+
|
| 134 |
+
print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
|
| 135 |
+
f"duration: {len(current_chunk_array)/16000:.3f}s, "
|
| 136 |
+
f"first 4 bytes: {audio_data[:4]}")
|
| 137 |
+
|
| 138 |
+
# DO NOT normalize individual chunks - this causes audio distortion
|
| 139 |
+
# We'll normalize the entire accumulated audio buffer before transcription
|
| 140 |
+
current_chunk_array = current_chunk_array.astype(np.float32)
|
| 141 |
+
|
| 142 |
+
# Get existing accumulated audio array (now stored as numpy array)
|
| 143 |
+
existing_array = self.candidate_audio_buffers[participant_id]
|
| 144 |
+
if len(existing_array) > 0:
|
| 145 |
+
# Concatenate with existing audio (like stream = np.concatenate([stream, y]))
|
| 146 |
+
combined_array = np.concatenate([existing_array, current_chunk_array])
|
| 147 |
+
else:
|
| 148 |
+
combined_array = current_chunk_array
|
| 149 |
+
|
| 150 |
+
# Store as numpy array to avoid WAV header accumulation issues
|
| 151 |
+
self.candidate_audio_buffers[participant_id] = combined_array
|
| 152 |
+
|
| 153 |
+
# For debug callback, convert to bytes (this adds ONE WAV header)
|
| 154 |
+
combined_bytes = self._audio_array_to_bytes(combined_array)
|
| 155 |
+
|
| 156 |
+
# Update silence counter based on voice activity
|
| 157 |
+
if not has_voice_activity:
|
| 158 |
+
self.silence_counters[participant_id] += 1
|
| 159 |
+
else:
|
| 160 |
+
self.silence_counters[participant_id] = 0
|
| 161 |
+
|
| 162 |
+
# Check if we should finalize sentence due to prolonged silence
|
| 163 |
+
should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
|
| 164 |
+
len(combined_array) > 0 and
|
| 165 |
+
not self.sentence_finalized[participant_id])
|
| 166 |
+
|
| 167 |
+
if should_finalize:
|
| 168 |
+
return await self._finalize_candidate_sentence(
|
| 169 |
+
language_code, participant_id, sentence_callback
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Always run transcription on the accumulated audio
|
| 173 |
+
audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
|
| 174 |
+
|
| 175 |
+
# Minimum duration check - ignore very short audio bursts
|
| 176 |
+
MIN_CHUNK_DURATION = 0.3 # 300ms minimum
|
| 177 |
+
if audio_duration_sec < MIN_CHUNK_DURATION:
|
| 178 |
+
print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
|
| 179 |
+
if progress_callback:
|
| 180 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 181 |
+
await progress_callback(cached_text, False)
|
| 182 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 183 |
+
|
| 184 |
+
# Force finalization if buffer gets too long (prevent infinite accumulation)
|
| 185 |
+
if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
|
| 186 |
+
return await self._finalize_candidate_sentence(
|
| 187 |
+
language_code, participant_id, sentence_callback
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
# Run voice activity detection on the accumulated audio before transcription
|
| 191 |
+
has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
|
| 192 |
+
|
| 193 |
+
if not has_voice_in_buffer:
|
| 194 |
+
# Still send progress update with cached text to maintain UI state
|
| 195 |
+
if progress_callback:
|
| 196 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 197 |
+
await progress_callback(cached_text, False)
|
| 198 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 199 |
+
|
| 200 |
+
# Run transcription
|
| 201 |
+
await self.ensure_model_loaded(language_code)
|
| 202 |
+
|
| 203 |
+
# Double-check voice activity before running expensive ASR
|
| 204 |
+
has_voice_for_asr = self.has_voice_activity(combined_bytes)
|
| 205 |
+
if not has_voice_for_asr:
|
| 206 |
+
print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
|
| 207 |
+
# Return cached text and send progress update
|
| 208 |
+
if progress_callback:
|
| 209 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 210 |
+
await progress_callback(cached_text, False)
|
| 211 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 212 |
+
|
| 213 |
+
if language_code not in self.asr_pipelines:
|
| 214 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 215 |
+
|
| 216 |
+
print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
|
| 217 |
+
pipeline_obj = self.asr_pipelines[language_code]
|
| 218 |
+
|
| 219 |
+
# Normalize the ENTIRE accumulated audio buffer before transcription
|
| 220 |
+
# This prevents audio distortion from per-chunk normalization
|
| 221 |
+
normalized_array = combined_array.astype(np.float32)
|
| 222 |
+
max_val = np.max(np.abs(normalized_array))
|
| 223 |
+
if max_val > 0:
|
| 224 |
+
normalized_array = normalized_array / max_val
|
| 225 |
+
|
| 226 |
+
# Track transcription latency
|
| 227 |
+
transcription_start_time = time.time()
|
| 228 |
+
|
| 229 |
+
# For wav2vec2 models, request word timestamps
|
| 230 |
+
model_type = self.asr_config[language_code].get("model_type", "whisper")
|
| 231 |
+
if model_type in ["wav2vec2-bert", "wav2vec2"]:
|
| 232 |
+
result = pipeline_obj(
|
| 233 |
+
{"sampling_rate": 16000, "raw": normalized_array},
|
| 234 |
+
return_timestamps="word"
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
# Whisper model - add anti-hallucination parameters
|
| 238 |
+
# Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
|
| 239 |
+
result = pipeline_obj(
|
| 240 |
+
{"sampling_rate": 16000, "raw": normalized_array},
|
| 241 |
+
return_timestamps=True,
|
| 242 |
+
chunk_length_s=30, # Process in 30s chunks
|
| 243 |
+
stride_length_s=5 # 5s stride for context
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
transcription_latency_ms = (time.time() - transcription_start_time) * 1000
|
| 247 |
+
|
| 248 |
+
candidate_text = result.get("text", "").strip()
|
| 249 |
+
word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
|
| 250 |
+
|
| 251 |
+
# Send debug information if callback provided (for wav2vec2 models only)
|
| 252 |
+
if debug_callback and word_timestamps is not None:
|
| 253 |
+
debug_info = {
|
| 254 |
+
"text": candidate_text,
|
| 255 |
+
"timestamps": word_timestamps,
|
| 256 |
+
"audio_data": combined_bytes,
|
| 257 |
+
"audio_duration": audio_duration_sec,
|
| 258 |
+
"model_type": model_type,
|
| 259 |
+
"transcription_latency_ms": transcription_latency_ms
|
| 260 |
+
}
|
| 261 |
+
await debug_callback(debug_info)
|
| 262 |
+
|
| 263 |
+
# Filter out common ASR artifacts and very short responses
|
| 264 |
+
artifacts = [
|
| 265 |
+
"thank you", "thanks", "bye", ".", ",", "?", "!",
|
| 266 |
+
"um", "uh", "ah", "hmm", "mm", "mhm",
|
| 267 |
+
"you", "the", "a", "an", "and", "but", "or",
|
| 268 |
+
"music", "laughter", "applause", "[music]", "[laughter]",
|
| 269 |
+
# Common Whisper hallucinations:
|
| 270 |
+
"subscribe", "subtitles", "amara", "www", "http",
|
| 271 |
+
"please subscribe", "like and subscribe",
|
| 272 |
+
"thank you for watching", "don't forget to subscribe",
|
| 273 |
+
"[blank_audio]", "[noise]", "[silence]",
|
| 274 |
+
]
|
| 275 |
+
|
| 276 |
+
# Check if the result is likely an artifact
|
| 277 |
+
is_artifact = (
|
| 278 |
+
len(candidate_text) < 3 or # Very short
|
| 279 |
+
candidate_text.lower() in artifacts or # Common artifacts
|
| 280 |
+
len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if is_artifact:
|
| 284 |
+
# Keep the previous cached text instead of updating with artifact
|
| 285 |
+
candidate_text = self.candidate_text_cache.get(participant_id, "")
|
| 286 |
+
|
| 287 |
+
# Cache the current candidate text
|
| 288 |
+
self.candidate_text_cache[participant_id] = candidate_text
|
| 289 |
+
|
| 290 |
+
# Force completion if we have a reasonable amount of text and some silence
|
| 291 |
+
word_count = len(candidate_text.split()) if candidate_text else 0
|
| 292 |
+
if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
|
| 293 |
+
not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
|
| 294 |
+
return await self._finalize_candidate_sentence(
|
| 295 |
+
language_code, participant_id, sentence_callback
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Always send progress update
|
| 299 |
+
if progress_callback:
|
| 300 |
+
await progress_callback(candidate_text, False)
|
| 301 |
+
|
| 302 |
+
return candidate_text
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print(f"TranscriptionService: Error processing audio chunk: {e}")
|
| 306 |
+
import traceback
|
| 307 |
+
traceback.print_exc()
|
| 308 |
+
# Even on error, try to send cached text
|
| 309 |
+
if progress_callback:
|
| 310 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 311 |
+
await progress_callback(cached_text, False)
|
| 312 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 313 |
+
|
| 314 |
+
async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
|
| 315 |
+
sentence_callback: Optional[Callable] = None) -> str:
|
| 316 |
+
"""Finalize the current candidate sentence and clear buffers"""
|
| 317 |
+
try:
|
| 318 |
+
# Check if sentence was already finalized
|
| 319 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 320 |
+
print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
|
| 321 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 322 |
+
|
| 323 |
+
final_text = self.candidate_text_cache.get(participant_id, "")
|
| 324 |
+
final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
|
| 325 |
+
|
| 326 |
+
# Convert audio array to bytes for VAD check and callback
|
| 327 |
+
final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
|
| 328 |
+
|
| 329 |
+
if final_text and len(final_text.strip()) > 0:
|
| 330 |
+
# Run VAD check on the final accumulated buffer before sending for translation
|
| 331 |
+
if len(final_audio_bytes) > 0:
|
| 332 |
+
has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
|
| 333 |
+
if not has_voice_in_final:
|
| 334 |
+
print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
|
| 335 |
+
# Clear buffers without sending to translation
|
| 336 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 337 |
+
self.candidate_text_cache[participant_id] = ""
|
| 338 |
+
self.silence_counters[participant_id] = 0
|
| 339 |
+
self.sentence_finalized[participant_id] = False
|
| 340 |
+
return ""
|
| 341 |
+
|
| 342 |
+
# Mark as finalized BEFORE calling the callback to prevent race conditions
|
| 343 |
+
self.sentence_finalized[participant_id] = True
|
| 344 |
+
|
| 345 |
+
# Send to sentence callback for translation
|
| 346 |
+
if sentence_callback and len(final_audio_bytes) > 0:
|
| 347 |
+
print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
|
| 348 |
+
await sentence_callback(final_text, final_audio_bytes)
|
| 349 |
+
|
| 350 |
+
# Clear buffers for next sentence
|
| 351 |
+
self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 352 |
+
self.candidate_text_cache[participant_id] = ""
|
| 353 |
+
self.silence_counters[participant_id] = 0
|
| 354 |
+
self.sentence_finalized[participant_id] = False # Reset for next sentence
|
| 355 |
+
|
| 356 |
+
return final_text
|
| 357 |
+
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"Error finalizing sentence: {e}")
|
| 360 |
+
import traceback
|
| 361 |
+
traceback.print_exc()
|
| 362 |
+
# Reset finalized flag on error
|
| 363 |
+
self.sentence_finalized[participant_id] = False
|
| 364 |
+
return ""
|
| 365 |
+
|
| 366 |
+
def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
|
| 367 |
+
"""Voice Activity Detection using Silero VAD (with RMS fallback)"""
|
| 368 |
+
try:
|
| 369 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 370 |
+
if len(audio_array) == 0:
|
| 371 |
+
print("VAD: No audio array, returning False")
|
| 372 |
+
return False
|
| 373 |
+
|
| 374 |
+
# Normalize audio to float32 range [-1, 1]
|
| 375 |
+
audio_array = audio_array.astype(np.float32)
|
| 376 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 377 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 378 |
+
|
| 379 |
+
# Use Silero VAD if available
|
| 380 |
+
if self.vad_available and self.vad_model is not None:
|
| 381 |
+
try:
|
| 382 |
+
# Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
|
| 383 |
+
# Process audio in chunks and average the probabilities
|
| 384 |
+
frame_size = 512 # 32ms at 16kHz
|
| 385 |
+
num_samples = len(audio_array)
|
| 386 |
+
|
| 387 |
+
# If audio is too short, pad it
|
| 388 |
+
if num_samples < frame_size:
|
| 389 |
+
audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
|
| 390 |
+
num_samples = frame_size
|
| 391 |
+
|
| 392 |
+
# Process in frames and collect probabilities
|
| 393 |
+
speech_probs = []
|
| 394 |
+
for i in range(0, num_samples, frame_size):
|
| 395 |
+
frame = audio_array[i:i + frame_size]
|
| 396 |
+
if len(frame) < frame_size:
|
| 397 |
+
# Pad last frame if needed
|
| 398 |
+
frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
|
| 399 |
+
|
| 400 |
+
# Convert to torch tensor
|
| 401 |
+
frame_tensor = torch.from_numpy(frame).float()
|
| 402 |
+
|
| 403 |
+
# Get speech probability from Silero VAD
|
| 404 |
+
with torch.no_grad():
|
| 405 |
+
prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
|
| 406 |
+
speech_probs.append(prob)
|
| 407 |
+
|
| 408 |
+
# Average probability across all frames
|
| 409 |
+
speech_prob = np.mean(speech_probs)
|
| 410 |
+
has_voice = speech_prob > threshold
|
| 411 |
+
|
| 412 |
+
print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
|
| 413 |
+
|
| 414 |
+
return has_voice
|
| 415 |
+
|
| 416 |
+
except Exception as e:
|
| 417 |
+
print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
|
| 418 |
+
# Fall through to RMS-based VAD below
|
| 419 |
+
|
| 420 |
+
# Fallback: RMS-based VAD (original implementation)
|
| 421 |
+
rms_threshold = 0.002
|
| 422 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 423 |
+
peak = np.max(np.abs(audio_array))
|
| 424 |
+
audio_std = np.std(audio_array)
|
| 425 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 426 |
+
|
| 427 |
+
has_voice_rms = rms > rms_threshold
|
| 428 |
+
has_voice_peak = peak > rms_threshold * 3
|
| 429 |
+
has_voice_variation = audio_std > rms_threshold * 0.8
|
| 430 |
+
has_voice_zcr = zero_crossing_rate > 0.008
|
| 431 |
+
|
| 432 |
+
has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
|
| 433 |
+
|
| 434 |
+
print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
|
| 435 |
+
|
| 436 |
+
return has_voice
|
| 437 |
+
|
| 438 |
+
except Exception as e:
|
| 439 |
+
print(f"Error in VAD: {e}")
|
| 440 |
+
return True # Default to assuming voice activity on error
|
| 441 |
+
|
| 442 |
+
def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
|
| 443 |
+
"""Stricter VAD check specifically for pre-transcription filtering"""
|
| 444 |
+
try:
|
| 445 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 446 |
+
if len(audio_array) == 0:
|
| 447 |
+
return False
|
| 448 |
+
|
| 449 |
+
# Normalize audio
|
| 450 |
+
audio_array = audio_array.astype(np.float32)
|
| 451 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 452 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 453 |
+
|
| 454 |
+
# Calculate features with higher thresholds for meaningful speech
|
| 455 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 456 |
+
peak = np.max(np.abs(audio_array))
|
| 457 |
+
audio_std = np.std(audio_array)
|
| 458 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 459 |
+
|
| 460 |
+
# Higher thresholds for meaningful speech detection
|
| 461 |
+
has_meaningful_voice = (
|
| 462 |
+
rms > threshold and
|
| 463 |
+
peak > threshold * 2 and
|
| 464 |
+
audio_std > threshold * 0.5 and
|
| 465 |
+
zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
return has_meaningful_voice
|
| 469 |
+
|
| 470 |
+
except Exception as e:
|
| 471 |
+
print(f"Error in meaningful VAD: {e}")
|
| 472 |
+
return False # Default to no meaningful voice on error
|
| 473 |
+
|
| 474 |
+
async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
|
| 475 |
+
"""Force complete any pending sentence for a participant"""
|
| 476 |
+
try:
|
| 477 |
+
# Check if sentence was already finalized
|
| 478 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 479 |
+
print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
|
| 480 |
+
return ""
|
| 481 |
+
|
| 482 |
+
if participant_id in self.candidate_text_cache:
|
| 483 |
+
cached_text = self.candidate_text_cache[participant_id]
|
| 484 |
+
|
| 485 |
+
if cached_text and len(cached_text.strip()) > 0:
|
| 486 |
+
result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
|
| 487 |
+
return result
|
| 488 |
+
|
| 489 |
+
return ""
|
| 490 |
+
|
| 491 |
+
except Exception as e:
|
| 492 |
+
print(f"Error in force_complete_sentence: {e}")
|
| 493 |
+
import traceback
|
| 494 |
+
traceback.print_exc()
|
| 495 |
+
return ""
|
| 496 |
+
|
| 497 |
+
@track_performance("transcription", "transcribe_audio")
|
| 498 |
+
async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
|
| 499 |
+
"""Transcribe audio data to text"""
|
| 500 |
+
try:
|
| 501 |
+
# Check for voice activity before running ASR
|
| 502 |
+
has_voice = self.has_voice_activity(audio_data)
|
| 503 |
+
if not has_voice:
|
| 504 |
+
print(f"ASR: No voice activity detected in audio data, skipping transcription")
|
| 505 |
+
return ""
|
| 506 |
+
|
| 507 |
+
await self.ensure_model_loaded(language_code)
|
| 508 |
+
|
| 509 |
+
if language_code not in self.asr_pipelines:
|
| 510 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 511 |
+
|
| 512 |
+
# Convert audio bytes to numpy array
|
| 513 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 514 |
+
|
| 515 |
+
print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
|
| 516 |
+
# Transcribe
|
| 517 |
+
pipeline_obj = self.asr_pipelines[language_code]
|
| 518 |
+
result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
|
| 519 |
+
|
| 520 |
+
text = result.get("text", "")
|
| 521 |
+
|
| 522 |
+
if callback:
|
| 523 |
+
await callback(text)
|
| 524 |
+
|
| 525 |
+
return text
|
| 526 |
+
|
| 527 |
+
except Exception as e:
|
| 528 |
+
print(f"TranscriptionService: Transcription error: {e}")
|
| 529 |
+
import traceback
|
| 530 |
+
traceback.print_exc()
|
| 531 |
+
return ""
|
| 532 |
+
|
| 533 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 534 |
+
"""Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
|
| 535 |
+
try:
|
| 536 |
+
# Detect format by checking magic bytes
|
| 537 |
+
is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
|
| 538 |
+
is_wav = audio_data[:4] == b'RIFF'
|
| 539 |
+
|
| 540 |
+
import sys
|
| 541 |
+
print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
|
| 542 |
+
sys.stdout.flush()
|
| 543 |
+
|
| 544 |
+
# Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
|
| 545 |
+
# This is the most common case now that we strip WAV headers in frontend
|
| 546 |
+
if not is_wav and not is_webm and len(audio_data) > 0:
|
| 547 |
+
try:
|
| 548 |
+
# Assume 16-bit PCM at 48kHz (browser's native rate)
|
| 549 |
+
audio_array = np.frombuffer(audio_data, dtype=np.int16)
|
| 550 |
+
|
| 551 |
+
# Check if this looks like valid audio data (not NaN, reasonable range)
|
| 552 |
+
if len(audio_array) > 0 and not np.isnan(audio_array).any():
|
| 553 |
+
print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
|
| 554 |
+
|
| 555 |
+
# Convert to float32 and normalize
|
| 556 |
+
audio_float = audio_array.astype(np.float32) / 32768.0
|
| 557 |
+
|
| 558 |
+
# Resample from 48kHz to 16kHz
|
| 559 |
+
import librosa
|
| 560 |
+
audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
|
| 561 |
+
print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
|
| 562 |
+
|
| 563 |
+
return audio_array
|
| 564 |
+
except Exception as pcm_error:
|
| 565 |
+
print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
|
| 566 |
+
# Fall through to other methods
|
| 567 |
+
|
| 568 |
+
if is_webm:
|
| 569 |
+
# Decode WebM/Opus using pydub (requires ffmpeg)
|
| 570 |
+
try:
|
| 571 |
+
from pydub import AudioSegment
|
| 572 |
+
audio_io = io.BytesIO(audio_data)
|
| 573 |
+
audio_segment = AudioSegment.from_file(audio_io, format="webm")
|
| 574 |
+
|
| 575 |
+
# Convert to mono 16kHz
|
| 576 |
+
audio_segment = audio_segment.set_channels(1)
|
| 577 |
+
audio_segment = audio_segment.set_frame_rate(16000)
|
| 578 |
+
|
| 579 |
+
# Convert to numpy array
|
| 580 |
+
samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
|
| 581 |
+
# Normalize to float32 [-1, 1]
|
| 582 |
+
audio_array = samples.astype(np.float32) / 32768.0
|
| 583 |
+
return audio_array
|
| 584 |
+
except Exception as webm_error:
|
| 585 |
+
print(f"TranscriptionService: WebM decoding error: {webm_error}")
|
| 586 |
+
# Fall through to other methods
|
| 587 |
+
|
| 588 |
+
if is_wav:
|
| 589 |
+
# Decode WAV format
|
| 590 |
+
try:
|
| 591 |
+
audio_io = io.BytesIO(audio_data)
|
| 592 |
+
with wave.open(audio_io, 'rb') as wav_file:
|
| 593 |
+
sample_rate = wav_file.getframerate()
|
| 594 |
+
channels = wav_file.getnchannels()
|
| 595 |
+
sample_width = wav_file.getsampwidth()
|
| 596 |
+
|
| 597 |
+
print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
|
| 598 |
+
|
| 599 |
+
frames = wav_file.readframes(-1)
|
| 600 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 601 |
+
|
| 602 |
+
# Resample if needed
|
| 603 |
+
if sample_rate != 16000:
|
| 604 |
+
print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
|
| 605 |
+
import librosa
|
| 606 |
+
# Convert to float first
|
| 607 |
+
audio_float = audio_array.astype(np.float32) / 32768.0
|
| 608 |
+
# Resample
|
| 609 |
+
audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
|
| 610 |
+
print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
|
| 611 |
+
else:
|
| 612 |
+
# Convert to float32 and normalize
|
| 613 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 614 |
+
|
| 615 |
+
print(f"Returning audio array: {len(audio_array)} samples", flush=True)
|
| 616 |
+
return audio_array
|
| 617 |
+
except Exception as wav_error:
|
| 618 |
+
print(f"TranscriptionService: WAV decoding error: {wav_error}")
|
| 619 |
+
import traceback
|
| 620 |
+
traceback.print_exc()
|
| 621 |
+
|
| 622 |
+
# Fallback: assume raw float32 audio data
|
| 623 |
+
try:
|
| 624 |
+
audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 625 |
+
return audio_array
|
| 626 |
+
except Exception:
|
| 627 |
+
pass
|
| 628 |
+
|
| 629 |
+
# Last resort: return empty array
|
| 630 |
+
return np.array([], dtype=np.float32)
|
| 631 |
+
|
| 632 |
+
except Exception as e:
|
| 633 |
+
print(f"TranscriptionService: Audio conversion error: {e}")
|
| 634 |
+
return np.array([], dtype=np.float32)
|
| 635 |
+
|
| 636 |
+
def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
|
| 637 |
+
"""Convert numpy audio array back to WAV bytes for storage"""
|
| 638 |
+
try:
|
| 639 |
+
# Ensure float32 format
|
| 640 |
+
if audio_array.dtype != np.float32:
|
| 641 |
+
audio_array = audio_array.astype(np.float32)
|
| 642 |
+
|
| 643 |
+
# Convert to 16-bit PCM for WAV storage
|
| 644 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
| 645 |
+
|
| 646 |
+
# Create WAV bytes
|
| 647 |
+
wav_buffer = io.BytesIO()
|
| 648 |
+
with wave.open(wav_buffer, 'wb') as wav_file:
|
| 649 |
+
wav_file.setnchannels(1) # Mono
|
| 650 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 651 |
+
wav_file.setframerate(16000) # 16kHz
|
| 652 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 653 |
+
|
| 654 |
+
return wav_buffer.getvalue()
|
| 655 |
+
|
| 656 |
+
except Exception as e:
|
| 657 |
+
print(f"Error converting audio array to bytes: {e}")
|
| 658 |
+
return b''
|
| 659 |
+
|
| 660 |
+
def clear_participant_buffers(self, participant_id: str):
|
| 661 |
+
"""Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
|
| 662 |
+
if participant_id in self.candidate_audio_buffers:
|
| 663 |
+
del self.candidate_audio_buffers[participant_id]
|
| 664 |
+
if participant_id in self.candidate_text_cache:
|
| 665 |
+
del self.candidate_text_cache[participant_id]
|
| 666 |
+
if participant_id in self.silence_counters:
|
| 667 |
+
del self.silence_counters[participant_id]
|
| 668 |
+
if participant_id in self.sentence_finalized:
|
| 669 |
+
del self.sentence_finalized[participant_id]
|
| 670 |
+
|
| 671 |
+
async def load_remaining_models_in_background(self):
|
| 672 |
+
"""Load all remaining ASR models in the background after startup"""
|
| 673 |
+
try:
|
| 674 |
+
print("ASR: Starting background loading of additional language models...")
|
| 675 |
+
for lang_code in self.asr_config.keys():
|
| 676 |
+
if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
|
| 677 |
+
try:
|
| 678 |
+
print(f"ASR: Background loading model for {lang_code}...")
|
| 679 |
+
self.models_loading_status[lang_code] = "loading"
|
| 680 |
+
|
| 681 |
+
model_config = self.asr_config[lang_code]
|
| 682 |
+
# Build pipeline parameters
|
| 683 |
+
pipeline_params = {
|
| 684 |
+
"task": "automatic-speech-recognition",
|
| 685 |
+
"model": model_config["model_repo"],
|
| 686 |
+
"device": self.device,
|
| 687 |
+
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
# Add trust_remote_code if specified
|
| 691 |
+
if model_config.get("trust_remote_code", False):
|
| 692 |
+
pipeline_params["trust_remote_code"] = True
|
| 693 |
+
|
| 694 |
+
pipeline_obj = pipeline(**pipeline_params)
|
| 695 |
+
self.asr_pipelines[lang_code] = pipeline_obj
|
| 696 |
+
self.models_loading_status[lang_code] = "loaded"
|
| 697 |
+
print(f"ASR: Successfully loaded model for {lang_code} in background")
|
| 698 |
+
|
| 699 |
+
# Add a small delay between loading models to prevent overwhelming the system
|
| 700 |
+
await asyncio.sleep(2)
|
| 701 |
+
except Exception as e:
|
| 702 |
+
print(f"ASR: Failed to load model for {lang_code} in background: {e}")
|
| 703 |
+
self.models_loading_status[lang_code] = "failed"
|
| 704 |
+
|
| 705 |
+
print("ASR: Background loading of all language models complete")
|
| 706 |
+
print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
|
| 707 |
+
except Exception as e:
|
| 708 |
+
print(f"ASR: Error in background model loading: {e}")
|
| 709 |
+
|
| 710 |
+
def start_background_loading(self):
|
| 711 |
+
"""Start background loading of models as a non-blocking task"""
|
| 712 |
+
if self.background_loading_task is None:
|
| 713 |
+
self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
|
| 714 |
+
print("ASR: Background model loading task started")
|
| 715 |
+
|
| 716 |
+
async def cleanup(self):
|
| 717 |
+
"""Cleanup resources"""
|
| 718 |
+
# Cancel background loading if still running
|
| 719 |
+
if self.background_loading_task and not self.background_loading_task.done():
|
| 720 |
+
self.background_loading_task.cancel()
|
| 721 |
+
try:
|
| 722 |
+
await self.background_loading_task
|
| 723 |
+
except asyncio.CancelledError:
|
| 724 |
+
pass
|
| 725 |
+
|
| 726 |
+
self.asr_pipelines.clear()
|
app/services/transcription_service_onnx.py
ADDED
|
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, Optional, Callable
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import onnxruntime as ort
|
| 8 |
+
from transformers import AutoProcessor, WhisperProcessor
|
| 9 |
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
| 10 |
+
import os
|
| 11 |
+
from app.models import LanguageCode
|
| 12 |
+
|
| 13 |
+
class ONNXTranscriptionService:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.asr_models: Dict[str, any] = {}
|
| 16 |
+
self.processors: Dict[str, any] = {}
|
| 17 |
+
self.max_asr_models = 2 # Memory management - keep max 2 models loaded
|
| 18 |
+
self.model_cache = OrderedDict() # LRU cache for models
|
| 19 |
+
|
| 20 |
+
# GPU optimization - detect and configure providers
|
| 21 |
+
available_providers = ort.get_available_providers()
|
| 22 |
+
print(f"ONNX ASR: Available providers: {available_providers}")
|
| 23 |
+
|
| 24 |
+
if 'CUDAExecutionProvider' in available_providers:
|
| 25 |
+
# Configure CUDA provider with optimizations
|
| 26 |
+
cuda_provider_options = {
|
| 27 |
+
'device_id': 0,
|
| 28 |
+
'arena_extend_strategy': 'kNextPowerOfTwo',
|
| 29 |
+
'gpu_mem_limit': int(0.8 * 1024 * 1024 * 1024), # 80% of GPU memory
|
| 30 |
+
'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
| 31 |
+
'do_copy_in_default_stream': True,
|
| 32 |
+
'enable_tracing': True, # Enable tracing for better diagnostics
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# Include TensorRT if available, then CUDA, then CPU
|
| 36 |
+
provider_list = []
|
| 37 |
+
if 'TensorrtExecutionProvider' in available_providers:
|
| 38 |
+
provider_list.append('TensorrtExecutionProvider')
|
| 39 |
+
provider_list.append(('CUDAExecutionProvider', cuda_provider_options))
|
| 40 |
+
provider_list.append('CPUExecutionProvider')
|
| 41 |
+
|
| 42 |
+
self.providers = provider_list
|
| 43 |
+
print(f"ONNX ASR: Using GPU acceleration with providers: {[p[0] if isinstance(p, tuple) else p for p in provider_list]}")
|
| 44 |
+
print(f"ONNX ASR: GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
|
| 45 |
+
else:
|
| 46 |
+
self.providers = ['CPUExecutionProvider']
|
| 47 |
+
print("ONNX ASR: CUDA not available, using CPU execution")
|
| 48 |
+
|
| 49 |
+
print(f"ONNX ASR: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
|
| 50 |
+
|
| 51 |
+
# ONNX Model configurations - using pre-converted ONNX models from HuggingFace
|
| 52 |
+
self.asr_config = {
|
| 53 |
+
"eng": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}, # Pre-converted ONNX model
|
| 54 |
+
"swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 55 |
+
"kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 56 |
+
"kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 57 |
+
"mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 58 |
+
"luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 59 |
+
"som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
# Alternative model configurations for different performance tiers
|
| 63 |
+
self.alternative_models = {
|
| 64 |
+
"eng_small": {"model_repo": "mutisya/whisper-small-en-onnx", "model_type": "whisper", "use_onnx": True},
|
| 65 |
+
"eng_base": {"model_repo": "mutisya/whisper-base-en-onnx", "model_type": "whisper", "use_onnx": True},
|
| 66 |
+
"eng_medium": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
self.preload_languages = ["eng"]
|
| 70 |
+
|
| 71 |
+
# Current model performance mode (small, base, medium)
|
| 72 |
+
# Can be configured via environment variable WHISPER_MODEL_SIZE
|
| 73 |
+
self.performance_mode = os.getenv("WHISPER_MODEL_SIZE", "medium").lower()
|
| 74 |
+
|
| 75 |
+
# Enhanced audio buffering for VAD-based sentence detection
|
| 76 |
+
self.candidate_audio_buffers: Dict[str, bytes] = {}
|
| 77 |
+
self.candidate_text_cache: Dict[str, str] = {}
|
| 78 |
+
self.silence_counters: Dict[str, int] = {}
|
| 79 |
+
self.sentence_finalized: Dict[str, bool] = {}
|
| 80 |
+
|
| 81 |
+
# VAD parameters
|
| 82 |
+
self.silence_threshold = 2
|
| 83 |
+
self.min_sentence_length = 0.03
|
| 84 |
+
|
| 85 |
+
def set_performance_mode(self, mode: str):
|
| 86 |
+
"""Set the performance mode for English models (small, base, medium)"""
|
| 87 |
+
if mode in ["small", "base", "medium"]:
|
| 88 |
+
self.performance_mode = mode
|
| 89 |
+
# Update the English model configuration based on performance mode
|
| 90 |
+
if f"eng_{mode}" in self.alternative_models:
|
| 91 |
+
self.asr_config["eng"] = self.alternative_models[f"eng_{mode}"]
|
| 92 |
+
# Clear cached English model to force reload with new configuration
|
| 93 |
+
if "eng" in self.model_cache:
|
| 94 |
+
del self.model_cache["eng"]
|
| 95 |
+
if "eng" in self.asr_models:
|
| 96 |
+
del self.asr_models["eng"]
|
| 97 |
+
if "eng" in self.processors:
|
| 98 |
+
del self.processors["eng"]
|
| 99 |
+
print(f"Performance mode set to {mode}. English model will be reloaded on next use.")
|
| 100 |
+
else:
|
| 101 |
+
print(f"Warning: No model configuration found for performance mode {mode}")
|
| 102 |
+
else:
|
| 103 |
+
print(f"Invalid performance mode: {mode}. Must be one of: small, base, medium")
|
| 104 |
+
|
| 105 |
+
async def initialize(self):
|
| 106 |
+
"""Initialize ASR models for preloaded languages"""
|
| 107 |
+
print(f"ONNX ASR: Initializing with providers: {self.providers}")
|
| 108 |
+
|
| 109 |
+
# Apply performance mode to English model configuration
|
| 110 |
+
if self.performance_mode in ["small", "base", "medium"]:
|
| 111 |
+
if f"eng_{self.performance_mode}" in self.alternative_models:
|
| 112 |
+
self.asr_config["eng"] = self.alternative_models[f"eng_{self.performance_mode}"]
|
| 113 |
+
print(f"Using Whisper {self.performance_mode} model for English")
|
| 114 |
+
else:
|
| 115 |
+
print(f"Warning: Performance mode {self.performance_mode} not available, using default medium")
|
| 116 |
+
|
| 117 |
+
for lang_code in self.preload_languages:
|
| 118 |
+
if lang_code in self.asr_config:
|
| 119 |
+
try:
|
| 120 |
+
await self.ensure_model_loaded(lang_code)
|
| 121 |
+
except Exception as e:
|
| 122 |
+
print(f"Failed to load ASR model for {lang_code}: {e}")
|
| 123 |
+
|
| 124 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 125 |
+
"""Load ASR model for language if not already loaded with LRU cache"""
|
| 126 |
+
if language_code in self.model_cache:
|
| 127 |
+
# Move to end (most recently used)
|
| 128 |
+
self.model_cache.move_to_end(language_code)
|
| 129 |
+
return
|
| 130 |
+
|
| 131 |
+
if language_code not in self.asr_config:
|
| 132 |
+
raise ValueError(f"Language {language_code} not supported")
|
| 133 |
+
|
| 134 |
+
model_config = self.asr_config[language_code]
|
| 135 |
+
|
| 136 |
+
# Check if we need to evict old models
|
| 137 |
+
while len(self.model_cache) >= self.max_asr_models:
|
| 138 |
+
# Remove least recently used model
|
| 139 |
+
old_lang, _ = self.model_cache.popitem(last=False)
|
| 140 |
+
if old_lang in self.asr_models:
|
| 141 |
+
del self.asr_models[old_lang]
|
| 142 |
+
if old_lang in self.processors:
|
| 143 |
+
del self.processors[old_lang]
|
| 144 |
+
print(f"ONNX ASR: Evicted model for {old_lang} (LRU cache)")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
if model_config.get("use_onnx", False):
|
| 148 |
+
# Load ONNX model
|
| 149 |
+
print(f"ONNX ASR: Loading ONNX model for {language_code}")
|
| 150 |
+
|
| 151 |
+
# Special handling for Whisper models
|
| 152 |
+
if model_config.get("model_type") == "whisper":
|
| 153 |
+
print(f"ONNX ASR: Loading Whisper ONNX model from {model_config['model_repo']}")
|
| 154 |
+
|
| 155 |
+
# Get authentication token for private repos
|
| 156 |
+
import os
|
| 157 |
+
auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
|
| 158 |
+
|
| 159 |
+
# Load pre-converted Whisper ONNX model using Optimum
|
| 160 |
+
load_kwargs = {
|
| 161 |
+
# export=False because we're using pre-converted models
|
| 162 |
+
"export": False,
|
| 163 |
+
# use_cache=True because our models now include past key value variants for optimization
|
| 164 |
+
"use_cache": True,
|
| 165 |
+
# Add authentication token for private repos
|
| 166 |
+
"token": auth_token
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Configure providers - pass all available providers to Optimum
|
| 170 |
+
provider_names = [p[0] if isinstance(p, tuple) else p for p in self.providers]
|
| 171 |
+
load_kwargs["providers"] = provider_names
|
| 172 |
+
print(f"ONNX ASR: Whisper using providers: {provider_names}")
|
| 173 |
+
|
| 174 |
+
# Add subfolder if specified (for models that store ONNX in subfolders)
|
| 175 |
+
if "subfolder" in model_config:
|
| 176 |
+
load_kwargs["subfolder"] = model_config["subfolder"]
|
| 177 |
+
|
| 178 |
+
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
| 179 |
+
model_config["model_repo"],
|
| 180 |
+
**load_kwargs
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Load Whisper processor with authentication token
|
| 184 |
+
processor = WhisperProcessor.from_pretrained(
|
| 185 |
+
model_config["model_repo"],
|
| 186 |
+
token=auth_token
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Configure for English transcription
|
| 190 |
+
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
|
| 191 |
+
language="en",
|
| 192 |
+
task="transcribe"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
self.asr_models[language_code] = model
|
| 196 |
+
self.processors[language_code] = processor
|
| 197 |
+
|
| 198 |
+
print(f"ONNX ASR: Successfully loaded Whisper ONNX model for {language_code}")
|
| 199 |
+
|
| 200 |
+
else:
|
| 201 |
+
# Original wav2vec2-bert model loading logic
|
| 202 |
+
# Create ONNX session with optimizations and verbose logging
|
| 203 |
+
session_options = ort.SessionOptions()
|
| 204 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 205 |
+
|
| 206 |
+
# Enable verbose logging to diagnose operator assignments
|
| 207 |
+
session_options.log_severity_level = 1 # WARNING level for detailed logs
|
| 208 |
+
session_options.logid = "ONNX_ASR" # Prefix for log identification
|
| 209 |
+
|
| 210 |
+
# Use configured providers with optimizations
|
| 211 |
+
providers = self.providers
|
| 212 |
+
print(f"ONNX ASR: wav2vec2-bert using providers: {[p[0] if isinstance(p, tuple) else p for p in providers]}")
|
| 213 |
+
|
| 214 |
+
# Get authentication token for private repos
|
| 215 |
+
import os
|
| 216 |
+
auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
|
| 217 |
+
|
| 218 |
+
# Download model files from HuggingFace Hub with authentication
|
| 219 |
+
from huggingface_hub import hf_hub_download
|
| 220 |
+
onnx_path = hf_hub_download(
|
| 221 |
+
repo_id=model_config["model_repo"],
|
| 222 |
+
filename="model.onnx",
|
| 223 |
+
token=auth_token
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
|
| 227 |
+
|
| 228 |
+
# Load processor for preprocessing with authentication
|
| 229 |
+
processor = AutoProcessor.from_pretrained(
|
| 230 |
+
model_config["model_repo"],
|
| 231 |
+
token=auth_token
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
self.asr_models[language_code] = session
|
| 235 |
+
self.processors[language_code] = processor
|
| 236 |
+
|
| 237 |
+
print(f"ONNX ASR: Successfully loaded ONNX model for {language_code}")
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
# This service is ONNX-only - no PyTorch fallback
|
| 241 |
+
raise ValueError(f"Language {language_code} is not configured for ONNX models. Set 'use_onnx': True in config.")
|
| 242 |
+
|
| 243 |
+
# Add to cache
|
| 244 |
+
self.model_cache[language_code] = True
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"Failed to load ASR model for {language_code}: {e}")
|
| 248 |
+
raise
|
| 249 |
+
|
| 250 |
+
async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
|
| 251 |
+
has_voice_activity: bool = True,
|
| 252 |
+
progress_callback: Optional[Callable] = None,
|
| 253 |
+
sentence_callback: Optional[Callable] = None) -> str:
|
| 254 |
+
"""Process audio chunk with VAD-based sentence detection using ONNX models"""
|
| 255 |
+
try:
|
| 256 |
+
# Initialize buffers if needed
|
| 257 |
+
if participant_id not in self.candidate_audio_buffers:
|
| 258 |
+
self.candidate_audio_buffers[participant_id] = b''
|
| 259 |
+
self.candidate_text_cache[participant_id] = ""
|
| 260 |
+
self.silence_counters[participant_id] = 0
|
| 261 |
+
self.sentence_finalized[participant_id] = False
|
| 262 |
+
|
| 263 |
+
# Convert current chunk to numpy array for processing
|
| 264 |
+
current_chunk_array = self._bytes_to_audio_array(audio_data)
|
| 265 |
+
if len(current_chunk_array) == 0:
|
| 266 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 267 |
+
|
| 268 |
+
# Normalize the audio chunk
|
| 269 |
+
current_chunk_array = current_chunk_array.astype(np.float32)
|
| 270 |
+
if np.max(np.abs(current_chunk_array)) > 0:
|
| 271 |
+
current_chunk_array /= np.max(np.abs(current_chunk_array))
|
| 272 |
+
|
| 273 |
+
# Get existing accumulated audio array
|
| 274 |
+
existing_buffer = self.candidate_audio_buffers[participant_id]
|
| 275 |
+
if len(existing_buffer) > 0:
|
| 276 |
+
existing_array = self._bytes_to_audio_array(existing_buffer)
|
| 277 |
+
if len(existing_array) > 0:
|
| 278 |
+
combined_array = np.concatenate([existing_array, current_chunk_array])
|
| 279 |
+
else:
|
| 280 |
+
combined_array = current_chunk_array
|
| 281 |
+
else:
|
| 282 |
+
combined_array = current_chunk_array
|
| 283 |
+
|
| 284 |
+
# Convert back to bytes for storage
|
| 285 |
+
combined_bytes = self._audio_array_to_bytes(combined_array)
|
| 286 |
+
self.candidate_audio_buffers[participant_id] = combined_bytes
|
| 287 |
+
|
| 288 |
+
# Update silence counter based on voice activity
|
| 289 |
+
if not has_voice_activity:
|
| 290 |
+
self.silence_counters[participant_id] += 1
|
| 291 |
+
else:
|
| 292 |
+
self.silence_counters[participant_id] = 0
|
| 293 |
+
|
| 294 |
+
# Check if we should finalize sentence due to prolonged silence
|
| 295 |
+
should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
|
| 296 |
+
len(combined_array) > 0 and
|
| 297 |
+
not self.sentence_finalized[participant_id])
|
| 298 |
+
|
| 299 |
+
if should_finalize:
|
| 300 |
+
return await self._finalize_candidate_sentence(
|
| 301 |
+
language_code, participant_id, sentence_callback
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
# Always run transcription on the accumulated audio
|
| 305 |
+
audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
|
| 306 |
+
|
| 307 |
+
if audio_duration_sec < 0.1: # Very short minimum
|
| 308 |
+
if progress_callback:
|
| 309 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 310 |
+
await progress_callback(cached_text, False)
|
| 311 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 312 |
+
|
| 313 |
+
# Force finalization if buffer gets too long
|
| 314 |
+
if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]:
|
| 315 |
+
return await self._finalize_candidate_sentence(
|
| 316 |
+
language_code, participant_id, sentence_callback
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Run voice activity detection on the accumulated audio before transcription
|
| 320 |
+
has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
|
| 321 |
+
|
| 322 |
+
if not has_voice_in_buffer:
|
| 323 |
+
if progress_callback:
|
| 324 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 325 |
+
await progress_callback(cached_text, False)
|
| 326 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 327 |
+
|
| 328 |
+
# Run transcription
|
| 329 |
+
await self.ensure_model_loaded(language_code)
|
| 330 |
+
|
| 331 |
+
# Double-check voice activity before running expensive ASR
|
| 332 |
+
has_voice_for_asr = self.has_voice_activity(combined_bytes)
|
| 333 |
+
if not has_voice_for_asr:
|
| 334 |
+
print(f"ONNX ASR: No voice activity detected, skipping ASR execution for {participant_id}")
|
| 335 |
+
if progress_callback:
|
| 336 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 337 |
+
await progress_callback(cached_text, False)
|
| 338 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 339 |
+
|
| 340 |
+
if language_code not in self.asr_models:
|
| 341 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 342 |
+
|
| 343 |
+
print(f"ONNX ASR: Running transcription for {participant_id} with {audio_duration_sec:.2f}s of audio")
|
| 344 |
+
|
| 345 |
+
# Run ONNX inference (this service is ONNX-only)
|
| 346 |
+
model_config = self.asr_config[language_code]
|
| 347 |
+
if not model_config.get("use_onnx", False):
|
| 348 |
+
raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
|
| 349 |
+
|
| 350 |
+
# ONNX inference
|
| 351 |
+
text = await self._run_onnx_inference(combined_array, language_code)
|
| 352 |
+
|
| 353 |
+
# Filter out common ASR artifacts
|
| 354 |
+
artifacts = [
|
| 355 |
+
"thank you", "thanks", "bye", ".", ",", "?", "!",
|
| 356 |
+
"um", "uh", "ah", "hmm", "mm", "mhm",
|
| 357 |
+
"you", "the", "a", "an", "and", "but", "or",
|
| 358 |
+
"music", "laughter", "applause", "[music]", "[laughter]",
|
| 359 |
+
]
|
| 360 |
+
|
| 361 |
+
# Check if the result is likely an artifact
|
| 362 |
+
is_artifact = (
|
| 363 |
+
len(text) < 3 or
|
| 364 |
+
text.lower() in artifacts or
|
| 365 |
+
len(text.split()) == 1 and len(text) < 6
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
if is_artifact:
|
| 369 |
+
text = self.candidate_text_cache.get(participant_id, "")
|
| 370 |
+
|
| 371 |
+
# Cache the current candidate text
|
| 372 |
+
self.candidate_text_cache[participant_id] = text
|
| 373 |
+
|
| 374 |
+
# Force completion if we have reasonable text and some silence
|
| 375 |
+
word_count = len(text.split()) if text else 0
|
| 376 |
+
if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
|
| 377 |
+
not self.sentence_finalized[participant_id]):
|
| 378 |
+
return await self._finalize_candidate_sentence(
|
| 379 |
+
language_code, participant_id, sentence_callback
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Always send progress update
|
| 383 |
+
if progress_callback:
|
| 384 |
+
await progress_callback(text, False)
|
| 385 |
+
|
| 386 |
+
return text
|
| 387 |
+
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(f"ONNX TranscriptionService: Error processing audio chunk: {e}")
|
| 390 |
+
import traceback
|
| 391 |
+
traceback.print_exc()
|
| 392 |
+
if progress_callback:
|
| 393 |
+
cached_text = self.candidate_text_cache.get(participant_id, "")
|
| 394 |
+
await progress_callback(cached_text, False)
|
| 395 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 396 |
+
|
| 397 |
+
async def _run_onnx_inference(self, audio_array: np.ndarray, language_code: str) -> str:
|
| 398 |
+
"""Run ONNX inference for speech recognition"""
|
| 399 |
+
try:
|
| 400 |
+
model = self.asr_models[language_code]
|
| 401 |
+
processor = self.processors[language_code]
|
| 402 |
+
model_config = self.asr_config[language_code]
|
| 403 |
+
|
| 404 |
+
# Check if this is a Whisper model
|
| 405 |
+
if model_config.get("model_type") == "whisper":
|
| 406 |
+
# Whisper-specific processing using Optimum
|
| 407 |
+
import torch
|
| 408 |
+
|
| 409 |
+
# Process audio input for Whisper
|
| 410 |
+
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
|
| 411 |
+
|
| 412 |
+
# Generate transcription using the ORTModelForSpeechSeq2Seq
|
| 413 |
+
predicted_ids = model.generate(inputs.input_features, max_length=448)
|
| 414 |
+
|
| 415 |
+
# Decode the generated IDs
|
| 416 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
| 417 |
+
|
| 418 |
+
return transcription[0].strip() if transcription else ""
|
| 419 |
+
else:
|
| 420 |
+
# Original wav2vec2-bert processing
|
| 421 |
+
session = model
|
| 422 |
+
|
| 423 |
+
# Preprocess audio
|
| 424 |
+
inputs = processor(audio_array, sampling_rate=16000, return_tensors="np")
|
| 425 |
+
|
| 426 |
+
# Get input names for ONNX session
|
| 427 |
+
input_names = [inp.name for inp in session.get_inputs()]
|
| 428 |
+
|
| 429 |
+
# Prepare inputs for ONNX
|
| 430 |
+
onnx_inputs = {}
|
| 431 |
+
for name in input_names:
|
| 432 |
+
if name in inputs:
|
| 433 |
+
onnx_inputs[name] = inputs[name]
|
| 434 |
+
elif name == "input_values" and "input_features" in inputs:
|
| 435 |
+
onnx_inputs[name] = inputs["input_features"]
|
| 436 |
+
elif name == "attention_mask" and "attention_mask" in inputs:
|
| 437 |
+
onnx_inputs[name] = inputs["attention_mask"]
|
| 438 |
+
|
| 439 |
+
# Run ONNX inference
|
| 440 |
+
outputs = session.run(None, onnx_inputs)
|
| 441 |
+
|
| 442 |
+
# Post-process outputs (assuming CTC decoding)
|
| 443 |
+
logits = outputs[0] # First output should be logits
|
| 444 |
+
|
| 445 |
+
# Simple greedy CTC decoding
|
| 446 |
+
predicted_ids = np.argmax(logits, axis=-1)
|
| 447 |
+
|
| 448 |
+
# Decode using processor
|
| 449 |
+
text = processor.batch_decode(predicted_ids)[0]
|
| 450 |
+
|
| 451 |
+
return text.strip()
|
| 452 |
+
|
| 453 |
+
except Exception as e:
|
| 454 |
+
print(f"ONNX ASR: Inference error: {e}")
|
| 455 |
+
import traceback
|
| 456 |
+
traceback.print_exc()
|
| 457 |
+
return ""
|
| 458 |
+
|
| 459 |
+
async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
|
| 460 |
+
sentence_callback: Optional[Callable] = None) -> str:
|
| 461 |
+
"""Finalize the current candidate sentence and clear buffers"""
|
| 462 |
+
try:
|
| 463 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 464 |
+
print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
|
| 465 |
+
return self.candidate_text_cache.get(participant_id, "")
|
| 466 |
+
|
| 467 |
+
final_text = self.candidate_text_cache.get(participant_id, "")
|
| 468 |
+
final_audio_bytes = self.candidate_audio_buffers.get(participant_id, b'')
|
| 469 |
+
|
| 470 |
+
if final_text and len(final_text.strip()) > 0:
|
| 471 |
+
self.sentence_finalized[participant_id] = True
|
| 472 |
+
|
| 473 |
+
if sentence_callback and len(final_audio_bytes) > 0:
|
| 474 |
+
print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
|
| 475 |
+
await sentence_callback(final_text, final_audio_bytes)
|
| 476 |
+
|
| 477 |
+
# Clear buffers for next sentence
|
| 478 |
+
self.candidate_audio_buffers[participant_id] = b''
|
| 479 |
+
self.candidate_text_cache[participant_id] = ""
|
| 480 |
+
self.silence_counters[participant_id] = 0
|
| 481 |
+
self.sentence_finalized[participant_id] = False
|
| 482 |
+
|
| 483 |
+
return final_text
|
| 484 |
+
|
| 485 |
+
except Exception as e:
|
| 486 |
+
print(f"Error finalizing sentence: {e}")
|
| 487 |
+
import traceback
|
| 488 |
+
traceback.print_exc()
|
| 489 |
+
self.sentence_finalized[participant_id] = False
|
| 490 |
+
return ""
|
| 491 |
+
|
| 492 |
+
def has_voice_activity(self, audio_data: bytes, threshold: float = 0.0005) -> bool:
|
| 493 |
+
"""Enhanced VAD based on audio analysis"""
|
| 494 |
+
try:
|
| 495 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 496 |
+
if len(audio_array) == 0:
|
| 497 |
+
return False
|
| 498 |
+
|
| 499 |
+
# Normalize audio
|
| 500 |
+
audio_array = audio_array.astype(np.float32)
|
| 501 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 502 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 503 |
+
|
| 504 |
+
# Calculate multiple features for better VAD
|
| 505 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 506 |
+
peak = np.max(np.abs(audio_array))
|
| 507 |
+
audio_std = np.std(audio_array)
|
| 508 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 509 |
+
|
| 510 |
+
# Voice activity detection
|
| 511 |
+
has_voice_rms = rms > threshold
|
| 512 |
+
has_voice_peak = peak > threshold * 3
|
| 513 |
+
has_voice_variation = audio_std > threshold * 0.8
|
| 514 |
+
has_voice_zcr = zero_crossing_rate > 0.008
|
| 515 |
+
|
| 516 |
+
has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
|
| 517 |
+
|
| 518 |
+
return has_voice
|
| 519 |
+
|
| 520 |
+
except Exception as e:
|
| 521 |
+
print(f"Error in VAD: {e}")
|
| 522 |
+
return True
|
| 523 |
+
|
| 524 |
+
def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.002) -> bool:
|
| 525 |
+
"""Stricter VAD check for pre-transcription filtering"""
|
| 526 |
+
try:
|
| 527 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 528 |
+
if len(audio_array) == 0:
|
| 529 |
+
return False
|
| 530 |
+
|
| 531 |
+
# Normalize audio
|
| 532 |
+
audio_array = audio_array.astype(np.float32)
|
| 533 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 534 |
+
audio_array /= np.max(np.abs(audio_array))
|
| 535 |
+
|
| 536 |
+
# Calculate features with higher thresholds
|
| 537 |
+
rms = np.sqrt(np.mean(audio_array ** 2))
|
| 538 |
+
peak = np.max(np.abs(audio_array))
|
| 539 |
+
audio_std = np.std(audio_array)
|
| 540 |
+
zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
|
| 541 |
+
|
| 542 |
+
# Higher thresholds for meaningful speech detection
|
| 543 |
+
has_meaningful_voice = (
|
| 544 |
+
rms > threshold and
|
| 545 |
+
peak > threshold * 2 and
|
| 546 |
+
audio_std > threshold * 0.5 and
|
| 547 |
+
zero_crossing_rate > 0.015
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
return has_meaningful_voice
|
| 551 |
+
|
| 552 |
+
except Exception as e:
|
| 553 |
+
print(f"Error in meaningful VAD: {e}")
|
| 554 |
+
return False
|
| 555 |
+
|
| 556 |
+
async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
|
| 557 |
+
"""Force complete any pending sentence for a participant"""
|
| 558 |
+
try:
|
| 559 |
+
if self.sentence_finalized.get(participant_id, False):
|
| 560 |
+
print(f"Force completion: Sentence for participant {participant_id} already finalized")
|
| 561 |
+
return ""
|
| 562 |
+
|
| 563 |
+
if participant_id in self.candidate_text_cache:
|
| 564 |
+
cached_text = self.candidate_text_cache[participant_id]
|
| 565 |
+
|
| 566 |
+
if cached_text and len(cached_text.strip()) > 0:
|
| 567 |
+
result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
|
| 568 |
+
return result
|
| 569 |
+
|
| 570 |
+
return ""
|
| 571 |
+
|
| 572 |
+
except Exception as e:
|
| 573 |
+
print(f"Error in force_complete_sentence: {e}")
|
| 574 |
+
import traceback
|
| 575 |
+
traceback.print_exc()
|
| 576 |
+
return ""
|
| 577 |
+
|
| 578 |
+
async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
|
| 579 |
+
"""Transcribe audio data to text using ONNX models"""
|
| 580 |
+
try:
|
| 581 |
+
# Check for voice activity before running ASR
|
| 582 |
+
has_voice = self.has_voice_activity(audio_data)
|
| 583 |
+
if not has_voice:
|
| 584 |
+
print(f"ONNX ASR: No voice activity detected, skipping transcription")
|
| 585 |
+
return ""
|
| 586 |
+
|
| 587 |
+
await self.ensure_model_loaded(language_code)
|
| 588 |
+
|
| 589 |
+
if language_code not in self.asr_models:
|
| 590 |
+
raise ValueError(f"ASR model not available for language: {language_code}")
|
| 591 |
+
|
| 592 |
+
# Convert audio bytes to numpy array
|
| 593 |
+
audio_array = self._bytes_to_audio_array(audio_data)
|
| 594 |
+
|
| 595 |
+
print(f"ONNX ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
|
| 596 |
+
|
| 597 |
+
# Run ONNX inference (this service is ONNX-only)
|
| 598 |
+
model_config = self.asr_config[language_code]
|
| 599 |
+
if not model_config.get("use_onnx", False):
|
| 600 |
+
raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
|
| 601 |
+
|
| 602 |
+
# ONNX inference
|
| 603 |
+
text = await self._run_onnx_inference(audio_array, language_code)
|
| 604 |
+
|
| 605 |
+
if callback:
|
| 606 |
+
await callback(text)
|
| 607 |
+
|
| 608 |
+
return text
|
| 609 |
+
|
| 610 |
+
except Exception as e:
|
| 611 |
+
print(f"ONNX TranscriptionService: Transcription error: {e}")
|
| 612 |
+
import traceback
|
| 613 |
+
traceback.print_exc()
|
| 614 |
+
return ""
|
| 615 |
+
|
| 616 |
+
def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
|
| 617 |
+
"""Convert audio bytes to numpy array"""
|
| 618 |
+
try:
|
| 619 |
+
# Try to decode as WAV
|
| 620 |
+
try:
|
| 621 |
+
audio_io = io.BytesIO(audio_data)
|
| 622 |
+
with wave.open(audio_io, 'rb') as wav_file:
|
| 623 |
+
frames = wav_file.readframes(-1)
|
| 624 |
+
audio_array = np.frombuffer(frames, dtype=np.int16)
|
| 625 |
+
# Convert to float32 and normalize
|
| 626 |
+
audio_array = audio_array.astype(np.float32) / 32768.0
|
| 627 |
+
return audio_array
|
| 628 |
+
except Exception:
|
| 629 |
+
pass
|
| 630 |
+
|
| 631 |
+
# Fallback: assume raw float32 audio data
|
| 632 |
+
try:
|
| 633 |
+
audio_array = np.frombuffer(audio_data, dtype=np.float32)
|
| 634 |
+
return audio_array
|
| 635 |
+
except Exception:
|
| 636 |
+
pass
|
| 637 |
+
|
| 638 |
+
return np.array([], dtype=np.float32)
|
| 639 |
+
|
| 640 |
+
except Exception as e:
|
| 641 |
+
print(f"ONNX TranscriptionService: Audio conversion error: {e}")
|
| 642 |
+
return np.array([], dtype=np.float32)
|
| 643 |
+
|
| 644 |
+
def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
|
| 645 |
+
"""Convert numpy audio array back to WAV bytes for storage"""
|
| 646 |
+
try:
|
| 647 |
+
if audio_array.dtype != np.float32:
|
| 648 |
+
audio_array = audio_array.astype(np.float32)
|
| 649 |
+
|
| 650 |
+
# Convert to 16-bit PCM for WAV storage
|
| 651 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
| 652 |
+
|
| 653 |
+
# Create WAV bytes
|
| 654 |
+
wav_buffer = io.BytesIO()
|
| 655 |
+
with wave.open(wav_buffer, 'wb') as wav_file:
|
| 656 |
+
wav_file.setnchannels(1) # Mono
|
| 657 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 658 |
+
wav_file.setframerate(16000) # 16kHz
|
| 659 |
+
wav_file.writeframes(audio_int16.tobytes())
|
| 660 |
+
|
| 661 |
+
return wav_buffer.getvalue()
|
| 662 |
+
|
| 663 |
+
except Exception as e:
|
| 664 |
+
print(f"Error converting audio array to bytes: {e}")
|
| 665 |
+
return b''
|
| 666 |
+
|
| 667 |
+
def clear_participant_buffers(self, participant_id: str):
|
| 668 |
+
"""Clear all buffers for a participant"""
|
| 669 |
+
if participant_id in self.candidate_audio_buffers:
|
| 670 |
+
del self.candidate_audio_buffers[participant_id]
|
| 671 |
+
if participant_id in self.candidate_text_cache:
|
| 672 |
+
del self.candidate_text_cache[participant_id]
|
| 673 |
+
if participant_id in self.silence_counters:
|
| 674 |
+
del self.silence_counters[participant_id]
|
| 675 |
+
if participant_id in self.sentence_finalized:
|
| 676 |
+
del self.sentence_finalized[participant_id]
|
| 677 |
+
|
| 678 |
+
async def cleanup(self):
|
| 679 |
+
"""Cleanup resources"""
|
| 680 |
+
self.asr_models.clear()
|
| 681 |
+
self.processors.clear()
|
| 682 |
+
self.model_cache.clear()
|
app/services/transcription_service_onnx_optimized.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Dict, Optional, Callable
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
import onnxruntime as ort
|
| 8 |
+
from transformers import AutoProcessor, WhisperProcessor
|
| 9 |
+
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
|
| 10 |
+
import os
|
| 11 |
+
from app.models import LanguageCode
|
| 12 |
+
|
| 13 |
+
class OptimizedONNXTranscriptionService:
|
| 14 |
+
"""
|
| 15 |
+
Optimized ONNX Transcription Service that uses pre-converted ONNX models
|
| 16 |
+
instead of performing runtime conversion from PyTorch models.
|
| 17 |
+
|
| 18 |
+
Benefits:
|
| 19 |
+
- Faster container startup (no conversion time)
|
| 20 |
+
- Reduced memory usage during initialization
|
| 21 |
+
- More predictable deployment times
|
| 22 |
+
- Better resource utilization in production
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.asr_models: Dict[str, any] = {}
|
| 27 |
+
self.processors: Dict[str, any] = {}
|
| 28 |
+
self.max_asr_models = 2 # Memory management - keep max 2 models loaded
|
| 29 |
+
self.model_cache = OrderedDict() # LRU cache for models
|
| 30 |
+
|
| 31 |
+
# GPU optimization
|
| 32 |
+
self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if ort.get_available_providers().__contains__('CUDAExecutionProvider') else ['CPUExecutionProvider']
|
| 33 |
+
|
| 34 |
+
# OPTIMIZED ONNX Model configurations - using pre-converted models
|
| 35 |
+
self.asr_config = {
|
| 36 |
+
# English: Use pre-converted ONNX model (no runtime conversion!)
|
| 37 |
+
"eng": {
|
| 38 |
+
"model_repo": "mutisya/whisper-medium-en-onnx", # Pre-converted ONNX model
|
| 39 |
+
"model_type": "whisper",
|
| 40 |
+
"use_onnx": True,
|
| 41 |
+
"export": False # ⭐ KEY CHANGE: No runtime export needed!
|
| 42 |
+
},
|
| 43 |
+
|
| 44 |
+
# African languages: Already using ONNX models
|
| 45 |
+
"swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 46 |
+
"kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 47 |
+
"kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 48 |
+
"mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 49 |
+
"luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
|
| 50 |
+
"som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
self.preload_languages = ["eng"]
|
| 54 |
+
|
| 55 |
+
# Enhanced audio buffering for VAD-based sentence detection
|
| 56 |
+
self.candidate_audio_buffers: Dict[str, bytes] = {}
|
| 57 |
+
self.candidate_text_cache: Dict[str, str] = {}
|
| 58 |
+
self.silence_counters: Dict[str, int] = {}
|
| 59 |
+
self.sentence_finalized: Dict[str, bool] = {}
|
| 60 |
+
|
| 61 |
+
# VAD parameters
|
| 62 |
+
self.silence_threshold = 2
|
| 63 |
+
self.min_sentence_length = 0.03
|
| 64 |
+
|
| 65 |
+
async def initialize(self):
|
| 66 |
+
"""Initialize ASR models for preloaded languages"""
|
| 67 |
+
print(f"🚀 Optimized ONNX ASR: Initializing with providers: {self.providers}")
|
| 68 |
+
print(f"📈 Performance Improvement: Using pre-converted ONNX models (no runtime conversion)")
|
| 69 |
+
|
| 70 |
+
for lang_code in self.preload_languages:
|
| 71 |
+
if lang_code in self.asr_config:
|
| 72 |
+
try:
|
| 73 |
+
start_time = asyncio.get_event_loop().time()
|
| 74 |
+
await self.ensure_model_loaded(lang_code)
|
| 75 |
+
end_time = asyncio.get_event_loop().time()
|
| 76 |
+
print(f"⚡ Model loading time for {lang_code}: {end_time - start_time:.2f}s")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"❌ Failed to load ASR model for {lang_code}: {e}")
|
| 79 |
+
|
| 80 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 81 |
+
"""Load ASR model for language if not already loaded with LRU cache"""
|
| 82 |
+
if language_code in self.model_cache:
|
| 83 |
+
# Move to end (most recently used)
|
| 84 |
+
self.model_cache.move_to_end(language_code)
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
if language_code not in self.asr_config:
|
| 88 |
+
raise ValueError(f"Language {language_code} not supported")
|
| 89 |
+
|
| 90 |
+
model_config = self.asr_config[language_code]
|
| 91 |
+
|
| 92 |
+
# Check if we need to evict old models
|
| 93 |
+
while len(self.model_cache) >= self.max_asr_models:
|
| 94 |
+
# Remove least recently used model
|
| 95 |
+
old_lang, _ = self.model_cache.popitem(last=False)
|
| 96 |
+
if old_lang in self.asr_models:
|
| 97 |
+
del self.asr_models[old_lang]
|
| 98 |
+
if old_lang in self.processors:
|
| 99 |
+
del self.processors[old_lang]
|
| 100 |
+
print(f"🗑️ ONNX ASR: Evicted model for {old_lang} (LRU cache)")
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
if model_config.get("use_onnx", False):
|
| 104 |
+
# Load ONNX model
|
| 105 |
+
print(f"📥 ONNX ASR: Loading ONNX model for {language_code}")
|
| 106 |
+
|
| 107 |
+
# Special handling for Whisper models
|
| 108 |
+
if model_config.get("model_type") == "whisper":
|
| 109 |
+
print(f"🎙️ ONNX ASR: Loading pre-converted Whisper ONNX model from {model_config['model_repo']}")
|
| 110 |
+
|
| 111 |
+
# Load pre-converted Whisper ONNX model using Optimum
|
| 112 |
+
load_kwargs = {
|
| 113 |
+
# Note: No 'export' parameter needed since model is already in ONNX format
|
| 114 |
+
# This is the key optimization - no runtime conversion!
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# Add subfolder if specified (for models that store ONNX in subfolders)
|
| 118 |
+
if "subfolder" in model_config:
|
| 119 |
+
load_kwargs["subfolder"] = model_config["subfolder"]
|
| 120 |
+
|
| 121 |
+
# ⭐ KEY OPTIMIZATION: No export flag needed for pre-converted models
|
| 122 |
+
# The old code had: if model_config.get("export", False): load_kwargs["export"] = True
|
| 123 |
+
# Now we skip this entirely since the model is already in ONNX format
|
| 124 |
+
|
| 125 |
+
model = ORTModelForSpeechSeq2Seq.from_pretrained(
|
| 126 |
+
model_config["model_repo"],
|
| 127 |
+
**load_kwargs
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Load Whisper processor
|
| 131 |
+
processor = WhisperProcessor.from_pretrained(model_config["model_repo"])
|
| 132 |
+
|
| 133 |
+
# Configure for English transcription
|
| 134 |
+
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
|
| 135 |
+
language="en",
|
| 136 |
+
task="transcribe"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
self.asr_models[language_code] = model
|
| 140 |
+
self.processors[language_code] = processor
|
| 141 |
+
|
| 142 |
+
print(f"✅ ONNX ASR: Successfully loaded pre-converted Whisper ONNX model for {language_code}")
|
| 143 |
+
|
| 144 |
+
else:
|
| 145 |
+
# Original wav2vec2-bert model loading logic (unchanged)
|
| 146 |
+
# Create ONNX session with optimizations
|
| 147 |
+
session_options = ort.SessionOptions()
|
| 148 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 149 |
+
|
| 150 |
+
# Enable parallel execution
|
| 151 |
+
session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
| 152 |
+
|
| 153 |
+
model_path = model_config["model_repo"]
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
# Try to load from HuggingFace directly
|
| 157 |
+
from huggingface_hub import hf_hub_download
|
| 158 |
+
model_file = hf_hub_download(repo_id=model_path, filename="model.onnx")
|
| 159 |
+
|
| 160 |
+
# Create ONNX Runtime session
|
| 161 |
+
session = ort.InferenceSession(
|
| 162 |
+
model_file,
|
| 163 |
+
session_options,
|
| 164 |
+
providers=self.providers
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Load processor/tokenizer
|
| 168 |
+
processor = AutoProcessor.from_pretrained(model_path)
|
| 169 |
+
|
| 170 |
+
self.asr_models[language_code] = session
|
| 171 |
+
self.processors[language_code] = processor
|
| 172 |
+
|
| 173 |
+
print(f"✅ ONNX ASR: Successfully loaded {model_config['model_type']} ONNX model for {language_code}")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"❌ Error loading ONNX model {model_path}: {e}")
|
| 177 |
+
raise
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(f"Non-ONNX models not supported in optimized service")
|
| 181 |
+
|
| 182 |
+
# Add to cache
|
| 183 |
+
self.model_cache[language_code] = True
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"❌ Error loading model for {language_code}: {e}")
|
| 187 |
+
raise
|
| 188 |
+
|
| 189 |
+
# Rest of the methods remain the same as the original transcription service
|
| 190 |
+
# (transcribe_audio, process_audio_chunk, etc.)
|
| 191 |
+
# ... [Include all other methods from the original service]
|
| 192 |
+
|
| 193 |
+
async def transcribe_audio(self, participant_id: str, audio_data: bytes, language_code: str = "eng") -> Optional[str]:
|
| 194 |
+
"""Transcribe audio using ONNX models"""
|
| 195 |
+
try:
|
| 196 |
+
await self.ensure_model_loaded(language_code)
|
| 197 |
+
|
| 198 |
+
if language_code not in self.asr_models or language_code not in self.processors:
|
| 199 |
+
raise ValueError(f"Model not loaded for language: {language_code}")
|
| 200 |
+
|
| 201 |
+
model = self.asr_models[language_code]
|
| 202 |
+
processor = self.processors[language_code]
|
| 203 |
+
|
| 204 |
+
# Convert audio bytes to numpy array
|
| 205 |
+
audio_io = io.BytesIO(audio_data)
|
| 206 |
+
with wave.open(audio_io, 'rb') as wav_file:
|
| 207 |
+
frames = wav_file.readframes(-1)
|
| 208 |
+
sample_rate = wav_file.getframerate()
|
| 209 |
+
audio_np = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
|
| 210 |
+
|
| 211 |
+
# Get model configuration
|
| 212 |
+
model_config = self.asr_config[language_code]
|
| 213 |
+
|
| 214 |
+
if model_config.get("model_type") == "whisper":
|
| 215 |
+
# Process with Whisper ONNX model
|
| 216 |
+
inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="pt")
|
| 217 |
+
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
predicted_ids = model.generate(**inputs)
|
| 220 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 221 |
+
|
| 222 |
+
return transcription.strip()
|
| 223 |
+
|
| 224 |
+
else:
|
| 225 |
+
# Process with wav2vec2-bert ONNX model
|
| 226 |
+
inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="np")
|
| 227 |
+
|
| 228 |
+
# Run ONNX inference
|
| 229 |
+
ort_inputs = {model.get_inputs()[0].name: inputs.input_values}
|
| 230 |
+
ort_outputs = model.run(None, ort_inputs)
|
| 231 |
+
|
| 232 |
+
# Decode results
|
| 233 |
+
predicted_ids = np.argmax(ort_outputs[0], axis=-1)
|
| 234 |
+
transcription = processor.decode(predicted_ids[0])
|
| 235 |
+
|
| 236 |
+
return transcription.strip()
|
| 237 |
+
|
| 238 |
+
except Exception as e:
|
| 239 |
+
print(f"❌ Transcription error for {participant_id}: {e}")
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
def get_performance_stats(self) -> Dict[str, any]:
|
| 243 |
+
"""Get performance statistics for monitoring"""
|
| 244 |
+
return {
|
| 245 |
+
"loaded_models": list(self.model_cache.keys()),
|
| 246 |
+
"cache_size": len(self.model_cache),
|
| 247 |
+
"max_cache_size": self.max_asr_models,
|
| 248 |
+
"providers": self.providers,
|
| 249 |
+
"optimization_enabled": True,
|
| 250 |
+
"runtime_conversion": False # Key metric: no runtime conversion
|
| 251 |
+
}
|
app/services/translation_service.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
| 4 |
+
import torch
|
| 5 |
+
import nltk
|
| 6 |
+
from app.models import LanguageCode
|
| 7 |
+
from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
|
| 8 |
+
|
| 9 |
+
# FLORES-200 language codes mapping
|
| 10 |
+
FLORES_CODES = {
|
| 11 |
+
"English": "eng_Latn",
|
| 12 |
+
"eng": "eng_Latn",
|
| 13 |
+
"Swahili": "swh_Latn",
|
| 14 |
+
"swa": "swh_Latn",
|
| 15 |
+
"Kikuyu": "kik_Latn",
|
| 16 |
+
"kik": "kik_Latn",
|
| 17 |
+
"Kamba": "kam_Latn",
|
| 18 |
+
"kam": "kam_Latn",
|
| 19 |
+
"Kimeru": "mer_Latn",
|
| 20 |
+
"mer": "mer_Latn",
|
| 21 |
+
"Luo": "luo_Latn",
|
| 22 |
+
"luo": "luo_Latn",
|
| 23 |
+
"Somali": "som_Latn",
|
| 24 |
+
"som": "som_Latn",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
class TranslationService:
|
| 28 |
+
def __init__(self, enable_quantization: bool = True):
|
| 29 |
+
self.translation_pipeline = None
|
| 30 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 31 |
+
self.model_path = "mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4"
|
| 32 |
+
self.enable_quantization = enable_quantization
|
| 33 |
+
|
| 34 |
+
async def initialize(self):
|
| 35 |
+
"""Initialize translation model"""
|
| 36 |
+
try:
|
| 37 |
+
# Download NLTK data with better error handling
|
| 38 |
+
try:
|
| 39 |
+
nltk.download("punkt", quiet=True)
|
| 40 |
+
nltk.download('punkt_tab', quiet=True)
|
| 41 |
+
except Exception as nltk_error:
|
| 42 |
+
print(f"Warning: NLTK data download failed: {nltk_error}")
|
| 43 |
+
# Continue anyway, sentence tokenization might still work
|
| 44 |
+
|
| 45 |
+
# Load translation model with explicit model kwargs for newer transformers
|
| 46 |
+
print(f"Loading translation model: {self.model_path}")
|
| 47 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 48 |
+
self.model_path,
|
| 49 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 50 |
+
)
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
| 52 |
+
|
| 53 |
+
# Apply quantization if enabled
|
| 54 |
+
if self.enable_quantization:
|
| 55 |
+
try:
|
| 56 |
+
print("Applying INT8 quantization to translation model...")
|
| 57 |
+
model = apply_dynamic_int8_quantization(model, "translation")
|
| 58 |
+
stats = get_quantization_stats(model)
|
| 59 |
+
print(f"✓ Translation model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"Warning: Could not quantize translation model: {e}")
|
| 62 |
+
print(f"Continuing with unquantized model")
|
| 63 |
+
|
| 64 |
+
self.translation_pipeline = pipeline(
|
| 65 |
+
'translation',
|
| 66 |
+
model=model,
|
| 67 |
+
tokenizer=tokenizer,
|
| 68 |
+
device=self.device,
|
| 69 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Failed to initialize translation service: {e}")
|
| 74 |
+
raise
|
| 75 |
+
|
| 76 |
+
async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 77 |
+
"""Translate text from source language to target language"""
|
| 78 |
+
print(f"=== TRANSLATION REQUEST ===")
|
| 79 |
+
print(f"Text: '{text}'")
|
| 80 |
+
print(f"Source: {source_lang}")
|
| 81 |
+
print(f"Target: {target_lang}")
|
| 82 |
+
|
| 83 |
+
if not self.translation_pipeline:
|
| 84 |
+
print("TRANSLATION ERROR: Translation service not initialized")
|
| 85 |
+
raise RuntimeError("Translation service not initialized")
|
| 86 |
+
|
| 87 |
+
if not text or not text.strip():
|
| 88 |
+
print("TRANSLATION ERROR: Empty text provided")
|
| 89 |
+
return ""
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# Get FLORES codes
|
| 93 |
+
src_code = FLORES_CODES.get(source_lang, "eng_Latn")
|
| 94 |
+
tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
|
| 95 |
+
|
| 96 |
+
print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
|
| 97 |
+
|
| 98 |
+
# Skip translation if same language
|
| 99 |
+
if src_code == tgt_code:
|
| 100 |
+
print("TRANSLATION SKIPPED: Same source and target language")
|
| 101 |
+
return text
|
| 102 |
+
|
| 103 |
+
# Tokenize into sentences for better translation
|
| 104 |
+
sentences = nltk.sent_tokenize(text)
|
| 105 |
+
translated_sentences = []
|
| 106 |
+
|
| 107 |
+
print(f"Translating {len(sentences)} sentences...")
|
| 108 |
+
|
| 109 |
+
for i, sentence in enumerate(sentences):
|
| 110 |
+
if sentence.strip():
|
| 111 |
+
print(f"Translating sentence {i+1}: '{sentence}'")
|
| 112 |
+
|
| 113 |
+
result = self.translation_pipeline(
|
| 114 |
+
sentence,
|
| 115 |
+
src_lang=src_code,
|
| 116 |
+
tgt_lang=tgt_code
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
translated = result[0]['translation_text']
|
| 120 |
+
print(f"Translation result: '{translated}'")
|
| 121 |
+
|
| 122 |
+
# Preserve punctuation and capitalization
|
| 123 |
+
if sentence.strip().endswith(".") and not translated.strip().endswith("."):
|
| 124 |
+
translated += "."
|
| 125 |
+
|
| 126 |
+
if sentence.strip()[0].isupper() and translated.strip():
|
| 127 |
+
translated = translated[0].upper() + translated[1:]
|
| 128 |
+
|
| 129 |
+
translated_sentences.append(translated)
|
| 130 |
+
|
| 131 |
+
final_translation = " ".join(translated_sentences)
|
| 132 |
+
|
| 133 |
+
# Preserve paragraph breaks
|
| 134 |
+
if text.endswith(".\n\n"):
|
| 135 |
+
final_translation += ".\n\n"
|
| 136 |
+
|
| 137 |
+
print(f"FINAL TRANSLATION: '{final_translation}'")
|
| 138 |
+
print(f"=== TRANSLATION COMPLETE ===")
|
| 139 |
+
|
| 140 |
+
return final_translation
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
print(f"TRANSLATION ERROR: {e}")
|
| 144 |
+
import traceback
|
| 145 |
+
traceback.print_exc()
|
| 146 |
+
return text # Return original text if translation fails
|
| 147 |
+
|
| 148 |
+
async def cleanup(self):
|
| 149 |
+
"""Cleanup resources"""
|
| 150 |
+
self.translation_pipeline = None
|
| 151 |
+
|
app/services/translation_service_onnx.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import Dict, Optional
|
| 3 |
+
from transformers import AutoTokenizer, pipeline
|
| 4 |
+
from optimum.onnxruntime import ORTModelForSeq2SeqLM
|
| 5 |
+
import nltk
|
| 6 |
+
from app.models import LanguageCode
|
| 7 |
+
|
| 8 |
+
# FLORES-200 language codes mapping
|
| 9 |
+
FLORES_CODES = {
|
| 10 |
+
"English": "eng_Latn",
|
| 11 |
+
"eng": "eng_Latn",
|
| 12 |
+
"Swahili": "swh_Latn",
|
| 13 |
+
"swa": "swh_Latn",
|
| 14 |
+
"Kikuyu": "kik_Latn",
|
| 15 |
+
"kik": "kik_Latn",
|
| 16 |
+
"Kamba": "kam_Latn",
|
| 17 |
+
"kam": "kam_Latn",
|
| 18 |
+
"Kimeru": "mer_Latn",
|
| 19 |
+
"mer": "mer_Latn",
|
| 20 |
+
"Luo": "luo_Latn",
|
| 21 |
+
"luo": "luo_Latn",
|
| 22 |
+
"Somali": "som_Latn",
|
| 23 |
+
"som": "som_Latn",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
class ONNXTranslationService:
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.model = None
|
| 29 |
+
self.tokenizer = None
|
| 30 |
+
self.translation_pipeline = None
|
| 31 |
+
|
| 32 |
+
# Use ONNX optimized NLLB model (FP32 format with separate encoder/decoder)
|
| 33 |
+
self.model_repo = "mutisya/nllb-translation-onnx-v25-37-1"
|
| 34 |
+
|
| 35 |
+
async def initialize(self):
|
| 36 |
+
"""Initialize ONNX translation model using optimum.onnxruntime"""
|
| 37 |
+
try:
|
| 38 |
+
print("ONNX Translation: Initializing translation service with ONNX Runtime...")
|
| 39 |
+
print(f"ONNX Translation: Loading model from {self.model_repo}")
|
| 40 |
+
|
| 41 |
+
# Check available providers for GPU detection
|
| 42 |
+
import onnxruntime as ort
|
| 43 |
+
available_providers = ort.get_available_providers()
|
| 44 |
+
print(f"ONNX Translation: Available providers: {available_providers}")
|
| 45 |
+
|
| 46 |
+
# Download NLTK data with better error handling
|
| 47 |
+
try:
|
| 48 |
+
nltk.download("punkt", quiet=True)
|
| 49 |
+
nltk.download('punkt_tab', quiet=True)
|
| 50 |
+
except Exception as nltk_error:
|
| 51 |
+
print(f"Warning: NLTK data download failed: {nltk_error}")
|
| 52 |
+
|
| 53 |
+
# Get authentication token for private repo
|
| 54 |
+
import os
|
| 55 |
+
auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
|
| 56 |
+
|
| 57 |
+
# Configure providers list for optimal performance
|
| 58 |
+
print("ONNX Translation: Configuring execution providers...")
|
| 59 |
+
if 'CUDAExecutionProvider' in available_providers:
|
| 60 |
+
# Use both CUDA and CPU providers to eliminate assignment warnings
|
| 61 |
+
providers_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
| 62 |
+
primary_provider = 'CUDAExecutionProvider'
|
| 63 |
+
print(f"ONNX Translation: Using providers: {providers_list} (primary: {primary_provider})")
|
| 64 |
+
else:
|
| 65 |
+
providers_list = ['CPUExecutionProvider']
|
| 66 |
+
primary_provider = 'CPUExecutionProvider'
|
| 67 |
+
print(f"ONNX Translation: Using CPU-only providers: {providers_list}")
|
| 68 |
+
|
| 69 |
+
# Load ONNX model using optimum (handles separate encoder/decoder files)
|
| 70 |
+
# Configure session options for optimal CUDA performance
|
| 71 |
+
import onnxruntime as ort
|
| 72 |
+
session_options = ort.SessionOptions()
|
| 73 |
+
session_options.log_severity_level = 1 # WARNING level for detailed logs
|
| 74 |
+
session_options.logid = "ONNX_Translation"
|
| 75 |
+
|
| 76 |
+
# Enable all graph optimizations to reduce memcpy operations
|
| 77 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 78 |
+
|
| 79 |
+
# Optimize threading for better GPU utilization
|
| 80 |
+
session_options.inter_op_num_threads = 1 # Reduce CPU thread contention
|
| 81 |
+
session_options.intra_op_num_threads = 1 # Focus on GPU execution
|
| 82 |
+
|
| 83 |
+
# Note: enable_cuda_graph not available in this ONNX Runtime version
|
| 84 |
+
|
| 85 |
+
# Configure provider options with performance optimizations for CUDA
|
| 86 |
+
provider_options = []
|
| 87 |
+
if primary_provider == 'CUDAExecutionProvider':
|
| 88 |
+
cuda_options = {
|
| 89 |
+
'device_id': 0,
|
| 90 |
+
'arena_extend_strategy': 'kNextPowerOfTwo',
|
| 91 |
+
'gpu_mem_limit': int(0.6 * 1024 * 1024 * 1024), # 60% of GPU memory for translation
|
| 92 |
+
'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
| 93 |
+
'cudnn_conv_use_max_workspace': '1', # Enable max workspace for fp16 tensor cores
|
| 94 |
+
'do_copy_in_default_stream': True,
|
| 95 |
+
'enable_skip_layer_norm_strict_mode': False, # Better performance for transformers
|
| 96 |
+
'prefer_nhwc': True, # Optimize data layout for GPU
|
| 97 |
+
}
|
| 98 |
+
# Configure providers with options
|
| 99 |
+
provider_options = [
|
| 100 |
+
('CUDAExecutionProvider', cuda_options),
|
| 101 |
+
('CPUExecutionProvider', {})
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
# Try with optimized provider configuration and session options
|
| 105 |
+
try:
|
| 106 |
+
print("ONNX Translation: Attempting optimized provider configuration...")
|
| 107 |
+
self.model = ORTModelForSeq2SeqLM.from_pretrained(
|
| 108 |
+
self.model_repo,
|
| 109 |
+
token=auth_token,
|
| 110 |
+
providers=provider_options if provider_options else providers_list, # Use provider options or list
|
| 111 |
+
session_options=session_options, # Add session options
|
| 112 |
+
)
|
| 113 |
+
print(f"ONNX Translation: Model loaded successfully with providers: {providers_list}")
|
| 114 |
+
|
| 115 |
+
# Check what providers the model is actually using
|
| 116 |
+
if hasattr(self.model, 'providers'):
|
| 117 |
+
print(f"ONNX Translation: Model is using providers: {self.model.providers}")
|
| 118 |
+
if hasattr(self.model, 'device'):
|
| 119 |
+
print(f"ONNX Translation: Model device: {self.model.device}")
|
| 120 |
+
|
| 121 |
+
except Exception as e1:
|
| 122 |
+
print(f"ONNX Translation: Optimized provider approach failed: {e1}")
|
| 123 |
+
print("ONNX Translation: Falling back to simple provider list...")
|
| 124 |
+
|
| 125 |
+
# Fallback: Try with simple provider list (no options)
|
| 126 |
+
try:
|
| 127 |
+
self.model = ORTModelForSeq2SeqLM.from_pretrained(
|
| 128 |
+
self.model_repo,
|
| 129 |
+
token=auth_token,
|
| 130 |
+
providers=providers_list, # Simple provider list
|
| 131 |
+
session_options=session_options,
|
| 132 |
+
)
|
| 133 |
+
print(f"ONNX Translation: Model loaded successfully with simple providers: {providers_list}")
|
| 134 |
+
|
| 135 |
+
# Check what the model is actually using
|
| 136 |
+
if hasattr(self.model, 'providers'):
|
| 137 |
+
print(f"ONNX Translation: Model is using providers: {self.model.providers}")
|
| 138 |
+
if hasattr(self.model, 'device'):
|
| 139 |
+
print(f"ONNX Translation: Model device: {self.model.device}")
|
| 140 |
+
|
| 141 |
+
except Exception as e2:
|
| 142 |
+
print(f"ONNX Translation: Simple provider approach failed: {e2}")
|
| 143 |
+
print("ONNX Translation: Falling back to auto-detect...")
|
| 144 |
+
|
| 145 |
+
# Final fallback: Let model auto-detect
|
| 146 |
+
self.model = ORTModelForSeq2SeqLM.from_pretrained(
|
| 147 |
+
self.model_repo,
|
| 148 |
+
token=auth_token
|
| 149 |
+
# Not passing provider, letting it auto-detect based on device
|
| 150 |
+
)
|
| 151 |
+
print(f"ONNX Translation: Model loaded successfully with auto-detection")
|
| 152 |
+
|
| 153 |
+
# Check what the model is actually using
|
| 154 |
+
if hasattr(self.model, 'providers'):
|
| 155 |
+
print(f"ONNX Translation: Model auto-selected providers: {self.model.providers}")
|
| 156 |
+
if hasattr(self.model, 'device'):
|
| 157 |
+
print(f"ONNX Translation: Model device: {self.model.device}")
|
| 158 |
+
|
| 159 |
+
# Load tokenizer
|
| 160 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 161 |
+
self.model_repo,
|
| 162 |
+
token=auth_token
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Create translation pipeline
|
| 166 |
+
# For ONNX models, we should specify device to ensure pipeline uses GPU
|
| 167 |
+
# Use the same provider detection as the model to ensure consistency
|
| 168 |
+
device = 0 if primary_provider == 'CUDAExecutionProvider' else -1
|
| 169 |
+
print(f"ONNX Translation: Setting pipeline device to: {device} ({'GPU' if device >= 0 else 'CPU'})")
|
| 170 |
+
print(f"ONNX Translation: Pipeline will use device based on primary provider: {primary_provider}")
|
| 171 |
+
|
| 172 |
+
self.translation_pipeline = pipeline(
|
| 173 |
+
"translation",
|
| 174 |
+
model=self.model,
|
| 175 |
+
tokenizer=self.tokenizer,
|
| 176 |
+
device=device
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
print("ONNX Translation: Successfully initialized ONNX translation model")
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"Failed to initialize ONNX translation service: {e}")
|
| 183 |
+
print("ONNX translation model is not available. Please ensure the model repository exists and contains the required ONNX files.")
|
| 184 |
+
import traceback
|
| 185 |
+
traceback.print_exc()
|
| 186 |
+
raise RuntimeError(f"ONNX translation model unavailable at {self.model_repo}: {e}")
|
| 187 |
+
|
| 188 |
+
async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
| 189 |
+
"""Translate text from source language to target language using ONNX"""
|
| 190 |
+
print(f"=== ONNX TRANSLATION REQUEST ===")
|
| 191 |
+
print(f"Text: '{text}'")
|
| 192 |
+
print(f"Source: {source_lang}")
|
| 193 |
+
print(f"Target: {target_lang}")
|
| 194 |
+
|
| 195 |
+
if not self.translation_pipeline:
|
| 196 |
+
print("ONNX TRANSLATION ERROR: Translation service not initialized")
|
| 197 |
+
raise RuntimeError("ONNX Translation service not initialized")
|
| 198 |
+
|
| 199 |
+
if not text or not text.strip():
|
| 200 |
+
print("ONNX TRANSLATION ERROR: Empty text provided")
|
| 201 |
+
return ""
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
# Get FLORES codes
|
| 205 |
+
src_code = FLORES_CODES.get(source_lang, "eng_Latn")
|
| 206 |
+
tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
|
| 207 |
+
|
| 208 |
+
print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
|
| 209 |
+
|
| 210 |
+
# Skip translation if same language
|
| 211 |
+
if src_code == tgt_code:
|
| 212 |
+
print("ONNX TRANSLATION SKIPPED: Same source and target language")
|
| 213 |
+
return text
|
| 214 |
+
|
| 215 |
+
# Tokenize into sentences for better translation
|
| 216 |
+
sentences = nltk.sent_tokenize(text)
|
| 217 |
+
translated_sentences = []
|
| 218 |
+
|
| 219 |
+
print(f"Translating {len(sentences)} sentences with ONNX...")
|
| 220 |
+
|
| 221 |
+
for i, sentence in enumerate(sentences):
|
| 222 |
+
if sentence.strip():
|
| 223 |
+
print(f"Translating sentence {i+1}: '{sentence}'")
|
| 224 |
+
|
| 225 |
+
# Use the pipeline for translation
|
| 226 |
+
result = self.translation_pipeline(
|
| 227 |
+
sentence.strip(),
|
| 228 |
+
src_lang=src_code,
|
| 229 |
+
tgt_lang=tgt_code,
|
| 230 |
+
max_length=512
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
translated = result[0]['translation_text']
|
| 234 |
+
print(f"ONNX Translation result: '{translated}'")
|
| 235 |
+
|
| 236 |
+
# Preserve punctuation and capitalization
|
| 237 |
+
if sentence.strip().endswith(".") and not translated.strip().endswith("."):
|
| 238 |
+
translated += "."
|
| 239 |
+
|
| 240 |
+
if sentence.strip() and sentence.strip()[0].isupper() and translated.strip():
|
| 241 |
+
translated = translated[0].upper() + translated[1:]
|
| 242 |
+
|
| 243 |
+
translated_sentences.append(translated)
|
| 244 |
+
|
| 245 |
+
final_translation = " ".join(translated_sentences)
|
| 246 |
+
|
| 247 |
+
# Preserve paragraph breaks
|
| 248 |
+
if text.endswith(".\n\n"):
|
| 249 |
+
final_translation += ".\n\n"
|
| 250 |
+
|
| 251 |
+
print(f"ONNX FINAL TRANSLATION: '{final_translation}'")
|
| 252 |
+
print(f"=== ONNX TRANSLATION COMPLETE ===")
|
| 253 |
+
|
| 254 |
+
return final_translation
|
| 255 |
+
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f"ONNX TRANSLATION ERROR: {e}")
|
| 258 |
+
import traceback
|
| 259 |
+
traceback.print_exc()
|
| 260 |
+
raise RuntimeError(f"Translation failed: {e}")
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
async def cleanup(self):
|
| 264 |
+
"""Cleanup resources"""
|
| 265 |
+
self.model = None
|
| 266 |
+
self.tokenizer = None
|
| 267 |
+
self.translation_pipeline = None
|
| 268 |
+
print("ONNX Translation: Translation service cleaned up")
|
app/services/tts_service.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
import subprocess
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
from transformers import pipeline
|
| 8 |
+
import torch
|
| 9 |
+
import os
|
| 10 |
+
from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
|
| 11 |
+
|
| 12 |
+
class TTSService:
|
| 13 |
+
def __init__(self, enable_quantization: bool = True):
|
| 14 |
+
self.tts_pipelines: Dict[str, any] = {}
|
| 15 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 16 |
+
self.enable_quantization = enable_quantization
|
| 17 |
+
|
| 18 |
+
# Check if espeak is available
|
| 19 |
+
self.espeak_available = self._check_espeak_availability()
|
| 20 |
+
|
| 21 |
+
# TTS model configurations from your original code
|
| 22 |
+
self.tts_config = {
|
| 23 |
+
"kik": {"model_repo": "mutisya/vits_kik_drL_24_5-v24_27_1_f", "model_type": "vits"},
|
| 24 |
+
"luo": {"model_repo": "mutisya/vits_luo_drL_24_5-v24_27_1_f", "model_type": "vits"},
|
| 25 |
+
"kam": {"model_repo": "mutisya/vits_kam_drL_24_5-v24_27_1_f", "model_type": "vits"},
|
| 26 |
+
"mer": {"model_repo": "mutisya/vits_mer_drL_24_5-v24_27_1_f", "model_type": "vits"},
|
| 27 |
+
"som": {"model_repo": "mutisya/vits_som_drL_24_5-v24_27_1_m", "model_type": "vits"},
|
| 28 |
+
"swa": {"model_repo": "mutisya/vits_swh_biblica-v24_27_1_m", "model_type": "vits"},
|
| 29 |
+
"eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits"},
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# Alternative TTS models that don't require espeak (fallback)
|
| 33 |
+
self.fallback_tts_config = {
|
| 34 |
+
"eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
|
| 35 |
+
"swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
|
| 36 |
+
"som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
self.preload_languages = ["kik", "swa"]
|
| 40 |
+
self.background_loading_task = None
|
| 41 |
+
self.models_loading_status = {}
|
| 42 |
+
|
| 43 |
+
def _check_espeak_availability(self) -> bool:
|
| 44 |
+
"""Check if espeak is available on the system"""
|
| 45 |
+
try:
|
| 46 |
+
result = subprocess.run(['espeak', '--version'],
|
| 47 |
+
capture_output=True, text=True, timeout=5)
|
| 48 |
+
if result.returncode == 0:
|
| 49 |
+
print("TTS: espeak is available")
|
| 50 |
+
return True
|
| 51 |
+
else:
|
| 52 |
+
print("TTS: espeak command failed")
|
| 53 |
+
return False
|
| 54 |
+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
|
| 55 |
+
print(f"TTS: espeak not available: {e}")
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
async def initialize(self):
|
| 59 |
+
"""Initialize TTS models for preloaded languages"""
|
| 60 |
+
print("TTS: Initializing TTS service...")
|
| 61 |
+
print(f"TTS: espeak available: {self.espeak_available}")
|
| 62 |
+
|
| 63 |
+
for lang_code in self.preload_languages:
|
| 64 |
+
await self.ensure_model_loaded(lang_code)
|
| 65 |
+
|
| 66 |
+
def _load_and_quantize_tts_pipeline(self, lang_code: str, model_repo: str, model_type: str = "vits"):
|
| 67 |
+
"""Load TTS pipeline and optionally apply INT8 quantization"""
|
| 68 |
+
print(f"TTS: Loading model for {lang_code}: {model_repo}")
|
| 69 |
+
|
| 70 |
+
pipeline_obj = pipeline(
|
| 71 |
+
"text-to-speech",
|
| 72 |
+
model=model_repo,
|
| 73 |
+
device=self.device
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Apply quantization if enabled
|
| 77 |
+
if self.enable_quantization:
|
| 78 |
+
try:
|
| 79 |
+
# Get the underlying model from the pipeline
|
| 80 |
+
model = pipeline_obj.model
|
| 81 |
+
|
| 82 |
+
print(f"TTS: Applying INT8 quantization to {lang_code} model...")
|
| 83 |
+
quantized_model = apply_dynamic_int8_quantization(model, model_type)
|
| 84 |
+
|
| 85 |
+
# Replace the model in the pipeline
|
| 86 |
+
pipeline_obj.model = quantized_model
|
| 87 |
+
|
| 88 |
+
# Print quantization stats
|
| 89 |
+
stats = get_quantization_stats(quantized_model)
|
| 90 |
+
print(f"✓ TTS {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
print(f"TTS: Warning - Could not quantize {lang_code} model: {e}")
|
| 94 |
+
print(f"TTS: Continuing with unquantized model")
|
| 95 |
+
|
| 96 |
+
return pipeline_obj
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 100 |
+
"""Load TTS model for language if not already loaded"""
|
| 101 |
+
if language_code in self.tts_pipelines:
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
# First try to load primary model if espeak is available
|
| 105 |
+
if self.espeak_available and language_code in self.tts_config:
|
| 106 |
+
try:
|
| 107 |
+
model_config = self.tts_config[language_code]
|
| 108 |
+
pipeline_obj = self._load_and_quantize_tts_pipeline(
|
| 109 |
+
language_code,
|
| 110 |
+
model_config["model_repo"],
|
| 111 |
+
model_config.get("model_type", "vits")
|
| 112 |
+
)
|
| 113 |
+
self.tts_pipelines[language_code] = pipeline_obj
|
| 114 |
+
print(f"TTS: Loaded primary TTS model for {language_code}")
|
| 115 |
+
return
|
| 116 |
+
except Exception as e:
|
| 117 |
+
print(f"TTS: Failed to load primary TTS model for {language_code}: {e}")
|
| 118 |
+
# Continue to try fallback models
|
| 119 |
+
|
| 120 |
+
# Try fallback models if primary failed or espeak not available
|
| 121 |
+
if language_code in self.fallback_tts_config:
|
| 122 |
+
try:
|
| 123 |
+
model_config = self.fallback_tts_config[language_code]
|
| 124 |
+
|
| 125 |
+
if model_config["model_type"] == "speecht5":
|
| 126 |
+
# Special handling for SpeechT5
|
| 127 |
+
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
|
| 128 |
+
import torch
|
| 129 |
+
|
| 130 |
+
processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
|
| 131 |
+
model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
|
| 132 |
+
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
| 133 |
+
|
| 134 |
+
# Create a custom pipeline-like object
|
| 135 |
+
class SpeechT5Pipeline:
|
| 136 |
+
def __init__(self, processor, model, vocoder):
|
| 137 |
+
self.processor = processor
|
| 138 |
+
self.model = model
|
| 139 |
+
self.vocoder = vocoder
|
| 140 |
+
|
| 141 |
+
def __call__(self, text):
|
| 142 |
+
inputs = self.processor(text=text, return_tensors="pt")
|
| 143 |
+
# Use default speaker embeddings
|
| 144 |
+
import datasets
|
| 145 |
+
embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
| 146 |
+
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
| 147 |
+
|
| 148 |
+
speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"audio": speech.numpy(),
|
| 152 |
+
"sampling_rate": 16000
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
|
| 156 |
+
else:
|
| 157 |
+
# Standard pipeline for MMS models
|
| 158 |
+
pipeline_obj = pipeline(
|
| 159 |
+
"text-to-speech",
|
| 160 |
+
model=model_config["model_repo"],
|
| 161 |
+
device=self.device
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.tts_pipelines[language_code] = pipeline_obj
|
| 165 |
+
print(f"TTS: Loaded fallback TTS model for {language_code}")
|
| 166 |
+
return
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"TTS: Failed to load fallback TTS model for {language_code}: {e}")
|
| 170 |
+
|
| 171 |
+
print(f"TTS: No TTS model available for language: {language_code}")
|
| 172 |
+
|
| 173 |
+
async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
|
| 174 |
+
"""Generate speech audio from text
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
text: Text to convert to speech
|
| 178 |
+
language_code: Language code for TTS model
|
| 179 |
+
output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Audio bytes in the requested format, or None if generation fails
|
| 183 |
+
"""
|
| 184 |
+
try:
|
| 185 |
+
print(f"=== TTS GENERATION REQUEST ===")
|
| 186 |
+
print(f"Text: '{text}'")
|
| 187 |
+
print(f"Language: {language_code}")
|
| 188 |
+
print(f"Output format: {output_format}")
|
| 189 |
+
|
| 190 |
+
# Input validation: Check for invalid or problematic text
|
| 191 |
+
if not text or not text.strip():
|
| 192 |
+
print("TTS: Empty or whitespace-only text, skipping TTS generation")
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
# Check for very short text that might cause issues
|
| 196 |
+
clean_text = text.strip()
|
| 197 |
+
if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
|
| 198 |
+
print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
|
| 199 |
+
return None
|
| 200 |
+
|
| 201 |
+
# Check for minimum meaningful length
|
| 202 |
+
if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
|
| 203 |
+
print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
print(f"TTS pipelines available: {list(self.tts_pipelines.keys())}")
|
| 207 |
+
print(f"TTS config available: {list(self.tts_config.keys())}")
|
| 208 |
+
print(f"Fallback config available: {list(self.fallback_tts_config.keys())}")
|
| 209 |
+
|
| 210 |
+
# Check if the language is supported
|
| 211 |
+
if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
|
| 212 |
+
print(f"TTS: Language {language_code} not configured for TTS")
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
await self.ensure_model_loaded(language_code)
|
| 216 |
+
|
| 217 |
+
if language_code not in self.tts_pipelines:
|
| 218 |
+
print(f"TTS: TTS model not available for language: {language_code}")
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
if not text or not text.strip():
|
| 222 |
+
print("TTS: Empty text provided")
|
| 223 |
+
return None
|
| 224 |
+
|
| 225 |
+
print(f"TTS: Generating speech for '{text}' in {language_code}")
|
| 226 |
+
|
| 227 |
+
# Generate speech
|
| 228 |
+
pipeline_obj = self.tts_pipelines[language_code]
|
| 229 |
+
result = pipeline_obj(text)
|
| 230 |
+
|
| 231 |
+
audio_array = result["audio"]
|
| 232 |
+
sample_rate = result.get("sampling_rate", 22050)
|
| 233 |
+
|
| 234 |
+
print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
|
| 235 |
+
|
| 236 |
+
# Validate audio array
|
| 237 |
+
if len(audio_array) == 0:
|
| 238 |
+
print("TTS: Warning - Generated audio array is empty")
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
# Check for potential issues with audio data
|
| 242 |
+
audio_min = np.min(audio_array)
|
| 243 |
+
audio_max = np.max(audio_array)
|
| 244 |
+
audio_rms = np.sqrt(np.mean(audio_array**2))
|
| 245 |
+
print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
|
| 246 |
+
|
| 247 |
+
# Check if audio might be silent or corrupted
|
| 248 |
+
if audio_rms < 0.001:
|
| 249 |
+
print("TTS: Warning - Audio appears to be very quiet or silent")
|
| 250 |
+
if audio_max > 1.0 or audio_min < -1.0:
|
| 251 |
+
print("TTS: Warning - Audio values outside expected range [-1, 1]")
|
| 252 |
+
# Clip to valid range
|
| 253 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 254 |
+
print("TTS: Clipped audio to valid range")
|
| 255 |
+
|
| 256 |
+
# Convert to WAV bytes with appropriate sample rate
|
| 257 |
+
if output_format == "wav":
|
| 258 |
+
# For Android: use 16kHz sample rate
|
| 259 |
+
target_sample_rate = 16000
|
| 260 |
+
wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 261 |
+
print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
|
| 262 |
+
|
| 263 |
+
# Convert sample rate to 16kHz if needed for Android compatibility
|
| 264 |
+
if sample_rate != target_sample_rate:
|
| 265 |
+
print(f"TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz for Android compatibility")
|
| 266 |
+
wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
|
| 267 |
+
print(f"TTS: Resampled WAV: {len(wav_bytes)} bytes")
|
| 268 |
+
|
| 269 |
+
print(f"TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
|
| 270 |
+
print(f"=== TTS GENERATION COMPLETE ===")
|
| 271 |
+
|
| 272 |
+
return wav_bytes
|
| 273 |
+
else:
|
| 274 |
+
# For web: use original sample rate and convert to WebM
|
| 275 |
+
wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 276 |
+
print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
|
| 277 |
+
|
| 278 |
+
# Convert to WebM format for web compatibility
|
| 279 |
+
webm_bytes = await self._convert_to_webm(wav_bytes)
|
| 280 |
+
|
| 281 |
+
print(f"TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
|
| 282 |
+
print(f"=== TTS GENERATION COMPLETE ===")
|
| 283 |
+
|
| 284 |
+
return webm_bytes
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"TTS: TTS generation error: {e}")
|
| 288 |
+
import traceback
|
| 289 |
+
traceback.print_exc()
|
| 290 |
+
return None
|
| 291 |
+
|
| 292 |
+
async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
|
| 293 |
+
"""Generate speech audio in both WebM and WAV formats
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
text: Text to convert to speech
|
| 297 |
+
language_code: Language code for TTS model
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
|
| 301 |
+
"""
|
| 302 |
+
try:
|
| 303 |
+
print(f"=== TTS DUAL FORMAT GENERATION REQUEST ===")
|
| 304 |
+
print(f"Text: '{text}'")
|
| 305 |
+
print(f"Language: {language_code}")
|
| 306 |
+
|
| 307 |
+
# Input validation: Check for invalid or problematic text
|
| 308 |
+
if not text or not text.strip():
|
| 309 |
+
print("TTS: Empty or whitespace-only text, skipping TTS generation")
|
| 310 |
+
return None, None
|
| 311 |
+
|
| 312 |
+
# Check for very short text that might cause issues
|
| 313 |
+
clean_text = text.strip()
|
| 314 |
+
if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
|
| 315 |
+
print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
|
| 316 |
+
return None, None
|
| 317 |
+
|
| 318 |
+
# Check for minimum meaningful length
|
| 319 |
+
if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
|
| 320 |
+
print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
|
| 321 |
+
return None, None
|
| 322 |
+
|
| 323 |
+
# Check if the language is supported
|
| 324 |
+
if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
|
| 325 |
+
print(f"TTS: Language {language_code} not configured for TTS")
|
| 326 |
+
return None, None
|
| 327 |
+
|
| 328 |
+
await self.ensure_model_loaded(language_code)
|
| 329 |
+
|
| 330 |
+
if language_code not in self.tts_pipelines:
|
| 331 |
+
print(f"TTS: TTS model not available for language: {language_code}")
|
| 332 |
+
return None, None
|
| 333 |
+
|
| 334 |
+
print(f"TTS: Generating speech for '{text}' in {language_code}")
|
| 335 |
+
|
| 336 |
+
# Generate speech once
|
| 337 |
+
pipeline_obj = self.tts_pipelines[language_code]
|
| 338 |
+
result = pipeline_obj(text)
|
| 339 |
+
|
| 340 |
+
audio_array = result["audio"]
|
| 341 |
+
sample_rate = result.get("sampling_rate", 22050)
|
| 342 |
+
|
| 343 |
+
print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
|
| 344 |
+
|
| 345 |
+
# Validate audio array
|
| 346 |
+
if len(audio_array) == 0:
|
| 347 |
+
print("TTS: Warning - Generated audio array is empty")
|
| 348 |
+
return None, None
|
| 349 |
+
|
| 350 |
+
# Check for potential issues with audio data
|
| 351 |
+
audio_min = np.min(audio_array)
|
| 352 |
+
audio_max = np.max(audio_array)
|
| 353 |
+
audio_rms = np.sqrt(np.mean(audio_array**2))
|
| 354 |
+
print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
|
| 355 |
+
|
| 356 |
+
# Check if audio might be silent or corrupted
|
| 357 |
+
if audio_rms < 0.001:
|
| 358 |
+
print("TTS: Warning - Audio appears to be very quiet or silent")
|
| 359 |
+
if audio_max > 1.0 or audio_min < -1.0:
|
| 360 |
+
print("TTS: Warning - Audio values outside expected range [-1, 1]")
|
| 361 |
+
# Clip to valid range
|
| 362 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 363 |
+
print("TTS: Clipped audio to valid range")
|
| 364 |
+
|
| 365 |
+
# Generate WAV at original sample rate first
|
| 366 |
+
wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 367 |
+
print(f"TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
|
| 368 |
+
|
| 369 |
+
# Generate WebM from original WAV
|
| 370 |
+
webm_bytes = await self._convert_to_webm(wav_bytes_original)
|
| 371 |
+
print(f"TTS: Converted to WebM: {len(webm_bytes)} bytes")
|
| 372 |
+
|
| 373 |
+
# Generate 16kHz WAV for Android
|
| 374 |
+
wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
|
| 375 |
+
print(f"TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
|
| 376 |
+
|
| 377 |
+
print(f"TTS: Generated dual format audio for '{text}'")
|
| 378 |
+
print(f"=== TTS DUAL FORMAT GENERATION COMPLETE ===")
|
| 379 |
+
|
| 380 |
+
return webm_bytes, wav_bytes_16k
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
print(f"TTS: Dual format TTS generation error: {e}")
|
| 384 |
+
import traceback
|
| 385 |
+
traceback.print_exc()
|
| 386 |
+
return None, None
|
| 387 |
+
|
| 388 |
+
def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
|
| 389 |
+
"""Convert numpy audio array to WAV bytes"""
|
| 390 |
+
buffer = io.BytesIO()
|
| 391 |
+
with wave.open(buffer, 'wb') as wav_file:
|
| 392 |
+
wav_file.setnchannels(1) # Mono
|
| 393 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 394 |
+
wav_file.setframerate(sample_rate)
|
| 395 |
+
|
| 396 |
+
# Ensure audio is in valid range [-1, 1]
|
| 397 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 398 |
+
|
| 399 |
+
# Convert to int16 with proper scaling
|
| 400 |
+
int16_audio = (audio_array * 32767).astype(np.int16)
|
| 401 |
+
|
| 402 |
+
# Validate the converted audio
|
| 403 |
+
print(f"TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
|
| 404 |
+
print(f"TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
|
| 405 |
+
|
| 406 |
+
wav_file.writeframes(int16_audio.tobytes())
|
| 407 |
+
|
| 408 |
+
wav_data = buffer.getvalue()
|
| 409 |
+
print(f"TTS: WAV file created: {len(wav_data)} bytes (expected header: 44 bytes + {len(int16_audio) * 2} data bytes)")
|
| 410 |
+
|
| 411 |
+
return wav_data
|
| 412 |
+
|
| 413 |
+
async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
|
| 414 |
+
"""Resample WAV audio to 16kHz using FFmpeg"""
|
| 415 |
+
try:
|
| 416 |
+
process = subprocess.Popen([
|
| 417 |
+
"ffmpeg", "-f", "wav", "-i", "pipe:0",
|
| 418 |
+
"-ar", "16000", # Set output sample rate to 16kHz
|
| 419 |
+
"-ac", "1", # Ensure mono output
|
| 420 |
+
"-f", "wav", "pipe:1"
|
| 421 |
+
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 422 |
+
|
| 423 |
+
resampled_data, stderr = process.communicate(input=wav_bytes)
|
| 424 |
+
|
| 425 |
+
if process.returncode != 0:
|
| 426 |
+
print(f"TTS: FFmpeg resampling error: {stderr.decode()}")
|
| 427 |
+
return wav_bytes # Return original if resampling fails
|
| 428 |
+
|
| 429 |
+
return resampled_data
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"TTS: Resampling error: {e}")
|
| 433 |
+
return wav_bytes # Return original if resampling fails
|
| 434 |
+
|
| 435 |
+
async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
|
| 436 |
+
"""Convert WAV bytes to WebM format using FFmpeg"""
|
| 437 |
+
try:
|
| 438 |
+
process = subprocess.Popen([
|
| 439 |
+
"ffmpeg", "-f", "wav", "-i", "pipe:0",
|
| 440 |
+
"-c:a", "libopus", "-b:a", "64k",
|
| 441 |
+
"-f", "webm", "pipe:1"
|
| 442 |
+
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 443 |
+
|
| 444 |
+
webm_data, stderr = process.communicate(input=wav_bytes)
|
| 445 |
+
|
| 446 |
+
if process.returncode != 0:
|
| 447 |
+
print(f"TTS: FFmpeg error: {stderr.decode()}")
|
| 448 |
+
return wav_bytes # Return original WAV if conversion fails
|
| 449 |
+
|
| 450 |
+
return webm_data
|
| 451 |
+
|
| 452 |
+
except Exception as e:
|
| 453 |
+
print(f"TTS: WebM conversion error: {e}")
|
| 454 |
+
return wav_bytes # Return original WAV if conversion fails
|
| 455 |
+
|
| 456 |
+
async def load_remaining_models_in_background(self):
|
| 457 |
+
"""Load all remaining TTS models in the background after startup"""
|
| 458 |
+
try:
|
| 459 |
+
print("TTS: Starting background loading of additional voice models...")
|
| 460 |
+
|
| 461 |
+
# Load primary models first
|
| 462 |
+
for lang_code in self.tts_config.keys():
|
| 463 |
+
if lang_code not in self.preload_languages and lang_code not in self.tts_pipelines:
|
| 464 |
+
if self.espeak_available:
|
| 465 |
+
try:
|
| 466 |
+
print(f"TTS: Background loading primary model for {lang_code}...")
|
| 467 |
+
self.models_loading_status[lang_code] = "loading"
|
| 468 |
+
|
| 469 |
+
model_config = self.tts_config[lang_code]
|
| 470 |
+
pipeline_obj = pipeline(
|
| 471 |
+
"text-to-speech",
|
| 472 |
+
model=model_config["model_repo"],
|
| 473 |
+
device=self.device
|
| 474 |
+
)
|
| 475 |
+
self.tts_pipelines[lang_code] = pipeline_obj
|
| 476 |
+
self.models_loading_status[lang_code] = "loaded"
|
| 477 |
+
print(f"TTS: Successfully loaded primary model for {lang_code} in background")
|
| 478 |
+
|
| 479 |
+
# Add a small delay between loading models
|
| 480 |
+
await asyncio.sleep(2)
|
| 481 |
+
except Exception as e:
|
| 482 |
+
print(f"TTS: Failed to load primary model for {lang_code} in background: {e}")
|
| 483 |
+
self.models_loading_status[lang_code] = "failed"
|
| 484 |
+
|
| 485 |
+
# Load fallback models for languages not yet loaded
|
| 486 |
+
for lang_code in self.fallback_tts_config.keys():
|
| 487 |
+
if lang_code not in self.tts_pipelines:
|
| 488 |
+
try:
|
| 489 |
+
print(f"TTS: Background loading fallback model for {lang_code}...")
|
| 490 |
+
model_config = self.fallback_tts_config[lang_code]
|
| 491 |
+
|
| 492 |
+
if model_config["model_type"] == "speecht5":
|
| 493 |
+
from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
|
| 494 |
+
processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
|
| 495 |
+
model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
|
| 496 |
+
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
| 497 |
+
if self.device >= 0:
|
| 498 |
+
model = model.to(f"cuda:{self.device}")
|
| 499 |
+
vocoder = vocoder.to(f"cuda:{self.device}")
|
| 500 |
+
self.tts_pipelines[lang_code] = {
|
| 501 |
+
"type": "speecht5",
|
| 502 |
+
"processor": processor,
|
| 503 |
+
"model": model,
|
| 504 |
+
"vocoder": vocoder
|
| 505 |
+
}
|
| 506 |
+
else:
|
| 507 |
+
pipeline_obj = pipeline(
|
| 508 |
+
"text-to-speech",
|
| 509 |
+
model=model_config["model_repo"],
|
| 510 |
+
device=self.device
|
| 511 |
+
)
|
| 512 |
+
self.tts_pipelines[lang_code] = pipeline_obj
|
| 513 |
+
|
| 514 |
+
print(f"TTS: Successfully loaded fallback model for {lang_code} in background")
|
| 515 |
+
await asyncio.sleep(2)
|
| 516 |
+
except Exception as e:
|
| 517 |
+
print(f"TTS: Failed to load fallback model for {lang_code}: {e}")
|
| 518 |
+
|
| 519 |
+
print("TTS: Background loading of all voice models complete")
|
| 520 |
+
print(f"TTS: Loaded models: {list(self.tts_pipelines.keys())}")
|
| 521 |
+
except Exception as e:
|
| 522 |
+
print(f"TTS: Error in background model loading: {e}")
|
| 523 |
+
|
| 524 |
+
def start_background_loading(self):
|
| 525 |
+
"""Start background loading of models as a non-blocking task"""
|
| 526 |
+
if self.background_loading_task is None:
|
| 527 |
+
self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
|
| 528 |
+
print("TTS: Background model loading task started")
|
| 529 |
+
|
| 530 |
+
async def cleanup(self):
|
| 531 |
+
"""Cleanup resources"""
|
| 532 |
+
# Cancel background loading if still running
|
| 533 |
+
if self.background_loading_task and not self.background_loading_task.done():
|
| 534 |
+
self.background_loading_task.cancel()
|
| 535 |
+
try:
|
| 536 |
+
await self.background_loading_task
|
| 537 |
+
except asyncio.CancelledError:
|
| 538 |
+
pass
|
| 539 |
+
|
| 540 |
+
self.tts_pipelines.clear()
|
| 541 |
+
print("TTS: TTS service cleaned up")
|
app/services/tts_service_onnx.py
ADDED
|
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import wave
|
| 4 |
+
import numpy as np
|
| 5 |
+
import subprocess
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
import onnxruntime as ort
|
| 8 |
+
from transformers import AutoProcessor
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
class ONNXTTSService:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.tts_models: Dict[str, any] = {}
|
| 15 |
+
self.processors: Dict[str, any] = {}
|
| 16 |
+
self.max_tts_models = 3 # Keep up to 3 TTS models in memory
|
| 17 |
+
self.model_cache = OrderedDict() # LRU cache
|
| 18 |
+
|
| 19 |
+
# GPU optimization - detect and configure providers
|
| 20 |
+
available_providers = ort.get_available_providers()
|
| 21 |
+
print(f"ONNX TTS: Available providers: {available_providers}")
|
| 22 |
+
|
| 23 |
+
if 'CUDAExecutionProvider' in available_providers:
|
| 24 |
+
# Configure CUDA provider with optimizations
|
| 25 |
+
cuda_provider_options = {
|
| 26 |
+
'device_id': 0,
|
| 27 |
+
'arena_extend_strategy': 'kNextPowerOfTwo',
|
| 28 |
+
'gpu_mem_limit': int(0.7 * 1024 * 1024 * 1024), # 70% of GPU memory (TTS uses less than ASR)
|
| 29 |
+
'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
| 30 |
+
'do_copy_in_default_stream': True,
|
| 31 |
+
}
|
| 32 |
+
self.providers = [('CUDAExecutionProvider', cuda_provider_options), 'CPUExecutionProvider']
|
| 33 |
+
print(f"ONNX TTS: Using CUDA acceleration with GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
|
| 34 |
+
else:
|
| 35 |
+
self.providers = ['CPUExecutionProvider']
|
| 36 |
+
print("ONNX TTS: CUDA not available, using CPU execution")
|
| 37 |
+
|
| 38 |
+
print(f"ONNX TTS: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
|
| 39 |
+
|
| 40 |
+
# Check if espeak is available
|
| 41 |
+
self.espeak_available = self._check_espeak_availability()
|
| 42 |
+
|
| 43 |
+
# ONNX TTS model configurations - using FP32 optimized models (16kHz corrected)
|
| 44 |
+
self.tts_config = {
|
| 45 |
+
"kik": {"model_repo": "mutisya/vits-tts-onnx-fp32-kikuyu-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 46 |
+
"luo": {"model_repo": "mutisya/vits-tts-onnx-fp32-luo-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 47 |
+
"kam": {"model_repo": "mutisya/vits-tts-onnx-fp32-kamba-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 48 |
+
"mer": {"model_repo": "mutisya/vits-tts-onnx-fp32-kimeru-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 49 |
+
"som": {"model_repo": "mutisya/vits-tts-onnx-fp32-somali-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 50 |
+
"swa": {"model_repo": "mutisya/vits-tts-onnx-fp32-swahili-v25-37-1", "model_type": "vits", "use_onnx": True},
|
| 51 |
+
"eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits", "use_onnx": False}, # Fallback to PyTorch
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Alternative TTS models that don't require espeak (fallback)
|
| 55 |
+
self.fallback_tts_config = {
|
| 56 |
+
"eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
|
| 57 |
+
"swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
|
| 58 |
+
"som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
self.preload_languages = ["kik", "swa"]
|
| 62 |
+
|
| 63 |
+
def _check_espeak_availability(self) -> bool:
|
| 64 |
+
"""Check if espeak is available on the system"""
|
| 65 |
+
try:
|
| 66 |
+
result = subprocess.run(['espeak', '--version'],
|
| 67 |
+
capture_output=True, text=True, timeout=5)
|
| 68 |
+
if result.returncode == 0:
|
| 69 |
+
print("ONNX TTS: espeak is available")
|
| 70 |
+
return True
|
| 71 |
+
else:
|
| 72 |
+
print("ONNX TTS: espeak command failed")
|
| 73 |
+
return False
|
| 74 |
+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
|
| 75 |
+
print(f"ONNX TTS: espeak not available: {e}")
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
async def initialize(self):
|
| 79 |
+
"""Initialize TTS models for preloaded languages"""
|
| 80 |
+
print("ONNX TTS: Initializing TTS service with ONNX Runtime...")
|
| 81 |
+
print(f"ONNX TTS: espeak available: {self.espeak_available}")
|
| 82 |
+
print(f"ONNX TTS: Using providers: {self.providers}")
|
| 83 |
+
|
| 84 |
+
for lang_code in self.preload_languages:
|
| 85 |
+
await self.ensure_model_loaded(lang_code)
|
| 86 |
+
|
| 87 |
+
async def ensure_model_loaded(self, language_code: str):
|
| 88 |
+
"""Load TTS model for language if not already loaded with LRU cache"""
|
| 89 |
+
if language_code in self.model_cache:
|
| 90 |
+
# Move to end (most recently used)
|
| 91 |
+
self.model_cache.move_to_end(language_code)
|
| 92 |
+
return
|
| 93 |
+
|
| 94 |
+
# Check if we need to evict old models
|
| 95 |
+
while len(self.model_cache) >= self.max_tts_models:
|
| 96 |
+
# Remove least recently used model
|
| 97 |
+
old_lang, _ = self.model_cache.popitem(last=False)
|
| 98 |
+
if old_lang in self.tts_models:
|
| 99 |
+
del self.tts_models[old_lang]
|
| 100 |
+
if old_lang in self.processors:
|
| 101 |
+
del self.processors[old_lang]
|
| 102 |
+
print(f"ONNX TTS: Evicted model for {old_lang} (LRU cache)")
|
| 103 |
+
|
| 104 |
+
# First try to load ONNX model
|
| 105 |
+
if language_code in self.tts_config:
|
| 106 |
+
model_config = self.tts_config[language_code]
|
| 107 |
+
|
| 108 |
+
if model_config.get("use_onnx", False):
|
| 109 |
+
try:
|
| 110 |
+
print(f"ONNX TTS: Loading ONNX model for {language_code}")
|
| 111 |
+
|
| 112 |
+
# Create ONNX session with optimizations and verbose logging
|
| 113 |
+
session_options = ort.SessionOptions()
|
| 114 |
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 115 |
+
|
| 116 |
+
# Enable verbose logging to diagnose operator assignments
|
| 117 |
+
session_options.log_severity_level = 1 # WARNING level for detailed logs
|
| 118 |
+
session_options.logid = "ONNX_TTS" # Prefix for log identification
|
| 119 |
+
|
| 120 |
+
# GPU memory optimization for T4 with diagnostic tracing
|
| 121 |
+
if 'CUDAExecutionProvider' in self.providers:
|
| 122 |
+
provider_options = [{
|
| 123 |
+
'device_id': 0,
|
| 124 |
+
'arena_extend_strategy': 'kSameAsRequested',
|
| 125 |
+
'gpu_mem_limit': int(0.3 * 1024 * 1024 * 1024), # 30% of GPU memory for TTS
|
| 126 |
+
'cudnn_conv_algo_search': 'EXHAUSTIVE',
|
| 127 |
+
'do_copy_in_default_stream': True,
|
| 128 |
+
'enable_tracing': True, # Enable tracing for better diagnostics
|
| 129 |
+
}]
|
| 130 |
+
providers = [('CUDAExecutionProvider', provider_options[0]), 'CPUExecutionProvider']
|
| 131 |
+
else:
|
| 132 |
+
providers = self.providers
|
| 133 |
+
|
| 134 |
+
# Get authentication token for private repos
|
| 135 |
+
import os
|
| 136 |
+
auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
|
| 137 |
+
|
| 138 |
+
# Download ONNX model from HuggingFace Hub with authentication
|
| 139 |
+
from huggingface_hub import hf_hub_download
|
| 140 |
+
onnx_path = hf_hub_download(
|
| 141 |
+
repo_id=model_config["model_repo"],
|
| 142 |
+
filename="model.onnx",
|
| 143 |
+
token=auth_token
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
|
| 147 |
+
|
| 148 |
+
# Load processor for preprocessing with authentication
|
| 149 |
+
processor = AutoProcessor.from_pretrained(
|
| 150 |
+
model_config["model_repo"],
|
| 151 |
+
token=auth_token
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
self.tts_models[language_code] = session
|
| 155 |
+
self.processors[language_code] = processor
|
| 156 |
+
self.model_cache[language_code] = True
|
| 157 |
+
|
| 158 |
+
print(f"ONNX TTS: Successfully loaded ONNX model for {language_code}")
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
except Exception as e:
|
| 162 |
+
print(f"ONNX TTS: Failed to load ONNX model for {language_code}: {e}")
|
| 163 |
+
# Continue to try fallback models
|
| 164 |
+
else:
|
| 165 |
+
# Try PyTorch model if ONNX not available
|
| 166 |
+
try:
|
| 167 |
+
print(f"ONNX TTS: Loading PyTorch model for {language_code} (fallback)")
|
| 168 |
+
from transformers import pipeline
|
| 169 |
+
|
| 170 |
+
pipeline_obj = pipeline(
|
| 171 |
+
"text-to-speech",
|
| 172 |
+
model=model_config["model_repo"],
|
| 173 |
+
device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
|
| 174 |
+
)
|
| 175 |
+
self.tts_models[language_code] = pipeline_obj
|
| 176 |
+
self.processors[language_code] = None # Not needed for pipeline
|
| 177 |
+
self.model_cache[language_code] = True
|
| 178 |
+
|
| 179 |
+
print(f"ONNX TTS: Successfully loaded PyTorch model for {language_code}")
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"ONNX TTS: Failed to load PyTorch model for {language_code}: {e}")
|
| 184 |
+
|
| 185 |
+
# Try fallback models if primary failed
|
| 186 |
+
if language_code in self.fallback_tts_config:
|
| 187 |
+
try:
|
| 188 |
+
model_config = self.fallback_tts_config[language_code]
|
| 189 |
+
|
| 190 |
+
if model_config["model_type"] == "speecht5":
|
| 191 |
+
# Special handling for SpeechT5
|
| 192 |
+
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
|
| 193 |
+
import torch
|
| 194 |
+
|
| 195 |
+
# Get authentication token for private repos
|
| 196 |
+
import os
|
| 197 |
+
auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
|
| 198 |
+
|
| 199 |
+
processor = SpeechT5Processor.from_pretrained(
|
| 200 |
+
model_config["model_repo"],
|
| 201 |
+
token=auth_token
|
| 202 |
+
)
|
| 203 |
+
model = SpeechT5ForTextToSpeech.from_pretrained(
|
| 204 |
+
model_config["model_repo"],
|
| 205 |
+
token=auth_token
|
| 206 |
+
)
|
| 207 |
+
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
|
| 208 |
+
|
| 209 |
+
# Create a custom pipeline-like object
|
| 210 |
+
class SpeechT5Pipeline:
|
| 211 |
+
def __init__(self, processor, model, vocoder):
|
| 212 |
+
self.processor = processor
|
| 213 |
+
self.model = model
|
| 214 |
+
self.vocoder = vocoder
|
| 215 |
+
|
| 216 |
+
def __call__(self, text):
|
| 217 |
+
inputs = self.processor(text=text, return_tensors="pt")
|
| 218 |
+
# Use default speaker embeddings
|
| 219 |
+
import datasets
|
| 220 |
+
embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
|
| 221 |
+
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
|
| 222 |
+
|
| 223 |
+
speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
|
| 224 |
+
|
| 225 |
+
return {
|
| 226 |
+
"audio": speech.numpy(),
|
| 227 |
+
"sampling_rate": 16000
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
|
| 231 |
+
else:
|
| 232 |
+
# Standard pipeline for MMS models
|
| 233 |
+
from transformers import pipeline
|
| 234 |
+
pipeline_obj = pipeline(
|
| 235 |
+
"text-to-speech",
|
| 236 |
+
model=model_config["model_repo"],
|
| 237 |
+
device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
self.tts_models[language_code] = pipeline_obj
|
| 241 |
+
self.processors[language_code] = None
|
| 242 |
+
self.model_cache[language_code] = True
|
| 243 |
+
|
| 244 |
+
print(f"ONNX TTS: Successfully loaded fallback model for {language_code}")
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"ONNX TTS: Failed to load fallback TTS model for {language_code}: {e}")
|
| 249 |
+
|
| 250 |
+
print(f"ONNX TTS: No TTS model available for language: {language_code}")
|
| 251 |
+
|
| 252 |
+
async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
|
| 253 |
+
"""Generate speech audio from text using ONNX models
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
text: Text to convert to speech
|
| 257 |
+
language_code: Language code for TTS model
|
| 258 |
+
output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Audio bytes in the requested format, or None if generation fails
|
| 262 |
+
"""
|
| 263 |
+
try:
|
| 264 |
+
print(f"=== ONNX TTS GENERATION REQUEST ===")
|
| 265 |
+
print(f"Text: '{text}'")
|
| 266 |
+
print(f"Language: {language_code}")
|
| 267 |
+
print(f"Output format: {output_format}")
|
| 268 |
+
|
| 269 |
+
# Input validation
|
| 270 |
+
if not text or not text.strip():
|
| 271 |
+
print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
|
| 272 |
+
return None
|
| 273 |
+
|
| 274 |
+
# Check for very short text that might cause issues
|
| 275 |
+
clean_text = text.strip()
|
| 276 |
+
if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
|
| 277 |
+
print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
|
| 278 |
+
return None
|
| 279 |
+
|
| 280 |
+
# Check for minimum meaningful length
|
| 281 |
+
if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
|
| 282 |
+
print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
|
| 283 |
+
return None
|
| 284 |
+
|
| 285 |
+
# Check if the language is supported
|
| 286 |
+
if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
|
| 287 |
+
print(f"ONNX TTS: Language {language_code} not configured for TTS")
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
await self.ensure_model_loaded(language_code)
|
| 291 |
+
|
| 292 |
+
if language_code not in self.tts_models:
|
| 293 |
+
print(f"ONNX TTS: TTS model not available for language: {language_code}")
|
| 294 |
+
return None
|
| 295 |
+
|
| 296 |
+
print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
|
| 297 |
+
|
| 298 |
+
# Generate speech based on model type
|
| 299 |
+
model_config = self.tts_config.get(language_code, {})
|
| 300 |
+
if model_config.get("use_onnx", False):
|
| 301 |
+
# ONNX inference
|
| 302 |
+
audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
|
| 303 |
+
else:
|
| 304 |
+
# PyTorch pipeline inference
|
| 305 |
+
pipeline_obj = self.tts_models[language_code]
|
| 306 |
+
result = pipeline_obj(text)
|
| 307 |
+
|
| 308 |
+
audio_array = result["audio"]
|
| 309 |
+
sample_rate = result.get("sampling_rate", 16000) # Default to 16kHz (corrected)
|
| 310 |
+
|
| 311 |
+
print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
|
| 312 |
+
|
| 313 |
+
# Validate audio array
|
| 314 |
+
if len(audio_array) == 0:
|
| 315 |
+
print("ONNX TTS: Warning - Generated audio array is empty")
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
# Check audio statistics
|
| 319 |
+
audio_min = np.min(audio_array)
|
| 320 |
+
audio_max = np.max(audio_array)
|
| 321 |
+
audio_rms = np.sqrt(np.mean(audio_array**2))
|
| 322 |
+
print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
|
| 323 |
+
|
| 324 |
+
# Check if audio might be silent or corrupted
|
| 325 |
+
if audio_rms < 0.001:
|
| 326 |
+
print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
|
| 327 |
+
if audio_max > 1.0 or audio_min < -1.0:
|
| 328 |
+
print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
|
| 329 |
+
# Clip to valid range
|
| 330 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 331 |
+
print("ONNX TTS: Clipped audio to valid range")
|
| 332 |
+
|
| 333 |
+
# Convert to requested format
|
| 334 |
+
if output_format == "wav":
|
| 335 |
+
# For Android: use 16kHz sample rate
|
| 336 |
+
target_sample_rate = 16000
|
| 337 |
+
wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 338 |
+
print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
|
| 339 |
+
|
| 340 |
+
# Convert sample rate to 16kHz if needed for Android compatibility
|
| 341 |
+
if sample_rate != target_sample_rate:
|
| 342 |
+
print(f"ONNX TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz")
|
| 343 |
+
wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
|
| 344 |
+
print(f"ONNX TTS: Resampled WAV: {len(wav_bytes)} bytes")
|
| 345 |
+
|
| 346 |
+
print(f"ONNX TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
|
| 347 |
+
print(f"=== ONNX TTS GENERATION COMPLETE ===")
|
| 348 |
+
|
| 349 |
+
return wav_bytes
|
| 350 |
+
else:
|
| 351 |
+
# For web: use original sample rate and convert to WebM
|
| 352 |
+
wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 353 |
+
print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
|
| 354 |
+
|
| 355 |
+
# Convert to WebM format for web compatibility
|
| 356 |
+
webm_bytes = await self._convert_to_webm(wav_bytes)
|
| 357 |
+
|
| 358 |
+
print(f"ONNX TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
|
| 359 |
+
print(f"=== ONNX TTS GENERATION COMPLETE ===")
|
| 360 |
+
|
| 361 |
+
return webm_bytes
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"ONNX TTS: TTS generation error: {e}")
|
| 365 |
+
import traceback
|
| 366 |
+
traceback.print_exc()
|
| 367 |
+
return None
|
| 368 |
+
|
| 369 |
+
async def _run_onnx_tts_inference(self, text: str, language_code: str) -> tuple[np.ndarray, int]:
|
| 370 |
+
"""Run ONNX inference for text-to-speech"""
|
| 371 |
+
try:
|
| 372 |
+
session = self.tts_models[language_code]
|
| 373 |
+
processor = self.processors[language_code]
|
| 374 |
+
|
| 375 |
+
# Preprocess text
|
| 376 |
+
inputs = processor(text=text, return_tensors="np")
|
| 377 |
+
|
| 378 |
+
# Get input names for ONNX session
|
| 379 |
+
input_names = [inp.name for inp in session.get_inputs()]
|
| 380 |
+
|
| 381 |
+
# Prepare inputs for ONNX
|
| 382 |
+
onnx_inputs = {}
|
| 383 |
+
for name in input_names:
|
| 384 |
+
if name in inputs:
|
| 385 |
+
onnx_inputs[name] = inputs[name]
|
| 386 |
+
elif name == "input_ids" and "input_ids" in inputs:
|
| 387 |
+
onnx_inputs[name] = inputs["input_ids"].astype(np.int64)
|
| 388 |
+
elif name == "attention_mask" and "attention_mask" in inputs:
|
| 389 |
+
onnx_inputs[name] = inputs["attention_mask"].astype(np.int64)
|
| 390 |
+
|
| 391 |
+
# Run ONNX inference
|
| 392 |
+
outputs = session.run(None, onnx_inputs)
|
| 393 |
+
|
| 394 |
+
# Extract audio from outputs (assuming first output is audio)
|
| 395 |
+
audio_array = outputs[0]
|
| 396 |
+
|
| 397 |
+
# Ensure audio is 1D
|
| 398 |
+
if audio_array.ndim > 1:
|
| 399 |
+
audio_array = audio_array.flatten()
|
| 400 |
+
|
| 401 |
+
# Convert to float32 if needed
|
| 402 |
+
if audio_array.dtype != np.float32:
|
| 403 |
+
audio_array = audio_array.astype(np.float32)
|
| 404 |
+
|
| 405 |
+
# Sample rate is 16kHz for our corrected models
|
| 406 |
+
sample_rate = 16000
|
| 407 |
+
|
| 408 |
+
return audio_array, sample_rate
|
| 409 |
+
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(f"ONNX TTS: Inference error: {e}")
|
| 412 |
+
import traceback
|
| 413 |
+
traceback.print_exc()
|
| 414 |
+
return np.array([], dtype=np.float32), 16000
|
| 415 |
+
|
| 416 |
+
async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
|
| 417 |
+
"""Generate speech audio in both WebM and WAV formats using ONNX
|
| 418 |
+
|
| 419 |
+
Args:
|
| 420 |
+
text: Text to convert to speech
|
| 421 |
+
language_code: Language code for TTS model
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
|
| 425 |
+
"""
|
| 426 |
+
try:
|
| 427 |
+
print(f"=== ONNX TTS DUAL FORMAT GENERATION REQUEST ===")
|
| 428 |
+
print(f"Text: '{text}'")
|
| 429 |
+
print(f"Language: {language_code}")
|
| 430 |
+
|
| 431 |
+
# Input validation
|
| 432 |
+
if not text or not text.strip():
|
| 433 |
+
print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
|
| 434 |
+
return None, None
|
| 435 |
+
|
| 436 |
+
clean_text = text.strip()
|
| 437 |
+
if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
|
| 438 |
+
print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
|
| 439 |
+
return None, None
|
| 440 |
+
|
| 441 |
+
if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
|
| 442 |
+
print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
|
| 443 |
+
return None, None
|
| 444 |
+
|
| 445 |
+
# Check if the language is supported
|
| 446 |
+
if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
|
| 447 |
+
print(f"ONNX TTS: Language {language_code} not configured for TTS")
|
| 448 |
+
return None, None
|
| 449 |
+
|
| 450 |
+
await self.ensure_model_loaded(language_code)
|
| 451 |
+
|
| 452 |
+
if language_code not in self.tts_models:
|
| 453 |
+
print(f"ONNX TTS: TTS model not available for language: {language_code}")
|
| 454 |
+
return None, None
|
| 455 |
+
|
| 456 |
+
print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
|
| 457 |
+
|
| 458 |
+
# Generate speech once
|
| 459 |
+
model_config = self.tts_config.get(language_code, {})
|
| 460 |
+
if model_config.get("use_onnx", False):
|
| 461 |
+
# ONNX inference
|
| 462 |
+
audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
|
| 463 |
+
else:
|
| 464 |
+
# PyTorch pipeline inference
|
| 465 |
+
pipeline_obj = self.tts_models[language_code]
|
| 466 |
+
result = pipeline_obj(text)
|
| 467 |
+
|
| 468 |
+
audio_array = result["audio"]
|
| 469 |
+
sample_rate = result.get("sampling_rate", 16000)
|
| 470 |
+
|
| 471 |
+
print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
|
| 472 |
+
|
| 473 |
+
# Validate audio array
|
| 474 |
+
if len(audio_array) == 0:
|
| 475 |
+
print("ONNX TTS: Warning - Generated audio array is empty")
|
| 476 |
+
return None, None
|
| 477 |
+
|
| 478 |
+
# Check for potential issues with audio data
|
| 479 |
+
audio_min = np.min(audio_array)
|
| 480 |
+
audio_max = np.max(audio_array)
|
| 481 |
+
audio_rms = np.sqrt(np.mean(audio_array**2))
|
| 482 |
+
print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
|
| 483 |
+
|
| 484 |
+
if audio_rms < 0.001:
|
| 485 |
+
print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
|
| 486 |
+
if audio_max > 1.0 or audio_min < -1.0:
|
| 487 |
+
print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
|
| 488 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 489 |
+
print("ONNX TTS: Clipped audio to valid range")
|
| 490 |
+
|
| 491 |
+
# Generate WAV at original sample rate first
|
| 492 |
+
wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
|
| 493 |
+
print(f"ONNX TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
|
| 494 |
+
|
| 495 |
+
# Generate WebM from original WAV
|
| 496 |
+
webm_bytes = await self._convert_to_webm(wav_bytes_original)
|
| 497 |
+
print(f"ONNX TTS: Converted to WebM: {len(webm_bytes)} bytes")
|
| 498 |
+
|
| 499 |
+
# Generate 16kHz WAV for Android
|
| 500 |
+
wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
|
| 501 |
+
print(f"ONNX TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
|
| 502 |
+
|
| 503 |
+
print(f"ONNX TTS: Generated dual format audio for '{text}'")
|
| 504 |
+
print(f"=== ONNX TTS DUAL FORMAT GENERATION COMPLETE ===")
|
| 505 |
+
|
| 506 |
+
return webm_bytes, wav_bytes_16k
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
print(f"ONNX TTS: Dual format TTS generation error: {e}")
|
| 510 |
+
import traceback
|
| 511 |
+
traceback.print_exc()
|
| 512 |
+
return None, None
|
| 513 |
+
|
| 514 |
+
def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
|
| 515 |
+
"""Convert numpy audio array to WAV bytes"""
|
| 516 |
+
buffer = io.BytesIO()
|
| 517 |
+
with wave.open(buffer, 'wb') as wav_file:
|
| 518 |
+
wav_file.setnchannels(1) # Mono
|
| 519 |
+
wav_file.setsampwidth(2) # 16-bit
|
| 520 |
+
wav_file.setframerate(sample_rate)
|
| 521 |
+
|
| 522 |
+
# Ensure audio is in valid range [-1, 1]
|
| 523 |
+
audio_array = np.clip(audio_array, -1.0, 1.0)
|
| 524 |
+
|
| 525 |
+
# Convert to int16 with proper scaling
|
| 526 |
+
int16_audio = (audio_array * 32767).astype(np.int16)
|
| 527 |
+
|
| 528 |
+
# Validate the converted audio
|
| 529 |
+
print(f"ONNX TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
|
| 530 |
+
print(f"ONNX TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
|
| 531 |
+
|
| 532 |
+
wav_file.writeframes(int16_audio.tobytes())
|
| 533 |
+
|
| 534 |
+
wav_data = buffer.getvalue()
|
| 535 |
+
print(f"ONNX TTS: WAV file created: {len(wav_data)} bytes")
|
| 536 |
+
|
| 537 |
+
return wav_data
|
| 538 |
+
|
| 539 |
+
async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
|
| 540 |
+
"""Resample WAV audio to 16kHz using FFmpeg"""
|
| 541 |
+
try:
|
| 542 |
+
process = subprocess.Popen([
|
| 543 |
+
"ffmpeg", "-f", "wav", "-i", "pipe:0",
|
| 544 |
+
"-ar", "16000", # Set output sample rate to 16kHz
|
| 545 |
+
"-ac", "1", # Ensure mono output
|
| 546 |
+
"-f", "wav", "pipe:1"
|
| 547 |
+
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 548 |
+
|
| 549 |
+
resampled_data, stderr = process.communicate(input=wav_bytes)
|
| 550 |
+
|
| 551 |
+
if process.returncode != 0:
|
| 552 |
+
print(f"ONNX TTS: FFmpeg resampling error: {stderr.decode()}")
|
| 553 |
+
return wav_bytes # Return original if resampling fails
|
| 554 |
+
|
| 555 |
+
return resampled_data
|
| 556 |
+
|
| 557 |
+
except Exception as e:
|
| 558 |
+
print(f"ONNX TTS: Resampling error: {e}")
|
| 559 |
+
return wav_bytes # Return original if resampling fails
|
| 560 |
+
|
| 561 |
+
async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
|
| 562 |
+
"""Convert WAV bytes to WebM format using FFmpeg"""
|
| 563 |
+
try:
|
| 564 |
+
process = subprocess.Popen([
|
| 565 |
+
"ffmpeg", "-f", "wav", "-i", "pipe:0",
|
| 566 |
+
"-c:a", "libopus", "-b:a", "64k",
|
| 567 |
+
"-f", "webm", "pipe:1"
|
| 568 |
+
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| 569 |
+
|
| 570 |
+
webm_data, stderr = process.communicate(input=wav_bytes)
|
| 571 |
+
|
| 572 |
+
if process.returncode != 0:
|
| 573 |
+
print(f"ONNX TTS: FFmpeg error: {stderr.decode()}")
|
| 574 |
+
return wav_bytes # Return original WAV if conversion fails
|
| 575 |
+
|
| 576 |
+
return webm_data
|
| 577 |
+
|
| 578 |
+
except Exception as e:
|
| 579 |
+
print(f"ONNX TTS: WebM conversion error: {e}")
|
| 580 |
+
return wav_bytes # Return original WAV if conversion fails
|
| 581 |
+
|
| 582 |
+
async def cleanup(self):
|
| 583 |
+
"""Cleanup resources"""
|
| 584 |
+
self.tts_models.clear()
|
| 585 |
+
self.processors.clear()
|
| 586 |
+
self.model_cache.clear()
|
| 587 |
+
print("ONNX TTS: TTS service cleaned up")
|
app/services/websocket_manager.py
ADDED
|
@@ -0,0 +1,909 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import uuid
|
| 3 |
+
from typing import Dict, Set, Optional
|
| 4 |
+
import socketio
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from app.models import Message, LanguageCode
|
| 8 |
+
from app.services.session_manager import SessionManager, LANGUAGE_MAP
|
| 9 |
+
from app.services.transcription_service import TranscriptionService
|
| 10 |
+
from app.services.translation_service import TranslationService
|
| 11 |
+
from app.services.tts_service import TTSService
|
| 12 |
+
|
| 13 |
+
def truncate_array_for_log(arr, max_items=10):
|
| 14 |
+
"""Helper function to truncate arrays in log messages for readability"""
|
| 15 |
+
if not arr or len(arr) <= max_items:
|
| 16 |
+
return arr
|
| 17 |
+
return arr[:max_items] + [f"... {len(arr) - max_items} more items"]
|
| 18 |
+
|
| 19 |
+
class WebSocketManager:
|
| 20 |
+
def __init__(self, session_manager: SessionManager, transcription_service: TranscriptionService,
|
| 21 |
+
translation_service: TranslationService, tts_service: TTSService):
|
| 22 |
+
self.session_manager = session_manager
|
| 23 |
+
self.transcription_service = transcription_service
|
| 24 |
+
self.translation_service = translation_service
|
| 25 |
+
self.tts_service = tts_service
|
| 26 |
+
self.sio = None # Will be set by main.py
|
| 27 |
+
|
| 28 |
+
self.client_sessions: Dict[str, str] = {} # sid -> session_id
|
| 29 |
+
self.client_participants: Dict[str, str] = {} # sid -> participant_id
|
| 30 |
+
self.session_clients: Dict[str, Set[str]] = {} # session_id -> set of sids
|
| 31 |
+
|
| 32 |
+
self.messages: Dict[str, Message] = {} # message_id -> message
|
| 33 |
+
self.participant_current_message: Dict[str, str] = {} # participant_id -> current_message_id
|
| 34 |
+
self.processed_messages: Set[str] = set() # Track processed message IDs to prevent duplicates
|
| 35 |
+
|
| 36 |
+
def set_socketio(self, sio):
|
| 37 |
+
"""Set the Socket.IO server instance"""
|
| 38 |
+
self.sio = sio
|
| 39 |
+
|
| 40 |
+
async def handle_join_session(self, sid: str, data: dict):
|
| 41 |
+
"""Handle participant joining a session"""
|
| 42 |
+
try:
|
| 43 |
+
session_id = data.get('sessionId')
|
| 44 |
+
participant_name = data.get('participantName')
|
| 45 |
+
language_code = data.get('language')
|
| 46 |
+
|
| 47 |
+
print(f"=== JOIN SESSION REQUEST ===")
|
| 48 |
+
print(f"Session ID: {session_id}")
|
| 49 |
+
print(f"Participant: {participant_name}")
|
| 50 |
+
print(f"Language: {language_code}")
|
| 51 |
+
|
| 52 |
+
if not all([session_id, participant_name, language_code]):
|
| 53 |
+
await self._emit_error(sid, "Missing required fields")
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
# Validate language code
|
| 57 |
+
try:
|
| 58 |
+
lang_enum = LanguageCode(language_code)
|
| 59 |
+
print(f"Language code validated: {lang_enum}")
|
| 60 |
+
except ValueError:
|
| 61 |
+
await self._emit_error(sid, f"Invalid language code: {language_code}")
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
# Resolve session ID (in case it's a short code)
|
| 65 |
+
session = await self.session_manager.get_session(session_id)
|
| 66 |
+
if not session:
|
| 67 |
+
await self._emit_error(sid, "Session not found")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
# Use the full UUID for all subsequent operations
|
| 71 |
+
session_id = session.id
|
| 72 |
+
print(f"Resolved session ID: {session_id}")
|
| 73 |
+
|
| 74 |
+
# Add participant to session
|
| 75 |
+
participant = await self.session_manager.add_participant(
|
| 76 |
+
session_id, participant_name, lang_enum
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
print(f"Participant created: {participant}")
|
| 80 |
+
|
| 81 |
+
if not participant:
|
| 82 |
+
await self._emit_error(sid, "Session not found or unable to join")
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
# Get updated session info
|
| 86 |
+
session = await self.session_manager.get_session(session_id)
|
| 87 |
+
if session:
|
| 88 |
+
print(f"Session {session_id} now has {len(session.languages)} languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
|
| 89 |
+
print(f"Session participants: {[f'{p.name}({p.language.name})' for p in session.participants]}")
|
| 90 |
+
|
| 91 |
+
# Track client connections
|
| 92 |
+
self.client_sessions[sid] = session_id
|
| 93 |
+
self.client_participants[sid] = participant.id
|
| 94 |
+
|
| 95 |
+
if session_id not in self.session_clients:
|
| 96 |
+
self.session_clients[session_id] = set()
|
| 97 |
+
self.session_clients[session_id].add(sid)
|
| 98 |
+
|
| 99 |
+
# Send success response
|
| 100 |
+
await self.sio.emit('participant_joined', participant.dict(), room=sid)
|
| 101 |
+
|
| 102 |
+
# Notify other participants
|
| 103 |
+
await self._broadcast_to_session(session_id, 'participant_update', participant.dict(), exclude_sid=sid)
|
| 104 |
+
|
| 105 |
+
print(f"=== JOIN SESSION COMPLETE ===")
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"Error in handle_join_session: {e}")
|
| 109 |
+
import traceback
|
| 110 |
+
traceback.print_exc()
|
| 111 |
+
await self._emit_error(sid, "Failed to join session")
|
| 112 |
+
|
| 113 |
+
async def handle_join_hub(self, sid: str, data: dict):
|
| 114 |
+
"""Handle hub joining a session for observation"""
|
| 115 |
+
try:
|
| 116 |
+
session_id = data.get('sessionId')
|
| 117 |
+
if not session_id:
|
| 118 |
+
await self._emit_error(sid, "Missing sessionId for hub")
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
# Verify session exists
|
| 122 |
+
session = await self.session_manager.get_session(session_id)
|
| 123 |
+
if not session:
|
| 124 |
+
await self._emit_error(sid, "Session not found")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
# Track hub connection
|
| 128 |
+
self.client_sessions[sid] = session_id
|
| 129 |
+
|
| 130 |
+
if session_id not in self.session_clients:
|
| 131 |
+
self.session_clients[session_id] = set()
|
| 132 |
+
self.session_clients[session_id].add(sid)
|
| 133 |
+
|
| 134 |
+
# Send success response
|
| 135 |
+
await self.sio.emit('hub_joined', {'sessionId': session_id}, room=sid)
|
| 136 |
+
|
| 137 |
+
print(f"Hub joined session {session_id} with sid {sid}")
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"Error in handle_join_hub: {e}")
|
| 141 |
+
await self._emit_error(sid, "Failed to join as hub")
|
| 142 |
+
|
| 143 |
+
async def handle_audio_chunk(self, sid: str, data: dict):
|
| 144 |
+
"""Handle incoming audio chunk from participant"""
|
| 145 |
+
try:
|
| 146 |
+
participant_id = self.client_participants.get(sid)
|
| 147 |
+
if not participant_id:
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
audio_data = data.get('audioData', [])
|
| 151 |
+
is_pause_boundary = data.get('isPauseBoundary', False)
|
| 152 |
+
|
| 153 |
+
if not audio_data:
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
# Convert array to bytes
|
| 157 |
+
audio_bytes = bytes(audio_data)
|
| 158 |
+
|
| 159 |
+
# Process audio chunk using VAD-based approach
|
| 160 |
+
if audio_bytes:
|
| 161 |
+
# Check for voice activity in this chunk
|
| 162 |
+
has_voice = self.transcription_service.has_voice_activity(audio_bytes)
|
| 163 |
+
|
| 164 |
+
# Process the chunk (even if no voice to handle silence detection)
|
| 165 |
+
# If isPauseBoundary is True, force finalization by treating as silence
|
| 166 |
+
await self._process_audio_chunk_vad(participant_id, audio_bytes, has_voice and not is_pause_boundary, is_pause_boundary)
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"Error in handle_audio_chunk: {e}")
|
| 170 |
+
import traceback
|
| 171 |
+
traceback.print_exc()
|
| 172 |
+
|
| 173 |
+
async def handle_speaking_status(self, sid: str, data: dict):
|
| 174 |
+
"""Handle speaking status updates"""
|
| 175 |
+
try:
|
| 176 |
+
participant_id = self.client_participants.get(sid)
|
| 177 |
+
if not participant_id:
|
| 178 |
+
return
|
| 179 |
+
|
| 180 |
+
is_speaking = data.get('isSpeaking', False)
|
| 181 |
+
await self.session_manager.update_participant_speaking_status(participant_id, is_speaking)
|
| 182 |
+
|
| 183 |
+
# If participant stopped speaking, force complete any pending sentence
|
| 184 |
+
if not is_speaking:
|
| 185 |
+
# Get session and participant info for force completion
|
| 186 |
+
session_id = await self.session_manager.get_participant_session_id(participant_id)
|
| 187 |
+
if session_id:
|
| 188 |
+
session = await self.session_manager.get_session(session_id)
|
| 189 |
+
participant = next((p for p in session.participants if p.id == participant_id), None)
|
| 190 |
+
|
| 191 |
+
if participant:
|
| 192 |
+
# Define the sentence callback for force completion
|
| 193 |
+
async def force_sentence_callback(final_text: str, final_audio: bytes):
|
| 194 |
+
# Create or get existing message
|
| 195 |
+
current_message_id = self.participant_current_message.get(participant_id)
|
| 196 |
+
if not current_message_id:
|
| 197 |
+
current_message_id = str(uuid.uuid4())
|
| 198 |
+
|
| 199 |
+
# Check if this message was already processed
|
| 200 |
+
if current_message_id in self.processed_messages:
|
| 201 |
+
print(f"Force completion: Message {current_message_id} already processed, skipping duplicate")
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
# Mark as processed to prevent duplicates
|
| 205 |
+
self.processed_messages.add(current_message_id)
|
| 206 |
+
|
| 207 |
+
from app.models import Message
|
| 208 |
+
message = Message(
|
| 209 |
+
id=current_message_id,
|
| 210 |
+
session_id=session_id,
|
| 211 |
+
speaker_id=participant_id,
|
| 212 |
+
speaker_name=participant.name,
|
| 213 |
+
original_text=final_text,
|
| 214 |
+
original_language=participant.language,
|
| 215 |
+
translations={},
|
| 216 |
+
is_transcribing=False
|
| 217 |
+
)
|
| 218 |
+
self.messages[current_message_id] = message
|
| 219 |
+
|
| 220 |
+
# Broadcast the completed message
|
| 221 |
+
print(f"Force completion: Broadcasting message_complete for {current_message_id}: '{final_text}'")
|
| 222 |
+
await self._broadcast_to_session(session_id, 'message_complete', {
|
| 223 |
+
'messageId': current_message_id,
|
| 224 |
+
'sessionId': session_id,
|
| 225 |
+
'text': final_text,
|
| 226 |
+
'speakerId': participant_id,
|
| 227 |
+
'speakerName': participant.name,
|
| 228 |
+
'language': participant.language.code.value
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
# Clear current message tracking
|
| 232 |
+
if participant_id in self.participant_current_message:
|
| 233 |
+
del self.participant_current_message[participant_id]
|
| 234 |
+
|
| 235 |
+
# Start translation processing (non-blocking to allow continued audio processing)
|
| 236 |
+
print("Starting TRANSLATION and TTS (background task)")
|
| 237 |
+
asyncio.create_task(self._process_translations_and_tts(message, session))
|
| 238 |
+
|
| 239 |
+
# Force complete any pending sentence
|
| 240 |
+
await self.transcription_service.force_complete_sentence(
|
| 241 |
+
participant_id,
|
| 242 |
+
participant.language.code.value,
|
| 243 |
+
force_sentence_callback
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Clear transcription service buffers after force completion
|
| 247 |
+
self.transcription_service.clear_participant_buffers(participant_id)
|
| 248 |
+
|
| 249 |
+
# Broadcast speaking status to session
|
| 250 |
+
session_id = self.client_sessions.get(sid)
|
| 251 |
+
if session_id:
|
| 252 |
+
await self._broadcast_to_session(session_id, 'speaking_status', {
|
| 253 |
+
'participantId': participant_id,
|
| 254 |
+
'isSpeaking': is_speaking
|
| 255 |
+
})
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
print(f"Error in handle_speaking_status: {e}")
|
| 259 |
+
import traceback
|
| 260 |
+
traceback.print_exc()
|
| 261 |
+
|
| 262 |
+
async def handle_leave_session(self, sid: str, data: dict):
|
| 263 |
+
"""Handle participant leaving a session"""
|
| 264 |
+
await self._cleanup_client(sid)
|
| 265 |
+
|
| 266 |
+
async def handle_disconnect(self, sid: str):
|
| 267 |
+
"""Handle client disconnection"""
|
| 268 |
+
await self._cleanup_client(sid)
|
| 269 |
+
async def _process_audio_chunk_vad(self, participant_id: str, audio_data: bytes, has_voice_activity: bool, is_pause_boundary: bool = False):
|
| 270 |
+
"""Process audio chunk using VAD-based sentence detection
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
participant_id: ID of the participant
|
| 274 |
+
audio_data: Raw audio data bytes
|
| 275 |
+
has_voice_activity: Whether voice activity was detected in this chunk
|
| 276 |
+
is_pause_boundary: If True, forces sentence finalization (from stop button or explicit pause)
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
session_id = await self.session_manager.get_participant_session_id(participant_id)
|
| 280 |
+
if not session_id:
|
| 281 |
+
return
|
| 282 |
+
|
| 283 |
+
session = await self.session_manager.get_session(session_id)
|
| 284 |
+
if not session:
|
| 285 |
+
return
|
| 286 |
+
|
| 287 |
+
participant = next((p for p in session.participants if p.id == participant_id), None)
|
| 288 |
+
if not participant:
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
# Get or create current message for this participant
|
| 292 |
+
current_message_id = self.participant_current_message.get(participant_id)
|
| 293 |
+
if not current_message_id:
|
| 294 |
+
current_message_id = str(uuid.uuid4())
|
| 295 |
+
message = Message(
|
| 296 |
+
id=current_message_id,
|
| 297 |
+
session_id=session_id,
|
| 298 |
+
speaker_id=participant_id,
|
| 299 |
+
speaker_name=participant.name,
|
| 300 |
+
original_text="",
|
| 301 |
+
original_language=participant.language,
|
| 302 |
+
translations={},
|
| 303 |
+
is_transcribing=True
|
| 304 |
+
)
|
| 305 |
+
self.messages[current_message_id] = message
|
| 306 |
+
self.participant_current_message[participant_id] = current_message_id
|
| 307 |
+
|
| 308 |
+
# Start typing indicator
|
| 309 |
+
await self._broadcast_to_session(session_id, 'typing_start', {
|
| 310 |
+
'speakerId': participant_id,
|
| 311 |
+
'speakerName': participant.name,
|
| 312 |
+
'languageCode': participant.language.code.value
|
| 313 |
+
})
|
| 314 |
+
|
| 315 |
+
message = self.messages[current_message_id]
|
| 316 |
+
|
| 317 |
+
# Define callbacks
|
| 318 |
+
async def on_progress(text: str, is_complete: bool):
|
| 319 |
+
"""Called with in-progress transcription updates"""
|
| 320 |
+
# Update the message text even for progress updates
|
| 321 |
+
message.original_text = text
|
| 322 |
+
|
| 323 |
+
await self._broadcast_to_session(session_id, 'transcription_progress', {
|
| 324 |
+
'messageId': current_message_id,
|
| 325 |
+
'text': text,
|
| 326 |
+
'isTranscribing': not is_complete,
|
| 327 |
+
'speakerId': participant_id,
|
| 328 |
+
'speakerName': participant.name
|
| 329 |
+
})
|
| 330 |
+
|
| 331 |
+
async def on_debug(debug_info: dict):
|
| 332 |
+
"""Called with debug information from ASR (wav2vec2 models only)"""
|
| 333 |
+
# Prepare debug data for transmission
|
| 334 |
+
debug_data = {
|
| 335 |
+
'messageId': current_message_id,
|
| 336 |
+
'text': debug_info['text'],
|
| 337 |
+
'timestamps': debug_info['timestamps'],
|
| 338 |
+
'audioData': list(debug_info['audio_data']),
|
| 339 |
+
'audioDuration': debug_info['audio_duration'],
|
| 340 |
+
'modelType': debug_info['model_type'],
|
| 341 |
+
'language': participant.language.code.value
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
await self._broadcast_to_session(session_id, 'transcription_debug', debug_data)
|
| 345 |
+
|
| 346 |
+
async def on_sentence_complete(final_text: str, final_audio: bytes):
|
| 347 |
+
"""Called when a complete sentence is detected"""
|
| 348 |
+
|
| 349 |
+
# Check if this message was already processed
|
| 350 |
+
if current_message_id in self.processed_messages:
|
| 351 |
+
print(f"Message {current_message_id} already processed, skipping duplicate")
|
| 352 |
+
return
|
| 353 |
+
|
| 354 |
+
# Mark as processed to prevent duplicates
|
| 355 |
+
self.processed_messages.add(current_message_id)
|
| 356 |
+
|
| 357 |
+
message.original_text = final_text
|
| 358 |
+
message.is_transcribing = False
|
| 359 |
+
|
| 360 |
+
# Broadcast complete sentence with session ID
|
| 361 |
+
message_data = {
|
| 362 |
+
'messageId': current_message_id,
|
| 363 |
+
'sessionId': session_id,
|
| 364 |
+
'text': final_text,
|
| 365 |
+
'speakerId': participant_id,
|
| 366 |
+
'speakerName': participant.name,
|
| 367 |
+
'language': participant.language.code.value,
|
| 368 |
+
'audioData': list(final_audio)
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
print(f"Broadcasting message_complete for {current_message_id}: '{final_text}'")
|
| 372 |
+
await self._broadcast_to_session(session_id, 'message_complete', message_data)
|
| 373 |
+
|
| 374 |
+
# Stop typing indicator
|
| 375 |
+
await self._broadcast_to_session(session_id, 'typing_stop', {
|
| 376 |
+
'speakerId': participant_id
|
| 377 |
+
})
|
| 378 |
+
|
| 379 |
+
# Clear current message tracking
|
| 380 |
+
if participant_id in self.participant_current_message:
|
| 381 |
+
del self.participant_current_message[participant_id]
|
| 382 |
+
|
| 383 |
+
# Start translation and TTS processing (non-blocking to allow continued audio processing)
|
| 384 |
+
print("Starting TRANSLATION and TTS (background task)")
|
| 385 |
+
asyncio.create_task(self._process_translations_and_tts(message, session))
|
| 386 |
+
|
| 387 |
+
# Process the audio chunk
|
| 388 |
+
result_text = await self.transcription_service.process_audio_chunk(
|
| 389 |
+
audio_data,
|
| 390 |
+
participant.language.code.value,
|
| 391 |
+
participant_id,
|
| 392 |
+
has_voice_activity,
|
| 393 |
+
progress_callback=on_progress,
|
| 394 |
+
sentence_callback=on_sentence_complete,
|
| 395 |
+
debug_callback=on_debug
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# If this is a pause boundary (stop button clicked), force immediate finalization
|
| 399 |
+
if is_pause_boundary and participant_id in self.participant_current_message:
|
| 400 |
+
print(f"Pause boundary detected - forcing sentence finalization for participant {participant_id}")
|
| 401 |
+
# Get the current accumulated text from transcription service
|
| 402 |
+
if hasattr(self.transcription_service, 'candidate_text_cache') and participant_id in self.transcription_service.candidate_text_cache:
|
| 403 |
+
final_text = self.transcription_service.candidate_text_cache.get(participant_id, "").strip()
|
| 404 |
+
if final_text: # Only finalize if there's actual text
|
| 405 |
+
# Get accumulated audio
|
| 406 |
+
final_audio = b""
|
| 407 |
+
if hasattr(self.transcription_service, 'candidate_audio_buffers') and participant_id in self.transcription_service.candidate_audio_buffers:
|
| 408 |
+
audio_array = self.transcription_service.candidate_audio_buffers.get(participant_id, np.array([]))
|
| 409 |
+
if len(audio_array) > 0:
|
| 410 |
+
# Convert float array to int16 bytes
|
| 411 |
+
audio_int16 = (audio_array * 32767).astype(np.int16)
|
| 412 |
+
final_audio = audio_int16.tobytes()
|
| 413 |
+
|
| 414 |
+
# Trigger sentence completion
|
| 415 |
+
await on_sentence_complete(final_text, final_audio)
|
| 416 |
+
|
| 417 |
+
# Clear the buffers manually since we're forcing finalization
|
| 418 |
+
if participant_id in self.transcription_service.candidate_text_cache:
|
| 419 |
+
self.transcription_service.candidate_text_cache[participant_id] = ""
|
| 420 |
+
if participant_id in self.transcription_service.candidate_audio_buffers:
|
| 421 |
+
self.transcription_service.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
|
| 422 |
+
if participant_id in self.transcription_service.silence_counters:
|
| 423 |
+
self.transcription_service.silence_counters[participant_id] = 0
|
| 424 |
+
if participant_id in self.transcription_service.sentence_finalized:
|
| 425 |
+
self.transcription_service.sentence_finalized[participant_id] = False
|
| 426 |
+
|
| 427 |
+
except Exception as e:
|
| 428 |
+
print(f"Error in _process_audio_chunk_vad: {e}")
|
| 429 |
+
import traceback
|
| 430 |
+
traceback.print_exc()
|
| 431 |
+
|
| 432 |
+
async def _process_translations_and_tts(self, message: Message, session):
|
| 433 |
+
"""Process translations and TTS for all session languages"""
|
| 434 |
+
try:
|
| 435 |
+
source_lang = message.original_language.name
|
| 436 |
+
|
| 437 |
+
print(f"=== TRANSLATION/TTS PROCESSING START ===")
|
| 438 |
+
print(f"Message ID: {message.id}")
|
| 439 |
+
print(f"Original message: '{message.original_text}'")
|
| 440 |
+
print(f"Original language: {message.original_language.name} ({message.original_language.code.value})")
|
| 441 |
+
print(f"Session languages: {[f'{lang.name} ({lang.code.value})' for lang in session.languages]}")
|
| 442 |
+
print(f"Session ID for verification: {session.id}")
|
| 443 |
+
|
| 444 |
+
# Create a mapping to track which audio belongs to which message and language
|
| 445 |
+
audio_tasks = []
|
| 446 |
+
|
| 447 |
+
# Check if TTS is enabled for this session
|
| 448 |
+
if session.enable_tts:
|
| 449 |
+
# First, generate TTS for the original message
|
| 450 |
+
print(f"TTS: Generating TTS for original message in {message.original_language.code.value}: '{message.original_text}'")
|
| 451 |
+
print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{message.original_language.name.lower()}) - File: tts_service_onnx.py")
|
| 452 |
+
original_audio_task = asyncio.create_task(
|
| 453 |
+
self.tts_service.generate_speech_dual_format(message.original_text, message.original_language.code.value)
|
| 454 |
+
)
|
| 455 |
+
audio_tasks.append((
|
| 456 |
+
message.original_language.code.value,
|
| 457 |
+
message.original_text,
|
| 458 |
+
original_audio_task,
|
| 459 |
+
True # is_original
|
| 460 |
+
))
|
| 461 |
+
else:
|
| 462 |
+
print(f"TTS: Skipping TTS generation (disabled for this session)")
|
| 463 |
+
|
| 464 |
+
# Process translations for each language in the session
|
| 465 |
+
print(f"Processing translations for {len(session.languages)} session languages...")
|
| 466 |
+
print(f"Session languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
|
| 467 |
+
translation_tasks = []
|
| 468 |
+
|
| 469 |
+
for language in session.languages:
|
| 470 |
+
print(f"Checking language: {language.name} ({language.code.value})")
|
| 471 |
+
if language.code != message.original_language.code:
|
| 472 |
+
print(f"TRANSLATING: '{message.original_text}' from {source_lang} to {language.name}")
|
| 473 |
+
print(f"Translation Model: mutisya/nllb_600m (NLLB-600M) - File: translation_service.py")
|
| 474 |
+
|
| 475 |
+
# Create translation task
|
| 476 |
+
translation_task = asyncio.create_task(
|
| 477 |
+
self.translation_service.translate_text(
|
| 478 |
+
message.original_text, source_lang, language.name
|
| 479 |
+
)
|
| 480 |
+
)
|
| 481 |
+
translation_tasks.append((language, translation_task))
|
| 482 |
+
else:
|
| 483 |
+
print(f"SKIPPING translation for {language.name} (same as original language)")
|
| 484 |
+
|
| 485 |
+
print(f"Created {len(translation_tasks)} translation tasks for non-original languages")
|
| 486 |
+
|
| 487 |
+
# Wait for all translations to complete
|
| 488 |
+
for language, translation_task in translation_tasks:
|
| 489 |
+
try:
|
| 490 |
+
translated_text = await translation_task
|
| 491 |
+
|
| 492 |
+
if translated_text:
|
| 493 |
+
print(f"TRANSLATION SUCCESS: '{translated_text}' for {language.name}")
|
| 494 |
+
|
| 495 |
+
message.translations[language.code.value] = translated_text
|
| 496 |
+
|
| 497 |
+
# Broadcast translation update to all clients
|
| 498 |
+
await self._broadcast_to_session(message.session_id, 'translation_update', {
|
| 499 |
+
'messageId': message.id,
|
| 500 |
+
'targetLanguage': language.code.value,
|
| 501 |
+
'translatedText': translated_text
|
| 502 |
+
})
|
| 503 |
+
|
| 504 |
+
# Check if TTS is enabled for this session
|
| 505 |
+
if session.enable_tts:
|
| 506 |
+
# Generate TTS for the translated text
|
| 507 |
+
print(f"TTS: Generating TTS for translation in {language.code.value}: '{translated_text}'")
|
| 508 |
+
print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{language.name.lower()}) - File: tts_service_onnx.py")
|
| 509 |
+
tts_task = asyncio.create_task(
|
| 510 |
+
self.tts_service.generate_speech_dual_format(translated_text, language.code.value)
|
| 511 |
+
)
|
| 512 |
+
audio_tasks.append((
|
| 513 |
+
language.code.value,
|
| 514 |
+
translated_text,
|
| 515 |
+
tts_task,
|
| 516 |
+
False # is_original
|
| 517 |
+
))
|
| 518 |
+
else:
|
| 519 |
+
print(f"TTS: Skipping TTS generation for translation (disabled for this session)")
|
| 520 |
+
else:
|
| 521 |
+
print(f"TRANSLATION FAILED: No translated text returned for {language.name}")
|
| 522 |
+
except Exception as e:
|
| 523 |
+
print(f"Translation error for {language.name}: {e}")
|
| 524 |
+
|
| 525 |
+
# Wait for all TTS generation to complete and broadcast with proper alignment
|
| 526 |
+
for language_code, text, audio_task, is_original in audio_tasks:
|
| 527 |
+
try:
|
| 528 |
+
audio_result = await audio_task
|
| 529 |
+
|
| 530 |
+
if audio_result and (audio_result[0] or audio_result[1]):
|
| 531 |
+
webm_data, wav_data = audio_result
|
| 532 |
+
print(f"TTS: Audio generated successfully for {language_code}")
|
| 533 |
+
if webm_data:
|
| 534 |
+
print(f"TTS: WebM audio: {len(webm_data)} bytes")
|
| 535 |
+
if wav_data:
|
| 536 |
+
print(f"TTS: WAV audio: {len(wav_data)} bytes")
|
| 537 |
+
print(f"TTS: Text for {language_code}: '{text}'")
|
| 538 |
+
|
| 539 |
+
# Broadcast TTS audio with explicit message-text-audio alignment (dual format)
|
| 540 |
+
await self._broadcast_tts_audio_aligned_dual_format(
|
| 541 |
+
message.session_id,
|
| 542 |
+
message.id,
|
| 543 |
+
language_code,
|
| 544 |
+
text,
|
| 545 |
+
webm_data,
|
| 546 |
+
wav_data,
|
| 547 |
+
is_original
|
| 548 |
+
)
|
| 549 |
+
else:
|
| 550 |
+
print(f"TTS: Failed to generate audio for {language_code}")
|
| 551 |
+
except Exception as e:
|
| 552 |
+
print(f"TTS generation error for {language_code}: {e}")
|
| 553 |
+
|
| 554 |
+
print(f"=== TRANSLATION/TTS PROCESSING END ===")
|
| 555 |
+
|
| 556 |
+
except Exception as e:
|
| 557 |
+
print(f"Error in _process_translations_and_tts: {e}")
|
| 558 |
+
import traceback
|
| 559 |
+
traceback.print_exc()
|
| 560 |
+
|
| 561 |
+
async def _broadcast_to_session(self, session_id: str, event: str, data: dict, exclude_sid: str = None):
|
| 562 |
+
"""Broadcast message to all clients in a session"""
|
| 563 |
+
if session_id not in self.session_clients:
|
| 564 |
+
return
|
| 565 |
+
|
| 566 |
+
# Create a copy of the set to avoid concurrent modification
|
| 567 |
+
client_sids = list(self.session_clients[session_id])
|
| 568 |
+
|
| 569 |
+
for sid in client_sids:
|
| 570 |
+
if sid != exclude_sid:
|
| 571 |
+
try:
|
| 572 |
+
await self.sio.emit(event, data, room=sid)
|
| 573 |
+
except Exception as e:
|
| 574 |
+
print(f"Error broadcasting {event} to client {sid}: {e}")
|
| 575 |
+
|
| 576 |
+
async def _broadcast_tts_audio_aligned(self, session_id: str, message_id: str,
|
| 577 |
+
language_code: str, text: str, audio_data: bytes,
|
| 578 |
+
is_original: bool = False):
|
| 579 |
+
"""Broadcast TTS audio with explicit message-text-audio alignment"""
|
| 580 |
+
try:
|
| 581 |
+
if session_id not in self.session_clients:
|
| 582 |
+
return
|
| 583 |
+
|
| 584 |
+
print(f"TTS ALIGNED: Broadcasting audio for message {message_id}")
|
| 585 |
+
print(f"TTS ALIGNED: Language: {language_code}")
|
| 586 |
+
print(f"TTS ALIGNED: Text: '{text}'")
|
| 587 |
+
print(f"TTS ALIGNED: Audio size: {len(audio_data)} bytes")
|
| 588 |
+
print(f"TTS ALIGNED: Is original: {is_original}")
|
| 589 |
+
|
| 590 |
+
# Create a copy of the set to avoid concurrent modification
|
| 591 |
+
client_sids = list(self.session_clients[session_id])
|
| 592 |
+
|
| 593 |
+
# Send audio data in chunks to all participants with explicit alignment data
|
| 594 |
+
for sid in client_sids:
|
| 595 |
+
try:
|
| 596 |
+
chunk_size = 4096
|
| 597 |
+
for i in range(0, len(audio_data), chunk_size):
|
| 598 |
+
chunk = audio_data[i:i + chunk_size]
|
| 599 |
+
is_last_chunk = i + chunk_size >= len(audio_data)
|
| 600 |
+
|
| 601 |
+
chunk_data = {
|
| 602 |
+
'messageId': message_id, # Explicit message ID
|
| 603 |
+
'languageCode': language_code, # Language of THIS audio
|
| 604 |
+
'text': text, # Text that THIS audio represents
|
| 605 |
+
'isOriginal': is_original, # Whether this is original or translation
|
| 606 |
+
'chunk': list(chunk),
|
| 607 |
+
'isLast': is_last_chunk,
|
| 608 |
+
'chunkIndex': i // chunk_size, # Chunk ordering
|
| 609 |
+
'totalChunks': (len(audio_data) + chunk_size - 1) // chunk_size
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
|
| 613 |
+
|
| 614 |
+
# Small delay to prevent overwhelming
|
| 615 |
+
await asyncio.sleep(0.01)
|
| 616 |
+
|
| 617 |
+
print(f"TTS ALIGNED: Successfully sent aligned audio to participant {sid}")
|
| 618 |
+
except Exception as e:
|
| 619 |
+
print(f"TTS ALIGNED: Error sending audio to participant {sid}: {e}")
|
| 620 |
+
|
| 621 |
+
except Exception as e:
|
| 622 |
+
print(f"TTS ALIGNED: Error broadcasting aligned audio: {e}")
|
| 623 |
+
|
| 624 |
+
async def _broadcast_tts_audio_aligned_dual_format(self, session_id: str, message_id: str,
|
| 625 |
+
language_code: str, text: str, webm_data: bytes,
|
| 626 |
+
wav_data: bytes, is_original: bool = False):
|
| 627 |
+
"""Broadcast TTS audio with both WebM and WAV formats for cross-platform compatibility"""
|
| 628 |
+
try:
|
| 629 |
+
if session_id not in self.session_clients:
|
| 630 |
+
return
|
| 631 |
+
|
| 632 |
+
print(f"TTS DUAL FORMAT: Broadcasting audio for message {message_id}")
|
| 633 |
+
print(f"TTS DUAL FORMAT: Language: {language_code}")
|
| 634 |
+
print(f"TTS DUAL FORMAT: Text: '{text}'")
|
| 635 |
+
if webm_data:
|
| 636 |
+
print(f"TTS DUAL FORMAT: WebM size: {len(webm_data)} bytes")
|
| 637 |
+
if wav_data:
|
| 638 |
+
print(f"TTS DUAL FORMAT: WAV size: {len(wav_data)} bytes")
|
| 639 |
+
print(f"TTS DUAL FORMAT: Is original: {is_original}")
|
| 640 |
+
|
| 641 |
+
# Create a copy of the set to avoid concurrent modification
|
| 642 |
+
client_sids = list(self.session_clients[session_id])
|
| 643 |
+
|
| 644 |
+
# Use WebM data for chunking (primary format for web clients)
|
| 645 |
+
primary_audio_data = webm_data if webm_data else wav_data
|
| 646 |
+
if not primary_audio_data:
|
| 647 |
+
print("TTS DUAL FORMAT: No audio data available")
|
| 648 |
+
return
|
| 649 |
+
|
| 650 |
+
# Send audio data in chunks to all participants with dual format support
|
| 651 |
+
chunk_size = 4096
|
| 652 |
+
for sid in client_sids:
|
| 653 |
+
try:
|
| 654 |
+
for i in range(0, len(primary_audio_data), chunk_size):
|
| 655 |
+
chunk = primary_audio_data[i:i + chunk_size]
|
| 656 |
+
is_last_chunk = i + chunk_size >= len(primary_audio_data)
|
| 657 |
+
|
| 658 |
+
# Prepare WAV chunk if available
|
| 659 |
+
wav_chunk = None
|
| 660 |
+
if wav_data and i < len(wav_data):
|
| 661 |
+
wav_end = min(i + chunk_size, len(wav_data))
|
| 662 |
+
wav_chunk = wav_data[i:wav_end]
|
| 663 |
+
|
| 664 |
+
chunk_data = {
|
| 665 |
+
'messageId': message_id, # Explicit message ID
|
| 666 |
+
'languageCode': language_code, # Language of THIS audio
|
| 667 |
+
'text': text, # Text that THIS audio represents
|
| 668 |
+
'isOriginal': is_original, # Whether this is original or translation
|
| 669 |
+
'chunk': list(chunk), # WebM audio chunk (for web clients)
|
| 670 |
+
'wavChunk': list(wav_chunk) if wav_chunk else None, # WAV audio chunk (for Android clients)
|
| 671 |
+
'isLast': is_last_chunk,
|
| 672 |
+
'chunkIndex': i // chunk_size, # Chunk ordering
|
| 673 |
+
'totalChunks': (len(primary_audio_data) + chunk_size - 1) // chunk_size,
|
| 674 |
+
'format': 'webm', # Primary format
|
| 675 |
+
'wavFormat': 'wav' if wav_chunk else None # Secondary format available
|
| 676 |
+
}
|
| 677 |
+
|
| 678 |
+
await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
|
| 679 |
+
|
| 680 |
+
# Small delay to prevent overwhelming
|
| 681 |
+
await asyncio.sleep(0.01)
|
| 682 |
+
|
| 683 |
+
print(f"TTS DUAL FORMAT: Successfully sent dual format audio to participant {sid}")
|
| 684 |
+
except Exception as e:
|
| 685 |
+
print(f"TTS DUAL FORMAT: Error sending audio to participant {sid}: {e}")
|
| 686 |
+
|
| 687 |
+
except Exception as e:
|
| 688 |
+
print(f"TTS DUAL FORMAT: Error broadcasting dual format audio: {e}")
|
| 689 |
+
|
| 690 |
+
async def _broadcast_tts_audio_to_all_participants(self, session_id: str, language_code: str,
|
| 691 |
+
audio_data: bytes, message_id: str, text: str):
|
| 692 |
+
"""Legacy method - now calls the aligned version"""
|
| 693 |
+
await self._broadcast_tts_audio_aligned(
|
| 694 |
+
session_id, message_id, language_code, text, audio_data, False
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
async def _broadcast_audio_to_language_participants(self, session_id: str, language_code: str,
|
| 698 |
+
audio_data: bytes, message_id: str):
|
| 699 |
+
"""Broadcast audio to participants listening in specific language (legacy method)"""
|
| 700 |
+
try:
|
| 701 |
+
session = await self.session_manager.get_session(session_id)
|
| 702 |
+
if not session:
|
| 703 |
+
return
|
| 704 |
+
|
| 705 |
+
# Find participants with matching language
|
| 706 |
+
target_participants = [p for p in session.participants if p.language.code.value == language_code]
|
| 707 |
+
|
| 708 |
+
for participant in target_participants:
|
| 709 |
+
# Find client SID for this participant
|
| 710 |
+
participant_sid = None
|
| 711 |
+
for sid, pid in self.client_participants.items():
|
| 712 |
+
if pid == participant.id:
|
| 713 |
+
participant_sid = sid
|
| 714 |
+
break
|
| 715 |
+
|
| 716 |
+
if participant_sid:
|
| 717 |
+
print(f"TTS: Broadcasting audio to participant {participant.name} in {language_code}")
|
| 718 |
+
# Send audio data in chunks
|
| 719 |
+
chunk_size = 4096
|
| 720 |
+
for i in range(0, len(audio_data), chunk_size):
|
| 721 |
+
chunk = audio_data[i:i + chunk_size]
|
| 722 |
+
await self.sio.emit('tts_audio_chunk', {
|
| 723 |
+
'messageId': message_id,
|
| 724 |
+
'chunk': list(chunk),
|
| 725 |
+
'isLast': i + chunk_size >= len(audio_data)
|
| 726 |
+
}, room=participant_sid)
|
| 727 |
+
|
| 728 |
+
# Small delay to prevent overwhelming
|
| 729 |
+
await asyncio.sleep(0.01)
|
| 730 |
+
|
| 731 |
+
except Exception as e:
|
| 732 |
+
print(f"TTS: Error broadcasting audio: {e}")
|
| 733 |
+
|
| 734 |
+
async def _cleanup_client(self, sid: str):
|
| 735 |
+
"""Clean up client data on disconnect"""
|
| 736 |
+
try:
|
| 737 |
+
participant_id = self.client_participants.get(sid)
|
| 738 |
+
session_id = self.client_sessions.get(sid)
|
| 739 |
+
|
| 740 |
+
if participant_id:
|
| 741 |
+
# Remove participant from session
|
| 742 |
+
await self.session_manager.remove_participant(participant_id)
|
| 743 |
+
|
| 744 |
+
# Clear participant buffers
|
| 745 |
+
self.transcription_service.clear_participant_buffers(participant_id)
|
| 746 |
+
|
| 747 |
+
# Clear current message tracking
|
| 748 |
+
if participant_id in self.participant_current_message:
|
| 749 |
+
del self.participant_current_message[participant_id]
|
| 750 |
+
|
| 751 |
+
del self.client_participants[sid]
|
| 752 |
+
|
| 753 |
+
if session_id:
|
| 754 |
+
# Remove from session clients
|
| 755 |
+
if session_id in self.session_clients:
|
| 756 |
+
self.session_clients[session_id].discard(sid)
|
| 757 |
+
if not self.session_clients[session_id]:
|
| 758 |
+
del self.session_clients[session_id]
|
| 759 |
+
# If session is empty, clear processed messages for this session
|
| 760 |
+
self._cleanup_session_processed_messages(session_id)
|
| 761 |
+
|
| 762 |
+
del self.client_sessions[sid]
|
| 763 |
+
|
| 764 |
+
except Exception as e:
|
| 765 |
+
print(f"Error cleaning up client {sid}: {e}")
|
| 766 |
+
|
| 767 |
+
def _cleanup_session_processed_messages(self, session_id: str):
|
| 768 |
+
"""Clean up processed messages for empty sessions to prevent memory leaks"""
|
| 769 |
+
try:
|
| 770 |
+
# Remove processed messages that belong to this session
|
| 771 |
+
messages_to_remove = []
|
| 772 |
+
for message_id in list(self.processed_messages):
|
| 773 |
+
if message_id in self.messages and self.messages[message_id].session_id == session_id:
|
| 774 |
+
messages_to_remove.append(message_id)
|
| 775 |
+
|
| 776 |
+
for message_id in messages_to_remove:
|
| 777 |
+
self.processed_messages.discard(message_id)
|
| 778 |
+
if message_id in self.messages:
|
| 779 |
+
del self.messages[message_id]
|
| 780 |
+
|
| 781 |
+
print(f"Cleaned up {len(messages_to_remove)} processed messages for session {session_id}")
|
| 782 |
+
except Exception as e:
|
| 783 |
+
print(f"Error cleaning up session processed messages: {e}")
|
| 784 |
+
|
| 785 |
+
async def _emit_error(self, sid: str, message: str):
|
| 786 |
+
"""Emit error message to specific client"""
|
| 787 |
+
try:
|
| 788 |
+
await self.sio.emit('join_error', message, room=sid)
|
| 789 |
+
except Exception as e:
|
| 790 |
+
print(f"Error emitting error to {sid}: {e}")
|
| 791 |
+
|
| 792 |
+
async def handle_update_participant_language(self, sid: str, data: dict):
|
| 793 |
+
"""Handle participant language update (affects speech recognition)"""
|
| 794 |
+
try:
|
| 795 |
+
session_id = data.get('sessionId')
|
| 796 |
+
participant_id = data.get('participantId')
|
| 797 |
+
language_code = data.get('language')
|
| 798 |
+
|
| 799 |
+
print(f"=== UPDATE PARTICIPANT LANGUAGE ===")
|
| 800 |
+
print(f"Session ID: {session_id}")
|
| 801 |
+
print(f"Participant ID: {participant_id}")
|
| 802 |
+
print(f"New Language: {language_code}")
|
| 803 |
+
|
| 804 |
+
if not all([session_id, participant_id, language_code]):
|
| 805 |
+
await self._emit_error(sid, "Missing required fields")
|
| 806 |
+
return
|
| 807 |
+
|
| 808 |
+
# Validate language code
|
| 809 |
+
try:
|
| 810 |
+
from app.models import LanguageCode
|
| 811 |
+
lang_enum = LanguageCode(language_code)
|
| 812 |
+
print(f"Language code validated: {lang_enum}")
|
| 813 |
+
except ValueError:
|
| 814 |
+
await self._emit_error(sid, f"Invalid language code: {language_code}")
|
| 815 |
+
return
|
| 816 |
+
|
| 817 |
+
# Update participant's language in session
|
| 818 |
+
session = await self.session_manager.get_session(session_id)
|
| 819 |
+
if session:
|
| 820 |
+
for participant in session.participants:
|
| 821 |
+
if participant.id == participant_id:
|
| 822 |
+
# Update participant's language using LANGUAGE_MAP for complete Language object
|
| 823 |
+
if lang_enum in LANGUAGE_MAP:
|
| 824 |
+
participant.language = LANGUAGE_MAP[lang_enum]
|
| 825 |
+
print(f"Updated participant {participant.name} language to {lang_enum.value} ({participant.language.display_name})")
|
| 826 |
+
else:
|
| 827 |
+
print(f"Warning: Language {lang_enum.value} not found in LANGUAGE_MAP, using fallback")
|
| 828 |
+
from app.models import Language
|
| 829 |
+
participant.language = Language(code=lang_enum, name=lang_enum.value, display_name=lang_enum.value)
|
| 830 |
+
|
| 831 |
+
# Notify all clients in session
|
| 832 |
+
await self._broadcast_to_session(session_id, 'participant_language_updated', {
|
| 833 |
+
'participantId': participant_id,
|
| 834 |
+
'language': language_code
|
| 835 |
+
})
|
| 836 |
+
break
|
| 837 |
+
|
| 838 |
+
print(f"=== UPDATE PARTICIPANT LANGUAGE COMPLETE ===")
|
| 839 |
+
|
| 840 |
+
except Exception as e:
|
| 841 |
+
print(f"Error in handle_update_participant_language: {e}")
|
| 842 |
+
import traceback
|
| 843 |
+
traceback.print_exc()
|
| 844 |
+
await self._emit_error(sid, "Failed to update participant language")
|
| 845 |
+
|
| 846 |
+
async def handle_update_session_languages(self, sid: str, data: dict):
|
| 847 |
+
"""Handle session languages update (affects translation targets)"""
|
| 848 |
+
try:
|
| 849 |
+
session_id = data.get('sessionId')
|
| 850 |
+
languages = data.get('languages', [])
|
| 851 |
+
|
| 852 |
+
print(f"=== UPDATE SESSION LANGUAGES (REPLACE MODE) ===")
|
| 853 |
+
print(f"Session ID: {session_id}")
|
| 854 |
+
print(f"New Languages: {languages}")
|
| 855 |
+
|
| 856 |
+
if not session_id or not languages:
|
| 857 |
+
await self._emit_error(sid, "Missing required fields")
|
| 858 |
+
return
|
| 859 |
+
|
| 860 |
+
# Get current session for comparison
|
| 861 |
+
session = await self.session_manager.get_session(session_id)
|
| 862 |
+
if not session:
|
| 863 |
+
await self._emit_error(sid, "Session not found")
|
| 864 |
+
return
|
| 865 |
+
|
| 866 |
+
current_languages = [lang.code.value for lang in session.languages]
|
| 867 |
+
print(f"Before update - Session languages: {current_languages}")
|
| 868 |
+
|
| 869 |
+
# Validate all language codes and create Language objects
|
| 870 |
+
validated_languages = []
|
| 871 |
+
try:
|
| 872 |
+
from app.models import Language, LanguageCode
|
| 873 |
+
from app.services.session_manager import LANGUAGE_MAP
|
| 874 |
+
|
| 875 |
+
for lang_code in languages:
|
| 876 |
+
lang_enum = LanguageCode(lang_code)
|
| 877 |
+
language = LANGUAGE_MAP[lang_enum]
|
| 878 |
+
validated_languages.append(language)
|
| 879 |
+
print(f"Validated language: {lang_code} -> {language.name}")
|
| 880 |
+
|
| 881 |
+
except ValueError as e:
|
| 882 |
+
await self._emit_error(sid, f"Invalid language code: {e}")
|
| 883 |
+
return
|
| 884 |
+
|
| 885 |
+
# REPLACE session languages (not add to them)
|
| 886 |
+
session.languages = validated_languages
|
| 887 |
+
new_languages = [lang.code.value for lang in session.languages]
|
| 888 |
+
print(f"After update - Session languages: {new_languages}")
|
| 889 |
+
|
| 890 |
+
# Verify the session manager has the updated languages
|
| 891 |
+
verification_session = await self.session_manager.get_session(session_id)
|
| 892 |
+
if verification_session:
|
| 893 |
+
verification_languages = [lang.code.value for lang in verification_session.languages]
|
| 894 |
+
print(f"Verification - Session manager languages: {verification_languages}")
|
| 895 |
+
|
| 896 |
+
# Notify all clients in session about the update
|
| 897 |
+
await self._broadcast_to_session(session_id, 'session_languages_updated', {
|
| 898 |
+
'sessionId': session_id,
|
| 899 |
+
'languages': new_languages,
|
| 900 |
+
'previous': current_languages
|
| 901 |
+
})
|
| 902 |
+
|
| 903 |
+
print(f"=== UPDATE SESSION LANGUAGES COMPLETE ===")
|
| 904 |
+
|
| 905 |
+
except Exception as e:
|
| 906 |
+
print(f"Error in handle_update_session_languages: {e}")
|
| 907 |
+
import traceback
|
| 908 |
+
traceback.print_exc()
|
| 909 |
+
await self._emit_error(sid, "Failed to update session languages")
|
preload_models.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 2 |
+
import io
|
| 3 |
+
import nltk
|
| 4 |
+
import json
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
if len(sys.argv) > 1:
|
| 13 |
+
os.environ["HUGGING_FACE_HUB_TOKEN"] = sys.argv[1]
|
| 14 |
+
|
| 15 |
+
nltk.download("punkt")
|
| 16 |
+
nltk.download('punkt_tab')
|
| 17 |
+
|
| 18 |
+
device = 0 if torch.cuda.is_available() else -1
|
| 19 |
+
|
| 20 |
+
def cleanup_model_resource(model):
|
| 21 |
+
del model
|
| 22 |
+
gc.collect()
|
| 23 |
+
torch.cuda.empty_cache()
|
requirements.txt
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web framework and server
|
| 2 |
+
fastapi==0.115.5
|
| 3 |
+
uvicorn[standard]==0.32.1
|
| 4 |
+
websockets==13.1
|
| 5 |
+
python-socketio==5.11.4
|
| 6 |
+
python-multipart==0.0.17
|
| 7 |
+
pydantic==2.10.3
|
| 8 |
+
|
| 9 |
+
# PyTorch ecosystem - latest stable versions
|
| 10 |
+
torch==2.5.1
|
| 11 |
+
torchaudio==2.5.1
|
| 12 |
+
transformers==4.45.2
|
| 13 |
+
datasets==3.1.0
|
| 14 |
+
tokenizers==0.20.4
|
| 15 |
+
accelerate==1.2.1
|
| 16 |
+
|
| 17 |
+
# ONNX Runtime for optimized inference - GPU enabled
|
| 18 |
+
onnxruntime-gpu==1.19.2
|
| 19 |
+
onnx==1.17.0
|
| 20 |
+
optimum[onnxruntime-gpu]==1.23.0
|
| 21 |
+
huggingface-hub==0.26.2
|
| 22 |
+
|
| 23 |
+
# Audio processing
|
| 24 |
+
soundfile==0.12.1
|
| 25 |
+
librosa==0.10.2
|
| 26 |
+
phonemizer==3.3.0
|
| 27 |
+
pydub==0.25.1
|
| 28 |
+
|
| 29 |
+
# Scientific computing
|
| 30 |
+
scipy==1.14.1
|
| 31 |
+
numpy==2.1.3
|
| 32 |
+
|
| 33 |
+
# Natural language processing
|
| 34 |
+
nltk==3.9.1
|
| 35 |
+
sentencepiece==0.2.0
|
| 36 |
+
|
| 37 |
+
# Computer vision and image processing
|
| 38 |
+
pillow==11.0.0
|
| 39 |
+
qrcode[pil]==8.0
|
| 40 |
+
|
| 41 |
+
# Authentication and security
|
| 42 |
+
python-jose[cryptography]==3.3.0
|
| 43 |
+
passlib[bcrypt]==1.7.4
|
| 44 |
+
|
| 45 |
+
# File handling
|
| 46 |
+
aiofiles==24.1.0
|
| 47 |
+
|
| 48 |
+
# Model optimization
|
| 49 |
+
bitsandbytes==0.45.0
|
| 50 |
+
|
| 51 |
+
# Protocol buffers - compatible version
|
| 52 |
+
protobuf==5.28.3
|
| 53 |
+
|
| 54 |
+
# Speech processing
|
| 55 |
+
speechbrain==1.0.2
|
| 56 |
+
|
| 57 |
+
# Voice Activity Detection
|
| 58 |
+
silero-vad>=5.1
|