R-TA commited on
Commit
4d0887b
·
verified ·
1 Parent(s): 91cdcbf

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +148 -0
utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for Multi-Language TTS application
3
+ """
4
+
5
+ import os
6
+ import tempfile
7
+ import logging
8
+ from typing import Optional, Tuple, List
9
+ import numpy as np
10
+ import torch
11
+ import librosa
12
+ from pathlib import Path
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def get_device() -> str:
17
+ """Get the best available device for inference"""
18
+ if torch.cuda.is_available():
19
+ return "cuda"
20
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
21
+ return "mps" # Apple Silicon
22
+ else:
23
+ return "cpu"
24
+
25
+ def validate_text(text: str, max_length: int = 1000) -> str:
26
+ """Validate and clean input text"""
27
+ if not text or not text.strip():
28
+ raise ValueError("Text cannot be empty")
29
+
30
+ text = text.strip()
31
+ if len(text) > max_length:
32
+ logger.warning(f"Text truncated from {len(text)} to {max_length} characters")
33
+ text = text[:max_length]
34
+
35
+ return text
36
+
37
+ def validate_audio_file(file_path: str) -> bool:
38
+ """Validate audio file format and accessibility"""
39
+ if not file_path or not os.path.exists(file_path):
40
+ return False
41
+
42
+ supported_formats = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
43
+ file_ext = Path(file_path).suffix.lower()
44
+
45
+ return file_ext in supported_formats
46
+
47
+ def create_temp_audio_file(suffix: str = ".wav") -> str:
48
+ """Create a temporary audio file"""
49
+ temp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
50
+ temp_file.close()
51
+ return temp_file.name
52
+
53
+ def cleanup_temp_file(file_path: str) -> None:
54
+ """Safely remove temporary file"""
55
+ try:
56
+ if file_path and os.path.exists(file_path):
57
+ os.unlink(file_path)
58
+ except Exception as e:
59
+ logger.warning(f"Failed to cleanup temp file {file_path}: {e}")
60
+
61
+ def load_audio(file_path: str, target_sr: int = 22050) -> Tuple[np.ndarray, int]:
62
+ """Load audio file with proper error handling"""
63
+ try:
64
+ audio, sr = librosa.load(file_path, sr=target_sr)
65
+ return audio, sr
66
+ except Exception as e:
67
+ logger.error(f"Failed to load audio from {file_path}: {e}")
68
+ raise ValueError(f"Could not load audio file: {e}")
69
+
70
+ def normalize_audio(audio: np.ndarray) -> np.ndarray:
71
+ """Normalize audio to prevent clipping"""
72
+ if len(audio) == 0:
73
+ return audio
74
+
75
+ # Normalize to [-1, 1] range
76
+ max_val = np.max(np.abs(audio))
77
+ if max_val > 0:
78
+ audio = audio / max_val
79
+
80
+ return audio
81
+
82
+ def get_supported_languages() -> List[str]:
83
+ """Get list of supported languages"""
84
+ from config import LANGUAGE_MODELS
85
+ return list(LANGUAGE_MODELS.keys())
86
+
87
+ def format_duration(seconds: float) -> str:
88
+ """Format duration in seconds to human readable format"""
89
+ if seconds < 1:
90
+ return f"{seconds*1000:.0f}ms"
91
+ elif seconds < 60:
92
+ return f"{seconds:.1f}s"
93
+ else:
94
+ minutes = int(seconds // 60)
95
+ seconds = seconds % 60
96
+ return f"{minutes}m {seconds:.1f}s"
97
+
98
+ def estimate_synthesis_time(text_length: int, language: str = "English") -> float:
99
+ """Estimate synthesis time based on text length and language"""
100
+ # Base time estimates (seconds per character)
101
+ base_times = {
102
+ "English": 0.05,
103
+ "Korean": 0.08,
104
+ "German": 0.06,
105
+ "Spanish": 0.05
106
+ }
107
+
108
+ base_time = base_times.get(language, 0.06)
109
+ return text_length * base_time + 2.0 # Add 2s overhead
110
+
111
+ def log_system_info():
112
+ """Log system information for debugging"""
113
+ logger.info(f"Device: {get_device()}")
114
+ logger.info(f"PyTorch version: {torch.__version__}")
115
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
116
+
117
+ if torch.cuda.is_available():
118
+ logger.info(f"CUDA device: {torch.cuda.get_device_name()}")
119
+ logger.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
120
+
121
+ class AudioProcessor:
122
+ """Audio processing utilities"""
123
+
124
+ @staticmethod
125
+ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
126
+ """Resample audio to target sample rate"""
127
+ if orig_sr == target_sr:
128
+ return audio
129
+ return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
130
+
131
+ @staticmethod
132
+ def trim_silence(audio: np.ndarray, sr: int, threshold: float = 0.01) -> np.ndarray:
133
+ """Trim silence from beginning and end of audio"""
134
+ return librosa.effects.trim(audio, top_db=20)[0]
135
+
136
+ @staticmethod
137
+ def apply_fade(audio: np.ndarray, sr: int, fade_duration: float = 0.1) -> np.ndarray:
138
+ """Apply fade in/out to audio"""
139
+ fade_samples = int(fade_duration * sr)
140
+ if len(audio) <= 2 * fade_samples:
141
+ return audio
142
+
143
+ # Fade in
144
+ audio[:fade_samples] *= np.linspace(0, 1, fade_samples)
145
+ # Fade out
146
+ audio[-fade_samples:] *= np.linspace(1, 0, fade_samples)
147
+
148
+ return audio