ducnguyen1978 commited on
Commit
b00ea69
·
verified ·
1 Parent(s): 5cca058

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +370 -0
utils.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility functions for Translation AI Agent
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import tempfile
9
+ import logging
10
+ import hashlib
11
+ from typing import Optional, Tuple, List, Dict, Any
12
+ import numpy as np
13
+ import librosa
14
+ import soundfile as sf
15
+ from pathlib import Path
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class AudioProcessor:
20
+ """Audio processing utilities"""
21
+
22
+ @staticmethod
23
+ def load_audio(file_path: str, target_sr: int = 16000) -> Tuple[np.ndarray, int]:
24
+ """Load and resample audio file"""
25
+ try:
26
+ audio, sr = librosa.load(file_path, sr=target_sr, mono=True)
27
+ return audio, sr
28
+ except Exception as e:
29
+ logger.error(f"Error loading audio: {e}")
30
+ raise
31
+
32
+ @staticmethod
33
+ def save_audio(audio: np.ndarray, file_path: str, sample_rate: int = 16000):
34
+ """Save audio array to file"""
35
+ try:
36
+ sf.write(file_path, audio, sample_rate)
37
+ except Exception as e:
38
+ logger.error(f"Error saving audio: {e}")
39
+ raise
40
+
41
+ @staticmethod
42
+ def get_audio_duration(file_path: str) -> float:
43
+ """Get duration of audio file in seconds"""
44
+ try:
45
+ audio, sr = librosa.load(file_path, sr=None)
46
+ return len(audio) / sr
47
+ except Exception as e:
48
+ logger.error(f"Error getting audio duration: {e}")
49
+ return 0.0
50
+
51
+ @staticmethod
52
+ def validate_audio_file(file_path: str, max_duration: int = 300) -> bool:
53
+ """Validate audio file format and duration"""
54
+ if not os.path.exists(file_path):
55
+ return False
56
+
57
+ try:
58
+ duration = AudioProcessor.get_audio_duration(file_path)
59
+ return 0 < duration <= max_duration
60
+ except:
61
+ return False
62
+
63
+ @staticmethod
64
+ def normalize_audio(audio: np.ndarray) -> np.ndarray:
65
+ """Normalize audio to [-1, 1] range"""
66
+ if audio.max() > 1.0 or audio.min() < -1.0:
67
+ audio = audio / np.max(np.abs(audio))
68
+ return audio
69
+
70
+ @staticmethod
71
+ def add_silence(audio: np.ndarray, duration: float, sample_rate: int) -> np.ndarray:
72
+ """Add silence to beginning and end of audio"""
73
+ silence_samples = int(duration * sample_rate)
74
+ silence = np.zeros(silence_samples)
75
+ return np.concatenate([silence, audio, silence])
76
+
77
+ class LanguageDetector:
78
+ """Language detection utilities"""
79
+
80
+ def __init__(self, keywords_dict: Dict[str, List[str]]):
81
+ self.keywords = keywords_dict
82
+
83
+ def detect(self, text: str, threshold: int = 2) -> str:
84
+ """Detect language from text using keyword matching"""
85
+ text_lower = text.lower().split()
86
+ scores = {}
87
+
88
+ for lang, keywords in self.keywords.items():
89
+ score = sum(1 for word in keywords if word in text_lower)
90
+ scores[lang] = score
91
+
92
+ # Get language with highest score
93
+ if scores:
94
+ detected_lang, score = max(scores.items(), key=lambda x: x[1])
95
+ if score >= threshold:
96
+ return detected_lang
97
+
98
+ return 'en' # Default to English
99
+
100
+ def get_confidence(self, text: str, detected_lang: str) -> float:
101
+ """Get confidence score for detected language"""
102
+ text_lower = text.lower().split()
103
+ keywords = self.keywords.get(detected_lang, [])
104
+
105
+ if not keywords or not text_lower:
106
+ return 0.0
107
+
108
+ matches = sum(1 for word in keywords if word in text_lower)
109
+ return min(matches / len(keywords), 1.0)
110
+
111
+ class FileManager:
112
+ """File management utilities"""
113
+
114
+ @staticmethod
115
+ def create_temp_file(suffix: str = '.wav', prefix: str = 'temp_') -> str:
116
+ """Create temporary file and return path"""
117
+ temp_file = tempfile.NamedTemporaryFile(
118
+ suffix=suffix,
119
+ prefix=prefix,
120
+ delete=False
121
+ )
122
+ temp_file.close()
123
+ return temp_file.name
124
+
125
+ @staticmethod
126
+ def cleanup_temp_files(file_paths: List[str]):
127
+ """Remove temporary files"""
128
+ for file_path in file_paths:
129
+ try:
130
+ if os.path.exists(file_path):
131
+ os.remove(file_path)
132
+ except Exception as e:
133
+ logger.warning(f"Could not remove temp file {file_path}: {e}")
134
+
135
+ @staticmethod
136
+ def ensure_directory(directory: str):
137
+ """Ensure directory exists, create if not"""
138
+ Path(directory).mkdir(parents=True, exist_ok=True)
139
+
140
+ @staticmethod
141
+ def get_file_hash(file_path: str) -> str:
142
+ """Get SHA256 hash of file"""
143
+ try:
144
+ with open(file_path, 'rb') as f:
145
+ return hashlib.sha256(f.read()).hexdigest()
146
+ except Exception as e:
147
+ logger.error(f"Error computing file hash: {e}")
148
+ return ""
149
+
150
+ class ModelManager:
151
+ """Model loading and management utilities"""
152
+
153
+ @staticmethod
154
+ def check_cuda_availability() -> bool:
155
+ """Check if CUDA is available"""
156
+ try:
157
+ import torch
158
+ return torch.cuda.is_available()
159
+ except ImportError:
160
+ return False
161
+
162
+ @staticmethod
163
+ def get_device_info() -> Dict[str, Any]:
164
+ """Get device information"""
165
+ info = {"has_cuda": False, "device_count": 0, "device_names": []}
166
+
167
+ try:
168
+ import torch
169
+ if torch.cuda.is_available():
170
+ info["has_cuda"] = True
171
+ info["device_count"] = torch.cuda.device_count()
172
+ info["device_names"] = [
173
+ torch.cuda.get_device_name(i)
174
+ for i in range(torch.cuda.device_count())
175
+ ]
176
+ except ImportError:
177
+ pass
178
+
179
+ return info
180
+
181
+ @staticmethod
182
+ def estimate_model_memory(model_name: str) -> int:
183
+ """Estimate memory requirements for model in MB"""
184
+ # Rough estimates based on common model sizes
185
+ memory_estimates = {
186
+ "whisper-tiny": 128,
187
+ "whisper-base": 256,
188
+ "whisper-small": 512,
189
+ "whisper-medium": 1024,
190
+ "nllb-200-distilled-600M": 1200,
191
+ "nllb-200-1.3B": 2600,
192
+ "speecht5": 800
193
+ }
194
+
195
+ for key, memory in memory_estimates.items():
196
+ if key in model_name.lower():
197
+ return memory
198
+
199
+ return 1000 # Default estimate
200
+
201
+ class CacheManager:
202
+ """Caching utilities"""
203
+
204
+ def __init__(self, cache_dir: str, max_size: int = 1000, ttl: int = 3600):
205
+ self.cache_dir = Path(cache_dir)
206
+ self.max_size = max_size
207
+ self.ttl = ttl # Time to live in seconds
208
+ self.cache_info = {}
209
+ self.ensure_cache_dir()
210
+
211
+ def ensure_cache_dir(self):
212
+ """Ensure cache directory exists"""
213
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
214
+
215
+ def get_cache_key(self, data: str) -> str:
216
+ """Generate cache key from data"""
217
+ return hashlib.md5(data.encode()).hexdigest()
218
+
219
+ def is_cached(self, key: str) -> bool:
220
+ """Check if key is in cache and not expired"""
221
+ cache_file = self.cache_dir / f"{key}.cache"
222
+ if not cache_file.exists():
223
+ return False
224
+
225
+ # Check TTL
226
+ if key in self.cache_info:
227
+ cache_time = self.cache_info[key]
228
+ if time.time() - cache_time > self.ttl:
229
+ self.remove_from_cache(key)
230
+ return False
231
+
232
+ return True
233
+
234
+ def get_from_cache(self, key: str) -> Optional[Any]:
235
+ """Get item from cache"""
236
+ if not self.is_cached(key):
237
+ return None
238
+
239
+ try:
240
+ cache_file = self.cache_dir / f"{key}.cache"
241
+ with open(cache_file, 'r', encoding='utf-8') as f:
242
+ return f.read()
243
+ except Exception as e:
244
+ logger.error(f"Error reading from cache: {e}")
245
+ return None
246
+
247
+ def add_to_cache(self, key: str, data: str):
248
+ """Add item to cache"""
249
+ try:
250
+ cache_file = self.cache_dir / f"{key}.cache"
251
+ with open(cache_file, 'w', encoding='utf-8') as f:
252
+ f.write(data)
253
+
254
+ self.cache_info[key] = time.time()
255
+ self.cleanup_old_cache()
256
+ except Exception as e:
257
+ logger.error(f"Error writing to cache: {e}")
258
+
259
+ def remove_from_cache(self, key: str):
260
+ """Remove item from cache"""
261
+ try:
262
+ cache_file = self.cache_dir / f"{key}.cache"
263
+ if cache_file.exists():
264
+ cache_file.unlink()
265
+
266
+ if key in self.cache_info:
267
+ del self.cache_info[key]
268
+ except Exception as e:
269
+ logger.error(f"Error removing from cache: {e}")
270
+
271
+ def cleanup_old_cache(self):
272
+ """Remove old cache entries if over max size"""
273
+ if len(self.cache_info) <= self.max_size:
274
+ return
275
+
276
+ # Sort by timestamp and remove oldest
277
+ sorted_items = sorted(self.cache_info.items(), key=lambda x: x[1])
278
+ items_to_remove = len(sorted_items) - self.max_size
279
+
280
+ for key, _ in sorted_items[:items_to_remove]:
281
+ self.remove_from_cache(key)
282
+
283
+ class MetricsTracker:
284
+ """Track performance metrics"""
285
+
286
+ def __init__(self):
287
+ self.metrics = {
288
+ "translations": 0,
289
+ "speech_recognitions": 0,
290
+ "text_to_speech": 0,
291
+ "total_processing_time": 0,
292
+ "average_processing_time": 0,
293
+ "errors": 0
294
+ }
295
+ self.start_time = time.time()
296
+
297
+ def record_translation(self, processing_time: float):
298
+ """Record a translation event"""
299
+ self.metrics["translations"] += 1
300
+ self._update_timing(processing_time)
301
+
302
+ def record_speech_recognition(self, processing_time: float):
303
+ """Record a speech recognition event"""
304
+ self.metrics["speech_recognitions"] += 1
305
+ self._update_timing(processing_time)
306
+
307
+ def record_tts(self, processing_time: float):
308
+ """Record a text-to-speech event"""
309
+ self.metrics["text_to_speech"] += 1
310
+ self._update_timing(processing_time)
311
+
312
+ def record_error(self):
313
+ """Record an error event"""
314
+ self.metrics["errors"] += 1
315
+
316
+ def _update_timing(self, processing_time: float):
317
+ """Update timing metrics"""
318
+ self.metrics["total_processing_time"] += processing_time
319
+ total_operations = (
320
+ self.metrics["translations"] +
321
+ self.metrics["speech_recognitions"] +
322
+ self.metrics["text_to_speech"]
323
+ )
324
+ if total_operations > 0:
325
+ self.metrics["average_processing_time"] = (
326
+ self.metrics["total_processing_time"] / total_operations
327
+ )
328
+
329
+ def get_stats(self) -> Dict[str, Any]:
330
+ """Get current statistics"""
331
+ uptime = time.time() - self.start_time
332
+ return {
333
+ **self.metrics,
334
+ "uptime_seconds": uptime,
335
+ "operations_per_minute": (
336
+ (self.metrics["translations"] +
337
+ self.metrics["speech_recognitions"] +
338
+ self.metrics["text_to_speech"]) / (uptime / 60)
339
+ if uptime > 0 else 0
340
+ )
341
+ }
342
+
343
+ # Utility functions
344
+ def format_duration(seconds: float) -> str:
345
+ """Format duration in human-readable format"""
346
+ if seconds < 60:
347
+ return f"{seconds:.1f}s"
348
+ elif seconds < 3600:
349
+ minutes = int(seconds // 60)
350
+ secs = int(seconds % 60)
351
+ return f"{minutes}m {secs}s"
352
+ else:
353
+ hours = int(seconds // 3600)
354
+ minutes = int((seconds % 3600) // 60)
355
+ return f"{hours}h {minutes}m"
356
+
357
+ def validate_language_code(code: str, supported_languages: Dict[str, str]) -> bool:
358
+ """Validate language code"""
359
+ return code in supported_languages
360
+
361
+ def extract_language_code(display_string: str) -> str:
362
+ """Extract language code from display string like 'en - English'"""
363
+ return display_string.split(' - ')[0] if ' - ' in display_string else display_string
364
+
365
+ def create_progress_callback(progress_bar=None):
366
+ """Create progress callback for long-running operations"""
367
+ def callback(current: int, total: int):
368
+ if progress_bar:
369
+ progress_bar.progress(current / total)
370
+ return callback