mutisya commited on
Commit
e2ccb09
·
verified ·
1 Parent(s): 5e05084

Deploy Polyglot backend with quantized models

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .cache
2
+ nltk_data
3
+ __pycache__
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+ .Python
8
+ *.so
9
+ *.egg
10
+ *.egg-info
11
+ dist
12
+ build
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ ffmpeg \
8
+ libsndfile1 \
9
+ sox \
10
+ espeak \
11
+ espeak-data \
12
+ libespeak1 \
13
+ libespeak-dev \
14
+ wget \
15
+ gnupg \
16
+ curl \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Copy requirements and install Python dependencies
20
+ COPY requirements.txt .
21
+ RUN pip install --no-cache-dir -r requirements.txt
22
+
23
+ # Copy application code
24
+ COPY app ./app
25
+ COPY preload_models.py .
26
+
27
+ # Set environment variables for caching
28
+ ENV HF_HOME=/app/.cache
29
+ ENV TRANSFORMERS_CACHE=/app/.cache
30
+ ENV NLTK_DATA=/app/nltk_data
31
+ ENV PYTHONPATH=/app
32
+ ENV PORT=7860
33
+
34
+ # Create cache directories
35
+ RUN mkdir -p $HF_HOME && chmod -R 777 $HF_HOME
36
+ RUN mkdir -p $NLTK_DATA && chmod -R 777 $NLTK_DATA
37
+
38
+ # Download models using HF token from environment
39
+ # HuggingFace Spaces automatically provides HUGGING_FACE_HUB_TOKEN
40
+ ARG HUGGING_FACE_HUB_TOKEN
41
+ RUN python preload_models.py $HUGGING_FACE_HUB_TOKEN || echo "Model preload skipped - will download on first use"
42
+
43
+ # Expose port 7860 (HuggingFace Spaces standard)
44
+ EXPOSE 7860
45
+
46
+ # Run the application
47
+ CMD ["uvicorn", "app.main:socket_app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,40 @@
1
- ---
2
- title: Polyglot Backend Quant
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Polyglot Translation Backend
3
+ emoji: 🌍
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ app_port: 7860
10
+ ---
11
+
12
+ # Polyglot Translation Backend - Quantized Models
13
+
14
+ Real-time speech transcription and translation API with Socket.IO for WebSocket communication. This version uses INT8 quantized models for improved performance and reduced memory footprint.
15
+
16
+ ## Features
17
+
18
+ - **Real-time Speech Recognition**: Support for English, Swahili, Kikuyu, Kamba, Kimeru, Luo, and Somali
19
+ - **Translation**: Multi-language translation using NLLB models
20
+ - **Text-to-Speech**: Generate speech in multiple languages
21
+ - **WebSocket Support**: Real-time communication via Socket.IO
22
+ - **Model Quantization**: INT8 dynamic quantization for faster inference
23
+
24
+ ## API Endpoints
25
+
26
+ - `GET /health` - Health check endpoint
27
+ - `WebSocket /` - Socket.IO connection for real-time communication
28
+
29
+ ## Environment
30
+
31
+ This Space requires a HuggingFace token for model access. The token is automatically provided by HuggingFace Spaces when configured as a secret.
32
+
33
+ ## Technical Details
34
+
35
+ - **Framework**: FastAPI with Socket.IO
36
+ - **Models**:
37
+ - ASR: Whisper (English) and Wav2Vec2-BERT (African languages)
38
+ - Translation: NLLB-600M fine-tuned model
39
+ - TTS: VITS models for each language
40
+ - **Optimization**: INT8 dynamic quantization via PyTorch
app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Backend application package
app/auth.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authentication module for HuggingFace token validation
3
+ """
4
+ import os
5
+ from typing import Optional
6
+ from fastapi import HTTPException, status, Request
7
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
+ from fastapi.security.utils import get_authorization_scheme_param
9
+
10
+
11
+ def is_local_development() -> bool:
12
+ """
13
+ Detect if the application is running in local development mode.
14
+ This checks multiple indicators to determine if auth should be disabled.
15
+ """
16
+ # Method 1: Explicit disable auth flag
17
+ disable_auth = os.getenv('DISABLE_AUTH', '').lower()
18
+ if disable_auth in ['true', '1', 'yes']:
19
+ return True
20
+
21
+ # Method 2: Check ENVIRONMENT variable
22
+ environment = os.getenv('ENVIRONMENT', '').lower()
23
+ if environment in ['development', 'dev', 'local']:
24
+ return True
25
+
26
+ # Method 3: Check DEBUG flag
27
+ debug = os.getenv('DEBUG', '').lower()
28
+ if debug in ['true', '1', 'yes']:
29
+ return True
30
+
31
+ # Method 4: Check if running on localhost/development ports
32
+ host = os.getenv('HOST', '')
33
+ port = os.getenv('PORT', '')
34
+ if host in ['localhost', '127.0.0.1', '0.0.0.0'] and port == '7860':
35
+ return True
36
+
37
+ # Method 5: Check for presence of local development files
38
+ local_indicators = [
39
+ '.env.local',
40
+ 'docker-compose.local.yml',
41
+ 'Dockerfile.local'
42
+ ]
43
+ for indicator in local_indicators:
44
+ if os.path.exists(indicator):
45
+ return True
46
+
47
+ # Method 6: Check if we're in a Docker container with local development setup
48
+ if os.path.exists('/.dockerenv'):
49
+ # We're in Docker, check if it's local development
50
+ if os.getenv('ALLOW_ALL_ORIGINS', '').lower() == 'true':
51
+ return True
52
+
53
+ return False
54
+
55
+
56
+ class HuggingFaceTokenAuth:
57
+ """HuggingFace token authentication handler"""
58
+
59
+ def __init__(self):
60
+ self.bearer = HTTPBearer(auto_error=False)
61
+ self.is_local = is_local_development()
62
+
63
+ if self.is_local:
64
+ print("🔓 RUNNING IN LOCAL DEVELOPMENT MODE - AUTH DISABLED")
65
+ print(" Environment indicators:")
66
+ print(f" - DISABLE_AUTH: {os.getenv('DISABLE_AUTH', 'not set')}")
67
+ print(f" - ENVIRONMENT: {os.getenv('ENVIRONMENT', 'not set')}")
68
+ print(f" - DEBUG: {os.getenv('DEBUG', 'not set')}")
69
+ print(f" - HOST: {os.getenv('HOST', 'not set')}")
70
+ print(f" - PORT: {os.getenv('PORT', 'not set')}")
71
+ print(f" - ALLOW_ALL_ORIGINS: {os.getenv('ALLOW_ALL_ORIGINS', 'not set')}")
72
+ print(f" - Docker container: {os.path.exists('/.dockerenv')}")
73
+ print(f" - .env.local exists: {os.path.exists('.env.local')}")
74
+ else:
75
+ print("🔒 RUNNING IN PRODUCTION MODE - AUTH REQUIRED")
76
+
77
+ def verify_token(self, token: str) -> bool:
78
+ """
79
+ Verify if the token is a valid HuggingFace token.
80
+ In local development mode, always returns True.
81
+ """
82
+ # Skip token validation in local development
83
+ if self.is_local:
84
+ print("🔓 Local development mode: skipping token validation")
85
+ return True
86
+
87
+ try:
88
+ if not token:
89
+ return False
90
+
91
+ if not isinstance(token, str):
92
+ print(f"❌ Token is not a string: {type(token)}")
93
+ return False
94
+
95
+ # HuggingFace tokens start with 'hf_'
96
+ if not token.startswith('hf_'):
97
+ print(f"❌ Token does not start with 'hf_': {token[:10]}...")
98
+ return False
99
+
100
+ # Additional validation can be added here
101
+ # For example, you could make a request to HuggingFace API
102
+ # to validate the token, but that would add latency
103
+
104
+ return True
105
+
106
+ except Exception as e:
107
+ print(f"❌ Error in verify_token: {e}")
108
+ return False
109
+
110
+ def get_token_from_request(self, request: Request) -> Optional[str]:
111
+ """Extract token from various sources in the request"""
112
+
113
+ # Method 1: Authorization header
114
+ authorization = request.headers.get("Authorization")
115
+ if authorization:
116
+ scheme, token = get_authorization_scheme_param(authorization)
117
+ if scheme.lower() == "bearer":
118
+ return token
119
+
120
+ # Method 2: Query parameter (for WebSocket initial handshake)
121
+ token = request.query_params.get("token")
122
+ if token:
123
+ return token
124
+
125
+ # Method 3: Custom header (alternative)
126
+ token = request.headers.get("X-HF-Token")
127
+ if token:
128
+ return token
129
+
130
+ return None
131
+
132
+ async def authenticate_request(self, request: Request) -> bool:
133
+ """Authenticate a request using HuggingFace token"""
134
+ token = self.get_token_from_request(request)
135
+
136
+ if not token:
137
+ return False
138
+
139
+ return self.verify_token(token)
140
+
141
+
142
+ # Global instance
143
+ hf_auth = HuggingFaceTokenAuth()
144
+
145
+
146
+ async def require_hf_token(request: Request) -> str:
147
+ """
148
+ FastAPI dependency that requires a valid HuggingFace token.
149
+ In local development mode, returns a dummy token.
150
+ Returns the token if valid, raises HTTPException if not.
151
+ """
152
+ # Skip authentication in local development
153
+ if hf_auth.is_local:
154
+ print("🔓 Local development mode: bypassing HF token requirement")
155
+ return "local-development-bypass"
156
+
157
+ token = hf_auth.get_token_from_request(request)
158
+
159
+ if not token:
160
+ raise HTTPException(
161
+ status_code=status.HTTP_401_UNAUTHORIZED,
162
+ detail="HuggingFace token required. Please provide a valid token in Authorization header.",
163
+ headers={"WWW-Authenticate": "Bearer"},
164
+ )
165
+
166
+ if not hf_auth.verify_token(token):
167
+ raise HTTPException(
168
+ status_code=status.HTTP_401_UNAUTHORIZED,
169
+ detail="Invalid HuggingFace token. Token must start with 'hf_'.",
170
+ headers={"WWW-Authenticate": "Bearer"},
171
+ )
172
+
173
+ return token
174
+
175
+
176
+ async def optional_hf_token(request: Request) -> Optional[str]:
177
+ """
178
+ FastAPI dependency that optionally validates HuggingFace token.
179
+ In local development mode, returns a dummy token if no real token provided.
180
+ Returns the token if present and valid, None otherwise.
181
+ Useful for endpoints that work with or without authentication.
182
+ """
183
+ # In local development, always return a token
184
+ if hf_auth.is_local:
185
+ token = hf_auth.get_token_from_request(request)
186
+ if token and hf_auth.verify_token(token):
187
+ return token
188
+ else:
189
+ print("🔓 Local development mode: providing dummy token for optional auth")
190
+ return "local-development-bypass"
191
+
192
+ token = hf_auth.get_token_from_request(request)
193
+
194
+ if not token:
195
+ return None
196
+
197
+ if hf_auth.verify_token(token):
198
+ return token
199
+
200
+ return None
201
+
202
+
203
+ def authenticate_websocket_connect(environ: dict) -> bool:
204
+ """
205
+ Authenticate WebSocket connection using token from various sources.
206
+ In local development mode, always returns True.
207
+ This is called during the Socket.IO connect event.
208
+ """
209
+ # Skip authentication in local development
210
+ if hf_auth.is_local:
211
+ print("🔓 Local development mode: bypassing WebSocket authentication")
212
+ return True
213
+
214
+ try:
215
+ print("=== WEBSOCKET ENVIRON AUTHENTICATION ===")
216
+ print(f"Environ type: {type(environ)}")
217
+
218
+ if not isinstance(environ, dict):
219
+ print(f"❌ Environ is not a dict: {type(environ)}")
220
+ return False
221
+
222
+ # Method 1: Check query parameters
223
+ query_string = environ.get('QUERY_STRING', '')
224
+ print(f"Query string: {query_string}")
225
+ if query_string:
226
+ from urllib.parse import parse_qs
227
+ query_params = parse_qs(query_string)
228
+ print(f"Parsed query params: {query_params}")
229
+ tokens = query_params.get('token', [])
230
+ if tokens:
231
+ token = tokens[0]
232
+ print(f"Found token in query: {token[:10]}...")
233
+ if hf_auth.verify_token(token):
234
+ print("✓ Token validated via query params")
235
+ return True
236
+
237
+ # Method 2: Check headers
238
+ auth_header = environ.get('HTTP_AUTHORIZATION', '')
239
+ print(f"Authorization header: {auth_header[:20] if auth_header else 'None'}...")
240
+ if auth_header:
241
+ if auth_header.startswith('Bearer '):
242
+ token = auth_header[7:] # Remove 'Bearer ' prefix
243
+ print(f"Found token in Authorization header: {token[:10]}...")
244
+ if hf_auth.verify_token(token):
245
+ print("✓ Token validated via Authorization header")
246
+ return True
247
+
248
+ # Method 3: Check custom header
249
+ hf_token_header = environ.get('HTTP_X_HF_TOKEN', '')
250
+ print(f"X-HF-Token header: {hf_token_header[:10] if hf_token_header else 'None'}...")
251
+ if hf_token_header:
252
+ if hf_auth.verify_token(hf_token_header):
253
+ print("✓ Token validated via X-HF-Token header")
254
+ return True
255
+
256
+ print("❌ No valid token found in environ")
257
+ print(f"Available environ keys: {list(environ.keys())}")
258
+ return False
259
+
260
+ except Exception as e:
261
+ print(f"❌ Error in authenticate_websocket_connect: {e}")
262
+ import traceback
263
+ traceback.print_exc()
264
+ return False
265
+
266
+
267
+ def authenticate_websocket_auth_data(auth_data: dict) -> bool:
268
+ """
269
+ Authenticate WebSocket connection using auth data from Socket.IO.
270
+ In local development mode, always returns True.
271
+ This is called when the client sends auth data in the connection.
272
+ """
273
+ # Skip authentication in local development
274
+ if hf_auth.is_local:
275
+ print("🔓 Local development mode: bypassing WebSocket auth data validation")
276
+ return True
277
+
278
+ try:
279
+ print("=== WEBSOCKET AUTH DATA AUTHENTICATION ===")
280
+ print(f"Auth data received: {auth_data}")
281
+ print(f"Auth data type: {type(auth_data)}")
282
+
283
+ if not auth_data:
284
+ print("❌ No auth data provided")
285
+ return False
286
+
287
+ if not isinstance(auth_data, dict):
288
+ print(f"❌ Auth data is not a dict: {type(auth_data)}")
289
+ return False
290
+
291
+ # Check for token in auth data
292
+ token = auth_data.get('token')
293
+ if token:
294
+ print(f"Found token in auth data: {token[:10]}...")
295
+ if hf_auth.verify_token(token):
296
+ print("✓ Token validated via auth data")
297
+ return True
298
+ else:
299
+ print("❌ Invalid token in auth data")
300
+ else:
301
+ print("❌ No token in auth data")
302
+ print(f"Available keys in auth data: {list(auth_data.keys())}")
303
+
304
+ return False
305
+
306
+ except Exception as e:
307
+ print(f"❌ Error in authenticate_websocket_auth_data: {e}")
308
+ import traceback
309
+ traceback.print_exc()
310
+ return False
app/config/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration package for Polyglot backend
3
+ """
4
+
5
+ from .cors import cors_config
6
+
7
+ __all__ = ["cors_config"]
app/config/cors.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CORS Configuration Module
3
+
4
+ Centralized CORS configuration supporting multiple deployment environments.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ from typing import List, Optional
10
+ from enum import Enum
11
+
12
+
13
+ class Environment(str, Enum):
14
+ """Deployment environment types"""
15
+ LOCAL = "local"
16
+ DEVELOPMENT = "development"
17
+ STAGING = "staging"
18
+ PRODUCTION = "production"
19
+
20
+
21
+ class CORSConfig:
22
+ """CORS configuration manager"""
23
+
24
+ # Default origins for local development
25
+ DEFAULT_LOCAL_ORIGINS = [
26
+ "http://localhost:3000", # React/Next.js dev server
27
+ "http://localhost:3001", # Polyglot frontend (Vite)
28
+ "http://localhost:3002", # Lessons UI (Vite)
29
+ "http://localhost:3003", # Podium (Vite)
30
+ "http://localhost:3004", # Podium alternative port
31
+ "http://localhost:5173", # Vite dev server
32
+ "http://localhost:7860", # Backend self-reference
33
+ "http://localhost:8080", # Alternative dev server
34
+ "http://127.0.0.1:3000", # IPv4 localhost variant
35
+ "http://127.0.0.1:3001", # IPv4 localhost variant
36
+ "http://127.0.0.1:3002", # IPv4 localhost variant
37
+ "http://127.0.0.1:3003", # IPv4 localhost variant
38
+ "http://127.0.0.1:3004", # IPv4 localhost variant
39
+ "http://127.0.0.1:5173", # IPv4 localhost variant
40
+ "http://127.0.0.1:7860", # IPv4 localhost variant
41
+ ]
42
+
43
+ # Default patterns for production deployments
44
+ DEFAULT_PRODUCTION_PATTERNS = [
45
+ r"^https://.*\.tafiti\.dev$", # Tafiti production/staging
46
+ r"^https://.*\.vercel\.app$", # Vercel deployments
47
+ r"^https://.*\.hf\.space$", # HuggingFace Spaces
48
+ r"^https://milimani\.tafiti-api\.org$", # Production API
49
+ ]
50
+
51
+ # Mobile app protocols
52
+ MOBILE_PROTOCOLS = [
53
+ "capacitor://localhost", # Capacitor apps
54
+ "ionic://localhost", # Ionic apps
55
+ "http://localhost", # Mobile WebView
56
+ ]
57
+
58
+ def __init__(self):
59
+ self.environment = self._get_environment()
60
+ self.allowed_origins = self._build_allowed_origins()
61
+ self.allow_all = self._should_allow_all()
62
+ self.origin_patterns = self._build_origin_patterns()
63
+
64
+ def _get_environment(self) -> Environment:
65
+ """Get current deployment environment"""
66
+ env_str = os.getenv("ENVIRONMENT", "local").lower()
67
+
68
+ try:
69
+ return Environment(env_str)
70
+ except ValueError:
71
+ print(f"⚠️ Unknown environment '{env_str}', defaulting to 'local'")
72
+ return Environment.LOCAL
73
+
74
+ def _should_allow_all(self) -> bool:
75
+ """Check if CORS should allow all origins (insecure, dev only)"""
76
+ allow_all = os.getenv("CORS_ALLOW_ALL", "false").lower()
77
+
78
+ if allow_all == "true":
79
+ if self.environment == Environment.PRODUCTION:
80
+ print("❌ ERROR: CORS_ALLOW_ALL=true is not allowed in production")
81
+ return False
82
+ else:
83
+ print("⚠️ WARNING: CORS allowing all origins - INSECURE, use only for development")
84
+ return True
85
+
86
+ return False
87
+
88
+ def _build_allowed_origins(self) -> List[str]:
89
+ """Build list of allowed origins from environment and defaults"""
90
+ origins = []
91
+
92
+ # Get custom origins from environment variable
93
+ custom_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "")
94
+
95
+ if custom_origins_str:
96
+ # Parse comma-separated origins
97
+ custom_origins = [
98
+ origin.strip()
99
+ for origin in custom_origins_str.split(",")
100
+ if origin.strip()
101
+ ]
102
+ origins.extend(custom_origins)
103
+ print(f"✓ Loaded {len(custom_origins)} custom CORS origins from environment")
104
+
105
+ # Add defaults based on environment
106
+ if self.environment == Environment.LOCAL:
107
+ origins.extend(self.DEFAULT_LOCAL_ORIGINS)
108
+ print(f"✓ Added {len(self.DEFAULT_LOCAL_ORIGINS)} default local origins")
109
+
110
+ # Always include mobile protocols in non-production
111
+ if self.environment != Environment.PRODUCTION:
112
+ origins.extend(self.MOBILE_PROTOCOLS)
113
+ print(f"✓ Added {len(self.MOBILE_PROTOCOLS)} mobile protocol origins")
114
+
115
+ # Remove duplicates while preserving order
116
+ seen = set()
117
+ unique_origins = []
118
+ for origin in origins:
119
+ if origin not in seen:
120
+ seen.add(origin)
121
+ unique_origins.append(origin)
122
+
123
+ return unique_origins
124
+
125
+ def _build_origin_patterns(self) -> List[re.Pattern]:
126
+ """Build regex patterns for origin matching"""
127
+ patterns = []
128
+
129
+ # Get custom patterns from environment
130
+ custom_patterns_str = os.getenv("CORS_ALLOWED_PATTERNS", "")
131
+
132
+ if custom_patterns_str:
133
+ custom_pattern_strs = [
134
+ p.strip()
135
+ for p in custom_patterns_str.split(",")
136
+ if p.strip()
137
+ ]
138
+
139
+ for pattern_str in custom_pattern_strs:
140
+ try:
141
+ patterns.append(re.compile(pattern_str))
142
+ except re.error as e:
143
+ print(f"⚠️ Invalid regex pattern '{pattern_str}': {e}")
144
+
145
+ print(f"✓ Loaded {len(patterns)} custom CORS patterns from environment")
146
+
147
+ # Add default production patterns if in production/staging/development
148
+ if self.environment in [Environment.PRODUCTION, Environment.STAGING, Environment.DEVELOPMENT]:
149
+ for pattern_str in self.DEFAULT_PRODUCTION_PATTERNS:
150
+ patterns.append(re.compile(pattern_str))
151
+
152
+ print(f"✓ Added {len(self.DEFAULT_PRODUCTION_PATTERNS)} default production patterns")
153
+
154
+ # Add localhost pattern for development
155
+ if self.environment == Environment.LOCAL:
156
+ patterns.append(re.compile(r"^http://localhost:\d+$"))
157
+ patterns.append(re.compile(r"^http://127\.0\.0\.1:\d+$"))
158
+ print("✓ Added localhost wildcard patterns for development")
159
+
160
+ return patterns
161
+
162
+ def is_origin_allowed(self, origin: str) -> bool:
163
+ """
164
+ Check if an origin is allowed based on explicit list or patterns
165
+
166
+ Args:
167
+ origin: Origin to check (e.g., "https://app.tafiti.dev")
168
+
169
+ Returns:
170
+ True if origin is allowed, False otherwise
171
+ """
172
+ # If allow_all is enabled (dev only)
173
+ if self.allow_all:
174
+ return True
175
+
176
+ # Check explicit origins list
177
+ if origin in self.allowed_origins:
178
+ return True
179
+
180
+ # Check against patterns
181
+ for pattern in self.origin_patterns:
182
+ if pattern.match(origin):
183
+ return True
184
+
185
+ return False
186
+
187
+ def get_cors_middleware_config(self) -> dict:
188
+ """Get configuration dict for FastAPI CORSMiddleware"""
189
+ if self.allow_all:
190
+ return {
191
+ "allow_origins": ["*"],
192
+ "allow_credentials": False, # Cannot use credentials with wildcard
193
+ "allow_methods": ["*"],
194
+ "allow_headers": ["*"],
195
+ }
196
+
197
+ # Build origin regex for pattern matching
198
+ if self.origin_patterns:
199
+ # Combine all patterns into a single regex
200
+ combined_pattern = "|".join(f"({p.pattern})" for p in self.origin_patterns)
201
+
202
+ return {
203
+ "allow_origins": self.allowed_origins,
204
+ "allow_origin_regex": combined_pattern,
205
+ "allow_credentials": True,
206
+ "allow_methods": ["*"],
207
+ "allow_headers": ["*"],
208
+ }
209
+ else:
210
+ return {
211
+ "allow_origins": self.allowed_origins,
212
+ "allow_credentials": True,
213
+ "allow_methods": ["*"],
214
+ "allow_headers": ["*"],
215
+ }
216
+
217
+ def get_socketio_cors_origins(self):
218
+ """
219
+ Get CORS origins for Socket.IO
220
+
221
+ Socket.IO doesn't support regex patterns, so we need to provide explicit list.
222
+ For production, this means we need to enumerate common origins.
223
+ """
224
+ if self.allow_all:
225
+ return "*"
226
+
227
+ # For Socket.IO, we can only provide explicit origins
228
+ # In production, we may need to enumerate common subdomains
229
+ socketio_origins = self.allowed_origins.copy()
230
+
231
+ # Add common production subdomains if using production patterns
232
+ if self.environment in [Environment.PRODUCTION, Environment.STAGING]:
233
+ # These should be added to CORS_ALLOWED_ORIGINS for Socket.IO support
234
+ production_origins = [
235
+ "https://app.tafiti.dev",
236
+ "https://www.tafiti.dev",
237
+ "https://polyglot.tafiti.dev",
238
+ "https://podium.tafiti.dev",
239
+ "https://milimani.tafiti-api.org",
240
+ "https://polyglot-ashy-beta.vercel.app",
241
+ "https://lessons-silk.vercel.app",
242
+ "https://lessons.tafiti.dev⁠",
243
+ "https://podium-chi.vercel.app",
244
+ ]
245
+ for origin in production_origins:
246
+ if origin not in socketio_origins:
247
+ socketio_origins.append(origin)
248
+
249
+ return socketio_origins
250
+
251
+ def print_config_summary(self):
252
+ """Print CORS configuration summary for debugging"""
253
+ print("\n" + "="*70)
254
+ print("CORS CONFIGURATION SUMMARY")
255
+ print("="*70)
256
+ print(f"Environment: {self.environment.value}")
257
+ print(f"Allow All: {self.allow_all}")
258
+ print(f"\nExplicit Origins ({len(self.allowed_origins)}):")
259
+ for origin in self.allowed_origins:
260
+ print(f" • {origin}")
261
+
262
+ if self.origin_patterns:
263
+ print(f"\nOrigin Patterns ({len(self.origin_patterns)}):")
264
+ for pattern in self.origin_patterns:
265
+ print(f" • {pattern.pattern}")
266
+
267
+ print("\nExample Origins That Would Be Allowed:")
268
+ test_origins = [
269
+ "http://localhost:3001",
270
+ "http://localhost:3002",
271
+ "http://localhost:3003",
272
+ "http://localhost:3004",
273
+ "http://localhost:5173",
274
+ "https://app.tafiti.dev",
275
+ "https://polyglot.tafiti.dev",
276
+ "https://podium.tafiti.dev",
277
+ "https://polyglot.vercel.app",
278
+ "https://lessons-silk.vercel.app",
279
+ "https://podium-chi.vercel.app",
280
+ "https://polyglot-ashy-beta.vercel.app",
281
+ "https://mutisya-translator.hf.space",
282
+ "https://milimani.tafiti-api.org",
283
+ "capacitor://localhost",
284
+ "https://example.com",
285
+ ]
286
+
287
+ for test_origin in test_origins:
288
+ allowed = "✓" if self.is_origin_allowed(test_origin) else "✗"
289
+ print(f" {allowed} {test_origin}")
290
+
291
+ print("="*70 + "\n")
292
+
293
+
294
+ # Global CORS configuration instance
295
+ cors_config = CORSConfig()
app/main.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os
3
+ import asyncio
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
+ from contextlib import asynccontextmanager
8
+ import logging
9
+ import socketio
10
+ import engineio
11
+ import re
12
+
13
+ from app.routers import sessions, mobile, watch, learning
14
+ from app.services.session_manager import SessionManager
15
+ from app.services.transcription_service import TranscriptionService
16
+ from app.services.translation_service import TranslationService
17
+ from app.services.tts_service import TTSService
18
+ from app.services.websocket_manager import WebSocketManager
19
+ from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
20
+ from app.config.cors import cors_config
21
+
22
+ class ChunkArrayTruncateFilter(logging.Filter):
23
+ """Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
24
+
25
+ def filter(self, record):
26
+ if hasattr(record, 'msg') and isinstance(record.msg, str):
27
+ # More aggressive approach to truncate audioData arrays
28
+ # Pattern to match: "audioData":[numbers,numbers,numbers,...]
29
+ audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
30
+
31
+ def truncate_audiodata(match):
32
+ array_content = match.group(1)
33
+ # Split by comma and get first 10 items
34
+ items = array_content.split(',')
35
+ if len(items) > 10:
36
+ truncated = ','.join(items[:10])
37
+ return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
38
+ return match.group(0)
39
+
40
+ record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
41
+
42
+ # Also handle any other large numeric arrays in brackets
43
+ # Pattern for arrays with more than 20 numbers
44
+ large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
45
+
46
+ def truncate_large_numeric_array(match):
47
+ prefix = match.group(1)
48
+ array_content = match.group(2)
49
+ suffix = match.group(3)
50
+
51
+ # Split by comma and get first 10 items
52
+ items = array_content.split(',')
53
+ if len(items) > 10:
54
+ truncated = ','.join(items[:10])
55
+ return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
56
+ return match.group(0)
57
+
58
+ record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
59
+
60
+ # Truncate other field types
61
+ for field_name in ['chunk', 'wavChunk', 'data']:
62
+ field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
63
+ def make_truncate_field(fname):
64
+ def truncate_field(match):
65
+ array_content = match.group(1)
66
+ items = array_content.split(',')
67
+ if len(items) > 10:
68
+ truncated = ','.join(items[:10])
69
+ return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
70
+ return match.group(0)
71
+ return truncate_field
72
+
73
+ record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
74
+
75
+ return True
76
+
77
+
78
+ @asynccontextmanager
79
+ async def lifespan(app: FastAPI):
80
+ # Initialize services
81
+ print("=== INITIALIZING BACKEND SERVICES ===")
82
+ try:
83
+ print("Initializing transcription service...")
84
+ await transcription_service.initialize()
85
+ print("✓ Transcription service initialized")
86
+
87
+ print("Initializing translation service...")
88
+ await translation_service.initialize()
89
+ print("✓ Translation service initialized")
90
+
91
+ print("Initializing TTS service...")
92
+ await tts_service.initialize()
93
+ print("✓ TTS service initialized")
94
+
95
+ print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===" )
96
+
97
+ # Start background loading of additional models after successful startup
98
+ print("=== STARTING BACKGROUND MODEL LOADING ===")
99
+ transcription_service.start_background_loading()
100
+ tts_service.start_background_loading()
101
+ print("=== BACKGROUND MODEL LOADING INITIATED ===")
102
+
103
+ # Print CORS configuration summary
104
+ cors_config.print_config_summary()
105
+
106
+ except Exception as e:
107
+ print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
108
+ import traceback
109
+ traceback.print_exc()
110
+ raise
111
+
112
+ yield
113
+
114
+ # Cleanup
115
+ print("=== CLEANING UP SERVICES ===")
116
+ await transcription_service.cleanup()
117
+ await translation_service.cleanup()
118
+ await tts_service.cleanup()
119
+ print("=== CLEANUP COMPLETE ===")
120
+
121
+ app = FastAPI(
122
+ title="Real-time Transcription & Translation API",
123
+ description="Backend API for real-time speech transcription and translation",
124
+ version="1.0.0",
125
+ lifespan=lifespan
126
+ )
127
+
128
+ # CORS middleware with environment-based configuration
129
+ cors_middleware_config = cors_config.get_cors_middleware_config()
130
+ print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
131
+
132
+ app.add_middleware(
133
+ CORSMiddleware,
134
+ **cors_middleware_config
135
+ )
136
+
137
+ # Initialize services - using PyTorch models for better compatibility
138
+ session_manager = SessionManager()
139
+ transcription_service = TranscriptionService()
140
+ translation_service = TranslationService()
141
+ tts_service = TTSService()
142
+ websocket_manager = WebSocketManager(
143
+ session_manager=session_manager,
144
+ transcription_service=transcription_service,
145
+ translation_service=translation_service,
146
+ tts_service=tts_service
147
+ )
148
+
149
+ # Include routers
150
+ app.include_router(sessions.router, prefix="/api")
151
+ app.include_router(mobile.router, prefix="/api")
152
+ app.include_router(watch.router, prefix="/api")
153
+ app.include_router(learning.router)
154
+
155
+ # Set the session manager in the router
156
+ sessions.session_manager = session_manager
157
+ sessions.translation_service = translation_service
158
+ sessions.tts_service = tts_service
159
+ sessions.transcription_service = transcription_service
160
+
161
+ # Set the mobile router
162
+ mobile.translation_service = translation_service
163
+ mobile.tts_service = tts_service
164
+ mobile.transcription_service = transcription_service
165
+
166
+ # Set the watch router
167
+ watch.translation_service = translation_service
168
+ watch.tts_service = tts_service
169
+ watch.transcription_service = transcription_service
170
+
171
+
172
+ # Configure logging with custom filter to truncate chunk arrays
173
+ chunk_filter = ChunkArrayTruncateFilter()
174
+
175
+ sio_logger = logging.getLogger('socketio')
176
+ sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
177
+ sio_logger.addFilter(chunk_filter)
178
+
179
+ engineio_logger = logging.getLogger('engineio')
180
+ engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
181
+ engineio_logger.addFilter(chunk_filter)
182
+
183
+ # Also apply filter to the root logger to catch any other verbose logging
184
+ root_logger = logging.getLogger()
185
+ root_logger.addFilter(chunk_filter)
186
+
187
+ # Configure Engine.IO payload limits for large audio chunks
188
+ engineio.payload.Payload.max_decode_packets = 250
189
+
190
+ # Socket.IO setup with environment-based CORS
191
+ socketio_cors_origins = cors_config.get_socketio_cors_origins()
192
+ print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
193
+
194
+ sio = socketio.AsyncServer(
195
+ async_mode='asgi',
196
+ cors_allowed_origins=socketio_cors_origins,
197
+ cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
198
+ logger=True, # Re-enabled with custom filtering
199
+ engineio_logger=True, # Re-enabled with custom filtering
200
+ always_connect=False # This ensures connect event is called for authentication
201
+ )
202
+
203
+ # Set the socketio instance in websocket manager
204
+ websocket_manager.set_socketio(sio)
205
+
206
+ socket_app = socketio.ASGIApp(sio, app)
207
+
208
+ @app.get("/health")
209
+ async def health_check(token: str = Depends(optional_hf_token)):
210
+ """Health check endpoint - optionally authenticated"""
211
+ from app.auth import hf_auth
212
+
213
+ auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
214
+ if not hf_auth.is_local and not token:
215
+ auth_status = "unauthenticated"
216
+
217
+ return {
218
+ "status": "healthy",
219
+ "message": "Translation service is running",
220
+ "auth_status": auth_status,
221
+ "local_development": hf_auth.is_local,
222
+ "auth_bypassed": hf_auth.is_local,
223
+ "token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
224
+ "environment": {
225
+ "ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
226
+ "DEBUG": os.getenv('DEBUG', 'not set'),
227
+ "DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
228
+ "HOST": os.getenv('HOST', 'not set'),
229
+ "PORT": os.getenv('PORT', 'not set')
230
+ },
231
+ "services": {
232
+ "transcription": transcription_service is not None,
233
+ "translation": translation_service is not None,
234
+ "tts": tts_service is not None,
235
+ "sessions": session_manager is not None
236
+ }
237
+ }
238
+
239
+ @sio.event
240
+ async def connect(sid, environ=None, auth=None):
241
+ """Handle Socket.IO connection with authentication"""
242
+ try:
243
+ print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
244
+ print(f"SID: {sid}")
245
+ print(f"Auth data: {auth}")
246
+ print(f"Environ type: {type(environ)}")
247
+ print(f"Environ data: {environ}")
248
+
249
+ # Ensure environ is a dict
250
+ if environ is None:
251
+ environ = {}
252
+
253
+ print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
254
+ print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
255
+
256
+ # Check authentication from multiple sources
257
+ authenticated = False
258
+ auth_method = None
259
+
260
+ # Method 1: Check auth data from client
261
+ if auth and authenticate_websocket_auth_data(auth):
262
+ authenticated = True
263
+ auth_method = "auth_data"
264
+ print("✓ Authenticated via auth data")
265
+
266
+ # Method 2: Check environment (headers, query params)
267
+ elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
268
+ authenticated = True
269
+ auth_method = "environ"
270
+ print("✓ Authenticated via headers/query")
271
+
272
+ # TEMPORARY: Allow connections for debugging (remove in production)
273
+ # This helps identify if the issue is authentication or something else
274
+ if not authenticated:
275
+ print("⚠️ Authentication failed, but allowing for debugging")
276
+ if isinstance(environ, dict):
277
+ print(f"Available environ keys: {list(environ.keys())}")
278
+ # Uncomment the next line to temporarily allow unauthenticated connections for debugging
279
+ authenticated = True
280
+ auth_method = "debug_bypass"
281
+
282
+ if not authenticated:
283
+ print("❌ Authentication failed - disconnecting")
284
+ await sio.disconnect(sid)
285
+ return False
286
+
287
+ print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
288
+ return True
289
+
290
+ except Exception as e:
291
+ print(f"❌ Error in connect handler: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ try:
295
+ await sio.disconnect(sid)
296
+ except:
297
+ pass
298
+ return False
299
+
300
+ @sio.event
301
+ async def disconnect(sid):
302
+ await websocket_manager.handle_disconnect(sid)
303
+
304
+ @sio.event
305
+ async def join_session(sid, data):
306
+ await websocket_manager.handle_join_session(sid, data)
307
+
308
+ @sio.event
309
+ async def join_hub(sid, data):
310
+ await websocket_manager.handle_join_hub(sid, data)
311
+
312
+ @sio.event
313
+ async def leave_session(sid, data):
314
+ await websocket_manager.handle_leave_session(sid, data)
315
+
316
+ @sio.event
317
+ async def audio_chunk(sid, data):
318
+ await websocket_manager.handle_audio_chunk(sid, data)
319
+
320
+ @sio.event
321
+ async def speaking_status(sid, data):
322
+ await websocket_manager.handle_speaking_status(sid, data)
323
+
324
+ @sio.event
325
+ async def test_echo(sid, data):
326
+ """Test event to verify WebSocket communication"""
327
+ await sio.emit('test_echo_response', data, room=sid)
328
+
329
+ @sio.event
330
+ async def update_participant_language(sid, data):
331
+ """Update participant's language (affects speech recognition)"""
332
+ await websocket_manager.handle_update_participant_language(sid, data)
333
+
334
+ @sio.event
335
+ async def update_session_languages(sid, data):
336
+ """Update session's languages (affects translation targets)"""
337
+ await websocket_manager.handle_update_session_languages(sid, data)
338
+
339
+ # Serve static files (for frontend)
340
+ if os.path.exists("../frontend/dist"):
341
+ app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
342
+
343
+ if __name__ == "__main__":
344
+ import uvicorn
345
+ uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
app/main.py.bak ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os
3
+ import asyncio
4
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.staticfiles import StaticFiles
7
+ from contextlib import asynccontextmanager
8
+ import logging
9
+ import socketio
10
+ import engineio
11
+ import re
12
+
13
+ from app.routers import sessions, mobile, watch, learning
14
+ from app.services.session_manager import SessionManager
15
+ from app.services.transcription_service import TranscriptionService
16
+ from app.services.translation_service import TranslationService
17
+ from app.services.tts_service import TTSService
18
+ from app.services.websocket_manager import WebSocketManager
19
+ from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
20
+ from app.config.cors import cors_config
21
+
22
+ class ChunkArrayTruncateFilter(logging.Filter):
23
+ """Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
24
+
25
+ def filter(self, record):
26
+ if hasattr(record, 'msg') and isinstance(record.msg, str):
27
+ # More aggressive approach to truncate audioData arrays
28
+ # Pattern to match: "audioData":[numbers,numbers,numbers,...]
29
+ audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
30
+
31
+ def truncate_audiodata(match):
32
+ array_content = match.group(1)
33
+ # Split by comma and get first 10 items
34
+ items = array_content.split(',')
35
+ if len(items) > 10:
36
+ truncated = ','.join(items[:10])
37
+ return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
38
+ return match.group(0)
39
+
40
+ record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
41
+
42
+ # Also handle any other large numeric arrays in brackets
43
+ # Pattern for arrays with more than 20 numbers
44
+ large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
45
+
46
+ def truncate_large_numeric_array(match):
47
+ prefix = match.group(1)
48
+ array_content = match.group(2)
49
+ suffix = match.group(3)
50
+
51
+ # Split by comma and get first 10 items
52
+ items = array_content.split(',')
53
+ if len(items) > 10:
54
+ truncated = ','.join(items[:10])
55
+ return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
56
+ return match.group(0)
57
+
58
+ record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
59
+
60
+ # Truncate other field types
61
+ for field_name in ['chunk', 'wavChunk', 'data']:
62
+ field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
63
+ def make_truncate_field(fname):
64
+ def truncate_field(match):
65
+ array_content = match.group(1)
66
+ items = array_content.split(',')
67
+ if len(items) > 10:
68
+ truncated = ','.join(items[:10])
69
+ return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
70
+ return match.group(0)
71
+ return truncate_field
72
+
73
+ record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
74
+
75
+ return True
76
+
77
+
78
+ @asynccontextmanager
79
+ async def lifespan(app: FastAPI):
80
+ # Initialize services
81
+ print("=== INITIALIZING BACKEND SERVICES ===")
82
+ try:
83
+ print("Initializing transcription service...")
84
+ await transcription_service.initialize()
85
+ print("✓ Transcription service initialized")
86
+
87
+ print("Initializing translation service...")
88
+ await translation_service.initialize()
89
+ print("✓ Translation service initialized")
90
+
91
+ print("Initializing TTS service...")
92
+ await tts_service.initialize()
93
+ print("✓ TTS service initialized")
94
+
95
+ print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===")
96
+
97
+ # Start background loading of additional models after successful startup
98
+ print("=== STARTING BACKGROUND MODEL LOADING ===")
99
+ transcription_service.start_background_loading()
100
+ tts_service.start_background_loading()
101
+ print("=== BACKGROUND MODEL LOADING INITIATED ===")
102
+
103
+ # Print CORS configuration summary
104
+ cors_config.print_config_summary()
105
+
106
+ except Exception as e:
107
+ print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
108
+ import traceback
109
+ traceback.print_exc()
110
+ raise
111
+
112
+ yield
113
+
114
+ # Cleanup
115
+ print("=== CLEANING UP SERVICES ===")
116
+ await transcription_service.cleanup()
117
+ await translation_service.cleanup()
118
+ await tts_service.cleanup()
119
+ print("=== CLEANUP COMPLETE ===")
120
+
121
+ app = FastAPI(
122
+ title="Real-time Transcription & Translation API",
123
+ description="Backend API for real-time speech transcription and translation",
124
+ version="1.0.0",
125
+ lifespan=lifespan
126
+ )
127
+
128
+ # CORS middleware with environment-based configuration
129
+ cors_middleware_config = cors_config.get_cors_middleware_config()
130
+ print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
131
+
132
+ app.add_middleware(
133
+ CORSMiddleware,
134
+ **cors_middleware_config
135
+ )
136
+
137
+ # Initialize services - using PyTorch models for better compatibility
138
+ session_manager = SessionManager()
139
+ transcription_service = TranscriptionService()
140
+ translation_service = TranslationService()
141
+ tts_service = TTSService()
142
+ websocket_manager = WebSocketManager(
143
+ session_manager=session_manager,
144
+ transcription_service=transcription_service,
145
+ translation_service=translation_service,
146
+ tts_service=tts_service
147
+ )
148
+
149
+ # Include routers
150
+ app.include_router(sessions.router, prefix="/api")
151
+ app.include_router(mobile.router, prefix="/api")
152
+ app.include_router(watch.router, prefix="/api")
153
+ app.include_router(learning.router)
154
+
155
+ # Set the session manager in the router
156
+ sessions.session_manager = session_manager
157
+ sessions.translation_service = translation_service
158
+ sessions.tts_service = tts_service
159
+ sessions.transcription_service = transcription_service
160
+
161
+ # Set the mobile router
162
+ mobile.translation_service = translation_service
163
+ mobile.tts_service = tts_service
164
+ mobile.transcription_service = transcription_service
165
+
166
+ # Set the watch router
167
+ watch.translation_service = translation_service
168
+ watch.tts_service = tts_service
169
+ watch.transcription_service = transcription_service
170
+
171
+
172
+ # Configure logging with custom filter to truncate chunk arrays
173
+ chunk_filter = ChunkArrayTruncateFilter()
174
+
175
+ sio_logger = logging.getLogger('socketio')
176
+ sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
177
+ sio_logger.addFilter(chunk_filter)
178
+
179
+ engineio_logger = logging.getLogger('engineio')
180
+ engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
181
+ engineio_logger.addFilter(chunk_filter)
182
+
183
+ # Also apply filter to the root logger to catch any other verbose logging
184
+ root_logger = logging.getLogger()
185
+ root_logger.addFilter(chunk_filter)
186
+
187
+ # Configure Engine.IO payload limits for large audio chunks
188
+ engineio.payload.Payload.max_decode_packets = 250
189
+
190
+ # Socket.IO setup with environment-based CORS
191
+ socketio_cors_origins = cors_config.get_socketio_cors_origins()
192
+ print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
193
+
194
+ sio = socketio.AsyncServer(
195
+ async_mode='asgi',
196
+ cors_allowed_origins=socketio_cors_origins,
197
+ cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
198
+ logger=True, # Re-enabled with custom filtering
199
+ engineio_logger=True, # Re-enabled with custom filtering
200
+ always_connect=False # This ensures connect event is called for authentication
201
+ )
202
+
203
+ # Set the socketio instance in websocket manager
204
+ websocket_manager.set_socketio(sio)
205
+
206
+ socket_app = socketio.ASGIApp(sio, app)
207
+
208
+ @app.get("/health")
209
+ async def health_check(token: str = Depends(optional_hf_token)):
210
+ """Health check endpoint - optionally authenticated"""
211
+ from app.auth import hf_auth
212
+
213
+ auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
214
+ if not hf_auth.is_local and not token:
215
+ auth_status = "unauthenticated"
216
+
217
+ return {
218
+ "status": "healthy",
219
+ "message": "Translation service is running",
220
+ "auth_status": auth_status,
221
+ "local_development": hf_auth.is_local,
222
+ "auth_bypassed": hf_auth.is_local,
223
+ "token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
224
+ "environment": {
225
+ "ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
226
+ "DEBUG": os.getenv('DEBUG', 'not set'),
227
+ "DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
228
+ "HOST": os.getenv('HOST', 'not set'),
229
+ "PORT": os.getenv('PORT', 'not set')
230
+ },
231
+ "services": {
232
+ "transcription": transcription_service is not None,
233
+ "translation": translation_service is not None,
234
+ "tts": tts_service is not None,
235
+ "sessions": session_manager is not None
236
+ }
237
+ }
238
+
239
+ @sio.event
240
+ async def connect(sid, environ=None, auth=None):
241
+ """Handle Socket.IO connection with authentication"""
242
+ try:
243
+ print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
244
+ print(f"SID: {sid}")
245
+ print(f"Auth data: {auth}")
246
+ print(f"Environ type: {type(environ)}")
247
+ print(f"Environ data: {environ}")
248
+
249
+ # Ensure environ is a dict
250
+ if environ is None:
251
+ environ = {}
252
+
253
+ print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
254
+ print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
255
+
256
+ # Check authentication from multiple sources
257
+ authenticated = False
258
+ auth_method = None
259
+
260
+ # Method 1: Check auth data from client
261
+ if auth and authenticate_websocket_auth_data(auth):
262
+ authenticated = True
263
+ auth_method = "auth_data"
264
+ print("✓ Authenticated via auth data")
265
+
266
+ # Method 2: Check environment (headers, query params)
267
+ elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
268
+ authenticated = True
269
+ auth_method = "environ"
270
+ print("✓ Authenticated via headers/query")
271
+
272
+ # TEMPORARY: Allow connections for debugging (remove in production)
273
+ # This helps identify if the issue is authentication or something else
274
+ if not authenticated:
275
+ print("⚠️ Authentication failed, but allowing for debugging")
276
+ if isinstance(environ, dict):
277
+ print(f"Available environ keys: {list(environ.keys())}")
278
+ # Uncomment the next line to temporarily allow unauthenticated connections for debugging
279
+ authenticated = True
280
+ auth_method = "debug_bypass"
281
+
282
+ if not authenticated:
283
+ print("❌ Authentication failed - disconnecting")
284
+ await sio.disconnect(sid)
285
+ return False
286
+
287
+ print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
288
+ return True
289
+
290
+ except Exception as e:
291
+ print(f"❌ Error in connect handler: {e}")
292
+ import traceback
293
+ traceback.print_exc()
294
+ try:
295
+ await sio.disconnect(sid)
296
+ except:
297
+ pass
298
+ return False
299
+
300
+ @sio.event
301
+ async def disconnect(sid):
302
+ await websocket_manager.handle_disconnect(sid)
303
+
304
+ @sio.event
305
+ async def join_session(sid, data):
306
+ await websocket_manager.handle_join_session(sid, data)
307
+
308
+ @sio.event
309
+ async def join_hub(sid, data):
310
+ await websocket_manager.handle_join_hub(sid, data)
311
+
312
+ @sio.event
313
+ async def leave_session(sid, data):
314
+ await websocket_manager.handle_leave_session(sid, data)
315
+
316
+ @sio.event
317
+ async def audio_chunk(sid, data):
318
+ await websocket_manager.handle_audio_chunk(sid, data)
319
+
320
+ @sio.event
321
+ async def speaking_status(sid, data):
322
+ await websocket_manager.handle_speaking_status(sid, data)
323
+
324
+ @sio.event
325
+ async def test_echo(sid, data):
326
+ """Test event to verify WebSocket communication"""
327
+ await sio.emit('test_echo_response', data, room=sid)
328
+
329
+ @sio.event
330
+ async def update_participant_language(sid, data):
331
+ """Update participant's language (affects speech recognition)"""
332
+ await websocket_manager.handle_update_participant_language(sid, data)
333
+
334
+ @sio.event
335
+ async def update_session_languages(sid, data):
336
+ """Update session's languages (affects translation targets)"""
337
+ await websocket_manager.handle_update_session_languages(sid, data)
338
+
339
+ # Serve static files (for frontend)
340
+ if os.path.exists("../frontend/dist"):
341
+ app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
342
+
343
+ if __name__ == "__main__":
344
+ import uvicorn
345
+ uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
app/models/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Dict, Optional
3
+ from enum import Enum
4
+
5
+ class LanguageCode(str, Enum):
6
+ ENGLISH = "eng"
7
+ SWAHILI = "swa"
8
+ KIKUYU = "kik"
9
+ KAMBA = "kam"
10
+ KIMERU = "mer"
11
+ LUO = "luo"
12
+ SOMALI = "som"
13
+
14
+ class Language(BaseModel):
15
+ code: LanguageCode
16
+ name: str
17
+ display_name: str
18
+
19
+ class ParticipantCreate(BaseModel):
20
+ name: str
21
+ language: LanguageCode
22
+
23
+ class Participant(BaseModel):
24
+ id: str
25
+ name: str
26
+ language: Language
27
+ is_organizer: bool = False
28
+ is_speaking: bool = False
29
+ is_connected: bool = False
30
+
31
+ class SessionCreate(BaseModel):
32
+ name: str
33
+ organizer_name: str
34
+ languages: List[LanguageCode]
35
+ enable_tts: bool = True # Enable TTS by default for backward compatibility
36
+
37
+ class Session(BaseModel):
38
+ id: str
39
+ name: str
40
+ organizer_name: str
41
+ participants: List[Participant] = []
42
+ languages: List[Language] = []
43
+ qr_code_url: Optional[str] = None
44
+ is_active: bool = True
45
+ enable_tts: bool = True # TTS enabled by default
46
+
47
+ class Message(BaseModel):
48
+ id: str
49
+ session_id: str
50
+ speaker_id: str
51
+ speaker_name: str
52
+ original_text: str
53
+ original_language: Language
54
+ translations: Dict[str, str] = {}
55
+ is_transcribing: bool = False
56
+
57
+ class TranscriptionUpdate(BaseModel):
58
+ message_id: str
59
+ text: str
60
+ is_complete: bool
61
+ confidence: Optional[float] = None
62
+
63
+ class TranslationUpdate(BaseModel):
64
+ message_id: str
65
+ target_language: LanguageCode
66
+ translated_text: str
67
+
68
+ class AudioChunk(BaseModel):
69
+ session_id: str
70
+ participant_id: str
71
+ audio_data: bytes
72
+
73
+ class WebSocketMessage(BaseModel):
74
+ type: str
75
+ data: Dict
76
+ session_id: str
77
+ participant_id: Optional[str] = None
app/routers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Routers package
app/routers/add_phase_endpoints.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Script to add remaining Phase 1-3 endpoints to learning.py
2
+
3
+ endpoints_code = """
4
+
5
+ @router.post("/vocabulary/add")
6
+ async def add_vocabulary_to_practice(
7
+ vocab_request: VocabularyAddRequest,
8
+ request: Request,
9
+ token: Optional[str] = Depends(optional_hf_token)
10
+ ):
11
+ \"\"\"Add a vocabulary word to user's practice queue with FSRS initialization\"\"\"
12
+ try:
13
+ user_id = token if token else 'anonymous'
14
+
15
+ vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
16
+ if not vocab:
17
+ raise HTTPException(status_code=404, detail="Vocabulary not found")
18
+
19
+ fsrs_data = {
20
+ 'difficulty': 0.3,
21
+ 'stability': 2.5,
22
+ 'retrievability': 1.0,
23
+ 'review_count': 0,
24
+ 'last_review': None,
25
+ 'next_review': datetime.utcnow().isoformat() + 'Z',
26
+ 'lapses': 0,
27
+ 'state': 'new'
28
+ }
29
+
30
+ user_vocab = {
31
+ 'vocabulary_id': vocab_request.vocab_id,
32
+ 'swahili': vocab.get('swahili', ''),
33
+ 'english': vocab.get('english', ''),
34
+ 'part_of_speech': vocab.get('part_of_speech', 'unknown'),
35
+ 'added_at': datetime.utcnow().isoformat() + 'Z',
36
+ 'added_from': vocab_request.source_lesson_id,
37
+ 'fsrs': fsrs_data,
38
+ 'mastery_level': 0,
39
+ 'times_reviewed': 0,
40
+ 'times_correct': 0,
41
+ 'accuracy': 0.0
42
+ }
43
+
44
+ success = learning_service.update_vocabulary_progress(
45
+ user_id, str(vocab_request.vocab_id), user_vocab
46
+ )
47
+
48
+ if success:
49
+ return {"success": True, "vocabulary": user_vocab}
50
+ else:
51
+ raise HTTPException(status_code=500, detail="Failed to add vocabulary")
52
+ except HTTPException:
53
+ raise
54
+ except Exception as e:
55
+ logger.error(f"Error adding vocabulary: {e}")
56
+ raise HTTPException(status_code=500, detail="Failed to add vocabulary")
57
+
58
+
59
+ def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
60
+ \"\"\"Implement FSRS algorithm\"\"\"
61
+ from datetime import timedelta
62
+
63
+ difficulty = fsrs['difficulty']
64
+ stability = fsrs['stability']
65
+
66
+ if grade == 0:
67
+ new_difficulty = min(difficulty + 0.2, 1.0)
68
+ elif grade == 2:
69
+ new_difficulty = min(difficulty + 0.1, 1.0)
70
+ elif grade == 4:
71
+ new_difficulty = max(difficulty - 0.1, 0.0)
72
+ else:
73
+ new_difficulty = difficulty
74
+
75
+ if grade == 0:
76
+ new_stability = stability * 0.5
77
+ state = 'relearning'
78
+ interval_minutes = 10
79
+ elif grade == 2:
80
+ new_stability = stability * 1.2
81
+ state = 'review'
82
+ interval_minutes = int(new_stability * 24 * 60)
83
+ elif grade == 3:
84
+ new_stability = stability * 2.5
85
+ state = 'review'
86
+ interval_minutes = int(new_stability * 24 * 60)
87
+ else:
88
+ new_stability = stability * 4.0
89
+ state = 'review'
90
+ interval_minutes = int(new_stability * 24 * 60)
91
+
92
+ next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
93
+
94
+ return {
95
+ 'difficulty': new_difficulty,
96
+ 'stability': new_stability,
97
+ 'retrievability': 0.9 if grade >= 2 else 0.0,
98
+ 'review_count': fsrs['review_count'] + 1,
99
+ 'last_review': datetime.utcnow().isoformat() + 'Z',
100
+ 'next_review': next_review.isoformat() + 'Z',
101
+ 'lapses': fsrs['lapses'],
102
+ 'state': state,
103
+ 'interval_days': interval_minutes / (24 * 60)
104
+ }
105
+
106
+
107
+ def calculate_mastery_level(vocab: Dict) -> int:
108
+ \"\"\"Calculate mastery level (0-5)\"\"\"
109
+ accuracy = vocab['accuracy']
110
+ reviews = vocab['times_reviewed']
111
+ stability = vocab['fsrs']['stability']
112
+
113
+ if reviews == 0:
114
+ return 0
115
+ elif reviews < 5 or accuracy < 70:
116
+ return 1
117
+ elif reviews < 10 or accuracy < 85:
118
+ return 2
119
+ elif reviews < 20 or accuracy < 95:
120
+ return 3
121
+ elif reviews >= 20 and accuracy >= 95 and stability >= 30:
122
+ return 4
123
+ elif reviews >= 40 and accuracy >= 98 and stability >= 90:
124
+ return 5
125
+ else:
126
+ return 3
127
+
128
+
129
+ @router.post("/vocabulary/review")
130
+ async def record_vocabulary_review_fsrs(
131
+ review_request: VocabularyReviewRequest,
132
+ request: Request,
133
+ token: Optional[str] = Depends(optional_hf_token)
134
+ ):
135
+ \"\"\"Record vocabulary review and update FSRS parameters\"\"\"
136
+ try:
137
+ user_id = token if token else 'anonymous'
138
+ progress = learning_service.get_user_progress(user_id)
139
+
140
+ if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
141
+ raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
142
+
143
+ vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
144
+ fsrs = vocab['fsrs']
145
+
146
+ grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
147
+ grade = grade_map.get(review_request.rating, 3)
148
+
149
+ new_fsrs = calculate_next_review_fsrs(fsrs, grade)
150
+
151
+ vocab['fsrs'] = new_fsrs
152
+ vocab['times_reviewed'] += 1
153
+ if grade >= 2:
154
+ vocab['times_correct'] += 1
155
+ else:
156
+ vocab['fsrs']['lapses'] += 1
157
+
158
+ vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
159
+ vocab['mastery_level'] = calculate_mastery_level(vocab)
160
+ vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
161
+
162
+ if 'vocabulary_reviewed' not in progress['overall_stats']:
163
+ progress['overall_stats']['vocabulary_reviewed'] = 0
164
+ progress['overall_stats']['vocabulary_reviewed'] += 1
165
+
166
+ learning_service.save_user_progress(user_id, progress)
167
+
168
+ return {
169
+ "success": True,
170
+ "vocabulary": vocab,
171
+ "next_review": new_fsrs['next_review'],
172
+ "interval_days": new_fsrs['interval_days']
173
+ }
174
+ except HTTPException:
175
+ raise
176
+ except Exception as e:
177
+ logger.error(f"Error recording vocabulary review: {e}")
178
+ raise HTTPException(status_code=500, detail="Failed to record review")
179
+
180
+
181
+ @router.get("/vocabulary/stats")
182
+ async def get_vocabulary_stats(
183
+ request: Request,
184
+ token: Optional[str] = Depends(optional_hf_token)
185
+ ):
186
+ \"\"\"Get vocabulary mastery statistics\"\"\"
187
+ try:
188
+ user_id = token if token else 'anonymous'
189
+ progress = learning_service.get_user_progress(user_id)
190
+
191
+ if not progress:
192
+ return {
193
+ "total_words": 0,
194
+ "in_practice": 0,
195
+ "mastery_breakdown": {str(i): 0 for i in range(6)},
196
+ "average_accuracy": 0
197
+ }
198
+
199
+ vocab_progress = progress.get('vocabulary_progress', {})
200
+ mastery_breakdown = {str(i): 0 for i in range(6)}
201
+ total_accuracy = 0
202
+ total_with_reviews = 0
203
+
204
+ for vocab_data in vocab_progress.values():
205
+ level = vocab_data.get('mastery_level', 0)
206
+ mastery_breakdown[str(level)] += 1
207
+
208
+ if vocab_data.get('times_reviewed', 0) > 0:
209
+ total_accuracy += vocab_data.get('accuracy', 0)
210
+ total_with_reviews += 1
211
+
212
+ avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
213
+
214
+ return {
215
+ "total_words": len(vocab_progress),
216
+ "in_practice": len(vocab_progress),
217
+ "mastery_breakdown": mastery_breakdown,
218
+ "average_accuracy": round(avg_accuracy, 1),
219
+ "total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
220
+ }
221
+ except Exception as e:
222
+ logger.error(f"Error getting vocabulary stats: {e}")
223
+ raise HTTPException(status_code=500, detail="Failed to get stats")
224
+
225
+
226
+ @router.get("/vocabulary/library")
227
+ async def get_vocabulary_library(
228
+ lesson_id: Optional[int] = None,
229
+ level: Optional[str] = None,
230
+ search: Optional[str] = None,
231
+ request: Request = None,
232
+ token: Optional[str] = Depends(optional_hf_token)
233
+ ):
234
+ \"\"\"Browse all vocabulary with filters\"\"\"
235
+ try:
236
+ user_id = token if token else 'anonymous'
237
+
238
+ all_vocab = learning_service.get_all_vocabulary()
239
+ progress = learning_service.get_user_progress(user_id)
240
+ user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
241
+
242
+ filtered_vocab = all_vocab
243
+
244
+ if lesson_id:
245
+ filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
246
+
247
+ if level:
248
+ filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
249
+
250
+ if search:
251
+ search_lower = search.lower()
252
+ filtered_vocab = [v for v in filtered_vocab
253
+ if search_lower in v.get('swahili', '').lower()
254
+ or search_lower in v.get('english', '').lower()]
255
+
256
+ for vocab in filtered_vocab:
257
+ vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
258
+ if vocab_id in user_vocab:
259
+ vocab['status'] = 'practicing'
260
+ vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
261
+ vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
262
+ vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
263
+ else:
264
+ vocab['status'] = 'not_practicing'
265
+ vocab['mastery_level'] = 0
266
+
267
+ return {
268
+ "vocabulary": filtered_vocab,
269
+ "total": len(filtered_vocab),
270
+ "filters_applied": {
271
+ "lesson_id": lesson_id,
272
+ "level": level,
273
+ "search": search
274
+ }
275
+ }
276
+ except Exception as e:
277
+ logger.error(f"Error getting vocabulary library: {e}")
278
+ raise HTTPException(status_code=500, detail="Failed to get vocabulary")
279
+
280
+
281
+ # Reading Comprehension
282
+
283
+ class ComprehensionAnswer(BaseModel):
284
+ question_id: str
285
+ answer: str
286
+
287
+
288
+ class ComprehensionSubmission(BaseModel):
289
+ lesson_id: int
290
+ passage_id: str
291
+ answers: List[ComprehensionAnswer]
292
+
293
+
294
+ @router.post("/comprehension/submit")
295
+ async def submit_comprehension_answers(
296
+ submission: ComprehensionSubmission,
297
+ request: Request,
298
+ token: Optional[str] = Depends(optional_hf_token)
299
+ ):
300
+ \"\"\"Submit reading comprehension answers and get scoring\"\"\"
301
+ try:
302
+ user_id = token if token else 'anonymous'
303
+
304
+ lesson = learning_service.get_lesson(submission.lesson_id)
305
+ if not lesson:
306
+ raise HTTPException(status_code=404, detail="Lesson not found")
307
+
308
+ passage = None
309
+ for p in lesson.get('reading_passages', []):
310
+ if p['passage_id'] == submission.passage_id:
311
+ passage = p
312
+ break
313
+
314
+ if not passage:
315
+ raise HTTPException(status_code=404, detail="Passage not found")
316
+
317
+ results = []
318
+ correct_count = 0
319
+
320
+ for submitted in submission.answers:
321
+ question_id = submitted.question_id
322
+ user_answer = submitted.answer.strip().lower()
323
+
324
+ question = None
325
+ for q in passage['comprehension_questions']:
326
+ if q['question_id'] == question_id:
327
+ question = q
328
+ break
329
+
330
+ if not question:
331
+ continue
332
+
333
+ correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
334
+ is_correct = user_answer in correct_answers
335
+
336
+ if is_correct:
337
+ correct_count += 1
338
+
339
+ results.append({
340
+ "question_id": question_id,
341
+ "correct": is_correct,
342
+ "user_answer": user_answer,
343
+ "correct_answer": question['correct_answers'][0] if correct_answers else None,
344
+ "explanation": question.get('explanation')
345
+ })
346
+
347
+ score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
348
+
349
+ progress = learning_service.get_user_progress(user_id)
350
+ if not progress:
351
+ progress = learning_service.create_default_progress(user_id)
352
+
353
+ if 'comprehension_scores' not in progress:
354
+ progress['comprehension_scores'] = {}
355
+
356
+ progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
357
+ "score": score,
358
+ "completed_at": datetime.utcnow().isoformat() + 'Z',
359
+ "attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
360
+ }
361
+
362
+ learning_service.save_user_progress(user_id, progress)
363
+
364
+ return {
365
+ "results": results,
366
+ "score": round(score, 1),
367
+ "correct": correct_count,
368
+ "total": len(submission.answers)
369
+ }
370
+ except HTTPException:
371
+ raise
372
+ except Exception as e:
373
+ logger.error(f"Error submitting comprehension: {e}")
374
+ raise HTTPException(status_code=500, detail="Failed to submit comprehension")
375
+
376
+
377
+ # Task Scenarios
378
+
379
+ class ScenarioProgressUpdate(BaseModel):
380
+ turn_id: str
381
+ choice_id: str
382
+
383
+
384
+ @router.get("/scenarios/{scenario_id}")
385
+ async def get_scenario(
386
+ scenario_id: str,
387
+ request: Request,
388
+ token: Optional[str] = Depends(optional_hf_token)
389
+ ):
390
+ \"\"\"Get task scenario with branching dialogue\"\"\"
391
+ try:
392
+ user_id = token if token else 'anonymous'
393
+
394
+ scenario = learning_service.get_scenario(scenario_id)
395
+ if not scenario:
396
+ raise HTTPException(status_code=404, detail="Scenario not found")
397
+
398
+ progress = learning_service.get_user_progress(user_id)
399
+ scenario_progress = None
400
+
401
+ if progress and 'scenario_progress' in progress:
402
+ scenario_progress = progress['scenario_progress'].get(scenario_id)
403
+
404
+ return {
405
+ "scenario": scenario,
406
+ "user_progress": scenario_progress
407
+ }
408
+ except HTTPException:
409
+ raise
410
+ except Exception as e:
411
+ logger.error(f"Error getting scenario: {e}")
412
+ raise HTTPException(status_code=500, detail="Failed to get scenario")
413
+
414
+
415
+ @router.post("/scenarios/{scenario_id}/progress")
416
+ async def update_scenario_progress(
417
+ scenario_id: str,
418
+ progress_update: ScenarioProgressUpdate,
419
+ request: Request,
420
+ token: Optional[str] = Depends(optional_hf_token)
421
+ ):
422
+ \"\"\"Update scenario progress with user choice\"\"\"
423
+ try:
424
+ user_id = token if token else 'anonymous'
425
+
426
+ scenario = learning_service.get_scenario(scenario_id)
427
+ if not scenario:
428
+ raise HTTPException(status_code=404, detail="Scenario not found")
429
+
430
+ progress = learning_service.get_user_progress(user_id)
431
+ if not progress:
432
+ progress = learning_service.create_default_progress(user_id)
433
+
434
+ if 'scenario_progress' not in progress:
435
+ progress['scenario_progress'] = {}
436
+
437
+ if scenario_id not in progress['scenario_progress']:
438
+ progress['scenario_progress'][scenario_id] = {
439
+ "started_at": datetime.utcnow().isoformat() + 'Z',
440
+ "turns": [],
441
+ "completed": False
442
+ }
443
+
444
+ progress['scenario_progress'][scenario_id]['turns'].append({
445
+ "turn_id": progress_update.turn_id,
446
+ "choice_id": progress_update.choice_id,
447
+ "timestamp": datetime.utcnow().isoformat() + 'Z'
448
+ })
449
+
450
+ turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
451
+ if turns_count >= scenario.get('required_turns', 6):
452
+ progress['scenario_progress'][scenario_id]['completed'] = True
453
+ progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
454
+
455
+ learning_service.save_user_progress(user_id, progress)
456
+
457
+ return {
458
+ "success": True,
459
+ "progress": progress['scenario_progress'][scenario_id]
460
+ }
461
+ except HTTPException:
462
+ raise
463
+ except Exception as e:
464
+ logger.error(f"Error updating scenario progress: {e}")
465
+ raise HTTPException(status_code=500, detail="Failed to update scenario")
466
+
467
+
468
+ @router.get("/scenarios")
469
+ async def list_scenarios(
470
+ request: Request,
471
+ token: Optional[str] = Depends(optional_hf_token)
472
+ ):
473
+ \"\"\"Get list of all available scenarios\"\"\"
474
+ try:
475
+ scenarios = learning_service.get_all_scenarios()
476
+ return {
477
+ "success": True,
478
+ "scenarios": scenarios,
479
+ "total": len(scenarios)
480
+ }
481
+ except Exception as e:
482
+ logger.error(f"Error listing scenarios: {e}")
483
+ raise HTTPException(status_code=500, detail="Failed to list scenarios")
484
+ """
485
+
486
+ # Append to learning.py
487
+ with open('C:/repos/polyglot/backend/app/routers/learning.py', 'a', encoding='utf-8') as f:
488
+ f.write(endpoints_code)
489
+
490
+ print("Successfully added all remaining Phase 1-3 endpoints!")
app/routers/learning.py ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning API Router - REST endpoints for language learning functionality
3
+
4
+ Provides endpoints for:
5
+ - Fetching lesson catalog and individual lessons
6
+ - Managing user progress
7
+ - Recording lesson completion and scores
8
+ - Achievement tracking
9
+ """
10
+
11
+ from fastapi import APIRouter, HTTPException, Depends, Request, File, UploadFile
12
+ from fastapi.responses import Response
13
+ from pydantic import BaseModel
14
+ from typing import List, Dict, Optional, Any
15
+ from datetime import datetime
16
+ import logging
17
+ import io
18
+
19
+ from app.services.learning_data_service import LearningDataService
20
+ from app.auth import optional_hf_token
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ router = APIRouter(prefix="/api/learning", tags=["learning"])
25
+
26
+ # Initialize data service
27
+ learning_service = LearningDataService()
28
+
29
+
30
+ # ==================== Request/Response Models ====================
31
+
32
+ class LessonProgressUpdate(BaseModel):
33
+ lesson_id: int
34
+ status: str # 'in_progress' or 'completed'
35
+ score: Optional[int] = None
36
+ pronunciation_score: Optional[float] = None
37
+ listening_score: Optional[float] = None
38
+ comprehension_score: Optional[float] = None
39
+ time_spent_seconds: Optional[int] = None
40
+ steps_completed: Optional[int] = None
41
+ steps_skipped: Optional[int] = None
42
+
43
+
44
+ class VocabularyReview(BaseModel):
45
+ vocabulary_id: int
46
+ swahili: str
47
+ is_correct: bool
48
+ mastery_level: Optional[int] = None
49
+
50
+
51
+ class AchievementCheck(BaseModel):
52
+ achievement_id: str
53
+ progress: int
54
+ target: int
55
+
56
+
57
+ # ==================== Lesson Endpoints ====================
58
+
59
+ @router.get("/lessons")
60
+ async def get_lessons(language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
61
+ """
62
+ Get catalog of all available lessons for a specific language
63
+
64
+ Args:
65
+ language: Language code (swahili, kamba, maasai)
66
+
67
+ Returns the lesson index with metadata for all lessons
68
+ """
69
+ try:
70
+ index = learning_service.get_lessons_index(language)
71
+ if not index:
72
+ raise HTTPException(status_code=404, detail=f"Lessons catalog not found for {language}")
73
+
74
+ return {
75
+ "success": True,
76
+ "lessons": index.get('lessons', []),
77
+ "learning_paths": index.get('learning_paths', {}),
78
+ "metadata": index.get('metadata', {})
79
+ }
80
+ except HTTPException:
81
+ raise
82
+ except Exception as e:
83
+ logger.error(f"Error fetching lessons for {language}: {e}")
84
+ raise HTTPException(status_code=500, detail="Failed to fetch lessons")
85
+
86
+
87
+ @router.get("/lessons/{lesson_id}")
88
+ async def get_lesson(lesson_id: int, language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
89
+ """
90
+ Get detailed lesson content including vocabulary, dialogue, and exercises
91
+
92
+ Args:
93
+ lesson_id: ID of the lesson to fetch
94
+ language: Language code (swahili, kamba, maasai)
95
+ """
96
+ try:
97
+ lesson = learning_service.get_lesson(lesson_id, language)
98
+ if not lesson:
99
+ raise HTTPException(status_code=404, detail=f"Lesson {lesson_id} not found for {language}")
100
+
101
+ return {
102
+ "success": True,
103
+ "lesson": lesson
104
+ }
105
+ except HTTPException:
106
+ raise
107
+ except Exception as e:
108
+ logger.error(f"Error fetching lesson {lesson_id} for {language}: {e}")
109
+ raise HTTPException(status_code=500, detail="Failed to fetch lesson")
110
+
111
+
112
+ # ==================== User Progress Endpoints ====================
113
+
114
+ @router.get("/progress")
115
+ async def get_user_progress(request: Request, token: Optional[str] = Depends(optional_hf_token)):
116
+ """
117
+ Get user's learning progress
118
+
119
+ Returns overall stats, lesson progress, vocabulary progress, and achievements
120
+ """
121
+ try:
122
+ # Use authenticated user ID or default for anonymous users
123
+ user_id = token if token else 'anonymous'
124
+
125
+ progress = learning_service.get_user_progress(user_id)
126
+ if not progress:
127
+ raise HTTPException(status_code=500, detail="Failed to load user progress")
128
+
129
+ return {
130
+ "success": True,
131
+ "progress": progress
132
+ }
133
+ except HTTPException:
134
+ raise
135
+ except Exception as e:
136
+ logger.error(f"Error fetching user progress: {e}")
137
+ raise HTTPException(status_code=500, detail="Failed to fetch user progress")
138
+
139
+
140
+ @router.post("/progress/lesson")
141
+ async def update_lesson_progress(
142
+ progress_update: LessonProgressUpdate,
143
+ request: Request,
144
+ token: Optional[str] = Depends(optional_hf_token)
145
+ ):
146
+ """
147
+ Update progress for a specific lesson
148
+
149
+ Records completion status, scores, and time spent on a lesson
150
+ """
151
+ try:
152
+ user_id = token if token else 'anonymous'
153
+
154
+ # Build progress update dict
155
+ update_data = {
156
+ 'lesson_id': progress_update.lesson_id,
157
+ 'status': progress_update.status
158
+ }
159
+
160
+ # Add optional fields if provided
161
+ if progress_update.score is not None:
162
+ update_data['latest_score'] = progress_update.score
163
+
164
+ # Track best score
165
+ user_progress = learning_service.get_user_progress(user_id)
166
+ if user_progress:
167
+ lesson_key = str(progress_update.lesson_id)
168
+ current_best = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('best_score', 0)
169
+ update_data['best_score'] = max(current_best, progress_update.score)
170
+
171
+ if progress_update.pronunciation_score is not None:
172
+ update_data['pronunciation_score'] = progress_update.pronunciation_score
173
+
174
+ if progress_update.listening_score is not None:
175
+ update_data['listening_score'] = progress_update.listening_score
176
+
177
+ if progress_update.comprehension_score is not None:
178
+ update_data['comprehension_score'] = progress_update.comprehension_score
179
+
180
+ if progress_update.time_spent_seconds is not None:
181
+ update_data['time_spent_seconds'] = progress_update.time_spent_seconds
182
+
183
+ if progress_update.steps_completed is not None:
184
+ update_data['steps_completed'] = progress_update.steps_completed
185
+
186
+ if progress_update.steps_skipped is not None:
187
+ update_data['steps_skipped'] = progress_update.steps_skipped
188
+
189
+ # Add completion timestamp if status is completed
190
+ if progress_update.status == 'completed':
191
+ update_data['completed_at'] = datetime.utcnow().isoformat() + 'Z'
192
+
193
+ # Increment attempts
194
+ user_progress = learning_service.get_user_progress(user_id)
195
+ if user_progress:
196
+ lesson_key = str(progress_update.lesson_id)
197
+ current_attempts = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('attempts', 0)
198
+ update_data['attempts'] = current_attempts + 1
199
+
200
+ # Save to file
201
+ success = learning_service.update_lesson_progress(
202
+ user_id,
203
+ progress_update.lesson_id,
204
+ update_data
205
+ )
206
+
207
+ if not success:
208
+ raise HTTPException(status_code=500, detail="Failed to save progress")
209
+
210
+ return {
211
+ "success": True,
212
+ "message": "Lesson progress updated"
213
+ }
214
+ except HTTPException:
215
+ raise
216
+ except Exception as e:
217
+ logger.error(f"Error updating lesson progress: {e}")
218
+ raise HTTPException(status_code=500, detail="Failed to update lesson progress")
219
+
220
+
221
+ @router.post("/progress/vocabulary")
222
+ async def record_vocabulary_review(
223
+ review: VocabularyReview,
224
+ request: Request,
225
+ token: Optional[str] = Depends(optional_hf_token)
226
+ ):
227
+ """
228
+ Record a vocabulary review/practice session
229
+
230
+ Updates mastery level and review statistics for a vocabulary word
231
+ """
232
+ try:
233
+ user_id = token if token else 'anonymous'
234
+
235
+ # Get current vocabulary progress
236
+ user_progress = learning_service.get_user_progress(user_id)
237
+ if not user_progress:
238
+ raise HTTPException(status_code=500, detail="Failed to load user progress")
239
+
240
+ vocab_key = str(review.vocabulary_id)
241
+ vocab_progress = user_progress.get('vocabulary_progress', {}).get(vocab_key, {
242
+ 'vocabulary_id': review.vocabulary_id,
243
+ 'swahili': review.swahili,
244
+ 'mastery_level': 0,
245
+ 'times_reviewed': 0,
246
+ 'times_correct': 0,
247
+ 'ease_factor': 2.5,
248
+ 'interval_days': 0
249
+ })
250
+
251
+ # Update review counts
252
+ vocab_progress['times_reviewed'] = vocab_progress.get('times_reviewed', 0) + 1
253
+ if review.is_correct:
254
+ vocab_progress['times_correct'] = vocab_progress.get('times_correct', 0) + 1
255
+
256
+ # Update mastery level if provided
257
+ if review.mastery_level is not None:
258
+ vocab_progress['mastery_level'] = review.mastery_level
259
+
260
+ # Update timestamps
261
+ vocab_progress['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
262
+
263
+ # Calculate next review date using simple spaced repetition
264
+ # (simplified version - could use SuperMemo SM-2 algorithm)
265
+ interval_days = vocab_progress.get('interval_days', 0)
266
+ if review.is_correct:
267
+ interval_days = max(1, interval_days * 2) # Double the interval
268
+ else:
269
+ interval_days = 1 # Reset to 1 day if incorrect
270
+
271
+ vocab_progress['interval_days'] = interval_days
272
+
273
+ from datetime import timedelta
274
+ next_review = datetime.utcnow() + timedelta(days=interval_days)
275
+ vocab_progress['next_review_at'] = next_review.isoformat() + 'Z'
276
+
277
+ # Save to file
278
+ success = learning_service.update_vocabulary_progress(
279
+ user_id,
280
+ review.vocabulary_id,
281
+ vocab_progress
282
+ )
283
+
284
+ if not success:
285
+ raise HTTPException(status_code=500, detail="Failed to save vocabulary progress")
286
+
287
+ return {
288
+ "success": True,
289
+ "message": "Vocabulary review recorded",
290
+ "next_review_at": vocab_progress['next_review_at']
291
+ }
292
+ except HTTPException:
293
+ raise
294
+ except Exception as e:
295
+ logger.error(f"Error recording vocabulary review: {e}")
296
+ raise HTTPException(status_code=500, detail="Failed to record vocabulary review")
297
+
298
+
299
+ # ==================== Achievement Endpoints ====================
300
+
301
+ @router.get("/achievements")
302
+ async def get_achievements(request: Request, token: Optional[str] = Depends(optional_hf_token)):
303
+ """
304
+ Get all available achievements and user's progress on them
305
+ """
306
+ try:
307
+ # Get achievements configuration
308
+ achievements_config = learning_service.get_achievements()
309
+ if not achievements_config:
310
+ raise HTTPException(status_code=404, detail="Achievements not found")
311
+
312
+ # Get user progress
313
+ user_id = token if token else 'anonymous'
314
+ user_progress = learning_service.get_user_progress(user_id)
315
+
316
+ # Merge achievement definitions with user progress
317
+ user_achievements = user_progress.get('achievements', {}) if user_progress else {}
318
+
319
+ achievements_with_progress = []
320
+ for achievement in achievements_config.get('achievements', []):
321
+ achievement_id = achievement['achievement_id']
322
+ achievement_data = {
323
+ **achievement,
324
+ 'unlocked': False,
325
+ 'progress': 0
326
+ }
327
+
328
+ # Add user progress if available
329
+ if achievement_id in user_achievements:
330
+ achievement_data.update(user_achievements[achievement_id])
331
+
332
+ achievements_with_progress.append(achievement_data)
333
+
334
+ return {
335
+ "success": True,
336
+ "achievements": achievements_with_progress,
337
+ "tiers": achievements_config.get('tiers', {})
338
+ }
339
+ except HTTPException:
340
+ raise
341
+ except Exception as e:
342
+ logger.error(f"Error fetching achievements: {e}")
343
+ raise HTTPException(status_code=500, detail="Failed to fetch achievements")
344
+
345
+
346
+ @router.post("/achievements/check")
347
+ async def check_achievement(
348
+ achievement: AchievementCheck,
349
+ request: Request,
350
+ token: Optional[str] = Depends(optional_hf_token)
351
+ ):
352
+ """
353
+ Check and potentially unlock an achievement
354
+
355
+ Updates achievement progress and unlocks if target is reached
356
+ """
357
+ try:
358
+ user_id = token if token else 'anonymous'
359
+
360
+ success = learning_service.unlock_achievement(
361
+ user_id,
362
+ achievement.achievement_id,
363
+ achievement.progress,
364
+ achievement.target
365
+ )
366
+
367
+ if not success:
368
+ raise HTTPException(status_code=500, detail="Failed to update achievement")
369
+
370
+ is_unlocked = achievement.progress >= achievement.target
371
+
372
+ return {
373
+ "success": True,
374
+ "unlocked": is_unlocked,
375
+ "achievement_id": achievement.achievement_id
376
+ }
377
+ except HTTPException:
378
+ raise
379
+ except Exception as e:
380
+ logger.error(f"Error checking achievement: {e}")
381
+ raise HTTPException(status_code=500, detail="Failed to check achievement")
382
+
383
+
384
+ # ==================== Statistics Endpoints ====================
385
+
386
+ @router.get("/stats")
387
+ async def get_user_stats(request: Request, token: Optional[str] = Depends(optional_hf_token)):
388
+ """
389
+ Get user's overall learning statistics
390
+
391
+ Returns aggregated stats like total XP, streak, lessons completed, etc.
392
+ """
393
+ try:
394
+ user_id = token if token else 'anonymous'
395
+ progress = learning_service.get_user_progress(user_id)
396
+
397
+ if not progress:
398
+ raise HTTPException(status_code=500, detail="Failed to load user progress")
399
+
400
+ return {
401
+ "success": True,
402
+ "stats": progress.get('overall_stats', {}),
403
+ "daily_stats": progress.get('daily_stats', {})
404
+ }
405
+ except HTTPException:
406
+ raise
407
+ except Exception as e:
408
+ logger.error(f"Error fetching user stats: {e}")
409
+ raise HTTPException(status_code=500, detail="Failed to fetch user stats")
410
+
411
+
412
+ # ==================== TTS and ASR Endpoints ====================
413
+
414
+ class TTSRequest(BaseModel):
415
+ text: str
416
+ language: str
417
+ messageId: Optional[str] = None
418
+
419
+
420
+ @router.post("/tts/generate")
421
+ async def generate_tts(
422
+ tts_request: TTSRequest,
423
+ request: Request
424
+ ):
425
+ """
426
+ Generate TTS audio for lesson text
427
+ """
428
+ try:
429
+ from app.main import tts_service
430
+
431
+ # Generate TTS audio
432
+ audio_data = await tts_service.generate_speech(
433
+ text=tts_request.text,
434
+ language_code=tts_request.language
435
+ )
436
+
437
+ if not audio_data:
438
+ raise HTTPException(status_code=500, detail="Failed to generate TTS audio")
439
+
440
+ # Return audio as WAV file
441
+ return Response(
442
+ content=audio_data,
443
+ media_type="audio/wav",
444
+ headers={
445
+ "Content-Disposition": f"inline; filename=tts_{tts_request.messageId or 'audio'}.wav"
446
+ }
447
+ )
448
+ except Exception as e:
449
+ logger.error(f"Error generating TTS: {e}")
450
+ raise HTTPException(status_code=500, detail=f"Failed to generate TTS: {str(e)}")
451
+
452
+
453
+ @router.post("/transcribe")
454
+ async def transcribe_audio(
455
+ request: Request,
456
+ audio: UploadFile = File(...)
457
+ ):
458
+ """
459
+ Transcribe audio for pronunciation practice
460
+ """
461
+ try:
462
+ from app.main import transcription_service
463
+
464
+ # Read audio file
465
+ audio_bytes = await audio.read()
466
+
467
+ # Get language from form data (default to Swahili)
468
+ form = await request.form()
469
+ language = form.get('language', 'swa')
470
+
471
+ # Transcribe
472
+ text = await transcription_service.transcribe_audio(
473
+ audio_data=audio_bytes,
474
+ language_code=language
475
+ )
476
+
477
+ return {
478
+ "success": True,
479
+ "text": text,
480
+ "language": language
481
+ }
482
+ except Exception as e:
483
+ logger.error(f"Error transcribing audio: {e}")
484
+ raise HTTPException(status_code=500, detail=f"Failed to transcribe: {str(e)}")
485
+
486
+
487
+ # ==================== Phase 1-3 Endpoints ====================
488
+
489
+ # Vocabulary Management
490
+
491
+ class VocabularyAddRequest(BaseModel):
492
+ vocab_id: int
493
+ source_lesson_id: Optional[int] = None
494
+
495
+
496
+ class VocabularyReviewRequest(BaseModel):
497
+ vocab_id: int
498
+ rating: str # 'again', 'hard', 'good', 'easy'
499
+
500
+
501
+ @router.get("/vocabulary/due")
502
+ async def get_due_vocabulary(
503
+ request: Request,
504
+ token: Optional[str] = Depends(optional_hf_token)
505
+ ):
506
+ """Get vocabulary words due for FSRS review"""
507
+ try:
508
+ user_id = token if token else 'anonymous'
509
+ progress = learning_service.get_user_progress(user_id)
510
+
511
+ if not progress:
512
+ return {"due_words": [], "total_due": 0}
513
+
514
+ vocab_progress = progress.get('vocabulary_progress', {})
515
+ now = datetime.utcnow()
516
+ due_words = []
517
+
518
+ for vocab_id, vocab_data in vocab_progress.items():
519
+ next_review_str = vocab_data.get('fsrs', {}).get('next_review')
520
+ if not next_review_str:
521
+ continue
522
+
523
+ next_review = datetime.fromisoformat(next_review_str.rstrip('Z'))
524
+
525
+ if next_review <= now:
526
+ hours_overdue = (now - next_review).total_seconds() / 3600
527
+ vocab_data['priority'] = 1000 - hours_overdue
528
+ due_words.append(vocab_data)
529
+
530
+ due_words.sort(key=lambda x: x.get('priority', 0), reverse=True)
531
+
532
+ return {
533
+ "due_words": due_words,
534
+ "total_due": len(due_words),
535
+ "timestamp": now.isoformat() + 'Z'
536
+ }
537
+ except Exception as e:
538
+ logger.error(f"Error getting due vocabulary: {e}")
539
+ raise HTTPException(status_code=500, detail="Failed to get due vocabulary")
540
+
541
+
542
+ @router.post("/vocabulary/add")
543
+ async def add_vocabulary_to_practice(
544
+ vocab_request: VocabularyAddRequest,
545
+ request: Request,
546
+ token: Optional[str] = Depends(optional_hf_token)
547
+ ):
548
+ """Add a vocabulary word to user's practice queue with FSRS initialization"""
549
+ try:
550
+ user_id = token if token else 'anonymous'
551
+
552
+ vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
553
+ if not vocab:
554
+ raise HTTPException(status_code=404, detail="Vocabulary not found")
555
+
556
+ fsrs_data = {
557
+ 'difficulty': 0.3,
558
+ 'stability': 2.5,
559
+ 'retrievability': 1.0,
560
+ 'review_count': 0,
561
+ 'last_review': None,
562
+ 'next_review': datetime.utcnow().isoformat() + 'Z',
563
+ 'lapses': 0,
564
+ 'state': 'new'
565
+ }
566
+
567
+ user_vocab = {
568
+ 'vocabulary_id': vocab_request.vocab_id,
569
+ 'swahili': vocab.get('swahili', ''),
570
+ 'english': vocab.get('english', ''),
571
+ 'part_of_speech': vocab.get('part_of_speech', 'unknown'),
572
+ 'added_at': datetime.utcnow().isoformat() + 'Z',
573
+ 'added_from': vocab_request.source_lesson_id,
574
+ 'fsrs': fsrs_data,
575
+ 'mastery_level': 0,
576
+ 'times_reviewed': 0,
577
+ 'times_correct': 0,
578
+ 'accuracy': 0.0
579
+ }
580
+
581
+ success = learning_service.update_vocabulary_progress(
582
+ user_id, str(vocab_request.vocab_id), user_vocab
583
+ )
584
+
585
+ if success:
586
+ return {"success": True, "vocabulary": user_vocab}
587
+ else:
588
+ raise HTTPException(status_code=500, detail="Failed to add vocabulary")
589
+ except HTTPException:
590
+ raise
591
+ except Exception as e:
592
+ logger.error(f"Error adding vocabulary: {e}")
593
+ raise HTTPException(status_code=500, detail="Failed to add vocabulary")
594
+
595
+
596
+ def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
597
+ """Implement FSRS algorithm"""
598
+ from datetime import timedelta
599
+
600
+ difficulty = fsrs['difficulty']
601
+ stability = fsrs['stability']
602
+
603
+ if grade == 0:
604
+ new_difficulty = min(difficulty + 0.2, 1.0)
605
+ elif grade == 2:
606
+ new_difficulty = min(difficulty + 0.1, 1.0)
607
+ elif grade == 4:
608
+ new_difficulty = max(difficulty - 0.1, 0.0)
609
+ else:
610
+ new_difficulty = difficulty
611
+
612
+ if grade == 0:
613
+ new_stability = stability * 0.5
614
+ state = 'relearning'
615
+ interval_minutes = 10
616
+ elif grade == 2:
617
+ new_stability = stability * 1.2
618
+ state = 'review'
619
+ interval_minutes = int(new_stability * 24 * 60)
620
+ elif grade == 3:
621
+ new_stability = stability * 2.5
622
+ state = 'review'
623
+ interval_minutes = int(new_stability * 24 * 60)
624
+ else:
625
+ new_stability = stability * 4.0
626
+ state = 'review'
627
+ interval_minutes = int(new_stability * 24 * 60)
628
+
629
+ next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
630
+
631
+ return {
632
+ 'difficulty': new_difficulty,
633
+ 'stability': new_stability,
634
+ 'retrievability': 0.9 if grade >= 2 else 0.0,
635
+ 'review_count': fsrs['review_count'] + 1,
636
+ 'last_review': datetime.utcnow().isoformat() + 'Z',
637
+ 'next_review': next_review.isoformat() + 'Z',
638
+ 'lapses': fsrs['lapses'],
639
+ 'state': state,
640
+ 'interval_days': interval_minutes / (24 * 60)
641
+ }
642
+
643
+
644
+ def calculate_mastery_level(vocab: Dict) -> int:
645
+ """Calculate mastery level (0-5)"""
646
+ accuracy = vocab['accuracy']
647
+ reviews = vocab['times_reviewed']
648
+ stability = vocab['fsrs']['stability']
649
+
650
+ if reviews == 0:
651
+ return 0
652
+ elif reviews < 5 or accuracy < 70:
653
+ return 1
654
+ elif reviews < 10 or accuracy < 85:
655
+ return 2
656
+ elif reviews < 20 or accuracy < 95:
657
+ return 3
658
+ elif reviews >= 20 and accuracy >= 95 and stability >= 30:
659
+ return 4
660
+ elif reviews >= 40 and accuracy >= 98 and stability >= 90:
661
+ return 5
662
+ else:
663
+ return 3
664
+
665
+
666
+ @router.post("/vocabulary/review")
667
+ async def record_vocabulary_review_fsrs(
668
+ review_request: VocabularyReviewRequest,
669
+ request: Request,
670
+ token: Optional[str] = Depends(optional_hf_token)
671
+ ):
672
+ """Record vocabulary review and update FSRS parameters"""
673
+ try:
674
+ user_id = token if token else 'anonymous'
675
+ progress = learning_service.get_user_progress(user_id)
676
+
677
+ if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
678
+ raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
679
+
680
+ vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
681
+ fsrs = vocab['fsrs']
682
+
683
+ grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
684
+ grade = grade_map.get(review_request.rating, 3)
685
+
686
+ new_fsrs = calculate_next_review_fsrs(fsrs, grade)
687
+
688
+ vocab['fsrs'] = new_fsrs
689
+ vocab['times_reviewed'] += 1
690
+ if grade >= 2:
691
+ vocab['times_correct'] += 1
692
+ else:
693
+ vocab['fsrs']['lapses'] += 1
694
+
695
+ vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
696
+ vocab['mastery_level'] = calculate_mastery_level(vocab)
697
+ vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
698
+
699
+ if 'vocabulary_reviewed' not in progress['overall_stats']:
700
+ progress['overall_stats']['vocabulary_reviewed'] = 0
701
+ progress['overall_stats']['vocabulary_reviewed'] += 1
702
+
703
+ learning_service.save_user_progress(user_id, progress)
704
+
705
+ return {
706
+ "success": True,
707
+ "vocabulary": vocab,
708
+ "next_review": new_fsrs['next_review'],
709
+ "interval_days": new_fsrs['interval_days']
710
+ }
711
+ except HTTPException:
712
+ raise
713
+ except Exception as e:
714
+ logger.error(f"Error recording vocabulary review: {e}")
715
+ raise HTTPException(status_code=500, detail="Failed to record review")
716
+
717
+
718
+ @router.get("/vocabulary/stats")
719
+ async def get_vocabulary_stats(
720
+ request: Request,
721
+ token: Optional[str] = Depends(optional_hf_token)
722
+ ):
723
+ """Get vocabulary mastery statistics"""
724
+ try:
725
+ user_id = token if token else 'anonymous'
726
+ progress = learning_service.get_user_progress(user_id)
727
+
728
+ if not progress:
729
+ return {
730
+ "total_words": 0,
731
+ "in_practice": 0,
732
+ "mastery_breakdown": {str(i): 0 for i in range(6)},
733
+ "average_accuracy": 0
734
+ }
735
+
736
+ vocab_progress = progress.get('vocabulary_progress', {})
737
+ mastery_breakdown = {str(i): 0 for i in range(6)}
738
+ total_accuracy = 0
739
+ total_with_reviews = 0
740
+
741
+ for vocab_data in vocab_progress.values():
742
+ level = vocab_data.get('mastery_level', 0)
743
+ mastery_breakdown[str(level)] += 1
744
+
745
+ if vocab_data.get('times_reviewed', 0) > 0:
746
+ total_accuracy += vocab_data.get('accuracy', 0)
747
+ total_with_reviews += 1
748
+
749
+ avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
750
+
751
+ return {
752
+ "total_words": len(vocab_progress),
753
+ "in_practice": len(vocab_progress),
754
+ "mastery_breakdown": mastery_breakdown,
755
+ "average_accuracy": round(avg_accuracy, 1),
756
+ "total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
757
+ }
758
+ except Exception as e:
759
+ logger.error(f"Error getting vocabulary stats: {e}")
760
+ raise HTTPException(status_code=500, detail="Failed to get stats")
761
+
762
+
763
+ @router.get("/vocabulary/library")
764
+ async def get_vocabulary_library(
765
+ lesson_id: Optional[int] = None,
766
+ level: Optional[str] = None,
767
+ search: Optional[str] = None,
768
+ request: Request = None,
769
+ token: Optional[str] = Depends(optional_hf_token)
770
+ ):
771
+ """Browse all vocabulary with filters"""
772
+ try:
773
+ user_id = token if token else 'anonymous'
774
+
775
+ all_vocab = learning_service.get_all_vocabulary()
776
+ progress = learning_service.get_user_progress(user_id)
777
+ user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
778
+
779
+ filtered_vocab = all_vocab
780
+
781
+ if lesson_id:
782
+ filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
783
+
784
+ if level:
785
+ filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
786
+
787
+ if search:
788
+ search_lower = search.lower()
789
+ filtered_vocab = [v for v in filtered_vocab
790
+ if search_lower in v.get('swahili', '').lower()
791
+ or search_lower in v.get('english', '').lower()]
792
+
793
+ for vocab in filtered_vocab:
794
+ vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
795
+ if vocab_id in user_vocab:
796
+ vocab['status'] = 'practicing'
797
+ vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
798
+ vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
799
+ vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
800
+ else:
801
+ vocab['status'] = 'not_practicing'
802
+ vocab['mastery_level'] = 0
803
+
804
+ return {
805
+ "vocabulary": filtered_vocab,
806
+ "total": len(filtered_vocab),
807
+ "filters_applied": {
808
+ "lesson_id": lesson_id,
809
+ "level": level,
810
+ "search": search
811
+ }
812
+ }
813
+ except Exception as e:
814
+ logger.error(f"Error getting vocabulary library: {e}")
815
+ raise HTTPException(status_code=500, detail="Failed to get vocabulary")
816
+
817
+
818
+ # Reading Comprehension
819
+
820
+ class ComprehensionAnswer(BaseModel):
821
+ question_id: str
822
+ answer: str
823
+
824
+
825
+ class ComprehensionSubmission(BaseModel):
826
+ lesson_id: int
827
+ passage_id: str
828
+ answers: List[ComprehensionAnswer]
829
+
830
+
831
+ @router.post("/comprehension/submit")
832
+ async def submit_comprehension_answers(
833
+ submission: ComprehensionSubmission,
834
+ request: Request,
835
+ token: Optional[str] = Depends(optional_hf_token)
836
+ ):
837
+ """Submit reading comprehension answers and get scoring"""
838
+ try:
839
+ user_id = token if token else 'anonymous'
840
+
841
+ lesson = learning_service.get_lesson(submission.lesson_id)
842
+ if not lesson:
843
+ raise HTTPException(status_code=404, detail="Lesson not found")
844
+
845
+ passage = None
846
+ for p in lesson.get('reading_passages', []):
847
+ if p['passage_id'] == submission.passage_id:
848
+ passage = p
849
+ break
850
+
851
+ if not passage:
852
+ raise HTTPException(status_code=404, detail="Passage not found")
853
+
854
+ results = []
855
+ correct_count = 0
856
+
857
+ for submitted in submission.answers:
858
+ question_id = submitted.question_id
859
+ user_answer = submitted.answer.strip().lower()
860
+
861
+ question = None
862
+ for q in passage['comprehension_questions']:
863
+ if q['question_id'] == question_id:
864
+ question = q
865
+ break
866
+
867
+ if not question:
868
+ continue
869
+
870
+ correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
871
+ is_correct = user_answer in correct_answers
872
+
873
+ if is_correct:
874
+ correct_count += 1
875
+
876
+ results.append({
877
+ "question_id": question_id,
878
+ "correct": is_correct,
879
+ "user_answer": user_answer,
880
+ "correct_answer": question['correct_answers'][0] if correct_answers else None,
881
+ "explanation": question.get('explanation')
882
+ })
883
+
884
+ score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
885
+
886
+ progress = learning_service.get_user_progress(user_id)
887
+ if not progress:
888
+ progress = learning_service.create_default_progress(user_id)
889
+
890
+ if 'comprehension_scores' not in progress:
891
+ progress['comprehension_scores'] = {}
892
+
893
+ progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
894
+ "score": score,
895
+ "completed_at": datetime.utcnow().isoformat() + 'Z',
896
+ "attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
897
+ }
898
+
899
+ learning_service.save_user_progress(user_id, progress)
900
+
901
+ return {
902
+ "results": results,
903
+ "score": round(score, 1),
904
+ "correct": correct_count,
905
+ "total": len(submission.answers)
906
+ }
907
+ except HTTPException:
908
+ raise
909
+ except Exception as e:
910
+ logger.error(f"Error submitting comprehension: {e}")
911
+ raise HTTPException(status_code=500, detail="Failed to submit comprehension")
912
+
913
+
914
+ # Task Scenarios
915
+
916
+ class ScenarioProgressUpdate(BaseModel):
917
+ turn_id: str
918
+ choice_id: str
919
+
920
+
921
+ @router.get("/scenarios/{scenario_id}")
922
+ async def get_scenario(
923
+ scenario_id: str,
924
+ request: Request,
925
+ token: Optional[str] = Depends(optional_hf_token)
926
+ ):
927
+ """Get task scenario with branching dialogue"""
928
+ try:
929
+ user_id = token if token else 'anonymous'
930
+
931
+ scenario = learning_service.get_scenario(scenario_id)
932
+ if not scenario:
933
+ raise HTTPException(status_code=404, detail="Scenario not found")
934
+
935
+ progress = learning_service.get_user_progress(user_id)
936
+ scenario_progress = None
937
+
938
+ if progress and 'scenario_progress' in progress:
939
+ scenario_progress = progress['scenario_progress'].get(scenario_id)
940
+
941
+ return {
942
+ "scenario": scenario,
943
+ "user_progress": scenario_progress
944
+ }
945
+ except HTTPException:
946
+ raise
947
+ except Exception as e:
948
+ logger.error(f"Error getting scenario: {e}")
949
+ raise HTTPException(status_code=500, detail="Failed to get scenario")
950
+
951
+
952
+ @router.post("/scenarios/{scenario_id}/progress")
953
+ async def update_scenario_progress(
954
+ scenario_id: str,
955
+ progress_update: ScenarioProgressUpdate,
956
+ request: Request,
957
+ token: Optional[str] = Depends(optional_hf_token)
958
+ ):
959
+ """Update scenario progress with user choice"""
960
+ try:
961
+ user_id = token if token else 'anonymous'
962
+
963
+ scenario = learning_service.get_scenario(scenario_id)
964
+ if not scenario:
965
+ raise HTTPException(status_code=404, detail="Scenario not found")
966
+
967
+ progress = learning_service.get_user_progress(user_id)
968
+ if not progress:
969
+ progress = learning_service.create_default_progress(user_id)
970
+
971
+ if 'scenario_progress' not in progress:
972
+ progress['scenario_progress'] = {}
973
+
974
+ if scenario_id not in progress['scenario_progress']:
975
+ progress['scenario_progress'][scenario_id] = {
976
+ "started_at": datetime.utcnow().isoformat() + 'Z',
977
+ "turns": [],
978
+ "completed": False
979
+ }
980
+
981
+ progress['scenario_progress'][scenario_id]['turns'].append({
982
+ "turn_id": progress_update.turn_id,
983
+ "choice_id": progress_update.choice_id,
984
+ "timestamp": datetime.utcnow().isoformat() + 'Z'
985
+ })
986
+
987
+ turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
988
+ if turns_count >= scenario.get('required_turns', 6):
989
+ progress['scenario_progress'][scenario_id]['completed'] = True
990
+ progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
991
+
992
+ learning_service.save_user_progress(user_id, progress)
993
+
994
+ return {
995
+ "success": True,
996
+ "progress": progress['scenario_progress'][scenario_id]
997
+ }
998
+ except HTTPException:
999
+ raise
1000
+ except Exception as e:
1001
+ logger.error(f"Error updating scenario progress: {e}")
1002
+ raise HTTPException(status_code=500, detail="Failed to update scenario")
1003
+
1004
+
1005
+ @router.get("/scenarios")
1006
+ async def list_scenarios(
1007
+ request: Request,
1008
+ token: Optional[str] = Depends(optional_hf_token)
1009
+ ):
1010
+ """Get list of all available scenarios"""
1011
+ try:
1012
+ scenarios = learning_service.get_all_scenarios()
1013
+ return {
1014
+ "success": True,
1015
+ "scenarios": scenarios,
1016
+ "total": len(scenarios)
1017
+ }
1018
+ except Exception as e:
1019
+ logger.error(f"Error listing scenarios: {e}")
1020
+ raise HTTPException(status_code=500, detail="Failed to list scenarios")
app/routers/mobile.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Query, Depends
2
+ from fastapi.responses import Response
3
+ from typing import Optional
4
+ from pydantic import BaseModel
5
+ import base64
6
+ import json
7
+ import uuid
8
+ import datetime
9
+ from app.services.translation_service import TranslationService
10
+ from app.services.tts_service import TTSService
11
+ from app.services.transcription_service import TranscriptionService
12
+ from app.auth import require_hf_token
13
+
14
+ router = APIRouter()
15
+
16
+ # Service instances - these will be injected by main app
17
+ translation_service = None
18
+ tts_service = None
19
+ transcription_service = None
20
+
21
+ # Mobile-specific data models
22
+ class MobileSessionRequest(BaseModel):
23
+ user_name: str
24
+ default_source_lang: str = "eng"
25
+ default_target_lang: str = "swa"
26
+
27
+ class MobileSessionResponse(BaseModel):
28
+ session_id: str
29
+ participant_id: str
30
+ user_name: str
31
+ source_language: str
32
+ target_language: str
33
+
34
+ class MobileTranscribeRequest(BaseModel):
35
+ participant_id: str
36
+ source_language: str
37
+ target_language: str
38
+ is_final_chunk: bool = False
39
+
40
+ class MobileLanguageUpdateRequest(BaseModel):
41
+ participant_id: str
42
+ source_language: str
43
+ target_language: str
44
+
45
+ # In-memory session storage (in production, use Redis or database)
46
+ mobile_sessions = {}
47
+
48
+ @router.post("/mobile/session/create", response_model=MobileSessionResponse)
49
+ async def create_mobile_session(
50
+ user_name: str = Form(...),
51
+ default_source_lang: str = Form("eng"),
52
+ default_target_lang: str = Form("swa"),
53
+ token: str = Depends(require_hf_token)
54
+ ):
55
+ """Create a mobile-specific single-user session"""
56
+ try:
57
+ print(f"=== MOBILE SESSION CREATE REQUEST ===")
58
+ print(f"User name: {user_name}")
59
+ print(f"Source language: {default_source_lang}")
60
+ print(f"Target language: {default_target_lang}")
61
+
62
+ # Validate inputs
63
+ if not user_name or user_name.strip() == "":
64
+ raise HTTPException(status_code=400, detail="User name is required")
65
+
66
+ # Validate language codes
67
+ valid_languages = ["eng", "swa", "kik", "kam", "mer", "luo", "som"]
68
+ if default_source_lang not in valid_languages:
69
+ print(f"Invalid source language: {default_source_lang}, defaulting to 'eng'")
70
+ default_source_lang = "eng"
71
+ if default_target_lang not in valid_languages:
72
+ print(f"Invalid target language: {default_target_lang}, defaulting to 'swa'")
73
+ default_target_lang = "swa"
74
+
75
+ session_id = f"mobile-{uuid.uuid4().hex[:8]}"
76
+ participant_id = f"user-{uuid.uuid4().hex[:8]}"
77
+
78
+ # Store session data
79
+ mobile_sessions[session_id] = {
80
+ "session_id": session_id,
81
+ "participant_id": participant_id,
82
+ "user_name": user_name.strip(),
83
+ "source_language": default_source_lang,
84
+ "target_language": default_target_lang,
85
+ "created_at": datetime.datetime.now().isoformat()
86
+ }
87
+
88
+ print(f"Created session: {session_id} for user: {user_name}")
89
+ print(f"Total sessions: {len(mobile_sessions)}")
90
+
91
+ response = MobileSessionResponse(
92
+ session_id=session_id,
93
+ participant_id=participant_id,
94
+ user_name=user_name.strip(),
95
+ source_language=default_source_lang,
96
+ target_language=default_target_lang
97
+ )
98
+
99
+ print(f"Returning response: {response}")
100
+ return response
101
+
102
+ except HTTPException:
103
+ raise
104
+ except Exception as e:
105
+ print(f"ERROR creating mobile session: {e}")
106
+ import traceback
107
+ traceback.print_exc()
108
+ raise HTTPException(status_code=500, detail=f"Failed to create mobile session: {str(e)}")
109
+
110
+ @router.get("/mobile/session/{session_id}")
111
+ async def get_mobile_session(session_id: str, token: str = Depends(require_hf_token)):
112
+ """Get mobile session details"""
113
+ if session_id not in mobile_sessions:
114
+ raise HTTPException(status_code=404, detail="Session not found")
115
+
116
+ return mobile_sessions[session_id]
117
+
118
+ @router.put("/mobile/session/{session_id}/languages")
119
+ async def update_session_languages(
120
+ session_id: str,
121
+ participant_id: str = Form(...),
122
+ source_language: str = Form(...),
123
+ target_language: str = Form(...)
124
+ ):
125
+ """Update the default languages for a mobile session"""
126
+ try:
127
+ if session_id not in mobile_sessions:
128
+ raise HTTPException(status_code=404, detail="Session not found")
129
+
130
+ session = mobile_sessions[session_id]
131
+ if session["participant_id"] != participant_id:
132
+ raise HTTPException(status_code=403, detail="Invalid participant")
133
+
134
+ # Update session languages
135
+ session["source_language"] = source_language
136
+ session["target_language"] = target_language
137
+
138
+ return {
139
+ "success": True,
140
+ "session_id": session_id,
141
+ "source_language": source_language,
142
+ "target_language": target_language
143
+ }
144
+
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=f"Failed to update languages: {str(e)}")
147
+
148
+ @router.post("/mobile/session/{session_id}/transcribe-realtime")
149
+ async def transcribe_realtime(
150
+ session_id: str,
151
+ audio: UploadFile = File(...),
152
+ participant_id: str = Form(...),
153
+ source_language: str = Form(...),
154
+ target_language: str = Form(...),
155
+ is_final_chunk: bool = Form(False),
156
+ chunk_sequence: int = Form(0)
157
+ ):
158
+ """Real-time transcription endpoint for mobile with streaming support"""
159
+ try:
160
+ if session_id not in mobile_sessions:
161
+ raise HTTPException(status_code=404, detail="Session not found")
162
+
163
+ session = mobile_sessions[session_id]
164
+ if session["participant_id"] != participant_id:
165
+ raise HTTPException(status_code=403, detail="Invalid participant")
166
+
167
+ # Read audio file
168
+ audio_data = await audio.read()
169
+
170
+ # Generate unique message ID for this chunk sequence
171
+ message_id = f"msg-{participant_id}-{chunk_sequence}"
172
+
173
+ # Initialize response data
174
+ response_data = {
175
+ "success": True,
176
+ "message_id": message_id,
177
+ "chunk_sequence": chunk_sequence,
178
+ "original_text": "",
179
+ "original_language": source_language,
180
+ "is_final_chunk": is_final_chunk,
181
+ "is_interim": not is_final_chunk,
182
+ "session_id": session_id,
183
+ "translated_text": None,
184
+ "target_language": target_language,
185
+ "has_audio": False,
186
+ "audio_base64": None,
187
+ "audio_format": None
188
+ }
189
+
190
+ # Process transcription
191
+ if transcription_service:
192
+ try:
193
+ # Use streaming transcription if available
194
+ if hasattr(transcription_service, 'process_realtime_chunk'):
195
+ transcription_result = await transcription_service.process_realtime_chunk(
196
+ audio_data, source_language, participant_id, is_final_chunk
197
+ )
198
+ else:
199
+ # Fallback to regular transcription
200
+ transcription_result = await transcription_service.transcribe_audio(
201
+ audio_data, source_language
202
+ )
203
+
204
+ response_data["original_text"] = transcription_result or ""
205
+
206
+ # Only process translation and TTS for final chunks with actual text
207
+ if is_final_chunk and transcription_result and transcription_result.strip():
208
+ if translation_service:
209
+ try:
210
+ translated_text = await translation_service.translate_text(
211
+ transcription_result, source_language, target_language
212
+ )
213
+ response_data["translated_text"] = translated_text
214
+
215
+ # Generate TTS audio in target language
216
+ if tts_service and translated_text:
217
+ try:
218
+ tts_audio = await tts_service.generate_speech(
219
+ translated_text, target_language, output_format="wav"
220
+ )
221
+
222
+ if tts_audio:
223
+ response_data.update({
224
+ "has_audio": True,
225
+ "audio_base64": base64.b64encode(tts_audio).decode('utf-8'),
226
+ "audio_format": "wav"
227
+ })
228
+ except Exception as tts_error:
229
+ print(f"TTS generation failed: {tts_error}")
230
+ # Continue without TTS
231
+
232
+ except Exception as translation_error:
233
+ print(f"Translation failed: {translation_error}")
234
+ # Continue without translation
235
+
236
+ except Exception as transcription_error:
237
+ print(f"Transcription failed: {transcription_error}")
238
+ response_data["original_text"] = ""
239
+
240
+ return response_data
241
+
242
+ else:
243
+ raise HTTPException(status_code=500, detail="Transcription service not available")
244
+
245
+ except Exception as e:
246
+ raise HTTPException(status_code=500, detail=f"Real-time transcription failed: {str(e)}")
247
+
248
+ @router.post("/mobile/session/{session_id}/stream-audio")
249
+ async def stream_audio_chunk(
250
+ session_id: str,
251
+ participant_id: str = Form(...),
252
+ audio_chunk: UploadFile = File(...),
253
+ source_language: str = Form(...),
254
+ target_language: str = Form(...),
255
+ chunk_index: int = Form(0),
256
+ is_speaking: bool = Form(True),
257
+ force_complete: bool = Form(False)
258
+ ):
259
+ """Stream audio chunks for continuous processing"""
260
+ try:
261
+ if session_id not in mobile_sessions:
262
+ raise HTTPException(status_code=404, detail="Session not found")
263
+
264
+ session = mobile_sessions[session_id]
265
+ if session["participant_id"] != participant_id:
266
+ raise HTTPException(status_code=403, detail="Invalid participant")
267
+
268
+ audio_data = await audio_chunk.read()
269
+
270
+ # Use streaming approach similar to WebSocket
271
+ interim_text = ""
272
+ if transcription_service:
273
+ try:
274
+ if hasattr(transcription_service, 'process_audio_chunk'):
275
+ result = await transcription_service.process_audio_chunk(
276
+ audio_data,
277
+ source_language,
278
+ participant_id,
279
+ has_voice_activity=is_speaking,
280
+ progress_callback=None, # No callback for HTTP
281
+ sentence_callback=None # No callback for HTTP
282
+ )
283
+ interim_text = result or ""
284
+ else:
285
+ # Fallback to regular transcription for interim results
286
+ interim_text = await transcription_service.transcribe_audio(
287
+ audio_data, source_language
288
+ ) or ""
289
+ except Exception as e:
290
+ print(f"Streaming transcription error: {e}")
291
+ interim_text = ""
292
+
293
+ return {
294
+ "success": True,
295
+ "chunk_index": chunk_index,
296
+ "session_id": session_id,
297
+ "interim_text": interim_text,
298
+ "is_speaking": is_speaking,
299
+ "force_complete": force_complete
300
+ }
301
+ else:
302
+ raise HTTPException(status_code=500, detail="Transcription service not available")
303
+
304
+ except Exception as e:
305
+ raise HTTPException(status_code=500, detail=f"Audio streaming failed: {str(e)}")
306
+
307
+ @router.get("/mobile/session/{session_id}/realtime-status")
308
+ async def get_realtime_status(session_id: str, participant_id: str = Query(...)):
309
+ """Get current real-time processing status"""
310
+ try:
311
+ if session_id not in mobile_sessions:
312
+ raise HTTPException(status_code=404, detail="Session not found")
313
+
314
+ session = mobile_sessions[session_id]
315
+ if session["participant_id"] != participant_id:
316
+ raise HTTPException(status_code=403, detail="Invalid participant")
317
+
318
+ # Check if transcription service has any pending messages
319
+ pending_messages = []
320
+ if transcription_service:
321
+ try:
322
+ if hasattr(transcription_service, 'get_participant_status'):
323
+ pending_messages = transcription_service.get_participant_status(participant_id)
324
+ else:
325
+ pending_messages = []
326
+ except Exception as e:
327
+ print(f"Error getting participant status: {e}")
328
+ pending_messages = []
329
+
330
+ return {
331
+ "session_id": session_id,
332
+ "participant_id": participant_id,
333
+ "is_active": True,
334
+ "pending_messages": pending_messages,
335
+ "current_languages": {
336
+ "source": session["source_language"],
337
+ "target": session["target_language"]
338
+ },
339
+ "service_status": {
340
+ "transcription": transcription_service is not None,
341
+ "translation": translation_service is not None,
342
+ "tts": tts_service is not None
343
+ }
344
+ }
345
+
346
+ except Exception as e:
347
+ raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
348
+
349
+ @router.post("/mobile/session/{session_id}/transcribe-with-languages")
350
+ async def transcribe_with_languages_legacy(
351
+ session_id: str,
352
+ audio: UploadFile = File(...),
353
+ participant_id: str = Form(...),
354
+ source_language: str = Form(...),
355
+ target_language: str = Form(...),
356
+ is_final_chunk: bool = Form(False)
357
+ ):
358
+ """Legacy endpoint - transcribe audio with specific source/target languages for mobile"""
359
+ try:
360
+ if session_id not in mobile_sessions:
361
+ raise HTTPException(status_code=404, detail="Session not found")
362
+
363
+ session = mobile_sessions[session_id]
364
+ if session["participant_id"] != participant_id:
365
+ raise HTTPException(status_code=403, detail="Invalid participant")
366
+
367
+ # Read audio file
368
+ audio_data = await audio.read()
369
+
370
+ # Generate unique message ID
371
+ message_id = f"msg-{uuid.uuid4().hex[:8]}"
372
+
373
+ # Initialize response
374
+ response_data = {
375
+ "success": True,
376
+ "message_id": message_id,
377
+ "original_text": "",
378
+ "original_language": source_language,
379
+ "translated_text": None,
380
+ "target_language": target_language,
381
+ "has_audio": False,
382
+ "is_final_chunk": is_final_chunk,
383
+ "audio_base64": None
384
+ }
385
+
386
+ # Process transcription in source language
387
+ if transcription_service:
388
+ try:
389
+ transcription_result = await transcription_service.transcribe_audio(
390
+ audio_data, source_language
391
+ )
392
+ response_data["original_text"] = transcription_result or ""
393
+
394
+ # Process translation to target language
395
+ if translation_service and transcription_result and transcription_result.strip():
396
+ try:
397
+ translated_text = await translation_service.translate_text(
398
+ transcription_result, source_language, target_language
399
+ )
400
+ response_data["translated_text"] = translated_text
401
+
402
+ # Generate TTS audio in target language
403
+ if tts_service and translated_text:
404
+ try:
405
+ tts_audio = await tts_service.generate_speech(
406
+ translated_text, target_language, output_format="wav"
407
+ )
408
+
409
+ if tts_audio:
410
+ response_data["has_audio"] = True
411
+ response_data["audio_base64"] = base64.b64encode(tts_audio).decode('utf-8')
412
+ except Exception as tts_error:
413
+ print(f"TTS generation failed: {tts_error}")
414
+
415
+ except Exception as translation_error:
416
+ print(f"Translation failed: {translation_error}")
417
+
418
+ except Exception as transcription_error:
419
+ print(f"Transcription failed: {transcription_error}")
420
+
421
+ return response_data
422
+ else:
423
+ raise HTTPException(status_code=500, detail="Transcription service not available")
424
+
425
+ except Exception as e:
426
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
427
+
428
+ @router.post("/mobile/translate")
429
+ async def translate_text_mobile(
430
+ text: str = Form(...),
431
+ source_lang: str = Form(...),
432
+ target_lang: str = Form(...)
433
+ ):
434
+ """Mobile-friendly text translation endpoint"""
435
+ try:
436
+ if not translation_service:
437
+ raise HTTPException(status_code=500, detail="Translation service not initialized")
438
+
439
+ # Map common language codes to internal format
440
+ lang_mapping = {
441
+ "english": "eng", "en": "eng",
442
+ "swahili": "swa", "sw": "swa",
443
+ "kikuyu": "kik", "ki": "kik",
444
+ "kamba": "kam", "kam": "kam",
445
+ "kimeru": "mer", "mer": "mer",
446
+ "luo": "luo", "luo": "luo",
447
+ "somali": "som", "so": "som"
448
+ }
449
+
450
+ source_code = lang_mapping.get(source_lang.lower(), source_lang.lower())
451
+ target_code = lang_mapping.get(target_lang.lower(), target_lang.lower())
452
+
453
+ translated_text = await translation_service.translate_text(text, source_code, target_code)
454
+
455
+ return {
456
+ "success": True,
457
+ "original_text": text,
458
+ "translated_text": translated_text or text,
459
+ "source_language": source_code,
460
+ "target_language": target_code
461
+ }
462
+
463
+ except Exception as e:
464
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
465
+
466
+ @router.get("/mobile/languages")
467
+ async def get_supported_languages():
468
+ """Get list of supported languages for mobile app"""
469
+ return {
470
+ "supported_languages": [
471
+ {"code": "eng", "name": "English", "display_name": "English (eng)"},
472
+ {"code": "swa", "name": "Swahili", "display_name": "Swahili (swa)"},
473
+ {"code": "kik", "name": "Kikuyu", "display_name": "Kikuyu (kik)"},
474
+ {"code": "kam", "name": "Kamba", "display_name": "Kamba (kam)"},
475
+ {"code": "mer", "name": "Kimeru", "display_name": "Kimeru (mer)"},
476
+ {"code": "luo", "name": "Luo", "display_name": "Luo (luo)"},
477
+ {"code": "som", "name": "Somali", "display_name": "Somali (som)"}
478
+ ]
479
+ }
480
+
481
+ @router.get("/mobile/test")
482
+ async def test_mobile_endpoints():
483
+ """Test endpoint for mobile app connectivity"""
484
+ return {
485
+ "status": "Mobile API is working",
486
+ "endpoints": [
487
+ "/mobile/session/create",
488
+ "/mobile/session/{session_id}",
489
+ "/mobile/session/{session_id}/languages",
490
+ "/mobile/session/{session_id}/transcribe-realtime",
491
+ "/mobile/session/{session_id}/stream-audio",
492
+ "/mobile/session/{session_id}/realtime-status",
493
+ "/mobile/session/{session_id}/transcribe-with-languages",
494
+ "/mobile/translate",
495
+ "/mobile/languages",
496
+ "/mobile/test"
497
+ ],
498
+ "timestamp": datetime.datetime.now().isoformat(),
499
+ "services_available": {
500
+ "transcription": transcription_service is not None,
501
+ "translation": translation_service is not None,
502
+ "tts": tts_service is not None
503
+ },
504
+ "active_sessions": len(mobile_sessions),
505
+ "session_list": list(mobile_sessions.keys())
506
+ }
507
+
508
+ @router.post("/mobile/test-session")
509
+ async def test_session_creation(
510
+ test_user: str = Form("TestUser"),
511
+ test_source: str = Form("eng"),
512
+ test_target: str = Form("swa")
513
+ ):
514
+ """Test session creation with debug info"""
515
+ try:
516
+ print(f"=== TEST SESSION CREATE ===")
517
+ print(f"Received: user={test_user}, source={test_source}, target={test_target}")
518
+
519
+ session_id = f"test-{uuid.uuid4().hex[:8]}"
520
+
521
+ return {
522
+ "success": True,
523
+ "test_session_id": session_id,
524
+ "received_params": {
525
+ "user": test_user,
526
+ "source": test_source,
527
+ "target": test_target
528
+ },
529
+ "form_processing": "OK"
530
+ }
531
+ except Exception as e:
532
+ print(f"Test session error: {e}")
533
+ return {
534
+ "success": False,
535
+ "error": str(e)
536
+ }
app/routers/sessions.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Response, Depends
2
+ from typing import List
3
+ from pydantic import BaseModel
4
+ import qrcode
5
+ import io
6
+ from app.models import Session, SessionCreate
7
+ from app.auth import require_hf_token, optional_hf_token
8
+
9
+ router = APIRouter()
10
+
11
+ # This will be set by the main app
12
+ session_manager = None
13
+
14
+ # Initialize services (these will be injected by main app)
15
+ transcription_service = None
16
+ translation_service = None
17
+ tts_service = None
18
+
19
+ class TextTranslationRequest(BaseModel):
20
+ text: str
21
+ source_language: str
22
+ target_language: str
23
+
24
+ class TextTranslationResponse(BaseModel):
25
+ original_text: str
26
+ translated_text: str
27
+ source_language: str
28
+ target_language: str
29
+
30
+ @router.post("/sessions", response_model=Session)
31
+ async def create_session(session_data: SessionCreate, token: str = Depends(require_hf_token)):
32
+ """Create a new transcription session"""
33
+ try:
34
+ session = await session_manager.create_session(session_data)
35
+ return session
36
+ except Exception as e:
37
+ raise HTTPException(status_code=500, detail=str(e))
38
+
39
+ @router.get("/sessions", response_model=List[Session])
40
+ async def get_all_sessions(token: str = Depends(require_hf_token)):
41
+ """Get all active sessions"""
42
+ try:
43
+ sessions = await session_manager.get_all_sessions()
44
+ return sessions
45
+ except Exception as e:
46
+ raise HTTPException(status_code=500, detail=str(e))
47
+
48
+ @router.get("/sessions/{session_id}", response_model=Session)
49
+ async def get_session(session_id: str, token: str = Depends(require_hf_token)):
50
+ """Get specific session by ID or short code"""
51
+ session = await session_manager.get_session(session_id)
52
+ if not session:
53
+ raise HTTPException(status_code=404, detail="Session not found")
54
+ return session
55
+
56
+ @router.get("/sessions/{session_id}/short-code")
57
+ async def get_session_short_code(session_id: str, token: str = Depends(require_hf_token)):
58
+ """Get short code for a session"""
59
+ session = await session_manager.get_session(session_id)
60
+ if not session:
61
+ raise HTTPException(status_code=404, detail="Session not found")
62
+
63
+ short_code = session_manager.get_short_code(session.id)
64
+ return {"session_id": session.id, "short_code": short_code}
65
+
66
+ @router.delete("/sessions/{session_id}")
67
+ async def delete_session(session_id: str, token: str = Depends(require_hf_token)):
68
+ """Delete a session"""
69
+ success = await session_manager.delete_session(session_id)
70
+ if not success:
71
+ raise HTTPException(status_code=404, detail="Session not found")
72
+ return {"message": "Session deleted successfully"}
73
+
74
+ @router.post("/sessions/{session_id}/languages/{language_code}")
75
+ async def add_language_to_session(session_id: str, language_code: str, token: str = Depends(require_hf_token)):
76
+ """Add a language to a session"""
77
+ from app.models import LanguageCode
78
+
79
+ # Convert string to LanguageCode enum
80
+ try:
81
+ lang_code_enum = LanguageCode(language_code)
82
+ except ValueError:
83
+ raise HTTPException(status_code=400, detail=f"Invalid language code: {language_code}")
84
+
85
+ success = await session_manager.add_language_to_session(session_id, lang_code_enum)
86
+ if success:
87
+ session = await session_manager.get_session(session_id)
88
+ return {"message": f"Language {language_code} added to session", "session": session}
89
+ else:
90
+ # Check if session exists
91
+ session = await session_manager.get_session(session_id)
92
+ if not session:
93
+ raise HTTPException(status_code=404, detail="Session not found")
94
+ return {"message": f"Language {language_code} already exists in session", "session": session}
95
+
96
+ @router.post("/translate", response_model=TextTranslationResponse)
97
+ async def translate_text(request: TextTranslationRequest, token: str = Depends(require_hf_token)):
98
+ """Translate text from source language to target language"""
99
+ try:
100
+ # Map language codes to proper names
101
+ lang_map = {
102
+ 'eng': 'English',
103
+ 'swa': 'Swahili',
104
+ 'kik': 'Kikuyu',
105
+ 'kam': 'Kamba',
106
+ 'mer': 'Kimeru',
107
+ 'luo': 'Luo',
108
+ 'som': 'Somali'
109
+ }
110
+
111
+ source_lang_name = lang_map.get(request.source_language.lower(), request.source_language)
112
+ target_lang_name = lang_map.get(request.target_language.lower(), request.target_language)
113
+
114
+ # Perform translation
115
+ translated_text = await translation_service.translate_text(
116
+ text=request.text,
117
+ source_lang=source_lang_name,
118
+ target_lang=target_lang_name
119
+ )
120
+
121
+ return TextTranslationResponse(
122
+ original_text=request.text,
123
+ translated_text=translated_text,
124
+ source_language=request.source_language,
125
+ target_language=request.target_language
126
+ )
127
+
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
130
+
131
+ @router.get("/test")
132
+ async def test_endpoint(token: str = Depends(optional_hf_token)):
133
+ """Test endpoint to verify API is working"""
134
+ auth_status = "authenticated" if token else "public"
135
+ return {
136
+ "status": "API is working",
137
+ "sessions_count": len(session_manager.sessions),
138
+ "auth_status": auth_status
139
+ }
140
+
141
+ @router.get("/test/translation")
142
+ async def test_translation(token: str = Depends(require_hf_token)):
143
+ """Test translation service directly"""
144
+ try:
145
+ # Test English to Swahili translation
146
+ result = await translation_service.translate_text("Hello, how are you?", "English", "Swahili")
147
+
148
+ return {
149
+ "status": "Translation test completed",
150
+ "original": "Hello, how are you?",
151
+ "translated": result,
152
+ "source_lang": "English",
153
+ "target_lang": "Swahili"
154
+ }
155
+ except Exception as e:
156
+ return {"status": "Translation test failed", "error": str(e)}
157
+
158
+ @router.get("/test/tts")
159
+ async def test_tts(token: str = Depends(require_hf_token)):
160
+ """Test TTS service directly"""
161
+ try:
162
+ # Test TTS generation
163
+ audio_data = await tts_service.generate_speech("Hello world", "eng")
164
+
165
+ return {
166
+ "status": "TTS test completed",
167
+ "text": "Hello world",
168
+ "language": "eng",
169
+ "audio_generated": audio_data is not None,
170
+ "audio_size": len(audio_data) if audio_data else 0
171
+ }
172
+ except Exception as e:
173
+ return {"status": "TTS test failed", "error": str(e)}
174
+
175
+ @router.get("/sessions/{session_id}/qr-code")
176
+ async def get_session_qr_code(session_id: str, token: str = Depends(require_hf_token)):
177
+ """Generate QR code for session"""
178
+ if session_manager is None:
179
+ raise HTTPException(status_code=500, detail="Session manager not initialized")
180
+
181
+ session = await session_manager.get_session(session_id)
182
+
183
+ if not session:
184
+ raise HTTPException(status_code=404, detail="Session not found")
185
+
186
+ # Generate QR code with session join URL - use your HF space URL
187
+ join_url = f"https://mutisya-realtime-translator-5-27-25-v2.hf.space/?join={session_id}"
188
+
189
+ qr = qrcode.QRCode(version=1, box_size=10, border=5)
190
+ qr.add_data(join_url)
191
+ qr.make(fit=True)
192
+
193
+ img = qr.make_image(fill_color="black", back_color="white")
194
+
195
+ # Convert to bytes
196
+ img_buffer = io.BytesIO()
197
+ img.save(img_buffer, format='PNG')
198
+ img_buffer.seek(0)
199
+
200
+ return Response(content=img_buffer.getvalue(), media_type="image/png")
app/routers/watch.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
2
+ from fastapi.responses import Response
3
+ from typing import Optional
4
+ import io
5
+ import base64
6
+ from app.services.transcription_service import TranscriptionService
7
+ from app.services.translation_service import TranslationService
8
+ from app.services.tts_service import TTSService
9
+ from app.models import LanguageCode
10
+ from pydantic import BaseModel
11
+ from app.auth import require_hf_token
12
+
13
+ router = APIRouter()
14
+
15
+ class WatchTranslationRequest(BaseModel):
16
+ source_language: str
17
+ target_language: str
18
+ audio_base64: str
19
+
20
+ class WatchTranslationResponse(BaseModel):
21
+ original_text: str
22
+ original_language: str
23
+ translated_text: str
24
+ target_language: str
25
+ translated_audio_base64: str
26
+ success: bool
27
+ error: Optional[str] = None
28
+
29
+ # Initialize services (these will be injected by main app)
30
+ transcription_service = None
31
+ translation_service = None
32
+ tts_service = None
33
+
34
+ @router.post("/watch/translate", response_model=WatchTranslationResponse)
35
+ async def watch_translate_audio(request: WatchTranslationRequest, token: str = Depends(require_hf_token)):
36
+ """
37
+ Process audio for watch app translation
38
+ - Transcribe audio using source language model
39
+ - Translate text to target language
40
+ - Generate TTS audio for target language
41
+ - Return all data to watch app
42
+ """
43
+ try:
44
+ # Validate languages
45
+ source_lang = request.source_language.lower()
46
+ target_lang = request.target_language.lower()
47
+
48
+ if source_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
49
+ raise HTTPException(status_code=400, detail=f"Unsupported source language: {source_lang}")
50
+
51
+ if target_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
52
+ raise HTTPException(status_code=400, detail=f"Unsupported target language: {target_lang}")
53
+
54
+ # Decode base64 audio
55
+ try:
56
+ audio_data = base64.b64decode(request.audio_base64)
57
+ print(f"Decoded audio data: {len(audio_data)} bytes")
58
+ except Exception as e:
59
+ raise HTTPException(status_code=400, detail=f"Invalid base64 audio data: {str(e)}")
60
+
61
+ # Step 1: Transcribe audio
62
+ print(f"Transcribing audio with {source_lang} model...")
63
+ transcribed_text = await transcription_service.transcribe_audio(audio_data, source_lang)
64
+
65
+ if not transcribed_text or transcribed_text.strip() == "":
66
+ return WatchTranslationResponse(
67
+ original_text="",
68
+ original_language=source_lang,
69
+ translated_text="No speech detected",
70
+ target_language=target_lang,
71
+ translated_audio_base64="",
72
+ success=False,
73
+ error="No speech detected in audio"
74
+ )
75
+
76
+ print(f"Transcribed text: {transcribed_text}")
77
+
78
+ # Step 2: Translate text (skip if source and target are the same)
79
+ if source_lang == target_lang:
80
+ translated_text = transcribed_text
81
+ else:
82
+ print(f"Translating from {source_lang} to {target_lang}...")
83
+
84
+ # Convert language codes to full names for translation service
85
+ lang_name_map = {
86
+ 'eng': 'English',
87
+ 'swa': 'Swahili',
88
+ 'kik': 'Kikuyu',
89
+ 'kam': 'Kamba',
90
+ 'mer': 'Kimeru',
91
+ 'luo': 'Luo',
92
+ 'som': 'Somali'
93
+ }
94
+
95
+ source_lang_name = lang_name_map.get(source_lang, 'English')
96
+ target_lang_name = lang_name_map.get(target_lang, 'Swahili')
97
+
98
+ translated_text = await translation_service.translate_text(
99
+ transcribed_text,
100
+ source_lang_name,
101
+ target_lang_name
102
+ )
103
+
104
+ print(f"Translated text: {translated_text}")
105
+
106
+ # Step 3: Generate TTS audio for translated text (Android-compatible WAV format)
107
+ print(f"Generating TTS audio for {target_lang} in WAV format for Android...")
108
+ tts_audio_data = await tts_service.generate_speech(translated_text, target_lang, output_format="wav")
109
+
110
+ # Encode TTS audio as base64
111
+ tts_audio_base64 = ""
112
+ if tts_audio_data:
113
+ tts_audio_base64 = base64.b64encode(tts_audio_data).decode('utf-8')
114
+ print(f"TTS audio generated: {len(tts_audio_data)} bytes, base64: {len(tts_audio_base64)} chars")
115
+ else:
116
+ print("TTS audio generation failed - no data returned")
117
+
118
+ return WatchTranslationResponse(
119
+ original_text=transcribed_text,
120
+ original_language=source_lang,
121
+ translated_text=translated_text,
122
+ target_language=target_lang,
123
+ translated_audio_base64=tts_audio_base64,
124
+ success=True
125
+ )
126
+
127
+ except Exception as e:
128
+ print(f"Error in watch translation: {str(e)}")
129
+ import traceback
130
+ traceback.print_exc()
131
+
132
+ return WatchTranslationResponse(
133
+ original_text="",
134
+ original_language=request.source_language,
135
+ translated_text="",
136
+ target_language=request.target_language,
137
+ translated_audio_base64="",
138
+ success=False,
139
+ error=str(e)
140
+ )
141
+
142
+ @router.get("/watch/test")
143
+ async def test_watch_endpoint(token: str = Depends(require_hf_token)):
144
+ """Test endpoint for watch app connectivity"""
145
+ return {
146
+ "status": "Watch API is working",
147
+ "services": {
148
+ "transcription": transcription_service is not None,
149
+ "translation": translation_service is not None,
150
+ "tts": tts_service is not None
151
+ }
152
+ }
app/services/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Services package
app/services/learning_data_service.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning Data Service - File-based data access for language learning prototype
3
+
4
+ This service provides access to lesson data, user progress, and achievements
5
+ using JSON files stored in the backend/data/learning directory.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Dict, List, Optional, Any
12
+ from datetime import datetime
13
+ import logging
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class LearningDataService:
19
+ """Service for managing language learning data using JSON files"""
20
+
21
+ def __init__(self):
22
+ # Get the data directory relative to this file
23
+ self.data_dir = Path(__file__).parent.parent.parent / "data" / "learning"
24
+ self.lessons_dir = self.data_dir / "lessons"
25
+ self.users_dir = self.data_dir / "users"
26
+
27
+ # Ensure directories exist
28
+ self.users_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ logger.info(f"Learning data directory: {self.data_dir}")
31
+ logger.info(f"Lessons directory: {self.lessons_dir}")
32
+ logger.info(f"Users directory: {self.users_dir}")
33
+
34
+ # ==================== Lesson Data ====================
35
+
36
+ def get_lessons_index(self, language: str = 'swahili') -> Optional[Dict]:
37
+ """Load the lessons index/catalog for a specific language"""
38
+ try:
39
+ # Map language codes to folder names
40
+ language_map = {
41
+ 'swahili': 'swahili',
42
+ 'swa': 'swahili',
43
+ 'kamba': 'kamba',
44
+ 'kam': 'kamba',
45
+ 'maasai': 'maasai',
46
+ 'mas': 'maasai'
47
+ }
48
+
49
+ language_folder = language_map.get(language.lower(), 'swahili')
50
+ index_path = self.lessons_dir / language_folder / "index.json"
51
+
52
+ logger.info(f"Loading lessons index for language '{language}' -> folder '{language_folder}' at {index_path}")
53
+
54
+ if not index_path.exists():
55
+ logger.warning(f"Lessons index not found at {index_path}")
56
+ logger.info(f"Lessons dir contents: {list(self.lessons_dir.iterdir())}")
57
+ return None
58
+
59
+ with open(index_path, 'r', encoding='utf-8') as f:
60
+ data = json.load(f)
61
+ logger.info(f"Successfully loaded {len(data.get('lessons', []))} lessons for {language}")
62
+ return data
63
+ except Exception as e:
64
+ logger.error(f"Error loading lessons index for {language}: {e}")
65
+ return None
66
+
67
+ def get_lesson(self, lesson_id: int, language: str = 'swahili') -> Optional[Dict]:
68
+ """Load a specific lesson by ID for a specific language"""
69
+ try:
70
+ # Map language codes to folder names
71
+ language_map = {
72
+ 'swahili': 'swahili',
73
+ 'swa': 'swahili',
74
+ 'kamba': 'kamba',
75
+ 'kam': 'kamba',
76
+ 'maasai': 'maasai',
77
+ 'mas': 'maasai'
78
+ }
79
+
80
+ language_folder = language_map.get(language.lower(), 'swahili')
81
+
82
+ # First get the index to find the lesson file
83
+ index = self.get_lessons_index(language)
84
+ if not index:
85
+ return None
86
+
87
+ # Find the lesson in the index
88
+ lesson_meta = None
89
+ for lesson in index.get('lessons', []):
90
+ if lesson['lesson_id'] == lesson_id:
91
+ lesson_meta = lesson
92
+ break
93
+
94
+ if not lesson_meta:
95
+ logger.warning(f"Lesson {lesson_id} not found in index for {language}")
96
+ return None
97
+
98
+ # Load the lesson file
99
+ lesson_path = self.lessons_dir / language_folder / lesson_meta['file']
100
+ if not lesson_path.exists():
101
+ logger.warning(f"Lesson file not found: {lesson_path}")
102
+ return None
103
+
104
+ with open(lesson_path, 'r', encoding='utf-8') as f:
105
+ return json.load(f)
106
+ except Exception as e:
107
+ logger.error(f"Error loading lesson {lesson_id} for {language}: {e}")
108
+ return None
109
+
110
+ def get_available_lessons(self) -> List[Dict]:
111
+ """Get list of available lessons (not planned)"""
112
+ try:
113
+ index = self.get_lessons_index()
114
+ if not index:
115
+ return []
116
+
117
+ available = [
118
+ lesson for lesson in index.get('lessons', [])
119
+ if lesson.get('status') == 'available'
120
+ ]
121
+ return available
122
+ except Exception as e:
123
+ logger.error(f"Error getting available lessons: {e}")
124
+ return []
125
+
126
+ # ==================== Achievements ====================
127
+
128
+ def get_achievements(self) -> Optional[Dict]:
129
+ """Load achievements configuration"""
130
+ try:
131
+ achievements_path = self.data_dir / "achievements.json"
132
+ if not achievements_path.exists():
133
+ logger.warning(f"Achievements file not found at {achievements_path}")
134
+ return None
135
+
136
+ with open(achievements_path, 'r', encoding='utf-8') as f:
137
+ return json.load(f)
138
+ except Exception as e:
139
+ logger.error(f"Error loading achievements: {e}")
140
+ return None
141
+
142
+ # ==================== User Progress ====================
143
+
144
+ def get_user_progress(self, user_id: str) -> Optional[Dict]:
145
+ """Load user progress data"""
146
+ try:
147
+ user_file = self.users_dir / f"user-{user_id}.json"
148
+ if not user_file.exists():
149
+ # Return default progress structure for new users
150
+ return self._create_default_user_progress(user_id)
151
+
152
+ with open(user_file, 'r', encoding='utf-8') as f:
153
+ return json.load(f)
154
+ except Exception as e:
155
+ logger.error(f"Error loading user progress for {user_id}: {e}")
156
+ return None
157
+
158
+ def save_user_progress(self, user_id: str, progress_data: Dict) -> bool:
159
+ """Save user progress data"""
160
+ try:
161
+ user_file = self.users_dir / f"user-{user_id}.json"
162
+
163
+ # Update last_active timestamp
164
+ if 'profile' in progress_data:
165
+ progress_data['profile']['last_active'] = datetime.utcnow().isoformat() + 'Z'
166
+
167
+ with open(user_file, 'w', encoding='utf-8') as f:
168
+ json.dump(progress_data, f, indent=2, ensure_ascii=False)
169
+
170
+ logger.info(f"Saved progress for user {user_id}")
171
+ return True
172
+ except Exception as e:
173
+ logger.error(f"Error saving user progress for {user_id}: {e}")
174
+ return False
175
+
176
+ def update_lesson_progress(
177
+ self,
178
+ user_id: str,
179
+ lesson_id: int,
180
+ progress_update: Dict
181
+ ) -> bool:
182
+ """Update progress for a specific lesson"""
183
+ try:
184
+ user_progress = self.get_user_progress(user_id)
185
+ if not user_progress:
186
+ return False
187
+
188
+ # Initialize lesson_progress if it doesn't exist
189
+ if 'lesson_progress' not in user_progress:
190
+ user_progress['lesson_progress'] = {}
191
+
192
+ lesson_key = str(lesson_id)
193
+
194
+ # Update or create lesson progress
195
+ if lesson_key in user_progress['lesson_progress']:
196
+ user_progress['lesson_progress'][lesson_key].update(progress_update)
197
+ else:
198
+ user_progress['lesson_progress'][lesson_key] = progress_update
199
+
200
+ return self.save_user_progress(user_id, user_progress)
201
+ except Exception as e:
202
+ logger.error(f"Error updating lesson progress: {e}")
203
+ return False
204
+
205
+ def update_vocabulary_progress(
206
+ self,
207
+ user_id: str,
208
+ vocab_id: int,
209
+ vocab_update: Dict
210
+ ) -> bool:
211
+ """Update progress for a specific vocabulary word"""
212
+ try:
213
+ user_progress = self.get_user_progress(user_id)
214
+ if not user_progress:
215
+ return False
216
+
217
+ # Initialize vocabulary_progress if it doesn't exist
218
+ if 'vocabulary_progress' not in user_progress:
219
+ user_progress['vocabulary_progress'] = {}
220
+
221
+ vocab_key = str(vocab_id)
222
+
223
+ # Update or create vocabulary progress
224
+ if vocab_key in user_progress['vocabulary_progress']:
225
+ user_progress['vocabulary_progress'][vocab_key].update(vocab_update)
226
+ else:
227
+ user_progress['vocabulary_progress'][vocab_key] = vocab_update
228
+
229
+ return self.save_user_progress(user_id, user_progress)
230
+ except Exception as e:
231
+ logger.error(f"Error updating vocabulary progress: {e}")
232
+ return False
233
+
234
+ def unlock_achievement(
235
+ self,
236
+ user_id: str,
237
+ achievement_id: str,
238
+ progress: int,
239
+ target: int
240
+ ) -> bool:
241
+ """Unlock or update progress on an achievement"""
242
+ try:
243
+ user_progress = self.get_user_progress(user_id)
244
+ if not user_progress:
245
+ return False
246
+
247
+ # Initialize achievements if it doesn't exist
248
+ if 'achievements' not in user_progress:
249
+ user_progress['achievements'] = {}
250
+
251
+ # Update achievement
252
+ achievement_data = {
253
+ 'achievement_id': achievement_id,
254
+ 'unlocked': progress >= target,
255
+ 'progress': progress,
256
+ 'target': target
257
+ }
258
+
259
+ # Add unlock timestamp if newly unlocked
260
+ if achievement_data['unlocked'] and achievement_id not in user_progress['achievements']:
261
+ achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
262
+ elif achievement_data['unlocked'] and achievement_id in user_progress['achievements']:
263
+ # Preserve original unlock time
264
+ if 'unlocked_at' in user_progress['achievements'][achievement_id]:
265
+ achievement_data['unlocked_at'] = user_progress['achievements'][achievement_id]['unlocked_at']
266
+ else:
267
+ achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
268
+
269
+ user_progress['achievements'][achievement_id] = achievement_data
270
+
271
+ return self.save_user_progress(user_id, user_progress)
272
+ except Exception as e:
273
+ logger.error(f"Error unlocking achievement: {e}")
274
+ return False
275
+
276
+ # ==================== Helper Methods ====================
277
+
278
+ def _create_default_user_progress(self, user_id: str) -> Dict:
279
+ """Create default progress structure for a new user"""
280
+ return {
281
+ 'user_id': user_id,
282
+ 'profile': {
283
+ 'user_id': user_id,
284
+ 'learning_language': 'swa',
285
+ 'native_language': 'eng',
286
+ 'created_at': datetime.utcnow().isoformat() + 'Z',
287
+ 'last_active': datetime.utcnow().isoformat() + 'Z'
288
+ },
289
+ 'overall_stats': {
290
+ 'level': 'beginner',
291
+ 'total_xp': 0,
292
+ 'next_level_xp': 1000,
293
+ 'current_streak': 0,
294
+ 'longest_streak': 0,
295
+ 'lessons_completed': 0,
296
+ 'vocabulary_learned': 0,
297
+ 'vocabulary_mastered': 0,
298
+ 'total_practice_time_seconds': 0,
299
+ 'pronunciation_avg_score': 0.0,
300
+ 'listening_avg_score': 0.0,
301
+ 'reading_avg_score': 0.0
302
+ },
303
+ 'daily_stats': {},
304
+ 'lesson_progress': {},
305
+ 'vocabulary_progress': {},
306
+ 'achievements': {},
307
+ 'session_history': []
308
+ }
309
+
310
+ def create_default_progress(self, user_id: str) -> Dict:
311
+ """Public method to create default progress structure"""
312
+ progress = self._create_default_user_progress(user_id)
313
+ # Add Phase 1-3 specific fields
314
+ progress['overall_stats']['vocabulary_reviewed'] = 0
315
+ progress['comprehension_scores'] = {}
316
+ progress['scenario_progress'] = {}
317
+ return progress
318
+
319
+ # ==================== Phase 1-3 Methods ====================
320
+
321
+ def get_vocabulary(self, vocab_id: int) -> Optional[Dict]:
322
+ """Get a single vocabulary word by ID from any lesson"""
323
+ try:
324
+ lessons_index = self.get_lessons_index()
325
+ if not lessons_index:
326
+ return None
327
+
328
+ # Search through all lessons
329
+ for lesson_meta in lessons_index.get('lessons', []):
330
+ lesson = self.get_lesson(lesson_meta['lesson_id'])
331
+ if lesson and 'vocabulary' in lesson:
332
+ for vocab in lesson['vocabulary']:
333
+ # Support both 'id' and 'vocabulary_id' fields
334
+ vocab_item_id = vocab.get('vocabulary_id') or vocab.get('id')
335
+ if vocab_item_id == vocab_id:
336
+ # Add lesson context
337
+ vocab['lesson_id'] = lesson['lesson_id']
338
+ vocab['lesson_title'] = lesson.get('title', '')
339
+ return vocab
340
+
341
+ logger.warning(f"Vocabulary {vocab_id} not found in any lesson")
342
+ return None
343
+ except Exception as e:
344
+ logger.error(f"Error getting vocabulary {vocab_id}: {e}")
345
+ return None
346
+
347
+ def get_all_vocabulary(self) -> List[Dict]:
348
+ """Get all vocabulary words from all lessons"""
349
+ try:
350
+ all_vocab = []
351
+ lessons_index = self.get_lessons_index()
352
+ if not lessons_index:
353
+ return all_vocab
354
+
355
+ for lesson_meta in lessons_index.get('lessons', []):
356
+ lesson = self.get_lesson(lesson_meta['lesson_id'])
357
+ if lesson and 'vocabulary' in lesson:
358
+ for vocab in lesson['vocabulary']:
359
+ # Add lesson context
360
+ vocab_copy = vocab.copy()
361
+ vocab_copy['lesson_id'] = lesson['lesson_id']
362
+ vocab_copy['lesson_title'] = lesson.get('title', '')
363
+ vocab_copy['lesson_level'] = lesson.get('difficulty_level', 1)
364
+ all_vocab.append(vocab_copy)
365
+
366
+ return all_vocab
367
+ except Exception as e:
368
+ logger.error(f"Error getting all vocabulary: {e}")
369
+ return []
370
+
371
+ def get_scenario(self, scenario_id: str) -> Optional[Dict]:
372
+ """Load a task scenario by ID"""
373
+ try:
374
+ scenarios_dir = self.data_dir / "scenarios"
375
+ scenario_path = scenarios_dir / f"{scenario_id}.json"
376
+
377
+ if not scenario_path.exists():
378
+ logger.warning(f"Scenario file not found: {scenario_path}")
379
+ return None
380
+
381
+ with open(scenario_path, 'r', encoding='utf-8') as f:
382
+ return json.load(f)
383
+ except Exception as e:
384
+ logger.error(f"Error loading scenario {scenario_id}: {e}")
385
+ return None
386
+
387
+ def get_all_scenarios(self) -> List[Dict]:
388
+ """Get list of all available scenarios"""
389
+ try:
390
+ scenarios_dir = self.data_dir / "scenarios"
391
+ if not scenarios_dir.exists():
392
+ return []
393
+
394
+ scenarios = []
395
+ for scenario_file in scenarios_dir.glob("*.json"):
396
+ try:
397
+ with open(scenario_file, 'r', encoding='utf-8') as f:
398
+ scenario_data = json.load(f)
399
+ # Add just metadata, not full dialogue tree
400
+ scenarios.append({
401
+ 'scenario_id': scenario_data.get('scenario_id'),
402
+ 'title': scenario_data.get('title'),
403
+ 'title_en': scenario_data.get('title_en'),
404
+ 'level': scenario_data.get('level'),
405
+ 'estimated_duration_minutes': scenario_data.get('estimated_duration_minutes'),
406
+ 'learning_goals': scenario_data.get('learning_goals', [])
407
+ })
408
+ except Exception as e:
409
+ logger.error(f"Error loading scenario {scenario_file}: {e}")
410
+ continue
411
+
412
+ return scenarios
413
+ except Exception as e:
414
+ logger.error(f"Error getting all scenarios: {e}")
415
+ return []
app/services/quantization_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dynamic INT8 Quantization utilities for ASR models.
3
+
4
+ This module provides utilities to apply PyTorch dynamic quantization to
5
+ Hugging Face transformer models, specifically optimized for ASR models like
6
+ Whisper and Wav2Vec2-BERT.
7
+ """
8
+
9
+ import torch
10
+ from torch.quantization import quantize_dynamic
11
+ from transformers import PreTrainedModel
12
+ import time
13
+
14
+
15
+ def apply_dynamic_int8_quantization(model: PreTrainedModel, model_type: str = "auto") -> PreTrainedModel:
16
+ """
17
+ Apply dynamic INT8 quantization to a Hugging Face model.
18
+
19
+ Dynamic quantization converts model weights to INT8 and activations to INT8 on-the-fly
20
+ during inference, reducing model size and improving inference speed with minimal
21
+ accuracy loss.
22
+
23
+ Args:
24
+ model: The Hugging Face model to quantize
25
+ model_type: Type of model ("whisper", "wav2vec2-bert", or "auto")
26
+
27
+ Returns:
28
+ Quantized model
29
+
30
+ References:
31
+ - PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html
32
+ - Dynamic Quantization for NLP: https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html
33
+ """
34
+ print(f"\n{'='*60}")
35
+ print(f"Applying Dynamic INT8 Quantization to {model_type} model")
36
+ print(f"{'='*60}")
37
+
38
+ # Get model size before quantization
39
+ param_size = 0
40
+ for param in model.parameters():
41
+ param_size += param.nelement() * param.element_size()
42
+ buffer_size = 0
43
+ for buffer in model.buffers():
44
+ buffer_size += buffer.nelement() * buffer.element_size()
45
+ size_before_mb = (param_size + buffer_size) / 1024**2
46
+
47
+ print(f"Model size before quantization: {size_before_mb:.2f} MB")
48
+
49
+ # Start quantization timer
50
+ start_time = time.time()
51
+
52
+ try:
53
+ # Dynamic quantization targets:
54
+ # - torch.nn.Linear: Most common layer type in transformers
55
+ # - torch.nn.LSTM/GRU/RNN: For sequential models (if present)
56
+ #
57
+ # Note: We use qint8 (quantized int8) which converts weights to INT8
58
+ # and performs INT8 arithmetic for linear layers during inference
59
+ quantized_model = quantize_dynamic(
60
+ model,
61
+ {torch.nn.Linear}, # Quantize all Linear layers
62
+ dtype=torch.qint8 # Use 8-bit integer quantization
63
+ )
64
+
65
+ # Get model size after quantization
66
+ param_size_q = 0
67
+ for param in quantized_model.parameters():
68
+ param_size_q += param.nelement() * param.element_size()
69
+ buffer_size_q = 0
70
+ for buffer in quantized_model.buffers():
71
+ buffer_size_q += buffer.nelement() * buffer.element_size()
72
+ size_after_mb = (param_size_q + buffer_size_q) / 1024**2
73
+
74
+ quantization_time = time.time() - start_time
75
+ size_reduction = ((size_before_mb - size_after_mb) / size_before_mb) * 100
76
+
77
+ print(f"✓ Quantization successful!")
78
+ print(f" - Model size after quantization: {size_after_mb:.2f} MB")
79
+ print(f" - Size reduction: {size_reduction:.1f}%")
80
+ print(f" - Quantization time: {quantization_time:.2f}s")
81
+ print(f"{'='*60}\n")
82
+
83
+ return quantized_model
84
+
85
+ except Exception as e:
86
+ print(f"✗ Quantization failed: {e}")
87
+ print(f" Returning original unquantized model")
88
+ print(f"{'='*60}\n")
89
+ return model
90
+
91
+
92
+ def get_quantization_stats(model: PreTrainedModel) -> dict:
93
+ """
94
+ Get statistics about a model's quantization status.
95
+
96
+ Args:
97
+ model: The model to analyze
98
+
99
+ Returns:
100
+ Dictionary with quantization statistics
101
+ """
102
+ stats = {
103
+ "is_quantized": False,
104
+ "quantized_layers": 0,
105
+ "total_layers": 0,
106
+ "size_mb": 0.0
107
+ }
108
+
109
+ # Count quantized vs regular layers
110
+ for name, module in model.named_modules():
111
+ if isinstance(module, (torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU)):
112
+ stats["total_layers"] += 1
113
+
114
+ # Check if layer is quantized (will have _packed_params attribute)
115
+ if hasattr(module, '_packed_params'):
116
+ stats["quantized_layers"] += 1
117
+ stats["is_quantized"] = True
118
+
119
+ # Calculate model size
120
+ param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
121
+ buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
122
+ stats["size_mb"] = (param_size + buffer_size) / 1024**2
123
+
124
+ return stats
app/services/session_manager.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import random
3
+ import string
4
+ from typing import Dict, List, Optional
5
+ from app.models import Session, SessionCreate, Participant, Language, LanguageCode
6
+
7
+ def generate_short_code(length: int = 8) -> str:
8
+ """Generate a random short code using uppercase letters and digits"""
9
+ # Use only uppercase letters and digits to avoid confusion (no lowercase to avoid O/0, I/1 confusion)
10
+ alphabet = string.ascii_uppercase + string.digits
11
+ # Remove confusing characters
12
+ alphabet = alphabet.replace('O', '').replace('0', '').replace('I', '').replace('1', '')
13
+ return ''.join(random.choice(alphabet) for _ in range(length))
14
+
15
+ # Language mappings
16
+ LANGUAGE_MAP = {
17
+ LanguageCode.ENGLISH: Language(code=LanguageCode.ENGLISH, name="English", display_name="English (eng)"),
18
+ LanguageCode.SWAHILI: Language(code=LanguageCode.SWAHILI, name="Swahili", display_name="Swahili (swa)"),
19
+ LanguageCode.KIKUYU: Language(code=LanguageCode.KIKUYU, name="Kikuyu", display_name="Kikuyu (kik)"),
20
+ LanguageCode.KAMBA: Language(code=LanguageCode.KAMBA, name="Kamba", display_name="Kamba (kam)"),
21
+ LanguageCode.KIMERU: Language(code=LanguageCode.KIMERU, name="Kimeru", display_name="Kimeru (mer)"),
22
+ LanguageCode.LUO: Language(code=LanguageCode.LUO, name="Luo", display_name="Luo (luo)"),
23
+ LanguageCode.SOMALI: Language(code=LanguageCode.SOMALI, name="Somali", display_name="Somali (som)"),
24
+ }
25
+
26
+ class SessionManager:
27
+ def __init__(self):
28
+ self.sessions: Dict[str, Session] = {}
29
+ self.participant_sessions: Dict[str, str] = {} # participant_id -> session_id
30
+ self.short_code_to_id: Dict[str, str] = {} # short_code -> session_id
31
+ self.id_to_short_code: Dict[str, str] = {} # session_id -> short_code
32
+
33
+ async def create_session(self, session_data: SessionCreate) -> Session:
34
+ session_id = str(uuid.uuid4())
35
+
36
+ # Generate unique short code
37
+ short_code = generate_short_code(8)
38
+ while short_code in self.short_code_to_id:
39
+ # Extremely unlikely collision, but regenerate if needed
40
+ short_code = generate_short_code(8)
41
+
42
+ # Convert language codes to Language objects
43
+ languages = [LANGUAGE_MAP[lang_code] for lang_code in session_data.languages]
44
+
45
+ session = Session(
46
+ id=session_id,
47
+ name=session_data.name,
48
+ organizer_name=session_data.organizer_name,
49
+ languages=languages,
50
+ participants=[],
51
+ is_active=True,
52
+ enable_tts=session_data.enable_tts
53
+ )
54
+
55
+ self.sessions[session_id] = session
56
+ self.short_code_to_id[short_code] = session_id
57
+ self.id_to_short_code[session_id] = short_code
58
+ return session
59
+
60
+ async def get_session(self, session_id_or_code: str) -> Optional[Session]:
61
+ """Get session by full UUID or short code"""
62
+ # Try as full UUID first
63
+ session = self.sessions.get(session_id_or_code)
64
+ if session:
65
+ return session
66
+
67
+ # Try as short code
68
+ session_id = self.short_code_to_id.get(session_id_or_code.upper())
69
+ if session_id:
70
+ return self.sessions.get(session_id)
71
+
72
+ return None
73
+
74
+ def get_short_code(self, session_id: str) -> str:
75
+ """Get short code for a session ID"""
76
+ return self.id_to_short_code.get(session_id, session_id)
77
+
78
+ async def get_all_sessions(self) -> List[Session]:
79
+ return list(self.sessions.values())
80
+
81
+ async def add_participant(self, session_id: str, participant_name: str, language_code: LanguageCode) -> Optional[Participant]:
82
+ session = await self.get_session(session_id)
83
+ if not session:
84
+ return None
85
+
86
+ participant_id = str(uuid.uuid4())
87
+ language = LANGUAGE_MAP[language_code]
88
+
89
+ # Check if the participant's language is already in the session languages
90
+ language_exists = any(lang.code == language_code for lang in session.languages)
91
+ if not language_exists:
92
+ print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
93
+ session.languages.append(language)
94
+
95
+ participant = Participant(
96
+ id=participant_id,
97
+ name=participant_name,
98
+ language=language,
99
+ is_organizer=len(session.participants) == 0, # First participant is organizer
100
+ is_speaking=False,
101
+ is_connected=True
102
+ )
103
+
104
+ session.participants.append(participant)
105
+ self.participant_sessions[participant_id] = session_id
106
+
107
+ print(f"Participant {participant_name} added to session. Session now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
108
+
109
+ return participant
110
+
111
+ async def remove_participant(self, participant_id: str) -> bool:
112
+ session_id = self.participant_sessions.get(participant_id)
113
+ if not session_id:
114
+ return False
115
+
116
+ session = await self.get_session(session_id)
117
+ if not session:
118
+ return False
119
+
120
+ # Remove participant from session
121
+ session.participants = [p for p in session.participants if p.id != participant_id]
122
+ del self.participant_sessions[participant_id]
123
+
124
+ return True
125
+
126
+ async def update_participant_speaking_status(self, participant_id: str, is_speaking: bool) -> bool:
127
+ session_id = self.participant_sessions.get(participant_id)
128
+ if not session_id:
129
+ return False
130
+
131
+ session = await self.get_session(session_id)
132
+ if not session:
133
+ return False
134
+
135
+ for participant in session.participants:
136
+ if participant.id == participant_id:
137
+ participant.is_speaking = is_speaking
138
+ return True
139
+
140
+ return False
141
+
142
+ async def get_participant_session_id(self, participant_id: str) -> Optional[str]:
143
+ return self.participant_sessions.get(participant_id)
144
+
145
+ async def add_language_to_session(self, session_id: str, language_code: LanguageCode) -> bool:
146
+ """Add a language to the session if it doesn't already exist"""
147
+ session = await self.get_session(session_id)
148
+ if not session:
149
+ return False
150
+
151
+ language = LANGUAGE_MAP[language_code]
152
+
153
+ # Check if the language is already in the session languages
154
+ language_exists = any(lang.code == language_code for lang in session.languages)
155
+ if not language_exists:
156
+ print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
157
+ session.languages.append(language)
158
+ print(f"Session {session_id} now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
159
+ return True
160
+ else:
161
+ print(f"Language {language.name} ({language_code.value}) already exists in session {session_id}")
162
+ return False
163
+
164
+ async def delete_session(self, session_id: str) -> bool:
165
+ if session_id in self.sessions:
166
+ # Remove all participants from tracking
167
+ session = self.sessions[session_id]
168
+ for participant in session.participants:
169
+ if participant.id in self.participant_sessions:
170
+ del self.participant_sessions[participant.id]
171
+
172
+ # Remove short code mapping
173
+ short_code = self.id_to_short_code.get(session_id)
174
+ if short_code:
175
+ del self.short_code_to_id[short_code]
176
+ del self.id_to_short_code[session_id]
177
+
178
+ del self.sessions[session_id]
179
+ return True
180
+ return False
app/services/transcription_service.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ import time
6
+ from typing import Dict, Optional, Callable
7
+ from transformers import pipeline
8
+ import torch
9
+ from app.models import LanguageCode
10
+ import os
11
+ from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
12
+
13
+ # Silero VAD imports
14
+ try:
15
+ import silero_vad
16
+ SILERO_VAD_AVAILABLE = True
17
+ except ImportError:
18
+ SILERO_VAD_AVAILABLE = False
19
+ print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
20
+
21
+ class TranscriptionService:
22
+ def __init__(self):
23
+ self.asr_pipelines: Dict[str, any] = {}
24
+ self.device = 0 if torch.cuda.is_available() else -1
25
+
26
+ # Model configurations - using original mutisya models with updated config
27
+ self.asr_config = {
28
+ "eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
29
+ "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
30
+ "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
31
+ "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
32
+ "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
33
+ "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
34
+ "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
35
+ }
36
+
37
+ self.preload_languages = ["eng"]
38
+ self.background_loading_task = None
39
+ self.models_loading_status = {}
40
+
41
+ # Enhanced audio buffering for VAD-based sentence detection
42
+ self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
43
+ self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
44
+ self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
45
+ self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
46
+
47
+ # VAD parameters - made more lenient for better detection
48
+ self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
49
+ self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
50
+
51
+ # Silero VAD initialization
52
+ self.vad_model = None
53
+ self.vad_sample_rate = 16000
54
+ self.vad_available = SILERO_VAD_AVAILABLE
55
+
56
+ # Quantization configuration
57
+ # Set ENABLE_INT8_QUANTIZATION=true in environment to enable quantization
58
+ self.enable_quantization = os.getenv('ENABLE_INT8_QUANTIZATION', 'true').lower() == 'true'
59
+ print(f"INT8 Quantization: {'ENABLED' if self.enable_quantization else 'DISABLED'}")
60
+
61
+ async def initialize(self):
62
+ """Initialize ASR models for preloaded languages and Silero VAD"""
63
+ # Initialize Silero VAD model
64
+ if self.vad_available:
65
+ try:
66
+ print("Loading Silero VAD model...")
67
+ self.vad_model = silero_vad.load_silero_vad(onnx=False)
68
+ print("✓ Silero VAD model loaded successfully")
69
+ except Exception as e:
70
+ print(f"Failed to load Silero VAD model: {e}")
71
+ print("Falling back to RMS-based VAD")
72
+ self.vad_available = False
73
+
74
+ # Initialize ASR models
75
+ for lang_code in self.preload_languages:
76
+ if lang_code in self.asr_config:
77
+ try:
78
+ model_config = self.asr_config[lang_code]
79
+ pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
80
+ self.asr_pipelines[lang_code] = pipeline_obj
81
+ except Exception as e:
82
+ print(f"Failed to load ASR model for {lang_code}: {e}")
83
+
84
+ def _load_and_quantize_pipeline(self, lang_code: str, model_config: dict):
85
+ """Load ASR pipeline and optionally apply INT8 quantization"""
86
+ # Build pipeline parameters
87
+ pipeline_params = {
88
+ "task": "automatic-speech-recognition",
89
+ "model": model_config["model_repo"],
90
+ "device": self.device,
91
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
92
+ }
93
+
94
+ # Add trust_remote_code if specified
95
+ if model_config.get("trust_remote_code", False):
96
+ pipeline_params["trust_remote_code"] = True
97
+
98
+ print(f"Loading ASR model for {lang_code}: {model_config['model_repo']}")
99
+ pipeline_obj = pipeline(**pipeline_params)
100
+
101
+ # Apply quantization if enabled
102
+ if self.enable_quantization:
103
+ try:
104
+ # Get the underlying model from the pipeline
105
+ model = pipeline_obj.model
106
+ model_type = model_config.get("model_type", "auto")
107
+
108
+ # Apply dynamic INT8 quantization
109
+ quantized_model = apply_dynamic_int8_quantization(model, model_type)
110
+
111
+ # Replace the model in the pipeline
112
+ pipeline_obj.model = quantized_model
113
+
114
+ # Print quantization stats
115
+ stats = get_quantization_stats(quantized_model)
116
+ print(f"✓ {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
117
+
118
+ except Exception as e:
119
+ print(f"Warning: Could not quantize {lang_code} model: {e}")
120
+ print(f"Continuing with unquantized model")
121
+
122
+ return pipeline_obj
123
+
124
+ async def ensure_model_loaded(self, language_code: str):
125
+ """Load ASR model for language if not already loaded"""
126
+ if language_code not in self.asr_pipelines and language_code in self.asr_config:
127
+ try:
128
+ model_config = self.asr_config[language_code]
129
+ pipeline_obj = self._load_and_quantize_pipeline(language_code, model_config)
130
+ self.asr_pipelines[language_code] = pipeline_obj
131
+ except Exception as e:
132
+ print(f"Failed to load ASR model for {language_code}: {e}")
133
+ raise
134
+
135
+ async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
136
+ has_voice_activity: bool = True,
137
+ progress_callback: Optional[Callable] = None,
138
+ sentence_callback: Optional[Callable] = None,
139
+ debug_callback: Optional[Callable] = None) -> str:
140
+ """Process audio chunk with VAD-based sentence detection"""
141
+ try:
142
+ # Initialize buffers if needed
143
+ if participant_id not in self.candidate_audio_buffers:
144
+ # Store as numpy array, not bytes, to avoid multiple WAV header issues
145
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
146
+ self.candidate_text_cache[participant_id] = ""
147
+ self.silence_counters[participant_id] = 0
148
+ self.sentence_finalized[participant_id] = False
149
+
150
+ # Convert current chunk to numpy array for processing
151
+ current_chunk_array = self._bytes_to_audio_array(audio_data)
152
+ if len(current_chunk_array) == 0:
153
+ print(f"WARNING: Received empty audio chunk for participant {participant_id}")
154
+ return self.candidate_text_cache.get(participant_id, "")
155
+
156
+ print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
157
+ f"duration: {len(current_chunk_array)/16000:.3f}s, "
158
+ f"first 4 bytes: {audio_data[:4]}")
159
+
160
+ # DO NOT normalize individual chunks - this causes audio distortion
161
+ # We'll normalize the entire accumulated audio buffer before transcription
162
+ current_chunk_array = current_chunk_array.astype(np.float32)
163
+
164
+ # Get existing accumulated audio array (now stored as numpy array)
165
+ existing_array = self.candidate_audio_buffers[participant_id]
166
+ if len(existing_array) > 0:
167
+ # Concatenate with existing audio (like stream = np.concatenate([stream, y]))
168
+ combined_array = np.concatenate([existing_array, current_chunk_array])
169
+ else:
170
+ combined_array = current_chunk_array
171
+
172
+ # Store as numpy array to avoid WAV header accumulation issues
173
+ self.candidate_audio_buffers[participant_id] = combined_array
174
+
175
+ # For debug callback, convert to bytes (this adds ONE WAV header)
176
+ combined_bytes = self._audio_array_to_bytes(combined_array)
177
+
178
+ # Update silence counter based on voice activity
179
+ if not has_voice_activity:
180
+ self.silence_counters[participant_id] += 1
181
+ else:
182
+ self.silence_counters[participant_id] = 0
183
+
184
+ # Check if we should finalize sentence due to prolonged silence
185
+ should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
186
+ len(combined_array) > 0 and
187
+ not self.sentence_finalized[participant_id])
188
+
189
+ if should_finalize:
190
+ return await self._finalize_candidate_sentence(
191
+ language_code, participant_id, sentence_callback
192
+ )
193
+
194
+ # Always run transcription on the accumulated audio
195
+ audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
196
+
197
+ # Minimum duration check - ignore very short audio bursts
198
+ MIN_CHUNK_DURATION = 0.3 # 300ms minimum
199
+ if audio_duration_sec < MIN_CHUNK_DURATION:
200
+ print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
201
+ if progress_callback:
202
+ cached_text = self.candidate_text_cache.get(participant_id, "")
203
+ await progress_callback(cached_text, False)
204
+ return self.candidate_text_cache.get(participant_id, "")
205
+
206
+ # Force finalization if buffer gets too long (prevent infinite accumulation)
207
+ if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
208
+ return await self._finalize_candidate_sentence(
209
+ language_code, participant_id, sentence_callback
210
+ )
211
+
212
+ # Run voice activity detection on the accumulated audio before transcription
213
+ has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
214
+
215
+ if not has_voice_in_buffer:
216
+ # Still send progress update with cached text to maintain UI state
217
+ if progress_callback:
218
+ cached_text = self.candidate_text_cache.get(participant_id, "")
219
+ await progress_callback(cached_text, False)
220
+ return self.candidate_text_cache.get(participant_id, "")
221
+
222
+ # Run transcription
223
+ await self.ensure_model_loaded(language_code)
224
+
225
+ # Double-check voice activity before running expensive ASR
226
+ has_voice_for_asr = self.has_voice_activity(combined_bytes)
227
+ if not has_voice_for_asr:
228
+ print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
229
+ # Return cached text and send progress update
230
+ if progress_callback:
231
+ cached_text = self.candidate_text_cache.get(participant_id, "")
232
+ await progress_callback(cached_text, False)
233
+ return self.candidate_text_cache.get(participant_id, "")
234
+
235
+ if language_code not in self.asr_pipelines:
236
+ raise ValueError(f"ASR model not available for language: {language_code}")
237
+
238
+ print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
239
+ pipeline_obj = self.asr_pipelines[language_code]
240
+
241
+ # Normalize the ENTIRE accumulated audio buffer before transcription
242
+ # This prevents audio distortion from per-chunk normalization
243
+ normalized_array = combined_array.astype(np.float32)
244
+ max_val = np.max(np.abs(normalized_array))
245
+ if max_val > 0:
246
+ normalized_array = normalized_array / max_val
247
+
248
+ # Track transcription latency
249
+ transcription_start_time = time.time()
250
+
251
+ # For wav2vec2 models, request word timestamps
252
+ model_type = self.asr_config[language_code].get("model_type", "whisper")
253
+ if model_type in ["wav2vec2-bert", "wav2vec2"]:
254
+ result = pipeline_obj(
255
+ {"sampling_rate": 16000, "raw": normalized_array},
256
+ return_timestamps="word"
257
+ )
258
+ else:
259
+ # Whisper model - add anti-hallucination parameters
260
+ # Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
261
+ result = pipeline_obj(
262
+ {"sampling_rate": 16000, "raw": normalized_array},
263
+ return_timestamps=True,
264
+ chunk_length_s=30, # Process in 30s chunks
265
+ stride_length_s=5 # 5s stride for context
266
+ )
267
+
268
+ transcription_latency_ms = (time.time() - transcription_start_time) * 1000
269
+
270
+ candidate_text = result.get("text", "").strip()
271
+ word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
272
+
273
+ # Send debug information if callback provided (for wav2vec2 models only)
274
+ if debug_callback and word_timestamps is not None:
275
+ debug_info = {
276
+ "text": candidate_text,
277
+ "timestamps": word_timestamps,
278
+ "audio_data": combined_bytes,
279
+ "audio_duration": audio_duration_sec,
280
+ "model_type": model_type,
281
+ "transcription_latency_ms": transcription_latency_ms
282
+ }
283
+ await debug_callback(debug_info)
284
+
285
+ # Filter out common ASR artifacts and very short responses
286
+ artifacts = [
287
+ "thank you", "thanks", "bye", ".", ",", "?", "!",
288
+ "um", "uh", "ah", "hmm", "mm", "mhm",
289
+ "you", "the", "a", "an", "and", "but", "or",
290
+ "music", "laughter", "applause", "[music]", "[laughter]",
291
+ # Common Whisper hallucinations:
292
+ "subscribe", "subtitles", "amara", "www", "http",
293
+ "please subscribe", "like and subscribe",
294
+ "thank you for watching", "don't forget to subscribe",
295
+ "[blank_audio]", "[noise]", "[silence]",
296
+ ]
297
+
298
+ # Check if the result is likely an artifact
299
+ is_artifact = (
300
+ len(candidate_text) < 3 or # Very short
301
+ candidate_text.lower() in artifacts or # Common artifacts
302
+ len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
303
+ )
304
+
305
+ if is_artifact:
306
+ # Keep the previous cached text instead of updating with artifact
307
+ candidate_text = self.candidate_text_cache.get(participant_id, "")
308
+
309
+ # Cache the current candidate text
310
+ self.candidate_text_cache[participant_id] = candidate_text
311
+
312
+ # Force completion if we have a reasonable amount of text and some silence
313
+ word_count = len(candidate_text.split()) if candidate_text else 0
314
+ if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
315
+ not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
316
+ return await self._finalize_candidate_sentence(
317
+ language_code, participant_id, sentence_callback
318
+ )
319
+
320
+ # Always send progress update
321
+ if progress_callback:
322
+ await progress_callback(candidate_text, False)
323
+
324
+ return candidate_text
325
+
326
+ except Exception as e:
327
+ print(f"TranscriptionService: Error processing audio chunk: {e}")
328
+ import traceback
329
+ traceback.print_exc()
330
+ # Even on error, try to send cached text
331
+ if progress_callback:
332
+ cached_text = self.candidate_text_cache.get(participant_id, "")
333
+ await progress_callback(cached_text, False)
334
+ return self.candidate_text_cache.get(participant_id, "")
335
+
336
+ async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
337
+ sentence_callback: Optional[Callable] = None) -> str:
338
+ """Finalize the current candidate sentence and clear buffers"""
339
+ try:
340
+ # Check if sentence was already finalized
341
+ if self.sentence_finalized.get(participant_id, False):
342
+ print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
343
+ return self.candidate_text_cache.get(participant_id, "")
344
+
345
+ final_text = self.candidate_text_cache.get(participant_id, "")
346
+ final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
347
+
348
+ # Convert audio array to bytes for VAD check and callback
349
+ final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
350
+
351
+ if final_text and len(final_text.strip()) > 0:
352
+ # Run VAD check on the final accumulated buffer before sending for translation
353
+ if len(final_audio_bytes) > 0:
354
+ has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
355
+ if not has_voice_in_final:
356
+ print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
357
+ # Clear buffers without sending to translation
358
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
359
+ self.candidate_text_cache[participant_id] = ""
360
+ self.silence_counters[participant_id] = 0
361
+ self.sentence_finalized[participant_id] = False
362
+ return ""
363
+
364
+ # Mark as finalized BEFORE calling the callback to prevent race conditions
365
+ self.sentence_finalized[participant_id] = True
366
+
367
+ # Send to sentence callback for translation
368
+ if sentence_callback and len(final_audio_bytes) > 0:
369
+ print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
370
+ await sentence_callback(final_text, final_audio_bytes)
371
+
372
+ # Clear buffers for next sentence
373
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
374
+ self.candidate_text_cache[participant_id] = ""
375
+ self.silence_counters[participant_id] = 0
376
+ self.sentence_finalized[participant_id] = False # Reset for next sentence
377
+
378
+ return final_text
379
+
380
+ except Exception as e:
381
+ print(f"Error finalizing sentence: {e}")
382
+ import traceback
383
+ traceback.print_exc()
384
+ # Reset finalized flag on error
385
+ self.sentence_finalized[participant_id] = False
386
+ return ""
387
+
388
+ def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
389
+ """Voice Activity Detection using Silero VAD (with RMS fallback)"""
390
+ try:
391
+ audio_array = self._bytes_to_audio_array(audio_data)
392
+ if len(audio_array) == 0:
393
+ print("VAD: No audio array, returning False")
394
+ return False
395
+
396
+ # Normalize audio to float32 range [-1, 1]
397
+ audio_array = audio_array.astype(np.float32)
398
+ if np.max(np.abs(audio_array)) > 0:
399
+ audio_array /= np.max(np.abs(audio_array))
400
+
401
+ # Use Silero VAD if available
402
+ if self.vad_available and self.vad_model is not None:
403
+ try:
404
+ # Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
405
+ # Process audio in chunks and average the probabilities
406
+ frame_size = 512 # 32ms at 16kHz
407
+ num_samples = len(audio_array)
408
+
409
+ # If audio is too short, pad it
410
+ if num_samples < frame_size:
411
+ audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
412
+ num_samples = frame_size
413
+
414
+ # Process in frames and collect probabilities
415
+ speech_probs = []
416
+ for i in range(0, num_samples, frame_size):
417
+ frame = audio_array[i:i + frame_size]
418
+ if len(frame) < frame_size:
419
+ # Pad last frame if needed
420
+ frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
421
+
422
+ # Convert to torch tensor
423
+ frame_tensor = torch.from_numpy(frame).float()
424
+
425
+ # Get speech probability from Silero VAD
426
+ with torch.no_grad():
427
+ prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
428
+ speech_probs.append(prob)
429
+
430
+ # Average probability across all frames
431
+ speech_prob = np.mean(speech_probs)
432
+ has_voice = speech_prob > threshold
433
+
434
+ print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
435
+
436
+ return has_voice
437
+
438
+ except Exception as e:
439
+ print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
440
+ # Fall through to RMS-based VAD below
441
+
442
+ # Fallback: RMS-based VAD (original implementation)
443
+ rms_threshold = 0.002
444
+ rms = np.sqrt(np.mean(audio_array ** 2))
445
+ peak = np.max(np.abs(audio_array))
446
+ audio_std = np.std(audio_array)
447
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
448
+
449
+ has_voice_rms = rms > rms_threshold
450
+ has_voice_peak = peak > rms_threshold * 3
451
+ has_voice_variation = audio_std > rms_threshold * 0.8
452
+ has_voice_zcr = zero_crossing_rate > 0.008
453
+
454
+ has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
455
+
456
+ print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
457
+
458
+ return has_voice
459
+
460
+ except Exception as e:
461
+ print(f"Error in VAD: {e}")
462
+ return True # Default to assuming voice activity on error
463
+
464
+ def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
465
+ """Stricter VAD check specifically for pre-transcription filtering"""
466
+ try:
467
+ audio_array = self._bytes_to_audio_array(audio_data)
468
+ if len(audio_array) == 0:
469
+ return False
470
+
471
+ # Normalize audio
472
+ audio_array = audio_array.astype(np.float32)
473
+ if np.max(np.abs(audio_array)) > 0:
474
+ audio_array /= np.max(np.abs(audio_array))
475
+
476
+ # Calculate features with higher thresholds for meaningful speech
477
+ rms = np.sqrt(np.mean(audio_array ** 2))
478
+ peak = np.max(np.abs(audio_array))
479
+ audio_std = np.std(audio_array)
480
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
481
+
482
+ # Higher thresholds for meaningful speech detection
483
+ has_meaningful_voice = (
484
+ rms > threshold and
485
+ peak > threshold * 2 and
486
+ audio_std > threshold * 0.5 and
487
+ zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
488
+ )
489
+
490
+ return has_meaningful_voice
491
+
492
+ except Exception as e:
493
+ print(f"Error in meaningful VAD: {e}")
494
+ return False # Default to no meaningful voice on error
495
+
496
+ async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
497
+ """Force complete any pending sentence for a participant"""
498
+ try:
499
+ # Check if sentence was already finalized
500
+ if self.sentence_finalized.get(participant_id, False):
501
+ print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
502
+ return ""
503
+
504
+ if participant_id in self.candidate_text_cache:
505
+ cached_text = self.candidate_text_cache[participant_id]
506
+
507
+ if cached_text and len(cached_text.strip()) > 0:
508
+ result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
509
+ return result
510
+
511
+ return ""
512
+
513
+ except Exception as e:
514
+ print(f"Error in force_complete_sentence: {e}")
515
+ import traceback
516
+ traceback.print_exc()
517
+ return ""
518
+
519
+ async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
520
+ """Transcribe audio data to text"""
521
+ try:
522
+ # Check for voice activity before running ASR
523
+ has_voice = self.has_voice_activity(audio_data)
524
+ if not has_voice:
525
+ print(f"ASR: No voice activity detected in audio data, skipping transcription")
526
+ return ""
527
+
528
+ await self.ensure_model_loaded(language_code)
529
+
530
+ if language_code not in self.asr_pipelines:
531
+ raise ValueError(f"ASR model not available for language: {language_code}")
532
+
533
+ # Convert audio bytes to numpy array
534
+ audio_array = self._bytes_to_audio_array(audio_data)
535
+
536
+ print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
537
+ # Transcribe
538
+ pipeline_obj = self.asr_pipelines[language_code]
539
+ result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
540
+
541
+ text = result.get("text", "")
542
+
543
+ if callback:
544
+ await callback(text)
545
+
546
+ return text
547
+
548
+ except Exception as e:
549
+ print(f"TranscriptionService: Transcription error: {e}")
550
+ import traceback
551
+ traceback.print_exc()
552
+ return ""
553
+
554
+ def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
555
+ """Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
556
+ try:
557
+ # Detect format by checking magic bytes
558
+ is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
559
+ is_wav = audio_data[:4] == b'RIFF'
560
+
561
+ import sys
562
+ print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
563
+ sys.stdout.flush()
564
+
565
+ # Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
566
+ # This is the most common case for microphone input
567
+ if not is_wav and not is_webm and len(audio_data) > 0:
568
+ try:
569
+ # Assume 16-bit PCM at 48kHz (browser's native rate)
570
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
571
+
572
+ # Check if this looks like valid audio data (not NaN, reasonable range)
573
+ if len(audio_array) > 0 and not np.isnan(audio_array).any():
574
+ print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
575
+
576
+ # Convert to float32 and normalize
577
+ audio_float = audio_array.astype(np.float32) / 32768.0
578
+
579
+ # Resample from 48kHz to 16kHz
580
+ import librosa
581
+ audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
582
+ print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
583
+
584
+ return audio_array
585
+ except Exception as pcm_error:
586
+ print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
587
+ # Fall through to other methods
588
+
589
+ if is_webm:
590
+ # Decode WebM/Opus using pydub (requires ffmpeg)
591
+ try:
592
+ from pydub import AudioSegment
593
+ audio_io = io.BytesIO(audio_data)
594
+ audio_segment = AudioSegment.from_file(audio_io, format="webm")
595
+
596
+ # Convert to mono 16kHz
597
+ audio_segment = audio_segment.set_channels(1)
598
+ audio_segment = audio_segment.set_frame_rate(16000)
599
+
600
+ # Convert to numpy array
601
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
602
+ # Normalize to float32 [-1, 1]
603
+ audio_array = samples.astype(np.float32) / 32768.0
604
+ return audio_array
605
+ except Exception as webm_error:
606
+ print(f"TranscriptionService: WebM decoding error: {webm_error}")
607
+ # Fall through to other methods
608
+
609
+ if is_wav:
610
+ # Decode WAV format (first chunk from frontend includes WAV header with sample rate)
611
+ try:
612
+ audio_io = io.BytesIO(audio_data)
613
+ with wave.open(audio_io, 'rb') as wav_file:
614
+ sample_rate = wav_file.getframerate()
615
+ channels = wav_file.getnchannels()
616
+ sample_width = wav_file.getsampwidth()
617
+
618
+ print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
619
+
620
+ frames = wav_file.readframes(-1)
621
+ audio_array = np.frombuffer(frames, dtype=np.int16)
622
+
623
+ # Resample if needed
624
+ if sample_rate != 16000:
625
+ print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
626
+ import librosa
627
+ # Convert to float first
628
+ audio_float = audio_array.astype(np.float32) / 32768.0
629
+ # Resample
630
+ audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
631
+ print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
632
+ else:
633
+ # Convert to float32 and normalize
634
+ audio_array = audio_array.astype(np.float32) / 32768.0
635
+
636
+ print(f"Returning audio array: {len(audio_array)} samples", flush=True)
637
+ return audio_array
638
+ except Exception as wav_error:
639
+ print(f"TranscriptionService: WAV decoding error: {wav_error}")
640
+ import traceback
641
+ traceback.print_exc()
642
+
643
+ # Fallback: assume raw float32 audio data
644
+ try:
645
+ audio_array = np.frombuffer(audio_data, dtype=np.float32)
646
+ return audio_array
647
+ except Exception:
648
+ pass
649
+
650
+ # Last resort: return empty array
651
+ return np.array([], dtype=np.float32)
652
+
653
+ except Exception as e:
654
+ print(f"TranscriptionService: Audio conversion error: {e}")
655
+ return np.array([], dtype=np.float32)
656
+
657
+ def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
658
+ """Convert numpy audio array back to WAV bytes for storage"""
659
+ try:
660
+ # Ensure float32 format
661
+ if audio_array.dtype != np.float32:
662
+ audio_array = audio_array.astype(np.float32)
663
+
664
+ # Convert to 16-bit PCM for WAV storage
665
+ audio_int16 = (audio_array * 32767).astype(np.int16)
666
+
667
+ # Create WAV bytes
668
+ wav_buffer = io.BytesIO()
669
+ with wave.open(wav_buffer, 'wb') as wav_file:
670
+ wav_file.setnchannels(1) # Mono
671
+ wav_file.setsampwidth(2) # 16-bit
672
+ wav_file.setframerate(16000) # 16kHz
673
+ wav_file.writeframes(audio_int16.tobytes())
674
+
675
+ return wav_buffer.getvalue()
676
+
677
+ except Exception as e:
678
+ print(f"Error converting audio array to bytes: {e}")
679
+ return b''
680
+
681
+ def clear_participant_buffers(self, participant_id: str):
682
+ """Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
683
+ if participant_id in self.candidate_audio_buffers:
684
+ del self.candidate_audio_buffers[participant_id]
685
+ if participant_id in self.candidate_text_cache:
686
+ del self.candidate_text_cache[participant_id]
687
+ if participant_id in self.silence_counters:
688
+ del self.silence_counters[participant_id]
689
+ if participant_id in self.sentence_finalized:
690
+ del self.sentence_finalized[participant_id]
691
+
692
+ async def load_remaining_models_in_background(self):
693
+ """Load all remaining ASR models in the background after startup"""
694
+ try:
695
+ print("ASR: Starting background loading of additional language models...")
696
+ for lang_code in self.asr_config.keys():
697
+ if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
698
+ try:
699
+ print(f"ASR: Background loading model for {lang_code}...")
700
+ self.models_loading_status[lang_code] = "loading"
701
+
702
+ model_config = self.asr_config[lang_code]
703
+ # Use quantization helper for background loading too
704
+ pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
705
+ self.asr_pipelines[lang_code] = pipeline_obj
706
+ self.models_loading_status[lang_code] = "loaded"
707
+ print(f"ASR: Successfully loaded model for {lang_code} in background")
708
+
709
+ # Add a small delay between loading models to prevent overwhelming the system
710
+ await asyncio.sleep(2)
711
+ except Exception as e:
712
+ print(f"ASR: Failed to load model for {lang_code} in background: {e}")
713
+ self.models_loading_status[lang_code] = "failed"
714
+
715
+ print("ASR: Background loading of all language models complete")
716
+ print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
717
+ except Exception as e:
718
+ print(f"ASR: Error in background model loading: {e}")
719
+
720
+ def start_background_loading(self):
721
+ """Start background loading of models as a non-blocking task"""
722
+ if self.background_loading_task is None:
723
+ self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
724
+ print("ASR: Background model loading task started")
725
+
726
+ async def cleanup(self):
727
+ """Cleanup resources"""
728
+ # Cancel background loading if still running
729
+ if self.background_loading_task and not self.background_loading_task.done():
730
+ self.background_loading_task.cancel()
731
+ try:
732
+ await self.background_loading_task
733
+ except asyncio.CancelledError:
734
+ pass
735
+
736
+ self.asr_pipelines.clear()
app/services/transcription_service.py.bak ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ import time
6
+ from typing import Dict, Optional, Callable
7
+ from transformers import pipeline
8
+ import torch
9
+ from app.models import LanguageCode
10
+ from app.services.performance_mixin import track_performance
11
+
12
+ # Silero VAD imports
13
+ try:
14
+ import silero_vad
15
+ SILERO_VAD_AVAILABLE = True
16
+ except ImportError:
17
+ SILERO_VAD_AVAILABLE = False
18
+ print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
19
+
20
+ class TranscriptionService:
21
+ def __init__(self):
22
+ self.asr_pipelines: Dict[str, any] = {}
23
+ self.device = 0 if torch.cuda.is_available() else -1
24
+
25
+ # Model configurations - using original mutisya models with updated config
26
+ self.asr_config = {
27
+ "eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
28
+ "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
29
+ "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
30
+ "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
31
+ "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
32
+ "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
33
+ "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
34
+ }
35
+
36
+ self.preload_languages = ["eng"]
37
+ self.background_loading_task = None
38
+ self.models_loading_status = {}
39
+
40
+ # Enhanced audio buffering for VAD-based sentence detection
41
+ self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
42
+ self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
43
+ self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
44
+ self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
45
+
46
+ # VAD parameters - made more lenient for better detection
47
+ self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
48
+ self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
49
+
50
+ # Silero VAD initialization
51
+ self.vad_model = None
52
+ self.vad_sample_rate = 16000
53
+ self.vad_available = SILERO_VAD_AVAILABLE
54
+
55
+ async def initialize(self):
56
+ """Initialize ASR models for preloaded languages and Silero VAD"""
57
+ # Initialize Silero VAD model
58
+ if self.vad_available:
59
+ try:
60
+ print("Loading Silero VAD model...")
61
+ self.vad_model = silero_vad.load_silero_vad(onnx=False)
62
+ print("✓ Silero VAD model loaded successfully")
63
+ except Exception as e:
64
+ print(f"Failed to load Silero VAD model: {e}")
65
+ print("Falling back to RMS-based VAD")
66
+ self.vad_available = False
67
+
68
+ # Initialize ASR models
69
+ for lang_code in self.preload_languages:
70
+ if lang_code in self.asr_config:
71
+ try:
72
+ model_config = self.asr_config[lang_code]
73
+ # Build pipeline parameters
74
+ pipeline_params = {
75
+ "task": "automatic-speech-recognition",
76
+ "model": model_config["model_repo"],
77
+ "device": self.device,
78
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
79
+ }
80
+
81
+ # Add trust_remote_code if specified
82
+ if model_config.get("trust_remote_code", False):
83
+ pipeline_params["trust_remote_code"] = True
84
+
85
+ pipeline_obj = pipeline(**pipeline_params)
86
+ self.asr_pipelines[lang_code] = pipeline_obj
87
+ except Exception as e:
88
+ print(f"Failed to load ASR model for {lang_code}: {e}")
89
+
90
+ async def ensure_model_loaded(self, language_code: str):
91
+ """Load ASR model for language if not already loaded"""
92
+ if language_code not in self.asr_pipelines and language_code in self.asr_config:
93
+ try:
94
+ model_config = self.asr_config[language_code]
95
+ # Build pipeline parameters
96
+ pipeline_params = {
97
+ "task": "automatic-speech-recognition",
98
+ "model": model_config["model_repo"],
99
+ "device": self.device,
100
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
101
+ }
102
+
103
+ # Add trust_remote_code if specified
104
+ if model_config.get("trust_remote_code", False):
105
+ pipeline_params["trust_remote_code"] = True
106
+
107
+ pipeline_obj = pipeline(**pipeline_params)
108
+ self.asr_pipelines[language_code] = pipeline_obj
109
+ except Exception as e:
110
+ print(f"Failed to load ASR model for {language_code}: {e}")
111
+ raise
112
+
113
+ async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
114
+ has_voice_activity: bool = True,
115
+ progress_callback: Optional[Callable] = None,
116
+ sentence_callback: Optional[Callable] = None,
117
+ debug_callback: Optional[Callable] = None) -> str:
118
+ """Process audio chunk with VAD-based sentence detection"""
119
+ try:
120
+ # Initialize buffers if needed
121
+ if participant_id not in self.candidate_audio_buffers:
122
+ # Store as numpy array, not bytes, to avoid multiple WAV header issues
123
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
124
+ self.candidate_text_cache[participant_id] = ""
125
+ self.silence_counters[participant_id] = 0
126
+ self.sentence_finalized[participant_id] = False
127
+
128
+ # Convert current chunk to numpy array for processing
129
+ current_chunk_array = self._bytes_to_audio_array(audio_data)
130
+ if len(current_chunk_array) == 0:
131
+ print(f"WARNING: Received empty audio chunk for participant {participant_id}")
132
+ return self.candidate_text_cache.get(participant_id, "")
133
+
134
+ print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
135
+ f"duration: {len(current_chunk_array)/16000:.3f}s, "
136
+ f"first 4 bytes: {audio_data[:4]}")
137
+
138
+ # DO NOT normalize individual chunks - this causes audio distortion
139
+ # We'll normalize the entire accumulated audio buffer before transcription
140
+ current_chunk_array = current_chunk_array.astype(np.float32)
141
+
142
+ # Get existing accumulated audio array (now stored as numpy array)
143
+ existing_array = self.candidate_audio_buffers[participant_id]
144
+ if len(existing_array) > 0:
145
+ # Concatenate with existing audio (like stream = np.concatenate([stream, y]))
146
+ combined_array = np.concatenate([existing_array, current_chunk_array])
147
+ else:
148
+ combined_array = current_chunk_array
149
+
150
+ # Store as numpy array to avoid WAV header accumulation issues
151
+ self.candidate_audio_buffers[participant_id] = combined_array
152
+
153
+ # For debug callback, convert to bytes (this adds ONE WAV header)
154
+ combined_bytes = self._audio_array_to_bytes(combined_array)
155
+
156
+ # Update silence counter based on voice activity
157
+ if not has_voice_activity:
158
+ self.silence_counters[participant_id] += 1
159
+ else:
160
+ self.silence_counters[participant_id] = 0
161
+
162
+ # Check if we should finalize sentence due to prolonged silence
163
+ should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
164
+ len(combined_array) > 0 and
165
+ not self.sentence_finalized[participant_id])
166
+
167
+ if should_finalize:
168
+ return await self._finalize_candidate_sentence(
169
+ language_code, participant_id, sentence_callback
170
+ )
171
+
172
+ # Always run transcription on the accumulated audio
173
+ audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
174
+
175
+ # Minimum duration check - ignore very short audio bursts
176
+ MIN_CHUNK_DURATION = 0.3 # 300ms minimum
177
+ if audio_duration_sec < MIN_CHUNK_DURATION:
178
+ print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
179
+ if progress_callback:
180
+ cached_text = self.candidate_text_cache.get(participant_id, "")
181
+ await progress_callback(cached_text, False)
182
+ return self.candidate_text_cache.get(participant_id, "")
183
+
184
+ # Force finalization if buffer gets too long (prevent infinite accumulation)
185
+ if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
186
+ return await self._finalize_candidate_sentence(
187
+ language_code, participant_id, sentence_callback
188
+ )
189
+
190
+ # Run voice activity detection on the accumulated audio before transcription
191
+ has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
192
+
193
+ if not has_voice_in_buffer:
194
+ # Still send progress update with cached text to maintain UI state
195
+ if progress_callback:
196
+ cached_text = self.candidate_text_cache.get(participant_id, "")
197
+ await progress_callback(cached_text, False)
198
+ return self.candidate_text_cache.get(participant_id, "")
199
+
200
+ # Run transcription
201
+ await self.ensure_model_loaded(language_code)
202
+
203
+ # Double-check voice activity before running expensive ASR
204
+ has_voice_for_asr = self.has_voice_activity(combined_bytes)
205
+ if not has_voice_for_asr:
206
+ print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
207
+ # Return cached text and send progress update
208
+ if progress_callback:
209
+ cached_text = self.candidate_text_cache.get(participant_id, "")
210
+ await progress_callback(cached_text, False)
211
+ return self.candidate_text_cache.get(participant_id, "")
212
+
213
+ if language_code not in self.asr_pipelines:
214
+ raise ValueError(f"ASR model not available for language: {language_code}")
215
+
216
+ print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
217
+ pipeline_obj = self.asr_pipelines[language_code]
218
+
219
+ # Normalize the ENTIRE accumulated audio buffer before transcription
220
+ # This prevents audio distortion from per-chunk normalization
221
+ normalized_array = combined_array.astype(np.float32)
222
+ max_val = np.max(np.abs(normalized_array))
223
+ if max_val > 0:
224
+ normalized_array = normalized_array / max_val
225
+
226
+ # Track transcription latency
227
+ transcription_start_time = time.time()
228
+
229
+ # For wav2vec2 models, request word timestamps
230
+ model_type = self.asr_config[language_code].get("model_type", "whisper")
231
+ if model_type in ["wav2vec2-bert", "wav2vec2"]:
232
+ result = pipeline_obj(
233
+ {"sampling_rate": 16000, "raw": normalized_array},
234
+ return_timestamps="word"
235
+ )
236
+ else:
237
+ # Whisper model - add anti-hallucination parameters
238
+ # Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
239
+ result = pipeline_obj(
240
+ {"sampling_rate": 16000, "raw": normalized_array},
241
+ return_timestamps=True,
242
+ chunk_length_s=30, # Process in 30s chunks
243
+ stride_length_s=5 # 5s stride for context
244
+ )
245
+
246
+ transcription_latency_ms = (time.time() - transcription_start_time) * 1000
247
+
248
+ candidate_text = result.get("text", "").strip()
249
+ word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
250
+
251
+ # Send debug information if callback provided (for wav2vec2 models only)
252
+ if debug_callback and word_timestamps is not None:
253
+ debug_info = {
254
+ "text": candidate_text,
255
+ "timestamps": word_timestamps,
256
+ "audio_data": combined_bytes,
257
+ "audio_duration": audio_duration_sec,
258
+ "model_type": model_type,
259
+ "transcription_latency_ms": transcription_latency_ms
260
+ }
261
+ await debug_callback(debug_info)
262
+
263
+ # Filter out common ASR artifacts and very short responses
264
+ artifacts = [
265
+ "thank you", "thanks", "bye", ".", ",", "?", "!",
266
+ "um", "uh", "ah", "hmm", "mm", "mhm",
267
+ "you", "the", "a", "an", "and", "but", "or",
268
+ "music", "laughter", "applause", "[music]", "[laughter]",
269
+ # Common Whisper hallucinations:
270
+ "subscribe", "subtitles", "amara", "www", "http",
271
+ "please subscribe", "like and subscribe",
272
+ "thank you for watching", "don't forget to subscribe",
273
+ "[blank_audio]", "[noise]", "[silence]",
274
+ ]
275
+
276
+ # Check if the result is likely an artifact
277
+ is_artifact = (
278
+ len(candidate_text) < 3 or # Very short
279
+ candidate_text.lower() in artifacts or # Common artifacts
280
+ len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
281
+ )
282
+
283
+ if is_artifact:
284
+ # Keep the previous cached text instead of updating with artifact
285
+ candidate_text = self.candidate_text_cache.get(participant_id, "")
286
+
287
+ # Cache the current candidate text
288
+ self.candidate_text_cache[participant_id] = candidate_text
289
+
290
+ # Force completion if we have a reasonable amount of text and some silence
291
+ word_count = len(candidate_text.split()) if candidate_text else 0
292
+ if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
293
+ not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
294
+ return await self._finalize_candidate_sentence(
295
+ language_code, participant_id, sentence_callback
296
+ )
297
+
298
+ # Always send progress update
299
+ if progress_callback:
300
+ await progress_callback(candidate_text, False)
301
+
302
+ return candidate_text
303
+
304
+ except Exception as e:
305
+ print(f"TranscriptionService: Error processing audio chunk: {e}")
306
+ import traceback
307
+ traceback.print_exc()
308
+ # Even on error, try to send cached text
309
+ if progress_callback:
310
+ cached_text = self.candidate_text_cache.get(participant_id, "")
311
+ await progress_callback(cached_text, False)
312
+ return self.candidate_text_cache.get(participant_id, "")
313
+
314
+ async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
315
+ sentence_callback: Optional[Callable] = None) -> str:
316
+ """Finalize the current candidate sentence and clear buffers"""
317
+ try:
318
+ # Check if sentence was already finalized
319
+ if self.sentence_finalized.get(participant_id, False):
320
+ print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
321
+ return self.candidate_text_cache.get(participant_id, "")
322
+
323
+ final_text = self.candidate_text_cache.get(participant_id, "")
324
+ final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
325
+
326
+ # Convert audio array to bytes for VAD check and callback
327
+ final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
328
+
329
+ if final_text and len(final_text.strip()) > 0:
330
+ # Run VAD check on the final accumulated buffer before sending for translation
331
+ if len(final_audio_bytes) > 0:
332
+ has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
333
+ if not has_voice_in_final:
334
+ print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
335
+ # Clear buffers without sending to translation
336
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
337
+ self.candidate_text_cache[participant_id] = ""
338
+ self.silence_counters[participant_id] = 0
339
+ self.sentence_finalized[participant_id] = False
340
+ return ""
341
+
342
+ # Mark as finalized BEFORE calling the callback to prevent race conditions
343
+ self.sentence_finalized[participant_id] = True
344
+
345
+ # Send to sentence callback for translation
346
+ if sentence_callback and len(final_audio_bytes) > 0:
347
+ print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
348
+ await sentence_callback(final_text, final_audio_bytes)
349
+
350
+ # Clear buffers for next sentence
351
+ self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
352
+ self.candidate_text_cache[participant_id] = ""
353
+ self.silence_counters[participant_id] = 0
354
+ self.sentence_finalized[participant_id] = False # Reset for next sentence
355
+
356
+ return final_text
357
+
358
+ except Exception as e:
359
+ print(f"Error finalizing sentence: {e}")
360
+ import traceback
361
+ traceback.print_exc()
362
+ # Reset finalized flag on error
363
+ self.sentence_finalized[participant_id] = False
364
+ return ""
365
+
366
+ def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
367
+ """Voice Activity Detection using Silero VAD (with RMS fallback)"""
368
+ try:
369
+ audio_array = self._bytes_to_audio_array(audio_data)
370
+ if len(audio_array) == 0:
371
+ print("VAD: No audio array, returning False")
372
+ return False
373
+
374
+ # Normalize audio to float32 range [-1, 1]
375
+ audio_array = audio_array.astype(np.float32)
376
+ if np.max(np.abs(audio_array)) > 0:
377
+ audio_array /= np.max(np.abs(audio_array))
378
+
379
+ # Use Silero VAD if available
380
+ if self.vad_available and self.vad_model is not None:
381
+ try:
382
+ # Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
383
+ # Process audio in chunks and average the probabilities
384
+ frame_size = 512 # 32ms at 16kHz
385
+ num_samples = len(audio_array)
386
+
387
+ # If audio is too short, pad it
388
+ if num_samples < frame_size:
389
+ audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
390
+ num_samples = frame_size
391
+
392
+ # Process in frames and collect probabilities
393
+ speech_probs = []
394
+ for i in range(0, num_samples, frame_size):
395
+ frame = audio_array[i:i + frame_size]
396
+ if len(frame) < frame_size:
397
+ # Pad last frame if needed
398
+ frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
399
+
400
+ # Convert to torch tensor
401
+ frame_tensor = torch.from_numpy(frame).float()
402
+
403
+ # Get speech probability from Silero VAD
404
+ with torch.no_grad():
405
+ prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
406
+ speech_probs.append(prob)
407
+
408
+ # Average probability across all frames
409
+ speech_prob = np.mean(speech_probs)
410
+ has_voice = speech_prob > threshold
411
+
412
+ print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
413
+
414
+ return has_voice
415
+
416
+ except Exception as e:
417
+ print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
418
+ # Fall through to RMS-based VAD below
419
+
420
+ # Fallback: RMS-based VAD (original implementation)
421
+ rms_threshold = 0.002
422
+ rms = np.sqrt(np.mean(audio_array ** 2))
423
+ peak = np.max(np.abs(audio_array))
424
+ audio_std = np.std(audio_array)
425
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
426
+
427
+ has_voice_rms = rms > rms_threshold
428
+ has_voice_peak = peak > rms_threshold * 3
429
+ has_voice_variation = audio_std > rms_threshold * 0.8
430
+ has_voice_zcr = zero_crossing_rate > 0.008
431
+
432
+ has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
433
+
434
+ print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
435
+
436
+ return has_voice
437
+
438
+ except Exception as e:
439
+ print(f"Error in VAD: {e}")
440
+ return True # Default to assuming voice activity on error
441
+
442
+ def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
443
+ """Stricter VAD check specifically for pre-transcription filtering"""
444
+ try:
445
+ audio_array = self._bytes_to_audio_array(audio_data)
446
+ if len(audio_array) == 0:
447
+ return False
448
+
449
+ # Normalize audio
450
+ audio_array = audio_array.astype(np.float32)
451
+ if np.max(np.abs(audio_array)) > 0:
452
+ audio_array /= np.max(np.abs(audio_array))
453
+
454
+ # Calculate features with higher thresholds for meaningful speech
455
+ rms = np.sqrt(np.mean(audio_array ** 2))
456
+ peak = np.max(np.abs(audio_array))
457
+ audio_std = np.std(audio_array)
458
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
459
+
460
+ # Higher thresholds for meaningful speech detection
461
+ has_meaningful_voice = (
462
+ rms > threshold and
463
+ peak > threshold * 2 and
464
+ audio_std > threshold * 0.5 and
465
+ zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
466
+ )
467
+
468
+ return has_meaningful_voice
469
+
470
+ except Exception as e:
471
+ print(f"Error in meaningful VAD: {e}")
472
+ return False # Default to no meaningful voice on error
473
+
474
+ async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
475
+ """Force complete any pending sentence for a participant"""
476
+ try:
477
+ # Check if sentence was already finalized
478
+ if self.sentence_finalized.get(participant_id, False):
479
+ print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
480
+ return ""
481
+
482
+ if participant_id in self.candidate_text_cache:
483
+ cached_text = self.candidate_text_cache[participant_id]
484
+
485
+ if cached_text and len(cached_text.strip()) > 0:
486
+ result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
487
+ return result
488
+
489
+ return ""
490
+
491
+ except Exception as e:
492
+ print(f"Error in force_complete_sentence: {e}")
493
+ import traceback
494
+ traceback.print_exc()
495
+ return ""
496
+
497
+ @track_performance("transcription", "transcribe_audio")
498
+ async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
499
+ """Transcribe audio data to text"""
500
+ try:
501
+ # Check for voice activity before running ASR
502
+ has_voice = self.has_voice_activity(audio_data)
503
+ if not has_voice:
504
+ print(f"ASR: No voice activity detected in audio data, skipping transcription")
505
+ return ""
506
+
507
+ await self.ensure_model_loaded(language_code)
508
+
509
+ if language_code not in self.asr_pipelines:
510
+ raise ValueError(f"ASR model not available for language: {language_code}")
511
+
512
+ # Convert audio bytes to numpy array
513
+ audio_array = self._bytes_to_audio_array(audio_data)
514
+
515
+ print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
516
+ # Transcribe
517
+ pipeline_obj = self.asr_pipelines[language_code]
518
+ result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
519
+
520
+ text = result.get("text", "")
521
+
522
+ if callback:
523
+ await callback(text)
524
+
525
+ return text
526
+
527
+ except Exception as e:
528
+ print(f"TranscriptionService: Transcription error: {e}")
529
+ import traceback
530
+ traceback.print_exc()
531
+ return ""
532
+
533
+ def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
534
+ """Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
535
+ try:
536
+ # Detect format by checking magic bytes
537
+ is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
538
+ is_wav = audio_data[:4] == b'RIFF'
539
+
540
+ import sys
541
+ print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
542
+ sys.stdout.flush()
543
+
544
+ # Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
545
+ # This is the most common case now that we strip WAV headers in frontend
546
+ if not is_wav and not is_webm and len(audio_data) > 0:
547
+ try:
548
+ # Assume 16-bit PCM at 48kHz (browser's native rate)
549
+ audio_array = np.frombuffer(audio_data, dtype=np.int16)
550
+
551
+ # Check if this looks like valid audio data (not NaN, reasonable range)
552
+ if len(audio_array) > 0 and not np.isnan(audio_array).any():
553
+ print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
554
+
555
+ # Convert to float32 and normalize
556
+ audio_float = audio_array.astype(np.float32) / 32768.0
557
+
558
+ # Resample from 48kHz to 16kHz
559
+ import librosa
560
+ audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
561
+ print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
562
+
563
+ return audio_array
564
+ except Exception as pcm_error:
565
+ print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
566
+ # Fall through to other methods
567
+
568
+ if is_webm:
569
+ # Decode WebM/Opus using pydub (requires ffmpeg)
570
+ try:
571
+ from pydub import AudioSegment
572
+ audio_io = io.BytesIO(audio_data)
573
+ audio_segment = AudioSegment.from_file(audio_io, format="webm")
574
+
575
+ # Convert to mono 16kHz
576
+ audio_segment = audio_segment.set_channels(1)
577
+ audio_segment = audio_segment.set_frame_rate(16000)
578
+
579
+ # Convert to numpy array
580
+ samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
581
+ # Normalize to float32 [-1, 1]
582
+ audio_array = samples.astype(np.float32) / 32768.0
583
+ return audio_array
584
+ except Exception as webm_error:
585
+ print(f"TranscriptionService: WebM decoding error: {webm_error}")
586
+ # Fall through to other methods
587
+
588
+ if is_wav:
589
+ # Decode WAV format
590
+ try:
591
+ audio_io = io.BytesIO(audio_data)
592
+ with wave.open(audio_io, 'rb') as wav_file:
593
+ sample_rate = wav_file.getframerate()
594
+ channels = wav_file.getnchannels()
595
+ sample_width = wav_file.getsampwidth()
596
+
597
+ print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
598
+
599
+ frames = wav_file.readframes(-1)
600
+ audio_array = np.frombuffer(frames, dtype=np.int16)
601
+
602
+ # Resample if needed
603
+ if sample_rate != 16000:
604
+ print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
605
+ import librosa
606
+ # Convert to float first
607
+ audio_float = audio_array.astype(np.float32) / 32768.0
608
+ # Resample
609
+ audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
610
+ print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
611
+ else:
612
+ # Convert to float32 and normalize
613
+ audio_array = audio_array.astype(np.float32) / 32768.0
614
+
615
+ print(f"Returning audio array: {len(audio_array)} samples", flush=True)
616
+ return audio_array
617
+ except Exception as wav_error:
618
+ print(f"TranscriptionService: WAV decoding error: {wav_error}")
619
+ import traceback
620
+ traceback.print_exc()
621
+
622
+ # Fallback: assume raw float32 audio data
623
+ try:
624
+ audio_array = np.frombuffer(audio_data, dtype=np.float32)
625
+ return audio_array
626
+ except Exception:
627
+ pass
628
+
629
+ # Last resort: return empty array
630
+ return np.array([], dtype=np.float32)
631
+
632
+ except Exception as e:
633
+ print(f"TranscriptionService: Audio conversion error: {e}")
634
+ return np.array([], dtype=np.float32)
635
+
636
+ def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
637
+ """Convert numpy audio array back to WAV bytes for storage"""
638
+ try:
639
+ # Ensure float32 format
640
+ if audio_array.dtype != np.float32:
641
+ audio_array = audio_array.astype(np.float32)
642
+
643
+ # Convert to 16-bit PCM for WAV storage
644
+ audio_int16 = (audio_array * 32767).astype(np.int16)
645
+
646
+ # Create WAV bytes
647
+ wav_buffer = io.BytesIO()
648
+ with wave.open(wav_buffer, 'wb') as wav_file:
649
+ wav_file.setnchannels(1) # Mono
650
+ wav_file.setsampwidth(2) # 16-bit
651
+ wav_file.setframerate(16000) # 16kHz
652
+ wav_file.writeframes(audio_int16.tobytes())
653
+
654
+ return wav_buffer.getvalue()
655
+
656
+ except Exception as e:
657
+ print(f"Error converting audio array to bytes: {e}")
658
+ return b''
659
+
660
+ def clear_participant_buffers(self, participant_id: str):
661
+ """Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
662
+ if participant_id in self.candidate_audio_buffers:
663
+ del self.candidate_audio_buffers[participant_id]
664
+ if participant_id in self.candidate_text_cache:
665
+ del self.candidate_text_cache[participant_id]
666
+ if participant_id in self.silence_counters:
667
+ del self.silence_counters[participant_id]
668
+ if participant_id in self.sentence_finalized:
669
+ del self.sentence_finalized[participant_id]
670
+
671
+ async def load_remaining_models_in_background(self):
672
+ """Load all remaining ASR models in the background after startup"""
673
+ try:
674
+ print("ASR: Starting background loading of additional language models...")
675
+ for lang_code in self.asr_config.keys():
676
+ if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
677
+ try:
678
+ print(f"ASR: Background loading model for {lang_code}...")
679
+ self.models_loading_status[lang_code] = "loading"
680
+
681
+ model_config = self.asr_config[lang_code]
682
+ # Build pipeline parameters
683
+ pipeline_params = {
684
+ "task": "automatic-speech-recognition",
685
+ "model": model_config["model_repo"],
686
+ "device": self.device,
687
+ "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
688
+ }
689
+
690
+ # Add trust_remote_code if specified
691
+ if model_config.get("trust_remote_code", False):
692
+ pipeline_params["trust_remote_code"] = True
693
+
694
+ pipeline_obj = pipeline(**pipeline_params)
695
+ self.asr_pipelines[lang_code] = pipeline_obj
696
+ self.models_loading_status[lang_code] = "loaded"
697
+ print(f"ASR: Successfully loaded model for {lang_code} in background")
698
+
699
+ # Add a small delay between loading models to prevent overwhelming the system
700
+ await asyncio.sleep(2)
701
+ except Exception as e:
702
+ print(f"ASR: Failed to load model for {lang_code} in background: {e}")
703
+ self.models_loading_status[lang_code] = "failed"
704
+
705
+ print("ASR: Background loading of all language models complete")
706
+ print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
707
+ except Exception as e:
708
+ print(f"ASR: Error in background model loading: {e}")
709
+
710
+ def start_background_loading(self):
711
+ """Start background loading of models as a non-blocking task"""
712
+ if self.background_loading_task is None:
713
+ self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
714
+ print("ASR: Background model loading task started")
715
+
716
+ async def cleanup(self):
717
+ """Cleanup resources"""
718
+ # Cancel background loading if still running
719
+ if self.background_loading_task and not self.background_loading_task.done():
720
+ self.background_loading_task.cancel()
721
+ try:
722
+ await self.background_loading_task
723
+ except asyncio.CancelledError:
724
+ pass
725
+
726
+ self.asr_pipelines.clear()
app/services/transcription_service_onnx.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ from typing import Dict, Optional, Callable
6
+ from collections import OrderedDict
7
+ import onnxruntime as ort
8
+ from transformers import AutoProcessor, WhisperProcessor
9
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
10
+ import os
11
+ from app.models import LanguageCode
12
+
13
+ class ONNXTranscriptionService:
14
+ def __init__(self):
15
+ self.asr_models: Dict[str, any] = {}
16
+ self.processors: Dict[str, any] = {}
17
+ self.max_asr_models = 2 # Memory management - keep max 2 models loaded
18
+ self.model_cache = OrderedDict() # LRU cache for models
19
+
20
+ # GPU optimization - detect and configure providers
21
+ available_providers = ort.get_available_providers()
22
+ print(f"ONNX ASR: Available providers: {available_providers}")
23
+
24
+ if 'CUDAExecutionProvider' in available_providers:
25
+ # Configure CUDA provider with optimizations
26
+ cuda_provider_options = {
27
+ 'device_id': 0,
28
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
29
+ 'gpu_mem_limit': int(0.8 * 1024 * 1024 * 1024), # 80% of GPU memory
30
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
31
+ 'do_copy_in_default_stream': True,
32
+ 'enable_tracing': True, # Enable tracing for better diagnostics
33
+ }
34
+
35
+ # Include TensorRT if available, then CUDA, then CPU
36
+ provider_list = []
37
+ if 'TensorrtExecutionProvider' in available_providers:
38
+ provider_list.append('TensorrtExecutionProvider')
39
+ provider_list.append(('CUDAExecutionProvider', cuda_provider_options))
40
+ provider_list.append('CPUExecutionProvider')
41
+
42
+ self.providers = provider_list
43
+ print(f"ONNX ASR: Using GPU acceleration with providers: {[p[0] if isinstance(p, tuple) else p for p in provider_list]}")
44
+ print(f"ONNX ASR: GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
45
+ else:
46
+ self.providers = ['CPUExecutionProvider']
47
+ print("ONNX ASR: CUDA not available, using CPU execution")
48
+
49
+ print(f"ONNX ASR: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
50
+
51
+ # ONNX Model configurations - using pre-converted ONNX models from HuggingFace
52
+ self.asr_config = {
53
+ "eng": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}, # Pre-converted ONNX model
54
+ "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
55
+ "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
56
+ "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
57
+ "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
58
+ "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
59
+ "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
60
+ }
61
+
62
+ # Alternative model configurations for different performance tiers
63
+ self.alternative_models = {
64
+ "eng_small": {"model_repo": "mutisya/whisper-small-en-onnx", "model_type": "whisper", "use_onnx": True},
65
+ "eng_base": {"model_repo": "mutisya/whisper-base-en-onnx", "model_type": "whisper", "use_onnx": True},
66
+ "eng_medium": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}
67
+ }
68
+
69
+ self.preload_languages = ["eng"]
70
+
71
+ # Current model performance mode (small, base, medium)
72
+ # Can be configured via environment variable WHISPER_MODEL_SIZE
73
+ self.performance_mode = os.getenv("WHISPER_MODEL_SIZE", "medium").lower()
74
+
75
+ # Enhanced audio buffering for VAD-based sentence detection
76
+ self.candidate_audio_buffers: Dict[str, bytes] = {}
77
+ self.candidate_text_cache: Dict[str, str] = {}
78
+ self.silence_counters: Dict[str, int] = {}
79
+ self.sentence_finalized: Dict[str, bool] = {}
80
+
81
+ # VAD parameters
82
+ self.silence_threshold = 2
83
+ self.min_sentence_length = 0.03
84
+
85
+ def set_performance_mode(self, mode: str):
86
+ """Set the performance mode for English models (small, base, medium)"""
87
+ if mode in ["small", "base", "medium"]:
88
+ self.performance_mode = mode
89
+ # Update the English model configuration based on performance mode
90
+ if f"eng_{mode}" in self.alternative_models:
91
+ self.asr_config["eng"] = self.alternative_models[f"eng_{mode}"]
92
+ # Clear cached English model to force reload with new configuration
93
+ if "eng" in self.model_cache:
94
+ del self.model_cache["eng"]
95
+ if "eng" in self.asr_models:
96
+ del self.asr_models["eng"]
97
+ if "eng" in self.processors:
98
+ del self.processors["eng"]
99
+ print(f"Performance mode set to {mode}. English model will be reloaded on next use.")
100
+ else:
101
+ print(f"Warning: No model configuration found for performance mode {mode}")
102
+ else:
103
+ print(f"Invalid performance mode: {mode}. Must be one of: small, base, medium")
104
+
105
+ async def initialize(self):
106
+ """Initialize ASR models for preloaded languages"""
107
+ print(f"ONNX ASR: Initializing with providers: {self.providers}")
108
+
109
+ # Apply performance mode to English model configuration
110
+ if self.performance_mode in ["small", "base", "medium"]:
111
+ if f"eng_{self.performance_mode}" in self.alternative_models:
112
+ self.asr_config["eng"] = self.alternative_models[f"eng_{self.performance_mode}"]
113
+ print(f"Using Whisper {self.performance_mode} model for English")
114
+ else:
115
+ print(f"Warning: Performance mode {self.performance_mode} not available, using default medium")
116
+
117
+ for lang_code in self.preload_languages:
118
+ if lang_code in self.asr_config:
119
+ try:
120
+ await self.ensure_model_loaded(lang_code)
121
+ except Exception as e:
122
+ print(f"Failed to load ASR model for {lang_code}: {e}")
123
+
124
+ async def ensure_model_loaded(self, language_code: str):
125
+ """Load ASR model for language if not already loaded with LRU cache"""
126
+ if language_code in self.model_cache:
127
+ # Move to end (most recently used)
128
+ self.model_cache.move_to_end(language_code)
129
+ return
130
+
131
+ if language_code not in self.asr_config:
132
+ raise ValueError(f"Language {language_code} not supported")
133
+
134
+ model_config = self.asr_config[language_code]
135
+
136
+ # Check if we need to evict old models
137
+ while len(self.model_cache) >= self.max_asr_models:
138
+ # Remove least recently used model
139
+ old_lang, _ = self.model_cache.popitem(last=False)
140
+ if old_lang in self.asr_models:
141
+ del self.asr_models[old_lang]
142
+ if old_lang in self.processors:
143
+ del self.processors[old_lang]
144
+ print(f"ONNX ASR: Evicted model for {old_lang} (LRU cache)")
145
+
146
+ try:
147
+ if model_config.get("use_onnx", False):
148
+ # Load ONNX model
149
+ print(f"ONNX ASR: Loading ONNX model for {language_code}")
150
+
151
+ # Special handling for Whisper models
152
+ if model_config.get("model_type") == "whisper":
153
+ print(f"ONNX ASR: Loading Whisper ONNX model from {model_config['model_repo']}")
154
+
155
+ # Get authentication token for private repos
156
+ import os
157
+ auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
158
+
159
+ # Load pre-converted Whisper ONNX model using Optimum
160
+ load_kwargs = {
161
+ # export=False because we're using pre-converted models
162
+ "export": False,
163
+ # use_cache=True because our models now include past key value variants for optimization
164
+ "use_cache": True,
165
+ # Add authentication token for private repos
166
+ "token": auth_token
167
+ }
168
+
169
+ # Configure providers - pass all available providers to Optimum
170
+ provider_names = [p[0] if isinstance(p, tuple) else p for p in self.providers]
171
+ load_kwargs["providers"] = provider_names
172
+ print(f"ONNX ASR: Whisper using providers: {provider_names}")
173
+
174
+ # Add subfolder if specified (for models that store ONNX in subfolders)
175
+ if "subfolder" in model_config:
176
+ load_kwargs["subfolder"] = model_config["subfolder"]
177
+
178
+ model = ORTModelForSpeechSeq2Seq.from_pretrained(
179
+ model_config["model_repo"],
180
+ **load_kwargs
181
+ )
182
+
183
+ # Load Whisper processor with authentication token
184
+ processor = WhisperProcessor.from_pretrained(
185
+ model_config["model_repo"],
186
+ token=auth_token
187
+ )
188
+
189
+ # Configure for English transcription
190
+ model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
191
+ language="en",
192
+ task="transcribe"
193
+ )
194
+
195
+ self.asr_models[language_code] = model
196
+ self.processors[language_code] = processor
197
+
198
+ print(f"ONNX ASR: Successfully loaded Whisper ONNX model for {language_code}")
199
+
200
+ else:
201
+ # Original wav2vec2-bert model loading logic
202
+ # Create ONNX session with optimizations and verbose logging
203
+ session_options = ort.SessionOptions()
204
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
205
+
206
+ # Enable verbose logging to diagnose operator assignments
207
+ session_options.log_severity_level = 1 # WARNING level for detailed logs
208
+ session_options.logid = "ONNX_ASR" # Prefix for log identification
209
+
210
+ # Use configured providers with optimizations
211
+ providers = self.providers
212
+ print(f"ONNX ASR: wav2vec2-bert using providers: {[p[0] if isinstance(p, tuple) else p for p in providers]}")
213
+
214
+ # Get authentication token for private repos
215
+ import os
216
+ auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
217
+
218
+ # Download model files from HuggingFace Hub with authentication
219
+ from huggingface_hub import hf_hub_download
220
+ onnx_path = hf_hub_download(
221
+ repo_id=model_config["model_repo"],
222
+ filename="model.onnx",
223
+ token=auth_token
224
+ )
225
+
226
+ session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
227
+
228
+ # Load processor for preprocessing with authentication
229
+ processor = AutoProcessor.from_pretrained(
230
+ model_config["model_repo"],
231
+ token=auth_token
232
+ )
233
+
234
+ self.asr_models[language_code] = session
235
+ self.processors[language_code] = processor
236
+
237
+ print(f"ONNX ASR: Successfully loaded ONNX model for {language_code}")
238
+
239
+ else:
240
+ # This service is ONNX-only - no PyTorch fallback
241
+ raise ValueError(f"Language {language_code} is not configured for ONNX models. Set 'use_onnx': True in config.")
242
+
243
+ # Add to cache
244
+ self.model_cache[language_code] = True
245
+
246
+ except Exception as e:
247
+ print(f"Failed to load ASR model for {language_code}: {e}")
248
+ raise
249
+
250
+ async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
251
+ has_voice_activity: bool = True,
252
+ progress_callback: Optional[Callable] = None,
253
+ sentence_callback: Optional[Callable] = None) -> str:
254
+ """Process audio chunk with VAD-based sentence detection using ONNX models"""
255
+ try:
256
+ # Initialize buffers if needed
257
+ if participant_id not in self.candidate_audio_buffers:
258
+ self.candidate_audio_buffers[participant_id] = b''
259
+ self.candidate_text_cache[participant_id] = ""
260
+ self.silence_counters[participant_id] = 0
261
+ self.sentence_finalized[participant_id] = False
262
+
263
+ # Convert current chunk to numpy array for processing
264
+ current_chunk_array = self._bytes_to_audio_array(audio_data)
265
+ if len(current_chunk_array) == 0:
266
+ return self.candidate_text_cache.get(participant_id, "")
267
+
268
+ # Normalize the audio chunk
269
+ current_chunk_array = current_chunk_array.astype(np.float32)
270
+ if np.max(np.abs(current_chunk_array)) > 0:
271
+ current_chunk_array /= np.max(np.abs(current_chunk_array))
272
+
273
+ # Get existing accumulated audio array
274
+ existing_buffer = self.candidate_audio_buffers[participant_id]
275
+ if len(existing_buffer) > 0:
276
+ existing_array = self._bytes_to_audio_array(existing_buffer)
277
+ if len(existing_array) > 0:
278
+ combined_array = np.concatenate([existing_array, current_chunk_array])
279
+ else:
280
+ combined_array = current_chunk_array
281
+ else:
282
+ combined_array = current_chunk_array
283
+
284
+ # Convert back to bytes for storage
285
+ combined_bytes = self._audio_array_to_bytes(combined_array)
286
+ self.candidate_audio_buffers[participant_id] = combined_bytes
287
+
288
+ # Update silence counter based on voice activity
289
+ if not has_voice_activity:
290
+ self.silence_counters[participant_id] += 1
291
+ else:
292
+ self.silence_counters[participant_id] = 0
293
+
294
+ # Check if we should finalize sentence due to prolonged silence
295
+ should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
296
+ len(combined_array) > 0 and
297
+ not self.sentence_finalized[participant_id])
298
+
299
+ if should_finalize:
300
+ return await self._finalize_candidate_sentence(
301
+ language_code, participant_id, sentence_callback
302
+ )
303
+
304
+ # Always run transcription on the accumulated audio
305
+ audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
306
+
307
+ if audio_duration_sec < 0.1: # Very short minimum
308
+ if progress_callback:
309
+ cached_text = self.candidate_text_cache.get(participant_id, "")
310
+ await progress_callback(cached_text, False)
311
+ return self.candidate_text_cache.get(participant_id, "")
312
+
313
+ # Force finalization if buffer gets too long
314
+ if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]:
315
+ return await self._finalize_candidate_sentence(
316
+ language_code, participant_id, sentence_callback
317
+ )
318
+
319
+ # Run voice activity detection on the accumulated audio before transcription
320
+ has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
321
+
322
+ if not has_voice_in_buffer:
323
+ if progress_callback:
324
+ cached_text = self.candidate_text_cache.get(participant_id, "")
325
+ await progress_callback(cached_text, False)
326
+ return self.candidate_text_cache.get(participant_id, "")
327
+
328
+ # Run transcription
329
+ await self.ensure_model_loaded(language_code)
330
+
331
+ # Double-check voice activity before running expensive ASR
332
+ has_voice_for_asr = self.has_voice_activity(combined_bytes)
333
+ if not has_voice_for_asr:
334
+ print(f"ONNX ASR: No voice activity detected, skipping ASR execution for {participant_id}")
335
+ if progress_callback:
336
+ cached_text = self.candidate_text_cache.get(participant_id, "")
337
+ await progress_callback(cached_text, False)
338
+ return self.candidate_text_cache.get(participant_id, "")
339
+
340
+ if language_code not in self.asr_models:
341
+ raise ValueError(f"ASR model not available for language: {language_code}")
342
+
343
+ print(f"ONNX ASR: Running transcription for {participant_id} with {audio_duration_sec:.2f}s of audio")
344
+
345
+ # Run ONNX inference (this service is ONNX-only)
346
+ model_config = self.asr_config[language_code]
347
+ if not model_config.get("use_onnx", False):
348
+ raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
349
+
350
+ # ONNX inference
351
+ text = await self._run_onnx_inference(combined_array, language_code)
352
+
353
+ # Filter out common ASR artifacts
354
+ artifacts = [
355
+ "thank you", "thanks", "bye", ".", ",", "?", "!",
356
+ "um", "uh", "ah", "hmm", "mm", "mhm",
357
+ "you", "the", "a", "an", "and", "but", "or",
358
+ "music", "laughter", "applause", "[music]", "[laughter]",
359
+ ]
360
+
361
+ # Check if the result is likely an artifact
362
+ is_artifact = (
363
+ len(text) < 3 or
364
+ text.lower() in artifacts or
365
+ len(text.split()) == 1 and len(text) < 6
366
+ )
367
+
368
+ if is_artifact:
369
+ text = self.candidate_text_cache.get(participant_id, "")
370
+
371
+ # Cache the current candidate text
372
+ self.candidate_text_cache[participant_id] = text
373
+
374
+ # Force completion if we have reasonable text and some silence
375
+ word_count = len(text.split()) if text else 0
376
+ if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
377
+ not self.sentence_finalized[participant_id]):
378
+ return await self._finalize_candidate_sentence(
379
+ language_code, participant_id, sentence_callback
380
+ )
381
+
382
+ # Always send progress update
383
+ if progress_callback:
384
+ await progress_callback(text, False)
385
+
386
+ return text
387
+
388
+ except Exception as e:
389
+ print(f"ONNX TranscriptionService: Error processing audio chunk: {e}")
390
+ import traceback
391
+ traceback.print_exc()
392
+ if progress_callback:
393
+ cached_text = self.candidate_text_cache.get(participant_id, "")
394
+ await progress_callback(cached_text, False)
395
+ return self.candidate_text_cache.get(participant_id, "")
396
+
397
+ async def _run_onnx_inference(self, audio_array: np.ndarray, language_code: str) -> str:
398
+ """Run ONNX inference for speech recognition"""
399
+ try:
400
+ model = self.asr_models[language_code]
401
+ processor = self.processors[language_code]
402
+ model_config = self.asr_config[language_code]
403
+
404
+ # Check if this is a Whisper model
405
+ if model_config.get("model_type") == "whisper":
406
+ # Whisper-specific processing using Optimum
407
+ import torch
408
+
409
+ # Process audio input for Whisper
410
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
411
+
412
+ # Generate transcription using the ORTModelForSpeechSeq2Seq
413
+ predicted_ids = model.generate(inputs.input_features, max_length=448)
414
+
415
+ # Decode the generated IDs
416
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
417
+
418
+ return transcription[0].strip() if transcription else ""
419
+ else:
420
+ # Original wav2vec2-bert processing
421
+ session = model
422
+
423
+ # Preprocess audio
424
+ inputs = processor(audio_array, sampling_rate=16000, return_tensors="np")
425
+
426
+ # Get input names for ONNX session
427
+ input_names = [inp.name for inp in session.get_inputs()]
428
+
429
+ # Prepare inputs for ONNX
430
+ onnx_inputs = {}
431
+ for name in input_names:
432
+ if name in inputs:
433
+ onnx_inputs[name] = inputs[name]
434
+ elif name == "input_values" and "input_features" in inputs:
435
+ onnx_inputs[name] = inputs["input_features"]
436
+ elif name == "attention_mask" and "attention_mask" in inputs:
437
+ onnx_inputs[name] = inputs["attention_mask"]
438
+
439
+ # Run ONNX inference
440
+ outputs = session.run(None, onnx_inputs)
441
+
442
+ # Post-process outputs (assuming CTC decoding)
443
+ logits = outputs[0] # First output should be logits
444
+
445
+ # Simple greedy CTC decoding
446
+ predicted_ids = np.argmax(logits, axis=-1)
447
+
448
+ # Decode using processor
449
+ text = processor.batch_decode(predicted_ids)[0]
450
+
451
+ return text.strip()
452
+
453
+ except Exception as e:
454
+ print(f"ONNX ASR: Inference error: {e}")
455
+ import traceback
456
+ traceback.print_exc()
457
+ return ""
458
+
459
+ async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
460
+ sentence_callback: Optional[Callable] = None) -> str:
461
+ """Finalize the current candidate sentence and clear buffers"""
462
+ try:
463
+ if self.sentence_finalized.get(participant_id, False):
464
+ print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
465
+ return self.candidate_text_cache.get(participant_id, "")
466
+
467
+ final_text = self.candidate_text_cache.get(participant_id, "")
468
+ final_audio_bytes = self.candidate_audio_buffers.get(participant_id, b'')
469
+
470
+ if final_text and len(final_text.strip()) > 0:
471
+ self.sentence_finalized[participant_id] = True
472
+
473
+ if sentence_callback and len(final_audio_bytes) > 0:
474
+ print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
475
+ await sentence_callback(final_text, final_audio_bytes)
476
+
477
+ # Clear buffers for next sentence
478
+ self.candidate_audio_buffers[participant_id] = b''
479
+ self.candidate_text_cache[participant_id] = ""
480
+ self.silence_counters[participant_id] = 0
481
+ self.sentence_finalized[participant_id] = False
482
+
483
+ return final_text
484
+
485
+ except Exception as e:
486
+ print(f"Error finalizing sentence: {e}")
487
+ import traceback
488
+ traceback.print_exc()
489
+ self.sentence_finalized[participant_id] = False
490
+ return ""
491
+
492
+ def has_voice_activity(self, audio_data: bytes, threshold: float = 0.0005) -> bool:
493
+ """Enhanced VAD based on audio analysis"""
494
+ try:
495
+ audio_array = self._bytes_to_audio_array(audio_data)
496
+ if len(audio_array) == 0:
497
+ return False
498
+
499
+ # Normalize audio
500
+ audio_array = audio_array.astype(np.float32)
501
+ if np.max(np.abs(audio_array)) > 0:
502
+ audio_array /= np.max(np.abs(audio_array))
503
+
504
+ # Calculate multiple features for better VAD
505
+ rms = np.sqrt(np.mean(audio_array ** 2))
506
+ peak = np.max(np.abs(audio_array))
507
+ audio_std = np.std(audio_array)
508
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
509
+
510
+ # Voice activity detection
511
+ has_voice_rms = rms > threshold
512
+ has_voice_peak = peak > threshold * 3
513
+ has_voice_variation = audio_std > threshold * 0.8
514
+ has_voice_zcr = zero_crossing_rate > 0.008
515
+
516
+ has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
517
+
518
+ return has_voice
519
+
520
+ except Exception as e:
521
+ print(f"Error in VAD: {e}")
522
+ return True
523
+
524
+ def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.002) -> bool:
525
+ """Stricter VAD check for pre-transcription filtering"""
526
+ try:
527
+ audio_array = self._bytes_to_audio_array(audio_data)
528
+ if len(audio_array) == 0:
529
+ return False
530
+
531
+ # Normalize audio
532
+ audio_array = audio_array.astype(np.float32)
533
+ if np.max(np.abs(audio_array)) > 0:
534
+ audio_array /= np.max(np.abs(audio_array))
535
+
536
+ # Calculate features with higher thresholds
537
+ rms = np.sqrt(np.mean(audio_array ** 2))
538
+ peak = np.max(np.abs(audio_array))
539
+ audio_std = np.std(audio_array)
540
+ zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
541
+
542
+ # Higher thresholds for meaningful speech detection
543
+ has_meaningful_voice = (
544
+ rms > threshold and
545
+ peak > threshold * 2 and
546
+ audio_std > threshold * 0.5 and
547
+ zero_crossing_rate > 0.015
548
+ )
549
+
550
+ return has_meaningful_voice
551
+
552
+ except Exception as e:
553
+ print(f"Error in meaningful VAD: {e}")
554
+ return False
555
+
556
+ async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
557
+ """Force complete any pending sentence for a participant"""
558
+ try:
559
+ if self.sentence_finalized.get(participant_id, False):
560
+ print(f"Force completion: Sentence for participant {participant_id} already finalized")
561
+ return ""
562
+
563
+ if participant_id in self.candidate_text_cache:
564
+ cached_text = self.candidate_text_cache[participant_id]
565
+
566
+ if cached_text and len(cached_text.strip()) > 0:
567
+ result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
568
+ return result
569
+
570
+ return ""
571
+
572
+ except Exception as e:
573
+ print(f"Error in force_complete_sentence: {e}")
574
+ import traceback
575
+ traceback.print_exc()
576
+ return ""
577
+
578
+ async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
579
+ """Transcribe audio data to text using ONNX models"""
580
+ try:
581
+ # Check for voice activity before running ASR
582
+ has_voice = self.has_voice_activity(audio_data)
583
+ if not has_voice:
584
+ print(f"ONNX ASR: No voice activity detected, skipping transcription")
585
+ return ""
586
+
587
+ await self.ensure_model_loaded(language_code)
588
+
589
+ if language_code not in self.asr_models:
590
+ raise ValueError(f"ASR model not available for language: {language_code}")
591
+
592
+ # Convert audio bytes to numpy array
593
+ audio_array = self._bytes_to_audio_array(audio_data)
594
+
595
+ print(f"ONNX ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
596
+
597
+ # Run ONNX inference (this service is ONNX-only)
598
+ model_config = self.asr_config[language_code]
599
+ if not model_config.get("use_onnx", False):
600
+ raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
601
+
602
+ # ONNX inference
603
+ text = await self._run_onnx_inference(audio_array, language_code)
604
+
605
+ if callback:
606
+ await callback(text)
607
+
608
+ return text
609
+
610
+ except Exception as e:
611
+ print(f"ONNX TranscriptionService: Transcription error: {e}")
612
+ import traceback
613
+ traceback.print_exc()
614
+ return ""
615
+
616
+ def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
617
+ """Convert audio bytes to numpy array"""
618
+ try:
619
+ # Try to decode as WAV
620
+ try:
621
+ audio_io = io.BytesIO(audio_data)
622
+ with wave.open(audio_io, 'rb') as wav_file:
623
+ frames = wav_file.readframes(-1)
624
+ audio_array = np.frombuffer(frames, dtype=np.int16)
625
+ # Convert to float32 and normalize
626
+ audio_array = audio_array.astype(np.float32) / 32768.0
627
+ return audio_array
628
+ except Exception:
629
+ pass
630
+
631
+ # Fallback: assume raw float32 audio data
632
+ try:
633
+ audio_array = np.frombuffer(audio_data, dtype=np.float32)
634
+ return audio_array
635
+ except Exception:
636
+ pass
637
+
638
+ return np.array([], dtype=np.float32)
639
+
640
+ except Exception as e:
641
+ print(f"ONNX TranscriptionService: Audio conversion error: {e}")
642
+ return np.array([], dtype=np.float32)
643
+
644
+ def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
645
+ """Convert numpy audio array back to WAV bytes for storage"""
646
+ try:
647
+ if audio_array.dtype != np.float32:
648
+ audio_array = audio_array.astype(np.float32)
649
+
650
+ # Convert to 16-bit PCM for WAV storage
651
+ audio_int16 = (audio_array * 32767).astype(np.int16)
652
+
653
+ # Create WAV bytes
654
+ wav_buffer = io.BytesIO()
655
+ with wave.open(wav_buffer, 'wb') as wav_file:
656
+ wav_file.setnchannels(1) # Mono
657
+ wav_file.setsampwidth(2) # 16-bit
658
+ wav_file.setframerate(16000) # 16kHz
659
+ wav_file.writeframes(audio_int16.tobytes())
660
+
661
+ return wav_buffer.getvalue()
662
+
663
+ except Exception as e:
664
+ print(f"Error converting audio array to bytes: {e}")
665
+ return b''
666
+
667
+ def clear_participant_buffers(self, participant_id: str):
668
+ """Clear all buffers for a participant"""
669
+ if participant_id in self.candidate_audio_buffers:
670
+ del self.candidate_audio_buffers[participant_id]
671
+ if participant_id in self.candidate_text_cache:
672
+ del self.candidate_text_cache[participant_id]
673
+ if participant_id in self.silence_counters:
674
+ del self.silence_counters[participant_id]
675
+ if participant_id in self.sentence_finalized:
676
+ del self.sentence_finalized[participant_id]
677
+
678
+ async def cleanup(self):
679
+ """Cleanup resources"""
680
+ self.asr_models.clear()
681
+ self.processors.clear()
682
+ self.model_cache.clear()
app/services/transcription_service_onnx_optimized.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ from typing import Dict, Optional, Callable
6
+ from collections import OrderedDict
7
+ import onnxruntime as ort
8
+ from transformers import AutoProcessor, WhisperProcessor
9
+ from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
10
+ import os
11
+ from app.models import LanguageCode
12
+
13
+ class OptimizedONNXTranscriptionService:
14
+ """
15
+ Optimized ONNX Transcription Service that uses pre-converted ONNX models
16
+ instead of performing runtime conversion from PyTorch models.
17
+
18
+ Benefits:
19
+ - Faster container startup (no conversion time)
20
+ - Reduced memory usage during initialization
21
+ - More predictable deployment times
22
+ - Better resource utilization in production
23
+ """
24
+
25
+ def __init__(self):
26
+ self.asr_models: Dict[str, any] = {}
27
+ self.processors: Dict[str, any] = {}
28
+ self.max_asr_models = 2 # Memory management - keep max 2 models loaded
29
+ self.model_cache = OrderedDict() # LRU cache for models
30
+
31
+ # GPU optimization
32
+ self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if ort.get_available_providers().__contains__('CUDAExecutionProvider') else ['CPUExecutionProvider']
33
+
34
+ # OPTIMIZED ONNX Model configurations - using pre-converted models
35
+ self.asr_config = {
36
+ # English: Use pre-converted ONNX model (no runtime conversion!)
37
+ "eng": {
38
+ "model_repo": "mutisya/whisper-medium-en-onnx", # Pre-converted ONNX model
39
+ "model_type": "whisper",
40
+ "use_onnx": True,
41
+ "export": False # ⭐ KEY CHANGE: No runtime export needed!
42
+ },
43
+
44
+ # African languages: Already using ONNX models
45
+ "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
46
+ "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
47
+ "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
48
+ "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
49
+ "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
50
+ "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
51
+ }
52
+
53
+ self.preload_languages = ["eng"]
54
+
55
+ # Enhanced audio buffering for VAD-based sentence detection
56
+ self.candidate_audio_buffers: Dict[str, bytes] = {}
57
+ self.candidate_text_cache: Dict[str, str] = {}
58
+ self.silence_counters: Dict[str, int] = {}
59
+ self.sentence_finalized: Dict[str, bool] = {}
60
+
61
+ # VAD parameters
62
+ self.silence_threshold = 2
63
+ self.min_sentence_length = 0.03
64
+
65
+ async def initialize(self):
66
+ """Initialize ASR models for preloaded languages"""
67
+ print(f"🚀 Optimized ONNX ASR: Initializing with providers: {self.providers}")
68
+ print(f"📈 Performance Improvement: Using pre-converted ONNX models (no runtime conversion)")
69
+
70
+ for lang_code in self.preload_languages:
71
+ if lang_code in self.asr_config:
72
+ try:
73
+ start_time = asyncio.get_event_loop().time()
74
+ await self.ensure_model_loaded(lang_code)
75
+ end_time = asyncio.get_event_loop().time()
76
+ print(f"⚡ Model loading time for {lang_code}: {end_time - start_time:.2f}s")
77
+ except Exception as e:
78
+ print(f"❌ Failed to load ASR model for {lang_code}: {e}")
79
+
80
+ async def ensure_model_loaded(self, language_code: str):
81
+ """Load ASR model for language if not already loaded with LRU cache"""
82
+ if language_code in self.model_cache:
83
+ # Move to end (most recently used)
84
+ self.model_cache.move_to_end(language_code)
85
+ return
86
+
87
+ if language_code not in self.asr_config:
88
+ raise ValueError(f"Language {language_code} not supported")
89
+
90
+ model_config = self.asr_config[language_code]
91
+
92
+ # Check if we need to evict old models
93
+ while len(self.model_cache) >= self.max_asr_models:
94
+ # Remove least recently used model
95
+ old_lang, _ = self.model_cache.popitem(last=False)
96
+ if old_lang in self.asr_models:
97
+ del self.asr_models[old_lang]
98
+ if old_lang in self.processors:
99
+ del self.processors[old_lang]
100
+ print(f"🗑️ ONNX ASR: Evicted model for {old_lang} (LRU cache)")
101
+
102
+ try:
103
+ if model_config.get("use_onnx", False):
104
+ # Load ONNX model
105
+ print(f"📥 ONNX ASR: Loading ONNX model for {language_code}")
106
+
107
+ # Special handling for Whisper models
108
+ if model_config.get("model_type") == "whisper":
109
+ print(f"🎙️ ONNX ASR: Loading pre-converted Whisper ONNX model from {model_config['model_repo']}")
110
+
111
+ # Load pre-converted Whisper ONNX model using Optimum
112
+ load_kwargs = {
113
+ # Note: No 'export' parameter needed since model is already in ONNX format
114
+ # This is the key optimization - no runtime conversion!
115
+ }
116
+
117
+ # Add subfolder if specified (for models that store ONNX in subfolders)
118
+ if "subfolder" in model_config:
119
+ load_kwargs["subfolder"] = model_config["subfolder"]
120
+
121
+ # ⭐ KEY OPTIMIZATION: No export flag needed for pre-converted models
122
+ # The old code had: if model_config.get("export", False): load_kwargs["export"] = True
123
+ # Now we skip this entirely since the model is already in ONNX format
124
+
125
+ model = ORTModelForSpeechSeq2Seq.from_pretrained(
126
+ model_config["model_repo"],
127
+ **load_kwargs
128
+ )
129
+
130
+ # Load Whisper processor
131
+ processor = WhisperProcessor.from_pretrained(model_config["model_repo"])
132
+
133
+ # Configure for English transcription
134
+ model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
135
+ language="en",
136
+ task="transcribe"
137
+ )
138
+
139
+ self.asr_models[language_code] = model
140
+ self.processors[language_code] = processor
141
+
142
+ print(f"✅ ONNX ASR: Successfully loaded pre-converted Whisper ONNX model for {language_code}")
143
+
144
+ else:
145
+ # Original wav2vec2-bert model loading logic (unchanged)
146
+ # Create ONNX session with optimizations
147
+ session_options = ort.SessionOptions()
148
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
149
+
150
+ # Enable parallel execution
151
+ session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
152
+
153
+ model_path = model_config["model_repo"]
154
+
155
+ try:
156
+ # Try to load from HuggingFace directly
157
+ from huggingface_hub import hf_hub_download
158
+ model_file = hf_hub_download(repo_id=model_path, filename="model.onnx")
159
+
160
+ # Create ONNX Runtime session
161
+ session = ort.InferenceSession(
162
+ model_file,
163
+ session_options,
164
+ providers=self.providers
165
+ )
166
+
167
+ # Load processor/tokenizer
168
+ processor = AutoProcessor.from_pretrained(model_path)
169
+
170
+ self.asr_models[language_code] = session
171
+ self.processors[language_code] = processor
172
+
173
+ print(f"✅ ONNX ASR: Successfully loaded {model_config['model_type']} ONNX model for {language_code}")
174
+
175
+ except Exception as e:
176
+ print(f"❌ Error loading ONNX model {model_path}: {e}")
177
+ raise
178
+
179
+ else:
180
+ raise ValueError(f"Non-ONNX models not supported in optimized service")
181
+
182
+ # Add to cache
183
+ self.model_cache[language_code] = True
184
+
185
+ except Exception as e:
186
+ print(f"❌ Error loading model for {language_code}: {e}")
187
+ raise
188
+
189
+ # Rest of the methods remain the same as the original transcription service
190
+ # (transcribe_audio, process_audio_chunk, etc.)
191
+ # ... [Include all other methods from the original service]
192
+
193
+ async def transcribe_audio(self, participant_id: str, audio_data: bytes, language_code: str = "eng") -> Optional[str]:
194
+ """Transcribe audio using ONNX models"""
195
+ try:
196
+ await self.ensure_model_loaded(language_code)
197
+
198
+ if language_code not in self.asr_models or language_code not in self.processors:
199
+ raise ValueError(f"Model not loaded for language: {language_code}")
200
+
201
+ model = self.asr_models[language_code]
202
+ processor = self.processors[language_code]
203
+
204
+ # Convert audio bytes to numpy array
205
+ audio_io = io.BytesIO(audio_data)
206
+ with wave.open(audio_io, 'rb') as wav_file:
207
+ frames = wav_file.readframes(-1)
208
+ sample_rate = wav_file.getframerate()
209
+ audio_np = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
210
+
211
+ # Get model configuration
212
+ model_config = self.asr_config[language_code]
213
+
214
+ if model_config.get("model_type") == "whisper":
215
+ # Process with Whisper ONNX model
216
+ inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="pt")
217
+
218
+ with torch.no_grad():
219
+ predicted_ids = model.generate(**inputs)
220
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
221
+
222
+ return transcription.strip()
223
+
224
+ else:
225
+ # Process with wav2vec2-bert ONNX model
226
+ inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="np")
227
+
228
+ # Run ONNX inference
229
+ ort_inputs = {model.get_inputs()[0].name: inputs.input_values}
230
+ ort_outputs = model.run(None, ort_inputs)
231
+
232
+ # Decode results
233
+ predicted_ids = np.argmax(ort_outputs[0], axis=-1)
234
+ transcription = processor.decode(predicted_ids[0])
235
+
236
+ return transcription.strip()
237
+
238
+ except Exception as e:
239
+ print(f"❌ Transcription error for {participant_id}: {e}")
240
+ return None
241
+
242
+ def get_performance_stats(self) -> Dict[str, any]:
243
+ """Get performance statistics for monitoring"""
244
+ return {
245
+ "loaded_models": list(self.model_cache.keys()),
246
+ "cache_size": len(self.model_cache),
247
+ "max_cache_size": self.max_asr_models,
248
+ "providers": self.providers,
249
+ "optimization_enabled": True,
250
+ "runtime_conversion": False # Key metric: no runtime conversion
251
+ }
app/services/translation_service.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict, Optional
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import torch
5
+ import nltk
6
+ from app.models import LanguageCode
7
+ from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
8
+
9
+ # FLORES-200 language codes mapping
10
+ FLORES_CODES = {
11
+ "English": "eng_Latn",
12
+ "eng": "eng_Latn",
13
+ "Swahili": "swh_Latn",
14
+ "swa": "swh_Latn",
15
+ "Kikuyu": "kik_Latn",
16
+ "kik": "kik_Latn",
17
+ "Kamba": "kam_Latn",
18
+ "kam": "kam_Latn",
19
+ "Kimeru": "mer_Latn",
20
+ "mer": "mer_Latn",
21
+ "Luo": "luo_Latn",
22
+ "luo": "luo_Latn",
23
+ "Somali": "som_Latn",
24
+ "som": "som_Latn",
25
+ }
26
+
27
+ class TranslationService:
28
+ def __init__(self, enable_quantization: bool = True):
29
+ self.translation_pipeline = None
30
+ self.device = 0 if torch.cuda.is_available() else -1
31
+ self.model_path = "mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4"
32
+ self.enable_quantization = enable_quantization
33
+
34
+ async def initialize(self):
35
+ """Initialize translation model"""
36
+ try:
37
+ # Download NLTK data with better error handling
38
+ try:
39
+ nltk.download("punkt", quiet=True)
40
+ nltk.download('punkt_tab', quiet=True)
41
+ except Exception as nltk_error:
42
+ print(f"Warning: NLTK data download failed: {nltk_error}")
43
+ # Continue anyway, sentence tokenization might still work
44
+
45
+ # Load translation model with explicit model kwargs for newer transformers
46
+ print(f"Loading translation model: {self.model_path}")
47
+ model = AutoModelForSeq2SeqLM.from_pretrained(
48
+ self.model_path,
49
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
50
+ )
51
+ tokenizer = AutoTokenizer.from_pretrained(self.model_path)
52
+
53
+ # Apply quantization if enabled
54
+ if self.enable_quantization:
55
+ try:
56
+ print("Applying INT8 quantization to translation model...")
57
+ model = apply_dynamic_int8_quantization(model, "translation")
58
+ stats = get_quantization_stats(model)
59
+ print(f"✓ Translation model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
60
+ except Exception as e:
61
+ print(f"Warning: Could not quantize translation model: {e}")
62
+ print(f"Continuing with unquantized model")
63
+
64
+ self.translation_pipeline = pipeline(
65
+ 'translation',
66
+ model=model,
67
+ tokenizer=tokenizer,
68
+ device=self.device,
69
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
70
+ )
71
+
72
+ except Exception as e:
73
+ print(f"Failed to initialize translation service: {e}")
74
+ raise
75
+
76
+ async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
77
+ """Translate text from source language to target language"""
78
+ print(f"=== TRANSLATION REQUEST ===")
79
+ print(f"Text: '{text}'")
80
+ print(f"Source: {source_lang}")
81
+ print(f"Target: {target_lang}")
82
+
83
+ if not self.translation_pipeline:
84
+ print("TRANSLATION ERROR: Translation service not initialized")
85
+ raise RuntimeError("Translation service not initialized")
86
+
87
+ if not text or not text.strip():
88
+ print("TRANSLATION ERROR: Empty text provided")
89
+ return ""
90
+
91
+ try:
92
+ # Get FLORES codes
93
+ src_code = FLORES_CODES.get(source_lang, "eng_Latn")
94
+ tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
95
+
96
+ print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
97
+
98
+ # Skip translation if same language
99
+ if src_code == tgt_code:
100
+ print("TRANSLATION SKIPPED: Same source and target language")
101
+ return text
102
+
103
+ # Tokenize into sentences for better translation
104
+ sentences = nltk.sent_tokenize(text)
105
+ translated_sentences = []
106
+
107
+ print(f"Translating {len(sentences)} sentences...")
108
+
109
+ for i, sentence in enumerate(sentences):
110
+ if sentence.strip():
111
+ print(f"Translating sentence {i+1}: '{sentence}'")
112
+
113
+ result = self.translation_pipeline(
114
+ sentence,
115
+ src_lang=src_code,
116
+ tgt_lang=tgt_code
117
+ )
118
+
119
+ translated = result[0]['translation_text']
120
+ print(f"Translation result: '{translated}'")
121
+
122
+ # Preserve punctuation and capitalization
123
+ if sentence.strip().endswith(".") and not translated.strip().endswith("."):
124
+ translated += "."
125
+
126
+ if sentence.strip()[0].isupper() and translated.strip():
127
+ translated = translated[0].upper() + translated[1:]
128
+
129
+ translated_sentences.append(translated)
130
+
131
+ final_translation = " ".join(translated_sentences)
132
+
133
+ # Preserve paragraph breaks
134
+ if text.endswith(".\n\n"):
135
+ final_translation += ".\n\n"
136
+
137
+ print(f"FINAL TRANSLATION: '{final_translation}'")
138
+ print(f"=== TRANSLATION COMPLETE ===")
139
+
140
+ return final_translation
141
+
142
+ except Exception as e:
143
+ print(f"TRANSLATION ERROR: {e}")
144
+ import traceback
145
+ traceback.print_exc()
146
+ return text # Return original text if translation fails
147
+
148
+ async def cleanup(self):
149
+ """Cleanup resources"""
150
+ self.translation_pipeline = None
151
+
app/services/translation_service_onnx.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Dict, Optional
3
+ from transformers import AutoTokenizer, pipeline
4
+ from optimum.onnxruntime import ORTModelForSeq2SeqLM
5
+ import nltk
6
+ from app.models import LanguageCode
7
+
8
+ # FLORES-200 language codes mapping
9
+ FLORES_CODES = {
10
+ "English": "eng_Latn",
11
+ "eng": "eng_Latn",
12
+ "Swahili": "swh_Latn",
13
+ "swa": "swh_Latn",
14
+ "Kikuyu": "kik_Latn",
15
+ "kik": "kik_Latn",
16
+ "Kamba": "kam_Latn",
17
+ "kam": "kam_Latn",
18
+ "Kimeru": "mer_Latn",
19
+ "mer": "mer_Latn",
20
+ "Luo": "luo_Latn",
21
+ "luo": "luo_Latn",
22
+ "Somali": "som_Latn",
23
+ "som": "som_Latn",
24
+ }
25
+
26
+ class ONNXTranslationService:
27
+ def __init__(self):
28
+ self.model = None
29
+ self.tokenizer = None
30
+ self.translation_pipeline = None
31
+
32
+ # Use ONNX optimized NLLB model (FP32 format with separate encoder/decoder)
33
+ self.model_repo = "mutisya/nllb-translation-onnx-v25-37-1"
34
+
35
+ async def initialize(self):
36
+ """Initialize ONNX translation model using optimum.onnxruntime"""
37
+ try:
38
+ print("ONNX Translation: Initializing translation service with ONNX Runtime...")
39
+ print(f"ONNX Translation: Loading model from {self.model_repo}")
40
+
41
+ # Check available providers for GPU detection
42
+ import onnxruntime as ort
43
+ available_providers = ort.get_available_providers()
44
+ print(f"ONNX Translation: Available providers: {available_providers}")
45
+
46
+ # Download NLTK data with better error handling
47
+ try:
48
+ nltk.download("punkt", quiet=True)
49
+ nltk.download('punkt_tab', quiet=True)
50
+ except Exception as nltk_error:
51
+ print(f"Warning: NLTK data download failed: {nltk_error}")
52
+
53
+ # Get authentication token for private repo
54
+ import os
55
+ auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
56
+
57
+ # Configure providers list for optimal performance
58
+ print("ONNX Translation: Configuring execution providers...")
59
+ if 'CUDAExecutionProvider' in available_providers:
60
+ # Use both CUDA and CPU providers to eliminate assignment warnings
61
+ providers_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
62
+ primary_provider = 'CUDAExecutionProvider'
63
+ print(f"ONNX Translation: Using providers: {providers_list} (primary: {primary_provider})")
64
+ else:
65
+ providers_list = ['CPUExecutionProvider']
66
+ primary_provider = 'CPUExecutionProvider'
67
+ print(f"ONNX Translation: Using CPU-only providers: {providers_list}")
68
+
69
+ # Load ONNX model using optimum (handles separate encoder/decoder files)
70
+ # Configure session options for optimal CUDA performance
71
+ import onnxruntime as ort
72
+ session_options = ort.SessionOptions()
73
+ session_options.log_severity_level = 1 # WARNING level for detailed logs
74
+ session_options.logid = "ONNX_Translation"
75
+
76
+ # Enable all graph optimizations to reduce memcpy operations
77
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
78
+
79
+ # Optimize threading for better GPU utilization
80
+ session_options.inter_op_num_threads = 1 # Reduce CPU thread contention
81
+ session_options.intra_op_num_threads = 1 # Focus on GPU execution
82
+
83
+ # Note: enable_cuda_graph not available in this ONNX Runtime version
84
+
85
+ # Configure provider options with performance optimizations for CUDA
86
+ provider_options = []
87
+ if primary_provider == 'CUDAExecutionProvider':
88
+ cuda_options = {
89
+ 'device_id': 0,
90
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
91
+ 'gpu_mem_limit': int(0.6 * 1024 * 1024 * 1024), # 60% of GPU memory for translation
92
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
93
+ 'cudnn_conv_use_max_workspace': '1', # Enable max workspace for fp16 tensor cores
94
+ 'do_copy_in_default_stream': True,
95
+ 'enable_skip_layer_norm_strict_mode': False, # Better performance for transformers
96
+ 'prefer_nhwc': True, # Optimize data layout for GPU
97
+ }
98
+ # Configure providers with options
99
+ provider_options = [
100
+ ('CUDAExecutionProvider', cuda_options),
101
+ ('CPUExecutionProvider', {})
102
+ ]
103
+
104
+ # Try with optimized provider configuration and session options
105
+ try:
106
+ print("ONNX Translation: Attempting optimized provider configuration...")
107
+ self.model = ORTModelForSeq2SeqLM.from_pretrained(
108
+ self.model_repo,
109
+ token=auth_token,
110
+ providers=provider_options if provider_options else providers_list, # Use provider options or list
111
+ session_options=session_options, # Add session options
112
+ )
113
+ print(f"ONNX Translation: Model loaded successfully with providers: {providers_list}")
114
+
115
+ # Check what providers the model is actually using
116
+ if hasattr(self.model, 'providers'):
117
+ print(f"ONNX Translation: Model is using providers: {self.model.providers}")
118
+ if hasattr(self.model, 'device'):
119
+ print(f"ONNX Translation: Model device: {self.model.device}")
120
+
121
+ except Exception as e1:
122
+ print(f"ONNX Translation: Optimized provider approach failed: {e1}")
123
+ print("ONNX Translation: Falling back to simple provider list...")
124
+
125
+ # Fallback: Try with simple provider list (no options)
126
+ try:
127
+ self.model = ORTModelForSeq2SeqLM.from_pretrained(
128
+ self.model_repo,
129
+ token=auth_token,
130
+ providers=providers_list, # Simple provider list
131
+ session_options=session_options,
132
+ )
133
+ print(f"ONNX Translation: Model loaded successfully with simple providers: {providers_list}")
134
+
135
+ # Check what the model is actually using
136
+ if hasattr(self.model, 'providers'):
137
+ print(f"ONNX Translation: Model is using providers: {self.model.providers}")
138
+ if hasattr(self.model, 'device'):
139
+ print(f"ONNX Translation: Model device: {self.model.device}")
140
+
141
+ except Exception as e2:
142
+ print(f"ONNX Translation: Simple provider approach failed: {e2}")
143
+ print("ONNX Translation: Falling back to auto-detect...")
144
+
145
+ # Final fallback: Let model auto-detect
146
+ self.model = ORTModelForSeq2SeqLM.from_pretrained(
147
+ self.model_repo,
148
+ token=auth_token
149
+ # Not passing provider, letting it auto-detect based on device
150
+ )
151
+ print(f"ONNX Translation: Model loaded successfully with auto-detection")
152
+
153
+ # Check what the model is actually using
154
+ if hasattr(self.model, 'providers'):
155
+ print(f"ONNX Translation: Model auto-selected providers: {self.model.providers}")
156
+ if hasattr(self.model, 'device'):
157
+ print(f"ONNX Translation: Model device: {self.model.device}")
158
+
159
+ # Load tokenizer
160
+ self.tokenizer = AutoTokenizer.from_pretrained(
161
+ self.model_repo,
162
+ token=auth_token
163
+ )
164
+
165
+ # Create translation pipeline
166
+ # For ONNX models, we should specify device to ensure pipeline uses GPU
167
+ # Use the same provider detection as the model to ensure consistency
168
+ device = 0 if primary_provider == 'CUDAExecutionProvider' else -1
169
+ print(f"ONNX Translation: Setting pipeline device to: {device} ({'GPU' if device >= 0 else 'CPU'})")
170
+ print(f"ONNX Translation: Pipeline will use device based on primary provider: {primary_provider}")
171
+
172
+ self.translation_pipeline = pipeline(
173
+ "translation",
174
+ model=self.model,
175
+ tokenizer=self.tokenizer,
176
+ device=device
177
+ )
178
+
179
+ print("ONNX Translation: Successfully initialized ONNX translation model")
180
+
181
+ except Exception as e:
182
+ print(f"Failed to initialize ONNX translation service: {e}")
183
+ print("ONNX translation model is not available. Please ensure the model repository exists and contains the required ONNX files.")
184
+ import traceback
185
+ traceback.print_exc()
186
+ raise RuntimeError(f"ONNX translation model unavailable at {self.model_repo}: {e}")
187
+
188
+ async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
189
+ """Translate text from source language to target language using ONNX"""
190
+ print(f"=== ONNX TRANSLATION REQUEST ===")
191
+ print(f"Text: '{text}'")
192
+ print(f"Source: {source_lang}")
193
+ print(f"Target: {target_lang}")
194
+
195
+ if not self.translation_pipeline:
196
+ print("ONNX TRANSLATION ERROR: Translation service not initialized")
197
+ raise RuntimeError("ONNX Translation service not initialized")
198
+
199
+ if not text or not text.strip():
200
+ print("ONNX TRANSLATION ERROR: Empty text provided")
201
+ return ""
202
+
203
+ try:
204
+ # Get FLORES codes
205
+ src_code = FLORES_CODES.get(source_lang, "eng_Latn")
206
+ tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
207
+
208
+ print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
209
+
210
+ # Skip translation if same language
211
+ if src_code == tgt_code:
212
+ print("ONNX TRANSLATION SKIPPED: Same source and target language")
213
+ return text
214
+
215
+ # Tokenize into sentences for better translation
216
+ sentences = nltk.sent_tokenize(text)
217
+ translated_sentences = []
218
+
219
+ print(f"Translating {len(sentences)} sentences with ONNX...")
220
+
221
+ for i, sentence in enumerate(sentences):
222
+ if sentence.strip():
223
+ print(f"Translating sentence {i+1}: '{sentence}'")
224
+
225
+ # Use the pipeline for translation
226
+ result = self.translation_pipeline(
227
+ sentence.strip(),
228
+ src_lang=src_code,
229
+ tgt_lang=tgt_code,
230
+ max_length=512
231
+ )
232
+
233
+ translated = result[0]['translation_text']
234
+ print(f"ONNX Translation result: '{translated}'")
235
+
236
+ # Preserve punctuation and capitalization
237
+ if sentence.strip().endswith(".") and not translated.strip().endswith("."):
238
+ translated += "."
239
+
240
+ if sentence.strip() and sentence.strip()[0].isupper() and translated.strip():
241
+ translated = translated[0].upper() + translated[1:]
242
+
243
+ translated_sentences.append(translated)
244
+
245
+ final_translation = " ".join(translated_sentences)
246
+
247
+ # Preserve paragraph breaks
248
+ if text.endswith(".\n\n"):
249
+ final_translation += ".\n\n"
250
+
251
+ print(f"ONNX FINAL TRANSLATION: '{final_translation}'")
252
+ print(f"=== ONNX TRANSLATION COMPLETE ===")
253
+
254
+ return final_translation
255
+
256
+ except Exception as e:
257
+ print(f"ONNX TRANSLATION ERROR: {e}")
258
+ import traceback
259
+ traceback.print_exc()
260
+ raise RuntimeError(f"Translation failed: {e}")
261
+
262
+
263
+ async def cleanup(self):
264
+ """Cleanup resources"""
265
+ self.model = None
266
+ self.tokenizer = None
267
+ self.translation_pipeline = None
268
+ print("ONNX Translation: Translation service cleaned up")
app/services/tts_service.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ import subprocess
6
+ from typing import Dict, Optional
7
+ from transformers import pipeline
8
+ import torch
9
+ import os
10
+ from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
11
+
12
+ class TTSService:
13
+ def __init__(self, enable_quantization: bool = True):
14
+ self.tts_pipelines: Dict[str, any] = {}
15
+ self.device = 0 if torch.cuda.is_available() else -1
16
+ self.enable_quantization = enable_quantization
17
+
18
+ # Check if espeak is available
19
+ self.espeak_available = self._check_espeak_availability()
20
+
21
+ # TTS model configurations from your original code
22
+ self.tts_config = {
23
+ "kik": {"model_repo": "mutisya/vits_kik_drL_24_5-v24_27_1_f", "model_type": "vits"},
24
+ "luo": {"model_repo": "mutisya/vits_luo_drL_24_5-v24_27_1_f", "model_type": "vits"},
25
+ "kam": {"model_repo": "mutisya/vits_kam_drL_24_5-v24_27_1_f", "model_type": "vits"},
26
+ "mer": {"model_repo": "mutisya/vits_mer_drL_24_5-v24_27_1_f", "model_type": "vits"},
27
+ "som": {"model_repo": "mutisya/vits_som_drL_24_5-v24_27_1_m", "model_type": "vits"},
28
+ "swa": {"model_repo": "mutisya/vits_swh_biblica-v24_27_1_m", "model_type": "vits"},
29
+ "eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits"},
30
+ }
31
+
32
+ # Alternative TTS models that don't require espeak (fallback)
33
+ self.fallback_tts_config = {
34
+ "eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
35
+ "swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
36
+ "som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
37
+ }
38
+
39
+ self.preload_languages = ["kik", "swa"]
40
+ self.background_loading_task = None
41
+ self.models_loading_status = {}
42
+
43
+ def _check_espeak_availability(self) -> bool:
44
+ """Check if espeak is available on the system"""
45
+ try:
46
+ result = subprocess.run(['espeak', '--version'],
47
+ capture_output=True, text=True, timeout=5)
48
+ if result.returncode == 0:
49
+ print("TTS: espeak is available")
50
+ return True
51
+ else:
52
+ print("TTS: espeak command failed")
53
+ return False
54
+ except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
55
+ print(f"TTS: espeak not available: {e}")
56
+ return False
57
+
58
+ async def initialize(self):
59
+ """Initialize TTS models for preloaded languages"""
60
+ print("TTS: Initializing TTS service...")
61
+ print(f"TTS: espeak available: {self.espeak_available}")
62
+
63
+ for lang_code in self.preload_languages:
64
+ await self.ensure_model_loaded(lang_code)
65
+
66
+ def _load_and_quantize_tts_pipeline(self, lang_code: str, model_repo: str, model_type: str = "vits"):
67
+ """Load TTS pipeline and optionally apply INT8 quantization"""
68
+ print(f"TTS: Loading model for {lang_code}: {model_repo}")
69
+
70
+ pipeline_obj = pipeline(
71
+ "text-to-speech",
72
+ model=model_repo,
73
+ device=self.device
74
+ )
75
+
76
+ # Apply quantization if enabled
77
+ if self.enable_quantization:
78
+ try:
79
+ # Get the underlying model from the pipeline
80
+ model = pipeline_obj.model
81
+
82
+ print(f"TTS: Applying INT8 quantization to {lang_code} model...")
83
+ quantized_model = apply_dynamic_int8_quantization(model, model_type)
84
+
85
+ # Replace the model in the pipeline
86
+ pipeline_obj.model = quantized_model
87
+
88
+ # Print quantization stats
89
+ stats = get_quantization_stats(quantized_model)
90
+ print(f"✓ TTS {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
91
+
92
+ except Exception as e:
93
+ print(f"TTS: Warning - Could not quantize {lang_code} model: {e}")
94
+ print(f"TTS: Continuing with unquantized model")
95
+
96
+ return pipeline_obj
97
+
98
+
99
+ async def ensure_model_loaded(self, language_code: str):
100
+ """Load TTS model for language if not already loaded"""
101
+ if language_code in self.tts_pipelines:
102
+ return
103
+
104
+ # First try to load primary model if espeak is available
105
+ if self.espeak_available and language_code in self.tts_config:
106
+ try:
107
+ model_config = self.tts_config[language_code]
108
+ pipeline_obj = self._load_and_quantize_tts_pipeline(
109
+ language_code,
110
+ model_config["model_repo"],
111
+ model_config.get("model_type", "vits")
112
+ )
113
+ self.tts_pipelines[language_code] = pipeline_obj
114
+ print(f"TTS: Loaded primary TTS model for {language_code}")
115
+ return
116
+ except Exception as e:
117
+ print(f"TTS: Failed to load primary TTS model for {language_code}: {e}")
118
+ # Continue to try fallback models
119
+
120
+ # Try fallback models if primary failed or espeak not available
121
+ if language_code in self.fallback_tts_config:
122
+ try:
123
+ model_config = self.fallback_tts_config[language_code]
124
+
125
+ if model_config["model_type"] == "speecht5":
126
+ # Special handling for SpeechT5
127
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
128
+ import torch
129
+
130
+ processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
131
+ model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
132
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
133
+
134
+ # Create a custom pipeline-like object
135
+ class SpeechT5Pipeline:
136
+ def __init__(self, processor, model, vocoder):
137
+ self.processor = processor
138
+ self.model = model
139
+ self.vocoder = vocoder
140
+
141
+ def __call__(self, text):
142
+ inputs = self.processor(text=text, return_tensors="pt")
143
+ # Use default speaker embeddings
144
+ import datasets
145
+ embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
146
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
147
+
148
+ speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
149
+
150
+ return {
151
+ "audio": speech.numpy(),
152
+ "sampling_rate": 16000
153
+ }
154
+
155
+ pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
156
+ else:
157
+ # Standard pipeline for MMS models
158
+ pipeline_obj = pipeline(
159
+ "text-to-speech",
160
+ model=model_config["model_repo"],
161
+ device=self.device
162
+ )
163
+
164
+ self.tts_pipelines[language_code] = pipeline_obj
165
+ print(f"TTS: Loaded fallback TTS model for {language_code}")
166
+ return
167
+
168
+ except Exception as e:
169
+ print(f"TTS: Failed to load fallback TTS model for {language_code}: {e}")
170
+
171
+ print(f"TTS: No TTS model available for language: {language_code}")
172
+
173
+ async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
174
+ """Generate speech audio from text
175
+
176
+ Args:
177
+ text: Text to convert to speech
178
+ language_code: Language code for TTS model
179
+ output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
180
+
181
+ Returns:
182
+ Audio bytes in the requested format, or None if generation fails
183
+ """
184
+ try:
185
+ print(f"=== TTS GENERATION REQUEST ===")
186
+ print(f"Text: '{text}'")
187
+ print(f"Language: {language_code}")
188
+ print(f"Output format: {output_format}")
189
+
190
+ # Input validation: Check for invalid or problematic text
191
+ if not text or not text.strip():
192
+ print("TTS: Empty or whitespace-only text, skipping TTS generation")
193
+ return None
194
+
195
+ # Check for very short text that might cause issues
196
+ clean_text = text.strip()
197
+ if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
198
+ print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
199
+ return None
200
+
201
+ # Check for minimum meaningful length
202
+ if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
203
+ print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
204
+ return None
205
+
206
+ print(f"TTS pipelines available: {list(self.tts_pipelines.keys())}")
207
+ print(f"TTS config available: {list(self.tts_config.keys())}")
208
+ print(f"Fallback config available: {list(self.fallback_tts_config.keys())}")
209
+
210
+ # Check if the language is supported
211
+ if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
212
+ print(f"TTS: Language {language_code} not configured for TTS")
213
+ return None
214
+
215
+ await self.ensure_model_loaded(language_code)
216
+
217
+ if language_code not in self.tts_pipelines:
218
+ print(f"TTS: TTS model not available for language: {language_code}")
219
+ return None
220
+
221
+ if not text or not text.strip():
222
+ print("TTS: Empty text provided")
223
+ return None
224
+
225
+ print(f"TTS: Generating speech for '{text}' in {language_code}")
226
+
227
+ # Generate speech
228
+ pipeline_obj = self.tts_pipelines[language_code]
229
+ result = pipeline_obj(text)
230
+
231
+ audio_array = result["audio"]
232
+ sample_rate = result.get("sampling_rate", 22050)
233
+
234
+ print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
235
+
236
+ # Validate audio array
237
+ if len(audio_array) == 0:
238
+ print("TTS: Warning - Generated audio array is empty")
239
+ return None
240
+
241
+ # Check for potential issues with audio data
242
+ audio_min = np.min(audio_array)
243
+ audio_max = np.max(audio_array)
244
+ audio_rms = np.sqrt(np.mean(audio_array**2))
245
+ print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
246
+
247
+ # Check if audio might be silent or corrupted
248
+ if audio_rms < 0.001:
249
+ print("TTS: Warning - Audio appears to be very quiet or silent")
250
+ if audio_max > 1.0 or audio_min < -1.0:
251
+ print("TTS: Warning - Audio values outside expected range [-1, 1]")
252
+ # Clip to valid range
253
+ audio_array = np.clip(audio_array, -1.0, 1.0)
254
+ print("TTS: Clipped audio to valid range")
255
+
256
+ # Convert to WAV bytes with appropriate sample rate
257
+ if output_format == "wav":
258
+ # For Android: use 16kHz sample rate
259
+ target_sample_rate = 16000
260
+ wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
261
+ print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
262
+
263
+ # Convert sample rate to 16kHz if needed for Android compatibility
264
+ if sample_rate != target_sample_rate:
265
+ print(f"TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz for Android compatibility")
266
+ wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
267
+ print(f"TTS: Resampled WAV: {len(wav_bytes)} bytes")
268
+
269
+ print(f"TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
270
+ print(f"=== TTS GENERATION COMPLETE ===")
271
+
272
+ return wav_bytes
273
+ else:
274
+ # For web: use original sample rate and convert to WebM
275
+ wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
276
+ print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
277
+
278
+ # Convert to WebM format for web compatibility
279
+ webm_bytes = await self._convert_to_webm(wav_bytes)
280
+
281
+ print(f"TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
282
+ print(f"=== TTS GENERATION COMPLETE ===")
283
+
284
+ return webm_bytes
285
+
286
+ except Exception as e:
287
+ print(f"TTS: TTS generation error: {e}")
288
+ import traceback
289
+ traceback.print_exc()
290
+ return None
291
+
292
+ async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
293
+ """Generate speech audio in both WebM and WAV formats
294
+
295
+ Args:
296
+ text: Text to convert to speech
297
+ language_code: Language code for TTS model
298
+
299
+ Returns:
300
+ Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
301
+ """
302
+ try:
303
+ print(f"=== TTS DUAL FORMAT GENERATION REQUEST ===")
304
+ print(f"Text: '{text}'")
305
+ print(f"Language: {language_code}")
306
+
307
+ # Input validation: Check for invalid or problematic text
308
+ if not text or not text.strip():
309
+ print("TTS: Empty or whitespace-only text, skipping TTS generation")
310
+ return None, None
311
+
312
+ # Check for very short text that might cause issues
313
+ clean_text = text.strip()
314
+ if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
315
+ print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
316
+ return None, None
317
+
318
+ # Check for minimum meaningful length
319
+ if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
320
+ print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
321
+ return None, None
322
+
323
+ # Check if the language is supported
324
+ if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
325
+ print(f"TTS: Language {language_code} not configured for TTS")
326
+ return None, None
327
+
328
+ await self.ensure_model_loaded(language_code)
329
+
330
+ if language_code not in self.tts_pipelines:
331
+ print(f"TTS: TTS model not available for language: {language_code}")
332
+ return None, None
333
+
334
+ print(f"TTS: Generating speech for '{text}' in {language_code}")
335
+
336
+ # Generate speech once
337
+ pipeline_obj = self.tts_pipelines[language_code]
338
+ result = pipeline_obj(text)
339
+
340
+ audio_array = result["audio"]
341
+ sample_rate = result.get("sampling_rate", 22050)
342
+
343
+ print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
344
+
345
+ # Validate audio array
346
+ if len(audio_array) == 0:
347
+ print("TTS: Warning - Generated audio array is empty")
348
+ return None, None
349
+
350
+ # Check for potential issues with audio data
351
+ audio_min = np.min(audio_array)
352
+ audio_max = np.max(audio_array)
353
+ audio_rms = np.sqrt(np.mean(audio_array**2))
354
+ print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
355
+
356
+ # Check if audio might be silent or corrupted
357
+ if audio_rms < 0.001:
358
+ print("TTS: Warning - Audio appears to be very quiet or silent")
359
+ if audio_max > 1.0 or audio_min < -1.0:
360
+ print("TTS: Warning - Audio values outside expected range [-1, 1]")
361
+ # Clip to valid range
362
+ audio_array = np.clip(audio_array, -1.0, 1.0)
363
+ print("TTS: Clipped audio to valid range")
364
+
365
+ # Generate WAV at original sample rate first
366
+ wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
367
+ print(f"TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
368
+
369
+ # Generate WebM from original WAV
370
+ webm_bytes = await self._convert_to_webm(wav_bytes_original)
371
+ print(f"TTS: Converted to WebM: {len(webm_bytes)} bytes")
372
+
373
+ # Generate 16kHz WAV for Android
374
+ wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
375
+ print(f"TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
376
+
377
+ print(f"TTS: Generated dual format audio for '{text}'")
378
+ print(f"=== TTS DUAL FORMAT GENERATION COMPLETE ===")
379
+
380
+ return webm_bytes, wav_bytes_16k
381
+
382
+ except Exception as e:
383
+ print(f"TTS: Dual format TTS generation error: {e}")
384
+ import traceback
385
+ traceback.print_exc()
386
+ return None, None
387
+
388
+ def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
389
+ """Convert numpy audio array to WAV bytes"""
390
+ buffer = io.BytesIO()
391
+ with wave.open(buffer, 'wb') as wav_file:
392
+ wav_file.setnchannels(1) # Mono
393
+ wav_file.setsampwidth(2) # 16-bit
394
+ wav_file.setframerate(sample_rate)
395
+
396
+ # Ensure audio is in valid range [-1, 1]
397
+ audio_array = np.clip(audio_array, -1.0, 1.0)
398
+
399
+ # Convert to int16 with proper scaling
400
+ int16_audio = (audio_array * 32767).astype(np.int16)
401
+
402
+ # Validate the converted audio
403
+ print(f"TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
404
+ print(f"TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
405
+
406
+ wav_file.writeframes(int16_audio.tobytes())
407
+
408
+ wav_data = buffer.getvalue()
409
+ print(f"TTS: WAV file created: {len(wav_data)} bytes (expected header: 44 bytes + {len(int16_audio) * 2} data bytes)")
410
+
411
+ return wav_data
412
+
413
+ async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
414
+ """Resample WAV audio to 16kHz using FFmpeg"""
415
+ try:
416
+ process = subprocess.Popen([
417
+ "ffmpeg", "-f", "wav", "-i", "pipe:0",
418
+ "-ar", "16000", # Set output sample rate to 16kHz
419
+ "-ac", "1", # Ensure mono output
420
+ "-f", "wav", "pipe:1"
421
+ ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
422
+
423
+ resampled_data, stderr = process.communicate(input=wav_bytes)
424
+
425
+ if process.returncode != 0:
426
+ print(f"TTS: FFmpeg resampling error: {stderr.decode()}")
427
+ return wav_bytes # Return original if resampling fails
428
+
429
+ return resampled_data
430
+
431
+ except Exception as e:
432
+ print(f"TTS: Resampling error: {e}")
433
+ return wav_bytes # Return original if resampling fails
434
+
435
+ async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
436
+ """Convert WAV bytes to WebM format using FFmpeg"""
437
+ try:
438
+ process = subprocess.Popen([
439
+ "ffmpeg", "-f", "wav", "-i", "pipe:0",
440
+ "-c:a", "libopus", "-b:a", "64k",
441
+ "-f", "webm", "pipe:1"
442
+ ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
443
+
444
+ webm_data, stderr = process.communicate(input=wav_bytes)
445
+
446
+ if process.returncode != 0:
447
+ print(f"TTS: FFmpeg error: {stderr.decode()}")
448
+ return wav_bytes # Return original WAV if conversion fails
449
+
450
+ return webm_data
451
+
452
+ except Exception as e:
453
+ print(f"TTS: WebM conversion error: {e}")
454
+ return wav_bytes # Return original WAV if conversion fails
455
+
456
+ async def load_remaining_models_in_background(self):
457
+ """Load all remaining TTS models in the background after startup"""
458
+ try:
459
+ print("TTS: Starting background loading of additional voice models...")
460
+
461
+ # Load primary models first
462
+ for lang_code in self.tts_config.keys():
463
+ if lang_code not in self.preload_languages and lang_code not in self.tts_pipelines:
464
+ if self.espeak_available:
465
+ try:
466
+ print(f"TTS: Background loading primary model for {lang_code}...")
467
+ self.models_loading_status[lang_code] = "loading"
468
+
469
+ model_config = self.tts_config[lang_code]
470
+ pipeline_obj = pipeline(
471
+ "text-to-speech",
472
+ model=model_config["model_repo"],
473
+ device=self.device
474
+ )
475
+ self.tts_pipelines[lang_code] = pipeline_obj
476
+ self.models_loading_status[lang_code] = "loaded"
477
+ print(f"TTS: Successfully loaded primary model for {lang_code} in background")
478
+
479
+ # Add a small delay between loading models
480
+ await asyncio.sleep(2)
481
+ except Exception as e:
482
+ print(f"TTS: Failed to load primary model for {lang_code} in background: {e}")
483
+ self.models_loading_status[lang_code] = "failed"
484
+
485
+ # Load fallback models for languages not yet loaded
486
+ for lang_code in self.fallback_tts_config.keys():
487
+ if lang_code not in self.tts_pipelines:
488
+ try:
489
+ print(f"TTS: Background loading fallback model for {lang_code}...")
490
+ model_config = self.fallback_tts_config[lang_code]
491
+
492
+ if model_config["model_type"] == "speecht5":
493
+ from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
494
+ processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
495
+ model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
496
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
497
+ if self.device >= 0:
498
+ model = model.to(f"cuda:{self.device}")
499
+ vocoder = vocoder.to(f"cuda:{self.device}")
500
+ self.tts_pipelines[lang_code] = {
501
+ "type": "speecht5",
502
+ "processor": processor,
503
+ "model": model,
504
+ "vocoder": vocoder
505
+ }
506
+ else:
507
+ pipeline_obj = pipeline(
508
+ "text-to-speech",
509
+ model=model_config["model_repo"],
510
+ device=self.device
511
+ )
512
+ self.tts_pipelines[lang_code] = pipeline_obj
513
+
514
+ print(f"TTS: Successfully loaded fallback model for {lang_code} in background")
515
+ await asyncio.sleep(2)
516
+ except Exception as e:
517
+ print(f"TTS: Failed to load fallback model for {lang_code}: {e}")
518
+
519
+ print("TTS: Background loading of all voice models complete")
520
+ print(f"TTS: Loaded models: {list(self.tts_pipelines.keys())}")
521
+ except Exception as e:
522
+ print(f"TTS: Error in background model loading: {e}")
523
+
524
+ def start_background_loading(self):
525
+ """Start background loading of models as a non-blocking task"""
526
+ if self.background_loading_task is None:
527
+ self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
528
+ print("TTS: Background model loading task started")
529
+
530
+ async def cleanup(self):
531
+ """Cleanup resources"""
532
+ # Cancel background loading if still running
533
+ if self.background_loading_task and not self.background_loading_task.done():
534
+ self.background_loading_task.cancel()
535
+ try:
536
+ await self.background_loading_task
537
+ except asyncio.CancelledError:
538
+ pass
539
+
540
+ self.tts_pipelines.clear()
541
+ print("TTS: TTS service cleaned up")
app/services/tts_service_onnx.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import io
3
+ import wave
4
+ import numpy as np
5
+ import subprocess
6
+ from typing import Dict, Optional
7
+ import onnxruntime as ort
8
+ from transformers import AutoProcessor
9
+ from collections import OrderedDict
10
+ import os
11
+
12
+ class ONNXTTSService:
13
+ def __init__(self):
14
+ self.tts_models: Dict[str, any] = {}
15
+ self.processors: Dict[str, any] = {}
16
+ self.max_tts_models = 3 # Keep up to 3 TTS models in memory
17
+ self.model_cache = OrderedDict() # LRU cache
18
+
19
+ # GPU optimization - detect and configure providers
20
+ available_providers = ort.get_available_providers()
21
+ print(f"ONNX TTS: Available providers: {available_providers}")
22
+
23
+ if 'CUDAExecutionProvider' in available_providers:
24
+ # Configure CUDA provider with optimizations
25
+ cuda_provider_options = {
26
+ 'device_id': 0,
27
+ 'arena_extend_strategy': 'kNextPowerOfTwo',
28
+ 'gpu_mem_limit': int(0.7 * 1024 * 1024 * 1024), # 70% of GPU memory (TTS uses less than ASR)
29
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
30
+ 'do_copy_in_default_stream': True,
31
+ }
32
+ self.providers = [('CUDAExecutionProvider', cuda_provider_options), 'CPUExecutionProvider']
33
+ print(f"ONNX TTS: Using CUDA acceleration with GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
34
+ else:
35
+ self.providers = ['CPUExecutionProvider']
36
+ print("ONNX TTS: CUDA not available, using CPU execution")
37
+
38
+ print(f"ONNX TTS: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
39
+
40
+ # Check if espeak is available
41
+ self.espeak_available = self._check_espeak_availability()
42
+
43
+ # ONNX TTS model configurations - using FP32 optimized models (16kHz corrected)
44
+ self.tts_config = {
45
+ "kik": {"model_repo": "mutisya/vits-tts-onnx-fp32-kikuyu-v25-37-1", "model_type": "vits", "use_onnx": True},
46
+ "luo": {"model_repo": "mutisya/vits-tts-onnx-fp32-luo-v25-37-1", "model_type": "vits", "use_onnx": True},
47
+ "kam": {"model_repo": "mutisya/vits-tts-onnx-fp32-kamba-v25-37-1", "model_type": "vits", "use_onnx": True},
48
+ "mer": {"model_repo": "mutisya/vits-tts-onnx-fp32-kimeru-v25-37-1", "model_type": "vits", "use_onnx": True},
49
+ "som": {"model_repo": "mutisya/vits-tts-onnx-fp32-somali-v25-37-1", "model_type": "vits", "use_onnx": True},
50
+ "swa": {"model_repo": "mutisya/vits-tts-onnx-fp32-swahili-v25-37-1", "model_type": "vits", "use_onnx": True},
51
+ "eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits", "use_onnx": False}, # Fallback to PyTorch
52
+ }
53
+
54
+ # Alternative TTS models that don't require espeak (fallback)
55
+ self.fallback_tts_config = {
56
+ "eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
57
+ "swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
58
+ "som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
59
+ }
60
+
61
+ self.preload_languages = ["kik", "swa"]
62
+
63
+ def _check_espeak_availability(self) -> bool:
64
+ """Check if espeak is available on the system"""
65
+ try:
66
+ result = subprocess.run(['espeak', '--version'],
67
+ capture_output=True, text=True, timeout=5)
68
+ if result.returncode == 0:
69
+ print("ONNX TTS: espeak is available")
70
+ return True
71
+ else:
72
+ print("ONNX TTS: espeak command failed")
73
+ return False
74
+ except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
75
+ print(f"ONNX TTS: espeak not available: {e}")
76
+ return False
77
+
78
+ async def initialize(self):
79
+ """Initialize TTS models for preloaded languages"""
80
+ print("ONNX TTS: Initializing TTS service with ONNX Runtime...")
81
+ print(f"ONNX TTS: espeak available: {self.espeak_available}")
82
+ print(f"ONNX TTS: Using providers: {self.providers}")
83
+
84
+ for lang_code in self.preload_languages:
85
+ await self.ensure_model_loaded(lang_code)
86
+
87
+ async def ensure_model_loaded(self, language_code: str):
88
+ """Load TTS model for language if not already loaded with LRU cache"""
89
+ if language_code in self.model_cache:
90
+ # Move to end (most recently used)
91
+ self.model_cache.move_to_end(language_code)
92
+ return
93
+
94
+ # Check if we need to evict old models
95
+ while len(self.model_cache) >= self.max_tts_models:
96
+ # Remove least recently used model
97
+ old_lang, _ = self.model_cache.popitem(last=False)
98
+ if old_lang in self.tts_models:
99
+ del self.tts_models[old_lang]
100
+ if old_lang in self.processors:
101
+ del self.processors[old_lang]
102
+ print(f"ONNX TTS: Evicted model for {old_lang} (LRU cache)")
103
+
104
+ # First try to load ONNX model
105
+ if language_code in self.tts_config:
106
+ model_config = self.tts_config[language_code]
107
+
108
+ if model_config.get("use_onnx", False):
109
+ try:
110
+ print(f"ONNX TTS: Loading ONNX model for {language_code}")
111
+
112
+ # Create ONNX session with optimizations and verbose logging
113
+ session_options = ort.SessionOptions()
114
+ session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
115
+
116
+ # Enable verbose logging to diagnose operator assignments
117
+ session_options.log_severity_level = 1 # WARNING level for detailed logs
118
+ session_options.logid = "ONNX_TTS" # Prefix for log identification
119
+
120
+ # GPU memory optimization for T4 with diagnostic tracing
121
+ if 'CUDAExecutionProvider' in self.providers:
122
+ provider_options = [{
123
+ 'device_id': 0,
124
+ 'arena_extend_strategy': 'kSameAsRequested',
125
+ 'gpu_mem_limit': int(0.3 * 1024 * 1024 * 1024), # 30% of GPU memory for TTS
126
+ 'cudnn_conv_algo_search': 'EXHAUSTIVE',
127
+ 'do_copy_in_default_stream': True,
128
+ 'enable_tracing': True, # Enable tracing for better diagnostics
129
+ }]
130
+ providers = [('CUDAExecutionProvider', provider_options[0]), 'CPUExecutionProvider']
131
+ else:
132
+ providers = self.providers
133
+
134
+ # Get authentication token for private repos
135
+ import os
136
+ auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
137
+
138
+ # Download ONNX model from HuggingFace Hub with authentication
139
+ from huggingface_hub import hf_hub_download
140
+ onnx_path = hf_hub_download(
141
+ repo_id=model_config["model_repo"],
142
+ filename="model.onnx",
143
+ token=auth_token
144
+ )
145
+
146
+ session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
147
+
148
+ # Load processor for preprocessing with authentication
149
+ processor = AutoProcessor.from_pretrained(
150
+ model_config["model_repo"],
151
+ token=auth_token
152
+ )
153
+
154
+ self.tts_models[language_code] = session
155
+ self.processors[language_code] = processor
156
+ self.model_cache[language_code] = True
157
+
158
+ print(f"ONNX TTS: Successfully loaded ONNX model for {language_code}")
159
+ return
160
+
161
+ except Exception as e:
162
+ print(f"ONNX TTS: Failed to load ONNX model for {language_code}: {e}")
163
+ # Continue to try fallback models
164
+ else:
165
+ # Try PyTorch model if ONNX not available
166
+ try:
167
+ print(f"ONNX TTS: Loading PyTorch model for {language_code} (fallback)")
168
+ from transformers import pipeline
169
+
170
+ pipeline_obj = pipeline(
171
+ "text-to-speech",
172
+ model=model_config["model_repo"],
173
+ device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
174
+ )
175
+ self.tts_models[language_code] = pipeline_obj
176
+ self.processors[language_code] = None # Not needed for pipeline
177
+ self.model_cache[language_code] = True
178
+
179
+ print(f"ONNX TTS: Successfully loaded PyTorch model for {language_code}")
180
+ return
181
+
182
+ except Exception as e:
183
+ print(f"ONNX TTS: Failed to load PyTorch model for {language_code}: {e}")
184
+
185
+ # Try fallback models if primary failed
186
+ if language_code in self.fallback_tts_config:
187
+ try:
188
+ model_config = self.fallback_tts_config[language_code]
189
+
190
+ if model_config["model_type"] == "speecht5":
191
+ # Special handling for SpeechT5
192
+ from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
193
+ import torch
194
+
195
+ # Get authentication token for private repos
196
+ import os
197
+ auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
198
+
199
+ processor = SpeechT5Processor.from_pretrained(
200
+ model_config["model_repo"],
201
+ token=auth_token
202
+ )
203
+ model = SpeechT5ForTextToSpeech.from_pretrained(
204
+ model_config["model_repo"],
205
+ token=auth_token
206
+ )
207
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
208
+
209
+ # Create a custom pipeline-like object
210
+ class SpeechT5Pipeline:
211
+ def __init__(self, processor, model, vocoder):
212
+ self.processor = processor
213
+ self.model = model
214
+ self.vocoder = vocoder
215
+
216
+ def __call__(self, text):
217
+ inputs = self.processor(text=text, return_tensors="pt")
218
+ # Use default speaker embeddings
219
+ import datasets
220
+ embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
221
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
222
+
223
+ speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
224
+
225
+ return {
226
+ "audio": speech.numpy(),
227
+ "sampling_rate": 16000
228
+ }
229
+
230
+ pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
231
+ else:
232
+ # Standard pipeline for MMS models
233
+ from transformers import pipeline
234
+ pipeline_obj = pipeline(
235
+ "text-to-speech",
236
+ model=model_config["model_repo"],
237
+ device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
238
+ )
239
+
240
+ self.tts_models[language_code] = pipeline_obj
241
+ self.processors[language_code] = None
242
+ self.model_cache[language_code] = True
243
+
244
+ print(f"ONNX TTS: Successfully loaded fallback model for {language_code}")
245
+ return
246
+
247
+ except Exception as e:
248
+ print(f"ONNX TTS: Failed to load fallback TTS model for {language_code}: {e}")
249
+
250
+ print(f"ONNX TTS: No TTS model available for language: {language_code}")
251
+
252
+ async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
253
+ """Generate speech audio from text using ONNX models
254
+
255
+ Args:
256
+ text: Text to convert to speech
257
+ language_code: Language code for TTS model
258
+ output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
259
+
260
+ Returns:
261
+ Audio bytes in the requested format, or None if generation fails
262
+ """
263
+ try:
264
+ print(f"=== ONNX TTS GENERATION REQUEST ===")
265
+ print(f"Text: '{text}'")
266
+ print(f"Language: {language_code}")
267
+ print(f"Output format: {output_format}")
268
+
269
+ # Input validation
270
+ if not text or not text.strip():
271
+ print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
272
+ return None
273
+
274
+ # Check for very short text that might cause issues
275
+ clean_text = text.strip()
276
+ if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
277
+ print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
278
+ return None
279
+
280
+ # Check for minimum meaningful length
281
+ if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
282
+ print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
283
+ return None
284
+
285
+ # Check if the language is supported
286
+ if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
287
+ print(f"ONNX TTS: Language {language_code} not configured for TTS")
288
+ return None
289
+
290
+ await self.ensure_model_loaded(language_code)
291
+
292
+ if language_code not in self.tts_models:
293
+ print(f"ONNX TTS: TTS model not available for language: {language_code}")
294
+ return None
295
+
296
+ print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
297
+
298
+ # Generate speech based on model type
299
+ model_config = self.tts_config.get(language_code, {})
300
+ if model_config.get("use_onnx", False):
301
+ # ONNX inference
302
+ audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
303
+ else:
304
+ # PyTorch pipeline inference
305
+ pipeline_obj = self.tts_models[language_code]
306
+ result = pipeline_obj(text)
307
+
308
+ audio_array = result["audio"]
309
+ sample_rate = result.get("sampling_rate", 16000) # Default to 16kHz (corrected)
310
+
311
+ print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
312
+
313
+ # Validate audio array
314
+ if len(audio_array) == 0:
315
+ print("ONNX TTS: Warning - Generated audio array is empty")
316
+ return None
317
+
318
+ # Check audio statistics
319
+ audio_min = np.min(audio_array)
320
+ audio_max = np.max(audio_array)
321
+ audio_rms = np.sqrt(np.mean(audio_array**2))
322
+ print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
323
+
324
+ # Check if audio might be silent or corrupted
325
+ if audio_rms < 0.001:
326
+ print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
327
+ if audio_max > 1.0 or audio_min < -1.0:
328
+ print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
329
+ # Clip to valid range
330
+ audio_array = np.clip(audio_array, -1.0, 1.0)
331
+ print("ONNX TTS: Clipped audio to valid range")
332
+
333
+ # Convert to requested format
334
+ if output_format == "wav":
335
+ # For Android: use 16kHz sample rate
336
+ target_sample_rate = 16000
337
+ wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
338
+ print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
339
+
340
+ # Convert sample rate to 16kHz if needed for Android compatibility
341
+ if sample_rate != target_sample_rate:
342
+ print(f"ONNX TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz")
343
+ wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
344
+ print(f"ONNX TTS: Resampled WAV: {len(wav_bytes)} bytes")
345
+
346
+ print(f"ONNX TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
347
+ print(f"=== ONNX TTS GENERATION COMPLETE ===")
348
+
349
+ return wav_bytes
350
+ else:
351
+ # For web: use original sample rate and convert to WebM
352
+ wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
353
+ print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
354
+
355
+ # Convert to WebM format for web compatibility
356
+ webm_bytes = await self._convert_to_webm(wav_bytes)
357
+
358
+ print(f"ONNX TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
359
+ print(f"=== ONNX TTS GENERATION COMPLETE ===")
360
+
361
+ return webm_bytes
362
+
363
+ except Exception as e:
364
+ print(f"ONNX TTS: TTS generation error: {e}")
365
+ import traceback
366
+ traceback.print_exc()
367
+ return None
368
+
369
+ async def _run_onnx_tts_inference(self, text: str, language_code: str) -> tuple[np.ndarray, int]:
370
+ """Run ONNX inference for text-to-speech"""
371
+ try:
372
+ session = self.tts_models[language_code]
373
+ processor = self.processors[language_code]
374
+
375
+ # Preprocess text
376
+ inputs = processor(text=text, return_tensors="np")
377
+
378
+ # Get input names for ONNX session
379
+ input_names = [inp.name for inp in session.get_inputs()]
380
+
381
+ # Prepare inputs for ONNX
382
+ onnx_inputs = {}
383
+ for name in input_names:
384
+ if name in inputs:
385
+ onnx_inputs[name] = inputs[name]
386
+ elif name == "input_ids" and "input_ids" in inputs:
387
+ onnx_inputs[name] = inputs["input_ids"].astype(np.int64)
388
+ elif name == "attention_mask" and "attention_mask" in inputs:
389
+ onnx_inputs[name] = inputs["attention_mask"].astype(np.int64)
390
+
391
+ # Run ONNX inference
392
+ outputs = session.run(None, onnx_inputs)
393
+
394
+ # Extract audio from outputs (assuming first output is audio)
395
+ audio_array = outputs[0]
396
+
397
+ # Ensure audio is 1D
398
+ if audio_array.ndim > 1:
399
+ audio_array = audio_array.flatten()
400
+
401
+ # Convert to float32 if needed
402
+ if audio_array.dtype != np.float32:
403
+ audio_array = audio_array.astype(np.float32)
404
+
405
+ # Sample rate is 16kHz for our corrected models
406
+ sample_rate = 16000
407
+
408
+ return audio_array, sample_rate
409
+
410
+ except Exception as e:
411
+ print(f"ONNX TTS: Inference error: {e}")
412
+ import traceback
413
+ traceback.print_exc()
414
+ return np.array([], dtype=np.float32), 16000
415
+
416
+ async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
417
+ """Generate speech audio in both WebM and WAV formats using ONNX
418
+
419
+ Args:
420
+ text: Text to convert to speech
421
+ language_code: Language code for TTS model
422
+
423
+ Returns:
424
+ Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
425
+ """
426
+ try:
427
+ print(f"=== ONNX TTS DUAL FORMAT GENERATION REQUEST ===")
428
+ print(f"Text: '{text}'")
429
+ print(f"Language: {language_code}")
430
+
431
+ # Input validation
432
+ if not text or not text.strip():
433
+ print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
434
+ return None, None
435
+
436
+ clean_text = text.strip()
437
+ if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
438
+ print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
439
+ return None, None
440
+
441
+ if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
442
+ print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
443
+ return None, None
444
+
445
+ # Check if the language is supported
446
+ if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
447
+ print(f"ONNX TTS: Language {language_code} not configured for TTS")
448
+ return None, None
449
+
450
+ await self.ensure_model_loaded(language_code)
451
+
452
+ if language_code not in self.tts_models:
453
+ print(f"ONNX TTS: TTS model not available for language: {language_code}")
454
+ return None, None
455
+
456
+ print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
457
+
458
+ # Generate speech once
459
+ model_config = self.tts_config.get(language_code, {})
460
+ if model_config.get("use_onnx", False):
461
+ # ONNX inference
462
+ audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
463
+ else:
464
+ # PyTorch pipeline inference
465
+ pipeline_obj = self.tts_models[language_code]
466
+ result = pipeline_obj(text)
467
+
468
+ audio_array = result["audio"]
469
+ sample_rate = result.get("sampling_rate", 16000)
470
+
471
+ print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
472
+
473
+ # Validate audio array
474
+ if len(audio_array) == 0:
475
+ print("ONNX TTS: Warning - Generated audio array is empty")
476
+ return None, None
477
+
478
+ # Check for potential issues with audio data
479
+ audio_min = np.min(audio_array)
480
+ audio_max = np.max(audio_array)
481
+ audio_rms = np.sqrt(np.mean(audio_array**2))
482
+ print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
483
+
484
+ if audio_rms < 0.001:
485
+ print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
486
+ if audio_max > 1.0 or audio_min < -1.0:
487
+ print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
488
+ audio_array = np.clip(audio_array, -1.0, 1.0)
489
+ print("ONNX TTS: Clipped audio to valid range")
490
+
491
+ # Generate WAV at original sample rate first
492
+ wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
493
+ print(f"ONNX TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
494
+
495
+ # Generate WebM from original WAV
496
+ webm_bytes = await self._convert_to_webm(wav_bytes_original)
497
+ print(f"ONNX TTS: Converted to WebM: {len(webm_bytes)} bytes")
498
+
499
+ # Generate 16kHz WAV for Android
500
+ wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
501
+ print(f"ONNX TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
502
+
503
+ print(f"ONNX TTS: Generated dual format audio for '{text}'")
504
+ print(f"=== ONNX TTS DUAL FORMAT GENERATION COMPLETE ===")
505
+
506
+ return webm_bytes, wav_bytes_16k
507
+
508
+ except Exception as e:
509
+ print(f"ONNX TTS: Dual format TTS generation error: {e}")
510
+ import traceback
511
+ traceback.print_exc()
512
+ return None, None
513
+
514
+ def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
515
+ """Convert numpy audio array to WAV bytes"""
516
+ buffer = io.BytesIO()
517
+ with wave.open(buffer, 'wb') as wav_file:
518
+ wav_file.setnchannels(1) # Mono
519
+ wav_file.setsampwidth(2) # 16-bit
520
+ wav_file.setframerate(sample_rate)
521
+
522
+ # Ensure audio is in valid range [-1, 1]
523
+ audio_array = np.clip(audio_array, -1.0, 1.0)
524
+
525
+ # Convert to int16 with proper scaling
526
+ int16_audio = (audio_array * 32767).astype(np.int16)
527
+
528
+ # Validate the converted audio
529
+ print(f"ONNX TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
530
+ print(f"ONNX TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
531
+
532
+ wav_file.writeframes(int16_audio.tobytes())
533
+
534
+ wav_data = buffer.getvalue()
535
+ print(f"ONNX TTS: WAV file created: {len(wav_data)} bytes")
536
+
537
+ return wav_data
538
+
539
+ async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
540
+ """Resample WAV audio to 16kHz using FFmpeg"""
541
+ try:
542
+ process = subprocess.Popen([
543
+ "ffmpeg", "-f", "wav", "-i", "pipe:0",
544
+ "-ar", "16000", # Set output sample rate to 16kHz
545
+ "-ac", "1", # Ensure mono output
546
+ "-f", "wav", "pipe:1"
547
+ ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
548
+
549
+ resampled_data, stderr = process.communicate(input=wav_bytes)
550
+
551
+ if process.returncode != 0:
552
+ print(f"ONNX TTS: FFmpeg resampling error: {stderr.decode()}")
553
+ return wav_bytes # Return original if resampling fails
554
+
555
+ return resampled_data
556
+
557
+ except Exception as e:
558
+ print(f"ONNX TTS: Resampling error: {e}")
559
+ return wav_bytes # Return original if resampling fails
560
+
561
+ async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
562
+ """Convert WAV bytes to WebM format using FFmpeg"""
563
+ try:
564
+ process = subprocess.Popen([
565
+ "ffmpeg", "-f", "wav", "-i", "pipe:0",
566
+ "-c:a", "libopus", "-b:a", "64k",
567
+ "-f", "webm", "pipe:1"
568
+ ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
569
+
570
+ webm_data, stderr = process.communicate(input=wav_bytes)
571
+
572
+ if process.returncode != 0:
573
+ print(f"ONNX TTS: FFmpeg error: {stderr.decode()}")
574
+ return wav_bytes # Return original WAV if conversion fails
575
+
576
+ return webm_data
577
+
578
+ except Exception as e:
579
+ print(f"ONNX TTS: WebM conversion error: {e}")
580
+ return wav_bytes # Return original WAV if conversion fails
581
+
582
+ async def cleanup(self):
583
+ """Cleanup resources"""
584
+ self.tts_models.clear()
585
+ self.processors.clear()
586
+ self.model_cache.clear()
587
+ print("ONNX TTS: TTS service cleaned up")
app/services/websocket_manager.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uuid
3
+ from typing import Dict, Set, Optional
4
+ import socketio
5
+ import numpy as np
6
+
7
+ from app.models import Message, LanguageCode
8
+ from app.services.session_manager import SessionManager, LANGUAGE_MAP
9
+ from app.services.transcription_service import TranscriptionService
10
+ from app.services.translation_service import TranslationService
11
+ from app.services.tts_service import TTSService
12
+
13
+ def truncate_array_for_log(arr, max_items=10):
14
+ """Helper function to truncate arrays in log messages for readability"""
15
+ if not arr or len(arr) <= max_items:
16
+ return arr
17
+ return arr[:max_items] + [f"... {len(arr) - max_items} more items"]
18
+
19
+ class WebSocketManager:
20
+ def __init__(self, session_manager: SessionManager, transcription_service: TranscriptionService,
21
+ translation_service: TranslationService, tts_service: TTSService):
22
+ self.session_manager = session_manager
23
+ self.transcription_service = transcription_service
24
+ self.translation_service = translation_service
25
+ self.tts_service = tts_service
26
+ self.sio = None # Will be set by main.py
27
+
28
+ self.client_sessions: Dict[str, str] = {} # sid -> session_id
29
+ self.client_participants: Dict[str, str] = {} # sid -> participant_id
30
+ self.session_clients: Dict[str, Set[str]] = {} # session_id -> set of sids
31
+
32
+ self.messages: Dict[str, Message] = {} # message_id -> message
33
+ self.participant_current_message: Dict[str, str] = {} # participant_id -> current_message_id
34
+ self.processed_messages: Set[str] = set() # Track processed message IDs to prevent duplicates
35
+
36
+ def set_socketio(self, sio):
37
+ """Set the Socket.IO server instance"""
38
+ self.sio = sio
39
+
40
+ async def handle_join_session(self, sid: str, data: dict):
41
+ """Handle participant joining a session"""
42
+ try:
43
+ session_id = data.get('sessionId')
44
+ participant_name = data.get('participantName')
45
+ language_code = data.get('language')
46
+
47
+ print(f"=== JOIN SESSION REQUEST ===")
48
+ print(f"Session ID: {session_id}")
49
+ print(f"Participant: {participant_name}")
50
+ print(f"Language: {language_code}")
51
+
52
+ if not all([session_id, participant_name, language_code]):
53
+ await self._emit_error(sid, "Missing required fields")
54
+ return
55
+
56
+ # Validate language code
57
+ try:
58
+ lang_enum = LanguageCode(language_code)
59
+ print(f"Language code validated: {lang_enum}")
60
+ except ValueError:
61
+ await self._emit_error(sid, f"Invalid language code: {language_code}")
62
+ return
63
+
64
+ # Resolve session ID (in case it's a short code)
65
+ session = await self.session_manager.get_session(session_id)
66
+ if not session:
67
+ await self._emit_error(sid, "Session not found")
68
+ return
69
+
70
+ # Use the full UUID for all subsequent operations
71
+ session_id = session.id
72
+ print(f"Resolved session ID: {session_id}")
73
+
74
+ # Add participant to session
75
+ participant = await self.session_manager.add_participant(
76
+ session_id, participant_name, lang_enum
77
+ )
78
+
79
+ print(f"Participant created: {participant}")
80
+
81
+ if not participant:
82
+ await self._emit_error(sid, "Session not found or unable to join")
83
+ return
84
+
85
+ # Get updated session info
86
+ session = await self.session_manager.get_session(session_id)
87
+ if session:
88
+ print(f"Session {session_id} now has {len(session.languages)} languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
89
+ print(f"Session participants: {[f'{p.name}({p.language.name})' for p in session.participants]}")
90
+
91
+ # Track client connections
92
+ self.client_sessions[sid] = session_id
93
+ self.client_participants[sid] = participant.id
94
+
95
+ if session_id not in self.session_clients:
96
+ self.session_clients[session_id] = set()
97
+ self.session_clients[session_id].add(sid)
98
+
99
+ # Send success response
100
+ await self.sio.emit('participant_joined', participant.dict(), room=sid)
101
+
102
+ # Notify other participants
103
+ await self._broadcast_to_session(session_id, 'participant_update', participant.dict(), exclude_sid=sid)
104
+
105
+ print(f"=== JOIN SESSION COMPLETE ===")
106
+
107
+ except Exception as e:
108
+ print(f"Error in handle_join_session: {e}")
109
+ import traceback
110
+ traceback.print_exc()
111
+ await self._emit_error(sid, "Failed to join session")
112
+
113
+ async def handle_join_hub(self, sid: str, data: dict):
114
+ """Handle hub joining a session for observation"""
115
+ try:
116
+ session_id = data.get('sessionId')
117
+ if not session_id:
118
+ await self._emit_error(sid, "Missing sessionId for hub")
119
+ return
120
+
121
+ # Verify session exists
122
+ session = await self.session_manager.get_session(session_id)
123
+ if not session:
124
+ await self._emit_error(sid, "Session not found")
125
+ return
126
+
127
+ # Track hub connection
128
+ self.client_sessions[sid] = session_id
129
+
130
+ if session_id not in self.session_clients:
131
+ self.session_clients[session_id] = set()
132
+ self.session_clients[session_id].add(sid)
133
+
134
+ # Send success response
135
+ await self.sio.emit('hub_joined', {'sessionId': session_id}, room=sid)
136
+
137
+ print(f"Hub joined session {session_id} with sid {sid}")
138
+
139
+ except Exception as e:
140
+ print(f"Error in handle_join_hub: {e}")
141
+ await self._emit_error(sid, "Failed to join as hub")
142
+
143
+ async def handle_audio_chunk(self, sid: str, data: dict):
144
+ """Handle incoming audio chunk from participant"""
145
+ try:
146
+ participant_id = self.client_participants.get(sid)
147
+ if not participant_id:
148
+ return
149
+
150
+ audio_data = data.get('audioData', [])
151
+ is_pause_boundary = data.get('isPauseBoundary', False)
152
+
153
+ if not audio_data:
154
+ return
155
+
156
+ # Convert array to bytes
157
+ audio_bytes = bytes(audio_data)
158
+
159
+ # Process audio chunk using VAD-based approach
160
+ if audio_bytes:
161
+ # Check for voice activity in this chunk
162
+ has_voice = self.transcription_service.has_voice_activity(audio_bytes)
163
+
164
+ # Process the chunk (even if no voice to handle silence detection)
165
+ # If isPauseBoundary is True, force finalization by treating as silence
166
+ await self._process_audio_chunk_vad(participant_id, audio_bytes, has_voice and not is_pause_boundary, is_pause_boundary)
167
+
168
+ except Exception as e:
169
+ print(f"Error in handle_audio_chunk: {e}")
170
+ import traceback
171
+ traceback.print_exc()
172
+
173
+ async def handle_speaking_status(self, sid: str, data: dict):
174
+ """Handle speaking status updates"""
175
+ try:
176
+ participant_id = self.client_participants.get(sid)
177
+ if not participant_id:
178
+ return
179
+
180
+ is_speaking = data.get('isSpeaking', False)
181
+ await self.session_manager.update_participant_speaking_status(participant_id, is_speaking)
182
+
183
+ # If participant stopped speaking, force complete any pending sentence
184
+ if not is_speaking:
185
+ # Get session and participant info for force completion
186
+ session_id = await self.session_manager.get_participant_session_id(participant_id)
187
+ if session_id:
188
+ session = await self.session_manager.get_session(session_id)
189
+ participant = next((p for p in session.participants if p.id == participant_id), None)
190
+
191
+ if participant:
192
+ # Define the sentence callback for force completion
193
+ async def force_sentence_callback(final_text: str, final_audio: bytes):
194
+ # Create or get existing message
195
+ current_message_id = self.participant_current_message.get(participant_id)
196
+ if not current_message_id:
197
+ current_message_id = str(uuid.uuid4())
198
+
199
+ # Check if this message was already processed
200
+ if current_message_id in self.processed_messages:
201
+ print(f"Force completion: Message {current_message_id} already processed, skipping duplicate")
202
+ return
203
+
204
+ # Mark as processed to prevent duplicates
205
+ self.processed_messages.add(current_message_id)
206
+
207
+ from app.models import Message
208
+ message = Message(
209
+ id=current_message_id,
210
+ session_id=session_id,
211
+ speaker_id=participant_id,
212
+ speaker_name=participant.name,
213
+ original_text=final_text,
214
+ original_language=participant.language,
215
+ translations={},
216
+ is_transcribing=False
217
+ )
218
+ self.messages[current_message_id] = message
219
+
220
+ # Broadcast the completed message
221
+ print(f"Force completion: Broadcasting message_complete for {current_message_id}: '{final_text}'")
222
+ await self._broadcast_to_session(session_id, 'message_complete', {
223
+ 'messageId': current_message_id,
224
+ 'sessionId': session_id,
225
+ 'text': final_text,
226
+ 'speakerId': participant_id,
227
+ 'speakerName': participant.name,
228
+ 'language': participant.language.code.value
229
+ })
230
+
231
+ # Clear current message tracking
232
+ if participant_id in self.participant_current_message:
233
+ del self.participant_current_message[participant_id]
234
+
235
+ # Start translation processing (non-blocking to allow continued audio processing)
236
+ print("Starting TRANSLATION and TTS (background task)")
237
+ asyncio.create_task(self._process_translations_and_tts(message, session))
238
+
239
+ # Force complete any pending sentence
240
+ await self.transcription_service.force_complete_sentence(
241
+ participant_id,
242
+ participant.language.code.value,
243
+ force_sentence_callback
244
+ )
245
+
246
+ # Clear transcription service buffers after force completion
247
+ self.transcription_service.clear_participant_buffers(participant_id)
248
+
249
+ # Broadcast speaking status to session
250
+ session_id = self.client_sessions.get(sid)
251
+ if session_id:
252
+ await self._broadcast_to_session(session_id, 'speaking_status', {
253
+ 'participantId': participant_id,
254
+ 'isSpeaking': is_speaking
255
+ })
256
+
257
+ except Exception as e:
258
+ print(f"Error in handle_speaking_status: {e}")
259
+ import traceback
260
+ traceback.print_exc()
261
+
262
+ async def handle_leave_session(self, sid: str, data: dict):
263
+ """Handle participant leaving a session"""
264
+ await self._cleanup_client(sid)
265
+
266
+ async def handle_disconnect(self, sid: str):
267
+ """Handle client disconnection"""
268
+ await self._cleanup_client(sid)
269
+ async def _process_audio_chunk_vad(self, participant_id: str, audio_data: bytes, has_voice_activity: bool, is_pause_boundary: bool = False):
270
+ """Process audio chunk using VAD-based sentence detection
271
+
272
+ Args:
273
+ participant_id: ID of the participant
274
+ audio_data: Raw audio data bytes
275
+ has_voice_activity: Whether voice activity was detected in this chunk
276
+ is_pause_boundary: If True, forces sentence finalization (from stop button or explicit pause)
277
+ """
278
+ try:
279
+ session_id = await self.session_manager.get_participant_session_id(participant_id)
280
+ if not session_id:
281
+ return
282
+
283
+ session = await self.session_manager.get_session(session_id)
284
+ if not session:
285
+ return
286
+
287
+ participant = next((p for p in session.participants if p.id == participant_id), None)
288
+ if not participant:
289
+ return
290
+
291
+ # Get or create current message for this participant
292
+ current_message_id = self.participant_current_message.get(participant_id)
293
+ if not current_message_id:
294
+ current_message_id = str(uuid.uuid4())
295
+ message = Message(
296
+ id=current_message_id,
297
+ session_id=session_id,
298
+ speaker_id=participant_id,
299
+ speaker_name=participant.name,
300
+ original_text="",
301
+ original_language=participant.language,
302
+ translations={},
303
+ is_transcribing=True
304
+ )
305
+ self.messages[current_message_id] = message
306
+ self.participant_current_message[participant_id] = current_message_id
307
+
308
+ # Start typing indicator
309
+ await self._broadcast_to_session(session_id, 'typing_start', {
310
+ 'speakerId': participant_id,
311
+ 'speakerName': participant.name,
312
+ 'languageCode': participant.language.code.value
313
+ })
314
+
315
+ message = self.messages[current_message_id]
316
+
317
+ # Define callbacks
318
+ async def on_progress(text: str, is_complete: bool):
319
+ """Called with in-progress transcription updates"""
320
+ # Update the message text even for progress updates
321
+ message.original_text = text
322
+
323
+ await self._broadcast_to_session(session_id, 'transcription_progress', {
324
+ 'messageId': current_message_id,
325
+ 'text': text,
326
+ 'isTranscribing': not is_complete,
327
+ 'speakerId': participant_id,
328
+ 'speakerName': participant.name
329
+ })
330
+
331
+ async def on_debug(debug_info: dict):
332
+ """Called with debug information from ASR (wav2vec2 models only)"""
333
+ # Prepare debug data for transmission
334
+ debug_data = {
335
+ 'messageId': current_message_id,
336
+ 'text': debug_info['text'],
337
+ 'timestamps': debug_info['timestamps'],
338
+ 'audioData': list(debug_info['audio_data']),
339
+ 'audioDuration': debug_info['audio_duration'],
340
+ 'modelType': debug_info['model_type'],
341
+ 'language': participant.language.code.value
342
+ }
343
+
344
+ await self._broadcast_to_session(session_id, 'transcription_debug', debug_data)
345
+
346
+ async def on_sentence_complete(final_text: str, final_audio: bytes):
347
+ """Called when a complete sentence is detected"""
348
+
349
+ # Check if this message was already processed
350
+ if current_message_id in self.processed_messages:
351
+ print(f"Message {current_message_id} already processed, skipping duplicate")
352
+ return
353
+
354
+ # Mark as processed to prevent duplicates
355
+ self.processed_messages.add(current_message_id)
356
+
357
+ message.original_text = final_text
358
+ message.is_transcribing = False
359
+
360
+ # Broadcast complete sentence with session ID
361
+ message_data = {
362
+ 'messageId': current_message_id,
363
+ 'sessionId': session_id,
364
+ 'text': final_text,
365
+ 'speakerId': participant_id,
366
+ 'speakerName': participant.name,
367
+ 'language': participant.language.code.value,
368
+ 'audioData': list(final_audio)
369
+ }
370
+
371
+ print(f"Broadcasting message_complete for {current_message_id}: '{final_text}'")
372
+ await self._broadcast_to_session(session_id, 'message_complete', message_data)
373
+
374
+ # Stop typing indicator
375
+ await self._broadcast_to_session(session_id, 'typing_stop', {
376
+ 'speakerId': participant_id
377
+ })
378
+
379
+ # Clear current message tracking
380
+ if participant_id in self.participant_current_message:
381
+ del self.participant_current_message[participant_id]
382
+
383
+ # Start translation and TTS processing (non-blocking to allow continued audio processing)
384
+ print("Starting TRANSLATION and TTS (background task)")
385
+ asyncio.create_task(self._process_translations_and_tts(message, session))
386
+
387
+ # Process the audio chunk
388
+ result_text = await self.transcription_service.process_audio_chunk(
389
+ audio_data,
390
+ participant.language.code.value,
391
+ participant_id,
392
+ has_voice_activity,
393
+ progress_callback=on_progress,
394
+ sentence_callback=on_sentence_complete,
395
+ debug_callback=on_debug
396
+ )
397
+
398
+ # If this is a pause boundary (stop button clicked), force immediate finalization
399
+ if is_pause_boundary and participant_id in self.participant_current_message:
400
+ print(f"Pause boundary detected - forcing sentence finalization for participant {participant_id}")
401
+ # Get the current accumulated text from transcription service
402
+ if hasattr(self.transcription_service, 'candidate_text_cache') and participant_id in self.transcription_service.candidate_text_cache:
403
+ final_text = self.transcription_service.candidate_text_cache.get(participant_id, "").strip()
404
+ if final_text: # Only finalize if there's actual text
405
+ # Get accumulated audio
406
+ final_audio = b""
407
+ if hasattr(self.transcription_service, 'candidate_audio_buffers') and participant_id in self.transcription_service.candidate_audio_buffers:
408
+ audio_array = self.transcription_service.candidate_audio_buffers.get(participant_id, np.array([]))
409
+ if len(audio_array) > 0:
410
+ # Convert float array to int16 bytes
411
+ audio_int16 = (audio_array * 32767).astype(np.int16)
412
+ final_audio = audio_int16.tobytes()
413
+
414
+ # Trigger sentence completion
415
+ await on_sentence_complete(final_text, final_audio)
416
+
417
+ # Clear the buffers manually since we're forcing finalization
418
+ if participant_id in self.transcription_service.candidate_text_cache:
419
+ self.transcription_service.candidate_text_cache[participant_id] = ""
420
+ if participant_id in self.transcription_service.candidate_audio_buffers:
421
+ self.transcription_service.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
422
+ if participant_id in self.transcription_service.silence_counters:
423
+ self.transcription_service.silence_counters[participant_id] = 0
424
+ if participant_id in self.transcription_service.sentence_finalized:
425
+ self.transcription_service.sentence_finalized[participant_id] = False
426
+
427
+ except Exception as e:
428
+ print(f"Error in _process_audio_chunk_vad: {e}")
429
+ import traceback
430
+ traceback.print_exc()
431
+
432
+ async def _process_translations_and_tts(self, message: Message, session):
433
+ """Process translations and TTS for all session languages"""
434
+ try:
435
+ source_lang = message.original_language.name
436
+
437
+ print(f"=== TRANSLATION/TTS PROCESSING START ===")
438
+ print(f"Message ID: {message.id}")
439
+ print(f"Original message: '{message.original_text}'")
440
+ print(f"Original language: {message.original_language.name} ({message.original_language.code.value})")
441
+ print(f"Session languages: {[f'{lang.name} ({lang.code.value})' for lang in session.languages]}")
442
+ print(f"Session ID for verification: {session.id}")
443
+
444
+ # Create a mapping to track which audio belongs to which message and language
445
+ audio_tasks = []
446
+
447
+ # Check if TTS is enabled for this session
448
+ if session.enable_tts:
449
+ # First, generate TTS for the original message
450
+ print(f"TTS: Generating TTS for original message in {message.original_language.code.value}: '{message.original_text}'")
451
+ print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{message.original_language.name.lower()}) - File: tts_service_onnx.py")
452
+ original_audio_task = asyncio.create_task(
453
+ self.tts_service.generate_speech_dual_format(message.original_text, message.original_language.code.value)
454
+ )
455
+ audio_tasks.append((
456
+ message.original_language.code.value,
457
+ message.original_text,
458
+ original_audio_task,
459
+ True # is_original
460
+ ))
461
+ else:
462
+ print(f"TTS: Skipping TTS generation (disabled for this session)")
463
+
464
+ # Process translations for each language in the session
465
+ print(f"Processing translations for {len(session.languages)} session languages...")
466
+ print(f"Session languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
467
+ translation_tasks = []
468
+
469
+ for language in session.languages:
470
+ print(f"Checking language: {language.name} ({language.code.value})")
471
+ if language.code != message.original_language.code:
472
+ print(f"TRANSLATING: '{message.original_text}' from {source_lang} to {language.name}")
473
+ print(f"Translation Model: mutisya/nllb_600m (NLLB-600M) - File: translation_service.py")
474
+
475
+ # Create translation task
476
+ translation_task = asyncio.create_task(
477
+ self.translation_service.translate_text(
478
+ message.original_text, source_lang, language.name
479
+ )
480
+ )
481
+ translation_tasks.append((language, translation_task))
482
+ else:
483
+ print(f"SKIPPING translation for {language.name} (same as original language)")
484
+
485
+ print(f"Created {len(translation_tasks)} translation tasks for non-original languages")
486
+
487
+ # Wait for all translations to complete
488
+ for language, translation_task in translation_tasks:
489
+ try:
490
+ translated_text = await translation_task
491
+
492
+ if translated_text:
493
+ print(f"TRANSLATION SUCCESS: '{translated_text}' for {language.name}")
494
+
495
+ message.translations[language.code.value] = translated_text
496
+
497
+ # Broadcast translation update to all clients
498
+ await self._broadcast_to_session(message.session_id, 'translation_update', {
499
+ 'messageId': message.id,
500
+ 'targetLanguage': language.code.value,
501
+ 'translatedText': translated_text
502
+ })
503
+
504
+ # Check if TTS is enabled for this session
505
+ if session.enable_tts:
506
+ # Generate TTS for the translated text
507
+ print(f"TTS: Generating TTS for translation in {language.code.value}: '{translated_text}'")
508
+ print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{language.name.lower()}) - File: tts_service_onnx.py")
509
+ tts_task = asyncio.create_task(
510
+ self.tts_service.generate_speech_dual_format(translated_text, language.code.value)
511
+ )
512
+ audio_tasks.append((
513
+ language.code.value,
514
+ translated_text,
515
+ tts_task,
516
+ False # is_original
517
+ ))
518
+ else:
519
+ print(f"TTS: Skipping TTS generation for translation (disabled for this session)")
520
+ else:
521
+ print(f"TRANSLATION FAILED: No translated text returned for {language.name}")
522
+ except Exception as e:
523
+ print(f"Translation error for {language.name}: {e}")
524
+
525
+ # Wait for all TTS generation to complete and broadcast with proper alignment
526
+ for language_code, text, audio_task, is_original in audio_tasks:
527
+ try:
528
+ audio_result = await audio_task
529
+
530
+ if audio_result and (audio_result[0] or audio_result[1]):
531
+ webm_data, wav_data = audio_result
532
+ print(f"TTS: Audio generated successfully for {language_code}")
533
+ if webm_data:
534
+ print(f"TTS: WebM audio: {len(webm_data)} bytes")
535
+ if wav_data:
536
+ print(f"TTS: WAV audio: {len(wav_data)} bytes")
537
+ print(f"TTS: Text for {language_code}: '{text}'")
538
+
539
+ # Broadcast TTS audio with explicit message-text-audio alignment (dual format)
540
+ await self._broadcast_tts_audio_aligned_dual_format(
541
+ message.session_id,
542
+ message.id,
543
+ language_code,
544
+ text,
545
+ webm_data,
546
+ wav_data,
547
+ is_original
548
+ )
549
+ else:
550
+ print(f"TTS: Failed to generate audio for {language_code}")
551
+ except Exception as e:
552
+ print(f"TTS generation error for {language_code}: {e}")
553
+
554
+ print(f"=== TRANSLATION/TTS PROCESSING END ===")
555
+
556
+ except Exception as e:
557
+ print(f"Error in _process_translations_and_tts: {e}")
558
+ import traceback
559
+ traceback.print_exc()
560
+
561
+ async def _broadcast_to_session(self, session_id: str, event: str, data: dict, exclude_sid: str = None):
562
+ """Broadcast message to all clients in a session"""
563
+ if session_id not in self.session_clients:
564
+ return
565
+
566
+ # Create a copy of the set to avoid concurrent modification
567
+ client_sids = list(self.session_clients[session_id])
568
+
569
+ for sid in client_sids:
570
+ if sid != exclude_sid:
571
+ try:
572
+ await self.sio.emit(event, data, room=sid)
573
+ except Exception as e:
574
+ print(f"Error broadcasting {event} to client {sid}: {e}")
575
+
576
+ async def _broadcast_tts_audio_aligned(self, session_id: str, message_id: str,
577
+ language_code: str, text: str, audio_data: bytes,
578
+ is_original: bool = False):
579
+ """Broadcast TTS audio with explicit message-text-audio alignment"""
580
+ try:
581
+ if session_id not in self.session_clients:
582
+ return
583
+
584
+ print(f"TTS ALIGNED: Broadcasting audio for message {message_id}")
585
+ print(f"TTS ALIGNED: Language: {language_code}")
586
+ print(f"TTS ALIGNED: Text: '{text}'")
587
+ print(f"TTS ALIGNED: Audio size: {len(audio_data)} bytes")
588
+ print(f"TTS ALIGNED: Is original: {is_original}")
589
+
590
+ # Create a copy of the set to avoid concurrent modification
591
+ client_sids = list(self.session_clients[session_id])
592
+
593
+ # Send audio data in chunks to all participants with explicit alignment data
594
+ for sid in client_sids:
595
+ try:
596
+ chunk_size = 4096
597
+ for i in range(0, len(audio_data), chunk_size):
598
+ chunk = audio_data[i:i + chunk_size]
599
+ is_last_chunk = i + chunk_size >= len(audio_data)
600
+
601
+ chunk_data = {
602
+ 'messageId': message_id, # Explicit message ID
603
+ 'languageCode': language_code, # Language of THIS audio
604
+ 'text': text, # Text that THIS audio represents
605
+ 'isOriginal': is_original, # Whether this is original or translation
606
+ 'chunk': list(chunk),
607
+ 'isLast': is_last_chunk,
608
+ 'chunkIndex': i // chunk_size, # Chunk ordering
609
+ 'totalChunks': (len(audio_data) + chunk_size - 1) // chunk_size
610
+ }
611
+
612
+ await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
613
+
614
+ # Small delay to prevent overwhelming
615
+ await asyncio.sleep(0.01)
616
+
617
+ print(f"TTS ALIGNED: Successfully sent aligned audio to participant {sid}")
618
+ except Exception as e:
619
+ print(f"TTS ALIGNED: Error sending audio to participant {sid}: {e}")
620
+
621
+ except Exception as e:
622
+ print(f"TTS ALIGNED: Error broadcasting aligned audio: {e}")
623
+
624
+ async def _broadcast_tts_audio_aligned_dual_format(self, session_id: str, message_id: str,
625
+ language_code: str, text: str, webm_data: bytes,
626
+ wav_data: bytes, is_original: bool = False):
627
+ """Broadcast TTS audio with both WebM and WAV formats for cross-platform compatibility"""
628
+ try:
629
+ if session_id not in self.session_clients:
630
+ return
631
+
632
+ print(f"TTS DUAL FORMAT: Broadcasting audio for message {message_id}")
633
+ print(f"TTS DUAL FORMAT: Language: {language_code}")
634
+ print(f"TTS DUAL FORMAT: Text: '{text}'")
635
+ if webm_data:
636
+ print(f"TTS DUAL FORMAT: WebM size: {len(webm_data)} bytes")
637
+ if wav_data:
638
+ print(f"TTS DUAL FORMAT: WAV size: {len(wav_data)} bytes")
639
+ print(f"TTS DUAL FORMAT: Is original: {is_original}")
640
+
641
+ # Create a copy of the set to avoid concurrent modification
642
+ client_sids = list(self.session_clients[session_id])
643
+
644
+ # Use WebM data for chunking (primary format for web clients)
645
+ primary_audio_data = webm_data if webm_data else wav_data
646
+ if not primary_audio_data:
647
+ print("TTS DUAL FORMAT: No audio data available")
648
+ return
649
+
650
+ # Send audio data in chunks to all participants with dual format support
651
+ chunk_size = 4096
652
+ for sid in client_sids:
653
+ try:
654
+ for i in range(0, len(primary_audio_data), chunk_size):
655
+ chunk = primary_audio_data[i:i + chunk_size]
656
+ is_last_chunk = i + chunk_size >= len(primary_audio_data)
657
+
658
+ # Prepare WAV chunk if available
659
+ wav_chunk = None
660
+ if wav_data and i < len(wav_data):
661
+ wav_end = min(i + chunk_size, len(wav_data))
662
+ wav_chunk = wav_data[i:wav_end]
663
+
664
+ chunk_data = {
665
+ 'messageId': message_id, # Explicit message ID
666
+ 'languageCode': language_code, # Language of THIS audio
667
+ 'text': text, # Text that THIS audio represents
668
+ 'isOriginal': is_original, # Whether this is original or translation
669
+ 'chunk': list(chunk), # WebM audio chunk (for web clients)
670
+ 'wavChunk': list(wav_chunk) if wav_chunk else None, # WAV audio chunk (for Android clients)
671
+ 'isLast': is_last_chunk,
672
+ 'chunkIndex': i // chunk_size, # Chunk ordering
673
+ 'totalChunks': (len(primary_audio_data) + chunk_size - 1) // chunk_size,
674
+ 'format': 'webm', # Primary format
675
+ 'wavFormat': 'wav' if wav_chunk else None # Secondary format available
676
+ }
677
+
678
+ await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
679
+
680
+ # Small delay to prevent overwhelming
681
+ await asyncio.sleep(0.01)
682
+
683
+ print(f"TTS DUAL FORMAT: Successfully sent dual format audio to participant {sid}")
684
+ except Exception as e:
685
+ print(f"TTS DUAL FORMAT: Error sending audio to participant {sid}: {e}")
686
+
687
+ except Exception as e:
688
+ print(f"TTS DUAL FORMAT: Error broadcasting dual format audio: {e}")
689
+
690
+ async def _broadcast_tts_audio_to_all_participants(self, session_id: str, language_code: str,
691
+ audio_data: bytes, message_id: str, text: str):
692
+ """Legacy method - now calls the aligned version"""
693
+ await self._broadcast_tts_audio_aligned(
694
+ session_id, message_id, language_code, text, audio_data, False
695
+ )
696
+
697
+ async def _broadcast_audio_to_language_participants(self, session_id: str, language_code: str,
698
+ audio_data: bytes, message_id: str):
699
+ """Broadcast audio to participants listening in specific language (legacy method)"""
700
+ try:
701
+ session = await self.session_manager.get_session(session_id)
702
+ if not session:
703
+ return
704
+
705
+ # Find participants with matching language
706
+ target_participants = [p for p in session.participants if p.language.code.value == language_code]
707
+
708
+ for participant in target_participants:
709
+ # Find client SID for this participant
710
+ participant_sid = None
711
+ for sid, pid in self.client_participants.items():
712
+ if pid == participant.id:
713
+ participant_sid = sid
714
+ break
715
+
716
+ if participant_sid:
717
+ print(f"TTS: Broadcasting audio to participant {participant.name} in {language_code}")
718
+ # Send audio data in chunks
719
+ chunk_size = 4096
720
+ for i in range(0, len(audio_data), chunk_size):
721
+ chunk = audio_data[i:i + chunk_size]
722
+ await self.sio.emit('tts_audio_chunk', {
723
+ 'messageId': message_id,
724
+ 'chunk': list(chunk),
725
+ 'isLast': i + chunk_size >= len(audio_data)
726
+ }, room=participant_sid)
727
+
728
+ # Small delay to prevent overwhelming
729
+ await asyncio.sleep(0.01)
730
+
731
+ except Exception as e:
732
+ print(f"TTS: Error broadcasting audio: {e}")
733
+
734
+ async def _cleanup_client(self, sid: str):
735
+ """Clean up client data on disconnect"""
736
+ try:
737
+ participant_id = self.client_participants.get(sid)
738
+ session_id = self.client_sessions.get(sid)
739
+
740
+ if participant_id:
741
+ # Remove participant from session
742
+ await self.session_manager.remove_participant(participant_id)
743
+
744
+ # Clear participant buffers
745
+ self.transcription_service.clear_participant_buffers(participant_id)
746
+
747
+ # Clear current message tracking
748
+ if participant_id in self.participant_current_message:
749
+ del self.participant_current_message[participant_id]
750
+
751
+ del self.client_participants[sid]
752
+
753
+ if session_id:
754
+ # Remove from session clients
755
+ if session_id in self.session_clients:
756
+ self.session_clients[session_id].discard(sid)
757
+ if not self.session_clients[session_id]:
758
+ del self.session_clients[session_id]
759
+ # If session is empty, clear processed messages for this session
760
+ self._cleanup_session_processed_messages(session_id)
761
+
762
+ del self.client_sessions[sid]
763
+
764
+ except Exception as e:
765
+ print(f"Error cleaning up client {sid}: {e}")
766
+
767
+ def _cleanup_session_processed_messages(self, session_id: str):
768
+ """Clean up processed messages for empty sessions to prevent memory leaks"""
769
+ try:
770
+ # Remove processed messages that belong to this session
771
+ messages_to_remove = []
772
+ for message_id in list(self.processed_messages):
773
+ if message_id in self.messages and self.messages[message_id].session_id == session_id:
774
+ messages_to_remove.append(message_id)
775
+
776
+ for message_id in messages_to_remove:
777
+ self.processed_messages.discard(message_id)
778
+ if message_id in self.messages:
779
+ del self.messages[message_id]
780
+
781
+ print(f"Cleaned up {len(messages_to_remove)} processed messages for session {session_id}")
782
+ except Exception as e:
783
+ print(f"Error cleaning up session processed messages: {e}")
784
+
785
+ async def _emit_error(self, sid: str, message: str):
786
+ """Emit error message to specific client"""
787
+ try:
788
+ await self.sio.emit('join_error', message, room=sid)
789
+ except Exception as e:
790
+ print(f"Error emitting error to {sid}: {e}")
791
+
792
+ async def handle_update_participant_language(self, sid: str, data: dict):
793
+ """Handle participant language update (affects speech recognition)"""
794
+ try:
795
+ session_id = data.get('sessionId')
796
+ participant_id = data.get('participantId')
797
+ language_code = data.get('language')
798
+
799
+ print(f"=== UPDATE PARTICIPANT LANGUAGE ===")
800
+ print(f"Session ID: {session_id}")
801
+ print(f"Participant ID: {participant_id}")
802
+ print(f"New Language: {language_code}")
803
+
804
+ if not all([session_id, participant_id, language_code]):
805
+ await self._emit_error(sid, "Missing required fields")
806
+ return
807
+
808
+ # Validate language code
809
+ try:
810
+ from app.models import LanguageCode
811
+ lang_enum = LanguageCode(language_code)
812
+ print(f"Language code validated: {lang_enum}")
813
+ except ValueError:
814
+ await self._emit_error(sid, f"Invalid language code: {language_code}")
815
+ return
816
+
817
+ # Update participant's language in session
818
+ session = await self.session_manager.get_session(session_id)
819
+ if session:
820
+ for participant in session.participants:
821
+ if participant.id == participant_id:
822
+ # Update participant's language using LANGUAGE_MAP for complete Language object
823
+ if lang_enum in LANGUAGE_MAP:
824
+ participant.language = LANGUAGE_MAP[lang_enum]
825
+ print(f"Updated participant {participant.name} language to {lang_enum.value} ({participant.language.display_name})")
826
+ else:
827
+ print(f"Warning: Language {lang_enum.value} not found in LANGUAGE_MAP, using fallback")
828
+ from app.models import Language
829
+ participant.language = Language(code=lang_enum, name=lang_enum.value, display_name=lang_enum.value)
830
+
831
+ # Notify all clients in session
832
+ await self._broadcast_to_session(session_id, 'participant_language_updated', {
833
+ 'participantId': participant_id,
834
+ 'language': language_code
835
+ })
836
+ break
837
+
838
+ print(f"=== UPDATE PARTICIPANT LANGUAGE COMPLETE ===")
839
+
840
+ except Exception as e:
841
+ print(f"Error in handle_update_participant_language: {e}")
842
+ import traceback
843
+ traceback.print_exc()
844
+ await self._emit_error(sid, "Failed to update participant language")
845
+
846
+ async def handle_update_session_languages(self, sid: str, data: dict):
847
+ """Handle session languages update (affects translation targets)"""
848
+ try:
849
+ session_id = data.get('sessionId')
850
+ languages = data.get('languages', [])
851
+
852
+ print(f"=== UPDATE SESSION LANGUAGES (REPLACE MODE) ===")
853
+ print(f"Session ID: {session_id}")
854
+ print(f"New Languages: {languages}")
855
+
856
+ if not session_id or not languages:
857
+ await self._emit_error(sid, "Missing required fields")
858
+ return
859
+
860
+ # Get current session for comparison
861
+ session = await self.session_manager.get_session(session_id)
862
+ if not session:
863
+ await self._emit_error(sid, "Session not found")
864
+ return
865
+
866
+ current_languages = [lang.code.value for lang in session.languages]
867
+ print(f"Before update - Session languages: {current_languages}")
868
+
869
+ # Validate all language codes and create Language objects
870
+ validated_languages = []
871
+ try:
872
+ from app.models import Language, LanguageCode
873
+ from app.services.session_manager import LANGUAGE_MAP
874
+
875
+ for lang_code in languages:
876
+ lang_enum = LanguageCode(lang_code)
877
+ language = LANGUAGE_MAP[lang_enum]
878
+ validated_languages.append(language)
879
+ print(f"Validated language: {lang_code} -> {language.name}")
880
+
881
+ except ValueError as e:
882
+ await self._emit_error(sid, f"Invalid language code: {e}")
883
+ return
884
+
885
+ # REPLACE session languages (not add to them)
886
+ session.languages = validated_languages
887
+ new_languages = [lang.code.value for lang in session.languages]
888
+ print(f"After update - Session languages: {new_languages}")
889
+
890
+ # Verify the session manager has the updated languages
891
+ verification_session = await self.session_manager.get_session(session_id)
892
+ if verification_session:
893
+ verification_languages = [lang.code.value for lang in verification_session.languages]
894
+ print(f"Verification - Session manager languages: {verification_languages}")
895
+
896
+ # Notify all clients in session about the update
897
+ await self._broadcast_to_session(session_id, 'session_languages_updated', {
898
+ 'sessionId': session_id,
899
+ 'languages': new_languages,
900
+ 'previous': current_languages
901
+ })
902
+
903
+ print(f"=== UPDATE SESSION LANGUAGES COMPLETE ===")
904
+
905
+ except Exception as e:
906
+ print(f"Error in handle_update_session_languages: {e}")
907
+ import traceback
908
+ traceback.print_exc()
909
+ await self._emit_error(sid, "Failed to update session languages")
preload_models.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ import io
3
+ import nltk
4
+ import json
5
+ import torch
6
+ import os
7
+ import sys
8
+
9
+ import gc
10
+
11
+
12
+ if len(sys.argv) > 1:
13
+ os.environ["HUGGING_FACE_HUB_TOKEN"] = sys.argv[1]
14
+
15
+ nltk.download("punkt")
16
+ nltk.download('punkt_tab')
17
+
18
+ device = 0 if torch.cuda.is_available() else -1
19
+
20
+ def cleanup_model_resource(model):
21
+ del model
22
+ gc.collect()
23
+ torch.cuda.empty_cache()
requirements.txt ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Web framework and server
2
+ fastapi==0.115.5
3
+ uvicorn[standard]==0.32.1
4
+ websockets==13.1
5
+ python-socketio==5.11.4
6
+ python-multipart==0.0.17
7
+ pydantic==2.10.3
8
+
9
+ # PyTorch ecosystem - latest stable versions
10
+ torch==2.5.1
11
+ torchaudio==2.5.1
12
+ transformers==4.45.2
13
+ datasets==3.1.0
14
+ tokenizers==0.20.4
15
+ accelerate==1.2.1
16
+
17
+ # ONNX Runtime for optimized inference - GPU enabled
18
+ onnxruntime-gpu==1.19.2
19
+ onnx==1.17.0
20
+ optimum[onnxruntime-gpu]==1.23.0
21
+ huggingface-hub==0.26.2
22
+
23
+ # Audio processing
24
+ soundfile==0.12.1
25
+ librosa==0.10.2
26
+ phonemizer==3.3.0
27
+ pydub==0.25.1
28
+
29
+ # Scientific computing
30
+ scipy==1.14.1
31
+ numpy==2.1.3
32
+
33
+ # Natural language processing
34
+ nltk==3.9.1
35
+ sentencepiece==0.2.0
36
+
37
+ # Computer vision and image processing
38
+ pillow==11.0.0
39
+ qrcode[pil]==8.0
40
+
41
+ # Authentication and security
42
+ python-jose[cryptography]==3.3.0
43
+ passlib[bcrypt]==1.7.4
44
+
45
+ # File handling
46
+ aiofiles==24.1.0
47
+
48
+ # Model optimization
49
+ bitsandbytes==0.45.0
50
+
51
+ # Protocol buffers - compatible version
52
+ protobuf==5.28.3
53
+
54
+ # Speech processing
55
+ speechbrain==1.0.2
56
+
57
+ # Voice Activity Detection
58
+ silero-vad>=5.1