Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -310,9 +310,8 @@ def _normalize_label(lbl: str) -> str:
|
|
| 310 |
|
| 311 |
def predict_emotion_ai(audio_bytes):
|
| 312 |
"""
|
| 313 |
-
S3PRL Featurizer → [B,T,H]
|
| 314 |
-
|
| 315 |
-
有効長のみ平均化([B,H])→ 線形ヘッドで分類。
|
| 316 |
"""
|
| 317 |
try:
|
| 318 |
featurizer, head, id2label, device = load_kushinada_s3prl()
|
|
@@ -333,27 +332,29 @@ def predict_emotion_ai(audio_bytes):
|
|
| 333 |
y = y[:max_samples]
|
| 334 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 335 |
|
| 336 |
-
# Featurizer
|
| 337 |
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 338 |
wavs_len = [int(len(y))]
|
| 339 |
|
| 340 |
with torch.no_grad():
|
| 341 |
-
reps, reps_len = featurizer(wavs, wavs_len) # reps:
|
| 342 |
-
|
| 343 |
-
#
|
| 344 |
-
if isinstance(reps, torch.Tensor):
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
| 350 |
else:
|
| 351 |
-
|
| 352 |
-
raise RuntimeError(f"Unexpected reps type: {type(reps)}")
|
| 353 |
|
| 354 |
B, T, H = reps.shape
|
| 355 |
|
| 356 |
-
#
|
| 357 |
if reps_len is None:
|
| 358 |
reps_len_list = [T] * B
|
| 359 |
elif isinstance(reps_len, int):
|
|
@@ -361,7 +362,6 @@ def predict_emotion_ai(audio_bytes):
|
|
| 361 |
elif isinstance(reps_len, (list, tuple)):
|
| 362 |
reps_len_list = [int(x) for x in reps_len]
|
| 363 |
if len(reps_len_list) != B:
|
| 364 |
-
# 長さが合わなければ T で埋める
|
| 365 |
reps_len_list = [T] * B
|
| 366 |
elif isinstance(reps_len, torch.Tensor):
|
| 367 |
reps_len_list = reps_len.view(-1).tolist()
|
|
@@ -370,13 +370,16 @@ def predict_emotion_ai(audio_bytes):
|
|
| 370 |
else:
|
| 371 |
reps_len_list = [T] * B
|
| 372 |
|
| 373 |
-
#
|
| 374 |
reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
|
| 375 |
|
| 376 |
-
#
|
| 377 |
-
pooled = torch.stack(
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
-
#
|
| 380 |
logits = head.fc(pooled.to(device)) # [B,C]
|
| 381 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 382 |
|
|
@@ -384,7 +387,7 @@ def predict_emotion_ai(audio_bytes):
|
|
| 384 |
raw_label = id2label[pred_id]
|
| 385 |
|
| 386 |
def _norm(lbl: str) -> str:
|
| 387 |
-
m = {"happy":"happiness", "angry":"anger", "sad":"sadness", "neutral":"neutral"}
|
| 388 |
return m.get(lbl.lower(), lbl)
|
| 389 |
|
| 390 |
label = _norm(raw_label)
|
|
|
|
| 310 |
|
| 311 |
def predict_emotion_ai(audio_bytes):
|
| 312 |
"""
|
| 313 |
+
S3PRL Featurizer → reps([B,T,H] | [T,H] | [H])と reps_len(int/list/tensor/None)を受け取り、
|
| 314 |
+
形を正規化して有効長で時間平均 → [B,H] → 線形ヘッドで分類。
|
|
|
|
| 315 |
"""
|
| 316 |
try:
|
| 317 |
featurizer, head, id2label, device = load_kushinada_s3prl()
|
|
|
|
| 332 |
y = y[:max_samples]
|
| 333 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 334 |
|
| 335 |
+
# Featurizer は list 入力想定
|
| 336 |
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 337 |
wavs_len = [int(len(y))]
|
| 338 |
|
| 339 |
with torch.no_grad():
|
| 340 |
+
reps, reps_len = featurizer(wavs, wavs_len) # reps: Tensor, reps_len: int/list/Tensor/None になる
|
| 341 |
+
|
| 342 |
+
# ---- reps を必ず [B,T,H] に統一 ----
|
| 343 |
+
if not isinstance(reps, torch.Tensor):
|
| 344 |
+
raise RuntimeError(f"Unexpected reps type: {type(reps)} (Tensor想定)")
|
| 345 |
+
|
| 346 |
+
if reps.dim() == 3: # [B,T,H] そのまま
|
| 347 |
+
pass
|
| 348 |
+
elif reps.dim() == 2: # [T,H] -> [1,T,H]
|
| 349 |
+
reps = reps.unsqueeze(0)
|
| 350 |
+
elif reps.dim() == 1: # [H] -> [1,1,H] ← ★今回ここを追加
|
| 351 |
+
reps = reps.unsqueeze(0).unsqueeze(0)
|
| 352 |
else:
|
| 353 |
+
raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
|
|
|
|
| 354 |
|
| 355 |
B, T, H = reps.shape
|
| 356 |
|
| 357 |
+
# ---- reps_len を [B] リストに正規化 ----
|
| 358 |
if reps_len is None:
|
| 359 |
reps_len_list = [T] * B
|
| 360 |
elif isinstance(reps_len, int):
|
|
|
|
| 362 |
elif isinstance(reps_len, (list, tuple)):
|
| 363 |
reps_len_list = [int(x) for x in reps_len]
|
| 364 |
if len(reps_len_list) != B:
|
|
|
|
| 365 |
reps_len_list = [T] * B
|
| 366 |
elif isinstance(reps_len, torch.Tensor):
|
| 367 |
reps_len_list = reps_len.view(-1).tolist()
|
|
|
|
| 370 |
else:
|
| 371 |
reps_len_list = [T] * B
|
| 372 |
|
| 373 |
+
# [1,1,H] になったケースでも安全に 1..T へクリップ
|
| 374 |
reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
|
| 375 |
|
| 376 |
+
# ---- 有効長のみで平均して [B,H] ----
|
| 377 |
+
pooled = torch.stack(
|
| 378 |
+
[reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)],
|
| 379 |
+
dim=0
|
| 380 |
+
) # [B,H]
|
| 381 |
|
| 382 |
+
# ---- 線形ヘッドで分類 ----
|
| 383 |
logits = head.fc(pooled.to(device)) # [B,C]
|
| 384 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 385 |
|
|
|
|
| 387 |
raw_label = id2label[pred_id]
|
| 388 |
|
| 389 |
def _norm(lbl: str) -> str:
|
| 390 |
+
m = {"happy": "happiness", "angry": "anger", "sad": "sadness", "neutral": "neutral"}
|
| 391 |
return m.get(lbl.lower(), lbl)
|
| 392 |
|
| 393 |
label = _norm(raw_label)
|