mutisya commited on
Commit
52ed889
·
verified ·
1 Parent(s): d6e8bff

Remove app folder - code will be downloaded from private code space during build

Browse files
app/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Backend application package
 
 
app/auth.py DELETED
@@ -1,310 +0,0 @@
1
- """
2
- Authentication module for HuggingFace token validation
3
- """
4
- import os
5
- from typing import Optional
6
- from fastapi import HTTPException, status, Request
7
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
8
- from fastapi.security.utils import get_authorization_scheme_param
9
-
10
-
11
- def is_local_development() -> bool:
12
- """
13
- Detect if the application is running in local development mode.
14
- This checks multiple indicators to determine if auth should be disabled.
15
- """
16
- # Method 1: Explicit disable auth flag
17
- disable_auth = os.getenv('DISABLE_AUTH', '').lower()
18
- if disable_auth in ['true', '1', 'yes']:
19
- return True
20
-
21
- # Method 2: Check ENVIRONMENT variable
22
- environment = os.getenv('ENVIRONMENT', '').lower()
23
- if environment in ['development', 'dev', 'local']:
24
- return True
25
-
26
- # Method 3: Check DEBUG flag
27
- debug = os.getenv('DEBUG', '').lower()
28
- if debug in ['true', '1', 'yes']:
29
- return True
30
-
31
- # Method 4: Check if running on localhost/development ports
32
- host = os.getenv('HOST', '')
33
- port = os.getenv('PORT', '')
34
- if host in ['localhost', '127.0.0.1', '0.0.0.0'] and port == '7860':
35
- return True
36
-
37
- # Method 5: Check for presence of local development files
38
- local_indicators = [
39
- '.env.local',
40
- 'docker-compose.local.yml',
41
- 'Dockerfile.local'
42
- ]
43
- for indicator in local_indicators:
44
- if os.path.exists(indicator):
45
- return True
46
-
47
- # Method 6: Check if we're in a Docker container with local development setup
48
- if os.path.exists('/.dockerenv'):
49
- # We're in Docker, check if it's local development
50
- if os.getenv('ALLOW_ALL_ORIGINS', '').lower() == 'true':
51
- return True
52
-
53
- return False
54
-
55
-
56
- class HuggingFaceTokenAuth:
57
- """HuggingFace token authentication handler"""
58
-
59
- def __init__(self):
60
- self.bearer = HTTPBearer(auto_error=False)
61
- self.is_local = is_local_development()
62
-
63
- if self.is_local:
64
- print("🔓 RUNNING IN LOCAL DEVELOPMENT MODE - AUTH DISABLED")
65
- print(" Environment indicators:")
66
- print(f" - DISABLE_AUTH: {os.getenv('DISABLE_AUTH', 'not set')}")
67
- print(f" - ENVIRONMENT: {os.getenv('ENVIRONMENT', 'not set')}")
68
- print(f" - DEBUG: {os.getenv('DEBUG', 'not set')}")
69
- print(f" - HOST: {os.getenv('HOST', 'not set')}")
70
- print(f" - PORT: {os.getenv('PORT', 'not set')}")
71
- print(f" - ALLOW_ALL_ORIGINS: {os.getenv('ALLOW_ALL_ORIGINS', 'not set')}")
72
- print(f" - Docker container: {os.path.exists('/.dockerenv')}")
73
- print(f" - .env.local exists: {os.path.exists('.env.local')}")
74
- else:
75
- print("🔒 RUNNING IN PRODUCTION MODE - AUTH REQUIRED")
76
-
77
- def verify_token(self, token: str) -> bool:
78
- """
79
- Verify if the token is a valid HuggingFace token.
80
- In local development mode, always returns True.
81
- """
82
- # Skip token validation in local development
83
- if self.is_local:
84
- print("🔓 Local development mode: skipping token validation")
85
- return True
86
-
87
- try:
88
- if not token:
89
- return False
90
-
91
- if not isinstance(token, str):
92
- print(f"❌ Token is not a string: {type(token)}")
93
- return False
94
-
95
- # HuggingFace tokens start with 'hf_'
96
- if not token.startswith('hf_'):
97
- print(f"❌ Token does not start with 'hf_': {token[:10]}...")
98
- return False
99
-
100
- # Additional validation can be added here
101
- # For example, you could make a request to HuggingFace API
102
- # to validate the token, but that would add latency
103
-
104
- return True
105
-
106
- except Exception as e:
107
- print(f"❌ Error in verify_token: {e}")
108
- return False
109
-
110
- def get_token_from_request(self, request: Request) -> Optional[str]:
111
- """Extract token from various sources in the request"""
112
-
113
- # Method 1: Authorization header
114
- authorization = request.headers.get("Authorization")
115
- if authorization:
116
- scheme, token = get_authorization_scheme_param(authorization)
117
- if scheme.lower() == "bearer":
118
- return token
119
-
120
- # Method 2: Query parameter (for WebSocket initial handshake)
121
- token = request.query_params.get("token")
122
- if token:
123
- return token
124
-
125
- # Method 3: Custom header (alternative)
126
- token = request.headers.get("X-HF-Token")
127
- if token:
128
- return token
129
-
130
- return None
131
-
132
- async def authenticate_request(self, request: Request) -> bool:
133
- """Authenticate a request using HuggingFace token"""
134
- token = self.get_token_from_request(request)
135
-
136
- if not token:
137
- return False
138
-
139
- return self.verify_token(token)
140
-
141
-
142
- # Global instance
143
- hf_auth = HuggingFaceTokenAuth()
144
-
145
-
146
- async def require_hf_token(request: Request) -> str:
147
- """
148
- FastAPI dependency that requires a valid HuggingFace token.
149
- In local development mode, returns a dummy token.
150
- Returns the token if valid, raises HTTPException if not.
151
- """
152
- # Skip authentication in local development
153
- if hf_auth.is_local:
154
- print("🔓 Local development mode: bypassing HF token requirement")
155
- return "local-development-bypass"
156
-
157
- token = hf_auth.get_token_from_request(request)
158
-
159
- if not token:
160
- raise HTTPException(
161
- status_code=status.HTTP_401_UNAUTHORIZED,
162
- detail="HuggingFace token required. Please provide a valid token in Authorization header.",
163
- headers={"WWW-Authenticate": "Bearer"},
164
- )
165
-
166
- if not hf_auth.verify_token(token):
167
- raise HTTPException(
168
- status_code=status.HTTP_401_UNAUTHORIZED,
169
- detail="Invalid HuggingFace token. Token must start with 'hf_'.",
170
- headers={"WWW-Authenticate": "Bearer"},
171
- )
172
-
173
- return token
174
-
175
-
176
- async def optional_hf_token(request: Request) -> Optional[str]:
177
- """
178
- FastAPI dependency that optionally validates HuggingFace token.
179
- In local development mode, returns a dummy token if no real token provided.
180
- Returns the token if present and valid, None otherwise.
181
- Useful for endpoints that work with or without authentication.
182
- """
183
- # In local development, always return a token
184
- if hf_auth.is_local:
185
- token = hf_auth.get_token_from_request(request)
186
- if token and hf_auth.verify_token(token):
187
- return token
188
- else:
189
- print("🔓 Local development mode: providing dummy token for optional auth")
190
- return "local-development-bypass"
191
-
192
- token = hf_auth.get_token_from_request(request)
193
-
194
- if not token:
195
- return None
196
-
197
- if hf_auth.verify_token(token):
198
- return token
199
-
200
- return None
201
-
202
-
203
- def authenticate_websocket_connect(environ: dict) -> bool:
204
- """
205
- Authenticate WebSocket connection using token from various sources.
206
- In local development mode, always returns True.
207
- This is called during the Socket.IO connect event.
208
- """
209
- # Skip authentication in local development
210
- if hf_auth.is_local:
211
- print("🔓 Local development mode: bypassing WebSocket authentication")
212
- return True
213
-
214
- try:
215
- print("=== WEBSOCKET ENVIRON AUTHENTICATION ===")
216
- print(f"Environ type: {type(environ)}")
217
-
218
- if not isinstance(environ, dict):
219
- print(f"❌ Environ is not a dict: {type(environ)}")
220
- return False
221
-
222
- # Method 1: Check query parameters
223
- query_string = environ.get('QUERY_STRING', '')
224
- print(f"Query string: {query_string}")
225
- if query_string:
226
- from urllib.parse import parse_qs
227
- query_params = parse_qs(query_string)
228
- print(f"Parsed query params: {query_params}")
229
- tokens = query_params.get('token', [])
230
- if tokens:
231
- token = tokens[0]
232
- print(f"Found token in query: {token[:10]}...")
233
- if hf_auth.verify_token(token):
234
- print("✓ Token validated via query params")
235
- return True
236
-
237
- # Method 2: Check headers
238
- auth_header = environ.get('HTTP_AUTHORIZATION', '')
239
- print(f"Authorization header: {auth_header[:20] if auth_header else 'None'}...")
240
- if auth_header:
241
- if auth_header.startswith('Bearer '):
242
- token = auth_header[7:] # Remove 'Bearer ' prefix
243
- print(f"Found token in Authorization header: {token[:10]}...")
244
- if hf_auth.verify_token(token):
245
- print("✓ Token validated via Authorization header")
246
- return True
247
-
248
- # Method 3: Check custom header
249
- hf_token_header = environ.get('HTTP_X_HF_TOKEN', '')
250
- print(f"X-HF-Token header: {hf_token_header[:10] if hf_token_header else 'None'}...")
251
- if hf_token_header:
252
- if hf_auth.verify_token(hf_token_header):
253
- print("✓ Token validated via X-HF-Token header")
254
- return True
255
-
256
- print("❌ No valid token found in environ")
257
- print(f"Available environ keys: {list(environ.keys())}")
258
- return False
259
-
260
- except Exception as e:
261
- print(f"❌ Error in authenticate_websocket_connect: {e}")
262
- import traceback
263
- traceback.print_exc()
264
- return False
265
-
266
-
267
- def authenticate_websocket_auth_data(auth_data: dict) -> bool:
268
- """
269
- Authenticate WebSocket connection using auth data from Socket.IO.
270
- In local development mode, always returns True.
271
- This is called when the client sends auth data in the connection.
272
- """
273
- # Skip authentication in local development
274
- if hf_auth.is_local:
275
- print("🔓 Local development mode: bypassing WebSocket auth data validation")
276
- return True
277
-
278
- try:
279
- print("=== WEBSOCKET AUTH DATA AUTHENTICATION ===")
280
- print(f"Auth data received: {auth_data}")
281
- print(f"Auth data type: {type(auth_data)}")
282
-
283
- if not auth_data:
284
- print("❌ No auth data provided")
285
- return False
286
-
287
- if not isinstance(auth_data, dict):
288
- print(f"❌ Auth data is not a dict: {type(auth_data)}")
289
- return False
290
-
291
- # Check for token in auth data
292
- token = auth_data.get('token')
293
- if token:
294
- print(f"Found token in auth data: {token[:10]}...")
295
- if hf_auth.verify_token(token):
296
- print("✓ Token validated via auth data")
297
- return True
298
- else:
299
- print("❌ Invalid token in auth data")
300
- else:
301
- print("❌ No token in auth data")
302
- print(f"Available keys in auth data: {list(auth_data.keys())}")
303
-
304
- return False
305
-
306
- except Exception as e:
307
- print(f"❌ Error in authenticate_websocket_auth_data: {e}")
308
- import traceback
309
- traceback.print_exc()
310
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/config/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- """
2
- Configuration package for Polyglot backend
3
- """
4
-
5
- from .cors import cors_config
6
-
7
- __all__ = ["cors_config"]
 
 
 
 
 
 
 
 
app/config/cors.py DELETED
@@ -1,295 +0,0 @@
1
- """
2
- CORS Configuration Module
3
-
4
- Centralized CORS configuration supporting multiple deployment environments.
5
- """
6
-
7
- import os
8
- import re
9
- from typing import List, Optional
10
- from enum import Enum
11
-
12
-
13
- class Environment(str, Enum):
14
- """Deployment environment types"""
15
- LOCAL = "local"
16
- DEVELOPMENT = "development"
17
- STAGING = "staging"
18
- PRODUCTION = "production"
19
-
20
-
21
- class CORSConfig:
22
- """CORS configuration manager"""
23
-
24
- # Default origins for local development
25
- DEFAULT_LOCAL_ORIGINS = [
26
- "http://localhost:3000", # React/Next.js dev server
27
- "http://localhost:3001", # Polyglot frontend (Vite)
28
- "http://localhost:3002", # Lessons UI (Vite)
29
- "http://localhost:3003", # Podium (Vite)
30
- "http://localhost:3004", # Podium alternative port
31
- "http://localhost:5173", # Vite dev server
32
- "http://localhost:7860", # Backend self-reference
33
- "http://localhost:8080", # Alternative dev server
34
- "http://127.0.0.1:3000", # IPv4 localhost variant
35
- "http://127.0.0.1:3001", # IPv4 localhost variant
36
- "http://127.0.0.1:3002", # IPv4 localhost variant
37
- "http://127.0.0.1:3003", # IPv4 localhost variant
38
- "http://127.0.0.1:3004", # IPv4 localhost variant
39
- "http://127.0.0.1:5173", # IPv4 localhost variant
40
- "http://127.0.0.1:7860", # IPv4 localhost variant
41
- ]
42
-
43
- # Default patterns for production deployments
44
- DEFAULT_PRODUCTION_PATTERNS = [
45
- r"^https://.*\.tafiti\.dev$", # Tafiti production/staging
46
- r"^https://.*\.vercel\.app$", # Vercel deployments
47
- r"^https://.*\.hf\.space$", # HuggingFace Spaces
48
- r"^https://milimani\.tafiti-api\.org$", # Production API
49
- ]
50
-
51
- # Mobile app protocols
52
- MOBILE_PROTOCOLS = [
53
- "capacitor://localhost", # Capacitor apps
54
- "ionic://localhost", # Ionic apps
55
- "http://localhost", # Mobile WebView
56
- ]
57
-
58
- def __init__(self):
59
- self.environment = self._get_environment()
60
- self.allowed_origins = self._build_allowed_origins()
61
- self.allow_all = self._should_allow_all()
62
- self.origin_patterns = self._build_origin_patterns()
63
-
64
- def _get_environment(self) -> Environment:
65
- """Get current deployment environment"""
66
- env_str = os.getenv("ENVIRONMENT", "local").lower()
67
-
68
- try:
69
- return Environment(env_str)
70
- except ValueError:
71
- print(f"⚠️ Unknown environment '{env_str}', defaulting to 'local'")
72
- return Environment.LOCAL
73
-
74
- def _should_allow_all(self) -> bool:
75
- """Check if CORS should allow all origins (insecure, dev only)"""
76
- allow_all = os.getenv("CORS_ALLOW_ALL", "false").lower()
77
-
78
- if allow_all == "true":
79
- if self.environment == Environment.PRODUCTION:
80
- print("❌ ERROR: CORS_ALLOW_ALL=true is not allowed in production")
81
- return False
82
- else:
83
- print("⚠️ WARNING: CORS allowing all origins - INSECURE, use only for development")
84
- return True
85
-
86
- return False
87
-
88
- def _build_allowed_origins(self) -> List[str]:
89
- """Build list of allowed origins from environment and defaults"""
90
- origins = []
91
-
92
- # Get custom origins from environment variable
93
- custom_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "")
94
-
95
- if custom_origins_str:
96
- # Parse comma-separated origins
97
- custom_origins = [
98
- origin.strip()
99
- for origin in custom_origins_str.split(",")
100
- if origin.strip()
101
- ]
102
- origins.extend(custom_origins)
103
- print(f"✓ Loaded {len(custom_origins)} custom CORS origins from environment")
104
-
105
- # Add defaults based on environment
106
- if self.environment == Environment.LOCAL:
107
- origins.extend(self.DEFAULT_LOCAL_ORIGINS)
108
- print(f"✓ Added {len(self.DEFAULT_LOCAL_ORIGINS)} default local origins")
109
-
110
- # Always include mobile protocols in non-production
111
- if self.environment != Environment.PRODUCTION:
112
- origins.extend(self.MOBILE_PROTOCOLS)
113
- print(f"✓ Added {len(self.MOBILE_PROTOCOLS)} mobile protocol origins")
114
-
115
- # Remove duplicates while preserving order
116
- seen = set()
117
- unique_origins = []
118
- for origin in origins:
119
- if origin not in seen:
120
- seen.add(origin)
121
- unique_origins.append(origin)
122
-
123
- return unique_origins
124
-
125
- def _build_origin_patterns(self) -> List[re.Pattern]:
126
- """Build regex patterns for origin matching"""
127
- patterns = []
128
-
129
- # Get custom patterns from environment
130
- custom_patterns_str = os.getenv("CORS_ALLOWED_PATTERNS", "")
131
-
132
- if custom_patterns_str:
133
- custom_pattern_strs = [
134
- p.strip()
135
- for p in custom_patterns_str.split(",")
136
- if p.strip()
137
- ]
138
-
139
- for pattern_str in custom_pattern_strs:
140
- try:
141
- patterns.append(re.compile(pattern_str))
142
- except re.error as e:
143
- print(f"⚠️ Invalid regex pattern '{pattern_str}': {e}")
144
-
145
- print(f"✓ Loaded {len(patterns)} custom CORS patterns from environment")
146
-
147
- # Add default production patterns if in production/staging/development
148
- if self.environment in [Environment.PRODUCTION, Environment.STAGING, Environment.DEVELOPMENT]:
149
- for pattern_str in self.DEFAULT_PRODUCTION_PATTERNS:
150
- patterns.append(re.compile(pattern_str))
151
-
152
- print(f"✓ Added {len(self.DEFAULT_PRODUCTION_PATTERNS)} default production patterns")
153
-
154
- # Add localhost pattern for development
155
- if self.environment == Environment.LOCAL:
156
- patterns.append(re.compile(r"^http://localhost:\d+$"))
157
- patterns.append(re.compile(r"^http://127\.0\.0\.1:\d+$"))
158
- print("✓ Added localhost wildcard patterns for development")
159
-
160
- return patterns
161
-
162
- def is_origin_allowed(self, origin: str) -> bool:
163
- """
164
- Check if an origin is allowed based on explicit list or patterns
165
-
166
- Args:
167
- origin: Origin to check (e.g., "https://app.tafiti.dev")
168
-
169
- Returns:
170
- True if origin is allowed, False otherwise
171
- """
172
- # If allow_all is enabled (dev only)
173
- if self.allow_all:
174
- return True
175
-
176
- # Check explicit origins list
177
- if origin in self.allowed_origins:
178
- return True
179
-
180
- # Check against patterns
181
- for pattern in self.origin_patterns:
182
- if pattern.match(origin):
183
- return True
184
-
185
- return False
186
-
187
- def get_cors_middleware_config(self) -> dict:
188
- """Get configuration dict for FastAPI CORSMiddleware"""
189
- if self.allow_all:
190
- return {
191
- "allow_origins": ["*"],
192
- "allow_credentials": False, # Cannot use credentials with wildcard
193
- "allow_methods": ["*"],
194
- "allow_headers": ["*"],
195
- }
196
-
197
- # Build origin regex for pattern matching
198
- if self.origin_patterns:
199
- # Combine all patterns into a single regex
200
- combined_pattern = "|".join(f"({p.pattern})" for p in self.origin_patterns)
201
-
202
- return {
203
- "allow_origins": self.allowed_origins,
204
- "allow_origin_regex": combined_pattern,
205
- "allow_credentials": True,
206
- "allow_methods": ["*"],
207
- "allow_headers": ["*"],
208
- }
209
- else:
210
- return {
211
- "allow_origins": self.allowed_origins,
212
- "allow_credentials": True,
213
- "allow_methods": ["*"],
214
- "allow_headers": ["*"],
215
- }
216
-
217
- def get_socketio_cors_origins(self):
218
- """
219
- Get CORS origins for Socket.IO
220
-
221
- Socket.IO doesn't support regex patterns, so we need to provide explicit list.
222
- For production, this means we need to enumerate common origins.
223
- """
224
- if self.allow_all:
225
- return "*"
226
-
227
- # For Socket.IO, we can only provide explicit origins
228
- # In production, we may need to enumerate common subdomains
229
- socketio_origins = self.allowed_origins.copy()
230
-
231
- # Add common production subdomains if using production patterns
232
- if self.environment in [Environment.PRODUCTION, Environment.STAGING]:
233
- # These should be added to CORS_ALLOWED_ORIGINS for Socket.IO support
234
- production_origins = [
235
- "https://app.tafiti.dev",
236
- "https://www.tafiti.dev",
237
- "https://polyglot.tafiti.dev",
238
- "https://podium.tafiti.dev",
239
- "https://milimani.tafiti-api.org",
240
- "https://polyglot-ashy-beta.vercel.app",
241
- "https://lessons-silk.vercel.app",
242
- "https://lessons.tafiti.dev⁠",
243
- "https://podium-chi.vercel.app",
244
- ]
245
- for origin in production_origins:
246
- if origin not in socketio_origins:
247
- socketio_origins.append(origin)
248
-
249
- return socketio_origins
250
-
251
- def print_config_summary(self):
252
- """Print CORS configuration summary for debugging"""
253
- print("\n" + "="*70)
254
- print("CORS CONFIGURATION SUMMARY")
255
- print("="*70)
256
- print(f"Environment: {self.environment.value}")
257
- print(f"Allow All: {self.allow_all}")
258
- print(f"\nExplicit Origins ({len(self.allowed_origins)}):")
259
- for origin in self.allowed_origins:
260
- print(f" • {origin}")
261
-
262
- if self.origin_patterns:
263
- print(f"\nOrigin Patterns ({len(self.origin_patterns)}):")
264
- for pattern in self.origin_patterns:
265
- print(f" • {pattern.pattern}")
266
-
267
- print("\nExample Origins That Would Be Allowed:")
268
- test_origins = [
269
- "http://localhost:3001",
270
- "http://localhost:3002",
271
- "http://localhost:3003",
272
- "http://localhost:3004",
273
- "http://localhost:5173",
274
- "https://app.tafiti.dev",
275
- "https://polyglot.tafiti.dev",
276
- "https://podium.tafiti.dev",
277
- "https://polyglot.vercel.app",
278
- "https://lessons-silk.vercel.app",
279
- "https://podium-chi.vercel.app",
280
- "https://polyglot-ashy-beta.vercel.app",
281
- "https://mutisya-translator.hf.space",
282
- "https://milimani.tafiti-api.org",
283
- "capacitor://localhost",
284
- "https://example.com",
285
- ]
286
-
287
- for test_origin in test_origins:
288
- allowed = "✓" if self.is_origin_allowed(test_origin) else "✗"
289
- print(f" {allowed} {test_origin}")
290
-
291
- print("="*70 + "\n")
292
-
293
-
294
- # Global CORS configuration instance
295
- cors_config = CORSConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/main.py DELETED
@@ -1,345 +0,0 @@
1
- import os
2
- import os
3
- import asyncio
4
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from fastapi.staticfiles import StaticFiles
7
- from contextlib import asynccontextmanager
8
- import logging
9
- import socketio
10
- import engineio
11
- import re
12
-
13
- from app.routers import sessions, mobile, watch, learning
14
- from app.services.session_manager import SessionManager
15
- from app.services.transcription_service import TranscriptionService
16
- from app.services.translation_service import TranslationService
17
- from app.services.tts_service import TTSService
18
- from app.services.websocket_manager import WebSocketManager
19
- from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
20
- from app.config.cors import cors_config
21
-
22
- class ChunkArrayTruncateFilter(logging.Filter):
23
- """Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
24
-
25
- def filter(self, record):
26
- if hasattr(record, 'msg') and isinstance(record.msg, str):
27
- # More aggressive approach to truncate audioData arrays
28
- # Pattern to match: "audioData":[numbers,numbers,numbers,...]
29
- audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
30
-
31
- def truncate_audiodata(match):
32
- array_content = match.group(1)
33
- # Split by comma and get first 10 items
34
- items = array_content.split(',')
35
- if len(items) > 10:
36
- truncated = ','.join(items[:10])
37
- return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
38
- return match.group(0)
39
-
40
- record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
41
-
42
- # Also handle any other large numeric arrays in brackets
43
- # Pattern for arrays with more than 20 numbers
44
- large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
45
-
46
- def truncate_large_numeric_array(match):
47
- prefix = match.group(1)
48
- array_content = match.group(2)
49
- suffix = match.group(3)
50
-
51
- # Split by comma and get first 10 items
52
- items = array_content.split(',')
53
- if len(items) > 10:
54
- truncated = ','.join(items[:10])
55
- return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
56
- return match.group(0)
57
-
58
- record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
59
-
60
- # Truncate other field types
61
- for field_name in ['chunk', 'wavChunk', 'data']:
62
- field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
63
- def make_truncate_field(fname):
64
- def truncate_field(match):
65
- array_content = match.group(1)
66
- items = array_content.split(',')
67
- if len(items) > 10:
68
- truncated = ','.join(items[:10])
69
- return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
70
- return match.group(0)
71
- return truncate_field
72
-
73
- record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
74
-
75
- return True
76
-
77
-
78
- @asynccontextmanager
79
- async def lifespan(app: FastAPI):
80
- # Initialize services
81
- print("=== INITIALIZING BACKEND SERVICES ===")
82
- try:
83
- print("Initializing transcription service...")
84
- await transcription_service.initialize()
85
- print("✓ Transcription service initialized")
86
-
87
- print("Initializing translation service...")
88
- await translation_service.initialize()
89
- print("✓ Translation service initialized")
90
-
91
- print("Initializing TTS service...")
92
- await tts_service.initialize()
93
- print("✓ TTS service initialized")
94
-
95
- print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===" )
96
-
97
- # Start background loading of additional models after successful startup
98
- print("=== STARTING BACKGROUND MODEL LOADING ===")
99
- transcription_service.start_background_loading()
100
- tts_service.start_background_loading()
101
- print("=== BACKGROUND MODEL LOADING INITIATED ===")
102
-
103
- # Print CORS configuration summary
104
- cors_config.print_config_summary()
105
-
106
- except Exception as e:
107
- print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
108
- import traceback
109
- traceback.print_exc()
110
- raise
111
-
112
- yield
113
-
114
- # Cleanup
115
- print("=== CLEANING UP SERVICES ===")
116
- await transcription_service.cleanup()
117
- await translation_service.cleanup()
118
- await tts_service.cleanup()
119
- print("=== CLEANUP COMPLETE ===")
120
-
121
- app = FastAPI(
122
- title="Real-time Transcription & Translation API",
123
- description="Backend API for real-time speech transcription and translation",
124
- version="1.0.0",
125
- lifespan=lifespan
126
- )
127
-
128
- # CORS middleware with environment-based configuration
129
- cors_middleware_config = cors_config.get_cors_middleware_config()
130
- print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
131
-
132
- app.add_middleware(
133
- CORSMiddleware,
134
- **cors_middleware_config
135
- )
136
-
137
- # Initialize services - using PyTorch models for better compatibility
138
- session_manager = SessionManager()
139
- transcription_service = TranscriptionService()
140
- translation_service = TranslationService()
141
- tts_service = TTSService()
142
- websocket_manager = WebSocketManager(
143
- session_manager=session_manager,
144
- transcription_service=transcription_service,
145
- translation_service=translation_service,
146
- tts_service=tts_service
147
- )
148
-
149
- # Include routers
150
- app.include_router(sessions.router, prefix="/api")
151
- app.include_router(mobile.router, prefix="/api")
152
- app.include_router(watch.router, prefix="/api")
153
- app.include_router(learning.router)
154
-
155
- # Set the session manager in the router
156
- sessions.session_manager = session_manager
157
- sessions.translation_service = translation_service
158
- sessions.tts_service = tts_service
159
- sessions.transcription_service = transcription_service
160
-
161
- # Set the mobile router
162
- mobile.translation_service = translation_service
163
- mobile.tts_service = tts_service
164
- mobile.transcription_service = transcription_service
165
-
166
- # Set the watch router
167
- watch.translation_service = translation_service
168
- watch.tts_service = tts_service
169
- watch.transcription_service = transcription_service
170
-
171
-
172
- # Configure logging with custom filter to truncate chunk arrays
173
- chunk_filter = ChunkArrayTruncateFilter()
174
-
175
- sio_logger = logging.getLogger('socketio')
176
- sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
177
- sio_logger.addFilter(chunk_filter)
178
-
179
- engineio_logger = logging.getLogger('engineio')
180
- engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
181
- engineio_logger.addFilter(chunk_filter)
182
-
183
- # Also apply filter to the root logger to catch any other verbose logging
184
- root_logger = logging.getLogger()
185
- root_logger.addFilter(chunk_filter)
186
-
187
- # Configure Engine.IO payload limits for large audio chunks
188
- engineio.payload.Payload.max_decode_packets = 250
189
-
190
- # Socket.IO setup with environment-based CORS
191
- socketio_cors_origins = cors_config.get_socketio_cors_origins()
192
- print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
193
-
194
- sio = socketio.AsyncServer(
195
- async_mode='asgi',
196
- cors_allowed_origins=socketio_cors_origins,
197
- cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
198
- logger=True, # Re-enabled with custom filtering
199
- engineio_logger=True, # Re-enabled with custom filtering
200
- always_connect=False # This ensures connect event is called for authentication
201
- )
202
-
203
- # Set the socketio instance in websocket manager
204
- websocket_manager.set_socketio(sio)
205
-
206
- socket_app = socketio.ASGIApp(sio, app)
207
-
208
- @app.get("/health")
209
- async def health_check(token: str = Depends(optional_hf_token)):
210
- """Health check endpoint - optionally authenticated"""
211
- from app.auth import hf_auth
212
-
213
- auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
214
- if not hf_auth.is_local and not token:
215
- auth_status = "unauthenticated"
216
-
217
- return {
218
- "status": "healthy",
219
- "message": "Translation service is running",
220
- "auth_status": auth_status,
221
- "local_development": hf_auth.is_local,
222
- "auth_bypassed": hf_auth.is_local,
223
- "token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
224
- "environment": {
225
- "ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
226
- "DEBUG": os.getenv('DEBUG', 'not set'),
227
- "DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
228
- "HOST": os.getenv('HOST', 'not set'),
229
- "PORT": os.getenv('PORT', 'not set')
230
- },
231
- "services": {
232
- "transcription": transcription_service is not None,
233
- "translation": translation_service is not None,
234
- "tts": tts_service is not None,
235
- "sessions": session_manager is not None
236
- }
237
- }
238
-
239
- @sio.event
240
- async def connect(sid, environ=None, auth=None):
241
- """Handle Socket.IO connection with authentication"""
242
- try:
243
- print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
244
- print(f"SID: {sid}")
245
- print(f"Auth data: {auth}")
246
- print(f"Environ type: {type(environ)}")
247
- print(f"Environ data: {environ}")
248
-
249
- # Ensure environ is a dict
250
- if environ is None:
251
- environ = {}
252
-
253
- print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
254
- print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
255
-
256
- # Check authentication from multiple sources
257
- authenticated = False
258
- auth_method = None
259
-
260
- # Method 1: Check auth data from client
261
- if auth and authenticate_websocket_auth_data(auth):
262
- authenticated = True
263
- auth_method = "auth_data"
264
- print("✓ Authenticated via auth data")
265
-
266
- # Method 2: Check environment (headers, query params)
267
- elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
268
- authenticated = True
269
- auth_method = "environ"
270
- print("✓ Authenticated via headers/query")
271
-
272
- # TEMPORARY: Allow connections for debugging (remove in production)
273
- # This helps identify if the issue is authentication or something else
274
- if not authenticated:
275
- print("⚠️ Authentication failed, but allowing for debugging")
276
- if isinstance(environ, dict):
277
- print(f"Available environ keys: {list(environ.keys())}")
278
- # Uncomment the next line to temporarily allow unauthenticated connections for debugging
279
- authenticated = True
280
- auth_method = "debug_bypass"
281
-
282
- if not authenticated:
283
- print("❌ Authentication failed - disconnecting")
284
- await sio.disconnect(sid)
285
- return False
286
-
287
- print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
288
- return True
289
-
290
- except Exception as e:
291
- print(f"❌ Error in connect handler: {e}")
292
- import traceback
293
- traceback.print_exc()
294
- try:
295
- await sio.disconnect(sid)
296
- except:
297
- pass
298
- return False
299
-
300
- @sio.event
301
- async def disconnect(sid):
302
- await websocket_manager.handle_disconnect(sid)
303
-
304
- @sio.event
305
- async def join_session(sid, data):
306
- await websocket_manager.handle_join_session(sid, data)
307
-
308
- @sio.event
309
- async def join_hub(sid, data):
310
- await websocket_manager.handle_join_hub(sid, data)
311
-
312
- @sio.event
313
- async def leave_session(sid, data):
314
- await websocket_manager.handle_leave_session(sid, data)
315
-
316
- @sio.event
317
- async def audio_chunk(sid, data):
318
- await websocket_manager.handle_audio_chunk(sid, data)
319
-
320
- @sio.event
321
- async def speaking_status(sid, data):
322
- await websocket_manager.handle_speaking_status(sid, data)
323
-
324
- @sio.event
325
- async def test_echo(sid, data):
326
- """Test event to verify WebSocket communication"""
327
- await sio.emit('test_echo_response', data, room=sid)
328
-
329
- @sio.event
330
- async def update_participant_language(sid, data):
331
- """Update participant's language (affects speech recognition)"""
332
- await websocket_manager.handle_update_participant_language(sid, data)
333
-
334
- @sio.event
335
- async def update_session_languages(sid, data):
336
- """Update session's languages (affects translation targets)"""
337
- await websocket_manager.handle_update_session_languages(sid, data)
338
-
339
- # Serve static files (for frontend)
340
- if os.path.exists("../frontend/dist"):
341
- app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
342
-
343
- if __name__ == "__main__":
344
- import uvicorn
345
- uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/main.py.bak DELETED
@@ -1,345 +0,0 @@
1
- import os
2
- import os
3
- import asyncio
4
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Depends
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from fastapi.staticfiles import StaticFiles
7
- from contextlib import asynccontextmanager
8
- import logging
9
- import socketio
10
- import engineio
11
- import re
12
-
13
- from app.routers import sessions, mobile, watch, learning
14
- from app.services.session_manager import SessionManager
15
- from app.services.transcription_service import TranscriptionService
16
- from app.services.translation_service import TranslationService
17
- from app.services.tts_service import TTSService
18
- from app.services.websocket_manager import WebSocketManager
19
- from app.auth import require_hf_token, optional_hf_token, authenticate_websocket_connect, authenticate_websocket_auth_data
20
- from app.config.cors import cors_config
21
-
22
- class ChunkArrayTruncateFilter(logging.Filter):
23
- """Custom logging filter to truncate long arrays in Socket.IO logs for better readability"""
24
-
25
- def filter(self, record):
26
- if hasattr(record, 'msg') and isinstance(record.msg, str):
27
- # More aggressive approach to truncate audioData arrays
28
- # Pattern to match: "audioData":[numbers,numbers,numbers,...]
29
- audiodata_pattern = r'"audioData":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
30
-
31
- def truncate_audiodata(match):
32
- array_content = match.group(1)
33
- # Split by comma and get first 10 items
34
- items = array_content.split(',')
35
- if len(items) > 10:
36
- truncated = ','.join(items[:10])
37
- return f'"audioData":[{truncated}, ...] (truncated {len(items)-10} more items)'
38
- return match.group(0)
39
-
40
- record.msg = re.sub(audiodata_pattern, truncate_audiodata, record.msg)
41
-
42
- # Also handle any other large numeric arrays in brackets
43
- # Pattern for arrays with more than 20 numbers
44
- large_numeric_array_pattern = r'(\[)([0-9,-]+(?:,[0-9,-]+){20,})(\])'
45
-
46
- def truncate_large_numeric_array(match):
47
- prefix = match.group(1)
48
- array_content = match.group(2)
49
- suffix = match.group(3)
50
-
51
- # Split by comma and get first 10 items
52
- items = array_content.split(',')
53
- if len(items) > 10:
54
- truncated = ','.join(items[:10])
55
- return f'{prefix}{truncated}, ... (truncated {len(items)-10} more){suffix}'
56
- return match.group(0)
57
-
58
- record.msg = re.sub(large_numeric_array_pattern, truncate_large_numeric_array, record.msg)
59
-
60
- # Truncate other field types
61
- for field_name in ['chunk', 'wavChunk', 'data']:
62
- field_pattern = rf'"{field_name}":\[([0-9,-]+(?:,[0-9,-]+)*)\]'
63
- def make_truncate_field(fname):
64
- def truncate_field(match):
65
- array_content = match.group(1)
66
- items = array_content.split(',')
67
- if len(items) > 10:
68
- truncated = ','.join(items[:10])
69
- return f'"{fname}":[{truncated}, ...] (truncated {len(items)-10} more)'
70
- return match.group(0)
71
- return truncate_field
72
-
73
- record.msg = re.sub(field_pattern, make_truncate_field(field_name), record.msg)
74
-
75
- return True
76
-
77
-
78
- @asynccontextmanager
79
- async def lifespan(app: FastAPI):
80
- # Initialize services
81
- print("=== INITIALIZING BACKEND SERVICES ===")
82
- try:
83
- print("Initializing transcription service...")
84
- await transcription_service.initialize()
85
- print("✓ Transcription service initialized")
86
-
87
- print("Initializing translation service...")
88
- await translation_service.initialize()
89
- print("✓ Translation service initialized")
90
-
91
- print("Initializing TTS service...")
92
- await tts_service.initialize()
93
- print("✓ TTS service initialized")
94
-
95
- print("=== ALL SERVICES INITIALIZED SUCCESSFULLY ===")
96
-
97
- # Start background loading of additional models after successful startup
98
- print("=== STARTING BACKGROUND MODEL LOADING ===")
99
- transcription_service.start_background_loading()
100
- tts_service.start_background_loading()
101
- print("=== BACKGROUND MODEL LOADING INITIATED ===")
102
-
103
- # Print CORS configuration summary
104
- cors_config.print_config_summary()
105
-
106
- except Exception as e:
107
- print(f"❌ SERVICE INITIALIZATION FAILED: {e}")
108
- import traceback
109
- traceback.print_exc()
110
- raise
111
-
112
- yield
113
-
114
- # Cleanup
115
- print("=== CLEANING UP SERVICES ===")
116
- await transcription_service.cleanup()
117
- await translation_service.cleanup()
118
- await tts_service.cleanup()
119
- print("=== CLEANUP COMPLETE ===")
120
-
121
- app = FastAPI(
122
- title="Real-time Transcription & Translation API",
123
- description="Backend API for real-time speech transcription and translation",
124
- version="1.0.0",
125
- lifespan=lifespan
126
- )
127
-
128
- # CORS middleware with environment-based configuration
129
- cors_middleware_config = cors_config.get_cors_middleware_config()
130
- print(f"Configuring CORS middleware with keys: {list(cors_middleware_config.keys())}")
131
-
132
- app.add_middleware(
133
- CORSMiddleware,
134
- **cors_middleware_config
135
- )
136
-
137
- # Initialize services - using PyTorch models for better compatibility
138
- session_manager = SessionManager()
139
- transcription_service = TranscriptionService()
140
- translation_service = TranslationService()
141
- tts_service = TTSService()
142
- websocket_manager = WebSocketManager(
143
- session_manager=session_manager,
144
- transcription_service=transcription_service,
145
- translation_service=translation_service,
146
- tts_service=tts_service
147
- )
148
-
149
- # Include routers
150
- app.include_router(sessions.router, prefix="/api")
151
- app.include_router(mobile.router, prefix="/api")
152
- app.include_router(watch.router, prefix="/api")
153
- app.include_router(learning.router)
154
-
155
- # Set the session manager in the router
156
- sessions.session_manager = session_manager
157
- sessions.translation_service = translation_service
158
- sessions.tts_service = tts_service
159
- sessions.transcription_service = transcription_service
160
-
161
- # Set the mobile router
162
- mobile.translation_service = translation_service
163
- mobile.tts_service = tts_service
164
- mobile.transcription_service = transcription_service
165
-
166
- # Set the watch router
167
- watch.translation_service = translation_service
168
- watch.tts_service = tts_service
169
- watch.transcription_service = transcription_service
170
-
171
-
172
- # Configure logging with custom filter to truncate chunk arrays
173
- chunk_filter = ChunkArrayTruncateFilter()
174
-
175
- sio_logger = logging.getLogger('socketio')
176
- sio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
177
- sio_logger.addFilter(chunk_filter)
178
-
179
- engineio_logger = logging.getLogger('engineio')
180
- engineio_logger.setLevel(logging.INFO) # Show info logs with truncated arrays
181
- engineio_logger.addFilter(chunk_filter)
182
-
183
- # Also apply filter to the root logger to catch any other verbose logging
184
- root_logger = logging.getLogger()
185
- root_logger.addFilter(chunk_filter)
186
-
187
- # Configure Engine.IO payload limits for large audio chunks
188
- engineio.payload.Payload.max_decode_packets = 250
189
-
190
- # Socket.IO setup with environment-based CORS
191
- socketio_cors_origins = cors_config.get_socketio_cors_origins()
192
- print(f"Configuring Socket.IO CORS: {len(socketio_cors_origins) if isinstance(socketio_cors_origins, list) else 'all'} origins")
193
-
194
- sio = socketio.AsyncServer(
195
- async_mode='asgi',
196
- cors_allowed_origins=socketio_cors_origins,
197
- cors_credentials=not cors_config.allow_all, # Cannot use credentials with wildcard
198
- logger=True, # Re-enabled with custom filtering
199
- engineio_logger=True, # Re-enabled with custom filtering
200
- always_connect=False # This ensures connect event is called for authentication
201
- )
202
-
203
- # Set the socketio instance in websocket manager
204
- websocket_manager.set_socketio(sio)
205
-
206
- socket_app = socketio.ASGIApp(sio, app)
207
-
208
- @app.get("/health")
209
- async def health_check(token: str = Depends(optional_hf_token)):
210
- """Health check endpoint - optionally authenticated"""
211
- from app.auth import hf_auth
212
-
213
- auth_status = "bypassed (local development)" if hf_auth.is_local else "authenticated"
214
- if not hf_auth.is_local and not token:
215
- auth_status = "unauthenticated"
216
-
217
- return {
218
- "status": "healthy",
219
- "message": "Translation service is running",
220
- "auth_status": auth_status,
221
- "local_development": hf_auth.is_local,
222
- "auth_bypassed": hf_auth.is_local,
223
- "token_prefix": token[:10] + "..." if token and token != "local-development-bypass" else "local-bypass" if hf_auth.is_local else None,
224
- "environment": {
225
- "ENVIRONMENT": os.getenv('ENVIRONMENT', 'not set'),
226
- "DEBUG": os.getenv('DEBUG', 'not set'),
227
- "DISABLE_AUTH": os.getenv('DISABLE_AUTH', 'not set'),
228
- "HOST": os.getenv('HOST', 'not set'),
229
- "PORT": os.getenv('PORT', 'not set')
230
- },
231
- "services": {
232
- "transcription": transcription_service is not None,
233
- "translation": translation_service is not None,
234
- "tts": tts_service is not None,
235
- "sessions": session_manager is not None
236
- }
237
- }
238
-
239
- @sio.event
240
- async def connect(sid, environ=None, auth=None):
241
- """Handle Socket.IO connection with authentication"""
242
- try:
243
- print(f"=== WEBSOCKET CONNECTION ATTEMPT ===")
244
- print(f"SID: {sid}")
245
- print(f"Auth data: {auth}")
246
- print(f"Environ type: {type(environ)}")
247
- print(f"Environ data: {environ}")
248
-
249
- # Ensure environ is a dict
250
- if environ is None:
251
- environ = {}
252
-
253
- print(f"Query string: {environ.get('QUERY_STRING', 'None')}")
254
- print(f"Headers: {[k for k in environ.keys() if k.startswith('HTTP_')] if isinstance(environ, dict) else 'environ not dict'}")
255
-
256
- # Check authentication from multiple sources
257
- authenticated = False
258
- auth_method = None
259
-
260
- # Method 1: Check auth data from client
261
- if auth and authenticate_websocket_auth_data(auth):
262
- authenticated = True
263
- auth_method = "auth_data"
264
- print("✓ Authenticated via auth data")
265
-
266
- # Method 2: Check environment (headers, query params)
267
- elif environ and isinstance(environ, dict) and authenticate_websocket_connect(environ):
268
- authenticated = True
269
- auth_method = "environ"
270
- print("✓ Authenticated via headers/query")
271
-
272
- # TEMPORARY: Allow connections for debugging (remove in production)
273
- # This helps identify if the issue is authentication or something else
274
- if not authenticated:
275
- print("⚠️ Authentication failed, but allowing for debugging")
276
- if isinstance(environ, dict):
277
- print(f"Available environ keys: {list(environ.keys())}")
278
- # Uncomment the next line to temporarily allow unauthenticated connections for debugging
279
- authenticated = True
280
- auth_method = "debug_bypass"
281
-
282
- if not authenticated:
283
- print("❌ Authentication failed - disconnecting")
284
- await sio.disconnect(sid)
285
- return False
286
-
287
- print(f"✓ WebSocket connection authenticated successfully via {auth_method}")
288
- return True
289
-
290
- except Exception as e:
291
- print(f"❌ Error in connect handler: {e}")
292
- import traceback
293
- traceback.print_exc()
294
- try:
295
- await sio.disconnect(sid)
296
- except:
297
- pass
298
- return False
299
-
300
- @sio.event
301
- async def disconnect(sid):
302
- await websocket_manager.handle_disconnect(sid)
303
-
304
- @sio.event
305
- async def join_session(sid, data):
306
- await websocket_manager.handle_join_session(sid, data)
307
-
308
- @sio.event
309
- async def join_hub(sid, data):
310
- await websocket_manager.handle_join_hub(sid, data)
311
-
312
- @sio.event
313
- async def leave_session(sid, data):
314
- await websocket_manager.handle_leave_session(sid, data)
315
-
316
- @sio.event
317
- async def audio_chunk(sid, data):
318
- await websocket_manager.handle_audio_chunk(sid, data)
319
-
320
- @sio.event
321
- async def speaking_status(sid, data):
322
- await websocket_manager.handle_speaking_status(sid, data)
323
-
324
- @sio.event
325
- async def test_echo(sid, data):
326
- """Test event to verify WebSocket communication"""
327
- await sio.emit('test_echo_response', data, room=sid)
328
-
329
- @sio.event
330
- async def update_participant_language(sid, data):
331
- """Update participant's language (affects speech recognition)"""
332
- await websocket_manager.handle_update_participant_language(sid, data)
333
-
334
- @sio.event
335
- async def update_session_languages(sid, data):
336
- """Update session's languages (affects translation targets)"""
337
- await websocket_manager.handle_update_session_languages(sid, data)
338
-
339
- # Serve static files (for frontend)
340
- if os.path.exists("../frontend/dist"):
341
- app.mount("/", StaticFiles(directory="../frontend/dist", html=True), name="static")
342
-
343
- if __name__ == "__main__":
344
- import uvicorn
345
- uvicorn.run("main:socket_app", host="0.0.0.0", port=7860, reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/models/__init__.py DELETED
@@ -1,77 +0,0 @@
1
- from pydantic import BaseModel, Field
2
- from typing import List, Dict, Optional
3
- from enum import Enum
4
-
5
- class LanguageCode(str, Enum):
6
- ENGLISH = "eng"
7
- SWAHILI = "swa"
8
- KIKUYU = "kik"
9
- KAMBA = "kam"
10
- KIMERU = "mer"
11
- LUO = "luo"
12
- SOMALI = "som"
13
-
14
- class Language(BaseModel):
15
- code: LanguageCode
16
- name: str
17
- display_name: str
18
-
19
- class ParticipantCreate(BaseModel):
20
- name: str
21
- language: LanguageCode
22
-
23
- class Participant(BaseModel):
24
- id: str
25
- name: str
26
- language: Language
27
- is_organizer: bool = False
28
- is_speaking: bool = False
29
- is_connected: bool = False
30
-
31
- class SessionCreate(BaseModel):
32
- name: str
33
- organizer_name: str
34
- languages: List[LanguageCode]
35
- enable_tts: bool = True # Enable TTS by default for backward compatibility
36
-
37
- class Session(BaseModel):
38
- id: str
39
- name: str
40
- organizer_name: str
41
- participants: List[Participant] = []
42
- languages: List[Language] = []
43
- qr_code_url: Optional[str] = None
44
- is_active: bool = True
45
- enable_tts: bool = True # TTS enabled by default
46
-
47
- class Message(BaseModel):
48
- id: str
49
- session_id: str
50
- speaker_id: str
51
- speaker_name: str
52
- original_text: str
53
- original_language: Language
54
- translations: Dict[str, str] = {}
55
- is_transcribing: bool = False
56
-
57
- class TranscriptionUpdate(BaseModel):
58
- message_id: str
59
- text: str
60
- is_complete: bool
61
- confidence: Optional[float] = None
62
-
63
- class TranslationUpdate(BaseModel):
64
- message_id: str
65
- target_language: LanguageCode
66
- translated_text: str
67
-
68
- class AudioChunk(BaseModel):
69
- session_id: str
70
- participant_id: str
71
- audio_data: bytes
72
-
73
- class WebSocketMessage(BaseModel):
74
- type: str
75
- data: Dict
76
- session_id: str
77
- participant_id: Optional[str] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Routers package
 
 
app/routers/add_phase_endpoints.py DELETED
@@ -1,490 +0,0 @@
1
- # Script to add remaining Phase 1-3 endpoints to learning.py
2
-
3
- endpoints_code = """
4
-
5
- @router.post("/vocabulary/add")
6
- async def add_vocabulary_to_practice(
7
- vocab_request: VocabularyAddRequest,
8
- request: Request,
9
- token: Optional[str] = Depends(optional_hf_token)
10
- ):
11
- \"\"\"Add a vocabulary word to user's practice queue with FSRS initialization\"\"\"
12
- try:
13
- user_id = token if token else 'anonymous'
14
-
15
- vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
16
- if not vocab:
17
- raise HTTPException(status_code=404, detail="Vocabulary not found")
18
-
19
- fsrs_data = {
20
- 'difficulty': 0.3,
21
- 'stability': 2.5,
22
- 'retrievability': 1.0,
23
- 'review_count': 0,
24
- 'last_review': None,
25
- 'next_review': datetime.utcnow().isoformat() + 'Z',
26
- 'lapses': 0,
27
- 'state': 'new'
28
- }
29
-
30
- user_vocab = {
31
- 'vocabulary_id': vocab_request.vocab_id,
32
- 'swahili': vocab.get('swahili', ''),
33
- 'english': vocab.get('english', ''),
34
- 'part_of_speech': vocab.get('part_of_speech', 'unknown'),
35
- 'added_at': datetime.utcnow().isoformat() + 'Z',
36
- 'added_from': vocab_request.source_lesson_id,
37
- 'fsrs': fsrs_data,
38
- 'mastery_level': 0,
39
- 'times_reviewed': 0,
40
- 'times_correct': 0,
41
- 'accuracy': 0.0
42
- }
43
-
44
- success = learning_service.update_vocabulary_progress(
45
- user_id, str(vocab_request.vocab_id), user_vocab
46
- )
47
-
48
- if success:
49
- return {"success": True, "vocabulary": user_vocab}
50
- else:
51
- raise HTTPException(status_code=500, detail="Failed to add vocabulary")
52
- except HTTPException:
53
- raise
54
- except Exception as e:
55
- logger.error(f"Error adding vocabulary: {e}")
56
- raise HTTPException(status_code=500, detail="Failed to add vocabulary")
57
-
58
-
59
- def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
60
- \"\"\"Implement FSRS algorithm\"\"\"
61
- from datetime import timedelta
62
-
63
- difficulty = fsrs['difficulty']
64
- stability = fsrs['stability']
65
-
66
- if grade == 0:
67
- new_difficulty = min(difficulty + 0.2, 1.0)
68
- elif grade == 2:
69
- new_difficulty = min(difficulty + 0.1, 1.0)
70
- elif grade == 4:
71
- new_difficulty = max(difficulty - 0.1, 0.0)
72
- else:
73
- new_difficulty = difficulty
74
-
75
- if grade == 0:
76
- new_stability = stability * 0.5
77
- state = 'relearning'
78
- interval_minutes = 10
79
- elif grade == 2:
80
- new_stability = stability * 1.2
81
- state = 'review'
82
- interval_minutes = int(new_stability * 24 * 60)
83
- elif grade == 3:
84
- new_stability = stability * 2.5
85
- state = 'review'
86
- interval_minutes = int(new_stability * 24 * 60)
87
- else:
88
- new_stability = stability * 4.0
89
- state = 'review'
90
- interval_minutes = int(new_stability * 24 * 60)
91
-
92
- next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
93
-
94
- return {
95
- 'difficulty': new_difficulty,
96
- 'stability': new_stability,
97
- 'retrievability': 0.9 if grade >= 2 else 0.0,
98
- 'review_count': fsrs['review_count'] + 1,
99
- 'last_review': datetime.utcnow().isoformat() + 'Z',
100
- 'next_review': next_review.isoformat() + 'Z',
101
- 'lapses': fsrs['lapses'],
102
- 'state': state,
103
- 'interval_days': interval_minutes / (24 * 60)
104
- }
105
-
106
-
107
- def calculate_mastery_level(vocab: Dict) -> int:
108
- \"\"\"Calculate mastery level (0-5)\"\"\"
109
- accuracy = vocab['accuracy']
110
- reviews = vocab['times_reviewed']
111
- stability = vocab['fsrs']['stability']
112
-
113
- if reviews == 0:
114
- return 0
115
- elif reviews < 5 or accuracy < 70:
116
- return 1
117
- elif reviews < 10 or accuracy < 85:
118
- return 2
119
- elif reviews < 20 or accuracy < 95:
120
- return 3
121
- elif reviews >= 20 and accuracy >= 95 and stability >= 30:
122
- return 4
123
- elif reviews >= 40 and accuracy >= 98 and stability >= 90:
124
- return 5
125
- else:
126
- return 3
127
-
128
-
129
- @router.post("/vocabulary/review")
130
- async def record_vocabulary_review_fsrs(
131
- review_request: VocabularyReviewRequest,
132
- request: Request,
133
- token: Optional[str] = Depends(optional_hf_token)
134
- ):
135
- \"\"\"Record vocabulary review and update FSRS parameters\"\"\"
136
- try:
137
- user_id = token if token else 'anonymous'
138
- progress = learning_service.get_user_progress(user_id)
139
-
140
- if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
141
- raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
142
-
143
- vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
144
- fsrs = vocab['fsrs']
145
-
146
- grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
147
- grade = grade_map.get(review_request.rating, 3)
148
-
149
- new_fsrs = calculate_next_review_fsrs(fsrs, grade)
150
-
151
- vocab['fsrs'] = new_fsrs
152
- vocab['times_reviewed'] += 1
153
- if grade >= 2:
154
- vocab['times_correct'] += 1
155
- else:
156
- vocab['fsrs']['lapses'] += 1
157
-
158
- vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
159
- vocab['mastery_level'] = calculate_mastery_level(vocab)
160
- vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
161
-
162
- if 'vocabulary_reviewed' not in progress['overall_stats']:
163
- progress['overall_stats']['vocabulary_reviewed'] = 0
164
- progress['overall_stats']['vocabulary_reviewed'] += 1
165
-
166
- learning_service.save_user_progress(user_id, progress)
167
-
168
- return {
169
- "success": True,
170
- "vocabulary": vocab,
171
- "next_review": new_fsrs['next_review'],
172
- "interval_days": new_fsrs['interval_days']
173
- }
174
- except HTTPException:
175
- raise
176
- except Exception as e:
177
- logger.error(f"Error recording vocabulary review: {e}")
178
- raise HTTPException(status_code=500, detail="Failed to record review")
179
-
180
-
181
- @router.get("/vocabulary/stats")
182
- async def get_vocabulary_stats(
183
- request: Request,
184
- token: Optional[str] = Depends(optional_hf_token)
185
- ):
186
- \"\"\"Get vocabulary mastery statistics\"\"\"
187
- try:
188
- user_id = token if token else 'anonymous'
189
- progress = learning_service.get_user_progress(user_id)
190
-
191
- if not progress:
192
- return {
193
- "total_words": 0,
194
- "in_practice": 0,
195
- "mastery_breakdown": {str(i): 0 for i in range(6)},
196
- "average_accuracy": 0
197
- }
198
-
199
- vocab_progress = progress.get('vocabulary_progress', {})
200
- mastery_breakdown = {str(i): 0 for i in range(6)}
201
- total_accuracy = 0
202
- total_with_reviews = 0
203
-
204
- for vocab_data in vocab_progress.values():
205
- level = vocab_data.get('mastery_level', 0)
206
- mastery_breakdown[str(level)] += 1
207
-
208
- if vocab_data.get('times_reviewed', 0) > 0:
209
- total_accuracy += vocab_data.get('accuracy', 0)
210
- total_with_reviews += 1
211
-
212
- avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
213
-
214
- return {
215
- "total_words": len(vocab_progress),
216
- "in_practice": len(vocab_progress),
217
- "mastery_breakdown": mastery_breakdown,
218
- "average_accuracy": round(avg_accuracy, 1),
219
- "total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
220
- }
221
- except Exception as e:
222
- logger.error(f"Error getting vocabulary stats: {e}")
223
- raise HTTPException(status_code=500, detail="Failed to get stats")
224
-
225
-
226
- @router.get("/vocabulary/library")
227
- async def get_vocabulary_library(
228
- lesson_id: Optional[int] = None,
229
- level: Optional[str] = None,
230
- search: Optional[str] = None,
231
- request: Request = None,
232
- token: Optional[str] = Depends(optional_hf_token)
233
- ):
234
- \"\"\"Browse all vocabulary with filters\"\"\"
235
- try:
236
- user_id = token if token else 'anonymous'
237
-
238
- all_vocab = learning_service.get_all_vocabulary()
239
- progress = learning_service.get_user_progress(user_id)
240
- user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
241
-
242
- filtered_vocab = all_vocab
243
-
244
- if lesson_id:
245
- filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
246
-
247
- if level:
248
- filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
249
-
250
- if search:
251
- search_lower = search.lower()
252
- filtered_vocab = [v for v in filtered_vocab
253
- if search_lower in v.get('swahili', '').lower()
254
- or search_lower in v.get('english', '').lower()]
255
-
256
- for vocab in filtered_vocab:
257
- vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
258
- if vocab_id in user_vocab:
259
- vocab['status'] = 'practicing'
260
- vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
261
- vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
262
- vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
263
- else:
264
- vocab['status'] = 'not_practicing'
265
- vocab['mastery_level'] = 0
266
-
267
- return {
268
- "vocabulary": filtered_vocab,
269
- "total": len(filtered_vocab),
270
- "filters_applied": {
271
- "lesson_id": lesson_id,
272
- "level": level,
273
- "search": search
274
- }
275
- }
276
- except Exception as e:
277
- logger.error(f"Error getting vocabulary library: {e}")
278
- raise HTTPException(status_code=500, detail="Failed to get vocabulary")
279
-
280
-
281
- # Reading Comprehension
282
-
283
- class ComprehensionAnswer(BaseModel):
284
- question_id: str
285
- answer: str
286
-
287
-
288
- class ComprehensionSubmission(BaseModel):
289
- lesson_id: int
290
- passage_id: str
291
- answers: List[ComprehensionAnswer]
292
-
293
-
294
- @router.post("/comprehension/submit")
295
- async def submit_comprehension_answers(
296
- submission: ComprehensionSubmission,
297
- request: Request,
298
- token: Optional[str] = Depends(optional_hf_token)
299
- ):
300
- \"\"\"Submit reading comprehension answers and get scoring\"\"\"
301
- try:
302
- user_id = token if token else 'anonymous'
303
-
304
- lesson = learning_service.get_lesson(submission.lesson_id)
305
- if not lesson:
306
- raise HTTPException(status_code=404, detail="Lesson not found")
307
-
308
- passage = None
309
- for p in lesson.get('reading_passages', []):
310
- if p['passage_id'] == submission.passage_id:
311
- passage = p
312
- break
313
-
314
- if not passage:
315
- raise HTTPException(status_code=404, detail="Passage not found")
316
-
317
- results = []
318
- correct_count = 0
319
-
320
- for submitted in submission.answers:
321
- question_id = submitted.question_id
322
- user_answer = submitted.answer.strip().lower()
323
-
324
- question = None
325
- for q in passage['comprehension_questions']:
326
- if q['question_id'] == question_id:
327
- question = q
328
- break
329
-
330
- if not question:
331
- continue
332
-
333
- correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
334
- is_correct = user_answer in correct_answers
335
-
336
- if is_correct:
337
- correct_count += 1
338
-
339
- results.append({
340
- "question_id": question_id,
341
- "correct": is_correct,
342
- "user_answer": user_answer,
343
- "correct_answer": question['correct_answers'][0] if correct_answers else None,
344
- "explanation": question.get('explanation')
345
- })
346
-
347
- score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
348
-
349
- progress = learning_service.get_user_progress(user_id)
350
- if not progress:
351
- progress = learning_service.create_default_progress(user_id)
352
-
353
- if 'comprehension_scores' not in progress:
354
- progress['comprehension_scores'] = {}
355
-
356
- progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
357
- "score": score,
358
- "completed_at": datetime.utcnow().isoformat() + 'Z',
359
- "attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
360
- }
361
-
362
- learning_service.save_user_progress(user_id, progress)
363
-
364
- return {
365
- "results": results,
366
- "score": round(score, 1),
367
- "correct": correct_count,
368
- "total": len(submission.answers)
369
- }
370
- except HTTPException:
371
- raise
372
- except Exception as e:
373
- logger.error(f"Error submitting comprehension: {e}")
374
- raise HTTPException(status_code=500, detail="Failed to submit comprehension")
375
-
376
-
377
- # Task Scenarios
378
-
379
- class ScenarioProgressUpdate(BaseModel):
380
- turn_id: str
381
- choice_id: str
382
-
383
-
384
- @router.get("/scenarios/{scenario_id}")
385
- async def get_scenario(
386
- scenario_id: str,
387
- request: Request,
388
- token: Optional[str] = Depends(optional_hf_token)
389
- ):
390
- \"\"\"Get task scenario with branching dialogue\"\"\"
391
- try:
392
- user_id = token if token else 'anonymous'
393
-
394
- scenario = learning_service.get_scenario(scenario_id)
395
- if not scenario:
396
- raise HTTPException(status_code=404, detail="Scenario not found")
397
-
398
- progress = learning_service.get_user_progress(user_id)
399
- scenario_progress = None
400
-
401
- if progress and 'scenario_progress' in progress:
402
- scenario_progress = progress['scenario_progress'].get(scenario_id)
403
-
404
- return {
405
- "scenario": scenario,
406
- "user_progress": scenario_progress
407
- }
408
- except HTTPException:
409
- raise
410
- except Exception as e:
411
- logger.error(f"Error getting scenario: {e}")
412
- raise HTTPException(status_code=500, detail="Failed to get scenario")
413
-
414
-
415
- @router.post("/scenarios/{scenario_id}/progress")
416
- async def update_scenario_progress(
417
- scenario_id: str,
418
- progress_update: ScenarioProgressUpdate,
419
- request: Request,
420
- token: Optional[str] = Depends(optional_hf_token)
421
- ):
422
- \"\"\"Update scenario progress with user choice\"\"\"
423
- try:
424
- user_id = token if token else 'anonymous'
425
-
426
- scenario = learning_service.get_scenario(scenario_id)
427
- if not scenario:
428
- raise HTTPException(status_code=404, detail="Scenario not found")
429
-
430
- progress = learning_service.get_user_progress(user_id)
431
- if not progress:
432
- progress = learning_service.create_default_progress(user_id)
433
-
434
- if 'scenario_progress' not in progress:
435
- progress['scenario_progress'] = {}
436
-
437
- if scenario_id not in progress['scenario_progress']:
438
- progress['scenario_progress'][scenario_id] = {
439
- "started_at": datetime.utcnow().isoformat() + 'Z',
440
- "turns": [],
441
- "completed": False
442
- }
443
-
444
- progress['scenario_progress'][scenario_id]['turns'].append({
445
- "turn_id": progress_update.turn_id,
446
- "choice_id": progress_update.choice_id,
447
- "timestamp": datetime.utcnow().isoformat() + 'Z'
448
- })
449
-
450
- turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
451
- if turns_count >= scenario.get('required_turns', 6):
452
- progress['scenario_progress'][scenario_id]['completed'] = True
453
- progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
454
-
455
- learning_service.save_user_progress(user_id, progress)
456
-
457
- return {
458
- "success": True,
459
- "progress": progress['scenario_progress'][scenario_id]
460
- }
461
- except HTTPException:
462
- raise
463
- except Exception as e:
464
- logger.error(f"Error updating scenario progress: {e}")
465
- raise HTTPException(status_code=500, detail="Failed to update scenario")
466
-
467
-
468
- @router.get("/scenarios")
469
- async def list_scenarios(
470
- request: Request,
471
- token: Optional[str] = Depends(optional_hf_token)
472
- ):
473
- \"\"\"Get list of all available scenarios\"\"\"
474
- try:
475
- scenarios = learning_service.get_all_scenarios()
476
- return {
477
- "success": True,
478
- "scenarios": scenarios,
479
- "total": len(scenarios)
480
- }
481
- except Exception as e:
482
- logger.error(f"Error listing scenarios: {e}")
483
- raise HTTPException(status_code=500, detail="Failed to list scenarios")
484
- """
485
-
486
- # Append to learning.py
487
- with open('C:/repos/polyglot/backend/app/routers/learning.py', 'a', encoding='utf-8') as f:
488
- f.write(endpoints_code)
489
-
490
- print("Successfully added all remaining Phase 1-3 endpoints!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/learning.py DELETED
@@ -1,1020 +0,0 @@
1
- """
2
- Learning API Router - REST endpoints for language learning functionality
3
-
4
- Provides endpoints for:
5
- - Fetching lesson catalog and individual lessons
6
- - Managing user progress
7
- - Recording lesson completion and scores
8
- - Achievement tracking
9
- """
10
-
11
- from fastapi import APIRouter, HTTPException, Depends, Request, File, UploadFile
12
- from fastapi.responses import Response
13
- from pydantic import BaseModel
14
- from typing import List, Dict, Optional, Any
15
- from datetime import datetime
16
- import logging
17
- import io
18
-
19
- from app.services.learning_data_service import LearningDataService
20
- from app.auth import optional_hf_token
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
- router = APIRouter(prefix="/api/learning", tags=["learning"])
25
-
26
- # Initialize data service
27
- learning_service = LearningDataService()
28
-
29
-
30
- # ==================== Request/Response Models ====================
31
-
32
- class LessonProgressUpdate(BaseModel):
33
- lesson_id: int
34
- status: str # 'in_progress' or 'completed'
35
- score: Optional[int] = None
36
- pronunciation_score: Optional[float] = None
37
- listening_score: Optional[float] = None
38
- comprehension_score: Optional[float] = None
39
- time_spent_seconds: Optional[int] = None
40
- steps_completed: Optional[int] = None
41
- steps_skipped: Optional[int] = None
42
-
43
-
44
- class VocabularyReview(BaseModel):
45
- vocabulary_id: int
46
- swahili: str
47
- is_correct: bool
48
- mastery_level: Optional[int] = None
49
-
50
-
51
- class AchievementCheck(BaseModel):
52
- achievement_id: str
53
- progress: int
54
- target: int
55
-
56
-
57
- # ==================== Lesson Endpoints ====================
58
-
59
- @router.get("/lessons")
60
- async def get_lessons(language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
61
- """
62
- Get catalog of all available lessons for a specific language
63
-
64
- Args:
65
- language: Language code (swahili, kamba, maasai)
66
-
67
- Returns the lesson index with metadata for all lessons
68
- """
69
- try:
70
- index = learning_service.get_lessons_index(language)
71
- if not index:
72
- raise HTTPException(status_code=404, detail=f"Lessons catalog not found for {language}")
73
-
74
- return {
75
- "success": True,
76
- "lessons": index.get('lessons', []),
77
- "learning_paths": index.get('learning_paths', {}),
78
- "metadata": index.get('metadata', {})
79
- }
80
- except HTTPException:
81
- raise
82
- except Exception as e:
83
- logger.error(f"Error fetching lessons for {language}: {e}")
84
- raise HTTPException(status_code=500, detail="Failed to fetch lessons")
85
-
86
-
87
- @router.get("/lessons/{lesson_id}")
88
- async def get_lesson(lesson_id: int, language: Optional[str] = 'swahili', request: Request = None, token: Optional[str] = Depends(optional_hf_token)):
89
- """
90
- Get detailed lesson content including vocabulary, dialogue, and exercises
91
-
92
- Args:
93
- lesson_id: ID of the lesson to fetch
94
- language: Language code (swahili, kamba, maasai)
95
- """
96
- try:
97
- lesson = learning_service.get_lesson(lesson_id, language)
98
- if not lesson:
99
- raise HTTPException(status_code=404, detail=f"Lesson {lesson_id} not found for {language}")
100
-
101
- return {
102
- "success": True,
103
- "lesson": lesson
104
- }
105
- except HTTPException:
106
- raise
107
- except Exception as e:
108
- logger.error(f"Error fetching lesson {lesson_id} for {language}: {e}")
109
- raise HTTPException(status_code=500, detail="Failed to fetch lesson")
110
-
111
-
112
- # ==================== User Progress Endpoints ====================
113
-
114
- @router.get("/progress")
115
- async def get_user_progress(request: Request, token: Optional[str] = Depends(optional_hf_token)):
116
- """
117
- Get user's learning progress
118
-
119
- Returns overall stats, lesson progress, vocabulary progress, and achievements
120
- """
121
- try:
122
- # Use authenticated user ID or default for anonymous users
123
- user_id = token if token else 'anonymous'
124
-
125
- progress = learning_service.get_user_progress(user_id)
126
- if not progress:
127
- raise HTTPException(status_code=500, detail="Failed to load user progress")
128
-
129
- return {
130
- "success": True,
131
- "progress": progress
132
- }
133
- except HTTPException:
134
- raise
135
- except Exception as e:
136
- logger.error(f"Error fetching user progress: {e}")
137
- raise HTTPException(status_code=500, detail="Failed to fetch user progress")
138
-
139
-
140
- @router.post("/progress/lesson")
141
- async def update_lesson_progress(
142
- progress_update: LessonProgressUpdate,
143
- request: Request,
144
- token: Optional[str] = Depends(optional_hf_token)
145
- ):
146
- """
147
- Update progress for a specific lesson
148
-
149
- Records completion status, scores, and time spent on a lesson
150
- """
151
- try:
152
- user_id = token if token else 'anonymous'
153
-
154
- # Build progress update dict
155
- update_data = {
156
- 'lesson_id': progress_update.lesson_id,
157
- 'status': progress_update.status
158
- }
159
-
160
- # Add optional fields if provided
161
- if progress_update.score is not None:
162
- update_data['latest_score'] = progress_update.score
163
-
164
- # Track best score
165
- user_progress = learning_service.get_user_progress(user_id)
166
- if user_progress:
167
- lesson_key = str(progress_update.lesson_id)
168
- current_best = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('best_score', 0)
169
- update_data['best_score'] = max(current_best, progress_update.score)
170
-
171
- if progress_update.pronunciation_score is not None:
172
- update_data['pronunciation_score'] = progress_update.pronunciation_score
173
-
174
- if progress_update.listening_score is not None:
175
- update_data['listening_score'] = progress_update.listening_score
176
-
177
- if progress_update.comprehension_score is not None:
178
- update_data['comprehension_score'] = progress_update.comprehension_score
179
-
180
- if progress_update.time_spent_seconds is not None:
181
- update_data['time_spent_seconds'] = progress_update.time_spent_seconds
182
-
183
- if progress_update.steps_completed is not None:
184
- update_data['steps_completed'] = progress_update.steps_completed
185
-
186
- if progress_update.steps_skipped is not None:
187
- update_data['steps_skipped'] = progress_update.steps_skipped
188
-
189
- # Add completion timestamp if status is completed
190
- if progress_update.status == 'completed':
191
- update_data['completed_at'] = datetime.utcnow().isoformat() + 'Z'
192
-
193
- # Increment attempts
194
- user_progress = learning_service.get_user_progress(user_id)
195
- if user_progress:
196
- lesson_key = str(progress_update.lesson_id)
197
- current_attempts = user_progress.get('lesson_progress', {}).get(lesson_key, {}).get('attempts', 0)
198
- update_data['attempts'] = current_attempts + 1
199
-
200
- # Save to file
201
- success = learning_service.update_lesson_progress(
202
- user_id,
203
- progress_update.lesson_id,
204
- update_data
205
- )
206
-
207
- if not success:
208
- raise HTTPException(status_code=500, detail="Failed to save progress")
209
-
210
- return {
211
- "success": True,
212
- "message": "Lesson progress updated"
213
- }
214
- except HTTPException:
215
- raise
216
- except Exception as e:
217
- logger.error(f"Error updating lesson progress: {e}")
218
- raise HTTPException(status_code=500, detail="Failed to update lesson progress")
219
-
220
-
221
- @router.post("/progress/vocabulary")
222
- async def record_vocabulary_review(
223
- review: VocabularyReview,
224
- request: Request,
225
- token: Optional[str] = Depends(optional_hf_token)
226
- ):
227
- """
228
- Record a vocabulary review/practice session
229
-
230
- Updates mastery level and review statistics for a vocabulary word
231
- """
232
- try:
233
- user_id = token if token else 'anonymous'
234
-
235
- # Get current vocabulary progress
236
- user_progress = learning_service.get_user_progress(user_id)
237
- if not user_progress:
238
- raise HTTPException(status_code=500, detail="Failed to load user progress")
239
-
240
- vocab_key = str(review.vocabulary_id)
241
- vocab_progress = user_progress.get('vocabulary_progress', {}).get(vocab_key, {
242
- 'vocabulary_id': review.vocabulary_id,
243
- 'swahili': review.swahili,
244
- 'mastery_level': 0,
245
- 'times_reviewed': 0,
246
- 'times_correct': 0,
247
- 'ease_factor': 2.5,
248
- 'interval_days': 0
249
- })
250
-
251
- # Update review counts
252
- vocab_progress['times_reviewed'] = vocab_progress.get('times_reviewed', 0) + 1
253
- if review.is_correct:
254
- vocab_progress['times_correct'] = vocab_progress.get('times_correct', 0) + 1
255
-
256
- # Update mastery level if provided
257
- if review.mastery_level is not None:
258
- vocab_progress['mastery_level'] = review.mastery_level
259
-
260
- # Update timestamps
261
- vocab_progress['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
262
-
263
- # Calculate next review date using simple spaced repetition
264
- # (simplified version - could use SuperMemo SM-2 algorithm)
265
- interval_days = vocab_progress.get('interval_days', 0)
266
- if review.is_correct:
267
- interval_days = max(1, interval_days * 2) # Double the interval
268
- else:
269
- interval_days = 1 # Reset to 1 day if incorrect
270
-
271
- vocab_progress['interval_days'] = interval_days
272
-
273
- from datetime import timedelta
274
- next_review = datetime.utcnow() + timedelta(days=interval_days)
275
- vocab_progress['next_review_at'] = next_review.isoformat() + 'Z'
276
-
277
- # Save to file
278
- success = learning_service.update_vocabulary_progress(
279
- user_id,
280
- review.vocabulary_id,
281
- vocab_progress
282
- )
283
-
284
- if not success:
285
- raise HTTPException(status_code=500, detail="Failed to save vocabulary progress")
286
-
287
- return {
288
- "success": True,
289
- "message": "Vocabulary review recorded",
290
- "next_review_at": vocab_progress['next_review_at']
291
- }
292
- except HTTPException:
293
- raise
294
- except Exception as e:
295
- logger.error(f"Error recording vocabulary review: {e}")
296
- raise HTTPException(status_code=500, detail="Failed to record vocabulary review")
297
-
298
-
299
- # ==================== Achievement Endpoints ====================
300
-
301
- @router.get("/achievements")
302
- async def get_achievements(request: Request, token: Optional[str] = Depends(optional_hf_token)):
303
- """
304
- Get all available achievements and user's progress on them
305
- """
306
- try:
307
- # Get achievements configuration
308
- achievements_config = learning_service.get_achievements()
309
- if not achievements_config:
310
- raise HTTPException(status_code=404, detail="Achievements not found")
311
-
312
- # Get user progress
313
- user_id = token if token else 'anonymous'
314
- user_progress = learning_service.get_user_progress(user_id)
315
-
316
- # Merge achievement definitions with user progress
317
- user_achievements = user_progress.get('achievements', {}) if user_progress else {}
318
-
319
- achievements_with_progress = []
320
- for achievement in achievements_config.get('achievements', []):
321
- achievement_id = achievement['achievement_id']
322
- achievement_data = {
323
- **achievement,
324
- 'unlocked': False,
325
- 'progress': 0
326
- }
327
-
328
- # Add user progress if available
329
- if achievement_id in user_achievements:
330
- achievement_data.update(user_achievements[achievement_id])
331
-
332
- achievements_with_progress.append(achievement_data)
333
-
334
- return {
335
- "success": True,
336
- "achievements": achievements_with_progress,
337
- "tiers": achievements_config.get('tiers', {})
338
- }
339
- except HTTPException:
340
- raise
341
- except Exception as e:
342
- logger.error(f"Error fetching achievements: {e}")
343
- raise HTTPException(status_code=500, detail="Failed to fetch achievements")
344
-
345
-
346
- @router.post("/achievements/check")
347
- async def check_achievement(
348
- achievement: AchievementCheck,
349
- request: Request,
350
- token: Optional[str] = Depends(optional_hf_token)
351
- ):
352
- """
353
- Check and potentially unlock an achievement
354
-
355
- Updates achievement progress and unlocks if target is reached
356
- """
357
- try:
358
- user_id = token if token else 'anonymous'
359
-
360
- success = learning_service.unlock_achievement(
361
- user_id,
362
- achievement.achievement_id,
363
- achievement.progress,
364
- achievement.target
365
- )
366
-
367
- if not success:
368
- raise HTTPException(status_code=500, detail="Failed to update achievement")
369
-
370
- is_unlocked = achievement.progress >= achievement.target
371
-
372
- return {
373
- "success": True,
374
- "unlocked": is_unlocked,
375
- "achievement_id": achievement.achievement_id
376
- }
377
- except HTTPException:
378
- raise
379
- except Exception as e:
380
- logger.error(f"Error checking achievement: {e}")
381
- raise HTTPException(status_code=500, detail="Failed to check achievement")
382
-
383
-
384
- # ==================== Statistics Endpoints ====================
385
-
386
- @router.get("/stats")
387
- async def get_user_stats(request: Request, token: Optional[str] = Depends(optional_hf_token)):
388
- """
389
- Get user's overall learning statistics
390
-
391
- Returns aggregated stats like total XP, streak, lessons completed, etc.
392
- """
393
- try:
394
- user_id = token if token else 'anonymous'
395
- progress = learning_service.get_user_progress(user_id)
396
-
397
- if not progress:
398
- raise HTTPException(status_code=500, detail="Failed to load user progress")
399
-
400
- return {
401
- "success": True,
402
- "stats": progress.get('overall_stats', {}),
403
- "daily_stats": progress.get('daily_stats', {})
404
- }
405
- except HTTPException:
406
- raise
407
- except Exception as e:
408
- logger.error(f"Error fetching user stats: {e}")
409
- raise HTTPException(status_code=500, detail="Failed to fetch user stats")
410
-
411
-
412
- # ==================== TTS and ASR Endpoints ====================
413
-
414
- class TTSRequest(BaseModel):
415
- text: str
416
- language: str
417
- messageId: Optional[str] = None
418
-
419
-
420
- @router.post("/tts/generate")
421
- async def generate_tts(
422
- tts_request: TTSRequest,
423
- request: Request
424
- ):
425
- """
426
- Generate TTS audio for lesson text
427
- """
428
- try:
429
- from app.main import tts_service
430
-
431
- # Generate TTS audio
432
- audio_data = await tts_service.generate_speech(
433
- text=tts_request.text,
434
- language_code=tts_request.language
435
- )
436
-
437
- if not audio_data:
438
- raise HTTPException(status_code=500, detail="Failed to generate TTS audio")
439
-
440
- # Return audio as WAV file
441
- return Response(
442
- content=audio_data,
443
- media_type="audio/wav",
444
- headers={
445
- "Content-Disposition": f"inline; filename=tts_{tts_request.messageId or 'audio'}.wav"
446
- }
447
- )
448
- except Exception as e:
449
- logger.error(f"Error generating TTS: {e}")
450
- raise HTTPException(status_code=500, detail=f"Failed to generate TTS: {str(e)}")
451
-
452
-
453
- @router.post("/transcribe")
454
- async def transcribe_audio(
455
- request: Request,
456
- audio: UploadFile = File(...)
457
- ):
458
- """
459
- Transcribe audio for pronunciation practice
460
- """
461
- try:
462
- from app.main import transcription_service
463
-
464
- # Read audio file
465
- audio_bytes = await audio.read()
466
-
467
- # Get language from form data (default to Swahili)
468
- form = await request.form()
469
- language = form.get('language', 'swa')
470
-
471
- # Transcribe
472
- text = await transcription_service.transcribe_audio(
473
- audio_data=audio_bytes,
474
- language_code=language
475
- )
476
-
477
- return {
478
- "success": True,
479
- "text": text,
480
- "language": language
481
- }
482
- except Exception as e:
483
- logger.error(f"Error transcribing audio: {e}")
484
- raise HTTPException(status_code=500, detail=f"Failed to transcribe: {str(e)}")
485
-
486
-
487
- # ==================== Phase 1-3 Endpoints ====================
488
-
489
- # Vocabulary Management
490
-
491
- class VocabularyAddRequest(BaseModel):
492
- vocab_id: int
493
- source_lesson_id: Optional[int] = None
494
-
495
-
496
- class VocabularyReviewRequest(BaseModel):
497
- vocab_id: int
498
- rating: str # 'again', 'hard', 'good', 'easy'
499
-
500
-
501
- @router.get("/vocabulary/due")
502
- async def get_due_vocabulary(
503
- request: Request,
504
- token: Optional[str] = Depends(optional_hf_token)
505
- ):
506
- """Get vocabulary words due for FSRS review"""
507
- try:
508
- user_id = token if token else 'anonymous'
509
- progress = learning_service.get_user_progress(user_id)
510
-
511
- if not progress:
512
- return {"due_words": [], "total_due": 0}
513
-
514
- vocab_progress = progress.get('vocabulary_progress', {})
515
- now = datetime.utcnow()
516
- due_words = []
517
-
518
- for vocab_id, vocab_data in vocab_progress.items():
519
- next_review_str = vocab_data.get('fsrs', {}).get('next_review')
520
- if not next_review_str:
521
- continue
522
-
523
- next_review = datetime.fromisoformat(next_review_str.rstrip('Z'))
524
-
525
- if next_review <= now:
526
- hours_overdue = (now - next_review).total_seconds() / 3600
527
- vocab_data['priority'] = 1000 - hours_overdue
528
- due_words.append(vocab_data)
529
-
530
- due_words.sort(key=lambda x: x.get('priority', 0), reverse=True)
531
-
532
- return {
533
- "due_words": due_words,
534
- "total_due": len(due_words),
535
- "timestamp": now.isoformat() + 'Z'
536
- }
537
- except Exception as e:
538
- logger.error(f"Error getting due vocabulary: {e}")
539
- raise HTTPException(status_code=500, detail="Failed to get due vocabulary")
540
-
541
-
542
- @router.post("/vocabulary/add")
543
- async def add_vocabulary_to_practice(
544
- vocab_request: VocabularyAddRequest,
545
- request: Request,
546
- token: Optional[str] = Depends(optional_hf_token)
547
- ):
548
- """Add a vocabulary word to user's practice queue with FSRS initialization"""
549
- try:
550
- user_id = token if token else 'anonymous'
551
-
552
- vocab = learning_service.get_vocabulary(vocab_request.vocab_id)
553
- if not vocab:
554
- raise HTTPException(status_code=404, detail="Vocabulary not found")
555
-
556
- fsrs_data = {
557
- 'difficulty': 0.3,
558
- 'stability': 2.5,
559
- 'retrievability': 1.0,
560
- 'review_count': 0,
561
- 'last_review': None,
562
- 'next_review': datetime.utcnow().isoformat() + 'Z',
563
- 'lapses': 0,
564
- 'state': 'new'
565
- }
566
-
567
- user_vocab = {
568
- 'vocabulary_id': vocab_request.vocab_id,
569
- 'swahili': vocab.get('swahili', ''),
570
- 'english': vocab.get('english', ''),
571
- 'part_of_speech': vocab.get('part_of_speech', 'unknown'),
572
- 'added_at': datetime.utcnow().isoformat() + 'Z',
573
- 'added_from': vocab_request.source_lesson_id,
574
- 'fsrs': fsrs_data,
575
- 'mastery_level': 0,
576
- 'times_reviewed': 0,
577
- 'times_correct': 0,
578
- 'accuracy': 0.0
579
- }
580
-
581
- success = learning_service.update_vocabulary_progress(
582
- user_id, str(vocab_request.vocab_id), user_vocab
583
- )
584
-
585
- if success:
586
- return {"success": True, "vocabulary": user_vocab}
587
- else:
588
- raise HTTPException(status_code=500, detail="Failed to add vocabulary")
589
- except HTTPException:
590
- raise
591
- except Exception as e:
592
- logger.error(f"Error adding vocabulary: {e}")
593
- raise HTTPException(status_code=500, detail="Failed to add vocabulary")
594
-
595
-
596
- def calculate_next_review_fsrs(fsrs: Dict, grade: int) -> Dict:
597
- """Implement FSRS algorithm"""
598
- from datetime import timedelta
599
-
600
- difficulty = fsrs['difficulty']
601
- stability = fsrs['stability']
602
-
603
- if grade == 0:
604
- new_difficulty = min(difficulty + 0.2, 1.0)
605
- elif grade == 2:
606
- new_difficulty = min(difficulty + 0.1, 1.0)
607
- elif grade == 4:
608
- new_difficulty = max(difficulty - 0.1, 0.0)
609
- else:
610
- new_difficulty = difficulty
611
-
612
- if grade == 0:
613
- new_stability = stability * 0.5
614
- state = 'relearning'
615
- interval_minutes = 10
616
- elif grade == 2:
617
- new_stability = stability * 1.2
618
- state = 'review'
619
- interval_minutes = int(new_stability * 24 * 60)
620
- elif grade == 3:
621
- new_stability = stability * 2.5
622
- state = 'review'
623
- interval_minutes = int(new_stability * 24 * 60)
624
- else:
625
- new_stability = stability * 4.0
626
- state = 'review'
627
- interval_minutes = int(new_stability * 24 * 60)
628
-
629
- next_review = datetime.utcnow() + timedelta(minutes=interval_minutes)
630
-
631
- return {
632
- 'difficulty': new_difficulty,
633
- 'stability': new_stability,
634
- 'retrievability': 0.9 if grade >= 2 else 0.0,
635
- 'review_count': fsrs['review_count'] + 1,
636
- 'last_review': datetime.utcnow().isoformat() + 'Z',
637
- 'next_review': next_review.isoformat() + 'Z',
638
- 'lapses': fsrs['lapses'],
639
- 'state': state,
640
- 'interval_days': interval_minutes / (24 * 60)
641
- }
642
-
643
-
644
- def calculate_mastery_level(vocab: Dict) -> int:
645
- """Calculate mastery level (0-5)"""
646
- accuracy = vocab['accuracy']
647
- reviews = vocab['times_reviewed']
648
- stability = vocab['fsrs']['stability']
649
-
650
- if reviews == 0:
651
- return 0
652
- elif reviews < 5 or accuracy < 70:
653
- return 1
654
- elif reviews < 10 or accuracy < 85:
655
- return 2
656
- elif reviews < 20 or accuracy < 95:
657
- return 3
658
- elif reviews >= 20 and accuracy >= 95 and stability >= 30:
659
- return 4
660
- elif reviews >= 40 and accuracy >= 98 and stability >= 90:
661
- return 5
662
- else:
663
- return 3
664
-
665
-
666
- @router.post("/vocabulary/review")
667
- async def record_vocabulary_review_fsrs(
668
- review_request: VocabularyReviewRequest,
669
- request: Request,
670
- token: Optional[str] = Depends(optional_hf_token)
671
- ):
672
- """Record vocabulary review and update FSRS parameters"""
673
- try:
674
- user_id = token if token else 'anonymous'
675
- progress = learning_service.get_user_progress(user_id)
676
-
677
- if not progress or str(review_request.vocab_id) not in progress.get('vocabulary_progress', {}):
678
- raise HTTPException(status_code=404, detail="Vocabulary not in practice queue")
679
-
680
- vocab = progress['vocabulary_progress'][str(review_request.vocab_id)]
681
- fsrs = vocab['fsrs']
682
-
683
- grade_map = {'again': 0, 'hard': 2, 'good': 3, 'easy': 4}
684
- grade = grade_map.get(review_request.rating, 3)
685
-
686
- new_fsrs = calculate_next_review_fsrs(fsrs, grade)
687
-
688
- vocab['fsrs'] = new_fsrs
689
- vocab['times_reviewed'] += 1
690
- if grade >= 2:
691
- vocab['times_correct'] += 1
692
- else:
693
- vocab['fsrs']['lapses'] += 1
694
-
695
- vocab['accuracy'] = (vocab['times_correct'] / vocab['times_reviewed']) * 100 if vocab['times_reviewed'] > 0 else 0
696
- vocab['mastery_level'] = calculate_mastery_level(vocab)
697
- vocab['last_reviewed_at'] = datetime.utcnow().isoformat() + 'Z'
698
-
699
- if 'vocabulary_reviewed' not in progress['overall_stats']:
700
- progress['overall_stats']['vocabulary_reviewed'] = 0
701
- progress['overall_stats']['vocabulary_reviewed'] += 1
702
-
703
- learning_service.save_user_progress(user_id, progress)
704
-
705
- return {
706
- "success": True,
707
- "vocabulary": vocab,
708
- "next_review": new_fsrs['next_review'],
709
- "interval_days": new_fsrs['interval_days']
710
- }
711
- except HTTPException:
712
- raise
713
- except Exception as e:
714
- logger.error(f"Error recording vocabulary review: {e}")
715
- raise HTTPException(status_code=500, detail="Failed to record review")
716
-
717
-
718
- @router.get("/vocabulary/stats")
719
- async def get_vocabulary_stats(
720
- request: Request,
721
- token: Optional[str] = Depends(optional_hf_token)
722
- ):
723
- """Get vocabulary mastery statistics"""
724
- try:
725
- user_id = token if token else 'anonymous'
726
- progress = learning_service.get_user_progress(user_id)
727
-
728
- if not progress:
729
- return {
730
- "total_words": 0,
731
- "in_practice": 0,
732
- "mastery_breakdown": {str(i): 0 for i in range(6)},
733
- "average_accuracy": 0
734
- }
735
-
736
- vocab_progress = progress.get('vocabulary_progress', {})
737
- mastery_breakdown = {str(i): 0 for i in range(6)}
738
- total_accuracy = 0
739
- total_with_reviews = 0
740
-
741
- for vocab_data in vocab_progress.values():
742
- level = vocab_data.get('mastery_level', 0)
743
- mastery_breakdown[str(level)] += 1
744
-
745
- if vocab_data.get('times_reviewed', 0) > 0:
746
- total_accuracy += vocab_data.get('accuracy', 0)
747
- total_with_reviews += 1
748
-
749
- avg_accuracy = total_accuracy / total_with_reviews if total_with_reviews > 0 else 0
750
-
751
- return {
752
- "total_words": len(vocab_progress),
753
- "in_practice": len(vocab_progress),
754
- "mastery_breakdown": mastery_breakdown,
755
- "average_accuracy": round(avg_accuracy, 1),
756
- "total_reviews": sum(v.get('times_reviewed', 0) for v in vocab_progress.values())
757
- }
758
- except Exception as e:
759
- logger.error(f"Error getting vocabulary stats: {e}")
760
- raise HTTPException(status_code=500, detail="Failed to get stats")
761
-
762
-
763
- @router.get("/vocabulary/library")
764
- async def get_vocabulary_library(
765
- lesson_id: Optional[int] = None,
766
- level: Optional[str] = None,
767
- search: Optional[str] = None,
768
- request: Request = None,
769
- token: Optional[str] = Depends(optional_hf_token)
770
- ):
771
- """Browse all vocabulary with filters"""
772
- try:
773
- user_id = token if token else 'anonymous'
774
-
775
- all_vocab = learning_service.get_all_vocabulary()
776
- progress = learning_service.get_user_progress(user_id)
777
- user_vocab = progress.get('vocabulary_progress', {}) if progress else {}
778
-
779
- filtered_vocab = all_vocab
780
-
781
- if lesson_id:
782
- filtered_vocab = [v for v in filtered_vocab if v.get('lesson_id') == lesson_id]
783
-
784
- if level:
785
- filtered_vocab = [v for v in filtered_vocab if v.get('level') == level]
786
-
787
- if search:
788
- search_lower = search.lower()
789
- filtered_vocab = [v for v in filtered_vocab
790
- if search_lower in v.get('swahili', '').lower()
791
- or search_lower in v.get('english', '').lower()]
792
-
793
- for vocab in filtered_vocab:
794
- vocab_id = str(vocab.get('vocabulary_id') or vocab.get('id'))
795
- if vocab_id in user_vocab:
796
- vocab['status'] = 'practicing'
797
- vocab['mastery_level'] = user_vocab[vocab_id].get('mastery_level', 0)
798
- vocab['accuracy'] = user_vocab[vocab_id].get('accuracy', 0)
799
- vocab['next_review'] = user_vocab[vocab_id].get('fsrs', {}).get('next_review')
800
- else:
801
- vocab['status'] = 'not_practicing'
802
- vocab['mastery_level'] = 0
803
-
804
- return {
805
- "vocabulary": filtered_vocab,
806
- "total": len(filtered_vocab),
807
- "filters_applied": {
808
- "lesson_id": lesson_id,
809
- "level": level,
810
- "search": search
811
- }
812
- }
813
- except Exception as e:
814
- logger.error(f"Error getting vocabulary library: {e}")
815
- raise HTTPException(status_code=500, detail="Failed to get vocabulary")
816
-
817
-
818
- # Reading Comprehension
819
-
820
- class ComprehensionAnswer(BaseModel):
821
- question_id: str
822
- answer: str
823
-
824
-
825
- class ComprehensionSubmission(BaseModel):
826
- lesson_id: int
827
- passage_id: str
828
- answers: List[ComprehensionAnswer]
829
-
830
-
831
- @router.post("/comprehension/submit")
832
- async def submit_comprehension_answers(
833
- submission: ComprehensionSubmission,
834
- request: Request,
835
- token: Optional[str] = Depends(optional_hf_token)
836
- ):
837
- """Submit reading comprehension answers and get scoring"""
838
- try:
839
- user_id = token if token else 'anonymous'
840
-
841
- lesson = learning_service.get_lesson(submission.lesson_id)
842
- if not lesson:
843
- raise HTTPException(status_code=404, detail="Lesson not found")
844
-
845
- passage = None
846
- for p in lesson.get('reading_passages', []):
847
- if p['passage_id'] == submission.passage_id:
848
- passage = p
849
- break
850
-
851
- if not passage:
852
- raise HTTPException(status_code=404, detail="Passage not found")
853
-
854
- results = []
855
- correct_count = 0
856
-
857
- for submitted in submission.answers:
858
- question_id = submitted.question_id
859
- user_answer = submitted.answer.strip().lower()
860
-
861
- question = None
862
- for q in passage['comprehension_questions']:
863
- if q['question_id'] == question_id:
864
- question = q
865
- break
866
-
867
- if not question:
868
- continue
869
-
870
- correct_answers = [ans.strip().lower() for ans in question.get('correct_answers', [])]
871
- is_correct = user_answer in correct_answers
872
-
873
- if is_correct:
874
- correct_count += 1
875
-
876
- results.append({
877
- "question_id": question_id,
878
- "correct": is_correct,
879
- "user_answer": user_answer,
880
- "correct_answer": question['correct_answers'][0] if correct_answers else None,
881
- "explanation": question.get('explanation')
882
- })
883
-
884
- score = (correct_count / len(submission.answers)) * 100 if submission.answers else 0
885
-
886
- progress = learning_service.get_user_progress(user_id)
887
- if not progress:
888
- progress = learning_service.create_default_progress(user_id)
889
-
890
- if 'comprehension_scores' not in progress:
891
- progress['comprehension_scores'] = {}
892
-
893
- progress['comprehension_scores'][f"{submission.lesson_id}_{submission.passage_id}"] = {
894
- "score": score,
895
- "completed_at": datetime.utcnow().isoformat() + 'Z',
896
- "attempts": progress['comprehension_scores'].get(f"{submission.lesson_id}_{submission.passage_id}", {}).get('attempts', 0) + 1
897
- }
898
-
899
- learning_service.save_user_progress(user_id, progress)
900
-
901
- return {
902
- "results": results,
903
- "score": round(score, 1),
904
- "correct": correct_count,
905
- "total": len(submission.answers)
906
- }
907
- except HTTPException:
908
- raise
909
- except Exception as e:
910
- logger.error(f"Error submitting comprehension: {e}")
911
- raise HTTPException(status_code=500, detail="Failed to submit comprehension")
912
-
913
-
914
- # Task Scenarios
915
-
916
- class ScenarioProgressUpdate(BaseModel):
917
- turn_id: str
918
- choice_id: str
919
-
920
-
921
- @router.get("/scenarios/{scenario_id}")
922
- async def get_scenario(
923
- scenario_id: str,
924
- request: Request,
925
- token: Optional[str] = Depends(optional_hf_token)
926
- ):
927
- """Get task scenario with branching dialogue"""
928
- try:
929
- user_id = token if token else 'anonymous'
930
-
931
- scenario = learning_service.get_scenario(scenario_id)
932
- if not scenario:
933
- raise HTTPException(status_code=404, detail="Scenario not found")
934
-
935
- progress = learning_service.get_user_progress(user_id)
936
- scenario_progress = None
937
-
938
- if progress and 'scenario_progress' in progress:
939
- scenario_progress = progress['scenario_progress'].get(scenario_id)
940
-
941
- return {
942
- "scenario": scenario,
943
- "user_progress": scenario_progress
944
- }
945
- except HTTPException:
946
- raise
947
- except Exception as e:
948
- logger.error(f"Error getting scenario: {e}")
949
- raise HTTPException(status_code=500, detail="Failed to get scenario")
950
-
951
-
952
- @router.post("/scenarios/{scenario_id}/progress")
953
- async def update_scenario_progress(
954
- scenario_id: str,
955
- progress_update: ScenarioProgressUpdate,
956
- request: Request,
957
- token: Optional[str] = Depends(optional_hf_token)
958
- ):
959
- """Update scenario progress with user choice"""
960
- try:
961
- user_id = token if token else 'anonymous'
962
-
963
- scenario = learning_service.get_scenario(scenario_id)
964
- if not scenario:
965
- raise HTTPException(status_code=404, detail="Scenario not found")
966
-
967
- progress = learning_service.get_user_progress(user_id)
968
- if not progress:
969
- progress = learning_service.create_default_progress(user_id)
970
-
971
- if 'scenario_progress' not in progress:
972
- progress['scenario_progress'] = {}
973
-
974
- if scenario_id not in progress['scenario_progress']:
975
- progress['scenario_progress'][scenario_id] = {
976
- "started_at": datetime.utcnow().isoformat() + 'Z',
977
- "turns": [],
978
- "completed": False
979
- }
980
-
981
- progress['scenario_progress'][scenario_id]['turns'].append({
982
- "turn_id": progress_update.turn_id,
983
- "choice_id": progress_update.choice_id,
984
- "timestamp": datetime.utcnow().isoformat() + 'Z'
985
- })
986
-
987
- turns_count = len(progress['scenario_progress'][scenario_id]['turns'])
988
- if turns_count >= scenario.get('required_turns', 6):
989
- progress['scenario_progress'][scenario_id]['completed'] = True
990
- progress['scenario_progress'][scenario_id]['completed_at'] = datetime.utcnow().isoformat() + 'Z'
991
-
992
- learning_service.save_user_progress(user_id, progress)
993
-
994
- return {
995
- "success": True,
996
- "progress": progress['scenario_progress'][scenario_id]
997
- }
998
- except HTTPException:
999
- raise
1000
- except Exception as e:
1001
- logger.error(f"Error updating scenario progress: {e}")
1002
- raise HTTPException(status_code=500, detail="Failed to update scenario")
1003
-
1004
-
1005
- @router.get("/scenarios")
1006
- async def list_scenarios(
1007
- request: Request,
1008
- token: Optional[str] = Depends(optional_hf_token)
1009
- ):
1010
- """Get list of all available scenarios"""
1011
- try:
1012
- scenarios = learning_service.get_all_scenarios()
1013
- return {
1014
- "success": True,
1015
- "scenarios": scenarios,
1016
- "total": len(scenarios)
1017
- }
1018
- except Exception as e:
1019
- logger.error(f"Error listing scenarios: {e}")
1020
- raise HTTPException(status_code=500, detail="Failed to list scenarios")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/mobile.py DELETED
@@ -1,536 +0,0 @@
1
- from fastapi import APIRouter, HTTPException, File, UploadFile, Form, Query, Depends
2
- from fastapi.responses import Response
3
- from typing import Optional
4
- from pydantic import BaseModel
5
- import base64
6
- import json
7
- import uuid
8
- import datetime
9
- from app.services.translation_service import TranslationService
10
- from app.services.tts_service import TTSService
11
- from app.services.transcription_service import TranscriptionService
12
- from app.auth import require_hf_token
13
-
14
- router = APIRouter()
15
-
16
- # Service instances - these will be injected by main app
17
- translation_service = None
18
- tts_service = None
19
- transcription_service = None
20
-
21
- # Mobile-specific data models
22
- class MobileSessionRequest(BaseModel):
23
- user_name: str
24
- default_source_lang: str = "eng"
25
- default_target_lang: str = "swa"
26
-
27
- class MobileSessionResponse(BaseModel):
28
- session_id: str
29
- participant_id: str
30
- user_name: str
31
- source_language: str
32
- target_language: str
33
-
34
- class MobileTranscribeRequest(BaseModel):
35
- participant_id: str
36
- source_language: str
37
- target_language: str
38
- is_final_chunk: bool = False
39
-
40
- class MobileLanguageUpdateRequest(BaseModel):
41
- participant_id: str
42
- source_language: str
43
- target_language: str
44
-
45
- # In-memory session storage (in production, use Redis or database)
46
- mobile_sessions = {}
47
-
48
- @router.post("/mobile/session/create", response_model=MobileSessionResponse)
49
- async def create_mobile_session(
50
- user_name: str = Form(...),
51
- default_source_lang: str = Form("eng"),
52
- default_target_lang: str = Form("swa"),
53
- token: str = Depends(require_hf_token)
54
- ):
55
- """Create a mobile-specific single-user session"""
56
- try:
57
- print(f"=== MOBILE SESSION CREATE REQUEST ===")
58
- print(f"User name: {user_name}")
59
- print(f"Source language: {default_source_lang}")
60
- print(f"Target language: {default_target_lang}")
61
-
62
- # Validate inputs
63
- if not user_name or user_name.strip() == "":
64
- raise HTTPException(status_code=400, detail="User name is required")
65
-
66
- # Validate language codes
67
- valid_languages = ["eng", "swa", "kik", "kam", "mer", "luo", "som"]
68
- if default_source_lang not in valid_languages:
69
- print(f"Invalid source language: {default_source_lang}, defaulting to 'eng'")
70
- default_source_lang = "eng"
71
- if default_target_lang not in valid_languages:
72
- print(f"Invalid target language: {default_target_lang}, defaulting to 'swa'")
73
- default_target_lang = "swa"
74
-
75
- session_id = f"mobile-{uuid.uuid4().hex[:8]}"
76
- participant_id = f"user-{uuid.uuid4().hex[:8]}"
77
-
78
- # Store session data
79
- mobile_sessions[session_id] = {
80
- "session_id": session_id,
81
- "participant_id": participant_id,
82
- "user_name": user_name.strip(),
83
- "source_language": default_source_lang,
84
- "target_language": default_target_lang,
85
- "created_at": datetime.datetime.now().isoformat()
86
- }
87
-
88
- print(f"Created session: {session_id} for user: {user_name}")
89
- print(f"Total sessions: {len(mobile_sessions)}")
90
-
91
- response = MobileSessionResponse(
92
- session_id=session_id,
93
- participant_id=participant_id,
94
- user_name=user_name.strip(),
95
- source_language=default_source_lang,
96
- target_language=default_target_lang
97
- )
98
-
99
- print(f"Returning response: {response}")
100
- return response
101
-
102
- except HTTPException:
103
- raise
104
- except Exception as e:
105
- print(f"ERROR creating mobile session: {e}")
106
- import traceback
107
- traceback.print_exc()
108
- raise HTTPException(status_code=500, detail=f"Failed to create mobile session: {str(e)}")
109
-
110
- @router.get("/mobile/session/{session_id}")
111
- async def get_mobile_session(session_id: str, token: str = Depends(require_hf_token)):
112
- """Get mobile session details"""
113
- if session_id not in mobile_sessions:
114
- raise HTTPException(status_code=404, detail="Session not found")
115
-
116
- return mobile_sessions[session_id]
117
-
118
- @router.put("/mobile/session/{session_id}/languages")
119
- async def update_session_languages(
120
- session_id: str,
121
- participant_id: str = Form(...),
122
- source_language: str = Form(...),
123
- target_language: str = Form(...)
124
- ):
125
- """Update the default languages for a mobile session"""
126
- try:
127
- if session_id not in mobile_sessions:
128
- raise HTTPException(status_code=404, detail="Session not found")
129
-
130
- session = mobile_sessions[session_id]
131
- if session["participant_id"] != participant_id:
132
- raise HTTPException(status_code=403, detail="Invalid participant")
133
-
134
- # Update session languages
135
- session["source_language"] = source_language
136
- session["target_language"] = target_language
137
-
138
- return {
139
- "success": True,
140
- "session_id": session_id,
141
- "source_language": source_language,
142
- "target_language": target_language
143
- }
144
-
145
- except Exception as e:
146
- raise HTTPException(status_code=500, detail=f"Failed to update languages: {str(e)}")
147
-
148
- @router.post("/mobile/session/{session_id}/transcribe-realtime")
149
- async def transcribe_realtime(
150
- session_id: str,
151
- audio: UploadFile = File(...),
152
- participant_id: str = Form(...),
153
- source_language: str = Form(...),
154
- target_language: str = Form(...),
155
- is_final_chunk: bool = Form(False),
156
- chunk_sequence: int = Form(0)
157
- ):
158
- """Real-time transcription endpoint for mobile with streaming support"""
159
- try:
160
- if session_id not in mobile_sessions:
161
- raise HTTPException(status_code=404, detail="Session not found")
162
-
163
- session = mobile_sessions[session_id]
164
- if session["participant_id"] != participant_id:
165
- raise HTTPException(status_code=403, detail="Invalid participant")
166
-
167
- # Read audio file
168
- audio_data = await audio.read()
169
-
170
- # Generate unique message ID for this chunk sequence
171
- message_id = f"msg-{participant_id}-{chunk_sequence}"
172
-
173
- # Initialize response data
174
- response_data = {
175
- "success": True,
176
- "message_id": message_id,
177
- "chunk_sequence": chunk_sequence,
178
- "original_text": "",
179
- "original_language": source_language,
180
- "is_final_chunk": is_final_chunk,
181
- "is_interim": not is_final_chunk,
182
- "session_id": session_id,
183
- "translated_text": None,
184
- "target_language": target_language,
185
- "has_audio": False,
186
- "audio_base64": None,
187
- "audio_format": None
188
- }
189
-
190
- # Process transcription
191
- if transcription_service:
192
- try:
193
- # Use streaming transcription if available
194
- if hasattr(transcription_service, 'process_realtime_chunk'):
195
- transcription_result = await transcription_service.process_realtime_chunk(
196
- audio_data, source_language, participant_id, is_final_chunk
197
- )
198
- else:
199
- # Fallback to regular transcription
200
- transcription_result = await transcription_service.transcribe_audio(
201
- audio_data, source_language
202
- )
203
-
204
- response_data["original_text"] = transcription_result or ""
205
-
206
- # Only process translation and TTS for final chunks with actual text
207
- if is_final_chunk and transcription_result and transcription_result.strip():
208
- if translation_service:
209
- try:
210
- translated_text = await translation_service.translate_text(
211
- transcription_result, source_language, target_language
212
- )
213
- response_data["translated_text"] = translated_text
214
-
215
- # Generate TTS audio in target language
216
- if tts_service and translated_text:
217
- try:
218
- tts_audio = await tts_service.generate_speech(
219
- translated_text, target_language, output_format="wav"
220
- )
221
-
222
- if tts_audio:
223
- response_data.update({
224
- "has_audio": True,
225
- "audio_base64": base64.b64encode(tts_audio).decode('utf-8'),
226
- "audio_format": "wav"
227
- })
228
- except Exception as tts_error:
229
- print(f"TTS generation failed: {tts_error}")
230
- # Continue without TTS
231
-
232
- except Exception as translation_error:
233
- print(f"Translation failed: {translation_error}")
234
- # Continue without translation
235
-
236
- except Exception as transcription_error:
237
- print(f"Transcription failed: {transcription_error}")
238
- response_data["original_text"] = ""
239
-
240
- return response_data
241
-
242
- else:
243
- raise HTTPException(status_code=500, detail="Transcription service not available")
244
-
245
- except Exception as e:
246
- raise HTTPException(status_code=500, detail=f"Real-time transcription failed: {str(e)}")
247
-
248
- @router.post("/mobile/session/{session_id}/stream-audio")
249
- async def stream_audio_chunk(
250
- session_id: str,
251
- participant_id: str = Form(...),
252
- audio_chunk: UploadFile = File(...),
253
- source_language: str = Form(...),
254
- target_language: str = Form(...),
255
- chunk_index: int = Form(0),
256
- is_speaking: bool = Form(True),
257
- force_complete: bool = Form(False)
258
- ):
259
- """Stream audio chunks for continuous processing"""
260
- try:
261
- if session_id not in mobile_sessions:
262
- raise HTTPException(status_code=404, detail="Session not found")
263
-
264
- session = mobile_sessions[session_id]
265
- if session["participant_id"] != participant_id:
266
- raise HTTPException(status_code=403, detail="Invalid participant")
267
-
268
- audio_data = await audio_chunk.read()
269
-
270
- # Use streaming approach similar to WebSocket
271
- interim_text = ""
272
- if transcription_service:
273
- try:
274
- if hasattr(transcription_service, 'process_audio_chunk'):
275
- result = await transcription_service.process_audio_chunk(
276
- audio_data,
277
- source_language,
278
- participant_id,
279
- has_voice_activity=is_speaking,
280
- progress_callback=None, # No callback for HTTP
281
- sentence_callback=None # No callback for HTTP
282
- )
283
- interim_text = result or ""
284
- else:
285
- # Fallback to regular transcription for interim results
286
- interim_text = await transcription_service.transcribe_audio(
287
- audio_data, source_language
288
- ) or ""
289
- except Exception as e:
290
- print(f"Streaming transcription error: {e}")
291
- interim_text = ""
292
-
293
- return {
294
- "success": True,
295
- "chunk_index": chunk_index,
296
- "session_id": session_id,
297
- "interim_text": interim_text,
298
- "is_speaking": is_speaking,
299
- "force_complete": force_complete
300
- }
301
- else:
302
- raise HTTPException(status_code=500, detail="Transcription service not available")
303
-
304
- except Exception as e:
305
- raise HTTPException(status_code=500, detail=f"Audio streaming failed: {str(e)}")
306
-
307
- @router.get("/mobile/session/{session_id}/realtime-status")
308
- async def get_realtime_status(session_id: str, participant_id: str = Query(...)):
309
- """Get current real-time processing status"""
310
- try:
311
- if session_id not in mobile_sessions:
312
- raise HTTPException(status_code=404, detail="Session not found")
313
-
314
- session = mobile_sessions[session_id]
315
- if session["participant_id"] != participant_id:
316
- raise HTTPException(status_code=403, detail="Invalid participant")
317
-
318
- # Check if transcription service has any pending messages
319
- pending_messages = []
320
- if transcription_service:
321
- try:
322
- if hasattr(transcription_service, 'get_participant_status'):
323
- pending_messages = transcription_service.get_participant_status(participant_id)
324
- else:
325
- pending_messages = []
326
- except Exception as e:
327
- print(f"Error getting participant status: {e}")
328
- pending_messages = []
329
-
330
- return {
331
- "session_id": session_id,
332
- "participant_id": participant_id,
333
- "is_active": True,
334
- "pending_messages": pending_messages,
335
- "current_languages": {
336
- "source": session["source_language"],
337
- "target": session["target_language"]
338
- },
339
- "service_status": {
340
- "transcription": transcription_service is not None,
341
- "translation": translation_service is not None,
342
- "tts": tts_service is not None
343
- }
344
- }
345
-
346
- except Exception as e:
347
- raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
348
-
349
- @router.post("/mobile/session/{session_id}/transcribe-with-languages")
350
- async def transcribe_with_languages_legacy(
351
- session_id: str,
352
- audio: UploadFile = File(...),
353
- participant_id: str = Form(...),
354
- source_language: str = Form(...),
355
- target_language: str = Form(...),
356
- is_final_chunk: bool = Form(False)
357
- ):
358
- """Legacy endpoint - transcribe audio with specific source/target languages for mobile"""
359
- try:
360
- if session_id not in mobile_sessions:
361
- raise HTTPException(status_code=404, detail="Session not found")
362
-
363
- session = mobile_sessions[session_id]
364
- if session["participant_id"] != participant_id:
365
- raise HTTPException(status_code=403, detail="Invalid participant")
366
-
367
- # Read audio file
368
- audio_data = await audio.read()
369
-
370
- # Generate unique message ID
371
- message_id = f"msg-{uuid.uuid4().hex[:8]}"
372
-
373
- # Initialize response
374
- response_data = {
375
- "success": True,
376
- "message_id": message_id,
377
- "original_text": "",
378
- "original_language": source_language,
379
- "translated_text": None,
380
- "target_language": target_language,
381
- "has_audio": False,
382
- "is_final_chunk": is_final_chunk,
383
- "audio_base64": None
384
- }
385
-
386
- # Process transcription in source language
387
- if transcription_service:
388
- try:
389
- transcription_result = await transcription_service.transcribe_audio(
390
- audio_data, source_language
391
- )
392
- response_data["original_text"] = transcription_result or ""
393
-
394
- # Process translation to target language
395
- if translation_service and transcription_result and transcription_result.strip():
396
- try:
397
- translated_text = await translation_service.translate_text(
398
- transcription_result, source_language, target_language
399
- )
400
- response_data["translated_text"] = translated_text
401
-
402
- # Generate TTS audio in target language
403
- if tts_service and translated_text:
404
- try:
405
- tts_audio = await tts_service.generate_speech(
406
- translated_text, target_language, output_format="wav"
407
- )
408
-
409
- if tts_audio:
410
- response_data["has_audio"] = True
411
- response_data["audio_base64"] = base64.b64encode(tts_audio).decode('utf-8')
412
- except Exception as tts_error:
413
- print(f"TTS generation failed: {tts_error}")
414
-
415
- except Exception as translation_error:
416
- print(f"Translation failed: {translation_error}")
417
-
418
- except Exception as transcription_error:
419
- print(f"Transcription failed: {transcription_error}")
420
-
421
- return response_data
422
- else:
423
- raise HTTPException(status_code=500, detail="Transcription service not available")
424
-
425
- except Exception as e:
426
- raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
427
-
428
- @router.post("/mobile/translate")
429
- async def translate_text_mobile(
430
- text: str = Form(...),
431
- source_lang: str = Form(...),
432
- target_lang: str = Form(...)
433
- ):
434
- """Mobile-friendly text translation endpoint"""
435
- try:
436
- if not translation_service:
437
- raise HTTPException(status_code=500, detail="Translation service not initialized")
438
-
439
- # Map common language codes to internal format
440
- lang_mapping = {
441
- "english": "eng", "en": "eng",
442
- "swahili": "swa", "sw": "swa",
443
- "kikuyu": "kik", "ki": "kik",
444
- "kamba": "kam", "kam": "kam",
445
- "kimeru": "mer", "mer": "mer",
446
- "luo": "luo", "luo": "luo",
447
- "somali": "som", "so": "som"
448
- }
449
-
450
- source_code = lang_mapping.get(source_lang.lower(), source_lang.lower())
451
- target_code = lang_mapping.get(target_lang.lower(), target_lang.lower())
452
-
453
- translated_text = await translation_service.translate_text(text, source_code, target_code)
454
-
455
- return {
456
- "success": True,
457
- "original_text": text,
458
- "translated_text": translated_text or text,
459
- "source_language": source_code,
460
- "target_language": target_code
461
- }
462
-
463
- except Exception as e:
464
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
465
-
466
- @router.get("/mobile/languages")
467
- async def get_supported_languages():
468
- """Get list of supported languages for mobile app"""
469
- return {
470
- "supported_languages": [
471
- {"code": "eng", "name": "English", "display_name": "English (eng)"},
472
- {"code": "swa", "name": "Swahili", "display_name": "Swahili (swa)"},
473
- {"code": "kik", "name": "Kikuyu", "display_name": "Kikuyu (kik)"},
474
- {"code": "kam", "name": "Kamba", "display_name": "Kamba (kam)"},
475
- {"code": "mer", "name": "Kimeru", "display_name": "Kimeru (mer)"},
476
- {"code": "luo", "name": "Luo", "display_name": "Luo (luo)"},
477
- {"code": "som", "name": "Somali", "display_name": "Somali (som)"}
478
- ]
479
- }
480
-
481
- @router.get("/mobile/test")
482
- async def test_mobile_endpoints():
483
- """Test endpoint for mobile app connectivity"""
484
- return {
485
- "status": "Mobile API is working",
486
- "endpoints": [
487
- "/mobile/session/create",
488
- "/mobile/session/{session_id}",
489
- "/mobile/session/{session_id}/languages",
490
- "/mobile/session/{session_id}/transcribe-realtime",
491
- "/mobile/session/{session_id}/stream-audio",
492
- "/mobile/session/{session_id}/realtime-status",
493
- "/mobile/session/{session_id}/transcribe-with-languages",
494
- "/mobile/translate",
495
- "/mobile/languages",
496
- "/mobile/test"
497
- ],
498
- "timestamp": datetime.datetime.now().isoformat(),
499
- "services_available": {
500
- "transcription": transcription_service is not None,
501
- "translation": translation_service is not None,
502
- "tts": tts_service is not None
503
- },
504
- "active_sessions": len(mobile_sessions),
505
- "session_list": list(mobile_sessions.keys())
506
- }
507
-
508
- @router.post("/mobile/test-session")
509
- async def test_session_creation(
510
- test_user: str = Form("TestUser"),
511
- test_source: str = Form("eng"),
512
- test_target: str = Form("swa")
513
- ):
514
- """Test session creation with debug info"""
515
- try:
516
- print(f"=== TEST SESSION CREATE ===")
517
- print(f"Received: user={test_user}, source={test_source}, target={test_target}")
518
-
519
- session_id = f"test-{uuid.uuid4().hex[:8]}"
520
-
521
- return {
522
- "success": True,
523
- "test_session_id": session_id,
524
- "received_params": {
525
- "user": test_user,
526
- "source": test_source,
527
- "target": test_target
528
- },
529
- "form_processing": "OK"
530
- }
531
- except Exception as e:
532
- print(f"Test session error: {e}")
533
- return {
534
- "success": False,
535
- "error": str(e)
536
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/sessions.py DELETED
@@ -1,200 +0,0 @@
1
- from fastapi import APIRouter, HTTPException, Response, Depends
2
- from typing import List
3
- from pydantic import BaseModel
4
- import qrcode
5
- import io
6
- from app.models import Session, SessionCreate
7
- from app.auth import require_hf_token, optional_hf_token
8
-
9
- router = APIRouter()
10
-
11
- # This will be set by the main app
12
- session_manager = None
13
-
14
- # Initialize services (these will be injected by main app)
15
- transcription_service = None
16
- translation_service = None
17
- tts_service = None
18
-
19
- class TextTranslationRequest(BaseModel):
20
- text: str
21
- source_language: str
22
- target_language: str
23
-
24
- class TextTranslationResponse(BaseModel):
25
- original_text: str
26
- translated_text: str
27
- source_language: str
28
- target_language: str
29
-
30
- @router.post("/sessions", response_model=Session)
31
- async def create_session(session_data: SessionCreate, token: str = Depends(require_hf_token)):
32
- """Create a new transcription session"""
33
- try:
34
- session = await session_manager.create_session(session_data)
35
- return session
36
- except Exception as e:
37
- raise HTTPException(status_code=500, detail=str(e))
38
-
39
- @router.get("/sessions", response_model=List[Session])
40
- async def get_all_sessions(token: str = Depends(require_hf_token)):
41
- """Get all active sessions"""
42
- try:
43
- sessions = await session_manager.get_all_sessions()
44
- return sessions
45
- except Exception as e:
46
- raise HTTPException(status_code=500, detail=str(e))
47
-
48
- @router.get("/sessions/{session_id}", response_model=Session)
49
- async def get_session(session_id: str, token: str = Depends(require_hf_token)):
50
- """Get specific session by ID or short code"""
51
- session = await session_manager.get_session(session_id)
52
- if not session:
53
- raise HTTPException(status_code=404, detail="Session not found")
54
- return session
55
-
56
- @router.get("/sessions/{session_id}/short-code")
57
- async def get_session_short_code(session_id: str, token: str = Depends(require_hf_token)):
58
- """Get short code for a session"""
59
- session = await session_manager.get_session(session_id)
60
- if not session:
61
- raise HTTPException(status_code=404, detail="Session not found")
62
-
63
- short_code = session_manager.get_short_code(session.id)
64
- return {"session_id": session.id, "short_code": short_code}
65
-
66
- @router.delete("/sessions/{session_id}")
67
- async def delete_session(session_id: str, token: str = Depends(require_hf_token)):
68
- """Delete a session"""
69
- success = await session_manager.delete_session(session_id)
70
- if not success:
71
- raise HTTPException(status_code=404, detail="Session not found")
72
- return {"message": "Session deleted successfully"}
73
-
74
- @router.post("/sessions/{session_id}/languages/{language_code}")
75
- async def add_language_to_session(session_id: str, language_code: str, token: str = Depends(require_hf_token)):
76
- """Add a language to a session"""
77
- from app.models import LanguageCode
78
-
79
- # Convert string to LanguageCode enum
80
- try:
81
- lang_code_enum = LanguageCode(language_code)
82
- except ValueError:
83
- raise HTTPException(status_code=400, detail=f"Invalid language code: {language_code}")
84
-
85
- success = await session_manager.add_language_to_session(session_id, lang_code_enum)
86
- if success:
87
- session = await session_manager.get_session(session_id)
88
- return {"message": f"Language {language_code} added to session", "session": session}
89
- else:
90
- # Check if session exists
91
- session = await session_manager.get_session(session_id)
92
- if not session:
93
- raise HTTPException(status_code=404, detail="Session not found")
94
- return {"message": f"Language {language_code} already exists in session", "session": session}
95
-
96
- @router.post("/translate", response_model=TextTranslationResponse)
97
- async def translate_text(request: TextTranslationRequest, token: str = Depends(require_hf_token)):
98
- """Translate text from source language to target language"""
99
- try:
100
- # Map language codes to proper names
101
- lang_map = {
102
- 'eng': 'English',
103
- 'swa': 'Swahili',
104
- 'kik': 'Kikuyu',
105
- 'kam': 'Kamba',
106
- 'mer': 'Kimeru',
107
- 'luo': 'Luo',
108
- 'som': 'Somali'
109
- }
110
-
111
- source_lang_name = lang_map.get(request.source_language.lower(), request.source_language)
112
- target_lang_name = lang_map.get(request.target_language.lower(), request.target_language)
113
-
114
- # Perform translation
115
- translated_text = await translation_service.translate_text(
116
- text=request.text,
117
- source_lang=source_lang_name,
118
- target_lang=target_lang_name
119
- )
120
-
121
- return TextTranslationResponse(
122
- original_text=request.text,
123
- translated_text=translated_text,
124
- source_language=request.source_language,
125
- target_language=request.target_language
126
- )
127
-
128
- except Exception as e:
129
- raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
130
-
131
- @router.get("/test")
132
- async def test_endpoint(token: str = Depends(optional_hf_token)):
133
- """Test endpoint to verify API is working"""
134
- auth_status = "authenticated" if token else "public"
135
- return {
136
- "status": "API is working",
137
- "sessions_count": len(session_manager.sessions),
138
- "auth_status": auth_status
139
- }
140
-
141
- @router.get("/test/translation")
142
- async def test_translation(token: str = Depends(require_hf_token)):
143
- """Test translation service directly"""
144
- try:
145
- # Test English to Swahili translation
146
- result = await translation_service.translate_text("Hello, how are you?", "English", "Swahili")
147
-
148
- return {
149
- "status": "Translation test completed",
150
- "original": "Hello, how are you?",
151
- "translated": result,
152
- "source_lang": "English",
153
- "target_lang": "Swahili"
154
- }
155
- except Exception as e:
156
- return {"status": "Translation test failed", "error": str(e)}
157
-
158
- @router.get("/test/tts")
159
- async def test_tts(token: str = Depends(require_hf_token)):
160
- """Test TTS service directly"""
161
- try:
162
- # Test TTS generation
163
- audio_data = await tts_service.generate_speech("Hello world", "eng")
164
-
165
- return {
166
- "status": "TTS test completed",
167
- "text": "Hello world",
168
- "language": "eng",
169
- "audio_generated": audio_data is not None,
170
- "audio_size": len(audio_data) if audio_data else 0
171
- }
172
- except Exception as e:
173
- return {"status": "TTS test failed", "error": str(e)}
174
-
175
- @router.get("/sessions/{session_id}/qr-code")
176
- async def get_session_qr_code(session_id: str, token: str = Depends(require_hf_token)):
177
- """Generate QR code for session"""
178
- if session_manager is None:
179
- raise HTTPException(status_code=500, detail="Session manager not initialized")
180
-
181
- session = await session_manager.get_session(session_id)
182
-
183
- if not session:
184
- raise HTTPException(status_code=404, detail="Session not found")
185
-
186
- # Generate QR code with session join URL - use your HF space URL
187
- join_url = f"https://mutisya-realtime-translator-5-27-25-v2.hf.space/?join={session_id}"
188
-
189
- qr = qrcode.QRCode(version=1, box_size=10, border=5)
190
- qr.add_data(join_url)
191
- qr.make(fit=True)
192
-
193
- img = qr.make_image(fill_color="black", back_color="white")
194
-
195
- # Convert to bytes
196
- img_buffer = io.BytesIO()
197
- img.save(img_buffer, format='PNG')
198
- img_buffer.seek(0)
199
-
200
- return Response(content=img_buffer.getvalue(), media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/routers/watch.py DELETED
@@ -1,152 +0,0 @@
1
- from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
2
- from fastapi.responses import Response
3
- from typing import Optional
4
- import io
5
- import base64
6
- from app.services.transcription_service import TranscriptionService
7
- from app.services.translation_service import TranslationService
8
- from app.services.tts_service import TTSService
9
- from app.models import LanguageCode
10
- from pydantic import BaseModel
11
- from app.auth import require_hf_token
12
-
13
- router = APIRouter()
14
-
15
- class WatchTranslationRequest(BaseModel):
16
- source_language: str
17
- target_language: str
18
- audio_base64: str
19
-
20
- class WatchTranslationResponse(BaseModel):
21
- original_text: str
22
- original_language: str
23
- translated_text: str
24
- target_language: str
25
- translated_audio_base64: str
26
- success: bool
27
- error: Optional[str] = None
28
-
29
- # Initialize services (these will be injected by main app)
30
- transcription_service = None
31
- translation_service = None
32
- tts_service = None
33
-
34
- @router.post("/watch/translate", response_model=WatchTranslationResponse)
35
- async def watch_translate_audio(request: WatchTranslationRequest, token: str = Depends(require_hf_token)):
36
- """
37
- Process audio for watch app translation
38
- - Transcribe audio using source language model
39
- - Translate text to target language
40
- - Generate TTS audio for target language
41
- - Return all data to watch app
42
- """
43
- try:
44
- # Validate languages
45
- source_lang = request.source_language.lower()
46
- target_lang = request.target_language.lower()
47
-
48
- if source_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
49
- raise HTTPException(status_code=400, detail=f"Unsupported source language: {source_lang}")
50
-
51
- if target_lang not in ['eng', 'swa', 'kik', 'kam', 'mer', 'luo', 'som']:
52
- raise HTTPException(status_code=400, detail=f"Unsupported target language: {target_lang}")
53
-
54
- # Decode base64 audio
55
- try:
56
- audio_data = base64.b64decode(request.audio_base64)
57
- print(f"Decoded audio data: {len(audio_data)} bytes")
58
- except Exception as e:
59
- raise HTTPException(status_code=400, detail=f"Invalid base64 audio data: {str(e)}")
60
-
61
- # Step 1: Transcribe audio
62
- print(f"Transcribing audio with {source_lang} model...")
63
- transcribed_text = await transcription_service.transcribe_audio(audio_data, source_lang)
64
-
65
- if not transcribed_text or transcribed_text.strip() == "":
66
- return WatchTranslationResponse(
67
- original_text="",
68
- original_language=source_lang,
69
- translated_text="No speech detected",
70
- target_language=target_lang,
71
- translated_audio_base64="",
72
- success=False,
73
- error="No speech detected in audio"
74
- )
75
-
76
- print(f"Transcribed text: {transcribed_text}")
77
-
78
- # Step 2: Translate text (skip if source and target are the same)
79
- if source_lang == target_lang:
80
- translated_text = transcribed_text
81
- else:
82
- print(f"Translating from {source_lang} to {target_lang}...")
83
-
84
- # Convert language codes to full names for translation service
85
- lang_name_map = {
86
- 'eng': 'English',
87
- 'swa': 'Swahili',
88
- 'kik': 'Kikuyu',
89
- 'kam': 'Kamba',
90
- 'mer': 'Kimeru',
91
- 'luo': 'Luo',
92
- 'som': 'Somali'
93
- }
94
-
95
- source_lang_name = lang_name_map.get(source_lang, 'English')
96
- target_lang_name = lang_name_map.get(target_lang, 'Swahili')
97
-
98
- translated_text = await translation_service.translate_text(
99
- transcribed_text,
100
- source_lang_name,
101
- target_lang_name
102
- )
103
-
104
- print(f"Translated text: {translated_text}")
105
-
106
- # Step 3: Generate TTS audio for translated text (Android-compatible WAV format)
107
- print(f"Generating TTS audio for {target_lang} in WAV format for Android...")
108
- tts_audio_data = await tts_service.generate_speech(translated_text, target_lang, output_format="wav")
109
-
110
- # Encode TTS audio as base64
111
- tts_audio_base64 = ""
112
- if tts_audio_data:
113
- tts_audio_base64 = base64.b64encode(tts_audio_data).decode('utf-8')
114
- print(f"TTS audio generated: {len(tts_audio_data)} bytes, base64: {len(tts_audio_base64)} chars")
115
- else:
116
- print("TTS audio generation failed - no data returned")
117
-
118
- return WatchTranslationResponse(
119
- original_text=transcribed_text,
120
- original_language=source_lang,
121
- translated_text=translated_text,
122
- target_language=target_lang,
123
- translated_audio_base64=tts_audio_base64,
124
- success=True
125
- )
126
-
127
- except Exception as e:
128
- print(f"Error in watch translation: {str(e)}")
129
- import traceback
130
- traceback.print_exc()
131
-
132
- return WatchTranslationResponse(
133
- original_text="",
134
- original_language=request.source_language,
135
- translated_text="",
136
- target_language=request.target_language,
137
- translated_audio_base64="",
138
- success=False,
139
- error=str(e)
140
- )
141
-
142
- @router.get("/watch/test")
143
- async def test_watch_endpoint(token: str = Depends(require_hf_token)):
144
- """Test endpoint for watch app connectivity"""
145
- return {
146
- "status": "Watch API is working",
147
- "services": {
148
- "transcription": transcription_service is not None,
149
- "translation": translation_service is not None,
150
- "tts": tts_service is not None
151
- }
152
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/__init__.py DELETED
@@ -1 +0,0 @@
1
- # Services package
 
 
app/services/learning_data_service.py DELETED
@@ -1,415 +0,0 @@
1
- """
2
- Learning Data Service - File-based data access for language learning prototype
3
-
4
- This service provides access to lesson data, user progress, and achievements
5
- using JSON files stored in the backend/data/learning directory.
6
- """
7
-
8
- import json
9
- import os
10
- from pathlib import Path
11
- from typing import Dict, List, Optional, Any
12
- from datetime import datetime
13
- import logging
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class LearningDataService:
19
- """Service for managing language learning data using JSON files"""
20
-
21
- def __init__(self):
22
- # Get the data directory relative to this file
23
- self.data_dir = Path(__file__).parent.parent.parent / "data" / "learning"
24
- self.lessons_dir = self.data_dir / "lessons"
25
- self.users_dir = self.data_dir / "users"
26
-
27
- # Ensure directories exist
28
- self.users_dir.mkdir(parents=True, exist_ok=True)
29
-
30
- logger.info(f"Learning data directory: {self.data_dir}")
31
- logger.info(f"Lessons directory: {self.lessons_dir}")
32
- logger.info(f"Users directory: {self.users_dir}")
33
-
34
- # ==================== Lesson Data ====================
35
-
36
- def get_lessons_index(self, language: str = 'swahili') -> Optional[Dict]:
37
- """Load the lessons index/catalog for a specific language"""
38
- try:
39
- # Map language codes to folder names
40
- language_map = {
41
- 'swahili': 'swahili',
42
- 'swa': 'swahili',
43
- 'kamba': 'kamba',
44
- 'kam': 'kamba',
45
- 'maasai': 'maasai',
46
- 'mas': 'maasai'
47
- }
48
-
49
- language_folder = language_map.get(language.lower(), 'swahili')
50
- index_path = self.lessons_dir / language_folder / "index.json"
51
-
52
- logger.info(f"Loading lessons index for language '{language}' -> folder '{language_folder}' at {index_path}")
53
-
54
- if not index_path.exists():
55
- logger.warning(f"Lessons index not found at {index_path}")
56
- logger.info(f"Lessons dir contents: {list(self.lessons_dir.iterdir())}")
57
- return None
58
-
59
- with open(index_path, 'r', encoding='utf-8') as f:
60
- data = json.load(f)
61
- logger.info(f"Successfully loaded {len(data.get('lessons', []))} lessons for {language}")
62
- return data
63
- except Exception as e:
64
- logger.error(f"Error loading lessons index for {language}: {e}")
65
- return None
66
-
67
- def get_lesson(self, lesson_id: int, language: str = 'swahili') -> Optional[Dict]:
68
- """Load a specific lesson by ID for a specific language"""
69
- try:
70
- # Map language codes to folder names
71
- language_map = {
72
- 'swahili': 'swahili',
73
- 'swa': 'swahili',
74
- 'kamba': 'kamba',
75
- 'kam': 'kamba',
76
- 'maasai': 'maasai',
77
- 'mas': 'maasai'
78
- }
79
-
80
- language_folder = language_map.get(language.lower(), 'swahili')
81
-
82
- # First get the index to find the lesson file
83
- index = self.get_lessons_index(language)
84
- if not index:
85
- return None
86
-
87
- # Find the lesson in the index
88
- lesson_meta = None
89
- for lesson in index.get('lessons', []):
90
- if lesson['lesson_id'] == lesson_id:
91
- lesson_meta = lesson
92
- break
93
-
94
- if not lesson_meta:
95
- logger.warning(f"Lesson {lesson_id} not found in index for {language}")
96
- return None
97
-
98
- # Load the lesson file
99
- lesson_path = self.lessons_dir / language_folder / lesson_meta['file']
100
- if not lesson_path.exists():
101
- logger.warning(f"Lesson file not found: {lesson_path}")
102
- return None
103
-
104
- with open(lesson_path, 'r', encoding='utf-8') as f:
105
- return json.load(f)
106
- except Exception as e:
107
- logger.error(f"Error loading lesson {lesson_id} for {language}: {e}")
108
- return None
109
-
110
- def get_available_lessons(self) -> List[Dict]:
111
- """Get list of available lessons (not planned)"""
112
- try:
113
- index = self.get_lessons_index()
114
- if not index:
115
- return []
116
-
117
- available = [
118
- lesson for lesson in index.get('lessons', [])
119
- if lesson.get('status') == 'available'
120
- ]
121
- return available
122
- except Exception as e:
123
- logger.error(f"Error getting available lessons: {e}")
124
- return []
125
-
126
- # ==================== Achievements ====================
127
-
128
- def get_achievements(self) -> Optional[Dict]:
129
- """Load achievements configuration"""
130
- try:
131
- achievements_path = self.data_dir / "achievements.json"
132
- if not achievements_path.exists():
133
- logger.warning(f"Achievements file not found at {achievements_path}")
134
- return None
135
-
136
- with open(achievements_path, 'r', encoding='utf-8') as f:
137
- return json.load(f)
138
- except Exception as e:
139
- logger.error(f"Error loading achievements: {e}")
140
- return None
141
-
142
- # ==================== User Progress ====================
143
-
144
- def get_user_progress(self, user_id: str) -> Optional[Dict]:
145
- """Load user progress data"""
146
- try:
147
- user_file = self.users_dir / f"user-{user_id}.json"
148
- if not user_file.exists():
149
- # Return default progress structure for new users
150
- return self._create_default_user_progress(user_id)
151
-
152
- with open(user_file, 'r', encoding='utf-8') as f:
153
- return json.load(f)
154
- except Exception as e:
155
- logger.error(f"Error loading user progress for {user_id}: {e}")
156
- return None
157
-
158
- def save_user_progress(self, user_id: str, progress_data: Dict) -> bool:
159
- """Save user progress data"""
160
- try:
161
- user_file = self.users_dir / f"user-{user_id}.json"
162
-
163
- # Update last_active timestamp
164
- if 'profile' in progress_data:
165
- progress_data['profile']['last_active'] = datetime.utcnow().isoformat() + 'Z'
166
-
167
- with open(user_file, 'w', encoding='utf-8') as f:
168
- json.dump(progress_data, f, indent=2, ensure_ascii=False)
169
-
170
- logger.info(f"Saved progress for user {user_id}")
171
- return True
172
- except Exception as e:
173
- logger.error(f"Error saving user progress for {user_id}: {e}")
174
- return False
175
-
176
- def update_lesson_progress(
177
- self,
178
- user_id: str,
179
- lesson_id: int,
180
- progress_update: Dict
181
- ) -> bool:
182
- """Update progress for a specific lesson"""
183
- try:
184
- user_progress = self.get_user_progress(user_id)
185
- if not user_progress:
186
- return False
187
-
188
- # Initialize lesson_progress if it doesn't exist
189
- if 'lesson_progress' not in user_progress:
190
- user_progress['lesson_progress'] = {}
191
-
192
- lesson_key = str(lesson_id)
193
-
194
- # Update or create lesson progress
195
- if lesson_key in user_progress['lesson_progress']:
196
- user_progress['lesson_progress'][lesson_key].update(progress_update)
197
- else:
198
- user_progress['lesson_progress'][lesson_key] = progress_update
199
-
200
- return self.save_user_progress(user_id, user_progress)
201
- except Exception as e:
202
- logger.error(f"Error updating lesson progress: {e}")
203
- return False
204
-
205
- def update_vocabulary_progress(
206
- self,
207
- user_id: str,
208
- vocab_id: int,
209
- vocab_update: Dict
210
- ) -> bool:
211
- """Update progress for a specific vocabulary word"""
212
- try:
213
- user_progress = self.get_user_progress(user_id)
214
- if not user_progress:
215
- return False
216
-
217
- # Initialize vocabulary_progress if it doesn't exist
218
- if 'vocabulary_progress' not in user_progress:
219
- user_progress['vocabulary_progress'] = {}
220
-
221
- vocab_key = str(vocab_id)
222
-
223
- # Update or create vocabulary progress
224
- if vocab_key in user_progress['vocabulary_progress']:
225
- user_progress['vocabulary_progress'][vocab_key].update(vocab_update)
226
- else:
227
- user_progress['vocabulary_progress'][vocab_key] = vocab_update
228
-
229
- return self.save_user_progress(user_id, user_progress)
230
- except Exception as e:
231
- logger.error(f"Error updating vocabulary progress: {e}")
232
- return False
233
-
234
- def unlock_achievement(
235
- self,
236
- user_id: str,
237
- achievement_id: str,
238
- progress: int,
239
- target: int
240
- ) -> bool:
241
- """Unlock or update progress on an achievement"""
242
- try:
243
- user_progress = self.get_user_progress(user_id)
244
- if not user_progress:
245
- return False
246
-
247
- # Initialize achievements if it doesn't exist
248
- if 'achievements' not in user_progress:
249
- user_progress['achievements'] = {}
250
-
251
- # Update achievement
252
- achievement_data = {
253
- 'achievement_id': achievement_id,
254
- 'unlocked': progress >= target,
255
- 'progress': progress,
256
- 'target': target
257
- }
258
-
259
- # Add unlock timestamp if newly unlocked
260
- if achievement_data['unlocked'] and achievement_id not in user_progress['achievements']:
261
- achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
262
- elif achievement_data['unlocked'] and achievement_id in user_progress['achievements']:
263
- # Preserve original unlock time
264
- if 'unlocked_at' in user_progress['achievements'][achievement_id]:
265
- achievement_data['unlocked_at'] = user_progress['achievements'][achievement_id]['unlocked_at']
266
- else:
267
- achievement_data['unlocked_at'] = datetime.utcnow().isoformat() + 'Z'
268
-
269
- user_progress['achievements'][achievement_id] = achievement_data
270
-
271
- return self.save_user_progress(user_id, user_progress)
272
- except Exception as e:
273
- logger.error(f"Error unlocking achievement: {e}")
274
- return False
275
-
276
- # ==================== Helper Methods ====================
277
-
278
- def _create_default_user_progress(self, user_id: str) -> Dict:
279
- """Create default progress structure for a new user"""
280
- return {
281
- 'user_id': user_id,
282
- 'profile': {
283
- 'user_id': user_id,
284
- 'learning_language': 'swa',
285
- 'native_language': 'eng',
286
- 'created_at': datetime.utcnow().isoformat() + 'Z',
287
- 'last_active': datetime.utcnow().isoformat() + 'Z'
288
- },
289
- 'overall_stats': {
290
- 'level': 'beginner',
291
- 'total_xp': 0,
292
- 'next_level_xp': 1000,
293
- 'current_streak': 0,
294
- 'longest_streak': 0,
295
- 'lessons_completed': 0,
296
- 'vocabulary_learned': 0,
297
- 'vocabulary_mastered': 0,
298
- 'total_practice_time_seconds': 0,
299
- 'pronunciation_avg_score': 0.0,
300
- 'listening_avg_score': 0.0,
301
- 'reading_avg_score': 0.0
302
- },
303
- 'daily_stats': {},
304
- 'lesson_progress': {},
305
- 'vocabulary_progress': {},
306
- 'achievements': {},
307
- 'session_history': []
308
- }
309
-
310
- def create_default_progress(self, user_id: str) -> Dict:
311
- """Public method to create default progress structure"""
312
- progress = self._create_default_user_progress(user_id)
313
- # Add Phase 1-3 specific fields
314
- progress['overall_stats']['vocabulary_reviewed'] = 0
315
- progress['comprehension_scores'] = {}
316
- progress['scenario_progress'] = {}
317
- return progress
318
-
319
- # ==================== Phase 1-3 Methods ====================
320
-
321
- def get_vocabulary(self, vocab_id: int) -> Optional[Dict]:
322
- """Get a single vocabulary word by ID from any lesson"""
323
- try:
324
- lessons_index = self.get_lessons_index()
325
- if not lessons_index:
326
- return None
327
-
328
- # Search through all lessons
329
- for lesson_meta in lessons_index.get('lessons', []):
330
- lesson = self.get_lesson(lesson_meta['lesson_id'])
331
- if lesson and 'vocabulary' in lesson:
332
- for vocab in lesson['vocabulary']:
333
- # Support both 'id' and 'vocabulary_id' fields
334
- vocab_item_id = vocab.get('vocabulary_id') or vocab.get('id')
335
- if vocab_item_id == vocab_id:
336
- # Add lesson context
337
- vocab['lesson_id'] = lesson['lesson_id']
338
- vocab['lesson_title'] = lesson.get('title', '')
339
- return vocab
340
-
341
- logger.warning(f"Vocabulary {vocab_id} not found in any lesson")
342
- return None
343
- except Exception as e:
344
- logger.error(f"Error getting vocabulary {vocab_id}: {e}")
345
- return None
346
-
347
- def get_all_vocabulary(self) -> List[Dict]:
348
- """Get all vocabulary words from all lessons"""
349
- try:
350
- all_vocab = []
351
- lessons_index = self.get_lessons_index()
352
- if not lessons_index:
353
- return all_vocab
354
-
355
- for lesson_meta in lessons_index.get('lessons', []):
356
- lesson = self.get_lesson(lesson_meta['lesson_id'])
357
- if lesson and 'vocabulary' in lesson:
358
- for vocab in lesson['vocabulary']:
359
- # Add lesson context
360
- vocab_copy = vocab.copy()
361
- vocab_copy['lesson_id'] = lesson['lesson_id']
362
- vocab_copy['lesson_title'] = lesson.get('title', '')
363
- vocab_copy['lesson_level'] = lesson.get('difficulty_level', 1)
364
- all_vocab.append(vocab_copy)
365
-
366
- return all_vocab
367
- except Exception as e:
368
- logger.error(f"Error getting all vocabulary: {e}")
369
- return []
370
-
371
- def get_scenario(self, scenario_id: str) -> Optional[Dict]:
372
- """Load a task scenario by ID"""
373
- try:
374
- scenarios_dir = self.data_dir / "scenarios"
375
- scenario_path = scenarios_dir / f"{scenario_id}.json"
376
-
377
- if not scenario_path.exists():
378
- logger.warning(f"Scenario file not found: {scenario_path}")
379
- return None
380
-
381
- with open(scenario_path, 'r', encoding='utf-8') as f:
382
- return json.load(f)
383
- except Exception as e:
384
- logger.error(f"Error loading scenario {scenario_id}: {e}")
385
- return None
386
-
387
- def get_all_scenarios(self) -> List[Dict]:
388
- """Get list of all available scenarios"""
389
- try:
390
- scenarios_dir = self.data_dir / "scenarios"
391
- if not scenarios_dir.exists():
392
- return []
393
-
394
- scenarios = []
395
- for scenario_file in scenarios_dir.glob("*.json"):
396
- try:
397
- with open(scenario_file, 'r', encoding='utf-8') as f:
398
- scenario_data = json.load(f)
399
- # Add just metadata, not full dialogue tree
400
- scenarios.append({
401
- 'scenario_id': scenario_data.get('scenario_id'),
402
- 'title': scenario_data.get('title'),
403
- 'title_en': scenario_data.get('title_en'),
404
- 'level': scenario_data.get('level'),
405
- 'estimated_duration_minutes': scenario_data.get('estimated_duration_minutes'),
406
- 'learning_goals': scenario_data.get('learning_goals', [])
407
- })
408
- except Exception as e:
409
- logger.error(f"Error loading scenario {scenario_file}: {e}")
410
- continue
411
-
412
- return scenarios
413
- except Exception as e:
414
- logger.error(f"Error getting all scenarios: {e}")
415
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/quantization_utils.py DELETED
@@ -1,124 +0,0 @@
1
- """
2
- Dynamic INT8 Quantization utilities for ASR models.
3
-
4
- This module provides utilities to apply PyTorch dynamic quantization to
5
- Hugging Face transformer models, specifically optimized for ASR models like
6
- Whisper and Wav2Vec2-BERT.
7
- """
8
-
9
- import torch
10
- from torch.quantization import quantize_dynamic
11
- from transformers import PreTrainedModel
12
- import time
13
-
14
-
15
- def apply_dynamic_int8_quantization(model: PreTrainedModel, model_type: str = "auto") -> PreTrainedModel:
16
- """
17
- Apply dynamic INT8 quantization to a Hugging Face model.
18
-
19
- Dynamic quantization converts model weights to INT8 and activations to INT8 on-the-fly
20
- during inference, reducing model size and improving inference speed with minimal
21
- accuracy loss.
22
-
23
- Args:
24
- model: The Hugging Face model to quantize
25
- model_type: Type of model ("whisper", "wav2vec2-bert", or "auto")
26
-
27
- Returns:
28
- Quantized model
29
-
30
- References:
31
- - PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html
32
- - Dynamic Quantization for NLP: https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html
33
- """
34
- print(f"\n{'='*60}")
35
- print(f"Applying Dynamic INT8 Quantization to {model_type} model")
36
- print(f"{'='*60}")
37
-
38
- # Get model size before quantization
39
- param_size = 0
40
- for param in model.parameters():
41
- param_size += param.nelement() * param.element_size()
42
- buffer_size = 0
43
- for buffer in model.buffers():
44
- buffer_size += buffer.nelement() * buffer.element_size()
45
- size_before_mb = (param_size + buffer_size) / 1024**2
46
-
47
- print(f"Model size before quantization: {size_before_mb:.2f} MB")
48
-
49
- # Start quantization timer
50
- start_time = time.time()
51
-
52
- try:
53
- # Dynamic quantization targets:
54
- # - torch.nn.Linear: Most common layer type in transformers
55
- # - torch.nn.LSTM/GRU/RNN: For sequential models (if present)
56
- #
57
- # Note: We use qint8 (quantized int8) which converts weights to INT8
58
- # and performs INT8 arithmetic for linear layers during inference
59
- quantized_model = quantize_dynamic(
60
- model,
61
- {torch.nn.Linear}, # Quantize all Linear layers
62
- dtype=torch.qint8 # Use 8-bit integer quantization
63
- )
64
-
65
- # Get model size after quantization
66
- param_size_q = 0
67
- for param in quantized_model.parameters():
68
- param_size_q += param.nelement() * param.element_size()
69
- buffer_size_q = 0
70
- for buffer in quantized_model.buffers():
71
- buffer_size_q += buffer.nelement() * buffer.element_size()
72
- size_after_mb = (param_size_q + buffer_size_q) / 1024**2
73
-
74
- quantization_time = time.time() - start_time
75
- size_reduction = ((size_before_mb - size_after_mb) / size_before_mb) * 100
76
-
77
- print(f"✓ Quantization successful!")
78
- print(f" - Model size after quantization: {size_after_mb:.2f} MB")
79
- print(f" - Size reduction: {size_reduction:.1f}%")
80
- print(f" - Quantization time: {quantization_time:.2f}s")
81
- print(f"{'='*60}\n")
82
-
83
- return quantized_model
84
-
85
- except Exception as e:
86
- print(f"✗ Quantization failed: {e}")
87
- print(f" Returning original unquantized model")
88
- print(f"{'='*60}\n")
89
- return model
90
-
91
-
92
- def get_quantization_stats(model: PreTrainedModel) -> dict:
93
- """
94
- Get statistics about a model's quantization status.
95
-
96
- Args:
97
- model: The model to analyze
98
-
99
- Returns:
100
- Dictionary with quantization statistics
101
- """
102
- stats = {
103
- "is_quantized": False,
104
- "quantized_layers": 0,
105
- "total_layers": 0,
106
- "size_mb": 0.0
107
- }
108
-
109
- # Count quantized vs regular layers
110
- for name, module in model.named_modules():
111
- if isinstance(module, (torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU)):
112
- stats["total_layers"] += 1
113
-
114
- # Check if layer is quantized (will have _packed_params attribute)
115
- if hasattr(module, '_packed_params'):
116
- stats["quantized_layers"] += 1
117
- stats["is_quantized"] = True
118
-
119
- # Calculate model size
120
- param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
121
- buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
122
- stats["size_mb"] = (param_size + buffer_size) / 1024**2
123
-
124
- return stats
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/session_manager.py DELETED
@@ -1,180 +0,0 @@
1
- import uuid
2
- import random
3
- import string
4
- from typing import Dict, List, Optional
5
- from app.models import Session, SessionCreate, Participant, Language, LanguageCode
6
-
7
- def generate_short_code(length: int = 8) -> str:
8
- """Generate a random short code using uppercase letters and digits"""
9
- # Use only uppercase letters and digits to avoid confusion (no lowercase to avoid O/0, I/1 confusion)
10
- alphabet = string.ascii_uppercase + string.digits
11
- # Remove confusing characters
12
- alphabet = alphabet.replace('O', '').replace('0', '').replace('I', '').replace('1', '')
13
- return ''.join(random.choice(alphabet) for _ in range(length))
14
-
15
- # Language mappings
16
- LANGUAGE_MAP = {
17
- LanguageCode.ENGLISH: Language(code=LanguageCode.ENGLISH, name="English", display_name="English (eng)"),
18
- LanguageCode.SWAHILI: Language(code=LanguageCode.SWAHILI, name="Swahili", display_name="Swahili (swa)"),
19
- LanguageCode.KIKUYU: Language(code=LanguageCode.KIKUYU, name="Kikuyu", display_name="Kikuyu (kik)"),
20
- LanguageCode.KAMBA: Language(code=LanguageCode.KAMBA, name="Kamba", display_name="Kamba (kam)"),
21
- LanguageCode.KIMERU: Language(code=LanguageCode.KIMERU, name="Kimeru", display_name="Kimeru (mer)"),
22
- LanguageCode.LUO: Language(code=LanguageCode.LUO, name="Luo", display_name="Luo (luo)"),
23
- LanguageCode.SOMALI: Language(code=LanguageCode.SOMALI, name="Somali", display_name="Somali (som)"),
24
- }
25
-
26
- class SessionManager:
27
- def __init__(self):
28
- self.sessions: Dict[str, Session] = {}
29
- self.participant_sessions: Dict[str, str] = {} # participant_id -> session_id
30
- self.short_code_to_id: Dict[str, str] = {} # short_code -> session_id
31
- self.id_to_short_code: Dict[str, str] = {} # session_id -> short_code
32
-
33
- async def create_session(self, session_data: SessionCreate) -> Session:
34
- session_id = str(uuid.uuid4())
35
-
36
- # Generate unique short code
37
- short_code = generate_short_code(8)
38
- while short_code in self.short_code_to_id:
39
- # Extremely unlikely collision, but regenerate if needed
40
- short_code = generate_short_code(8)
41
-
42
- # Convert language codes to Language objects
43
- languages = [LANGUAGE_MAP[lang_code] for lang_code in session_data.languages]
44
-
45
- session = Session(
46
- id=session_id,
47
- name=session_data.name,
48
- organizer_name=session_data.organizer_name,
49
- languages=languages,
50
- participants=[],
51
- is_active=True,
52
- enable_tts=session_data.enable_tts
53
- )
54
-
55
- self.sessions[session_id] = session
56
- self.short_code_to_id[short_code] = session_id
57
- self.id_to_short_code[session_id] = short_code
58
- return session
59
-
60
- async def get_session(self, session_id_or_code: str) -> Optional[Session]:
61
- """Get session by full UUID or short code"""
62
- # Try as full UUID first
63
- session = self.sessions.get(session_id_or_code)
64
- if session:
65
- return session
66
-
67
- # Try as short code
68
- session_id = self.short_code_to_id.get(session_id_or_code.upper())
69
- if session_id:
70
- return self.sessions.get(session_id)
71
-
72
- return None
73
-
74
- def get_short_code(self, session_id: str) -> str:
75
- """Get short code for a session ID"""
76
- return self.id_to_short_code.get(session_id, session_id)
77
-
78
- async def get_all_sessions(self) -> List[Session]:
79
- return list(self.sessions.values())
80
-
81
- async def add_participant(self, session_id: str, participant_name: str, language_code: LanguageCode) -> Optional[Participant]:
82
- session = await self.get_session(session_id)
83
- if not session:
84
- return None
85
-
86
- participant_id = str(uuid.uuid4())
87
- language = LANGUAGE_MAP[language_code]
88
-
89
- # Check if the participant's language is already in the session languages
90
- language_exists = any(lang.code == language_code for lang in session.languages)
91
- if not language_exists:
92
- print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
93
- session.languages.append(language)
94
-
95
- participant = Participant(
96
- id=participant_id,
97
- name=participant_name,
98
- language=language,
99
- is_organizer=len(session.participants) == 0, # First participant is organizer
100
- is_speaking=False,
101
- is_connected=True
102
- )
103
-
104
- session.participants.append(participant)
105
- self.participant_sessions[participant_id] = session_id
106
-
107
- print(f"Participant {participant_name} added to session. Session now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
108
-
109
- return participant
110
-
111
- async def remove_participant(self, participant_id: str) -> bool:
112
- session_id = self.participant_sessions.get(participant_id)
113
- if not session_id:
114
- return False
115
-
116
- session = await self.get_session(session_id)
117
- if not session:
118
- return False
119
-
120
- # Remove participant from session
121
- session.participants = [p for p in session.participants if p.id != participant_id]
122
- del self.participant_sessions[participant_id]
123
-
124
- return True
125
-
126
- async def update_participant_speaking_status(self, participant_id: str, is_speaking: bool) -> bool:
127
- session_id = self.participant_sessions.get(participant_id)
128
- if not session_id:
129
- return False
130
-
131
- session = await self.get_session(session_id)
132
- if not session:
133
- return False
134
-
135
- for participant in session.participants:
136
- if participant.id == participant_id:
137
- participant.is_speaking = is_speaking
138
- return True
139
-
140
- return False
141
-
142
- async def get_participant_session_id(self, participant_id: str) -> Optional[str]:
143
- return self.participant_sessions.get(participant_id)
144
-
145
- async def add_language_to_session(self, session_id: str, language_code: LanguageCode) -> bool:
146
- """Add a language to the session if it doesn't already exist"""
147
- session = await self.get_session(session_id)
148
- if not session:
149
- return False
150
-
151
- language = LANGUAGE_MAP[language_code]
152
-
153
- # Check if the language is already in the session languages
154
- language_exists = any(lang.code == language_code for lang in session.languages)
155
- if not language_exists:
156
- print(f"Adding new language {language.name} ({language_code.value}) to session {session_id}")
157
- session.languages.append(language)
158
- print(f"Session {session_id} now has {len(session.languages)} languages: {[lang.name for lang in session.languages]}")
159
- return True
160
- else:
161
- print(f"Language {language.name} ({language_code.value}) already exists in session {session_id}")
162
- return False
163
-
164
- async def delete_session(self, session_id: str) -> bool:
165
- if session_id in self.sessions:
166
- # Remove all participants from tracking
167
- session = self.sessions[session_id]
168
- for participant in session.participants:
169
- if participant.id in self.participant_sessions:
170
- del self.participant_sessions[participant.id]
171
-
172
- # Remove short code mapping
173
- short_code = self.id_to_short_code.get(session_id)
174
- if short_code:
175
- del self.short_code_to_id[short_code]
176
- del self.id_to_short_code[session_id]
177
-
178
- del self.sessions[session_id]
179
- return True
180
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/transcription_service.py DELETED
@@ -1,736 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- import time
6
- from typing import Dict, Optional, Callable
7
- from transformers import pipeline
8
- import torch
9
- from app.models import LanguageCode
10
- import os
11
- from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
12
-
13
- # Silero VAD imports
14
- try:
15
- import silero_vad
16
- SILERO_VAD_AVAILABLE = True
17
- except ImportError:
18
- SILERO_VAD_AVAILABLE = False
19
- print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
20
-
21
- class TranscriptionService:
22
- def __init__(self):
23
- self.asr_pipelines: Dict[str, any] = {}
24
- self.device = 0 if torch.cuda.is_available() else -1
25
-
26
- # Model configurations - using original mutisya models with updated config
27
- self.asr_config = {
28
- "eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
29
- "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
30
- "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
31
- "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
32
- "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
33
- "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
34
- "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
35
- }
36
-
37
- self.preload_languages = ["eng"]
38
- self.background_loading_task = None
39
- self.models_loading_status = {}
40
-
41
- # Enhanced audio buffering for VAD-based sentence detection
42
- self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
43
- self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
44
- self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
45
- self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
46
-
47
- # VAD parameters - made more lenient for better detection
48
- self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
49
- self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
50
-
51
- # Silero VAD initialization
52
- self.vad_model = None
53
- self.vad_sample_rate = 16000
54
- self.vad_available = SILERO_VAD_AVAILABLE
55
-
56
- # Quantization configuration
57
- # Set ENABLE_INT8_QUANTIZATION=true in environment to enable quantization
58
- self.enable_quantization = os.getenv('ENABLE_INT8_QUANTIZATION', 'true').lower() == 'true'
59
- print(f"INT8 Quantization: {'ENABLED' if self.enable_quantization else 'DISABLED'}")
60
-
61
- async def initialize(self):
62
- """Initialize ASR models for preloaded languages and Silero VAD"""
63
- # Initialize Silero VAD model
64
- if self.vad_available:
65
- try:
66
- print("Loading Silero VAD model...")
67
- self.vad_model = silero_vad.load_silero_vad(onnx=False)
68
- print("✓ Silero VAD model loaded successfully")
69
- except Exception as e:
70
- print(f"Failed to load Silero VAD model: {e}")
71
- print("Falling back to RMS-based VAD")
72
- self.vad_available = False
73
-
74
- # Initialize ASR models
75
- for lang_code in self.preload_languages:
76
- if lang_code in self.asr_config:
77
- try:
78
- model_config = self.asr_config[lang_code]
79
- pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
80
- self.asr_pipelines[lang_code] = pipeline_obj
81
- except Exception as e:
82
- print(f"Failed to load ASR model for {lang_code}: {e}")
83
-
84
- def _load_and_quantize_pipeline(self, lang_code: str, model_config: dict):
85
- """Load ASR pipeline and optionally apply INT8 quantization"""
86
- # Build pipeline parameters
87
- pipeline_params = {
88
- "task": "automatic-speech-recognition",
89
- "model": model_config["model_repo"],
90
- "device": self.device,
91
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
92
- }
93
-
94
- # Add trust_remote_code if specified
95
- if model_config.get("trust_remote_code", False):
96
- pipeline_params["trust_remote_code"] = True
97
-
98
- print(f"Loading ASR model for {lang_code}: {model_config['model_repo']}")
99
- pipeline_obj = pipeline(**pipeline_params)
100
-
101
- # Apply quantization if enabled
102
- if self.enable_quantization:
103
- try:
104
- # Get the underlying model from the pipeline
105
- model = pipeline_obj.model
106
- model_type = model_config.get("model_type", "auto")
107
-
108
- # Apply dynamic INT8 quantization
109
- quantized_model = apply_dynamic_int8_quantization(model, model_type)
110
-
111
- # Replace the model in the pipeline
112
- pipeline_obj.model = quantized_model
113
-
114
- # Print quantization stats
115
- stats = get_quantization_stats(quantized_model)
116
- print(f"✓ {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
117
-
118
- except Exception as e:
119
- print(f"Warning: Could not quantize {lang_code} model: {e}")
120
- print(f"Continuing with unquantized model")
121
-
122
- return pipeline_obj
123
-
124
- async def ensure_model_loaded(self, language_code: str):
125
- """Load ASR model for language if not already loaded"""
126
- if language_code not in self.asr_pipelines and language_code in self.asr_config:
127
- try:
128
- model_config = self.asr_config[language_code]
129
- pipeline_obj = self._load_and_quantize_pipeline(language_code, model_config)
130
- self.asr_pipelines[language_code] = pipeline_obj
131
- except Exception as e:
132
- print(f"Failed to load ASR model for {language_code}: {e}")
133
- raise
134
-
135
- async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
136
- has_voice_activity: bool = True,
137
- progress_callback: Optional[Callable] = None,
138
- sentence_callback: Optional[Callable] = None,
139
- debug_callback: Optional[Callable] = None) -> str:
140
- """Process audio chunk with VAD-based sentence detection"""
141
- try:
142
- # Initialize buffers if needed
143
- if participant_id not in self.candidate_audio_buffers:
144
- # Store as numpy array, not bytes, to avoid multiple WAV header issues
145
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
146
- self.candidate_text_cache[participant_id] = ""
147
- self.silence_counters[participant_id] = 0
148
- self.sentence_finalized[participant_id] = False
149
-
150
- # Convert current chunk to numpy array for processing
151
- current_chunk_array = self._bytes_to_audio_array(audio_data)
152
- if len(current_chunk_array) == 0:
153
- print(f"WARNING: Received empty audio chunk for participant {participant_id}")
154
- return self.candidate_text_cache.get(participant_id, "")
155
-
156
- print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
157
- f"duration: {len(current_chunk_array)/16000:.3f}s, "
158
- f"first 4 bytes: {audio_data[:4]}")
159
-
160
- # DO NOT normalize individual chunks - this causes audio distortion
161
- # We'll normalize the entire accumulated audio buffer before transcription
162
- current_chunk_array = current_chunk_array.astype(np.float32)
163
-
164
- # Get existing accumulated audio array (now stored as numpy array)
165
- existing_array = self.candidate_audio_buffers[participant_id]
166
- if len(existing_array) > 0:
167
- # Concatenate with existing audio (like stream = np.concatenate([stream, y]))
168
- combined_array = np.concatenate([existing_array, current_chunk_array])
169
- else:
170
- combined_array = current_chunk_array
171
-
172
- # Store as numpy array to avoid WAV header accumulation issues
173
- self.candidate_audio_buffers[participant_id] = combined_array
174
-
175
- # For debug callback, convert to bytes (this adds ONE WAV header)
176
- combined_bytes = self._audio_array_to_bytes(combined_array)
177
-
178
- # Update silence counter based on voice activity
179
- if not has_voice_activity:
180
- self.silence_counters[participant_id] += 1
181
- else:
182
- self.silence_counters[participant_id] = 0
183
-
184
- # Check if we should finalize sentence due to prolonged silence
185
- should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
186
- len(combined_array) > 0 and
187
- not self.sentence_finalized[participant_id])
188
-
189
- if should_finalize:
190
- return await self._finalize_candidate_sentence(
191
- language_code, participant_id, sentence_callback
192
- )
193
-
194
- # Always run transcription on the accumulated audio
195
- audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
196
-
197
- # Minimum duration check - ignore very short audio bursts
198
- MIN_CHUNK_DURATION = 0.3 # 300ms minimum
199
- if audio_duration_sec < MIN_CHUNK_DURATION:
200
- print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
201
- if progress_callback:
202
- cached_text = self.candidate_text_cache.get(participant_id, "")
203
- await progress_callback(cached_text, False)
204
- return self.candidate_text_cache.get(participant_id, "")
205
-
206
- # Force finalization if buffer gets too long (prevent infinite accumulation)
207
- if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
208
- return await self._finalize_candidate_sentence(
209
- language_code, participant_id, sentence_callback
210
- )
211
-
212
- # Run voice activity detection on the accumulated audio before transcription
213
- has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
214
-
215
- if not has_voice_in_buffer:
216
- # Still send progress update with cached text to maintain UI state
217
- if progress_callback:
218
- cached_text = self.candidate_text_cache.get(participant_id, "")
219
- await progress_callback(cached_text, False)
220
- return self.candidate_text_cache.get(participant_id, "")
221
-
222
- # Run transcription
223
- await self.ensure_model_loaded(language_code)
224
-
225
- # Double-check voice activity before running expensive ASR
226
- has_voice_for_asr = self.has_voice_activity(combined_bytes)
227
- if not has_voice_for_asr:
228
- print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
229
- # Return cached text and send progress update
230
- if progress_callback:
231
- cached_text = self.candidate_text_cache.get(participant_id, "")
232
- await progress_callback(cached_text, False)
233
- return self.candidate_text_cache.get(participant_id, "")
234
-
235
- if language_code not in self.asr_pipelines:
236
- raise ValueError(f"ASR model not available for language: {language_code}")
237
-
238
- print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
239
- pipeline_obj = self.asr_pipelines[language_code]
240
-
241
- # Normalize the ENTIRE accumulated audio buffer before transcription
242
- # This prevents audio distortion from per-chunk normalization
243
- normalized_array = combined_array.astype(np.float32)
244
- max_val = np.max(np.abs(normalized_array))
245
- if max_val > 0:
246
- normalized_array = normalized_array / max_val
247
-
248
- # Track transcription latency
249
- transcription_start_time = time.time()
250
-
251
- # For wav2vec2 models, request word timestamps
252
- model_type = self.asr_config[language_code].get("model_type", "whisper")
253
- if model_type in ["wav2vec2-bert", "wav2vec2"]:
254
- result = pipeline_obj(
255
- {"sampling_rate": 16000, "raw": normalized_array},
256
- return_timestamps="word"
257
- )
258
- else:
259
- # Whisper model - add anti-hallucination parameters
260
- # Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
261
- result = pipeline_obj(
262
- {"sampling_rate": 16000, "raw": normalized_array},
263
- return_timestamps=True,
264
- chunk_length_s=30, # Process in 30s chunks
265
- stride_length_s=5 # 5s stride for context
266
- )
267
-
268
- transcription_latency_ms = (time.time() - transcription_start_time) * 1000
269
-
270
- candidate_text = result.get("text", "").strip()
271
- word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
272
-
273
- # Send debug information if callback provided (for wav2vec2 models only)
274
- if debug_callback and word_timestamps is not None:
275
- debug_info = {
276
- "text": candidate_text,
277
- "timestamps": word_timestamps,
278
- "audio_data": combined_bytes,
279
- "audio_duration": audio_duration_sec,
280
- "model_type": model_type,
281
- "transcription_latency_ms": transcription_latency_ms
282
- }
283
- await debug_callback(debug_info)
284
-
285
- # Filter out common ASR artifacts and very short responses
286
- artifacts = [
287
- "thank you", "thanks", "bye", ".", ",", "?", "!",
288
- "um", "uh", "ah", "hmm", "mm", "mhm",
289
- "you", "the", "a", "an", "and", "but", "or",
290
- "music", "laughter", "applause", "[music]", "[laughter]",
291
- # Common Whisper hallucinations:
292
- "subscribe", "subtitles", "amara", "www", "http",
293
- "please subscribe", "like and subscribe",
294
- "thank you for watching", "don't forget to subscribe",
295
- "[blank_audio]", "[noise]", "[silence]",
296
- ]
297
-
298
- # Check if the result is likely an artifact
299
- is_artifact = (
300
- len(candidate_text) < 3 or # Very short
301
- candidate_text.lower() in artifacts or # Common artifacts
302
- len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
303
- )
304
-
305
- if is_artifact:
306
- # Keep the previous cached text instead of updating with artifact
307
- candidate_text = self.candidate_text_cache.get(participant_id, "")
308
-
309
- # Cache the current candidate text
310
- self.candidate_text_cache[participant_id] = candidate_text
311
-
312
- # Force completion if we have a reasonable amount of text and some silence
313
- word_count = len(candidate_text.split()) if candidate_text else 0
314
- if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
315
- not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
316
- return await self._finalize_candidate_sentence(
317
- language_code, participant_id, sentence_callback
318
- )
319
-
320
- # Always send progress update
321
- if progress_callback:
322
- await progress_callback(candidate_text, False)
323
-
324
- return candidate_text
325
-
326
- except Exception as e:
327
- print(f"TranscriptionService: Error processing audio chunk: {e}")
328
- import traceback
329
- traceback.print_exc()
330
- # Even on error, try to send cached text
331
- if progress_callback:
332
- cached_text = self.candidate_text_cache.get(participant_id, "")
333
- await progress_callback(cached_text, False)
334
- return self.candidate_text_cache.get(participant_id, "")
335
-
336
- async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
337
- sentence_callback: Optional[Callable] = None) -> str:
338
- """Finalize the current candidate sentence and clear buffers"""
339
- try:
340
- # Check if sentence was already finalized
341
- if self.sentence_finalized.get(participant_id, False):
342
- print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
343
- return self.candidate_text_cache.get(participant_id, "")
344
-
345
- final_text = self.candidate_text_cache.get(participant_id, "")
346
- final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
347
-
348
- # Convert audio array to bytes for VAD check and callback
349
- final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
350
-
351
- if final_text and len(final_text.strip()) > 0:
352
- # Run VAD check on the final accumulated buffer before sending for translation
353
- if len(final_audio_bytes) > 0:
354
- has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
355
- if not has_voice_in_final:
356
- print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
357
- # Clear buffers without sending to translation
358
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
359
- self.candidate_text_cache[participant_id] = ""
360
- self.silence_counters[participant_id] = 0
361
- self.sentence_finalized[participant_id] = False
362
- return ""
363
-
364
- # Mark as finalized BEFORE calling the callback to prevent race conditions
365
- self.sentence_finalized[participant_id] = True
366
-
367
- # Send to sentence callback for translation
368
- if sentence_callback and len(final_audio_bytes) > 0:
369
- print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
370
- await sentence_callback(final_text, final_audio_bytes)
371
-
372
- # Clear buffers for next sentence
373
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
374
- self.candidate_text_cache[participant_id] = ""
375
- self.silence_counters[participant_id] = 0
376
- self.sentence_finalized[participant_id] = False # Reset for next sentence
377
-
378
- return final_text
379
-
380
- except Exception as e:
381
- print(f"Error finalizing sentence: {e}")
382
- import traceback
383
- traceback.print_exc()
384
- # Reset finalized flag on error
385
- self.sentence_finalized[participant_id] = False
386
- return ""
387
-
388
- def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
389
- """Voice Activity Detection using Silero VAD (with RMS fallback)"""
390
- try:
391
- audio_array = self._bytes_to_audio_array(audio_data)
392
- if len(audio_array) == 0:
393
- print("VAD: No audio array, returning False")
394
- return False
395
-
396
- # Normalize audio to float32 range [-1, 1]
397
- audio_array = audio_array.astype(np.float32)
398
- if np.max(np.abs(audio_array)) > 0:
399
- audio_array /= np.max(np.abs(audio_array))
400
-
401
- # Use Silero VAD if available
402
- if self.vad_available and self.vad_model is not None:
403
- try:
404
- # Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
405
- # Process audio in chunks and average the probabilities
406
- frame_size = 512 # 32ms at 16kHz
407
- num_samples = len(audio_array)
408
-
409
- # If audio is too short, pad it
410
- if num_samples < frame_size:
411
- audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
412
- num_samples = frame_size
413
-
414
- # Process in frames and collect probabilities
415
- speech_probs = []
416
- for i in range(0, num_samples, frame_size):
417
- frame = audio_array[i:i + frame_size]
418
- if len(frame) < frame_size:
419
- # Pad last frame if needed
420
- frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
421
-
422
- # Convert to torch tensor
423
- frame_tensor = torch.from_numpy(frame).float()
424
-
425
- # Get speech probability from Silero VAD
426
- with torch.no_grad():
427
- prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
428
- speech_probs.append(prob)
429
-
430
- # Average probability across all frames
431
- speech_prob = np.mean(speech_probs)
432
- has_voice = speech_prob > threshold
433
-
434
- print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
435
-
436
- return has_voice
437
-
438
- except Exception as e:
439
- print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
440
- # Fall through to RMS-based VAD below
441
-
442
- # Fallback: RMS-based VAD (original implementation)
443
- rms_threshold = 0.002
444
- rms = np.sqrt(np.mean(audio_array ** 2))
445
- peak = np.max(np.abs(audio_array))
446
- audio_std = np.std(audio_array)
447
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
448
-
449
- has_voice_rms = rms > rms_threshold
450
- has_voice_peak = peak > rms_threshold * 3
451
- has_voice_variation = audio_std > rms_threshold * 0.8
452
- has_voice_zcr = zero_crossing_rate > 0.008
453
-
454
- has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
455
-
456
- print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
457
-
458
- return has_voice
459
-
460
- except Exception as e:
461
- print(f"Error in VAD: {e}")
462
- return True # Default to assuming voice activity on error
463
-
464
- def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
465
- """Stricter VAD check specifically for pre-transcription filtering"""
466
- try:
467
- audio_array = self._bytes_to_audio_array(audio_data)
468
- if len(audio_array) == 0:
469
- return False
470
-
471
- # Normalize audio
472
- audio_array = audio_array.astype(np.float32)
473
- if np.max(np.abs(audio_array)) > 0:
474
- audio_array /= np.max(np.abs(audio_array))
475
-
476
- # Calculate features with higher thresholds for meaningful speech
477
- rms = np.sqrt(np.mean(audio_array ** 2))
478
- peak = np.max(np.abs(audio_array))
479
- audio_std = np.std(audio_array)
480
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
481
-
482
- # Higher thresholds for meaningful speech detection
483
- has_meaningful_voice = (
484
- rms > threshold and
485
- peak > threshold * 2 and
486
- audio_std > threshold * 0.5 and
487
- zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
488
- )
489
-
490
- return has_meaningful_voice
491
-
492
- except Exception as e:
493
- print(f"Error in meaningful VAD: {e}")
494
- return False # Default to no meaningful voice on error
495
-
496
- async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
497
- """Force complete any pending sentence for a participant"""
498
- try:
499
- # Check if sentence was already finalized
500
- if self.sentence_finalized.get(participant_id, False):
501
- print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
502
- return ""
503
-
504
- if participant_id in self.candidate_text_cache:
505
- cached_text = self.candidate_text_cache[participant_id]
506
-
507
- if cached_text and len(cached_text.strip()) > 0:
508
- result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
509
- return result
510
-
511
- return ""
512
-
513
- except Exception as e:
514
- print(f"Error in force_complete_sentence: {e}")
515
- import traceback
516
- traceback.print_exc()
517
- return ""
518
-
519
- async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
520
- """Transcribe audio data to text"""
521
- try:
522
- # Check for voice activity before running ASR
523
- has_voice = self.has_voice_activity(audio_data)
524
- if not has_voice:
525
- print(f"ASR: No voice activity detected in audio data, skipping transcription")
526
- return ""
527
-
528
- await self.ensure_model_loaded(language_code)
529
-
530
- if language_code not in self.asr_pipelines:
531
- raise ValueError(f"ASR model not available for language: {language_code}")
532
-
533
- # Convert audio bytes to numpy array
534
- audio_array = self._bytes_to_audio_array(audio_data)
535
-
536
- print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
537
- # Transcribe
538
- pipeline_obj = self.asr_pipelines[language_code]
539
- result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
540
-
541
- text = result.get("text", "")
542
-
543
- if callback:
544
- await callback(text)
545
-
546
- return text
547
-
548
- except Exception as e:
549
- print(f"TranscriptionService: Transcription error: {e}")
550
- import traceback
551
- traceback.print_exc()
552
- return ""
553
-
554
- def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
555
- """Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
556
- try:
557
- # Detect format by checking magic bytes
558
- is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
559
- is_wav = audio_data[:4] == b'RIFF'
560
-
561
- import sys
562
- print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
563
- sys.stdout.flush()
564
-
565
- # Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
566
- # This is the most common case for microphone input
567
- if not is_wav and not is_webm and len(audio_data) > 0:
568
- try:
569
- # Assume 16-bit PCM at 48kHz (browser's native rate)
570
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
571
-
572
- # Check if this looks like valid audio data (not NaN, reasonable range)
573
- if len(audio_array) > 0 and not np.isnan(audio_array).any():
574
- print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
575
-
576
- # Convert to float32 and normalize
577
- audio_float = audio_array.astype(np.float32) / 32768.0
578
-
579
- # Resample from 48kHz to 16kHz
580
- import librosa
581
- audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
582
- print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
583
-
584
- return audio_array
585
- except Exception as pcm_error:
586
- print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
587
- # Fall through to other methods
588
-
589
- if is_webm:
590
- # Decode WebM/Opus using pydub (requires ffmpeg)
591
- try:
592
- from pydub import AudioSegment
593
- audio_io = io.BytesIO(audio_data)
594
- audio_segment = AudioSegment.from_file(audio_io, format="webm")
595
-
596
- # Convert to mono 16kHz
597
- audio_segment = audio_segment.set_channels(1)
598
- audio_segment = audio_segment.set_frame_rate(16000)
599
-
600
- # Convert to numpy array
601
- samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
602
- # Normalize to float32 [-1, 1]
603
- audio_array = samples.astype(np.float32) / 32768.0
604
- return audio_array
605
- except Exception as webm_error:
606
- print(f"TranscriptionService: WebM decoding error: {webm_error}")
607
- # Fall through to other methods
608
-
609
- if is_wav:
610
- # Decode WAV format (first chunk from frontend includes WAV header with sample rate)
611
- try:
612
- audio_io = io.BytesIO(audio_data)
613
- with wave.open(audio_io, 'rb') as wav_file:
614
- sample_rate = wav_file.getframerate()
615
- channels = wav_file.getnchannels()
616
- sample_width = wav_file.getsampwidth()
617
-
618
- print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
619
-
620
- frames = wav_file.readframes(-1)
621
- audio_array = np.frombuffer(frames, dtype=np.int16)
622
-
623
- # Resample if needed
624
- if sample_rate != 16000:
625
- print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
626
- import librosa
627
- # Convert to float first
628
- audio_float = audio_array.astype(np.float32) / 32768.0
629
- # Resample
630
- audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
631
- print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
632
- else:
633
- # Convert to float32 and normalize
634
- audio_array = audio_array.astype(np.float32) / 32768.0
635
-
636
- print(f"Returning audio array: {len(audio_array)} samples", flush=True)
637
- return audio_array
638
- except Exception as wav_error:
639
- print(f"TranscriptionService: WAV decoding error: {wav_error}")
640
- import traceback
641
- traceback.print_exc()
642
-
643
- # Fallback: assume raw float32 audio data
644
- try:
645
- audio_array = np.frombuffer(audio_data, dtype=np.float32)
646
- return audio_array
647
- except Exception:
648
- pass
649
-
650
- # Last resort: return empty array
651
- return np.array([], dtype=np.float32)
652
-
653
- except Exception as e:
654
- print(f"TranscriptionService: Audio conversion error: {e}")
655
- return np.array([], dtype=np.float32)
656
-
657
- def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
658
- """Convert numpy audio array back to WAV bytes for storage"""
659
- try:
660
- # Ensure float32 format
661
- if audio_array.dtype != np.float32:
662
- audio_array = audio_array.astype(np.float32)
663
-
664
- # Convert to 16-bit PCM for WAV storage
665
- audio_int16 = (audio_array * 32767).astype(np.int16)
666
-
667
- # Create WAV bytes
668
- wav_buffer = io.BytesIO()
669
- with wave.open(wav_buffer, 'wb') as wav_file:
670
- wav_file.setnchannels(1) # Mono
671
- wav_file.setsampwidth(2) # 16-bit
672
- wav_file.setframerate(16000) # 16kHz
673
- wav_file.writeframes(audio_int16.tobytes())
674
-
675
- return wav_buffer.getvalue()
676
-
677
- except Exception as e:
678
- print(f"Error converting audio array to bytes: {e}")
679
- return b''
680
-
681
- def clear_participant_buffers(self, participant_id: str):
682
- """Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
683
- if participant_id in self.candidate_audio_buffers:
684
- del self.candidate_audio_buffers[participant_id]
685
- if participant_id in self.candidate_text_cache:
686
- del self.candidate_text_cache[participant_id]
687
- if participant_id in self.silence_counters:
688
- del self.silence_counters[participant_id]
689
- if participant_id in self.sentence_finalized:
690
- del self.sentence_finalized[participant_id]
691
-
692
- async def load_remaining_models_in_background(self):
693
- """Load all remaining ASR models in the background after startup"""
694
- try:
695
- print("ASR: Starting background loading of additional language models...")
696
- for lang_code in self.asr_config.keys():
697
- if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
698
- try:
699
- print(f"ASR: Background loading model for {lang_code}...")
700
- self.models_loading_status[lang_code] = "loading"
701
-
702
- model_config = self.asr_config[lang_code]
703
- # Use quantization helper for background loading too
704
- pipeline_obj = self._load_and_quantize_pipeline(lang_code, model_config)
705
- self.asr_pipelines[lang_code] = pipeline_obj
706
- self.models_loading_status[lang_code] = "loaded"
707
- print(f"ASR: Successfully loaded model for {lang_code} in background")
708
-
709
- # Add a small delay between loading models to prevent overwhelming the system
710
- await asyncio.sleep(2)
711
- except Exception as e:
712
- print(f"ASR: Failed to load model for {lang_code} in background: {e}")
713
- self.models_loading_status[lang_code] = "failed"
714
-
715
- print("ASR: Background loading of all language models complete")
716
- print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
717
- except Exception as e:
718
- print(f"ASR: Error in background model loading: {e}")
719
-
720
- def start_background_loading(self):
721
- """Start background loading of models as a non-blocking task"""
722
- if self.background_loading_task is None:
723
- self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
724
- print("ASR: Background model loading task started")
725
-
726
- async def cleanup(self):
727
- """Cleanup resources"""
728
- # Cancel background loading if still running
729
- if self.background_loading_task and not self.background_loading_task.done():
730
- self.background_loading_task.cancel()
731
- try:
732
- await self.background_loading_task
733
- except asyncio.CancelledError:
734
- pass
735
-
736
- self.asr_pipelines.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/transcription_service.py.bak DELETED
@@ -1,726 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- import time
6
- from typing import Dict, Optional, Callable
7
- from transformers import pipeline
8
- import torch
9
- from app.models import LanguageCode
10
- from app.services.performance_mixin import track_performance
11
-
12
- # Silero VAD imports
13
- try:
14
- import silero_vad
15
- SILERO_VAD_AVAILABLE = True
16
- except ImportError:
17
- SILERO_VAD_AVAILABLE = False
18
- print("Warning: silero-vad not installed. Falling back to RMS-based VAD.")
19
-
20
- class TranscriptionService:
21
- def __init__(self):
22
- self.asr_pipelines: Dict[str, any] = {}
23
- self.device = 0 if torch.cuda.is_available() else -1
24
-
25
- # Model configurations - using original mutisya models with updated config
26
- self.asr_config = {
27
- "eng": {"model_repo": "openai/whisper-base.en", "model_type": "whisper"},
28
- "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-swh-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
29
- "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-kik-superv-v25-37-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
30
- "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-kam-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
31
- "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-mer-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
32
- "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-luo-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True},
33
- "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-som-superv-v25-36-1", "model_type": "wav2vec2-bert", "trust_remote_code": True}
34
- }
35
-
36
- self.preload_languages = ["eng"]
37
- self.background_loading_task = None
38
- self.models_loading_status = {}
39
-
40
- # Enhanced audio buffering for VAD-based sentence detection
41
- self.candidate_audio_buffers: Dict[str, bytes] = {} # participant_id -> candidate audio buffer
42
- self.candidate_text_cache: Dict[str, str] = {} # participant_id -> current candidate text
43
- self.silence_counters: Dict[str, int] = {} # participant_id -> consecutive silence chunks
44
- self.sentence_finalized: Dict[str, bool] = {} # participant_id -> whether current sentence is already finalized
45
-
46
- # VAD parameters - made more lenient for better detection
47
- self.silence_threshold = 1 # Number of consecutive silent chunks before sentence break (1 second for natural pauses)
48
- self.min_sentence_length = 0.03 # Minimum sentence length in seconds (very short)
49
-
50
- # Silero VAD initialization
51
- self.vad_model = None
52
- self.vad_sample_rate = 16000
53
- self.vad_available = SILERO_VAD_AVAILABLE
54
-
55
- async def initialize(self):
56
- """Initialize ASR models for preloaded languages and Silero VAD"""
57
- # Initialize Silero VAD model
58
- if self.vad_available:
59
- try:
60
- print("Loading Silero VAD model...")
61
- self.vad_model = silero_vad.load_silero_vad(onnx=False)
62
- print("✓ Silero VAD model loaded successfully")
63
- except Exception as e:
64
- print(f"Failed to load Silero VAD model: {e}")
65
- print("Falling back to RMS-based VAD")
66
- self.vad_available = False
67
-
68
- # Initialize ASR models
69
- for lang_code in self.preload_languages:
70
- if lang_code in self.asr_config:
71
- try:
72
- model_config = self.asr_config[lang_code]
73
- # Build pipeline parameters
74
- pipeline_params = {
75
- "task": "automatic-speech-recognition",
76
- "model": model_config["model_repo"],
77
- "device": self.device,
78
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
79
- }
80
-
81
- # Add trust_remote_code if specified
82
- if model_config.get("trust_remote_code", False):
83
- pipeline_params["trust_remote_code"] = True
84
-
85
- pipeline_obj = pipeline(**pipeline_params)
86
- self.asr_pipelines[lang_code] = pipeline_obj
87
- except Exception as e:
88
- print(f"Failed to load ASR model for {lang_code}: {e}")
89
-
90
- async def ensure_model_loaded(self, language_code: str):
91
- """Load ASR model for language if not already loaded"""
92
- if language_code not in self.asr_pipelines and language_code in self.asr_config:
93
- try:
94
- model_config = self.asr_config[language_code]
95
- # Build pipeline parameters
96
- pipeline_params = {
97
- "task": "automatic-speech-recognition",
98
- "model": model_config["model_repo"],
99
- "device": self.device,
100
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
101
- }
102
-
103
- # Add trust_remote_code if specified
104
- if model_config.get("trust_remote_code", False):
105
- pipeline_params["trust_remote_code"] = True
106
-
107
- pipeline_obj = pipeline(**pipeline_params)
108
- self.asr_pipelines[language_code] = pipeline_obj
109
- except Exception as e:
110
- print(f"Failed to load ASR model for {language_code}: {e}")
111
- raise
112
-
113
- async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
114
- has_voice_activity: bool = True,
115
- progress_callback: Optional[Callable] = None,
116
- sentence_callback: Optional[Callable] = None,
117
- debug_callback: Optional[Callable] = None) -> str:
118
- """Process audio chunk with VAD-based sentence detection"""
119
- try:
120
- # Initialize buffers if needed
121
- if participant_id not in self.candidate_audio_buffers:
122
- # Store as numpy array, not bytes, to avoid multiple WAV header issues
123
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
124
- self.candidate_text_cache[participant_id] = ""
125
- self.silence_counters[participant_id] = 0
126
- self.sentence_finalized[participant_id] = False
127
-
128
- # Convert current chunk to numpy array for processing
129
- current_chunk_array = self._bytes_to_audio_array(audio_data)
130
- if len(current_chunk_array) == 0:
131
- print(f"WARNING: Received empty audio chunk for participant {participant_id}")
132
- return self.candidate_text_cache.get(participant_id, "")
133
-
134
- print(f"DEBUG: Received audio chunk - bytes: {len(audio_data)}, samples: {len(current_chunk_array)}, "
135
- f"duration: {len(current_chunk_array)/16000:.3f}s, "
136
- f"first 4 bytes: {audio_data[:4]}")
137
-
138
- # DO NOT normalize individual chunks - this causes audio distortion
139
- # We'll normalize the entire accumulated audio buffer before transcription
140
- current_chunk_array = current_chunk_array.astype(np.float32)
141
-
142
- # Get existing accumulated audio array (now stored as numpy array)
143
- existing_array = self.candidate_audio_buffers[participant_id]
144
- if len(existing_array) > 0:
145
- # Concatenate with existing audio (like stream = np.concatenate([stream, y]))
146
- combined_array = np.concatenate([existing_array, current_chunk_array])
147
- else:
148
- combined_array = current_chunk_array
149
-
150
- # Store as numpy array to avoid WAV header accumulation issues
151
- self.candidate_audio_buffers[participant_id] = combined_array
152
-
153
- # For debug callback, convert to bytes (this adds ONE WAV header)
154
- combined_bytes = self._audio_array_to_bytes(combined_array)
155
-
156
- # Update silence counter based on voice activity
157
- if not has_voice_activity:
158
- self.silence_counters[participant_id] += 1
159
- else:
160
- self.silence_counters[participant_id] = 0
161
-
162
- # Check if we should finalize sentence due to prolonged silence
163
- should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
164
- len(combined_array) > 0 and
165
- not self.sentence_finalized[participant_id])
166
-
167
- if should_finalize:
168
- return await self._finalize_candidate_sentence(
169
- language_code, participant_id, sentence_callback
170
- )
171
-
172
- # Always run transcription on the accumulated audio
173
- audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
174
-
175
- # Minimum duration check - ignore very short audio bursts
176
- MIN_CHUNK_DURATION = 0.3 # 300ms minimum
177
- if audio_duration_sec < MIN_CHUNK_DURATION:
178
- print(f"Audio chunk too short: {audio_duration_sec:.3f}s < {MIN_CHUNK_DURATION}s, skipping transcription")
179
- if progress_callback:
180
- cached_text = self.candidate_text_cache.get(participant_id, "")
181
- await progress_callback(cached_text, False)
182
- return self.candidate_text_cache.get(participant_id, "")
183
-
184
- # Force finalization if buffer gets too long (prevent infinite accumulation)
185
- if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]: # Force completion after 15 seconds
186
- return await self._finalize_candidate_sentence(
187
- language_code, participant_id, sentence_callback
188
- )
189
-
190
- # Run voice activity detection on the accumulated audio before transcription
191
- has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
192
-
193
- if not has_voice_in_buffer:
194
- # Still send progress update with cached text to maintain UI state
195
- if progress_callback:
196
- cached_text = self.candidate_text_cache.get(participant_id, "")
197
- await progress_callback(cached_text, False)
198
- return self.candidate_text_cache.get(participant_id, "")
199
-
200
- # Run transcription
201
- await self.ensure_model_loaded(language_code)
202
-
203
- # Double-check voice activity before running expensive ASR
204
- has_voice_for_asr = self.has_voice_activity(combined_bytes)
205
- if not has_voice_for_asr:
206
- print(f"ASR: No voice activity detected in audio buffer for participant {participant_id}, skipping ASR execution")
207
- # Return cached text and send progress update
208
- if progress_callback:
209
- cached_text = self.candidate_text_cache.get(participant_id, "")
210
- await progress_callback(cached_text, False)
211
- return self.candidate_text_cache.get(participant_id, "")
212
-
213
- if language_code not in self.asr_pipelines:
214
- raise ValueError(f"ASR model not available for language: {language_code}")
215
-
216
- print(f"ASR: Running transcription for participant {participant_id} with {len(combined_array)/16000:.2f}s of audio")
217
- pipeline_obj = self.asr_pipelines[language_code]
218
-
219
- # Normalize the ENTIRE accumulated audio buffer before transcription
220
- # This prevents audio distortion from per-chunk normalization
221
- normalized_array = combined_array.astype(np.float32)
222
- max_val = np.max(np.abs(normalized_array))
223
- if max_val > 0:
224
- normalized_array = normalized_array / max_val
225
-
226
- # Track transcription latency
227
- transcription_start_time = time.time()
228
-
229
- # For wav2vec2 models, request word timestamps
230
- model_type = self.asr_config[language_code].get("model_type", "whisper")
231
- if model_type in ["wav2vec2-bert", "wav2vec2"]:
232
- result = pipeline_obj(
233
- {"sampling_rate": 16000, "raw": normalized_array},
234
- return_timestamps="word"
235
- )
236
- else:
237
- # Whisper model - add anti-hallucination parameters
238
- # Note: HuggingFace pipeline uses different parameter names than OpenAI Whisper
239
- result = pipeline_obj(
240
- {"sampling_rate": 16000, "raw": normalized_array},
241
- return_timestamps=True,
242
- chunk_length_s=30, # Process in 30s chunks
243
- stride_length_s=5 # 5s stride for context
244
- )
245
-
246
- transcription_latency_ms = (time.time() - transcription_start_time) * 1000
247
-
248
- candidate_text = result.get("text", "").strip()
249
- word_timestamps = result.get("chunks", []) if model_type in ["wav2vec2-bert", "wav2vec2"] else None
250
-
251
- # Send debug information if callback provided (for wav2vec2 models only)
252
- if debug_callback and word_timestamps is not None:
253
- debug_info = {
254
- "text": candidate_text,
255
- "timestamps": word_timestamps,
256
- "audio_data": combined_bytes,
257
- "audio_duration": audio_duration_sec,
258
- "model_type": model_type,
259
- "transcription_latency_ms": transcription_latency_ms
260
- }
261
- await debug_callback(debug_info)
262
-
263
- # Filter out common ASR artifacts and very short responses
264
- artifacts = [
265
- "thank you", "thanks", "bye", ".", ",", "?", "!",
266
- "um", "uh", "ah", "hmm", "mm", "mhm",
267
- "you", "the", "a", "an", "and", "but", "or",
268
- "music", "laughter", "applause", "[music]", "[laughter]",
269
- # Common Whisper hallucinations:
270
- "subscribe", "subtitles", "amara", "www", "http",
271
- "please subscribe", "like and subscribe",
272
- "thank you for watching", "don't forget to subscribe",
273
- "[blank_audio]", "[noise]", "[silence]",
274
- ]
275
-
276
- # Check if the result is likely an artifact
277
- is_artifact = (
278
- len(candidate_text) < 3 or # Very short
279
- candidate_text.lower() in artifacts or # Common artifacts
280
- len(candidate_text.split()) == 1 and len(candidate_text) < 6 # Single very short word
281
- )
282
-
283
- if is_artifact:
284
- # Keep the previous cached text instead of updating with artifact
285
- candidate_text = self.candidate_text_cache.get(participant_id, "")
286
-
287
- # Cache the current candidate text
288
- self.candidate_text_cache[participant_id] = candidate_text
289
-
290
- # Force completion if we have a reasonable amount of text and some silence
291
- word_count = len(candidate_text.split()) if candidate_text else 0
292
- if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
293
- not self.sentence_finalized[participant_id]): # At least 3 words and 2 silent chunks
294
- return await self._finalize_candidate_sentence(
295
- language_code, participant_id, sentence_callback
296
- )
297
-
298
- # Always send progress update
299
- if progress_callback:
300
- await progress_callback(candidate_text, False)
301
-
302
- return candidate_text
303
-
304
- except Exception as e:
305
- print(f"TranscriptionService: Error processing audio chunk: {e}")
306
- import traceback
307
- traceback.print_exc()
308
- # Even on error, try to send cached text
309
- if progress_callback:
310
- cached_text = self.candidate_text_cache.get(participant_id, "")
311
- await progress_callback(cached_text, False)
312
- return self.candidate_text_cache.get(participant_id, "")
313
-
314
- async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
315
- sentence_callback: Optional[Callable] = None) -> str:
316
- """Finalize the current candidate sentence and clear buffers"""
317
- try:
318
- # Check if sentence was already finalized
319
- if self.sentence_finalized.get(participant_id, False):
320
- print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
321
- return self.candidate_text_cache.get(participant_id, "")
322
-
323
- final_text = self.candidate_text_cache.get(participant_id, "")
324
- final_audio_array = self.candidate_audio_buffers.get(participant_id, np.array([], dtype=np.float32))
325
-
326
- # Convert audio array to bytes for VAD check and callback
327
- final_audio_bytes = self._audio_array_to_bytes(final_audio_array) if len(final_audio_array) > 0 else b''
328
-
329
- if final_text and len(final_text.strip()) > 0:
330
- # Run VAD check on the final accumulated buffer before sending for translation
331
- if len(final_audio_bytes) > 0:
332
- has_voice_in_final = self.has_meaningful_voice_activity(final_audio_bytes)
333
- if not has_voice_in_final:
334
- print(f"Finalize: No voice activity in final buffer for participant {participant_id}, discarding sentence: '{final_text}'")
335
- # Clear buffers without sending to translation
336
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
337
- self.candidate_text_cache[participant_id] = ""
338
- self.silence_counters[participant_id] = 0
339
- self.sentence_finalized[participant_id] = False
340
- return ""
341
-
342
- # Mark as finalized BEFORE calling the callback to prevent race conditions
343
- self.sentence_finalized[participant_id] = True
344
-
345
- # Send to sentence callback for translation
346
- if sentence_callback and len(final_audio_bytes) > 0:
347
- print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
348
- await sentence_callback(final_text, final_audio_bytes)
349
-
350
- # Clear buffers for next sentence
351
- self.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
352
- self.candidate_text_cache[participant_id] = ""
353
- self.silence_counters[participant_id] = 0
354
- self.sentence_finalized[participant_id] = False # Reset for next sentence
355
-
356
- return final_text
357
-
358
- except Exception as e:
359
- print(f"Error finalizing sentence: {e}")
360
- import traceback
361
- traceback.print_exc()
362
- # Reset finalized flag on error
363
- self.sentence_finalized[participant_id] = False
364
- return ""
365
-
366
- def has_voice_activity(self, audio_data: bytes, threshold: float = 0.5) -> bool:
367
- """Voice Activity Detection using Silero VAD (with RMS fallback)"""
368
- try:
369
- audio_array = self._bytes_to_audio_array(audio_data)
370
- if len(audio_array) == 0:
371
- print("VAD: No audio array, returning False")
372
- return False
373
-
374
- # Normalize audio to float32 range [-1, 1]
375
- audio_array = audio_array.astype(np.float32)
376
- if np.max(np.abs(audio_array)) > 0:
377
- audio_array /= np.max(np.abs(audio_array))
378
-
379
- # Use Silero VAD if available
380
- if self.vad_available and self.vad_model is not None:
381
- try:
382
- # Silero VAD expects 512 samples (32ms) or 1536 samples (96ms) for 16kHz
383
- # Process audio in chunks and average the probabilities
384
- frame_size = 512 # 32ms at 16kHz
385
- num_samples = len(audio_array)
386
-
387
- # If audio is too short, pad it
388
- if num_samples < frame_size:
389
- audio_array = np.pad(audio_array, (0, frame_size - num_samples), mode='constant')
390
- num_samples = frame_size
391
-
392
- # Process in frames and collect probabilities
393
- speech_probs = []
394
- for i in range(0, num_samples, frame_size):
395
- frame = audio_array[i:i + frame_size]
396
- if len(frame) < frame_size:
397
- # Pad last frame if needed
398
- frame = np.pad(frame, (0, frame_size - len(frame)), mode='constant')
399
-
400
- # Convert to torch tensor
401
- frame_tensor = torch.from_numpy(frame).float()
402
-
403
- # Get speech probability from Silero VAD
404
- with torch.no_grad():
405
- prob = self.vad_model(frame_tensor, self.vad_sample_rate).item()
406
- speech_probs.append(prob)
407
-
408
- # Average probability across all frames
409
- speech_prob = np.mean(speech_probs)
410
- has_voice = speech_prob > threshold
411
-
412
- print(f"VAD: Silero speech_prob={speech_prob:.4f} (avg of {len(speech_probs)} frames), threshold={threshold}, RESULT={has_voice}")
413
-
414
- return has_voice
415
-
416
- except Exception as e:
417
- print(f"Silero VAD error: {e}, falling back to RMS-based VAD")
418
- # Fall through to RMS-based VAD below
419
-
420
- # Fallback: RMS-based VAD (original implementation)
421
- rms_threshold = 0.002
422
- rms = np.sqrt(np.mean(audio_array ** 2))
423
- peak = np.max(np.abs(audio_array))
424
- audio_std = np.std(audio_array)
425
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
426
-
427
- has_voice_rms = rms > rms_threshold
428
- has_voice_peak = peak > rms_threshold * 3
429
- has_voice_variation = audio_std > rms_threshold * 0.8
430
- has_voice_zcr = zero_crossing_rate > 0.008
431
-
432
- has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
433
-
434
- print(f"VAD: RMS-based - RMS={rms:.6f}({has_voice_rms}), peak={peak:.6f}({has_voice_peak}), std={audio_std:.6f}({has_voice_variation}), zcr={zero_crossing_rate:.6f}({has_voice_zcr}), RESULT={has_voice}")
435
-
436
- return has_voice
437
-
438
- except Exception as e:
439
- print(f"Error in VAD: {e}")
440
- return True # Default to assuming voice activity on error
441
-
442
- def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.005) -> bool:
443
- """Stricter VAD check specifically for pre-transcription filtering"""
444
- try:
445
- audio_array = self._bytes_to_audio_array(audio_data)
446
- if len(audio_array) == 0:
447
- return False
448
-
449
- # Normalize audio
450
- audio_array = audio_array.astype(np.float32)
451
- if np.max(np.abs(audio_array)) > 0:
452
- audio_array /= np.max(np.abs(audio_array))
453
-
454
- # Calculate features with higher thresholds for meaningful speech
455
- rms = np.sqrt(np.mean(audio_array ** 2))
456
- peak = np.max(np.abs(audio_array))
457
- audio_std = np.std(audio_array)
458
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
459
-
460
- # Higher thresholds for meaningful speech detection
461
- has_meaningful_voice = (
462
- rms > threshold and
463
- peak > threshold * 2 and
464
- audio_std > threshold * 0.5 and
465
- zero_crossing_rate > 0.015 # Higher ZCR threshold for meaningful speech
466
- )
467
-
468
- return has_meaningful_voice
469
-
470
- except Exception as e:
471
- print(f"Error in meaningful VAD: {e}")
472
- return False # Default to no meaningful voice on error
473
-
474
- async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
475
- """Force complete any pending sentence for a participant"""
476
- try:
477
- # Check if sentence was already finalized
478
- if self.sentence_finalized.get(participant_id, False):
479
- print(f"Force completion: Sentence for participant {participant_id} already finalized, skipping")
480
- return ""
481
-
482
- if participant_id in self.candidate_text_cache:
483
- cached_text = self.candidate_text_cache[participant_id]
484
-
485
- if cached_text and len(cached_text.strip()) > 0:
486
- result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
487
- return result
488
-
489
- return ""
490
-
491
- except Exception as e:
492
- print(f"Error in force_complete_sentence: {e}")
493
- import traceback
494
- traceback.print_exc()
495
- return ""
496
-
497
- @track_performance("transcription", "transcribe_audio")
498
- async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
499
- """Transcribe audio data to text"""
500
- try:
501
- # Check for voice activity before running ASR
502
- has_voice = self.has_voice_activity(audio_data)
503
- if not has_voice:
504
- print(f"ASR: No voice activity detected in audio data, skipping transcription")
505
- return ""
506
-
507
- await self.ensure_model_loaded(language_code)
508
-
509
- if language_code not in self.asr_pipelines:
510
- raise ValueError(f"ASR model not available for language: {language_code}")
511
-
512
- # Convert audio bytes to numpy array
513
- audio_array = self._bytes_to_audio_array(audio_data)
514
-
515
- print(f"ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
516
- # Transcribe
517
- pipeline_obj = self.asr_pipelines[language_code]
518
- result = pipeline_obj({"sampling_rate": 16000, "raw": audio_array})
519
-
520
- text = result.get("text", "")
521
-
522
- if callback:
523
- await callback(text)
524
-
525
- return text
526
-
527
- except Exception as e:
528
- print(f"TranscriptionService: Transcription error: {e}")
529
- import traceback
530
- traceback.print_exc()
531
- return ""
532
-
533
- def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
534
- """Convert audio bytes to numpy array (supports WAV, WebM/Opus)"""
535
- try:
536
- # Detect format by checking magic bytes
537
- is_webm = audio_data[:4] == b'\x1a\x45\xdf\xa3' # WebM/Matroska magic bytes
538
- is_wav = audio_data[:4] == b'RIFF'
539
-
540
- import sys
541
- print(f"_bytes_to_audio_array: length={len(audio_data)}, first 4 bytes={audio_data[:4]}, is_wav={is_wav}", flush=True)
542
- sys.stdout.flush()
543
-
544
- # Handle raw PCM (16-bit, 48kHz from extendable-media-recorder)
545
- # This is the most common case now that we strip WAV headers in frontend
546
- if not is_wav and not is_webm and len(audio_data) > 0:
547
- try:
548
- # Assume 16-bit PCM at 48kHz (browser's native rate)
549
- audio_array = np.frombuffer(audio_data, dtype=np.int16)
550
-
551
- # Check if this looks like valid audio data (not NaN, reasonable range)
552
- if len(audio_array) > 0 and not np.isnan(audio_array).any():
553
- print(f"Raw PCM: {len(audio_array)} samples, assuming 48kHz 16-bit", flush=True)
554
-
555
- # Convert to float32 and normalize
556
- audio_float = audio_array.astype(np.float32) / 32768.0
557
-
558
- # Resample from 48kHz to 16kHz
559
- import librosa
560
- audio_array = librosa.resample(audio_float, orig_sr=48000, target_sr=16000)
561
- print(f"Resampled to 16kHz: {len(audio_array)} samples", flush=True)
562
-
563
- return audio_array
564
- except Exception as pcm_error:
565
- print(f"TranscriptionService: Raw PCM decoding error: {pcm_error}", flush=True)
566
- # Fall through to other methods
567
-
568
- if is_webm:
569
- # Decode WebM/Opus using pydub (requires ffmpeg)
570
- try:
571
- from pydub import AudioSegment
572
- audio_io = io.BytesIO(audio_data)
573
- audio_segment = AudioSegment.from_file(audio_io, format="webm")
574
-
575
- # Convert to mono 16kHz
576
- audio_segment = audio_segment.set_channels(1)
577
- audio_segment = audio_segment.set_frame_rate(16000)
578
-
579
- # Convert to numpy array
580
- samples = np.array(audio_segment.get_array_of_samples(), dtype=np.int16)
581
- # Normalize to float32 [-1, 1]
582
- audio_array = samples.astype(np.float32) / 32768.0
583
- return audio_array
584
- except Exception as webm_error:
585
- print(f"TranscriptionService: WebM decoding error: {webm_error}")
586
- # Fall through to other methods
587
-
588
- if is_wav:
589
- # Decode WAV format
590
- try:
591
- audio_io = io.BytesIO(audio_data)
592
- with wave.open(audio_io, 'rb') as wav_file:
593
- sample_rate = wav_file.getframerate()
594
- channels = wav_file.getnchannels()
595
- sample_width = wav_file.getsampwidth()
596
-
597
- print(f"WAV format: {sample_rate}Hz, {channels} channel(s), {sample_width*8}-bit", flush=True)
598
-
599
- frames = wav_file.readframes(-1)
600
- audio_array = np.frombuffer(frames, dtype=np.int16)
601
-
602
- # Resample if needed
603
- if sample_rate != 16000:
604
- print(f"WARNING: Resampling from {sample_rate}Hz to 16000Hz", flush=True)
605
- import librosa
606
- # Convert to float first
607
- audio_float = audio_array.astype(np.float32) / 32768.0
608
- # Resample
609
- audio_array = librosa.resample(audio_float, orig_sr=sample_rate, target_sr=16000)
610
- print(f"Resampled: {len(audio_array)} samples at 16kHz", flush=True)
611
- else:
612
- # Convert to float32 and normalize
613
- audio_array = audio_array.astype(np.float32) / 32768.0
614
-
615
- print(f"Returning audio array: {len(audio_array)} samples", flush=True)
616
- return audio_array
617
- except Exception as wav_error:
618
- print(f"TranscriptionService: WAV decoding error: {wav_error}")
619
- import traceback
620
- traceback.print_exc()
621
-
622
- # Fallback: assume raw float32 audio data
623
- try:
624
- audio_array = np.frombuffer(audio_data, dtype=np.float32)
625
- return audio_array
626
- except Exception:
627
- pass
628
-
629
- # Last resort: return empty array
630
- return np.array([], dtype=np.float32)
631
-
632
- except Exception as e:
633
- print(f"TranscriptionService: Audio conversion error: {e}")
634
- return np.array([], dtype=np.float32)
635
-
636
- def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
637
- """Convert numpy audio array back to WAV bytes for storage"""
638
- try:
639
- # Ensure float32 format
640
- if audio_array.dtype != np.float32:
641
- audio_array = audio_array.astype(np.float32)
642
-
643
- # Convert to 16-bit PCM for WAV storage
644
- audio_int16 = (audio_array * 32767).astype(np.int16)
645
-
646
- # Create WAV bytes
647
- wav_buffer = io.BytesIO()
648
- with wave.open(wav_buffer, 'wb') as wav_file:
649
- wav_file.setnchannels(1) # Mono
650
- wav_file.setsampwidth(2) # 16-bit
651
- wav_file.setframerate(16000) # 16kHz
652
- wav_file.writeframes(audio_int16.tobytes())
653
-
654
- return wav_buffer.getvalue()
655
-
656
- except Exception as e:
657
- print(f"Error converting audio array to bytes: {e}")
658
- return b''
659
-
660
- def clear_participant_buffers(self, participant_id: str):
661
- """Clear all buffers for a participant (e.g., when they stop speaking or disconnect)"""
662
- if participant_id in self.candidate_audio_buffers:
663
- del self.candidate_audio_buffers[participant_id]
664
- if participant_id in self.candidate_text_cache:
665
- del self.candidate_text_cache[participant_id]
666
- if participant_id in self.silence_counters:
667
- del self.silence_counters[participant_id]
668
- if participant_id in self.sentence_finalized:
669
- del self.sentence_finalized[participant_id]
670
-
671
- async def load_remaining_models_in_background(self):
672
- """Load all remaining ASR models in the background after startup"""
673
- try:
674
- print("ASR: Starting background loading of additional language models...")
675
- for lang_code in self.asr_config.keys():
676
- if lang_code not in self.preload_languages and lang_code not in self.asr_pipelines:
677
- try:
678
- print(f"ASR: Background loading model for {lang_code}...")
679
- self.models_loading_status[lang_code] = "loading"
680
-
681
- model_config = self.asr_config[lang_code]
682
- # Build pipeline parameters
683
- pipeline_params = {
684
- "task": "automatic-speech-recognition",
685
- "model": model_config["model_repo"],
686
- "device": self.device,
687
- "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32
688
- }
689
-
690
- # Add trust_remote_code if specified
691
- if model_config.get("trust_remote_code", False):
692
- pipeline_params["trust_remote_code"] = True
693
-
694
- pipeline_obj = pipeline(**pipeline_params)
695
- self.asr_pipelines[lang_code] = pipeline_obj
696
- self.models_loading_status[lang_code] = "loaded"
697
- print(f"ASR: Successfully loaded model for {lang_code} in background")
698
-
699
- # Add a small delay between loading models to prevent overwhelming the system
700
- await asyncio.sleep(2)
701
- except Exception as e:
702
- print(f"ASR: Failed to load model for {lang_code} in background: {e}")
703
- self.models_loading_status[lang_code] = "failed"
704
-
705
- print("ASR: Background loading of all language models complete")
706
- print(f"ASR: Loaded models: {list(self.asr_pipelines.keys())}")
707
- except Exception as e:
708
- print(f"ASR: Error in background model loading: {e}")
709
-
710
- def start_background_loading(self):
711
- """Start background loading of models as a non-blocking task"""
712
- if self.background_loading_task is None:
713
- self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
714
- print("ASR: Background model loading task started")
715
-
716
- async def cleanup(self):
717
- """Cleanup resources"""
718
- # Cancel background loading if still running
719
- if self.background_loading_task and not self.background_loading_task.done():
720
- self.background_loading_task.cancel()
721
- try:
722
- await self.background_loading_task
723
- except asyncio.CancelledError:
724
- pass
725
-
726
- self.asr_pipelines.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/transcription_service_onnx.py DELETED
@@ -1,682 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- from typing import Dict, Optional, Callable
6
- from collections import OrderedDict
7
- import onnxruntime as ort
8
- from transformers import AutoProcessor, WhisperProcessor
9
- from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
10
- import os
11
- from app.models import LanguageCode
12
-
13
- class ONNXTranscriptionService:
14
- def __init__(self):
15
- self.asr_models: Dict[str, any] = {}
16
- self.processors: Dict[str, any] = {}
17
- self.max_asr_models = 2 # Memory management - keep max 2 models loaded
18
- self.model_cache = OrderedDict() # LRU cache for models
19
-
20
- # GPU optimization - detect and configure providers
21
- available_providers = ort.get_available_providers()
22
- print(f"ONNX ASR: Available providers: {available_providers}")
23
-
24
- if 'CUDAExecutionProvider' in available_providers:
25
- # Configure CUDA provider with optimizations
26
- cuda_provider_options = {
27
- 'device_id': 0,
28
- 'arena_extend_strategy': 'kNextPowerOfTwo',
29
- 'gpu_mem_limit': int(0.8 * 1024 * 1024 * 1024), # 80% of GPU memory
30
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
31
- 'do_copy_in_default_stream': True,
32
- 'enable_tracing': True, # Enable tracing for better diagnostics
33
- }
34
-
35
- # Include TensorRT if available, then CUDA, then CPU
36
- provider_list = []
37
- if 'TensorrtExecutionProvider' in available_providers:
38
- provider_list.append('TensorrtExecutionProvider')
39
- provider_list.append(('CUDAExecutionProvider', cuda_provider_options))
40
- provider_list.append('CPUExecutionProvider')
41
-
42
- self.providers = provider_list
43
- print(f"ONNX ASR: Using GPU acceleration with providers: {[p[0] if isinstance(p, tuple) else p for p in provider_list]}")
44
- print(f"ONNX ASR: GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
45
- else:
46
- self.providers = ['CPUExecutionProvider']
47
- print("ONNX ASR: CUDA not available, using CPU execution")
48
-
49
- print(f"ONNX ASR: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
50
-
51
- # ONNX Model configurations - using pre-converted ONNX models from HuggingFace
52
- self.asr_config = {
53
- "eng": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}, # Pre-converted ONNX model
54
- "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
55
- "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
56
- "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
57
- "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
58
- "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
59
- "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
60
- }
61
-
62
- # Alternative model configurations for different performance tiers
63
- self.alternative_models = {
64
- "eng_small": {"model_repo": "mutisya/whisper-small-en-onnx", "model_type": "whisper", "use_onnx": True},
65
- "eng_base": {"model_repo": "mutisya/whisper-base-en-onnx", "model_type": "whisper", "use_onnx": True},
66
- "eng_medium": {"model_repo": "mutisya/whisper-medium-en-onnx", "model_type": "whisper", "use_onnx": True}
67
- }
68
-
69
- self.preload_languages = ["eng"]
70
-
71
- # Current model performance mode (small, base, medium)
72
- # Can be configured via environment variable WHISPER_MODEL_SIZE
73
- self.performance_mode = os.getenv("WHISPER_MODEL_SIZE", "medium").lower()
74
-
75
- # Enhanced audio buffering for VAD-based sentence detection
76
- self.candidate_audio_buffers: Dict[str, bytes] = {}
77
- self.candidate_text_cache: Dict[str, str] = {}
78
- self.silence_counters: Dict[str, int] = {}
79
- self.sentence_finalized: Dict[str, bool] = {}
80
-
81
- # VAD parameters
82
- self.silence_threshold = 2
83
- self.min_sentence_length = 0.03
84
-
85
- def set_performance_mode(self, mode: str):
86
- """Set the performance mode for English models (small, base, medium)"""
87
- if mode in ["small", "base", "medium"]:
88
- self.performance_mode = mode
89
- # Update the English model configuration based on performance mode
90
- if f"eng_{mode}" in self.alternative_models:
91
- self.asr_config["eng"] = self.alternative_models[f"eng_{mode}"]
92
- # Clear cached English model to force reload with new configuration
93
- if "eng" in self.model_cache:
94
- del self.model_cache["eng"]
95
- if "eng" in self.asr_models:
96
- del self.asr_models["eng"]
97
- if "eng" in self.processors:
98
- del self.processors["eng"]
99
- print(f"Performance mode set to {mode}. English model will be reloaded on next use.")
100
- else:
101
- print(f"Warning: No model configuration found for performance mode {mode}")
102
- else:
103
- print(f"Invalid performance mode: {mode}. Must be one of: small, base, medium")
104
-
105
- async def initialize(self):
106
- """Initialize ASR models for preloaded languages"""
107
- print(f"ONNX ASR: Initializing with providers: {self.providers}")
108
-
109
- # Apply performance mode to English model configuration
110
- if self.performance_mode in ["small", "base", "medium"]:
111
- if f"eng_{self.performance_mode}" in self.alternative_models:
112
- self.asr_config["eng"] = self.alternative_models[f"eng_{self.performance_mode}"]
113
- print(f"Using Whisper {self.performance_mode} model for English")
114
- else:
115
- print(f"Warning: Performance mode {self.performance_mode} not available, using default medium")
116
-
117
- for lang_code in self.preload_languages:
118
- if lang_code in self.asr_config:
119
- try:
120
- await self.ensure_model_loaded(lang_code)
121
- except Exception as e:
122
- print(f"Failed to load ASR model for {lang_code}: {e}")
123
-
124
- async def ensure_model_loaded(self, language_code: str):
125
- """Load ASR model for language if not already loaded with LRU cache"""
126
- if language_code in self.model_cache:
127
- # Move to end (most recently used)
128
- self.model_cache.move_to_end(language_code)
129
- return
130
-
131
- if language_code not in self.asr_config:
132
- raise ValueError(f"Language {language_code} not supported")
133
-
134
- model_config = self.asr_config[language_code]
135
-
136
- # Check if we need to evict old models
137
- while len(self.model_cache) >= self.max_asr_models:
138
- # Remove least recently used model
139
- old_lang, _ = self.model_cache.popitem(last=False)
140
- if old_lang in self.asr_models:
141
- del self.asr_models[old_lang]
142
- if old_lang in self.processors:
143
- del self.processors[old_lang]
144
- print(f"ONNX ASR: Evicted model for {old_lang} (LRU cache)")
145
-
146
- try:
147
- if model_config.get("use_onnx", False):
148
- # Load ONNX model
149
- print(f"ONNX ASR: Loading ONNX model for {language_code}")
150
-
151
- # Special handling for Whisper models
152
- if model_config.get("model_type") == "whisper":
153
- print(f"ONNX ASR: Loading Whisper ONNX model from {model_config['model_repo']}")
154
-
155
- # Get authentication token for private repos
156
- import os
157
- auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
158
-
159
- # Load pre-converted Whisper ONNX model using Optimum
160
- load_kwargs = {
161
- # export=False because we're using pre-converted models
162
- "export": False,
163
- # use_cache=True because our models now include past key value variants for optimization
164
- "use_cache": True,
165
- # Add authentication token for private repos
166
- "token": auth_token
167
- }
168
-
169
- # Configure providers - pass all available providers to Optimum
170
- provider_names = [p[0] if isinstance(p, tuple) else p for p in self.providers]
171
- load_kwargs["providers"] = provider_names
172
- print(f"ONNX ASR: Whisper using providers: {provider_names}")
173
-
174
- # Add subfolder if specified (for models that store ONNX in subfolders)
175
- if "subfolder" in model_config:
176
- load_kwargs["subfolder"] = model_config["subfolder"]
177
-
178
- model = ORTModelForSpeechSeq2Seq.from_pretrained(
179
- model_config["model_repo"],
180
- **load_kwargs
181
- )
182
-
183
- # Load Whisper processor with authentication token
184
- processor = WhisperProcessor.from_pretrained(
185
- model_config["model_repo"],
186
- token=auth_token
187
- )
188
-
189
- # Configure for English transcription
190
- model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
191
- language="en",
192
- task="transcribe"
193
- )
194
-
195
- self.asr_models[language_code] = model
196
- self.processors[language_code] = processor
197
-
198
- print(f"ONNX ASR: Successfully loaded Whisper ONNX model for {language_code}")
199
-
200
- else:
201
- # Original wav2vec2-bert model loading logic
202
- # Create ONNX session with optimizations and verbose logging
203
- session_options = ort.SessionOptions()
204
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
205
-
206
- # Enable verbose logging to diagnose operator assignments
207
- session_options.log_severity_level = 1 # WARNING level for detailed logs
208
- session_options.logid = "ONNX_ASR" # Prefix for log identification
209
-
210
- # Use configured providers with optimizations
211
- providers = self.providers
212
- print(f"ONNX ASR: wav2vec2-bert using providers: {[p[0] if isinstance(p, tuple) else p for p in providers]}")
213
-
214
- # Get authentication token for private repos
215
- import os
216
- auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
217
-
218
- # Download model files from HuggingFace Hub with authentication
219
- from huggingface_hub import hf_hub_download
220
- onnx_path = hf_hub_download(
221
- repo_id=model_config["model_repo"],
222
- filename="model.onnx",
223
- token=auth_token
224
- )
225
-
226
- session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
227
-
228
- # Load processor for preprocessing with authentication
229
- processor = AutoProcessor.from_pretrained(
230
- model_config["model_repo"],
231
- token=auth_token
232
- )
233
-
234
- self.asr_models[language_code] = session
235
- self.processors[language_code] = processor
236
-
237
- print(f"ONNX ASR: Successfully loaded ONNX model for {language_code}")
238
-
239
- else:
240
- # This service is ONNX-only - no PyTorch fallback
241
- raise ValueError(f"Language {language_code} is not configured for ONNX models. Set 'use_onnx': True in config.")
242
-
243
- # Add to cache
244
- self.model_cache[language_code] = True
245
-
246
- except Exception as e:
247
- print(f"Failed to load ASR model for {language_code}: {e}")
248
- raise
249
-
250
- async def process_audio_chunk(self, audio_data: bytes, language_code: str, participant_id: str,
251
- has_voice_activity: bool = True,
252
- progress_callback: Optional[Callable] = None,
253
- sentence_callback: Optional[Callable] = None) -> str:
254
- """Process audio chunk with VAD-based sentence detection using ONNX models"""
255
- try:
256
- # Initialize buffers if needed
257
- if participant_id not in self.candidate_audio_buffers:
258
- self.candidate_audio_buffers[participant_id] = b''
259
- self.candidate_text_cache[participant_id] = ""
260
- self.silence_counters[participant_id] = 0
261
- self.sentence_finalized[participant_id] = False
262
-
263
- # Convert current chunk to numpy array for processing
264
- current_chunk_array = self._bytes_to_audio_array(audio_data)
265
- if len(current_chunk_array) == 0:
266
- return self.candidate_text_cache.get(participant_id, "")
267
-
268
- # Normalize the audio chunk
269
- current_chunk_array = current_chunk_array.astype(np.float32)
270
- if np.max(np.abs(current_chunk_array)) > 0:
271
- current_chunk_array /= np.max(np.abs(current_chunk_array))
272
-
273
- # Get existing accumulated audio array
274
- existing_buffer = self.candidate_audio_buffers[participant_id]
275
- if len(existing_buffer) > 0:
276
- existing_array = self._bytes_to_audio_array(existing_buffer)
277
- if len(existing_array) > 0:
278
- combined_array = np.concatenate([existing_array, current_chunk_array])
279
- else:
280
- combined_array = current_chunk_array
281
- else:
282
- combined_array = current_chunk_array
283
-
284
- # Convert back to bytes for storage
285
- combined_bytes = self._audio_array_to_bytes(combined_array)
286
- self.candidate_audio_buffers[participant_id] = combined_bytes
287
-
288
- # Update silence counter based on voice activity
289
- if not has_voice_activity:
290
- self.silence_counters[participant_id] += 1
291
- else:
292
- self.silence_counters[participant_id] = 0
293
-
294
- # Check if we should finalize sentence due to prolonged silence
295
- should_finalize = (self.silence_counters[participant_id] >= self.silence_threshold and
296
- len(combined_array) > 0 and
297
- not self.sentence_finalized[participant_id])
298
-
299
- if should_finalize:
300
- return await self._finalize_candidate_sentence(
301
- language_code, participant_id, sentence_callback
302
- )
303
-
304
- # Always run transcription on the accumulated audio
305
- audio_duration_sec = len(combined_array) / 16000.0 # 16kHz sample rate
306
-
307
- if audio_duration_sec < 0.1: # Very short minimum
308
- if progress_callback:
309
- cached_text = self.candidate_text_cache.get(participant_id, "")
310
- await progress_callback(cached_text, False)
311
- return self.candidate_text_cache.get(participant_id, "")
312
-
313
- # Force finalization if buffer gets too long
314
- if audio_duration_sec > 15.0 and not self.sentence_finalized[participant_id]:
315
- return await self._finalize_candidate_sentence(
316
- language_code, participant_id, sentence_callback
317
- )
318
-
319
- # Run voice activity detection on the accumulated audio before transcription
320
- has_voice_in_buffer = self.has_meaningful_voice_activity(combined_bytes)
321
-
322
- if not has_voice_in_buffer:
323
- if progress_callback:
324
- cached_text = self.candidate_text_cache.get(participant_id, "")
325
- await progress_callback(cached_text, False)
326
- return self.candidate_text_cache.get(participant_id, "")
327
-
328
- # Run transcription
329
- await self.ensure_model_loaded(language_code)
330
-
331
- # Double-check voice activity before running expensive ASR
332
- has_voice_for_asr = self.has_voice_activity(combined_bytes)
333
- if not has_voice_for_asr:
334
- print(f"ONNX ASR: No voice activity detected, skipping ASR execution for {participant_id}")
335
- if progress_callback:
336
- cached_text = self.candidate_text_cache.get(participant_id, "")
337
- await progress_callback(cached_text, False)
338
- return self.candidate_text_cache.get(participant_id, "")
339
-
340
- if language_code not in self.asr_models:
341
- raise ValueError(f"ASR model not available for language: {language_code}")
342
-
343
- print(f"ONNX ASR: Running transcription for {participant_id} with {audio_duration_sec:.2f}s of audio")
344
-
345
- # Run ONNX inference (this service is ONNX-only)
346
- model_config = self.asr_config[language_code]
347
- if not model_config.get("use_onnx", False):
348
- raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
349
-
350
- # ONNX inference
351
- text = await self._run_onnx_inference(combined_array, language_code)
352
-
353
- # Filter out common ASR artifacts
354
- artifacts = [
355
- "thank you", "thanks", "bye", ".", ",", "?", "!",
356
- "um", "uh", "ah", "hmm", "mm", "mhm",
357
- "you", "the", "a", "an", "and", "but", "or",
358
- "music", "laughter", "applause", "[music]", "[laughter]",
359
- ]
360
-
361
- # Check if the result is likely an artifact
362
- is_artifact = (
363
- len(text) < 3 or
364
- text.lower() in artifacts or
365
- len(text.split()) == 1 and len(text) < 6
366
- )
367
-
368
- if is_artifact:
369
- text = self.candidate_text_cache.get(participant_id, "")
370
-
371
- # Cache the current candidate text
372
- self.candidate_text_cache[participant_id] = text
373
-
374
- # Force completion if we have reasonable text and some silence
375
- word_count = len(text.split()) if text else 0
376
- if (word_count >= 3 and self.silence_counters[participant_id] >= 2 and
377
- not self.sentence_finalized[participant_id]):
378
- return await self._finalize_candidate_sentence(
379
- language_code, participant_id, sentence_callback
380
- )
381
-
382
- # Always send progress update
383
- if progress_callback:
384
- await progress_callback(text, False)
385
-
386
- return text
387
-
388
- except Exception as e:
389
- print(f"ONNX TranscriptionService: Error processing audio chunk: {e}")
390
- import traceback
391
- traceback.print_exc()
392
- if progress_callback:
393
- cached_text = self.candidate_text_cache.get(participant_id, "")
394
- await progress_callback(cached_text, False)
395
- return self.candidate_text_cache.get(participant_id, "")
396
-
397
- async def _run_onnx_inference(self, audio_array: np.ndarray, language_code: str) -> str:
398
- """Run ONNX inference for speech recognition"""
399
- try:
400
- model = self.asr_models[language_code]
401
- processor = self.processors[language_code]
402
- model_config = self.asr_config[language_code]
403
-
404
- # Check if this is a Whisper model
405
- if model_config.get("model_type") == "whisper":
406
- # Whisper-specific processing using Optimum
407
- import torch
408
-
409
- # Process audio input for Whisper
410
- inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
411
-
412
- # Generate transcription using the ORTModelForSpeechSeq2Seq
413
- predicted_ids = model.generate(inputs.input_features, max_length=448)
414
-
415
- # Decode the generated IDs
416
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
417
-
418
- return transcription[0].strip() if transcription else ""
419
- else:
420
- # Original wav2vec2-bert processing
421
- session = model
422
-
423
- # Preprocess audio
424
- inputs = processor(audio_array, sampling_rate=16000, return_tensors="np")
425
-
426
- # Get input names for ONNX session
427
- input_names = [inp.name for inp in session.get_inputs()]
428
-
429
- # Prepare inputs for ONNX
430
- onnx_inputs = {}
431
- for name in input_names:
432
- if name in inputs:
433
- onnx_inputs[name] = inputs[name]
434
- elif name == "input_values" and "input_features" in inputs:
435
- onnx_inputs[name] = inputs["input_features"]
436
- elif name == "attention_mask" and "attention_mask" in inputs:
437
- onnx_inputs[name] = inputs["attention_mask"]
438
-
439
- # Run ONNX inference
440
- outputs = session.run(None, onnx_inputs)
441
-
442
- # Post-process outputs (assuming CTC decoding)
443
- logits = outputs[0] # First output should be logits
444
-
445
- # Simple greedy CTC decoding
446
- predicted_ids = np.argmax(logits, axis=-1)
447
-
448
- # Decode using processor
449
- text = processor.batch_decode(predicted_ids)[0]
450
-
451
- return text.strip()
452
-
453
- except Exception as e:
454
- print(f"ONNX ASR: Inference error: {e}")
455
- import traceback
456
- traceback.print_exc()
457
- return ""
458
-
459
- async def _finalize_candidate_sentence(self, language_code: str, participant_id: str,
460
- sentence_callback: Optional[Callable] = None) -> str:
461
- """Finalize the current candidate sentence and clear buffers"""
462
- try:
463
- if self.sentence_finalized.get(participant_id, False):
464
- print(f"Sentence for participant {participant_id} already finalized, skipping duplicate")
465
- return self.candidate_text_cache.get(participant_id, "")
466
-
467
- final_text = self.candidate_text_cache.get(participant_id, "")
468
- final_audio_bytes = self.candidate_audio_buffers.get(participant_id, b'')
469
-
470
- if final_text and len(final_text.strip()) > 0:
471
- self.sentence_finalized[participant_id] = True
472
-
473
- if sentence_callback and len(final_audio_bytes) > 0:
474
- print(f"Finalizing sentence for participant {participant_id}: '{final_text}'")
475
- await sentence_callback(final_text, final_audio_bytes)
476
-
477
- # Clear buffers for next sentence
478
- self.candidate_audio_buffers[participant_id] = b''
479
- self.candidate_text_cache[participant_id] = ""
480
- self.silence_counters[participant_id] = 0
481
- self.sentence_finalized[participant_id] = False
482
-
483
- return final_text
484
-
485
- except Exception as e:
486
- print(f"Error finalizing sentence: {e}")
487
- import traceback
488
- traceback.print_exc()
489
- self.sentence_finalized[participant_id] = False
490
- return ""
491
-
492
- def has_voice_activity(self, audio_data: bytes, threshold: float = 0.0005) -> bool:
493
- """Enhanced VAD based on audio analysis"""
494
- try:
495
- audio_array = self._bytes_to_audio_array(audio_data)
496
- if len(audio_array) == 0:
497
- return False
498
-
499
- # Normalize audio
500
- audio_array = audio_array.astype(np.float32)
501
- if np.max(np.abs(audio_array)) > 0:
502
- audio_array /= np.max(np.abs(audio_array))
503
-
504
- # Calculate multiple features for better VAD
505
- rms = np.sqrt(np.mean(audio_array ** 2))
506
- peak = np.max(np.abs(audio_array))
507
- audio_std = np.std(audio_array)
508
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
509
-
510
- # Voice activity detection
511
- has_voice_rms = rms > threshold
512
- has_voice_peak = peak > threshold * 3
513
- has_voice_variation = audio_std > threshold * 0.8
514
- has_voice_zcr = zero_crossing_rate > 0.008
515
-
516
- has_voice = has_voice_rms or (has_voice_peak and has_voice_variation) or has_voice_zcr
517
-
518
- return has_voice
519
-
520
- except Exception as e:
521
- print(f"Error in VAD: {e}")
522
- return True
523
-
524
- def has_meaningful_voice_activity(self, audio_data: bytes, threshold: float = 0.002) -> bool:
525
- """Stricter VAD check for pre-transcription filtering"""
526
- try:
527
- audio_array = self._bytes_to_audio_array(audio_data)
528
- if len(audio_array) == 0:
529
- return False
530
-
531
- # Normalize audio
532
- audio_array = audio_array.astype(np.float32)
533
- if np.max(np.abs(audio_array)) > 0:
534
- audio_array /= np.max(np.abs(audio_array))
535
-
536
- # Calculate features with higher thresholds
537
- rms = np.sqrt(np.mean(audio_array ** 2))
538
- peak = np.max(np.abs(audio_array))
539
- audio_std = np.std(audio_array)
540
- zero_crossing_rate = np.sum(np.diff(np.sign(audio_array)) != 0) / len(audio_array)
541
-
542
- # Higher thresholds for meaningful speech detection
543
- has_meaningful_voice = (
544
- rms > threshold and
545
- peak > threshold * 2 and
546
- audio_std > threshold * 0.5 and
547
- zero_crossing_rate > 0.015
548
- )
549
-
550
- return has_meaningful_voice
551
-
552
- except Exception as e:
553
- print(f"Error in meaningful VAD: {e}")
554
- return False
555
-
556
- async def force_complete_sentence(self, participant_id: str, language_code: str, sentence_callback: Optional[Callable] = None) -> str:
557
- """Force complete any pending sentence for a participant"""
558
- try:
559
- if self.sentence_finalized.get(participant_id, False):
560
- print(f"Force completion: Sentence for participant {participant_id} already finalized")
561
- return ""
562
-
563
- if participant_id in self.candidate_text_cache:
564
- cached_text = self.candidate_text_cache[participant_id]
565
-
566
- if cached_text and len(cached_text.strip()) > 0:
567
- result = await self._finalize_candidate_sentence(language_code, participant_id, sentence_callback)
568
- return result
569
-
570
- return ""
571
-
572
- except Exception as e:
573
- print(f"Error in force_complete_sentence: {e}")
574
- import traceback
575
- traceback.print_exc()
576
- return ""
577
-
578
- async def transcribe_audio(self, audio_data: bytes, language_code: str, callback: Optional[Callable] = None) -> str:
579
- """Transcribe audio data to text using ONNX models"""
580
- try:
581
- # Check for voice activity before running ASR
582
- has_voice = self.has_voice_activity(audio_data)
583
- if not has_voice:
584
- print(f"ONNX ASR: No voice activity detected, skipping transcription")
585
- return ""
586
-
587
- await self.ensure_model_loaded(language_code)
588
-
589
- if language_code not in self.asr_models:
590
- raise ValueError(f"ASR model not available for language: {language_code}")
591
-
592
- # Convert audio bytes to numpy array
593
- audio_array = self._bytes_to_audio_array(audio_data)
594
-
595
- print(f"ONNX ASR: Running transcription with {len(audio_array)/16000:.2f}s of audio")
596
-
597
- # Run ONNX inference (this service is ONNX-only)
598
- model_config = self.asr_config[language_code]
599
- if not model_config.get("use_onnx", False):
600
- raise ValueError(f"Language {language_code} is not configured for ONNX. This service only supports ONNX models.")
601
-
602
- # ONNX inference
603
- text = await self._run_onnx_inference(audio_array, language_code)
604
-
605
- if callback:
606
- await callback(text)
607
-
608
- return text
609
-
610
- except Exception as e:
611
- print(f"ONNX TranscriptionService: Transcription error: {e}")
612
- import traceback
613
- traceback.print_exc()
614
- return ""
615
-
616
- def _bytes_to_audio_array(self, audio_data: bytes) -> np.ndarray:
617
- """Convert audio bytes to numpy array"""
618
- try:
619
- # Try to decode as WAV
620
- try:
621
- audio_io = io.BytesIO(audio_data)
622
- with wave.open(audio_io, 'rb') as wav_file:
623
- frames = wav_file.readframes(-1)
624
- audio_array = np.frombuffer(frames, dtype=np.int16)
625
- # Convert to float32 and normalize
626
- audio_array = audio_array.astype(np.float32) / 32768.0
627
- return audio_array
628
- except Exception:
629
- pass
630
-
631
- # Fallback: assume raw float32 audio data
632
- try:
633
- audio_array = np.frombuffer(audio_data, dtype=np.float32)
634
- return audio_array
635
- except Exception:
636
- pass
637
-
638
- return np.array([], dtype=np.float32)
639
-
640
- except Exception as e:
641
- print(f"ONNX TranscriptionService: Audio conversion error: {e}")
642
- return np.array([], dtype=np.float32)
643
-
644
- def _audio_array_to_bytes(self, audio_array: np.ndarray) -> bytes:
645
- """Convert numpy audio array back to WAV bytes for storage"""
646
- try:
647
- if audio_array.dtype != np.float32:
648
- audio_array = audio_array.astype(np.float32)
649
-
650
- # Convert to 16-bit PCM for WAV storage
651
- audio_int16 = (audio_array * 32767).astype(np.int16)
652
-
653
- # Create WAV bytes
654
- wav_buffer = io.BytesIO()
655
- with wave.open(wav_buffer, 'wb') as wav_file:
656
- wav_file.setnchannels(1) # Mono
657
- wav_file.setsampwidth(2) # 16-bit
658
- wav_file.setframerate(16000) # 16kHz
659
- wav_file.writeframes(audio_int16.tobytes())
660
-
661
- return wav_buffer.getvalue()
662
-
663
- except Exception as e:
664
- print(f"Error converting audio array to bytes: {e}")
665
- return b''
666
-
667
- def clear_participant_buffers(self, participant_id: str):
668
- """Clear all buffers for a participant"""
669
- if participant_id in self.candidate_audio_buffers:
670
- del self.candidate_audio_buffers[participant_id]
671
- if participant_id in self.candidate_text_cache:
672
- del self.candidate_text_cache[participant_id]
673
- if participant_id in self.silence_counters:
674
- del self.silence_counters[participant_id]
675
- if participant_id in self.sentence_finalized:
676
- del self.sentence_finalized[participant_id]
677
-
678
- async def cleanup(self):
679
- """Cleanup resources"""
680
- self.asr_models.clear()
681
- self.processors.clear()
682
- self.model_cache.clear()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/transcription_service_onnx_optimized.py DELETED
@@ -1,251 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- from typing import Dict, Optional, Callable
6
- from collections import OrderedDict
7
- import onnxruntime as ort
8
- from transformers import AutoProcessor, WhisperProcessor
9
- from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
10
- import os
11
- from app.models import LanguageCode
12
-
13
- class OptimizedONNXTranscriptionService:
14
- """
15
- Optimized ONNX Transcription Service that uses pre-converted ONNX models
16
- instead of performing runtime conversion from PyTorch models.
17
-
18
- Benefits:
19
- - Faster container startup (no conversion time)
20
- - Reduced memory usage during initialization
21
- - More predictable deployment times
22
- - Better resource utilization in production
23
- """
24
-
25
- def __init__(self):
26
- self.asr_models: Dict[str, any] = {}
27
- self.processors: Dict[str, any] = {}
28
- self.max_asr_models = 2 # Memory management - keep max 2 models loaded
29
- self.model_cache = OrderedDict() # LRU cache for models
30
-
31
- # GPU optimization
32
- self.providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if ort.get_available_providers().__contains__('CUDAExecutionProvider') else ['CPUExecutionProvider']
33
-
34
- # OPTIMIZED ONNX Model configurations - using pre-converted models
35
- self.asr_config = {
36
- # English: Use pre-converted ONNX model (no runtime conversion!)
37
- "eng": {
38
- "model_repo": "mutisya/whisper-medium-en-onnx", # Pre-converted ONNX model
39
- "model_type": "whisper",
40
- "use_onnx": True,
41
- "export": False # ⭐ KEY CHANGE: No runtime export needed!
42
- },
43
-
44
- # African languages: Already using ONNX models
45
- "swa": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-swh-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
46
- "kik": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kik-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
47
- "kam": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-kam-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
48
- "mer": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-mer-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
49
- "luo": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-luo-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True},
50
- "som": {"model_repo": "mutisya/w2v-bert-2.0-asr-onnx-som-v25-37-1", "model_type": "wav2vec2-bert", "use_onnx": True}
51
- }
52
-
53
- self.preload_languages = ["eng"]
54
-
55
- # Enhanced audio buffering for VAD-based sentence detection
56
- self.candidate_audio_buffers: Dict[str, bytes] = {}
57
- self.candidate_text_cache: Dict[str, str] = {}
58
- self.silence_counters: Dict[str, int] = {}
59
- self.sentence_finalized: Dict[str, bool] = {}
60
-
61
- # VAD parameters
62
- self.silence_threshold = 2
63
- self.min_sentence_length = 0.03
64
-
65
- async def initialize(self):
66
- """Initialize ASR models for preloaded languages"""
67
- print(f"🚀 Optimized ONNX ASR: Initializing with providers: {self.providers}")
68
- print(f"📈 Performance Improvement: Using pre-converted ONNX models (no runtime conversion)")
69
-
70
- for lang_code in self.preload_languages:
71
- if lang_code in self.asr_config:
72
- try:
73
- start_time = asyncio.get_event_loop().time()
74
- await self.ensure_model_loaded(lang_code)
75
- end_time = asyncio.get_event_loop().time()
76
- print(f"⚡ Model loading time for {lang_code}: {end_time - start_time:.2f}s")
77
- except Exception as e:
78
- print(f"❌ Failed to load ASR model for {lang_code}: {e}")
79
-
80
- async def ensure_model_loaded(self, language_code: str):
81
- """Load ASR model for language if not already loaded with LRU cache"""
82
- if language_code in self.model_cache:
83
- # Move to end (most recently used)
84
- self.model_cache.move_to_end(language_code)
85
- return
86
-
87
- if language_code not in self.asr_config:
88
- raise ValueError(f"Language {language_code} not supported")
89
-
90
- model_config = self.asr_config[language_code]
91
-
92
- # Check if we need to evict old models
93
- while len(self.model_cache) >= self.max_asr_models:
94
- # Remove least recently used model
95
- old_lang, _ = self.model_cache.popitem(last=False)
96
- if old_lang in self.asr_models:
97
- del self.asr_models[old_lang]
98
- if old_lang in self.processors:
99
- del self.processors[old_lang]
100
- print(f"🗑️ ONNX ASR: Evicted model for {old_lang} (LRU cache)")
101
-
102
- try:
103
- if model_config.get("use_onnx", False):
104
- # Load ONNX model
105
- print(f"📥 ONNX ASR: Loading ONNX model for {language_code}")
106
-
107
- # Special handling for Whisper models
108
- if model_config.get("model_type") == "whisper":
109
- print(f"🎙️ ONNX ASR: Loading pre-converted Whisper ONNX model from {model_config['model_repo']}")
110
-
111
- # Load pre-converted Whisper ONNX model using Optimum
112
- load_kwargs = {
113
- # Note: No 'export' parameter needed since model is already in ONNX format
114
- # This is the key optimization - no runtime conversion!
115
- }
116
-
117
- # Add subfolder if specified (for models that store ONNX in subfolders)
118
- if "subfolder" in model_config:
119
- load_kwargs["subfolder"] = model_config["subfolder"]
120
-
121
- # ⭐ KEY OPTIMIZATION: No export flag needed for pre-converted models
122
- # The old code had: if model_config.get("export", False): load_kwargs["export"] = True
123
- # Now we skip this entirely since the model is already in ONNX format
124
-
125
- model = ORTModelForSpeechSeq2Seq.from_pretrained(
126
- model_config["model_repo"],
127
- **load_kwargs
128
- )
129
-
130
- # Load Whisper processor
131
- processor = WhisperProcessor.from_pretrained(model_config["model_repo"])
132
-
133
- # Configure for English transcription
134
- model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
135
- language="en",
136
- task="transcribe"
137
- )
138
-
139
- self.asr_models[language_code] = model
140
- self.processors[language_code] = processor
141
-
142
- print(f"✅ ONNX ASR: Successfully loaded pre-converted Whisper ONNX model for {language_code}")
143
-
144
- else:
145
- # Original wav2vec2-bert model loading logic (unchanged)
146
- # Create ONNX session with optimizations
147
- session_options = ort.SessionOptions()
148
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
149
-
150
- # Enable parallel execution
151
- session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
152
-
153
- model_path = model_config["model_repo"]
154
-
155
- try:
156
- # Try to load from HuggingFace directly
157
- from huggingface_hub import hf_hub_download
158
- model_file = hf_hub_download(repo_id=model_path, filename="model.onnx")
159
-
160
- # Create ONNX Runtime session
161
- session = ort.InferenceSession(
162
- model_file,
163
- session_options,
164
- providers=self.providers
165
- )
166
-
167
- # Load processor/tokenizer
168
- processor = AutoProcessor.from_pretrained(model_path)
169
-
170
- self.asr_models[language_code] = session
171
- self.processors[language_code] = processor
172
-
173
- print(f"✅ ONNX ASR: Successfully loaded {model_config['model_type']} ONNX model for {language_code}")
174
-
175
- except Exception as e:
176
- print(f"❌ Error loading ONNX model {model_path}: {e}")
177
- raise
178
-
179
- else:
180
- raise ValueError(f"Non-ONNX models not supported in optimized service")
181
-
182
- # Add to cache
183
- self.model_cache[language_code] = True
184
-
185
- except Exception as e:
186
- print(f"❌ Error loading model for {language_code}: {e}")
187
- raise
188
-
189
- # Rest of the methods remain the same as the original transcription service
190
- # (transcribe_audio, process_audio_chunk, etc.)
191
- # ... [Include all other methods from the original service]
192
-
193
- async def transcribe_audio(self, participant_id: str, audio_data: bytes, language_code: str = "eng") -> Optional[str]:
194
- """Transcribe audio using ONNX models"""
195
- try:
196
- await self.ensure_model_loaded(language_code)
197
-
198
- if language_code not in self.asr_models or language_code not in self.processors:
199
- raise ValueError(f"Model not loaded for language: {language_code}")
200
-
201
- model = self.asr_models[language_code]
202
- processor = self.processors[language_code]
203
-
204
- # Convert audio bytes to numpy array
205
- audio_io = io.BytesIO(audio_data)
206
- with wave.open(audio_io, 'rb') as wav_file:
207
- frames = wav_file.readframes(-1)
208
- sample_rate = wav_file.getframerate()
209
- audio_np = np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0
210
-
211
- # Get model configuration
212
- model_config = self.asr_config[language_code]
213
-
214
- if model_config.get("model_type") == "whisper":
215
- # Process with Whisper ONNX model
216
- inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="pt")
217
-
218
- with torch.no_grad():
219
- predicted_ids = model.generate(**inputs)
220
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
221
-
222
- return transcription.strip()
223
-
224
- else:
225
- # Process with wav2vec2-bert ONNX model
226
- inputs = processor(audio_np, sampling_rate=sample_rate, return_tensors="np")
227
-
228
- # Run ONNX inference
229
- ort_inputs = {model.get_inputs()[0].name: inputs.input_values}
230
- ort_outputs = model.run(None, ort_inputs)
231
-
232
- # Decode results
233
- predicted_ids = np.argmax(ort_outputs[0], axis=-1)
234
- transcription = processor.decode(predicted_ids[0])
235
-
236
- return transcription.strip()
237
-
238
- except Exception as e:
239
- print(f"❌ Transcription error for {participant_id}: {e}")
240
- return None
241
-
242
- def get_performance_stats(self) -> Dict[str, any]:
243
- """Get performance statistics for monitoring"""
244
- return {
245
- "loaded_models": list(self.model_cache.keys()),
246
- "cache_size": len(self.model_cache),
247
- "max_cache_size": self.max_asr_models,
248
- "providers": self.providers,
249
- "optimization_enabled": True,
250
- "runtime_conversion": False # Key metric: no runtime conversion
251
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/translation_service.py DELETED
@@ -1,151 +0,0 @@
1
- import asyncio
2
- from typing import Dict, Optional
3
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
4
- import torch
5
- import nltk
6
- from app.models import LanguageCode
7
- from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
8
-
9
- # FLORES-200 language codes mapping
10
- FLORES_CODES = {
11
- "English": "eng_Latn",
12
- "eng": "eng_Latn",
13
- "Swahili": "swh_Latn",
14
- "swa": "swh_Latn",
15
- "Kikuyu": "kik_Latn",
16
- "kik": "kik_Latn",
17
- "Kamba": "kam_Latn",
18
- "kam": "kam_Latn",
19
- "Kimeru": "mer_Latn",
20
- "mer": "mer_Latn",
21
- "Luo": "luo_Latn",
22
- "luo": "luo_Latn",
23
- "Somali": "som_Latn",
24
- "som": "som_Latn",
25
- }
26
-
27
- class TranslationService:
28
- def __init__(self, enable_quantization: bool = True):
29
- self.translation_pipeline = None
30
- self.device = 0 if torch.cuda.is_available() else -1
31
- self.model_path = "mutisya/nllb_600m-en-kik-kam-luo-mer-som-swh-drL-24_5-filtered-v24_28_4"
32
- self.enable_quantization = enable_quantization
33
-
34
- async def initialize(self):
35
- """Initialize translation model"""
36
- try:
37
- # Download NLTK data with better error handling
38
- try:
39
- nltk.download("punkt", quiet=True)
40
- nltk.download('punkt_tab', quiet=True)
41
- except Exception as nltk_error:
42
- print(f"Warning: NLTK data download failed: {nltk_error}")
43
- # Continue anyway, sentence tokenization might still work
44
-
45
- # Load translation model with explicit model kwargs for newer transformers
46
- print(f"Loading translation model: {self.model_path}")
47
- model = AutoModelForSeq2SeqLM.from_pretrained(
48
- self.model_path,
49
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
50
- )
51
- tokenizer = AutoTokenizer.from_pretrained(self.model_path)
52
-
53
- # Apply quantization if enabled
54
- if self.enable_quantization:
55
- try:
56
- print("Applying INT8 quantization to translation model...")
57
- model = apply_dynamic_int8_quantization(model, "translation")
58
- stats = get_quantization_stats(model)
59
- print(f"✓ Translation model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
60
- except Exception as e:
61
- print(f"Warning: Could not quantize translation model: {e}")
62
- print(f"Continuing with unquantized model")
63
-
64
- self.translation_pipeline = pipeline(
65
- 'translation',
66
- model=model,
67
- tokenizer=tokenizer,
68
- device=self.device,
69
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
70
- )
71
-
72
- except Exception as e:
73
- print(f"Failed to initialize translation service: {e}")
74
- raise
75
-
76
- async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
77
- """Translate text from source language to target language"""
78
- print(f"=== TRANSLATION REQUEST ===")
79
- print(f"Text: '{text}'")
80
- print(f"Source: {source_lang}")
81
- print(f"Target: {target_lang}")
82
-
83
- if not self.translation_pipeline:
84
- print("TRANSLATION ERROR: Translation service not initialized")
85
- raise RuntimeError("Translation service not initialized")
86
-
87
- if not text or not text.strip():
88
- print("TRANSLATION ERROR: Empty text provided")
89
- return ""
90
-
91
- try:
92
- # Get FLORES codes
93
- src_code = FLORES_CODES.get(source_lang, "eng_Latn")
94
- tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
95
-
96
- print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
97
-
98
- # Skip translation if same language
99
- if src_code == tgt_code:
100
- print("TRANSLATION SKIPPED: Same source and target language")
101
- return text
102
-
103
- # Tokenize into sentences for better translation
104
- sentences = nltk.sent_tokenize(text)
105
- translated_sentences = []
106
-
107
- print(f"Translating {len(sentences)} sentences...")
108
-
109
- for i, sentence in enumerate(sentences):
110
- if sentence.strip():
111
- print(f"Translating sentence {i+1}: '{sentence}'")
112
-
113
- result = self.translation_pipeline(
114
- sentence,
115
- src_lang=src_code,
116
- tgt_lang=tgt_code
117
- )
118
-
119
- translated = result[0]['translation_text']
120
- print(f"Translation result: '{translated}'")
121
-
122
- # Preserve punctuation and capitalization
123
- if sentence.strip().endswith(".") and not translated.strip().endswith("."):
124
- translated += "."
125
-
126
- if sentence.strip()[0].isupper() and translated.strip():
127
- translated = translated[0].upper() + translated[1:]
128
-
129
- translated_sentences.append(translated)
130
-
131
- final_translation = " ".join(translated_sentences)
132
-
133
- # Preserve paragraph breaks
134
- if text.endswith(".\n\n"):
135
- final_translation += ".\n\n"
136
-
137
- print(f"FINAL TRANSLATION: '{final_translation}'")
138
- print(f"=== TRANSLATION COMPLETE ===")
139
-
140
- return final_translation
141
-
142
- except Exception as e:
143
- print(f"TRANSLATION ERROR: {e}")
144
- import traceback
145
- traceback.print_exc()
146
- return text # Return original text if translation fails
147
-
148
- async def cleanup(self):
149
- """Cleanup resources"""
150
- self.translation_pipeline = None
151
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/translation_service_onnx.py DELETED
@@ -1,268 +0,0 @@
1
- import asyncio
2
- from typing import Dict, Optional
3
- from transformers import AutoTokenizer, pipeline
4
- from optimum.onnxruntime import ORTModelForSeq2SeqLM
5
- import nltk
6
- from app.models import LanguageCode
7
-
8
- # FLORES-200 language codes mapping
9
- FLORES_CODES = {
10
- "English": "eng_Latn",
11
- "eng": "eng_Latn",
12
- "Swahili": "swh_Latn",
13
- "swa": "swh_Latn",
14
- "Kikuyu": "kik_Latn",
15
- "kik": "kik_Latn",
16
- "Kamba": "kam_Latn",
17
- "kam": "kam_Latn",
18
- "Kimeru": "mer_Latn",
19
- "mer": "mer_Latn",
20
- "Luo": "luo_Latn",
21
- "luo": "luo_Latn",
22
- "Somali": "som_Latn",
23
- "som": "som_Latn",
24
- }
25
-
26
- class ONNXTranslationService:
27
- def __init__(self):
28
- self.model = None
29
- self.tokenizer = None
30
- self.translation_pipeline = None
31
-
32
- # Use ONNX optimized NLLB model (FP32 format with separate encoder/decoder)
33
- self.model_repo = "mutisya/nllb-translation-onnx-v25-37-1"
34
-
35
- async def initialize(self):
36
- """Initialize ONNX translation model using optimum.onnxruntime"""
37
- try:
38
- print("ONNX Translation: Initializing translation service with ONNX Runtime...")
39
- print(f"ONNX Translation: Loading model from {self.model_repo}")
40
-
41
- # Check available providers for GPU detection
42
- import onnxruntime as ort
43
- available_providers = ort.get_available_providers()
44
- print(f"ONNX Translation: Available providers: {available_providers}")
45
-
46
- # Download NLTK data with better error handling
47
- try:
48
- nltk.download("punkt", quiet=True)
49
- nltk.download('punkt_tab', quiet=True)
50
- except Exception as nltk_error:
51
- print(f"Warning: NLTK data download failed: {nltk_error}")
52
-
53
- # Get authentication token for private repo
54
- import os
55
- auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
56
-
57
- # Configure providers list for optimal performance
58
- print("ONNX Translation: Configuring execution providers...")
59
- if 'CUDAExecutionProvider' in available_providers:
60
- # Use both CUDA and CPU providers to eliminate assignment warnings
61
- providers_list = ['CUDAExecutionProvider', 'CPUExecutionProvider']
62
- primary_provider = 'CUDAExecutionProvider'
63
- print(f"ONNX Translation: Using providers: {providers_list} (primary: {primary_provider})")
64
- else:
65
- providers_list = ['CPUExecutionProvider']
66
- primary_provider = 'CPUExecutionProvider'
67
- print(f"ONNX Translation: Using CPU-only providers: {providers_list}")
68
-
69
- # Load ONNX model using optimum (handles separate encoder/decoder files)
70
- # Configure session options for optimal CUDA performance
71
- import onnxruntime as ort
72
- session_options = ort.SessionOptions()
73
- session_options.log_severity_level = 1 # WARNING level for detailed logs
74
- session_options.logid = "ONNX_Translation"
75
-
76
- # Enable all graph optimizations to reduce memcpy operations
77
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
78
-
79
- # Optimize threading for better GPU utilization
80
- session_options.inter_op_num_threads = 1 # Reduce CPU thread contention
81
- session_options.intra_op_num_threads = 1 # Focus on GPU execution
82
-
83
- # Note: enable_cuda_graph not available in this ONNX Runtime version
84
-
85
- # Configure provider options with performance optimizations for CUDA
86
- provider_options = []
87
- if primary_provider == 'CUDAExecutionProvider':
88
- cuda_options = {
89
- 'device_id': 0,
90
- 'arena_extend_strategy': 'kNextPowerOfTwo',
91
- 'gpu_mem_limit': int(0.6 * 1024 * 1024 * 1024), # 60% of GPU memory for translation
92
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
93
- 'cudnn_conv_use_max_workspace': '1', # Enable max workspace for fp16 tensor cores
94
- 'do_copy_in_default_stream': True,
95
- 'enable_skip_layer_norm_strict_mode': False, # Better performance for transformers
96
- 'prefer_nhwc': True, # Optimize data layout for GPU
97
- }
98
- # Configure providers with options
99
- provider_options = [
100
- ('CUDAExecutionProvider', cuda_options),
101
- ('CPUExecutionProvider', {})
102
- ]
103
-
104
- # Try with optimized provider configuration and session options
105
- try:
106
- print("ONNX Translation: Attempting optimized provider configuration...")
107
- self.model = ORTModelForSeq2SeqLM.from_pretrained(
108
- self.model_repo,
109
- token=auth_token,
110
- providers=provider_options if provider_options else providers_list, # Use provider options or list
111
- session_options=session_options, # Add session options
112
- )
113
- print(f"ONNX Translation: Model loaded successfully with providers: {providers_list}")
114
-
115
- # Check what providers the model is actually using
116
- if hasattr(self.model, 'providers'):
117
- print(f"ONNX Translation: Model is using providers: {self.model.providers}")
118
- if hasattr(self.model, 'device'):
119
- print(f"ONNX Translation: Model device: {self.model.device}")
120
-
121
- except Exception as e1:
122
- print(f"ONNX Translation: Optimized provider approach failed: {e1}")
123
- print("ONNX Translation: Falling back to simple provider list...")
124
-
125
- # Fallback: Try with simple provider list (no options)
126
- try:
127
- self.model = ORTModelForSeq2SeqLM.from_pretrained(
128
- self.model_repo,
129
- token=auth_token,
130
- providers=providers_list, # Simple provider list
131
- session_options=session_options,
132
- )
133
- print(f"ONNX Translation: Model loaded successfully with simple providers: {providers_list}")
134
-
135
- # Check what the model is actually using
136
- if hasattr(self.model, 'providers'):
137
- print(f"ONNX Translation: Model is using providers: {self.model.providers}")
138
- if hasattr(self.model, 'device'):
139
- print(f"ONNX Translation: Model device: {self.model.device}")
140
-
141
- except Exception as e2:
142
- print(f"ONNX Translation: Simple provider approach failed: {e2}")
143
- print("ONNX Translation: Falling back to auto-detect...")
144
-
145
- # Final fallback: Let model auto-detect
146
- self.model = ORTModelForSeq2SeqLM.from_pretrained(
147
- self.model_repo,
148
- token=auth_token
149
- # Not passing provider, letting it auto-detect based on device
150
- )
151
- print(f"ONNX Translation: Model loaded successfully with auto-detection")
152
-
153
- # Check what the model is actually using
154
- if hasattr(self.model, 'providers'):
155
- print(f"ONNX Translation: Model auto-selected providers: {self.model.providers}")
156
- if hasattr(self.model, 'device'):
157
- print(f"ONNX Translation: Model device: {self.model.device}")
158
-
159
- # Load tokenizer
160
- self.tokenizer = AutoTokenizer.from_pretrained(
161
- self.model_repo,
162
- token=auth_token
163
- )
164
-
165
- # Create translation pipeline
166
- # For ONNX models, we should specify device to ensure pipeline uses GPU
167
- # Use the same provider detection as the model to ensure consistency
168
- device = 0 if primary_provider == 'CUDAExecutionProvider' else -1
169
- print(f"ONNX Translation: Setting pipeline device to: {device} ({'GPU' if device >= 0 else 'CPU'})")
170
- print(f"ONNX Translation: Pipeline will use device based on primary provider: {primary_provider}")
171
-
172
- self.translation_pipeline = pipeline(
173
- "translation",
174
- model=self.model,
175
- tokenizer=self.tokenizer,
176
- device=device
177
- )
178
-
179
- print("ONNX Translation: Successfully initialized ONNX translation model")
180
-
181
- except Exception as e:
182
- print(f"Failed to initialize ONNX translation service: {e}")
183
- print("ONNX translation model is not available. Please ensure the model repository exists and contains the required ONNX files.")
184
- import traceback
185
- traceback.print_exc()
186
- raise RuntimeError(f"ONNX translation model unavailable at {self.model_repo}: {e}")
187
-
188
- async def translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
189
- """Translate text from source language to target language using ONNX"""
190
- print(f"=== ONNX TRANSLATION REQUEST ===")
191
- print(f"Text: '{text}'")
192
- print(f"Source: {source_lang}")
193
- print(f"Target: {target_lang}")
194
-
195
- if not self.translation_pipeline:
196
- print("ONNX TRANSLATION ERROR: Translation service not initialized")
197
- raise RuntimeError("ONNX Translation service not initialized")
198
-
199
- if not text or not text.strip():
200
- print("ONNX TRANSLATION ERROR: Empty text provided")
201
- return ""
202
-
203
- try:
204
- # Get FLORES codes
205
- src_code = FLORES_CODES.get(source_lang, "eng_Latn")
206
- tgt_code = FLORES_CODES.get(target_lang, "eng_Latn")
207
-
208
- print(f"FLORES codes: {source_lang} -> {src_code}, {target_lang} -> {tgt_code}")
209
-
210
- # Skip translation if same language
211
- if src_code == tgt_code:
212
- print("ONNX TRANSLATION SKIPPED: Same source and target language")
213
- return text
214
-
215
- # Tokenize into sentences for better translation
216
- sentences = nltk.sent_tokenize(text)
217
- translated_sentences = []
218
-
219
- print(f"Translating {len(sentences)} sentences with ONNX...")
220
-
221
- for i, sentence in enumerate(sentences):
222
- if sentence.strip():
223
- print(f"Translating sentence {i+1}: '{sentence}'")
224
-
225
- # Use the pipeline for translation
226
- result = self.translation_pipeline(
227
- sentence.strip(),
228
- src_lang=src_code,
229
- tgt_lang=tgt_code,
230
- max_length=512
231
- )
232
-
233
- translated = result[0]['translation_text']
234
- print(f"ONNX Translation result: '{translated}'")
235
-
236
- # Preserve punctuation and capitalization
237
- if sentence.strip().endswith(".") and not translated.strip().endswith("."):
238
- translated += "."
239
-
240
- if sentence.strip() and sentence.strip()[0].isupper() and translated.strip():
241
- translated = translated[0].upper() + translated[1:]
242
-
243
- translated_sentences.append(translated)
244
-
245
- final_translation = " ".join(translated_sentences)
246
-
247
- # Preserve paragraph breaks
248
- if text.endswith(".\n\n"):
249
- final_translation += ".\n\n"
250
-
251
- print(f"ONNX FINAL TRANSLATION: '{final_translation}'")
252
- print(f"=== ONNX TRANSLATION COMPLETE ===")
253
-
254
- return final_translation
255
-
256
- except Exception as e:
257
- print(f"ONNX TRANSLATION ERROR: {e}")
258
- import traceback
259
- traceback.print_exc()
260
- raise RuntimeError(f"Translation failed: {e}")
261
-
262
-
263
- async def cleanup(self):
264
- """Cleanup resources"""
265
- self.model = None
266
- self.tokenizer = None
267
- self.translation_pipeline = None
268
- print("ONNX Translation: Translation service cleaned up")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/tts_service.py DELETED
@@ -1,541 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- import subprocess
6
- from typing import Dict, Optional
7
- from transformers import pipeline
8
- import torch
9
- import os
10
- from app.services.quantization_utils import apply_dynamic_int8_quantization, get_quantization_stats
11
-
12
- class TTSService:
13
- def __init__(self, enable_quantization: bool = True):
14
- self.tts_pipelines: Dict[str, any] = {}
15
- self.device = 0 if torch.cuda.is_available() else -1
16
- self.enable_quantization = enable_quantization
17
-
18
- # Check if espeak is available
19
- self.espeak_available = self._check_espeak_availability()
20
-
21
- # TTS model configurations from your original code
22
- self.tts_config = {
23
- "kik": {"model_repo": "mutisya/vits_kik_drL_24_5-v24_27_1_f", "model_type": "vits"},
24
- "luo": {"model_repo": "mutisya/vits_luo_drL_24_5-v24_27_1_f", "model_type": "vits"},
25
- "kam": {"model_repo": "mutisya/vits_kam_drL_24_5-v24_27_1_f", "model_type": "vits"},
26
- "mer": {"model_repo": "mutisya/vits_mer_drL_24_5-v24_27_1_f", "model_type": "vits"},
27
- "som": {"model_repo": "mutisya/vits_som_drL_24_5-v24_27_1_m", "model_type": "vits"},
28
- "swa": {"model_repo": "mutisya/vits_swh_biblica-v24_27_1_m", "model_type": "vits"},
29
- "eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits"},
30
- }
31
-
32
- # Alternative TTS models that don't require espeak (fallback)
33
- self.fallback_tts_config = {
34
- "eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
35
- "swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
36
- "som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
37
- }
38
-
39
- self.preload_languages = ["kik", "swa"]
40
- self.background_loading_task = None
41
- self.models_loading_status = {}
42
-
43
- def _check_espeak_availability(self) -> bool:
44
- """Check if espeak is available on the system"""
45
- try:
46
- result = subprocess.run(['espeak', '--version'],
47
- capture_output=True, text=True, timeout=5)
48
- if result.returncode == 0:
49
- print("TTS: espeak is available")
50
- return True
51
- else:
52
- print("TTS: espeak command failed")
53
- return False
54
- except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
55
- print(f"TTS: espeak not available: {e}")
56
- return False
57
-
58
- async def initialize(self):
59
- """Initialize TTS models for preloaded languages"""
60
- print("TTS: Initializing TTS service...")
61
- print(f"TTS: espeak available: {self.espeak_available}")
62
-
63
- for lang_code in self.preload_languages:
64
- await self.ensure_model_loaded(lang_code)
65
-
66
- def _load_and_quantize_tts_pipeline(self, lang_code: str, model_repo: str, model_type: str = "vits"):
67
- """Load TTS pipeline and optionally apply INT8 quantization"""
68
- print(f"TTS: Loading model for {lang_code}: {model_repo}")
69
-
70
- pipeline_obj = pipeline(
71
- "text-to-speech",
72
- model=model_repo,
73
- device=self.device
74
- )
75
-
76
- # Apply quantization if enabled
77
- if self.enable_quantization:
78
- try:
79
- # Get the underlying model from the pipeline
80
- model = pipeline_obj.model
81
-
82
- print(f"TTS: Applying INT8 quantization to {lang_code} model...")
83
- quantized_model = apply_dynamic_int8_quantization(model, model_type)
84
-
85
- # Replace the model in the pipeline
86
- pipeline_obj.model = quantized_model
87
-
88
- # Print quantization stats
89
- stats = get_quantization_stats(quantized_model)
90
- print(f"✓ TTS {lang_code} model quantized: {stats['quantized_layers']}/{stats['total_layers']} layers, {stats['size_mb']:.2f} MB")
91
-
92
- except Exception as e:
93
- print(f"TTS: Warning - Could not quantize {lang_code} model: {e}")
94
- print(f"TTS: Continuing with unquantized model")
95
-
96
- return pipeline_obj
97
-
98
-
99
- async def ensure_model_loaded(self, language_code: str):
100
- """Load TTS model for language if not already loaded"""
101
- if language_code in self.tts_pipelines:
102
- return
103
-
104
- # First try to load primary model if espeak is available
105
- if self.espeak_available and language_code in self.tts_config:
106
- try:
107
- model_config = self.tts_config[language_code]
108
- pipeline_obj = self._load_and_quantize_tts_pipeline(
109
- language_code,
110
- model_config["model_repo"],
111
- model_config.get("model_type", "vits")
112
- )
113
- self.tts_pipelines[language_code] = pipeline_obj
114
- print(f"TTS: Loaded primary TTS model for {language_code}")
115
- return
116
- except Exception as e:
117
- print(f"TTS: Failed to load primary TTS model for {language_code}: {e}")
118
- # Continue to try fallback models
119
-
120
- # Try fallback models if primary failed or espeak not available
121
- if language_code in self.fallback_tts_config:
122
- try:
123
- model_config = self.fallback_tts_config[language_code]
124
-
125
- if model_config["model_type"] == "speecht5":
126
- # Special handling for SpeechT5
127
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
128
- import torch
129
-
130
- processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
131
- model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
132
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
133
-
134
- # Create a custom pipeline-like object
135
- class SpeechT5Pipeline:
136
- def __init__(self, processor, model, vocoder):
137
- self.processor = processor
138
- self.model = model
139
- self.vocoder = vocoder
140
-
141
- def __call__(self, text):
142
- inputs = self.processor(text=text, return_tensors="pt")
143
- # Use default speaker embeddings
144
- import datasets
145
- embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
146
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
147
-
148
- speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
149
-
150
- return {
151
- "audio": speech.numpy(),
152
- "sampling_rate": 16000
153
- }
154
-
155
- pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
156
- else:
157
- # Standard pipeline for MMS models
158
- pipeline_obj = pipeline(
159
- "text-to-speech",
160
- model=model_config["model_repo"],
161
- device=self.device
162
- )
163
-
164
- self.tts_pipelines[language_code] = pipeline_obj
165
- print(f"TTS: Loaded fallback TTS model for {language_code}")
166
- return
167
-
168
- except Exception as e:
169
- print(f"TTS: Failed to load fallback TTS model for {language_code}: {e}")
170
-
171
- print(f"TTS: No TTS model available for language: {language_code}")
172
-
173
- async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
174
- """Generate speech audio from text
175
-
176
- Args:
177
- text: Text to convert to speech
178
- language_code: Language code for TTS model
179
- output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
180
-
181
- Returns:
182
- Audio bytes in the requested format, or None if generation fails
183
- """
184
- try:
185
- print(f"=== TTS GENERATION REQUEST ===")
186
- print(f"Text: '{text}'")
187
- print(f"Language: {language_code}")
188
- print(f"Output format: {output_format}")
189
-
190
- # Input validation: Check for invalid or problematic text
191
- if not text or not text.strip():
192
- print("TTS: Empty or whitespace-only text, skipping TTS generation")
193
- return None
194
-
195
- # Check for very short text that might cause issues
196
- clean_text = text.strip()
197
- if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
198
- print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
199
- return None
200
-
201
- # Check for minimum meaningful length
202
- if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
203
- print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
204
- return None
205
-
206
- print(f"TTS pipelines available: {list(self.tts_pipelines.keys())}")
207
- print(f"TTS config available: {list(self.tts_config.keys())}")
208
- print(f"Fallback config available: {list(self.fallback_tts_config.keys())}")
209
-
210
- # Check if the language is supported
211
- if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
212
- print(f"TTS: Language {language_code} not configured for TTS")
213
- return None
214
-
215
- await self.ensure_model_loaded(language_code)
216
-
217
- if language_code not in self.tts_pipelines:
218
- print(f"TTS: TTS model not available for language: {language_code}")
219
- return None
220
-
221
- if not text or not text.strip():
222
- print("TTS: Empty text provided")
223
- return None
224
-
225
- print(f"TTS: Generating speech for '{text}' in {language_code}")
226
-
227
- # Generate speech
228
- pipeline_obj = self.tts_pipelines[language_code]
229
- result = pipeline_obj(text)
230
-
231
- audio_array = result["audio"]
232
- sample_rate = result.get("sampling_rate", 22050)
233
-
234
- print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
235
-
236
- # Validate audio array
237
- if len(audio_array) == 0:
238
- print("TTS: Warning - Generated audio array is empty")
239
- return None
240
-
241
- # Check for potential issues with audio data
242
- audio_min = np.min(audio_array)
243
- audio_max = np.max(audio_array)
244
- audio_rms = np.sqrt(np.mean(audio_array**2))
245
- print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
246
-
247
- # Check if audio might be silent or corrupted
248
- if audio_rms < 0.001:
249
- print("TTS: Warning - Audio appears to be very quiet or silent")
250
- if audio_max > 1.0 or audio_min < -1.0:
251
- print("TTS: Warning - Audio values outside expected range [-1, 1]")
252
- # Clip to valid range
253
- audio_array = np.clip(audio_array, -1.0, 1.0)
254
- print("TTS: Clipped audio to valid range")
255
-
256
- # Convert to WAV bytes with appropriate sample rate
257
- if output_format == "wav":
258
- # For Android: use 16kHz sample rate
259
- target_sample_rate = 16000
260
- wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
261
- print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
262
-
263
- # Convert sample rate to 16kHz if needed for Android compatibility
264
- if sample_rate != target_sample_rate:
265
- print(f"TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz for Android compatibility")
266
- wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
267
- print(f"TTS: Resampled WAV: {len(wav_bytes)} bytes")
268
-
269
- print(f"TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
270
- print(f"=== TTS GENERATION COMPLETE ===")
271
-
272
- return wav_bytes
273
- else:
274
- # For web: use original sample rate and convert to WebM
275
- wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
276
- print(f"TTS: Converted to WAV: {len(wav_bytes)} bytes")
277
-
278
- # Convert to WebM format for web compatibility
279
- webm_bytes = await self._convert_to_webm(wav_bytes)
280
-
281
- print(f"TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
282
- print(f"=== TTS GENERATION COMPLETE ===")
283
-
284
- return webm_bytes
285
-
286
- except Exception as e:
287
- print(f"TTS: TTS generation error: {e}")
288
- import traceback
289
- traceback.print_exc()
290
- return None
291
-
292
- async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
293
- """Generate speech audio in both WebM and WAV formats
294
-
295
- Args:
296
- text: Text to convert to speech
297
- language_code: Language code for TTS model
298
-
299
- Returns:
300
- Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
301
- """
302
- try:
303
- print(f"=== TTS DUAL FORMAT GENERATION REQUEST ===")
304
- print(f"Text: '{text}'")
305
- print(f"Language: {language_code}")
306
-
307
- # Input validation: Check for invalid or problematic text
308
- if not text or not text.strip():
309
- print("TTS: Empty or whitespace-only text, skipping TTS generation")
310
- return None, None
311
-
312
- # Check for very short text that might cause issues
313
- clean_text = text.strip()
314
- if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
315
- print(f"TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
316
- return None, None
317
-
318
- # Check for minimum meaningful length
319
- if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
320
- print(f"TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
321
- return None, None
322
-
323
- # Check if the language is supported
324
- if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
325
- print(f"TTS: Language {language_code} not configured for TTS")
326
- return None, None
327
-
328
- await self.ensure_model_loaded(language_code)
329
-
330
- if language_code not in self.tts_pipelines:
331
- print(f"TTS: TTS model not available for language: {language_code}")
332
- return None, None
333
-
334
- print(f"TTS: Generating speech for '{text}' in {language_code}")
335
-
336
- # Generate speech once
337
- pipeline_obj = self.tts_pipelines[language_code]
338
- result = pipeline_obj(text)
339
-
340
- audio_array = result["audio"]
341
- sample_rate = result.get("sampling_rate", 22050)
342
-
343
- print(f"TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
344
-
345
- # Validate audio array
346
- if len(audio_array) == 0:
347
- print("TTS: Warning - Generated audio array is empty")
348
- return None, None
349
-
350
- # Check for potential issues with audio data
351
- audio_min = np.min(audio_array)
352
- audio_max = np.max(audio_array)
353
- audio_rms = np.sqrt(np.mean(audio_array**2))
354
- print(f"TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
355
-
356
- # Check if audio might be silent or corrupted
357
- if audio_rms < 0.001:
358
- print("TTS: Warning - Audio appears to be very quiet or silent")
359
- if audio_max > 1.0 or audio_min < -1.0:
360
- print("TTS: Warning - Audio values outside expected range [-1, 1]")
361
- # Clip to valid range
362
- audio_array = np.clip(audio_array, -1.0, 1.0)
363
- print("TTS: Clipped audio to valid range")
364
-
365
- # Generate WAV at original sample rate first
366
- wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
367
- print(f"TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
368
-
369
- # Generate WebM from original WAV
370
- webm_bytes = await self._convert_to_webm(wav_bytes_original)
371
- print(f"TTS: Converted to WebM: {len(webm_bytes)} bytes")
372
-
373
- # Generate 16kHz WAV for Android
374
- wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
375
- print(f"TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
376
-
377
- print(f"TTS: Generated dual format audio for '{text}'")
378
- print(f"=== TTS DUAL FORMAT GENERATION COMPLETE ===")
379
-
380
- return webm_bytes, wav_bytes_16k
381
-
382
- except Exception as e:
383
- print(f"TTS: Dual format TTS generation error: {e}")
384
- import traceback
385
- traceback.print_exc()
386
- return None, None
387
-
388
- def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
389
- """Convert numpy audio array to WAV bytes"""
390
- buffer = io.BytesIO()
391
- with wave.open(buffer, 'wb') as wav_file:
392
- wav_file.setnchannels(1) # Mono
393
- wav_file.setsampwidth(2) # 16-bit
394
- wav_file.setframerate(sample_rate)
395
-
396
- # Ensure audio is in valid range [-1, 1]
397
- audio_array = np.clip(audio_array, -1.0, 1.0)
398
-
399
- # Convert to int16 with proper scaling
400
- int16_audio = (audio_array * 32767).astype(np.int16)
401
-
402
- # Validate the converted audio
403
- print(f"TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
404
- print(f"TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
405
-
406
- wav_file.writeframes(int16_audio.tobytes())
407
-
408
- wav_data = buffer.getvalue()
409
- print(f"TTS: WAV file created: {len(wav_data)} bytes (expected header: 44 bytes + {len(int16_audio) * 2} data bytes)")
410
-
411
- return wav_data
412
-
413
- async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
414
- """Resample WAV audio to 16kHz using FFmpeg"""
415
- try:
416
- process = subprocess.Popen([
417
- "ffmpeg", "-f", "wav", "-i", "pipe:0",
418
- "-ar", "16000", # Set output sample rate to 16kHz
419
- "-ac", "1", # Ensure mono output
420
- "-f", "wav", "pipe:1"
421
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
422
-
423
- resampled_data, stderr = process.communicate(input=wav_bytes)
424
-
425
- if process.returncode != 0:
426
- print(f"TTS: FFmpeg resampling error: {stderr.decode()}")
427
- return wav_bytes # Return original if resampling fails
428
-
429
- return resampled_data
430
-
431
- except Exception as e:
432
- print(f"TTS: Resampling error: {e}")
433
- return wav_bytes # Return original if resampling fails
434
-
435
- async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
436
- """Convert WAV bytes to WebM format using FFmpeg"""
437
- try:
438
- process = subprocess.Popen([
439
- "ffmpeg", "-f", "wav", "-i", "pipe:0",
440
- "-c:a", "libopus", "-b:a", "64k",
441
- "-f", "webm", "pipe:1"
442
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
443
-
444
- webm_data, stderr = process.communicate(input=wav_bytes)
445
-
446
- if process.returncode != 0:
447
- print(f"TTS: FFmpeg error: {stderr.decode()}")
448
- return wav_bytes # Return original WAV if conversion fails
449
-
450
- return webm_data
451
-
452
- except Exception as e:
453
- print(f"TTS: WebM conversion error: {e}")
454
- return wav_bytes # Return original WAV if conversion fails
455
-
456
- async def load_remaining_models_in_background(self):
457
- """Load all remaining TTS models in the background after startup"""
458
- try:
459
- print("TTS: Starting background loading of additional voice models...")
460
-
461
- # Load primary models first
462
- for lang_code in self.tts_config.keys():
463
- if lang_code not in self.preload_languages and lang_code not in self.tts_pipelines:
464
- if self.espeak_available:
465
- try:
466
- print(f"TTS: Background loading primary model for {lang_code}...")
467
- self.models_loading_status[lang_code] = "loading"
468
-
469
- model_config = self.tts_config[lang_code]
470
- pipeline_obj = pipeline(
471
- "text-to-speech",
472
- model=model_config["model_repo"],
473
- device=self.device
474
- )
475
- self.tts_pipelines[lang_code] = pipeline_obj
476
- self.models_loading_status[lang_code] = "loaded"
477
- print(f"TTS: Successfully loaded primary model for {lang_code} in background")
478
-
479
- # Add a small delay between loading models
480
- await asyncio.sleep(2)
481
- except Exception as e:
482
- print(f"TTS: Failed to load primary model for {lang_code} in background: {e}")
483
- self.models_loading_status[lang_code] = "failed"
484
-
485
- # Load fallback models for languages not yet loaded
486
- for lang_code in self.fallback_tts_config.keys():
487
- if lang_code not in self.tts_pipelines:
488
- try:
489
- print(f"TTS: Background loading fallback model for {lang_code}...")
490
- model_config = self.fallback_tts_config[lang_code]
491
-
492
- if model_config["model_type"] == "speecht5":
493
- from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor
494
- processor = SpeechT5Processor.from_pretrained(model_config["model_repo"])
495
- model = SpeechT5ForTextToSpeech.from_pretrained(model_config["model_repo"])
496
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
497
- if self.device >= 0:
498
- model = model.to(f"cuda:{self.device}")
499
- vocoder = vocoder.to(f"cuda:{self.device}")
500
- self.tts_pipelines[lang_code] = {
501
- "type": "speecht5",
502
- "processor": processor,
503
- "model": model,
504
- "vocoder": vocoder
505
- }
506
- else:
507
- pipeline_obj = pipeline(
508
- "text-to-speech",
509
- model=model_config["model_repo"],
510
- device=self.device
511
- )
512
- self.tts_pipelines[lang_code] = pipeline_obj
513
-
514
- print(f"TTS: Successfully loaded fallback model for {lang_code} in background")
515
- await asyncio.sleep(2)
516
- except Exception as e:
517
- print(f"TTS: Failed to load fallback model for {lang_code}: {e}")
518
-
519
- print("TTS: Background loading of all voice models complete")
520
- print(f"TTS: Loaded models: {list(self.tts_pipelines.keys())}")
521
- except Exception as e:
522
- print(f"TTS: Error in background model loading: {e}")
523
-
524
- def start_background_loading(self):
525
- """Start background loading of models as a non-blocking task"""
526
- if self.background_loading_task is None:
527
- self.background_loading_task = asyncio.create_task(self.load_remaining_models_in_background())
528
- print("TTS: Background model loading task started")
529
-
530
- async def cleanup(self):
531
- """Cleanup resources"""
532
- # Cancel background loading if still running
533
- if self.background_loading_task and not self.background_loading_task.done():
534
- self.background_loading_task.cancel()
535
- try:
536
- await self.background_loading_task
537
- except asyncio.CancelledError:
538
- pass
539
-
540
- self.tts_pipelines.clear()
541
- print("TTS: TTS service cleaned up")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/tts_service_onnx.py DELETED
@@ -1,587 +0,0 @@
1
- import asyncio
2
- import io
3
- import wave
4
- import numpy as np
5
- import subprocess
6
- from typing import Dict, Optional
7
- import onnxruntime as ort
8
- from transformers import AutoProcessor
9
- from collections import OrderedDict
10
- import os
11
-
12
- class ONNXTTSService:
13
- def __init__(self):
14
- self.tts_models: Dict[str, any] = {}
15
- self.processors: Dict[str, any] = {}
16
- self.max_tts_models = 3 # Keep up to 3 TTS models in memory
17
- self.model_cache = OrderedDict() # LRU cache
18
-
19
- # GPU optimization - detect and configure providers
20
- available_providers = ort.get_available_providers()
21
- print(f"ONNX TTS: Available providers: {available_providers}")
22
-
23
- if 'CUDAExecutionProvider' in available_providers:
24
- # Configure CUDA provider with optimizations
25
- cuda_provider_options = {
26
- 'device_id': 0,
27
- 'arena_extend_strategy': 'kNextPowerOfTwo',
28
- 'gpu_mem_limit': int(0.7 * 1024 * 1024 * 1024), # 70% of GPU memory (TTS uses less than ASR)
29
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
30
- 'do_copy_in_default_stream': True,
31
- }
32
- self.providers = [('CUDAExecutionProvider', cuda_provider_options), 'CPUExecutionProvider']
33
- print(f"ONNX TTS: Using CUDA acceleration with GPU memory limit: {cuda_provider_options['gpu_mem_limit'] // (1024**3)}GB")
34
- else:
35
- self.providers = ['CPUExecutionProvider']
36
- print("ONNX TTS: CUDA not available, using CPU execution")
37
-
38
- print(f"ONNX TTS: Configured providers: {[p[0] if isinstance(p, tuple) else p for p in self.providers]}")
39
-
40
- # Check if espeak is available
41
- self.espeak_available = self._check_espeak_availability()
42
-
43
- # ONNX TTS model configurations - using FP32 optimized models (16kHz corrected)
44
- self.tts_config = {
45
- "kik": {"model_repo": "mutisya/vits-tts-onnx-fp32-kikuyu-v25-37-1", "model_type": "vits", "use_onnx": True},
46
- "luo": {"model_repo": "mutisya/vits-tts-onnx-fp32-luo-v25-37-1", "model_type": "vits", "use_onnx": True},
47
- "kam": {"model_repo": "mutisya/vits-tts-onnx-fp32-kamba-v25-37-1", "model_type": "vits", "use_onnx": True},
48
- "mer": {"model_repo": "mutisya/vits-tts-onnx-fp32-kimeru-v25-37-1", "model_type": "vits", "use_onnx": True},
49
- "som": {"model_repo": "mutisya/vits-tts-onnx-fp32-somali-v25-37-1", "model_type": "vits", "use_onnx": True},
50
- "swa": {"model_repo": "mutisya/vits-tts-onnx-fp32-swahili-v25-37-1", "model_type": "vits", "use_onnx": True},
51
- "eng": {"model_repo": "kakao-enterprise/vits-ljs", "model_type": "vits", "use_onnx": False}, # Fallback to PyTorch
52
- }
53
-
54
- # Alternative TTS models that don't require espeak (fallback)
55
- self.fallback_tts_config = {
56
- "eng": {"model_repo": "microsoft/speecht5_tts", "model_type": "speecht5"},
57
- "swa": {"model_repo": "facebook/mms-tts-swh", "model_type": "mms"},
58
- "som": {"model_repo": "facebook/mms-tts-som", "model_type": "mms"},
59
- }
60
-
61
- self.preload_languages = ["kik", "swa"]
62
-
63
- def _check_espeak_availability(self) -> bool:
64
- """Check if espeak is available on the system"""
65
- try:
66
- result = subprocess.run(['espeak', '--version'],
67
- capture_output=True, text=True, timeout=5)
68
- if result.returncode == 0:
69
- print("ONNX TTS: espeak is available")
70
- return True
71
- else:
72
- print("ONNX TTS: espeak command failed")
73
- return False
74
- except (subprocess.TimeoutExpired, FileNotFoundError, Exception) as e:
75
- print(f"ONNX TTS: espeak not available: {e}")
76
- return False
77
-
78
- async def initialize(self):
79
- """Initialize TTS models for preloaded languages"""
80
- print("ONNX TTS: Initializing TTS service with ONNX Runtime...")
81
- print(f"ONNX TTS: espeak available: {self.espeak_available}")
82
- print(f"ONNX TTS: Using providers: {self.providers}")
83
-
84
- for lang_code in self.preload_languages:
85
- await self.ensure_model_loaded(lang_code)
86
-
87
- async def ensure_model_loaded(self, language_code: str):
88
- """Load TTS model for language if not already loaded with LRU cache"""
89
- if language_code in self.model_cache:
90
- # Move to end (most recently used)
91
- self.model_cache.move_to_end(language_code)
92
- return
93
-
94
- # Check if we need to evict old models
95
- while len(self.model_cache) >= self.max_tts_models:
96
- # Remove least recently used model
97
- old_lang, _ = self.model_cache.popitem(last=False)
98
- if old_lang in self.tts_models:
99
- del self.tts_models[old_lang]
100
- if old_lang in self.processors:
101
- del self.processors[old_lang]
102
- print(f"ONNX TTS: Evicted model for {old_lang} (LRU cache)")
103
-
104
- # First try to load ONNX model
105
- if language_code in self.tts_config:
106
- model_config = self.tts_config[language_code]
107
-
108
- if model_config.get("use_onnx", False):
109
- try:
110
- print(f"ONNX TTS: Loading ONNX model for {language_code}")
111
-
112
- # Create ONNX session with optimizations and verbose logging
113
- session_options = ort.SessionOptions()
114
- session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
115
-
116
- # Enable verbose logging to diagnose operator assignments
117
- session_options.log_severity_level = 1 # WARNING level for detailed logs
118
- session_options.logid = "ONNX_TTS" # Prefix for log identification
119
-
120
- # GPU memory optimization for T4 with diagnostic tracing
121
- if 'CUDAExecutionProvider' in self.providers:
122
- provider_options = [{
123
- 'device_id': 0,
124
- 'arena_extend_strategy': 'kSameAsRequested',
125
- 'gpu_mem_limit': int(0.3 * 1024 * 1024 * 1024), # 30% of GPU memory for TTS
126
- 'cudnn_conv_algo_search': 'EXHAUSTIVE',
127
- 'do_copy_in_default_stream': True,
128
- 'enable_tracing': True, # Enable tracing for better diagnostics
129
- }]
130
- providers = [('CUDAExecutionProvider', provider_options[0]), 'CPUExecutionProvider']
131
- else:
132
- providers = self.providers
133
-
134
- # Get authentication token for private repos
135
- import os
136
- auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
137
-
138
- # Download ONNX model from HuggingFace Hub with authentication
139
- from huggingface_hub import hf_hub_download
140
- onnx_path = hf_hub_download(
141
- repo_id=model_config["model_repo"],
142
- filename="model.onnx",
143
- token=auth_token
144
- )
145
-
146
- session = ort.InferenceSession(onnx_path, providers=providers, sess_options=session_options)
147
-
148
- # Load processor for preprocessing with authentication
149
- processor = AutoProcessor.from_pretrained(
150
- model_config["model_repo"],
151
- token=auth_token
152
- )
153
-
154
- self.tts_models[language_code] = session
155
- self.processors[language_code] = processor
156
- self.model_cache[language_code] = True
157
-
158
- print(f"ONNX TTS: Successfully loaded ONNX model for {language_code}")
159
- return
160
-
161
- except Exception as e:
162
- print(f"ONNX TTS: Failed to load ONNX model for {language_code}: {e}")
163
- # Continue to try fallback models
164
- else:
165
- # Try PyTorch model if ONNX not available
166
- try:
167
- print(f"ONNX TTS: Loading PyTorch model for {language_code} (fallback)")
168
- from transformers import pipeline
169
-
170
- pipeline_obj = pipeline(
171
- "text-to-speech",
172
- model=model_config["model_repo"],
173
- device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
174
- )
175
- self.tts_models[language_code] = pipeline_obj
176
- self.processors[language_code] = None # Not needed for pipeline
177
- self.model_cache[language_code] = True
178
-
179
- print(f"ONNX TTS: Successfully loaded PyTorch model for {language_code}")
180
- return
181
-
182
- except Exception as e:
183
- print(f"ONNX TTS: Failed to load PyTorch model for {language_code}: {e}")
184
-
185
- # Try fallback models if primary failed
186
- if language_code in self.fallback_tts_config:
187
- try:
188
- model_config = self.fallback_tts_config[language_code]
189
-
190
- if model_config["model_type"] == "speecht5":
191
- # Special handling for SpeechT5
192
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
193
- import torch
194
-
195
- # Get authentication token for private repos
196
- import os
197
- auth_token = os.getenv('HUGGING_FACE_HUB_TOKEN') or os.getenv('HF_TOKEN')
198
-
199
- processor = SpeechT5Processor.from_pretrained(
200
- model_config["model_repo"],
201
- token=auth_token
202
- )
203
- model = SpeechT5ForTextToSpeech.from_pretrained(
204
- model_config["model_repo"],
205
- token=auth_token
206
- )
207
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
208
-
209
- # Create a custom pipeline-like object
210
- class SpeechT5Pipeline:
211
- def __init__(self, processor, model, vocoder):
212
- self.processor = processor
213
- self.model = model
214
- self.vocoder = vocoder
215
-
216
- def __call__(self, text):
217
- inputs = self.processor(text=text, return_tensors="pt")
218
- # Use default speaker embeddings
219
- import datasets
220
- embeddings_dataset = datasets.load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
221
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
222
-
223
- speech = self.model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=self.vocoder)
224
-
225
- return {
226
- "audio": speech.numpy(),
227
- "sampling_rate": 16000
228
- }
229
-
230
- pipeline_obj = SpeechT5Pipeline(processor, model, vocoder)
231
- else:
232
- # Standard pipeline for MMS models
233
- from transformers import pipeline
234
- pipeline_obj = pipeline(
235
- "text-to-speech",
236
- model=model_config["model_repo"],
237
- device=0 if self.providers[0] == 'CUDAExecutionProvider' else -1
238
- )
239
-
240
- self.tts_models[language_code] = pipeline_obj
241
- self.processors[language_code] = None
242
- self.model_cache[language_code] = True
243
-
244
- print(f"ONNX TTS: Successfully loaded fallback model for {language_code}")
245
- return
246
-
247
- except Exception as e:
248
- print(f"ONNX TTS: Failed to load fallback TTS model for {language_code}: {e}")
249
-
250
- print(f"ONNX TTS: No TTS model available for language: {language_code}")
251
-
252
- async def generate_speech(self, text: str, language_code: str, output_format: str = "webm") -> Optional[bytes]:
253
- """Generate speech audio from text using ONNX models
254
-
255
- Args:
256
- text: Text to convert to speech
257
- language_code: Language code for TTS model
258
- output_format: Output format - "webm" (default, web-compatible) or "wav" (Android-compatible)
259
-
260
- Returns:
261
- Audio bytes in the requested format, or None if generation fails
262
- """
263
- try:
264
- print(f"=== ONNX TTS GENERATION REQUEST ===")
265
- print(f"Text: '{text}'")
266
- print(f"Language: {language_code}")
267
- print(f"Output format: {output_format}")
268
-
269
- # Input validation
270
- if not text or not text.strip():
271
- print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
272
- return None
273
-
274
- # Check for very short text that might cause issues
275
- clean_text = text.strip()
276
- if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
277
- print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
278
- return None
279
-
280
- # Check for minimum meaningful length
281
- if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
282
- print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
283
- return None
284
-
285
- # Check if the language is supported
286
- if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
287
- print(f"ONNX TTS: Language {language_code} not configured for TTS")
288
- return None
289
-
290
- await self.ensure_model_loaded(language_code)
291
-
292
- if language_code not in self.tts_models:
293
- print(f"ONNX TTS: TTS model not available for language: {language_code}")
294
- return None
295
-
296
- print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
297
-
298
- # Generate speech based on model type
299
- model_config = self.tts_config.get(language_code, {})
300
- if model_config.get("use_onnx", False):
301
- # ONNX inference
302
- audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
303
- else:
304
- # PyTorch pipeline inference
305
- pipeline_obj = self.tts_models[language_code]
306
- result = pipeline_obj(text)
307
-
308
- audio_array = result["audio"]
309
- sample_rate = result.get("sampling_rate", 16000) # Default to 16kHz (corrected)
310
-
311
- print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
312
-
313
- # Validate audio array
314
- if len(audio_array) == 0:
315
- print("ONNX TTS: Warning - Generated audio array is empty")
316
- return None
317
-
318
- # Check audio statistics
319
- audio_min = np.min(audio_array)
320
- audio_max = np.max(audio_array)
321
- audio_rms = np.sqrt(np.mean(audio_array**2))
322
- print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
323
-
324
- # Check if audio might be silent or corrupted
325
- if audio_rms < 0.001:
326
- print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
327
- if audio_max > 1.0 or audio_min < -1.0:
328
- print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
329
- # Clip to valid range
330
- audio_array = np.clip(audio_array, -1.0, 1.0)
331
- print("ONNX TTS: Clipped audio to valid range")
332
-
333
- # Convert to requested format
334
- if output_format == "wav":
335
- # For Android: use 16kHz sample rate
336
- target_sample_rate = 16000
337
- wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
338
- print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
339
-
340
- # Convert sample rate to 16kHz if needed for Android compatibility
341
- if sample_rate != target_sample_rate:
342
- print(f"ONNX TTS: Converting sample rate from {sample_rate}Hz to {target_sample_rate}Hz")
343
- wav_bytes = await self._resample_wav_to_16khz(wav_bytes, sample_rate)
344
- print(f"ONNX TTS: Resampled WAV: {len(wav_bytes)} bytes")
345
-
346
- print(f"ONNX TTS: Generated {len(wav_bytes)} bytes of WAV audio for '{text}'")
347
- print(f"=== ONNX TTS GENERATION COMPLETE ===")
348
-
349
- return wav_bytes
350
- else:
351
- # For web: use original sample rate and convert to WebM
352
- wav_bytes = self._convert_to_wav_bytes(audio_array, sample_rate)
353
- print(f"ONNX TTS: Converted to WAV: {len(wav_bytes)} bytes")
354
-
355
- # Convert to WebM format for web compatibility
356
- webm_bytes = await self._convert_to_webm(wav_bytes)
357
-
358
- print(f"ONNX TTS: Generated {len(webm_bytes)} bytes of WebM audio for '{text}'")
359
- print(f"=== ONNX TTS GENERATION COMPLETE ===")
360
-
361
- return webm_bytes
362
-
363
- except Exception as e:
364
- print(f"ONNX TTS: TTS generation error: {e}")
365
- import traceback
366
- traceback.print_exc()
367
- return None
368
-
369
- async def _run_onnx_tts_inference(self, text: str, language_code: str) -> tuple[np.ndarray, int]:
370
- """Run ONNX inference for text-to-speech"""
371
- try:
372
- session = self.tts_models[language_code]
373
- processor = self.processors[language_code]
374
-
375
- # Preprocess text
376
- inputs = processor(text=text, return_tensors="np")
377
-
378
- # Get input names for ONNX session
379
- input_names = [inp.name for inp in session.get_inputs()]
380
-
381
- # Prepare inputs for ONNX
382
- onnx_inputs = {}
383
- for name in input_names:
384
- if name in inputs:
385
- onnx_inputs[name] = inputs[name]
386
- elif name == "input_ids" and "input_ids" in inputs:
387
- onnx_inputs[name] = inputs["input_ids"].astype(np.int64)
388
- elif name == "attention_mask" and "attention_mask" in inputs:
389
- onnx_inputs[name] = inputs["attention_mask"].astype(np.int64)
390
-
391
- # Run ONNX inference
392
- outputs = session.run(None, onnx_inputs)
393
-
394
- # Extract audio from outputs (assuming first output is audio)
395
- audio_array = outputs[0]
396
-
397
- # Ensure audio is 1D
398
- if audio_array.ndim > 1:
399
- audio_array = audio_array.flatten()
400
-
401
- # Convert to float32 if needed
402
- if audio_array.dtype != np.float32:
403
- audio_array = audio_array.astype(np.float32)
404
-
405
- # Sample rate is 16kHz for our corrected models
406
- sample_rate = 16000
407
-
408
- return audio_array, sample_rate
409
-
410
- except Exception as e:
411
- print(f"ONNX TTS: Inference error: {e}")
412
- import traceback
413
- traceback.print_exc()
414
- return np.array([], dtype=np.float32), 16000
415
-
416
- async def generate_speech_dual_format(self, text: str, language_code: str) -> tuple[Optional[bytes], Optional[bytes]]:
417
- """Generate speech audio in both WebM and WAV formats using ONNX
418
-
419
- Args:
420
- text: Text to convert to speech
421
- language_code: Language code for TTS model
422
-
423
- Returns:
424
- Tuple of (webm_bytes, wav_bytes), either can be None if generation fails
425
- """
426
- try:
427
- print(f"=== ONNX TTS DUAL FORMAT GENERATION REQUEST ===")
428
- print(f"Text: '{text}'")
429
- print(f"Language: {language_code}")
430
-
431
- # Input validation
432
- if not text or not text.strip():
433
- print("ONNX TTS: Empty or whitespace-only text, skipping TTS generation")
434
- return None, None
435
-
436
- clean_text = text.strip()
437
- if len(clean_text) <= 2 and clean_text in [".", ",", "!", "?", ":", ";", "-"]:
438
- print(f"ONNX TTS: Text '{clean_text}' is too short or punctuation-only, skipping TTS generation")
439
- return None, None
440
-
441
- if len(clean_text.replace(" ", "").replace(".", "").replace(",", "")) < 2:
442
- print(f"ONNX TTS: Text '{clean_text}' has insufficient content for TTS, skipping")
443
- return None, None
444
-
445
- # Check if the language is supported
446
- if language_code not in self.tts_config and language_code not in self.fallback_tts_config:
447
- print(f"ONNX TTS: Language {language_code} not configured for TTS")
448
- return None, None
449
-
450
- await self.ensure_model_loaded(language_code)
451
-
452
- if language_code not in self.tts_models:
453
- print(f"ONNX TTS: TTS model not available for language: {language_code}")
454
- return None, None
455
-
456
- print(f"ONNX TTS: Generating speech for '{text}' in {language_code}")
457
-
458
- # Generate speech once
459
- model_config = self.tts_config.get(language_code, {})
460
- if model_config.get("use_onnx", False):
461
- # ONNX inference
462
- audio_array, sample_rate = await self._run_onnx_tts_inference(text, language_code)
463
- else:
464
- # PyTorch pipeline inference
465
- pipeline_obj = self.tts_models[language_code]
466
- result = pipeline_obj(text)
467
-
468
- audio_array = result["audio"]
469
- sample_rate = result.get("sampling_rate", 16000)
470
-
471
- print(f"ONNX TTS: Generated audio array of length {len(audio_array)} at {sample_rate}Hz")
472
-
473
- # Validate audio array
474
- if len(audio_array) == 0:
475
- print("ONNX TTS: Warning - Generated audio array is empty")
476
- return None, None
477
-
478
- # Check for potential issues with audio data
479
- audio_min = np.min(audio_array)
480
- audio_max = np.max(audio_array)
481
- audio_rms = np.sqrt(np.mean(audio_array**2))
482
- print(f"ONNX TTS: Audio statistics - Min: {audio_min:.4f}, Max: {audio_max:.4f}, RMS: {audio_rms:.4f}")
483
-
484
- if audio_rms < 0.001:
485
- print("ONNX TTS: Warning - Audio appears to be very quiet or silent")
486
- if audio_max > 1.0 or audio_min < -1.0:
487
- print("ONNX TTS: Warning - Audio values outside expected range [-1, 1]")
488
- audio_array = np.clip(audio_array, -1.0, 1.0)
489
- print("ONNX TTS: Clipped audio to valid range")
490
-
491
- # Generate WAV at original sample rate first
492
- wav_bytes_original = self._convert_to_wav_bytes(audio_array, sample_rate)
493
- print(f"ONNX TTS: Converted to WAV: {len(wav_bytes_original)} bytes")
494
-
495
- # Generate WebM from original WAV
496
- webm_bytes = await self._convert_to_webm(wav_bytes_original)
497
- print(f"ONNX TTS: Converted to WebM: {len(webm_bytes)} bytes")
498
-
499
- # Generate 16kHz WAV for Android
500
- wav_bytes_16k = await self._resample_wav_to_16khz(wav_bytes_original, sample_rate)
501
- print(f"ONNX TTS: Resampled to 16kHz WAV: {len(wav_bytes_16k)} bytes")
502
-
503
- print(f"ONNX TTS: Generated dual format audio for '{text}'")
504
- print(f"=== ONNX TTS DUAL FORMAT GENERATION COMPLETE ===")
505
-
506
- return webm_bytes, wav_bytes_16k
507
-
508
- except Exception as e:
509
- print(f"ONNX TTS: Dual format TTS generation error: {e}")
510
- import traceback
511
- traceback.print_exc()
512
- return None, None
513
-
514
- def _convert_to_wav_bytes(self, audio_array: np.ndarray, sample_rate: int) -> bytes:
515
- """Convert numpy audio array to WAV bytes"""
516
- buffer = io.BytesIO()
517
- with wave.open(buffer, 'wb') as wav_file:
518
- wav_file.setnchannels(1) # Mono
519
- wav_file.setsampwidth(2) # 16-bit
520
- wav_file.setframerate(sample_rate)
521
-
522
- # Ensure audio is in valid range [-1, 1]
523
- audio_array = np.clip(audio_array, -1.0, 1.0)
524
-
525
- # Convert to int16 with proper scaling
526
- int16_audio = (audio_array * 32767).astype(np.int16)
527
-
528
- # Validate the converted audio
529
- print(f"ONNX TTS: Converting {len(audio_array)} samples to WAV at {sample_rate}Hz")
530
- print(f"ONNX TTS: Int16 audio range: {np.min(int16_audio)} to {np.max(int16_audio)}")
531
-
532
- wav_file.writeframes(int16_audio.tobytes())
533
-
534
- wav_data = buffer.getvalue()
535
- print(f"ONNX TTS: WAV file created: {len(wav_data)} bytes")
536
-
537
- return wav_data
538
-
539
- async def _resample_wav_to_16khz(self, wav_bytes: bytes, original_sample_rate: int) -> bytes:
540
- """Resample WAV audio to 16kHz using FFmpeg"""
541
- try:
542
- process = subprocess.Popen([
543
- "ffmpeg", "-f", "wav", "-i", "pipe:0",
544
- "-ar", "16000", # Set output sample rate to 16kHz
545
- "-ac", "1", # Ensure mono output
546
- "-f", "wav", "pipe:1"
547
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
548
-
549
- resampled_data, stderr = process.communicate(input=wav_bytes)
550
-
551
- if process.returncode != 0:
552
- print(f"ONNX TTS: FFmpeg resampling error: {stderr.decode()}")
553
- return wav_bytes # Return original if resampling fails
554
-
555
- return resampled_data
556
-
557
- except Exception as e:
558
- print(f"ONNX TTS: Resampling error: {e}")
559
- return wav_bytes # Return original if resampling fails
560
-
561
- async def _convert_to_webm(self, wav_bytes: bytes) -> bytes:
562
- """Convert WAV bytes to WebM format using FFmpeg"""
563
- try:
564
- process = subprocess.Popen([
565
- "ffmpeg", "-f", "wav", "-i", "pipe:0",
566
- "-c:a", "libopus", "-b:a", "64k",
567
- "-f", "webm", "pipe:1"
568
- ], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
569
-
570
- webm_data, stderr = process.communicate(input=wav_bytes)
571
-
572
- if process.returncode != 0:
573
- print(f"ONNX TTS: FFmpeg error: {stderr.decode()}")
574
- return wav_bytes # Return original WAV if conversion fails
575
-
576
- return webm_data
577
-
578
- except Exception as e:
579
- print(f"ONNX TTS: WebM conversion error: {e}")
580
- return wav_bytes # Return original WAV if conversion fails
581
-
582
- async def cleanup(self):
583
- """Cleanup resources"""
584
- self.tts_models.clear()
585
- self.processors.clear()
586
- self.model_cache.clear()
587
- print("ONNX TTS: TTS service cleaned up")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/services/websocket_manager.py DELETED
@@ -1,909 +0,0 @@
1
- import asyncio
2
- import uuid
3
- from typing import Dict, Set, Optional
4
- import socketio
5
- import numpy as np
6
-
7
- from app.models import Message, LanguageCode
8
- from app.services.session_manager import SessionManager, LANGUAGE_MAP
9
- from app.services.transcription_service import TranscriptionService
10
- from app.services.translation_service import TranslationService
11
- from app.services.tts_service import TTSService
12
-
13
- def truncate_array_for_log(arr, max_items=10):
14
- """Helper function to truncate arrays in log messages for readability"""
15
- if not arr or len(arr) <= max_items:
16
- return arr
17
- return arr[:max_items] + [f"... {len(arr) - max_items} more items"]
18
-
19
- class WebSocketManager:
20
- def __init__(self, session_manager: SessionManager, transcription_service: TranscriptionService,
21
- translation_service: TranslationService, tts_service: TTSService):
22
- self.session_manager = session_manager
23
- self.transcription_service = transcription_service
24
- self.translation_service = translation_service
25
- self.tts_service = tts_service
26
- self.sio = None # Will be set by main.py
27
-
28
- self.client_sessions: Dict[str, str] = {} # sid -> session_id
29
- self.client_participants: Dict[str, str] = {} # sid -> participant_id
30
- self.session_clients: Dict[str, Set[str]] = {} # session_id -> set of sids
31
-
32
- self.messages: Dict[str, Message] = {} # message_id -> message
33
- self.participant_current_message: Dict[str, str] = {} # participant_id -> current_message_id
34
- self.processed_messages: Set[str] = set() # Track processed message IDs to prevent duplicates
35
-
36
- def set_socketio(self, sio):
37
- """Set the Socket.IO server instance"""
38
- self.sio = sio
39
-
40
- async def handle_join_session(self, sid: str, data: dict):
41
- """Handle participant joining a session"""
42
- try:
43
- session_id = data.get('sessionId')
44
- participant_name = data.get('participantName')
45
- language_code = data.get('language')
46
-
47
- print(f"=== JOIN SESSION REQUEST ===")
48
- print(f"Session ID: {session_id}")
49
- print(f"Participant: {participant_name}")
50
- print(f"Language: {language_code}")
51
-
52
- if not all([session_id, participant_name, language_code]):
53
- await self._emit_error(sid, "Missing required fields")
54
- return
55
-
56
- # Validate language code
57
- try:
58
- lang_enum = LanguageCode(language_code)
59
- print(f"Language code validated: {lang_enum}")
60
- except ValueError:
61
- await self._emit_error(sid, f"Invalid language code: {language_code}")
62
- return
63
-
64
- # Resolve session ID (in case it's a short code)
65
- session = await self.session_manager.get_session(session_id)
66
- if not session:
67
- await self._emit_error(sid, "Session not found")
68
- return
69
-
70
- # Use the full UUID for all subsequent operations
71
- session_id = session.id
72
- print(f"Resolved session ID: {session_id}")
73
-
74
- # Add participant to session
75
- participant = await self.session_manager.add_participant(
76
- session_id, participant_name, lang_enum
77
- )
78
-
79
- print(f"Participant created: {participant}")
80
-
81
- if not participant:
82
- await self._emit_error(sid, "Session not found or unable to join")
83
- return
84
-
85
- # Get updated session info
86
- session = await self.session_manager.get_session(session_id)
87
- if session:
88
- print(f"Session {session_id} now has {len(session.languages)} languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
89
- print(f"Session participants: {[f'{p.name}({p.language.name})' for p in session.participants]}")
90
-
91
- # Track client connections
92
- self.client_sessions[sid] = session_id
93
- self.client_participants[sid] = participant.id
94
-
95
- if session_id not in self.session_clients:
96
- self.session_clients[session_id] = set()
97
- self.session_clients[session_id].add(sid)
98
-
99
- # Send success response
100
- await self.sio.emit('participant_joined', participant.dict(), room=sid)
101
-
102
- # Notify other participants
103
- await self._broadcast_to_session(session_id, 'participant_update', participant.dict(), exclude_sid=sid)
104
-
105
- print(f"=== JOIN SESSION COMPLETE ===")
106
-
107
- except Exception as e:
108
- print(f"Error in handle_join_session: {e}")
109
- import traceback
110
- traceback.print_exc()
111
- await self._emit_error(sid, "Failed to join session")
112
-
113
- async def handle_join_hub(self, sid: str, data: dict):
114
- """Handle hub joining a session for observation"""
115
- try:
116
- session_id = data.get('sessionId')
117
- if not session_id:
118
- await self._emit_error(sid, "Missing sessionId for hub")
119
- return
120
-
121
- # Verify session exists
122
- session = await self.session_manager.get_session(session_id)
123
- if not session:
124
- await self._emit_error(sid, "Session not found")
125
- return
126
-
127
- # Track hub connection
128
- self.client_sessions[sid] = session_id
129
-
130
- if session_id not in self.session_clients:
131
- self.session_clients[session_id] = set()
132
- self.session_clients[session_id].add(sid)
133
-
134
- # Send success response
135
- await self.sio.emit('hub_joined', {'sessionId': session_id}, room=sid)
136
-
137
- print(f"Hub joined session {session_id} with sid {sid}")
138
-
139
- except Exception as e:
140
- print(f"Error in handle_join_hub: {e}")
141
- await self._emit_error(sid, "Failed to join as hub")
142
-
143
- async def handle_audio_chunk(self, sid: str, data: dict):
144
- """Handle incoming audio chunk from participant"""
145
- try:
146
- participant_id = self.client_participants.get(sid)
147
- if not participant_id:
148
- return
149
-
150
- audio_data = data.get('audioData', [])
151
- is_pause_boundary = data.get('isPauseBoundary', False)
152
-
153
- if not audio_data:
154
- return
155
-
156
- # Convert array to bytes
157
- audio_bytes = bytes(audio_data)
158
-
159
- # Process audio chunk using VAD-based approach
160
- if audio_bytes:
161
- # Check for voice activity in this chunk
162
- has_voice = self.transcription_service.has_voice_activity(audio_bytes)
163
-
164
- # Process the chunk (even if no voice to handle silence detection)
165
- # If isPauseBoundary is True, force finalization by treating as silence
166
- await self._process_audio_chunk_vad(participant_id, audio_bytes, has_voice and not is_pause_boundary, is_pause_boundary)
167
-
168
- except Exception as e:
169
- print(f"Error in handle_audio_chunk: {e}")
170
- import traceback
171
- traceback.print_exc()
172
-
173
- async def handle_speaking_status(self, sid: str, data: dict):
174
- """Handle speaking status updates"""
175
- try:
176
- participant_id = self.client_participants.get(sid)
177
- if not participant_id:
178
- return
179
-
180
- is_speaking = data.get('isSpeaking', False)
181
- await self.session_manager.update_participant_speaking_status(participant_id, is_speaking)
182
-
183
- # If participant stopped speaking, force complete any pending sentence
184
- if not is_speaking:
185
- # Get session and participant info for force completion
186
- session_id = await self.session_manager.get_participant_session_id(participant_id)
187
- if session_id:
188
- session = await self.session_manager.get_session(session_id)
189
- participant = next((p for p in session.participants if p.id == participant_id), None)
190
-
191
- if participant:
192
- # Define the sentence callback for force completion
193
- async def force_sentence_callback(final_text: str, final_audio: bytes):
194
- # Create or get existing message
195
- current_message_id = self.participant_current_message.get(participant_id)
196
- if not current_message_id:
197
- current_message_id = str(uuid.uuid4())
198
-
199
- # Check if this message was already processed
200
- if current_message_id in self.processed_messages:
201
- print(f"Force completion: Message {current_message_id} already processed, skipping duplicate")
202
- return
203
-
204
- # Mark as processed to prevent duplicates
205
- self.processed_messages.add(current_message_id)
206
-
207
- from app.models import Message
208
- message = Message(
209
- id=current_message_id,
210
- session_id=session_id,
211
- speaker_id=participant_id,
212
- speaker_name=participant.name,
213
- original_text=final_text,
214
- original_language=participant.language,
215
- translations={},
216
- is_transcribing=False
217
- )
218
- self.messages[current_message_id] = message
219
-
220
- # Broadcast the completed message
221
- print(f"Force completion: Broadcasting message_complete for {current_message_id}: '{final_text}'")
222
- await self._broadcast_to_session(session_id, 'message_complete', {
223
- 'messageId': current_message_id,
224
- 'sessionId': session_id,
225
- 'text': final_text,
226
- 'speakerId': participant_id,
227
- 'speakerName': participant.name,
228
- 'language': participant.language.code.value
229
- })
230
-
231
- # Clear current message tracking
232
- if participant_id in self.participant_current_message:
233
- del self.participant_current_message[participant_id]
234
-
235
- # Start translation processing (non-blocking to allow continued audio processing)
236
- print("Starting TRANSLATION and TTS (background task)")
237
- asyncio.create_task(self._process_translations_and_tts(message, session))
238
-
239
- # Force complete any pending sentence
240
- await self.transcription_service.force_complete_sentence(
241
- participant_id,
242
- participant.language.code.value,
243
- force_sentence_callback
244
- )
245
-
246
- # Clear transcription service buffers after force completion
247
- self.transcription_service.clear_participant_buffers(participant_id)
248
-
249
- # Broadcast speaking status to session
250
- session_id = self.client_sessions.get(sid)
251
- if session_id:
252
- await self._broadcast_to_session(session_id, 'speaking_status', {
253
- 'participantId': participant_id,
254
- 'isSpeaking': is_speaking
255
- })
256
-
257
- except Exception as e:
258
- print(f"Error in handle_speaking_status: {e}")
259
- import traceback
260
- traceback.print_exc()
261
-
262
- async def handle_leave_session(self, sid: str, data: dict):
263
- """Handle participant leaving a session"""
264
- await self._cleanup_client(sid)
265
-
266
- async def handle_disconnect(self, sid: str):
267
- """Handle client disconnection"""
268
- await self._cleanup_client(sid)
269
- async def _process_audio_chunk_vad(self, participant_id: str, audio_data: bytes, has_voice_activity: bool, is_pause_boundary: bool = False):
270
- """Process audio chunk using VAD-based sentence detection
271
-
272
- Args:
273
- participant_id: ID of the participant
274
- audio_data: Raw audio data bytes
275
- has_voice_activity: Whether voice activity was detected in this chunk
276
- is_pause_boundary: If True, forces sentence finalization (from stop button or explicit pause)
277
- """
278
- try:
279
- session_id = await self.session_manager.get_participant_session_id(participant_id)
280
- if not session_id:
281
- return
282
-
283
- session = await self.session_manager.get_session(session_id)
284
- if not session:
285
- return
286
-
287
- participant = next((p for p in session.participants if p.id == participant_id), None)
288
- if not participant:
289
- return
290
-
291
- # Get or create current message for this participant
292
- current_message_id = self.participant_current_message.get(participant_id)
293
- if not current_message_id:
294
- current_message_id = str(uuid.uuid4())
295
- message = Message(
296
- id=current_message_id,
297
- session_id=session_id,
298
- speaker_id=participant_id,
299
- speaker_name=participant.name,
300
- original_text="",
301
- original_language=participant.language,
302
- translations={},
303
- is_transcribing=True
304
- )
305
- self.messages[current_message_id] = message
306
- self.participant_current_message[participant_id] = current_message_id
307
-
308
- # Start typing indicator
309
- await self._broadcast_to_session(session_id, 'typing_start', {
310
- 'speakerId': participant_id,
311
- 'speakerName': participant.name,
312
- 'languageCode': participant.language.code.value
313
- })
314
-
315
- message = self.messages[current_message_id]
316
-
317
- # Define callbacks
318
- async def on_progress(text: str, is_complete: bool):
319
- """Called with in-progress transcription updates"""
320
- # Update the message text even for progress updates
321
- message.original_text = text
322
-
323
- await self._broadcast_to_session(session_id, 'transcription_progress', {
324
- 'messageId': current_message_id,
325
- 'text': text,
326
- 'isTranscribing': not is_complete,
327
- 'speakerId': participant_id,
328
- 'speakerName': participant.name
329
- })
330
-
331
- async def on_debug(debug_info: dict):
332
- """Called with debug information from ASR (wav2vec2 models only)"""
333
- # Prepare debug data for transmission
334
- debug_data = {
335
- 'messageId': current_message_id,
336
- 'text': debug_info['text'],
337
- 'timestamps': debug_info['timestamps'],
338
- 'audioData': list(debug_info['audio_data']),
339
- 'audioDuration': debug_info['audio_duration'],
340
- 'modelType': debug_info['model_type'],
341
- 'language': participant.language.code.value
342
- }
343
-
344
- await self._broadcast_to_session(session_id, 'transcription_debug', debug_data)
345
-
346
- async def on_sentence_complete(final_text: str, final_audio: bytes):
347
- """Called when a complete sentence is detected"""
348
-
349
- # Check if this message was already processed
350
- if current_message_id in self.processed_messages:
351
- print(f"Message {current_message_id} already processed, skipping duplicate")
352
- return
353
-
354
- # Mark as processed to prevent duplicates
355
- self.processed_messages.add(current_message_id)
356
-
357
- message.original_text = final_text
358
- message.is_transcribing = False
359
-
360
- # Broadcast complete sentence with session ID
361
- message_data = {
362
- 'messageId': current_message_id,
363
- 'sessionId': session_id,
364
- 'text': final_text,
365
- 'speakerId': participant_id,
366
- 'speakerName': participant.name,
367
- 'language': participant.language.code.value,
368
- 'audioData': list(final_audio)
369
- }
370
-
371
- print(f"Broadcasting message_complete for {current_message_id}: '{final_text}'")
372
- await self._broadcast_to_session(session_id, 'message_complete', message_data)
373
-
374
- # Stop typing indicator
375
- await self._broadcast_to_session(session_id, 'typing_stop', {
376
- 'speakerId': participant_id
377
- })
378
-
379
- # Clear current message tracking
380
- if participant_id in self.participant_current_message:
381
- del self.participant_current_message[participant_id]
382
-
383
- # Start translation and TTS processing (non-blocking to allow continued audio processing)
384
- print("Starting TRANSLATION and TTS (background task)")
385
- asyncio.create_task(self._process_translations_and_tts(message, session))
386
-
387
- # Process the audio chunk
388
- result_text = await self.transcription_service.process_audio_chunk(
389
- audio_data,
390
- participant.language.code.value,
391
- participant_id,
392
- has_voice_activity,
393
- progress_callback=on_progress,
394
- sentence_callback=on_sentence_complete,
395
- debug_callback=on_debug
396
- )
397
-
398
- # If this is a pause boundary (stop button clicked), force immediate finalization
399
- if is_pause_boundary and participant_id in self.participant_current_message:
400
- print(f"Pause boundary detected - forcing sentence finalization for participant {participant_id}")
401
- # Get the current accumulated text from transcription service
402
- if hasattr(self.transcription_service, 'candidate_text_cache') and participant_id in self.transcription_service.candidate_text_cache:
403
- final_text = self.transcription_service.candidate_text_cache.get(participant_id, "").strip()
404
- if final_text: # Only finalize if there's actual text
405
- # Get accumulated audio
406
- final_audio = b""
407
- if hasattr(self.transcription_service, 'candidate_audio_buffers') and participant_id in self.transcription_service.candidate_audio_buffers:
408
- audio_array = self.transcription_service.candidate_audio_buffers.get(participant_id, np.array([]))
409
- if len(audio_array) > 0:
410
- # Convert float array to int16 bytes
411
- audio_int16 = (audio_array * 32767).astype(np.int16)
412
- final_audio = audio_int16.tobytes()
413
-
414
- # Trigger sentence completion
415
- await on_sentence_complete(final_text, final_audio)
416
-
417
- # Clear the buffers manually since we're forcing finalization
418
- if participant_id in self.transcription_service.candidate_text_cache:
419
- self.transcription_service.candidate_text_cache[participant_id] = ""
420
- if participant_id in self.transcription_service.candidate_audio_buffers:
421
- self.transcription_service.candidate_audio_buffers[participant_id] = np.array([], dtype=np.float32)
422
- if participant_id in self.transcription_service.silence_counters:
423
- self.transcription_service.silence_counters[participant_id] = 0
424
- if participant_id in self.transcription_service.sentence_finalized:
425
- self.transcription_service.sentence_finalized[participant_id] = False
426
-
427
- except Exception as e:
428
- print(f"Error in _process_audio_chunk_vad: {e}")
429
- import traceback
430
- traceback.print_exc()
431
-
432
- async def _process_translations_and_tts(self, message: Message, session):
433
- """Process translations and TTS for all session languages"""
434
- try:
435
- source_lang = message.original_language.name
436
-
437
- print(f"=== TRANSLATION/TTS PROCESSING START ===")
438
- print(f"Message ID: {message.id}")
439
- print(f"Original message: '{message.original_text}'")
440
- print(f"Original language: {message.original_language.name} ({message.original_language.code.value})")
441
- print(f"Session languages: {[f'{lang.name} ({lang.code.value})' for lang in session.languages]}")
442
- print(f"Session ID for verification: {session.id}")
443
-
444
- # Create a mapping to track which audio belongs to which message and language
445
- audio_tasks = []
446
-
447
- # Check if TTS is enabled for this session
448
- if session.enable_tts:
449
- # First, generate TTS for the original message
450
- print(f"TTS: Generating TTS for original message in {message.original_language.code.value}: '{message.original_text}'")
451
- print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{message.original_language.name.lower()}) - File: tts_service_onnx.py")
452
- original_audio_task = asyncio.create_task(
453
- self.tts_service.generate_speech_dual_format(message.original_text, message.original_language.code.value)
454
- )
455
- audio_tasks.append((
456
- message.original_language.code.value,
457
- message.original_text,
458
- original_audio_task,
459
- True # is_original
460
- ))
461
- else:
462
- print(f"TTS: Skipping TTS generation (disabled for this session)")
463
-
464
- # Process translations for each language in the session
465
- print(f"Processing translations for {len(session.languages)} session languages...")
466
- print(f"Session languages: {[f'{lang.name}({lang.code.value})' for lang in session.languages]}")
467
- translation_tasks = []
468
-
469
- for language in session.languages:
470
- print(f"Checking language: {language.name} ({language.code.value})")
471
- if language.code != message.original_language.code:
472
- print(f"TRANSLATING: '{message.original_text}' from {source_lang} to {language.name}")
473
- print(f"Translation Model: mutisya/nllb_600m (NLLB-600M) - File: translation_service.py")
474
-
475
- # Create translation task
476
- translation_task = asyncio.create_task(
477
- self.translation_service.translate_text(
478
- message.original_text, source_lang, language.name
479
- )
480
- )
481
- translation_tasks.append((language, translation_task))
482
- else:
483
- print(f"SKIPPING translation for {language.name} (same as original language)")
484
-
485
- print(f"Created {len(translation_tasks)} translation tasks for non-original languages")
486
-
487
- # Wait for all translations to complete
488
- for language, translation_task in translation_tasks:
489
- try:
490
- translated_text = await translation_task
491
-
492
- if translated_text:
493
- print(f"TRANSLATION SUCCESS: '{translated_text}' for {language.name}")
494
-
495
- message.translations[language.code.value] = translated_text
496
-
497
- # Broadcast translation update to all clients
498
- await self._broadcast_to_session(message.session_id, 'translation_update', {
499
- 'messageId': message.id,
500
- 'targetLanguage': language.code.value,
501
- 'translatedText': translated_text
502
- })
503
-
504
- # Check if TTS is enabled for this session
505
- if session.enable_tts:
506
- # Generate TTS for the translated text
507
- print(f"TTS: Generating TTS for translation in {language.code.value}: '{translated_text}'")
508
- print(f"TTS Model: VITS ONNX (mutisya/vits-tts-onnx-fp32-{language.name.lower()}) - File: tts_service_onnx.py")
509
- tts_task = asyncio.create_task(
510
- self.tts_service.generate_speech_dual_format(translated_text, language.code.value)
511
- )
512
- audio_tasks.append((
513
- language.code.value,
514
- translated_text,
515
- tts_task,
516
- False # is_original
517
- ))
518
- else:
519
- print(f"TTS: Skipping TTS generation for translation (disabled for this session)")
520
- else:
521
- print(f"TRANSLATION FAILED: No translated text returned for {language.name}")
522
- except Exception as e:
523
- print(f"Translation error for {language.name}: {e}")
524
-
525
- # Wait for all TTS generation to complete and broadcast with proper alignment
526
- for language_code, text, audio_task, is_original in audio_tasks:
527
- try:
528
- audio_result = await audio_task
529
-
530
- if audio_result and (audio_result[0] or audio_result[1]):
531
- webm_data, wav_data = audio_result
532
- print(f"TTS: Audio generated successfully for {language_code}")
533
- if webm_data:
534
- print(f"TTS: WebM audio: {len(webm_data)} bytes")
535
- if wav_data:
536
- print(f"TTS: WAV audio: {len(wav_data)} bytes")
537
- print(f"TTS: Text for {language_code}: '{text}'")
538
-
539
- # Broadcast TTS audio with explicit message-text-audio alignment (dual format)
540
- await self._broadcast_tts_audio_aligned_dual_format(
541
- message.session_id,
542
- message.id,
543
- language_code,
544
- text,
545
- webm_data,
546
- wav_data,
547
- is_original
548
- )
549
- else:
550
- print(f"TTS: Failed to generate audio for {language_code}")
551
- except Exception as e:
552
- print(f"TTS generation error for {language_code}: {e}")
553
-
554
- print(f"=== TRANSLATION/TTS PROCESSING END ===")
555
-
556
- except Exception as e:
557
- print(f"Error in _process_translations_and_tts: {e}")
558
- import traceback
559
- traceback.print_exc()
560
-
561
- async def _broadcast_to_session(self, session_id: str, event: str, data: dict, exclude_sid: str = None):
562
- """Broadcast message to all clients in a session"""
563
- if session_id not in self.session_clients:
564
- return
565
-
566
- # Create a copy of the set to avoid concurrent modification
567
- client_sids = list(self.session_clients[session_id])
568
-
569
- for sid in client_sids:
570
- if sid != exclude_sid:
571
- try:
572
- await self.sio.emit(event, data, room=sid)
573
- except Exception as e:
574
- print(f"Error broadcasting {event} to client {sid}: {e}")
575
-
576
- async def _broadcast_tts_audio_aligned(self, session_id: str, message_id: str,
577
- language_code: str, text: str, audio_data: bytes,
578
- is_original: bool = False):
579
- """Broadcast TTS audio with explicit message-text-audio alignment"""
580
- try:
581
- if session_id not in self.session_clients:
582
- return
583
-
584
- print(f"TTS ALIGNED: Broadcasting audio for message {message_id}")
585
- print(f"TTS ALIGNED: Language: {language_code}")
586
- print(f"TTS ALIGNED: Text: '{text}'")
587
- print(f"TTS ALIGNED: Audio size: {len(audio_data)} bytes")
588
- print(f"TTS ALIGNED: Is original: {is_original}")
589
-
590
- # Create a copy of the set to avoid concurrent modification
591
- client_sids = list(self.session_clients[session_id])
592
-
593
- # Send audio data in chunks to all participants with explicit alignment data
594
- for sid in client_sids:
595
- try:
596
- chunk_size = 4096
597
- for i in range(0, len(audio_data), chunk_size):
598
- chunk = audio_data[i:i + chunk_size]
599
- is_last_chunk = i + chunk_size >= len(audio_data)
600
-
601
- chunk_data = {
602
- 'messageId': message_id, # Explicit message ID
603
- 'languageCode': language_code, # Language of THIS audio
604
- 'text': text, # Text that THIS audio represents
605
- 'isOriginal': is_original, # Whether this is original or translation
606
- 'chunk': list(chunk),
607
- 'isLast': is_last_chunk,
608
- 'chunkIndex': i // chunk_size, # Chunk ordering
609
- 'totalChunks': (len(audio_data) + chunk_size - 1) // chunk_size
610
- }
611
-
612
- await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
613
-
614
- # Small delay to prevent overwhelming
615
- await asyncio.sleep(0.01)
616
-
617
- print(f"TTS ALIGNED: Successfully sent aligned audio to participant {sid}")
618
- except Exception as e:
619
- print(f"TTS ALIGNED: Error sending audio to participant {sid}: {e}")
620
-
621
- except Exception as e:
622
- print(f"TTS ALIGNED: Error broadcasting aligned audio: {e}")
623
-
624
- async def _broadcast_tts_audio_aligned_dual_format(self, session_id: str, message_id: str,
625
- language_code: str, text: str, webm_data: bytes,
626
- wav_data: bytes, is_original: bool = False):
627
- """Broadcast TTS audio with both WebM and WAV formats for cross-platform compatibility"""
628
- try:
629
- if session_id not in self.session_clients:
630
- return
631
-
632
- print(f"TTS DUAL FORMAT: Broadcasting audio for message {message_id}")
633
- print(f"TTS DUAL FORMAT: Language: {language_code}")
634
- print(f"TTS DUAL FORMAT: Text: '{text}'")
635
- if webm_data:
636
- print(f"TTS DUAL FORMAT: WebM size: {len(webm_data)} bytes")
637
- if wav_data:
638
- print(f"TTS DUAL FORMAT: WAV size: {len(wav_data)} bytes")
639
- print(f"TTS DUAL FORMAT: Is original: {is_original}")
640
-
641
- # Create a copy of the set to avoid concurrent modification
642
- client_sids = list(self.session_clients[session_id])
643
-
644
- # Use WebM data for chunking (primary format for web clients)
645
- primary_audio_data = webm_data if webm_data else wav_data
646
- if not primary_audio_data:
647
- print("TTS DUAL FORMAT: No audio data available")
648
- return
649
-
650
- # Send audio data in chunks to all participants with dual format support
651
- chunk_size = 4096
652
- for sid in client_sids:
653
- try:
654
- for i in range(0, len(primary_audio_data), chunk_size):
655
- chunk = primary_audio_data[i:i + chunk_size]
656
- is_last_chunk = i + chunk_size >= len(primary_audio_data)
657
-
658
- # Prepare WAV chunk if available
659
- wav_chunk = None
660
- if wav_data and i < len(wav_data):
661
- wav_end = min(i + chunk_size, len(wav_data))
662
- wav_chunk = wav_data[i:wav_end]
663
-
664
- chunk_data = {
665
- 'messageId': message_id, # Explicit message ID
666
- 'languageCode': language_code, # Language of THIS audio
667
- 'text': text, # Text that THIS audio represents
668
- 'isOriginal': is_original, # Whether this is original or translation
669
- 'chunk': list(chunk), # WebM audio chunk (for web clients)
670
- 'wavChunk': list(wav_chunk) if wav_chunk else None, # WAV audio chunk (for Android clients)
671
- 'isLast': is_last_chunk,
672
- 'chunkIndex': i // chunk_size, # Chunk ordering
673
- 'totalChunks': (len(primary_audio_data) + chunk_size - 1) // chunk_size,
674
- 'format': 'webm', # Primary format
675
- 'wavFormat': 'wav' if wav_chunk else None # Secondary format available
676
- }
677
-
678
- await self.sio.emit('tts_audio_chunk', chunk_data, room=sid)
679
-
680
- # Small delay to prevent overwhelming
681
- await asyncio.sleep(0.01)
682
-
683
- print(f"TTS DUAL FORMAT: Successfully sent dual format audio to participant {sid}")
684
- except Exception as e:
685
- print(f"TTS DUAL FORMAT: Error sending audio to participant {sid}: {e}")
686
-
687
- except Exception as e:
688
- print(f"TTS DUAL FORMAT: Error broadcasting dual format audio: {e}")
689
-
690
- async def _broadcast_tts_audio_to_all_participants(self, session_id: str, language_code: str,
691
- audio_data: bytes, message_id: str, text: str):
692
- """Legacy method - now calls the aligned version"""
693
- await self._broadcast_tts_audio_aligned(
694
- session_id, message_id, language_code, text, audio_data, False
695
- )
696
-
697
- async def _broadcast_audio_to_language_participants(self, session_id: str, language_code: str,
698
- audio_data: bytes, message_id: str):
699
- """Broadcast audio to participants listening in specific language (legacy method)"""
700
- try:
701
- session = await self.session_manager.get_session(session_id)
702
- if not session:
703
- return
704
-
705
- # Find participants with matching language
706
- target_participants = [p for p in session.participants if p.language.code.value == language_code]
707
-
708
- for participant in target_participants:
709
- # Find client SID for this participant
710
- participant_sid = None
711
- for sid, pid in self.client_participants.items():
712
- if pid == participant.id:
713
- participant_sid = sid
714
- break
715
-
716
- if participant_sid:
717
- print(f"TTS: Broadcasting audio to participant {participant.name} in {language_code}")
718
- # Send audio data in chunks
719
- chunk_size = 4096
720
- for i in range(0, len(audio_data), chunk_size):
721
- chunk = audio_data[i:i + chunk_size]
722
- await self.sio.emit('tts_audio_chunk', {
723
- 'messageId': message_id,
724
- 'chunk': list(chunk),
725
- 'isLast': i + chunk_size >= len(audio_data)
726
- }, room=participant_sid)
727
-
728
- # Small delay to prevent overwhelming
729
- await asyncio.sleep(0.01)
730
-
731
- except Exception as e:
732
- print(f"TTS: Error broadcasting audio: {e}")
733
-
734
- async def _cleanup_client(self, sid: str):
735
- """Clean up client data on disconnect"""
736
- try:
737
- participant_id = self.client_participants.get(sid)
738
- session_id = self.client_sessions.get(sid)
739
-
740
- if participant_id:
741
- # Remove participant from session
742
- await self.session_manager.remove_participant(participant_id)
743
-
744
- # Clear participant buffers
745
- self.transcription_service.clear_participant_buffers(participant_id)
746
-
747
- # Clear current message tracking
748
- if participant_id in self.participant_current_message:
749
- del self.participant_current_message[participant_id]
750
-
751
- del self.client_participants[sid]
752
-
753
- if session_id:
754
- # Remove from session clients
755
- if session_id in self.session_clients:
756
- self.session_clients[session_id].discard(sid)
757
- if not self.session_clients[session_id]:
758
- del self.session_clients[session_id]
759
- # If session is empty, clear processed messages for this session
760
- self._cleanup_session_processed_messages(session_id)
761
-
762
- del self.client_sessions[sid]
763
-
764
- except Exception as e:
765
- print(f"Error cleaning up client {sid}: {e}")
766
-
767
- def _cleanup_session_processed_messages(self, session_id: str):
768
- """Clean up processed messages for empty sessions to prevent memory leaks"""
769
- try:
770
- # Remove processed messages that belong to this session
771
- messages_to_remove = []
772
- for message_id in list(self.processed_messages):
773
- if message_id in self.messages and self.messages[message_id].session_id == session_id:
774
- messages_to_remove.append(message_id)
775
-
776
- for message_id in messages_to_remove:
777
- self.processed_messages.discard(message_id)
778
- if message_id in self.messages:
779
- del self.messages[message_id]
780
-
781
- print(f"Cleaned up {len(messages_to_remove)} processed messages for session {session_id}")
782
- except Exception as e:
783
- print(f"Error cleaning up session processed messages: {e}")
784
-
785
- async def _emit_error(self, sid: str, message: str):
786
- """Emit error message to specific client"""
787
- try:
788
- await self.sio.emit('join_error', message, room=sid)
789
- except Exception as e:
790
- print(f"Error emitting error to {sid}: {e}")
791
-
792
- async def handle_update_participant_language(self, sid: str, data: dict):
793
- """Handle participant language update (affects speech recognition)"""
794
- try:
795
- session_id = data.get('sessionId')
796
- participant_id = data.get('participantId')
797
- language_code = data.get('language')
798
-
799
- print(f"=== UPDATE PARTICIPANT LANGUAGE ===")
800
- print(f"Session ID: {session_id}")
801
- print(f"Participant ID: {participant_id}")
802
- print(f"New Language: {language_code}")
803
-
804
- if not all([session_id, participant_id, language_code]):
805
- await self._emit_error(sid, "Missing required fields")
806
- return
807
-
808
- # Validate language code
809
- try:
810
- from app.models import LanguageCode
811
- lang_enum = LanguageCode(language_code)
812
- print(f"Language code validated: {lang_enum}")
813
- except ValueError:
814
- await self._emit_error(sid, f"Invalid language code: {language_code}")
815
- return
816
-
817
- # Update participant's language in session
818
- session = await self.session_manager.get_session(session_id)
819
- if session:
820
- for participant in session.participants:
821
- if participant.id == participant_id:
822
- # Update participant's language using LANGUAGE_MAP for complete Language object
823
- if lang_enum in LANGUAGE_MAP:
824
- participant.language = LANGUAGE_MAP[lang_enum]
825
- print(f"Updated participant {participant.name} language to {lang_enum.value} ({participant.language.display_name})")
826
- else:
827
- print(f"Warning: Language {lang_enum.value} not found in LANGUAGE_MAP, using fallback")
828
- from app.models import Language
829
- participant.language = Language(code=lang_enum, name=lang_enum.value, display_name=lang_enum.value)
830
-
831
- # Notify all clients in session
832
- await self._broadcast_to_session(session_id, 'participant_language_updated', {
833
- 'participantId': participant_id,
834
- 'language': language_code
835
- })
836
- break
837
-
838
- print(f"=== UPDATE PARTICIPANT LANGUAGE COMPLETE ===")
839
-
840
- except Exception as e:
841
- print(f"Error in handle_update_participant_language: {e}")
842
- import traceback
843
- traceback.print_exc()
844
- await self._emit_error(sid, "Failed to update participant language")
845
-
846
- async def handle_update_session_languages(self, sid: str, data: dict):
847
- """Handle session languages update (affects translation targets)"""
848
- try:
849
- session_id = data.get('sessionId')
850
- languages = data.get('languages', [])
851
-
852
- print(f"=== UPDATE SESSION LANGUAGES (REPLACE MODE) ===")
853
- print(f"Session ID: {session_id}")
854
- print(f"New Languages: {languages}")
855
-
856
- if not session_id or not languages:
857
- await self._emit_error(sid, "Missing required fields")
858
- return
859
-
860
- # Get current session for comparison
861
- session = await self.session_manager.get_session(session_id)
862
- if not session:
863
- await self._emit_error(sid, "Session not found")
864
- return
865
-
866
- current_languages = [lang.code.value for lang in session.languages]
867
- print(f"Before update - Session languages: {current_languages}")
868
-
869
- # Validate all language codes and create Language objects
870
- validated_languages = []
871
- try:
872
- from app.models import Language, LanguageCode
873
- from app.services.session_manager import LANGUAGE_MAP
874
-
875
- for lang_code in languages:
876
- lang_enum = LanguageCode(lang_code)
877
- language = LANGUAGE_MAP[lang_enum]
878
- validated_languages.append(language)
879
- print(f"Validated language: {lang_code} -> {language.name}")
880
-
881
- except ValueError as e:
882
- await self._emit_error(sid, f"Invalid language code: {e}")
883
- return
884
-
885
- # REPLACE session languages (not add to them)
886
- session.languages = validated_languages
887
- new_languages = [lang.code.value for lang in session.languages]
888
- print(f"After update - Session languages: {new_languages}")
889
-
890
- # Verify the session manager has the updated languages
891
- verification_session = await self.session_manager.get_session(session_id)
892
- if verification_session:
893
- verification_languages = [lang.code.value for lang in verification_session.languages]
894
- print(f"Verification - Session manager languages: {verification_languages}")
895
-
896
- # Notify all clients in session about the update
897
- await self._broadcast_to_session(session_id, 'session_languages_updated', {
898
- 'sessionId': session_id,
899
- 'languages': new_languages,
900
- 'previous': current_languages
901
- })
902
-
903
- print(f"=== UPDATE SESSION LANGUAGES COMPLETE ===")
904
-
905
- except Exception as e:
906
- print(f"Error in handle_update_session_languages: {e}")
907
- import traceback
908
- traceback.print_exc()
909
- await self._emit_error(sid, "Failed to update session languages")