ayaka68 commited on
Commit
84e84dd
·
verified ·
1 Parent(s): 574f756

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -22
app.py CHANGED
@@ -310,9 +310,8 @@ def _normalize_label(lbl: str) -> str:
310
 
311
  def predict_emotion_ai(audio_bytes):
312
  """
313
- S3PRL Featurizer → [B,T,H] reps_len を受け取り、
314
- reps_len int / list / tuple / Tensor / None のいずれでも動くよう正規化して
315
- 有効長のみ平均化([B,H])→ 線形ヘッドで分類。
316
  """
317
  try:
318
  featurizer, head, id2label, device = load_kushinada_s3prl()
@@ -333,27 +332,29 @@ def predict_emotion_ai(audio_bytes):
333
  y = y[:max_samples]
334
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
335
 
336
- # Featurizer想定の入力(CPU list でOK)
337
  wavs = [torch.tensor(y, dtype=torch.float32)]
338
  wavs_len = [int(len(y))]
339
 
340
  with torch.no_grad():
341
- reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H] か [T,H]、reps_len: list/int/Tensor/None
342
-
343
- # --- reps [B,T,H] に統一 ---
344
- if isinstance(reps, torch.Tensor):
345
- if reps.dim() == 2: # [T,H] → [1,T,H]
346
- reps = reps.unsqueeze(0)
347
- elif reps.dim() != 3:
348
- raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
349
-
 
 
 
350
  else:
351
- # 念のため、非テンソルなら失敗扱い(通常ここには来ない)
352
- raise RuntimeError(f"Unexpected reps type: {type(reps)}")
353
 
354
  B, T, H = reps.shape
355
 
356
- # --- reps_len を [B] のリストに正規化 ---
357
  if reps_len is None:
358
  reps_len_list = [T] * B
359
  elif isinstance(reps_len, int):
@@ -361,7 +362,6 @@ def predict_emotion_ai(audio_bytes):
361
  elif isinstance(reps_len, (list, tuple)):
362
  reps_len_list = [int(x) for x in reps_len]
363
  if len(reps_len_list) != B:
364
- # 長さが合わなければ T で埋める
365
  reps_len_list = [T] * B
366
  elif isinstance(reps_len, torch.Tensor):
367
  reps_len_list = reps_len.view(-1).tolist()
@@ -370,13 +370,16 @@ def predict_emotion_ai(audio_bytes):
370
  else:
371
  reps_len_list = [T] * B
372
 
373
- # 安全に 1..T にクリップ
374
  reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
375
 
376
- # --- 有効長のみ平均して [B,H] ---
377
- pooled = torch.stack([reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)], dim=0) # [B,H]
 
 
 
378
 
379
- # --- 線形ヘッドで分類 ---
380
  logits = head.fc(pooled.to(device)) # [B,C]
381
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
382
 
@@ -384,7 +387,7 @@ def predict_emotion_ai(audio_bytes):
384
  raw_label = id2label[pred_id]
385
 
386
  def _norm(lbl: str) -> str:
387
- m = {"happy":"happiness", "angry":"anger", "sad":"sadness", "neutral":"neutral"}
388
  return m.get(lbl.lower(), lbl)
389
 
390
  label = _norm(raw_label)
 
310
 
311
  def predict_emotion_ai(audio_bytes):
312
  """
313
+ S3PRL Featurizer → reps([B,T,H] | [T,H] | [H])と reps_len(int/list/tensor/None)を受け取り、
314
+ 形を正規化して有効長で時間平均 [B,H] 線形ヘッドで分類。
 
315
  """
316
  try:
317
  featurizer, head, id2label, device = load_kushinada_s3prl()
 
332
  y = y[:max_samples]
333
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
334
 
335
+ # Featurizer list 入力想定
336
  wavs = [torch.tensor(y, dtype=torch.float32)]
337
  wavs_len = [int(len(y))]
338
 
339
  with torch.no_grad():
340
+ reps, reps_len = featurizer(wavs, wavs_len) # reps: Tensor, reps_len: int/list/Tensor/None になる
341
+
342
+ # ---- reps を必ず [B,T,H] に統一 ----
343
+ if not isinstance(reps, torch.Tensor):
344
+ raise RuntimeError(f"Unexpected reps type: {type(reps)} (Tensor想定)")
345
+
346
+ if reps.dim() == 3: # [B,T,H] そのまま
347
+ pass
348
+ elif reps.dim() == 2: # [T,H] -> [1,T,H]
349
+ reps = reps.unsqueeze(0)
350
+ elif reps.dim() == 1: # [H] -> [1,1,H] ← ★今回ここを追加
351
+ reps = reps.unsqueeze(0).unsqueeze(0)
352
  else:
353
+ raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
 
354
 
355
  B, T, H = reps.shape
356
 
357
+ # ---- reps_len を [B] リストに正規化 ----
358
  if reps_len is None:
359
  reps_len_list = [T] * B
360
  elif isinstance(reps_len, int):
 
362
  elif isinstance(reps_len, (list, tuple)):
363
  reps_len_list = [int(x) for x in reps_len]
364
  if len(reps_len_list) != B:
 
365
  reps_len_list = [T] * B
366
  elif isinstance(reps_len, torch.Tensor):
367
  reps_len_list = reps_len.view(-1).tolist()
 
370
  else:
371
  reps_len_list = [T] * B
372
 
373
+ # [1,1,H] になったケースでも安全に 1..T へクリップ
374
  reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
375
 
376
+ # ---- 有効長のみで平均して [B,H] ----
377
+ pooled = torch.stack(
378
+ [reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)],
379
+ dim=0
380
+ ) # [B,H]
381
 
382
+ # ---- 線形ヘッドで分類 ----
383
  logits = head.fc(pooled.to(device)) # [B,C]
384
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
385
 
 
387
  raw_label = id2label[pred_id]
388
 
389
  def _norm(lbl: str) -> str:
390
+ m = {"happy": "happiness", "angry": "anger", "sad": "sadness", "neutral": "neutral"}
391
  return m.get(lbl.lower(), lbl)
392
 
393
  label = _norm(raw_label)