ayaka68 commited on
Commit
acd8897
·
verified ·
1 Parent(s): c36c080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -299,8 +299,9 @@ def _normalize_label(lbl: str) -> str:
299
 
300
  def predict_emotion_ai(audio_bytes):
301
  """
302
- S3PRL Featurizer で必ず [B,T,H] を取得 → 各サンプルの有効長 reps_len で時間平均 → [B,H]。
303
- 後、線形ヘッド(W,b)分類。
 
304
  """
305
  try:
306
  featurizer, head, id2label, device = load_kushinada_s3prl()
@@ -321,25 +322,51 @@ def predict_emotion_ai(audio_bytes):
321
  y = y[:max_samples]
322
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
323
 
324
- # S3PRLは list[Tensor], list[int] を想定
325
  wavs = [torch.tensor(y, dtype=torch.float32)]
326
  wavs_len = [int(len(y))]
327
 
328
  with torch.no_grad():
329
- reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H], reps_len: list[int] or Tensor[B]
330
- if isinstance(reps_len, torch.Tensor):
331
- reps_len = reps_len.tolist()
332
-
333
- # 有効長のみで平均(パディングを無視)
334
- pooled = []
335
- for i in range(reps.shape[0]):
336
- Ti = int(reps_len[i]) if reps_len else reps.shape[1]
337
- Ti = max(1, min(Ti, reps.shape[1])) # 安全側
338
- pooled.append(reps[i, :Ti].mean(dim=0))
339
- pooled = torch.stack(pooled, dim=0) # [B,H]
340
-
341
- # 線形ヘッドで分類(head.fcに直接入れる)
342
- logits = head.fc(pooled.to(device)) # [B,C]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
344
 
345
  pred_id = int(np.argmax(probs))
@@ -360,6 +387,7 @@ def predict_emotion_ai(audio_bytes):
360
  st.warning(f"AI予測中にエラーが発生: {e}")
361
  return predict_emotion_features(audio_bytes)
362
 
 
363
  # ===== 推薦 =====
364
  def score_places(emo_label, top_k=4, diversity=True):
365
  EMO_MAP_PRIORS = {
 
299
 
300
  def predict_emotion_ai(audio_bytes):
301
  """
302
+ S3PRL Featurizer [B,T,H] reps_len を受け取り、
303
+ reps_len が int / list / tuple / Tensor / None いずれも動くよう正規化して
304
+ 有効長のみ平均化([B,H])→ 線形ヘッドで分類。
305
  """
306
  try:
307
  featurizer, head, id2label, device = load_kushinada_s3prl()
 
322
  y = y[:max_samples]
323
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
324
 
325
+ # Featurizer想定の入力(CPU list でOK)
326
  wavs = [torch.tensor(y, dtype=torch.float32)]
327
  wavs_len = [int(len(y))]
328
 
329
  with torch.no_grad():
330
+ reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H] か [T,H]、reps_len: list/int/Tensor/None
331
+
332
+ # --- reps を [B,T,H] に統一 ---
333
+ if isinstance(reps, torch.Tensor):
334
+ if reps.dim() == 2: # [T,H] → [1,T,H]
335
+ reps = reps.unsqueeze(0)
336
+ elif reps.dim() != 3:
337
+ raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
338
+
339
+ else:
340
+ # 念のため、非テンソルなら失敗扱い(通常ここには来ない)
341
+ raise RuntimeError(f"Unexpected reps type: {type(reps)}")
342
+
343
+ B, T, H = reps.shape
344
+
345
+ # --- reps_len を [B] のリストに正規化 ---
346
+ if reps_len is None:
347
+ reps_len_list = [T] * B
348
+ elif isinstance(reps_len, int):
349
+ reps_len_list = [int(reps_len)] * B
350
+ elif isinstance(reps_len, (list, tuple)):
351
+ reps_len_list = [int(x) for x in reps_len]
352
+ if len(reps_len_list) != B:
353
+ # 長さが合わなければ T で埋める
354
+ reps_len_list = [T] * B
355
+ elif isinstance(reps_len, torch.Tensor):
356
+ reps_len_list = reps_len.view(-1).tolist()
357
+ if len(reps_len_list) != B:
358
+ reps_len_list = [T] * B
359
+ else:
360
+ reps_len_list = [T] * B
361
+
362
+ # 安全に 1..T にクリップ
363
+ reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
364
+
365
+ # --- 有効長のみ平均して [B,H] ---
366
+ pooled = torch.stack([reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)], dim=0) # [B,H]
367
+
368
+ # --- 線形ヘッドで分類 ---
369
+ logits = head.fc(pooled.to(device)) # [B,C]
370
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
371
 
372
  pred_id = int(np.argmax(probs))
 
387
  st.warning(f"AI予測中にエラーが発生: {e}")
388
  return predict_emotion_features(audio_bytes)
389
 
390
+
391
  # ===== 推薦 =====
392
  def score_places(emo_label, top_k=4, diversity=True):
393
  EMO_MAP_PRIORS = {