Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,7 +26,7 @@ import japanize_matplotlib
|
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
| 28 |
from huggingface_hub import list_repo_files, hf_hub_download
|
| 29 |
-
from s3prl.nn import S3PRLUpstream
|
| 30 |
|
| 31 |
# ===== フォント設定 =====
|
| 32 |
rcParams["font.family"] = "DejaVu Sans"
|
|
@@ -80,67 +80,50 @@ KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
|
|
| 80 |
@st.cache_resource(show_spinner=False)
|
| 81 |
def load_kushinada_s3prl():
|
| 82 |
"""
|
| 83 |
-
S3PRL上流(HuBERT base)
|
| 84 |
-
|
| 85 |
-
- サブフォルダ内も対象
|
| 86 |
-
- 必要なら KUSHINADA_FILENAME / KUSHINADA_REVISION を Secrets に設定して固定
|
| 87 |
"""
|
| 88 |
token = os.getenv("HF_TOKEN")
|
| 89 |
if not token:
|
| 90 |
raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
|
| 91 |
|
| 92 |
revision = os.getenv("KUSHINADA_REVISION", "main")
|
| 93 |
-
prefer_filename = os.getenv("KUSHINADA_FILENAME") # 例: "
|
| 94 |
|
| 95 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 96 |
|
| 97 |
-
# 1)
|
| 98 |
upstream = S3PRLUpstream("hubert_base").to(device).eval()
|
|
|
|
| 99 |
|
| 100 |
-
# 2)
|
| 101 |
api = HfApi()
|
| 102 |
info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
|
| 103 |
-
all_files = [s.rfilename for s in info.siblings]
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
with st.expander("📦 モデル内ファイル一覧(デバッグ)", expanded=False):
|
| 107 |
-
st.write(all_files)
|
| 108 |
-
|
| 109 |
-
# 3) 候補ファイルの決定
|
| 110 |
exts = (".pt", ".ckpt", ".pth", ".bin")
|
| 111 |
candidates = [f for f in all_files if f.lower().endswith(exts)]
|
| 112 |
-
|
| 113 |
-
# Secretsで明示指定があればそれを優先
|
| 114 |
filename = None
|
| 115 |
if prefer_filename:
|
| 116 |
if prefer_filename in all_files:
|
| 117 |
filename = prefer_filename
|
| 118 |
else:
|
| 119 |
-
# サブフォルダなしで指定された場合に補正を試みる
|
| 120 |
matches = [f for f in all_files if f.endswith(prefer_filename)]
|
| 121 |
if matches:
|
| 122 |
filename = matches[0]
|
| 123 |
-
|
| 124 |
-
# それでも未決なら候補の先頭を採用
|
| 125 |
if filename is None and candidates:
|
| 126 |
-
# なるべく "downstream", "classifier", "jtes" を含むものを優先
|
| 127 |
ranked = sorted(
|
| 128 |
candidates,
|
| 129 |
key=lambda f: (
|
| 130 |
-
-int(any(k in f.lower() for k in ["downstream",
|
| 131 |
len(f)
|
| 132 |
)
|
| 133 |
)
|
| 134 |
filename = ranked[0] if ranked else None
|
| 135 |
-
|
| 136 |
if filename is None:
|
| 137 |
-
raise FileNotFoundError(
|
| 138 |
-
"下流チェックポイント(.pt/.ckpt/.pth/.bin)が見つかりません。\n"
|
| 139 |
-
"モデルページの Files でファイル名を確認し、SpacesのSecretsに "
|
| 140 |
-
"KUSHINADA_FILENAME として保存してください。"
|
| 141 |
-
)
|
| 142 |
|
| 143 |
-
# 4) チェックポイントを取得
|
| 144 |
ckpt_path = hf_hub_download(
|
| 145 |
repo_id=KUSHINADA_REPO,
|
| 146 |
filename=filename,
|
|
@@ -151,18 +134,16 @@ def load_kushinada_s3prl():
|
|
| 151 |
local_dir_use_symlinks=False,
|
| 152 |
force_download=False
|
| 153 |
)
|
| 154 |
-
|
| 155 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 156 |
|
| 157 |
-
#
|
| 158 |
state = None
|
| 159 |
if isinstance(ckpt, dict):
|
| 160 |
-
for key in ["state_dict",
|
| 161 |
if key in ckpt and isinstance(ckpt[key], dict):
|
| 162 |
state = ckpt[key]; break
|
| 163 |
if state is None:
|
| 164 |
-
state = ckpt
|
| 165 |
-
|
| 166 |
if not isinstance(state, dict):
|
| 167 |
raise RuntimeError("チェックポイント形式を解釈できませんでした。")
|
| 168 |
|
|
@@ -175,26 +156,23 @@ def load_kushinada_s3prl():
|
|
| 175 |
linear_W, linear_b = v, state[bias_key]
|
| 176 |
break
|
| 177 |
if linear_W is None:
|
| 178 |
-
# weight/biasのペア探索(末尾名が weight/bias)
|
| 179 |
twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
|
| 180 |
for wk, w in twos:
|
| 181 |
-
bk = wk.replace("weight",
|
| 182 |
if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
|
| 183 |
linear_W, linear_b = w, state[bk]
|
| 184 |
break
|
| 185 |
if linear_W is None:
|
| 186 |
-
raise RuntimeError("線形分類器の重みが見つかりません。
|
| 187 |
|
| 188 |
num_classes, hidden_dim = linear_W.shape # [C, H]
|
| 189 |
-
head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes,
|
| 190 |
-
W=linear_W, b=linear_b).to(device).eval()
|
| 191 |
|
| 192 |
-
|
| 193 |
-
default_labels = ["angry", "happy", "neutral", "sad"]
|
| 194 |
id2label = {i: (default_labels[i] if num_classes == 4 and i < 4 else f"class_{i}") for i in range(num_classes)}
|
| 195 |
|
| 196 |
st.info(f"✅ 使うチェックポイント: `{filename}`(revision: {revision})")
|
| 197 |
-
return
|
| 198 |
|
| 199 |
# ===== ユーティリティ =====
|
| 200 |
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
|
@@ -307,12 +285,11 @@ def _normalize_label(lbl: str) -> str:
|
|
| 307 |
|
| 308 |
def predict_emotion_ai(audio_bytes):
|
| 309 |
"""
|
| 310 |
-
S3PRL
|
| 311 |
-
|
| 312 |
-
出力は最終的に list[Tensor([T_i,H])] に正規化 → 時間平均で [B,H] → 線形ヘッド。
|
| 313 |
"""
|
| 314 |
try:
|
| 315 |
-
|
| 316 |
except Exception as e:
|
| 317 |
st.error(f"モデルのロードに失敗しました: {e}")
|
| 318 |
st.info("音声特徴量ベースの分析に切り替えます。")
|
|
@@ -330,94 +307,25 @@ def predict_emotion_ai(audio_bytes):
|
|
| 330 |
y = y[:max_samples]
|
| 331 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 332 |
|
| 333 |
-
#
|
| 334 |
-
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 335 |
-
wavs_len = [int(len(y))]
|
| 336 |
|
| 337 |
with torch.no_grad():
|
| 338 |
-
#
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
if isinstance(seqs, torch.Tensor):
|
| 353 |
-
if seqs.dim() == 3:
|
| 354 |
-
return [seqs[i].cpu() for i in range(seqs.size(0))]
|
| 355 |
-
if seqs.dim() == 2:
|
| 356 |
-
return [seqs.cpu()]
|
| 357 |
-
# dict などが来たら再帰
|
| 358 |
-
return as_seq_list(seqs)
|
| 359 |
-
|
| 360 |
-
# 2) Tensor
|
| 361 |
-
if isinstance(obj, torch.Tensor):
|
| 362 |
-
if obj.dim() == 3: # [B,T,H]
|
| 363 |
-
return [obj[i].cpu() for i in range(obj.size(0))]
|
| 364 |
-
if obj.dim() == 2: # [T,H]
|
| 365 |
-
return [obj.cpu()]
|
| 366 |
-
if obj.dim() == 1: # [H](既にプール済み)→T=1として扱う
|
| 367 |
-
return [obj.unsqueeze(0).cpu()]
|
| 368 |
-
|
| 369 |
-
# 3) dict(代表キー優先)
|
| 370 |
-
if isinstance(obj, dict):
|
| 371 |
-
for k in ["last_hidden_state", "hidden_states"]:
|
| 372 |
-
if k in obj:
|
| 373 |
-
v = obj[k]
|
| 374 |
-
# hidden_states がリストなら最終層
|
| 375 |
-
if k == "hidden_states" and isinstance(v, (list, tuple)) and len(v) > 0:
|
| 376 |
-
v = v[-1]
|
| 377 |
-
return as_seq_list(v)
|
| 378 |
-
# 他キーも探索
|
| 379 |
-
for v in obj.values():
|
| 380 |
-
got = as_seq_list(v)
|
| 381 |
-
if got:
|
| 382 |
-
return got
|
| 383 |
-
return []
|
| 384 |
-
|
| 385 |
-
# 4) list / tuple(入れ子を平坦化)
|
| 386 |
-
if isinstance(obj, (list, tuple)):
|
| 387 |
-
out = []
|
| 388 |
-
for it in obj:
|
| 389 |
-
out.extend(as_seq_list(it))
|
| 390 |
-
return out
|
| 391 |
-
|
| 392 |
-
# それ以外は無視
|
| 393 |
-
return []
|
| 394 |
-
|
| 395 |
-
seq_list = as_seq_list(reps_out)
|
| 396 |
-
if not seq_list:
|
| 397 |
-
raise RuntimeError("上流出力を [T,H] の列へ正規化できませんでした。")
|
| 398 |
-
|
| 399 |
-
# ★ 時間平均で [H] にプール → [B,H]
|
| 400 |
-
pooled_list = []
|
| 401 |
-
for t in seq_list:
|
| 402 |
-
if not isinstance(t, torch.Tensor):
|
| 403 |
-
continue
|
| 404 |
-
t = t.to(device)
|
| 405 |
-
if t.dim() == 3: # [?,T,H] が来たら T 次元で平均
|
| 406 |
-
t = t.mean(dim=1)
|
| 407 |
-
if t.dim() == 2: # [T,H]
|
| 408 |
-
pooled_list.append(t.mean(dim=0)) # -> [H]
|
| 409 |
-
elif t.dim() == 1: # [H]
|
| 410 |
-
pooled_list.append(t)
|
| 411 |
-
else:
|
| 412 |
-
raise RuntimeError(f"Unexpected tensor shape from upstream: {tuple(t.size())}")
|
| 413 |
-
|
| 414 |
-
if len(pooled_list) == 0:
|
| 415 |
-
raise RuntimeError("プーリング後テンソルが空です。")
|
| 416 |
-
|
| 417 |
-
pooled = torch.stack(pooled_list, dim=0) # [B,H]
|
| 418 |
-
|
| 419 |
-
# 線形ヘッドで分類
|
| 420 |
-
logits = head.fc(pooled) # [B,C]
|
| 421 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 422 |
|
| 423 |
pred_id = int(np.argmax(probs))
|
|
@@ -435,7 +343,6 @@ def predict_emotion_ai(audio_bytes):
|
|
| 435 |
return label, scores, "AI(S3PRL)"
|
| 436 |
|
| 437 |
except Exception as e:
|
| 438 |
-
# デバッグ補助(発生時だけ型を少し表示)
|
| 439 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
| 440 |
return predict_emotion_features(audio_bytes)
|
| 441 |
|
|
|
|
| 26 |
import torch
|
| 27 |
import torch.nn as nn
|
| 28 |
from huggingface_hub import list_repo_files, hf_hub_download
|
| 29 |
+
from s3prl.nn import S3PRLUpstream, Featurizer
|
| 30 |
|
| 31 |
# ===== フォント設定 =====
|
| 32 |
rcParams["font.family"] = "DejaVu Sans"
|
|
|
|
| 80 |
@st.cache_resource(show_spinner=False)
|
| 81 |
def load_kushinada_s3prl():
|
| 82 |
"""
|
| 83 |
+
S3PRL 上流(HuBERT base) → Featurizer で [B,T,H] を得る。
|
| 84 |
+
下流(.ckpt)は線形層の W,b を抽出して SimpleLinearHead を構築。
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
token = os.getenv("HF_TOKEN")
|
| 87 |
if not token:
|
| 88 |
raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
|
| 89 |
|
| 90 |
revision = os.getenv("KUSHINADA_REVISION", "main")
|
| 91 |
+
prefer_filename = os.getenv("KUSHINADA_FILENAME") # 例: "s3prl/result/downstream/.../dev-best.ckpt"
|
| 92 |
|
| 93 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 94 |
|
| 95 |
+
# 1) 上流 + Featurizer(最終層)
|
| 96 |
upstream = S3PRLUpstream("hubert_base").to(device).eval()
|
| 97 |
+
featurizer = Featurizer(upstream, layer=-1).to(device).eval()
|
| 98 |
|
| 99 |
+
# 2) モデル内のファイル一覧(サブフォルダ込み)
|
| 100 |
api = HfApi()
|
| 101 |
info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
|
| 102 |
+
all_files = [s.rfilename for s in info.siblings]
|
| 103 |
|
| 104 |
+
# 3) チェックポイント選定
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
exts = (".pt", ".ckpt", ".pth", ".bin")
|
| 106 |
candidates = [f for f in all_files if f.lower().endswith(exts)]
|
|
|
|
|
|
|
| 107 |
filename = None
|
| 108 |
if prefer_filename:
|
| 109 |
if prefer_filename in all_files:
|
| 110 |
filename = prefer_filename
|
| 111 |
else:
|
|
|
|
| 112 |
matches = [f for f in all_files if f.endswith(prefer_filename)]
|
| 113 |
if matches:
|
| 114 |
filename = matches[0]
|
|
|
|
|
|
|
| 115 |
if filename is None and candidates:
|
|
|
|
| 116 |
ranked = sorted(
|
| 117 |
candidates,
|
| 118 |
key=lambda f: (
|
| 119 |
+
-int(any(k in f.lower() for k in ["downstream","classifier","jtes","kushinada"])),
|
| 120 |
len(f)
|
| 121 |
)
|
| 122 |
)
|
| 123 |
filename = ranked[0] if ranked else None
|
|
|
|
| 124 |
if filename is None:
|
| 125 |
+
raise FileNotFoundError("下流チェックポイント(.pt/.ckpt/.pth/.bin)が見つかりません。KUSHINADA_FILENAME を Secrets に指定してください。")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
|
|
|
| 127 |
ckpt_path = hf_hub_download(
|
| 128 |
repo_id=KUSHINADA_REPO,
|
| 129 |
filename=filename,
|
|
|
|
| 134 |
local_dir_use_symlinks=False,
|
| 135 |
force_download=False
|
| 136 |
)
|
|
|
|
| 137 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 138 |
|
| 139 |
+
# 4) state_dict から線形層の W, b を抽出
|
| 140 |
state = None
|
| 141 |
if isinstance(ckpt, dict):
|
| 142 |
+
for key in ["state_dict","Downstream","model","downstream","net","weights"]:
|
| 143 |
if key in ckpt and isinstance(ckpt[key], dict):
|
| 144 |
state = ckpt[key]; break
|
| 145 |
if state is None:
|
| 146 |
+
state = ckpt
|
|
|
|
| 147 |
if not isinstance(state, dict):
|
| 148 |
raise RuntimeError("チェックポイント形式を解釈できませんでした。")
|
| 149 |
|
|
|
|
| 156 |
linear_W, linear_b = v, state[bias_key]
|
| 157 |
break
|
| 158 |
if linear_W is None:
|
|
|
|
| 159 |
twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
|
| 160 |
for wk, w in twos:
|
| 161 |
+
bk = wk.replace("weight","bias")
|
| 162 |
if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
|
| 163 |
linear_W, linear_b = w, state[bk]
|
| 164 |
break
|
| 165 |
if linear_W is None:
|
| 166 |
+
raise RuntimeError("線形分類器の重みが見つかりません。Downstream 構造の再現が必要です。")
|
| 167 |
|
| 168 |
num_classes, hidden_dim = linear_W.shape # [C, H]
|
| 169 |
+
head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes, W=linear_W, b=linear_b).to(device).eval()
|
|
|
|
| 170 |
|
| 171 |
+
default_labels = ["angry","happy","neutral","sad"]
|
|
|
|
| 172 |
id2label = {i: (default_labels[i] if num_classes == 4 and i < 4 else f"class_{i}") for i in range(num_classes)}
|
| 173 |
|
| 174 |
st.info(f"✅ 使うチェックポイント: `{filename}`(revision: {revision})")
|
| 175 |
+
return featurizer, head, id2label, device
|
| 176 |
|
| 177 |
# ===== ユーティリティ =====
|
| 178 |
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
|
|
|
| 285 |
|
| 286 |
def predict_emotion_ai(audio_bytes):
|
| 287 |
"""
|
| 288 |
+
S3PRL Featurizer で必ず [B,T,H] を取得 → 各サンプルの有効長 reps_len で時間平均 → [B,H]。
|
| 289 |
+
その後、線形ヘッド(W,b)で分類。
|
|
|
|
| 290 |
"""
|
| 291 |
try:
|
| 292 |
+
featurizer, head, id2label, device = load_kushinada_s3prl()
|
| 293 |
except Exception as e:
|
| 294 |
st.error(f"モデルのロードに失敗しました: {e}")
|
| 295 |
st.info("音声特徴量ベースの分析に切り替えます。")
|
|
|
|
| 307 |
y = y[:max_samples]
|
| 308 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 309 |
|
| 310 |
+
# S3PRLは list[Tensor], list[int] を想定
|
| 311 |
+
wavs = [torch.tensor(y, dtype=torch.float32)]
|
| 312 |
+
wavs_len = [int(len(y))]
|
| 313 |
|
| 314 |
with torch.no_grad():
|
| 315 |
+
reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H], reps_len: list[int] or Tensor[B]
|
| 316 |
+
if isinstance(reps_len, torch.Tensor):
|
| 317 |
+
reps_len = reps_len.tolist()
|
| 318 |
+
|
| 319 |
+
# 有効長のみで平均(パディングを無視)
|
| 320 |
+
pooled = []
|
| 321 |
+
for i in range(reps.shape[0]):
|
| 322 |
+
Ti = int(reps_len[i]) if reps_len else reps.shape[1]
|
| 323 |
+
Ti = max(1, min(Ti, reps.shape[1])) # 安全側
|
| 324 |
+
pooled.append(reps[i, :Ti].mean(dim=0))
|
| 325 |
+
pooled = torch.stack(pooled, dim=0) # [B,H]
|
| 326 |
+
|
| 327 |
+
# 線形ヘッドで分類(head.fcに直接入れる)
|
| 328 |
+
logits = head.fc(pooled.to(device)) # [B,C]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
| 330 |
|
| 331 |
pred_id = int(np.argmax(probs))
|
|
|
|
| 343 |
return label, scores, "AI(S3PRL)"
|
| 344 |
|
| 345 |
except Exception as e:
|
|
|
|
| 346 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
| 347 |
return predict_emotion_features(audio_bytes)
|
| 348 |
|