ayaka68 commited on
Commit
d5cddb7
·
verified ·
1 Parent(s): 5453e94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -132
app.py CHANGED
@@ -26,7 +26,7 @@ import japanize_matplotlib
26
  import torch
27
  import torch.nn as nn
28
  from huggingface_hub import list_repo_files, hf_hub_download
29
- from s3prl.nn import S3PRLUpstream
30
 
31
  # ===== フォント設定 =====
32
  rcParams["font.family"] = "DejaVu Sans"
@@ -80,67 +80,50 @@ KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
80
  @st.cache_resource(show_spinner=False)
81
  def load_kushinada_s3prl():
82
  """
83
- S3PRL上流(HuBERT base) + HFの下流(ckpt)を自動取得して復元。
84
- - .pt 以外に .ckpt / .pth / .bin も探索
85
- - サブフォルダ内も対象
86
- - 必要なら KUSHINADA_FILENAME / KUSHINADA_REVISION を Secrets に設定して固定
87
  """
88
  token = os.getenv("HF_TOKEN")
89
  if not token:
90
  raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
91
 
92
  revision = os.getenv("KUSHINADA_REVISION", "main")
93
- prefer_filename = os.getenv("KUSHINADA_FILENAME") # 例: "checkpoints/epoch=9-step=1234.ckpt"
94
 
95
  device = "cuda" if torch.cuda.is_available() else "cpu"
96
 
97
- # 1) S3PRL 上流:HuBERT base
98
  upstream = S3PRLUpstream("hubert_base").to(device).eval()
 
99
 
100
- # 2) モデル内のファイル一覧を取得(サブフォルダ込み)
101
  api = HfApi()
102
  info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
103
- all_files = [s.rfilename for s in info.siblings] # ルート/サブフォルダ含むファイルパス
104
 
105
- # --- デバッグ出力(必要なら見える化)
106
- with st.expander("📦 モデル内ファイル一覧(デバッグ)", expanded=False):
107
- st.write(all_files)
108
-
109
- # 3) 候補ファイルの決定
110
  exts = (".pt", ".ckpt", ".pth", ".bin")
111
  candidates = [f for f in all_files if f.lower().endswith(exts)]
112
-
113
- # Secretsで明示指定があればそれを優先
114
  filename = None
115
  if prefer_filename:
116
  if prefer_filename in all_files:
117
  filename = prefer_filename
118
  else:
119
- # サブフォルダなしで指定された場合に補正を試みる
120
  matches = [f for f in all_files if f.endswith(prefer_filename)]
121
  if matches:
122
  filename = matches[0]
123
-
124
- # それでも未決なら候補の先頭を採用
125
  if filename is None and candidates:
126
- # なるべく "downstream", "classifier", "jtes" を含むものを優先
127
  ranked = sorted(
128
  candidates,
129
  key=lambda f: (
130
- -int(any(k in f.lower() for k in ["downstream", "classifier", "jtes", "kushinada"])),
131
  len(f)
132
  )
133
  )
134
  filename = ranked[0] if ranked else None
135
-
136
  if filename is None:
137
- raise FileNotFoundError(
138
- "下流チェックポイント(.pt/.ckpt/.pth/.bin)が見つかりません。\n"
139
- "モデルページの Files でファイル名を確認し、SpacesのSecretsに "
140
- "KUSHINADA_FILENAME として保存してください。"
141
- )
142
 
143
- # 4) チェックポイントを取得
144
  ckpt_path = hf_hub_download(
145
  repo_id=KUSHINADA_REPO,
146
  filename=filename,
@@ -151,18 +134,16 @@ def load_kushinada_s3prl():
151
  local_dir_use_symlinks=False,
152
  force_download=False
153
  )
154
-
155
  ckpt = torch.load(ckpt_path, map_location="cpu")
156
 
157
- # 5) state_dict を探索し、線形ヘッド (W, b) を復元
158
  state = None
159
  if isinstance(ckpt, dict):
160
- for key in ["state_dict", "Downstream", "model", "downstream", "net", "weights"]:
161
  if key in ckpt and isinstance(ckpt[key], dict):
162
  state = ckpt[key]; break
163
  if state is None:
164
- state = ckpt # そのままstate_dictの場合
165
-
166
  if not isinstance(state, dict):
167
  raise RuntimeError("チェックポイント形式を解釈できませんでした。")
168
 
@@ -175,26 +156,23 @@ def load_kushinada_s3prl():
175
  linear_W, linear_b = v, state[bias_key]
176
  break
177
  if linear_W is None:
178
- # weight/biasのペア探索(末尾名が weight/bias)
179
  twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
180
  for wk, w in twos:
181
- bk = wk.replace("weight", "bias")
182
  if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
183
  linear_W, linear_b = w, state[bk]
184
  break
185
  if linear_W is None:
186
- raise RuntimeError("線形分類器の重みが見つかりません。S3PRLのDownstream構造を再現する必要があります。")
187
 
188
  num_classes, hidden_dim = linear_W.shape # [C, H]
189
- head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes,
190
- W=linear_W, b=linear_b).to(device).eval()
191
 
192
- # ラベル(暫定)。必要なら順序を手動調整してください。
193
- default_labels = ["angry", "happy", "neutral", "sad"]
194
  id2label = {i: (default_labels[i] if num_classes == 4 and i < 4 else f"class_{i}") for i in range(num_classes)}
195
 
196
  st.info(f"✅ 使うチェックポイント: `{filename}`(revision: {revision})")
197
- return upstream, head, id2label, device
198
 
199
  # ===== ユーティリティ =====
200
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
@@ -307,12 +285,11 @@ def _normalize_label(lbl: str) -> str:
307
 
308
  def predict_emotion_ai(audio_bytes):
309
  """
310
- S3PRL上流 + HF下流(.ckpt) で推論(S3PRLはリスト入力/出力前提)。
311
- 入力は CPU の list[Tensor] / list[int] に統一。
312
- 出力は最終的に list[Tensor([T_i,H])] に正規化 → 時間平均で [B,H] → 線形ヘッド。
313
  """
314
  try:
315
- upstream, head, id2label, device = load_kushinada_s3prl()
316
  except Exception as e:
317
  st.error(f"モデルのロードに失敗しました: {e}")
318
  st.info("音声特徴量ベースの分析に切り替えます。")
@@ -330,94 +307,25 @@ def predict_emotion_ai(audio_bytes):
330
  y = y[:max_samples]
331
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
332
 
333
- # S3PRLは「CPUの list 形式」を想定している実装が多い
334
- wavs = [torch.tensor(y, dtype=torch.float32)] # list[Tensor([T])]
335
- wavs_len = [int(len(y))] # list[int]
336
 
337
  with torch.no_grad():
338
- # 返り値は実装により:
339
- # - (list[Tensor([T_i,H])], list[int]) ←最も一般的
340
- # - Tensor([B,T,H]) / dict / 入れ子
341
- reps_out = upstream(wavs, wavs_len)
342
-
343
- def as_seq_list(obj):
344
- """上流出力を list[Tensor([T_i, H])] に正規化"""
345
- # 1) すでに (seqs, lens) 形式
346
- if isinstance(obj, tuple) and len(obj) == 2:
347
- seqs, lens = obj
348
- # seqs: list of Tensor ならそのまま
349
- if isinstance(seqs, list) and len(seqs) > 0 and isinstance(seqs[0], torch.Tensor):
350
- return seqs
351
- # seqs Tensor [B,T,H] の場合 → バラす
352
- if isinstance(seqs, torch.Tensor):
353
- if seqs.dim() == 3:
354
- return [seqs[i].cpu() for i in range(seqs.size(0))]
355
- if seqs.dim() == 2:
356
- return [seqs.cpu()]
357
- # dict などが来たら再帰
358
- return as_seq_list(seqs)
359
-
360
- # 2) Tensor
361
- if isinstance(obj, torch.Tensor):
362
- if obj.dim() == 3: # [B,T,H]
363
- return [obj[i].cpu() for i in range(obj.size(0))]
364
- if obj.dim() == 2: # [T,H]
365
- return [obj.cpu()]
366
- if obj.dim() == 1: # [H](既にプール済み)→T=1として扱う
367
- return [obj.unsqueeze(0).cpu()]
368
-
369
- # 3) dict(代表キー優先)
370
- if isinstance(obj, dict):
371
- for k in ["last_hidden_state", "hidden_states"]:
372
- if k in obj:
373
- v = obj[k]
374
- # hidden_states がリストなら最終層
375
- if k == "hidden_states" and isinstance(v, (list, tuple)) and len(v) > 0:
376
- v = v[-1]
377
- return as_seq_list(v)
378
- # 他キーも探索
379
- for v in obj.values():
380
- got = as_seq_list(v)
381
- if got:
382
- return got
383
- return []
384
-
385
- # 4) list / tuple(入れ子を平坦化)
386
- if isinstance(obj, (list, tuple)):
387
- out = []
388
- for it in obj:
389
- out.extend(as_seq_list(it))
390
- return out
391
-
392
- # それ以外は無視
393
- return []
394
-
395
- seq_list = as_seq_list(reps_out)
396
- if not seq_list:
397
- raise RuntimeError("上流出力を [T,H] の列へ正規化できませんでした。")
398
-
399
- # ★ 時間平均で [H] にプール → [B,H]
400
- pooled_list = []
401
- for t in seq_list:
402
- if not isinstance(t, torch.Tensor):
403
- continue
404
- t = t.to(device)
405
- if t.dim() == 3: # [?,T,H] が来たら T 次元で平均
406
- t = t.mean(dim=1)
407
- if t.dim() == 2: # [T,H]
408
- pooled_list.append(t.mean(dim=0)) # -> [H]
409
- elif t.dim() == 1: # [H]
410
- pooled_list.append(t)
411
- else:
412
- raise RuntimeError(f"Unexpected tensor shape from upstream: {tuple(t.size())}")
413
-
414
- if len(pooled_list) == 0:
415
- raise RuntimeError("プーリング後テンソルが空です。")
416
-
417
- pooled = torch.stack(pooled_list, dim=0) # [B,H]
418
-
419
- # 線形ヘッドで分類
420
- logits = head.fc(pooled) # [B,C]
421
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
422
 
423
  pred_id = int(np.argmax(probs))
@@ -435,7 +343,6 @@ def predict_emotion_ai(audio_bytes):
435
  return label, scores, "AI(S3PRL)"
436
 
437
  except Exception as e:
438
- # デバッグ補助(発生時だけ型を少し表示)
439
  st.warning(f"AI予測中にエラーが発生: {e}")
440
  return predict_emotion_features(audio_bytes)
441
 
 
26
  import torch
27
  import torch.nn as nn
28
  from huggingface_hub import list_repo_files, hf_hub_download
29
+ from s3prl.nn import S3PRLUpstream, Featurizer
30
 
31
  # ===== フォント設定 =====
32
  rcParams["font.family"] = "DejaVu Sans"
 
80
  @st.cache_resource(show_spinner=False)
81
  def load_kushinada_s3prl():
82
  """
83
+ S3PRL 上流(HuBERT base) Featurizer で [B,T,H] を得る。
84
+ 下流(.ckpt)は線形層の W,b を抽出して SimpleLinearHead を構築。
 
 
85
  """
86
  token = os.getenv("HF_TOKEN")
87
  if not token:
88
  raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
89
 
90
  revision = os.getenv("KUSHINADA_REVISION", "main")
91
+ prefer_filename = os.getenv("KUSHINADA_FILENAME") # 例: "s3prl/result/downstream/.../dev-best.ckpt"
92
 
93
  device = "cuda" if torch.cuda.is_available() else "cpu"
94
 
95
+ # 1) 上流 + Featurizer(最終層)
96
  upstream = S3PRLUpstream("hubert_base").to(device).eval()
97
+ featurizer = Featurizer(upstream, layer=-1).to(device).eval()
98
 
99
+ # 2) モデル内のファイル一覧(サブフォルダ込み)
100
  api = HfApi()
101
  info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
102
+ all_files = [s.rfilename for s in info.siblings]
103
 
104
+ # 3) チェックポイント選定
 
 
 
 
105
  exts = (".pt", ".ckpt", ".pth", ".bin")
106
  candidates = [f for f in all_files if f.lower().endswith(exts)]
 
 
107
  filename = None
108
  if prefer_filename:
109
  if prefer_filename in all_files:
110
  filename = prefer_filename
111
  else:
 
112
  matches = [f for f in all_files if f.endswith(prefer_filename)]
113
  if matches:
114
  filename = matches[0]
 
 
115
  if filename is None and candidates:
 
116
  ranked = sorted(
117
  candidates,
118
  key=lambda f: (
119
+ -int(any(k in f.lower() for k in ["downstream","classifier","jtes","kushinada"])),
120
  len(f)
121
  )
122
  )
123
  filename = ranked[0] if ranked else None
 
124
  if filename is None:
125
+ raise FileNotFoundError("下流チェックポイント(.pt/.ckpt/.pth/.bin)が見つかりません。KUSHINADA_FILENAME を Secrets に指定してください。")
 
 
 
 
126
 
 
127
  ckpt_path = hf_hub_download(
128
  repo_id=KUSHINADA_REPO,
129
  filename=filename,
 
134
  local_dir_use_symlinks=False,
135
  force_download=False
136
  )
 
137
  ckpt = torch.load(ckpt_path, map_location="cpu")
138
 
139
+ # 4) state_dict から線形層の W, b を抽出
140
  state = None
141
  if isinstance(ckpt, dict):
142
+ for key in ["state_dict","Downstream","model","downstream","net","weights"]:
143
  if key in ckpt and isinstance(ckpt[key], dict):
144
  state = ckpt[key]; break
145
  if state is None:
146
+ state = ckpt
 
147
  if not isinstance(state, dict):
148
  raise RuntimeError("チェックポイント形式を解釈できませんでした。")
149
 
 
156
  linear_W, linear_b = v, state[bias_key]
157
  break
158
  if linear_W is None:
 
159
  twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
160
  for wk, w in twos:
161
+ bk = wk.replace("weight","bias")
162
  if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
163
  linear_W, linear_b = w, state[bk]
164
  break
165
  if linear_W is None:
166
+ raise RuntimeError("線形分類器の重みが見つかりません。Downstream 構造の再現が必要です。")
167
 
168
  num_classes, hidden_dim = linear_W.shape # [C, H]
169
+ head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes, W=linear_W, b=linear_b).to(device).eval()
 
170
 
171
+ default_labels = ["angry","happy","neutral","sad"]
 
172
  id2label = {i: (default_labels[i] if num_classes == 4 and i < 4 else f"class_{i}") for i in range(num_classes)}
173
 
174
  st.info(f"✅ 使うチェックポイント: `{filename}`(revision: {revision})")
175
+ return featurizer, head, id2label, device
176
 
177
  # ===== ユーティリティ =====
178
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
 
285
 
286
  def predict_emotion_ai(audio_bytes):
287
  """
288
+ S3PRL Featurizer で必ず [B,T,H] を取得 → 各サンプルの有効長 reps_len で時間平均 → [B,H]。
289
+ その後、線形ヘッド(W,b)で分類。
 
290
  """
291
  try:
292
+ featurizer, head, id2label, device = load_kushinada_s3prl()
293
  except Exception as e:
294
  st.error(f"モデルのロードに失敗しました: {e}")
295
  st.info("音声特徴量ベースの分析に切り替えます。")
 
307
  y = y[:max_samples]
308
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
309
 
310
+ # S3PRL list[Tensor], list[int] を想定
311
+ wavs = [torch.tensor(y, dtype=torch.float32)]
312
+ wavs_len = [int(len(y))]
313
 
314
  with torch.no_grad():
315
+ reps, reps_len = featurizer(wavs, wavs_len) # reps: [B,T,H], reps_len: list[int] or Tensor[B]
316
+ if isinstance(reps_len, torch.Tensor):
317
+ reps_len = reps_len.tolist()
318
+
319
+ # 有効長のみで平均(パディングを無視)
320
+ pooled = []
321
+ for i in range(reps.shape[0]):
322
+ Ti = int(reps_len[i]) if reps_len else reps.shape[1]
323
+ Ti = max(1, min(Ti, reps.shape[1])) # 安全側
324
+ pooled.append(reps[i, :Ti].mean(dim=0))
325
+ pooled = torch.stack(pooled, dim=0) # [B,H]
326
+
327
+ # 線形ヘッドで分類(head.fcに直接入れる)
328
+ logits = head.fc(pooled.to(device)) # [B,C]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
330
 
331
  pred_id = int(np.argmax(probs))
 
343
  return label, scores, "AI(S3PRL)"
344
 
345
  except Exception as e:
 
346
  st.warning(f"AI予測中にエラーが発生: {e}")
347
  return predict_emotion_features(audio_bytes)
348