ayaka68 commited on
Commit
1981810
·
verified ·
1 Parent(s): d22abdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -47
app.py CHANGED
@@ -1,11 +1,13 @@
 
1
  """
2
  Voice→Place Recommender (Streamlit / Hugging Face Spaces)
3
- - Gated model対応(HF_TOKENSecretsに登録して使
4
- - モデルロードは@st.cache_resource一本化
 
5
  """
6
 
7
  # ===== 基本インポート =====
8
- import io, uuid, datetime as dt, csv, base64, json, random, os
9
  import numpy as np
10
  import soundfile as sf
11
  from pydub import AudioSegment
@@ -13,14 +15,18 @@ from pydub import AudioSegment
13
  import streamlit as st
14
  from audiorecorder import audiorecorder
15
 
 
16
  import matplotlib
17
  matplotlib.use('Agg')
18
  import matplotlib.pyplot as plt
19
  from matplotlib import rcParams
20
  import japanize_matplotlib
21
 
 
22
  import torch
23
- from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
 
 
24
 
25
  # ===== フォント設定 =====
26
  rcParams["font.family"] = "DejaVu Sans"
@@ -45,14 +51,33 @@ PLACES = [
45
  ]
46
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
47
 
48
- # ===== Gated model ロード一本化)=====
49
- MODEL_NAME = "imprt/kushinada-hubert-base-jtes-er"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @st.cache_resource(show_spinner=False)
52
- def load_model():
53
  """
54
- 日本語音声感情認識モデルロード(gated対応)
55
- SpacesのSecretsに 'HF_TOKEN'保存している前提
56
  """
57
  token = os.getenv("HF_TOKEN")
58
  if not token:
@@ -60,17 +85,71 @@ def load_model():
60
 
61
  device = "cuda" if torch.cuda.is_available() else "cpu"
62
 
63
- feature_extractor = AutoFeatureExtractor.from_pretrained(
64
- MODEL_NAME,
65
- token=token
66
- )
67
- model = AutoModelForAudioClassification.from_pretrained(
68
- MODEL_NAME,
69
- token=token
70
- ).to(device)
71
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- return feature_extractor, model, device
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # ===== ユーティリティ =====
76
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
@@ -92,6 +171,7 @@ def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
92
  return buf.getvalue()
93
 
94
  def audio_player_bytes(b: bytes, mime="audio/wav"):
 
95
  if not b:
96
  return
97
  b64 = base64.b64encode(b).decode("utf-8")
@@ -104,8 +184,9 @@ def audio_player_bytes(b: bytes, mime="audio/wav"):
104
  unsafe_allow_html=True,
105
  )
106
 
107
- # ===== フォールバック用簡易特徴量 =====
108
  def extract_features(y, sr):
 
109
  abs_y = np.abs(y)
110
  thr = 0.01 * (abs_y.max() + 1e-9)
111
  idx = np.where(abs_y > thr)[0]
@@ -125,6 +206,7 @@ def extract_features(y, sr):
125
  zc = (y[:-1] * y[1:] < 0).astype(np.float32)
126
  zcr_mean = float(zc.mean()) if zc.size else 0.0
127
 
 
128
  fmin, fmax = 80.0, 600.0
129
  if len(y) < int(sr / fmin) + 2:
130
  f0_est = 0.0
@@ -148,6 +230,7 @@ def extract_features(y, sr):
148
  }
149
 
150
  def predict_emotion_features(audio_bytes):
 
151
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
152
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
153
  feat = extract_features(y, sr)
@@ -172,23 +255,17 @@ def predict_emotion_features(audio_bytes):
172
  scores["neutral"] += 0.3
173
  return label, scores, "Features"
174
 
175
- # ===== AI推定 =====
176
- def normalize_label(lbl: str) -> str:
177
- """モデル出力ラベルをUI想定に正規化"""
178
- m = {
179
- "happy": "happiness",
180
- "happiness": "happiness",
181
- "angry": "anger",
182
- "anger": "anger",
183
- "sad": "sadness",
184
- "sadness": "sadness",
185
- "neutral": "neutral"
186
- }
187
  return m.get(lbl.lower(), lbl)
188
 
189
  def predict_emotion_ai(audio_bytes):
 
 
 
190
  try:
191
- feature_extractor, model, device = load_model()
192
  except Exception as e:
193
  st.error(f"モデルのロードに失敗しました: {e}")
194
  st.info("音声特徴量ベースの分析に切り替えます。")
@@ -204,23 +281,32 @@ def predict_emotion_ai(audio_bytes):
204
  y = y[:max_samples]
205
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
206
 
207
- inputs = feature_extractor(y, sampling_rate=sr, return_tensors="pt", padding=True)
208
- inputs = {k: v.to(device) for k, v in inputs.items()}
209
 
210
  with torch.no_grad():
211
- logits = model(**inputs).logits
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
214
- pred_id = int(np.argmax(probs))
215
- raw_label = model.config.id2label[pred_id]
216
- label = normalize_label(raw_label)
217
 
218
- scores = {normalize_label(model.config.id2label[i]): float(probs[i]) for i in range(len(probs))}
219
- # 期待外ラベルが混ざっても可視化で扱えるように0〜1にクリップ
 
 
220
  for k in list(scores.keys()):
221
  scores[k] = max(0.0, min(1.0, scores[k]))
222
-
223
- return label, scores, "AI"
224
 
225
  except Exception as e:
226
  st.warning(f"AI予測中にエラーが発生: {e}")
@@ -330,9 +416,8 @@ def main():
330
  if key not in st.session_state: st.session_state[key] = default
331
 
332
  st.subheader("1) 録音またはアップロード")
333
-
334
- with st.warning("⚠️ ファイルアップロード403出る場合は、録音機能をお使いください。"):
335
- st.markdown("**🎤 録音** → PC/スマホで直接話す or 端末で音声再生しながら録音")
336
 
337
  tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
338
 
 
1
+ # app.py
2
  """
3
  Voice→Place Recommender (Streamlit / Hugging Face Spaces)
4
+ - 日本語音声感情認識:S3PRL(HuBERT base) + HFの下流(.pt) チェックポイントを用いた推論
5
+ - Spaces の Settings → Secrets HF_TOKEN を設定してください
6
+ - ffmpeg が必要(apt.txtに ffmpeg を記載)
7
  """
8
 
9
  # ===== 基本インポート =====
10
+ import io, json, base64, random, os
11
  import numpy as np
12
  import soundfile as sf
13
  from pydub import AudioSegment
 
15
  import streamlit as st
16
  from audiorecorder import audiorecorder
17
 
18
+ # Matplotlib
19
  import matplotlib
20
  matplotlib.use('Agg')
21
  import matplotlib.pyplot as plt
22
  from matplotlib import rcParams
23
  import japanize_matplotlib
24
 
25
+ # Torch / Hugging Face Hub / S3PRL
26
  import torch
27
+ import torch.nn as nn
28
+ from huggingface_hub import list_repo_files, hf_hub_download
29
+ from s3prl.nn import S3PRLUpstream
30
 
31
  # ===== フォント設定 =====
32
  rcParams["font.family"] = "DejaVu Sans"
 
51
  ]
52
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
53
 
54
+ # ===== KUSHINADA 定義HF の gated モデルのリポ)=====
55
+ KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
56
 
57
+ # ===== S3PRL 下流ヘッド(線形) =====
58
+ class SimpleLinearHead(nn.Module):
59
+ """
60
+ チェックポイント中の線形分類器 (W, b) を復元する簡易ヘッド。
61
+ 入力: [B, T, H] → mean-pool → [B, H] → Linear(H, C)
62
+ """
63
+ def __init__(self, in_dim: int, num_classes: int, W: torch.Tensor, b: torch.Tensor):
64
+ super().__init__()
65
+ self.pool = lambda x: x.mean(dim=1) # 時系列平均
66
+ self.fc = nn.Linear(in_dim, num_classes)
67
+ with torch.no_grad():
68
+ self.fc.weight.copy_(W) # [C, H]
69
+ self.fc.bias.copy_(b) # [C]
70
+
71
+ def forward(self, reps): # reps: [B, T, H]
72
+ x = self.pool(reps)
73
+ return self.fc(x)
74
+
75
+ # ===== KUSHINADA (S3PRL) ローダ =====
76
  @st.cache_resource(show_spinner=False)
77
+ def load_kushinada_s3prl():
78
  """
79
+ S3PRL上流(HuBERT base) + HFの下流(.pt)自動取得して復元
80
+ チェックポイント中から (weight,bias)推定して線形ヘッドを構築
81
  """
82
  token = os.getenv("HF_TOKEN")
83
  if not token:
 
85
 
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
 
88
+ # 1) S3PRL 上流:HuBERT base(kushinada はHuBERT系想定)
89
+ upstream = S3PRLUpstream("hubert_base").to(device).eval()
90
+
91
+ # 2) HFから .pt を探してダウンロード
92
+ files = list_repo_files(KUSHINADA_REPO, token=token)
93
+ pt_files = [f for f in files if f.endswith(".pt")]
94
+ if not pt_files:
95
+ raise FileNotFoundError("下流チェックポイント(.pt)が見つかりません。モデルページの Files を確認してください。")
96
+
97
+ # 最初の .pt を採用(必要なら固定のファイル名に変更)
98
+ ckpt_path = hf_hub_download(repo_id=KUSHINADA_REPO, filename=pt_files[0], token=token)
99
+
100
+ # 3) チェックポイント読込
101
+ ckpt = torch.load(ckpt_path, map_location="cpu")
102
+
103
+ # 4) state_dict から線形層の W, b を推定
104
+ state = None
105
+ if isinstance(ckpt, dict):
106
+ for key in ["state_dict", "Downstream", "model", "downstream", "net", "weights"]:
107
+ if key in ckpt and isinstance(ckpt[key], dict):
108
+ state = ckpt[key]
109
+ break
110
+ if state is None:
111
+ # そのままstate dictの可能性
112
+ # S3PRLのスクリプトにより出力形式は複数パターンありうる
113
+ state = ckpt
114
+
115
+ if not isinstance(state, dict):
116
+ raise RuntimeError("チェックポイント形式を解釈できませんでした。")
117
+
118
+ # W, b らしきテンソルを探索([C,H], [C] っぽい組を探す)
119
+ linear_W, linear_b = None, None
120
+ for k, v in state.items():
121
+ if isinstance(v, torch.Tensor) and v.ndim == 2:
122
+ base = k.rsplit(".", 1)[0] # 例: "classifier.fc.weight" → "classifier.fc"
123
+ bias_key = base + ".bias"
124
+ if bias_key in state and isinstance(state[bias_key], torch.Tensor) and state[bias_key].ndim == 1:
125
+ linear_W = v
126
+ linear_b = state[bias_key]
127
+ break
128
+
129
+ if linear_W is None:
130
+ # 次善策: "weight"と"bias"という名前のペアを総当たり
131
+ twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
132
+ for wk, w in twos:
133
+ bk = wk.replace("weight", "bias")
134
+ if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
135
+ linear_W, linear_b = w, state[bk]
136
+ break
137
+
138
+ if linear_W is None:
139
+ raise RuntimeError("線形分類器の重みが見つかりません。S3PRLの公式手順に沿ったDownstream再現が必要です。")
140
 
141
+ num_classes, hidden_dim = linear_W.shape # [C, H]
142
+ head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes,
143
+ W=linear_W, b=linear_b).to(device).eval()
144
+
145
+ # JTES想定:4クラス(angry/happy/neutral/sad)※順序は環境/学習で異なる可能性あり
146
+ default_labels = ["angry", "happy", "neutral", "sad"]
147
+ if num_classes == 4:
148
+ id2label = {i: default_labels[i] for i in range(4)}
149
+ else:
150
+ id2label = {i: f"class_{i}" for i in range(num_classes)}
151
+
152
+ return upstream, head, id2label, device
153
 
154
  # ===== ユーティリティ =====
155
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
 
171
  return buf.getvalue()
172
 
173
  def audio_player_bytes(b: bytes, mime="audio/wav"):
174
+ """音声プレイヤーを表示"""
175
  if not b:
176
  return
177
  b64 = base64.b64encode(b).decode("utf-8")
 
184
  unsafe_allow_html=True,
185
  )
186
 
187
+ # ===== フォールバック用簡易特徴量ベース =====
188
  def extract_features(y, sr):
189
+ """音声から簡易特徴量を抽出"""
190
  abs_y = np.abs(y)
191
  thr = 0.01 * (abs_y.max() + 1e-9)
192
  idx = np.where(abs_y > thr)[0]
 
206
  zc = (y[:-1] * y[1:] < 0).astype(np.float32)
207
  zcr_mean = float(zc.mean()) if zc.size else 0.0
208
 
209
+ # F0推定(非常に簡易)
210
  fmin, fmax = 80.0, 600.0
211
  if len(y) < int(sr / fmin) + 2:
212
  f0_est = 0.0
 
230
  }
231
 
232
  def predict_emotion_features(audio_bytes):
233
+ """音声特徴量から感情を推定(フォールバック)"""
234
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
235
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
236
  feat = extract_features(y, sr)
 
255
  scores["neutral"] += 0.3
256
  return label, scores, "Features"
257
 
258
+ # ===== AI推定(S3PRL)=====
259
+ def _normalize_label(lbl: str) -> str:
260
+ m = {"happy": "happiness", "angry": "anger", "sad": "sadness", "neutral": "neutral"}
 
 
 
 
 
 
 
 
 
261
  return m.get(lbl.lower(), lbl)
262
 
263
  def predict_emotion_ai(audio_bytes):
264
+ """
265
+ S3PRL上流 + HF下流(.pt) で推論。
266
+ """
267
  try:
268
+ upstream, head, id2label, device = load_kushinada_s3prl()
269
  except Exception as e:
270
  st.error(f"モデルのロードに失敗しました: {e}")
271
  st.info("音声特徴量ベースの分析に切り替えます。")
 
281
  y = y[:max_samples]
282
  st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
283
 
284
+ wav = torch.tensor(y, dtype=torch.float32, device=device).unsqueeze(0) # [1, T]
 
285
 
286
  with torch.no_grad():
287
+ reps_dict = upstream(wav) # S3PRL Upstream の出力
288
+ if isinstance(reps_dict, dict):
289
+ reps = reps_dict.get("last_hidden_state", None)
290
+ if reps is None:
291
+ # 代替:最終層の hidden_states など
292
+ if "hidden_states" in reps_dict and isinstance(reps_dict["hidden_states"], (list, tuple)):
293
+ reps = reps_dict["hidden_states"][-1]
294
+ else:
295
+ # 直接テンソルが来る実装もある
296
+ reps = list(reps_dict.values())[-1]
297
+ else:
298
+ reps = reps_dict # テンソル想定 [B, T, H]
299
 
300
+ logits = head(reps) # [B, C]
301
+ probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
 
 
302
 
303
+ pred_id = int(np.argmax(probs))
304
+ raw_label = id2label[pred_id]
305
+ label = _normalize_label(raw_label)
306
+ scores = {_normalize_label(id2label[i]): float(probs[i]) for i in range(len(probs))}
307
  for k in list(scores.keys()):
308
  scores[k] = max(0.0, min(1.0, scores[k]))
309
+ return label, scores, "AI(S3PRL)"
 
310
 
311
  except Exception as e:
312
  st.warning(f"AI予測中にエラーが発生: {e}")
 
416
  if key not in st.session_state: st.session_state[key] = default
417
 
418
  st.subheader("1) 録音またはアップロード")
419
+ with st.warning("⚠️ ファイルアップロードで403が出る場合は、録音機能をご利用ください。"):
420
+ st.markdown("**🎤 録音** → 直接話す or 端末音声再生しな録音")
 
421
 
422
  tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
423