Thanh-Lam commited on
Commit
ec8293e
·
1 Parent(s): 8f4a2bc

Fix: correct gender/dialect label mapping (Female=0, Male=1) and remove trim-causing sources param

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -27,9 +27,17 @@ MODELS_CONFIG = {
27
  }
28
  }
29
 
30
- # Labels
31
- GENDER_LABELS = ["Male", "Female"]
32
- DIALECT_LABELS = ["Northern", "Central", "Southern"]
 
 
 
 
 
 
 
 
33
 
34
 
35
  class MultiModelProfiler:
@@ -38,6 +46,7 @@ class MultiModelProfiler:
38
  def __init__(self):
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  self.sampling_rate = 16000
 
41
  self.models = {}
42
  self.processors = {}
43
  self.current_model = None
@@ -129,20 +138,29 @@ class MultiModelProfiler:
129
  processor = self.processors[model_name]
130
  is_whisper = MODELS_CONFIG[model_name]["is_whisper"]
131
 
132
- # Load audio using librosa (more compatible)
 
 
 
 
 
 
133
  waveform, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
134
 
 
 
 
 
 
135
  # Process based on model type
136
  if is_whisper:
137
- # Whisper requires exactly 30 seconds of audio
138
- whisper_length = self.sampling_rate * 30 # 480000 samples
139
  if len(waveform) < whisper_length:
140
- waveform_padded = np.pad(waveform, (0, whisper_length - len(waveform)))
141
- else:
142
- waveform_padded = waveform[:whisper_length]
143
 
144
  inputs = processor(
145
- waveform_padded,
146
  sampling_rate=self.sampling_rate,
147
  return_tensors="pt"
148
  )
@@ -163,14 +181,14 @@ class MultiModelProfiler:
163
  gender_logits = outputs['gender_logits']
164
  dialect_logits = outputs['dialect_logits']
165
 
166
- gender_probs = torch.softmax(gender_logits, dim=-1)
167
- dialect_probs = torch.softmax(dialect_logits, dim=-1)
168
 
169
- gender_idx = gender_probs.argmax(dim=-1).item()
170
- dialect_idx = dialect_probs.argmax(dim=-1).item()
171
 
172
- gender_conf = gender_probs[0, gender_idx].item() * 100
173
- dialect_conf = dialect_probs[0, dialect_idx].item() * 100
174
 
175
  gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)"
176
  dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)"
@@ -223,8 +241,7 @@ def create_interface():
223
  gr.Markdown("### Input")
224
  audio_input = gr.Audio(
225
  label="Upload or Record Audio",
226
- type="filepath",
227
- sources=["upload", "microphone"]
228
  )
229
 
230
  model_dropdown = gr.Dropdown(
@@ -247,9 +264,9 @@ def create_interface():
247
  gr.Markdown(
248
  """
249
  ### Dialect Regions
250
- - **Northern**: Hanoi and surrounding areas
251
  - **Central**: Hue, Da Nang, and Central Vietnam
252
- - **Southern**: Ho Chi Minh City and Southern Vietnam
253
  """
254
  )
255
 
 
27
  }
28
  }
29
 
30
+ # Labels - IMPORTANT: Must match training order!
31
+ # Model was trained with Female=0, Male=1
32
+ GENDER_LABELS = {
33
+ 0: "Female",
34
+ 1: "Male"
35
+ }
36
+ DIALECT_LABELS = {
37
+ 0: "North",
38
+ 1: "Central",
39
+ 2: "South"
40
+ }
41
 
42
 
43
  class MultiModelProfiler:
 
46
  def __init__(self):
47
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
  self.sampling_rate = 16000
49
+ self.max_duration = 5 # seconds for non-whisper models
50
  self.models = {}
51
  self.processors = {}
52
  self.current_model = None
 
138
  processor = self.processors[model_name]
139
  is_whisper = MODELS_CONFIG[model_name]["is_whisper"]
140
 
141
+ # Set max duration based on model type
142
+ if is_whisper:
143
+ max_duration = 30 # Whisper requires 30 seconds
144
+ else:
145
+ max_duration = self.max_duration
146
+
147
+ # Load audio using librosa
148
  waveform, sr = librosa.load(audio_path, sr=self.sampling_rate, mono=True)
149
 
150
+ # Trim to max duration
151
+ max_samples = int(max_duration * self.sampling_rate)
152
+ if len(waveform) > max_samples:
153
+ waveform = waveform[:max_samples]
154
+
155
  # Process based on model type
156
  if is_whisper:
157
+ # Whisper requires exactly 30 seconds - pad if needed
158
+ whisper_length = self.sampling_rate * 30
159
  if len(waveform) < whisper_length:
160
+ waveform = np.pad(waveform, (0, whisper_length - len(waveform)))
 
 
161
 
162
  inputs = processor(
163
+ waveform,
164
  sampling_rate=self.sampling_rate,
165
  return_tensors="pt"
166
  )
 
181
  gender_logits = outputs['gender_logits']
182
  dialect_logits = outputs['dialect_logits']
183
 
184
+ gender_probs = torch.softmax(gender_logits, dim=-1).cpu().numpy()[0]
185
+ dialect_probs = torch.softmax(dialect_logits, dim=-1).cpu().numpy()[0]
186
 
187
+ gender_idx = int(np.argmax(gender_probs))
188
+ dialect_idx = int(np.argmax(dialect_probs))
189
 
190
+ gender_conf = float(gender_probs[gender_idx]) * 100
191
+ dialect_conf = float(dialect_probs[dialect_idx]) * 100
192
 
193
  gender_result = f"{GENDER_LABELS[gender_idx]} ({gender_conf:.1f}%)"
194
  dialect_result = f"{DIALECT_LABELS[dialect_idx]} ({dialect_conf:.1f}%)"
 
241
  gr.Markdown("### Input")
242
  audio_input = gr.Audio(
243
  label="Upload or Record Audio",
244
+ type="filepath"
 
245
  )
246
 
247
  model_dropdown = gr.Dropdown(
 
264
  gr.Markdown(
265
  """
266
  ### Dialect Regions
267
+ - **North**: Hanoi and surrounding areas
268
  - **Central**: Hue, Da Nang, and Central Vietnam
269
+ - **South**: Ho Chi Minh City and Southern Vietnam
270
  """
271
  )
272