Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -299,8 +299,9 @@ def _normalize_label(lbl: str) -> str:
|
|
| 299 |
|
| 300 |
def predict_emotion_ai(audio_bytes):
|
| 301 |
"""
|
| 302 |
-
S3PRL Featurizer
|
| 303 |
-
|
|
|
|
| 304 |
"""
|
| 305 |
try:
|
| 306 |
featurizer, head, id2label, device = load_kushinada_s3prl()
|
|
@@ -321,25 +322,51 @@ def predict_emotion_ai(audio_bytes):
|
|
| 321 |
y = y[:max_samples]
|
| 322 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 323 |
|
| 324 |
-
#
|
| 325 |
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 326 |
wavs_len = [int(len(y))]
|
| 327 |
|
| 328 |
with torch.no_grad():
|
| 329 |
-
reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H]
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 344 |
|
| 345 |
pred_id = int(np.argmax(probs))
|
|
@@ -360,6 +387,7 @@ def predict_emotion_ai(audio_bytes):
|
|
| 360 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
| 361 |
return predict_emotion_features(audio_bytes)
|
| 362 |
|
|
|
|
| 363 |
# ===== 推薦 =====
|
| 364 |
def score_places(emo_label, top_k=4, diversity=True):
|
| 365 |
EMO_MAP_PRIORS = {
|
|
|
|
| 299 |
|
| 300 |
def predict_emotion_ai(audio_bytes):
|
| 301 |
"""
|
| 302 |
+
S3PRL Featurizer → [B,T,H] と reps_len を受け取り、
|
| 303 |
+
reps_len が int / list / tuple / Tensor / None のいずれでも動くよう正規化して
|
| 304 |
+
有効長のみ平均化([B,H])→ 線形ヘッドで分類。
|
| 305 |
"""
|
| 306 |
try:
|
| 307 |
featurizer, head, id2label, device = load_kushinada_s3prl()
|
|
|
|
| 322 |
y = y[:max_samples]
|
| 323 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 324 |
|
| 325 |
+
# Featurizer想定の入力(CPU list でOK)
|
| 326 |
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 327 |
wavs_len = [int(len(y))]
|
| 328 |
|
| 329 |
with torch.no_grad():
|
| 330 |
+
reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H] か [T,H]、reps_len: list/int/Tensor/None
|
| 331 |
+
|
| 332 |
+
# --- reps を [B,T,H] に統一 ---
|
| 333 |
+
if isinstance(reps, torch.Tensor):
|
| 334 |
+
if reps.dim() == 2: # [T,H] → [1,T,H]
|
| 335 |
+
reps = reps.unsqueeze(0)
|
| 336 |
+
elif reps.dim() != 3:
|
| 337 |
+
raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
|
| 338 |
+
|
| 339 |
+
else:
|
| 340 |
+
# 念のため、非テンソルなら失敗扱い(通常ここには来ない)
|
| 341 |
+
raise RuntimeError(f"Unexpected reps type: {type(reps)}")
|
| 342 |
+
|
| 343 |
+
B, T, H = reps.shape
|
| 344 |
+
|
| 345 |
+
# --- reps_len を [B] のリストに正規化 ---
|
| 346 |
+
if reps_len is None:
|
| 347 |
+
reps_len_list = [T] * B
|
| 348 |
+
elif isinstance(reps_len, int):
|
| 349 |
+
reps_len_list = [int(reps_len)] * B
|
| 350 |
+
elif isinstance(reps_len, (list, tuple)):
|
| 351 |
+
reps_len_list = [int(x) for x in reps_len]
|
| 352 |
+
if len(reps_len_list) != B:
|
| 353 |
+
# 長さが合わなければ T で埋める
|
| 354 |
+
reps_len_list = [T] * B
|
| 355 |
+
elif isinstance(reps_len, torch.Tensor):
|
| 356 |
+
reps_len_list = reps_len.view(-1).tolist()
|
| 357 |
+
if len(reps_len_list) != B:
|
| 358 |
+
reps_len_list = [T] * B
|
| 359 |
+
else:
|
| 360 |
+
reps_len_list = [T] * B
|
| 361 |
+
|
| 362 |
+
# 安全に 1..T にクリップ
|
| 363 |
+
reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
|
| 364 |
+
|
| 365 |
+
# --- 有効長のみ平均して [B,H] ---
|
| 366 |
+
pooled = torch.stack([reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)], dim=0) # [B,H]
|
| 367 |
+
|
| 368 |
+
# --- 線形ヘッドで分類 ---
|
| 369 |
+
logits = head.fc(pooled.to(device)) # [B,C]
|
| 370 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 371 |
|
| 372 |
pred_id = int(np.argmax(probs))
|
|
|
|
| 387 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
| 388 |
return predict_emotion_features(audio_bytes)
|
| 389 |
|
| 390 |
+
|
| 391 |
# ===== 推薦 =====
|
| 392 |
def score_places(emo_label, top_k=4, diversity=True):
|
| 393 |
EMO_MAP_PRIORS = {
|