ayaka68 commited on
Commit
7d80d37
·
verified ·
1 Parent(s): 1803050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -282
app.py CHANGED
@@ -1,13 +1,15 @@
1
  # app.py
2
  """
3
  Voice→Place Recommender (Streamlit / Hugging Face Spaces)
4
- - 日本語音声感情認識:S3PRL(HuBERT base) + HFの下流(.pt) チェックポイントを用いた推論
5
- - Spaces Settings → Secrets に HF_TOKEN を設定してください
6
- - ffmpeg が必要(apt.txtに ffmpeg を記載)
 
 
7
  """
8
 
9
  # ===== 基本インポート =====
10
- import io, json, base64, random, os
11
  import numpy as np
12
  import soundfile as sf
13
  from pydub import AudioSegment
@@ -21,29 +23,24 @@ matplotlib.use('Agg')
21
  import matplotlib.pyplot as plt
22
  from matplotlib import rcParams
23
  import japanize_matplotlib
 
24
 
25
- # Torch / Hugging Face Hub / S3PRL
26
  import torch
27
  import torch.nn as nn
28
  from huggingface_hub import HfApi, hf_hub_download
29
  from s3prl.nn import S3PRLUpstream, Featurizer
30
 
31
-
32
- # ===== フォント設定 =====
33
- import matplotlib.font_manager as fm
34
-
35
- # 日本語フォントを優先(IPAexGothic → IPAGothic → Noto Sans CJK)
36
  jp_candidates = ["IPAexGothic", "IPAGothic", "Noto Sans CJK JP", "Noto Sans CJK"]
37
  for name in jp_candidates:
38
  if any(name in f.name for f in fm.fontManager.ttflist):
39
  rcParams["font.family"] = name
40
  break
41
  else:
42
- rcParams["font.family"] = "DejaVu Sans" # 最後の保険
43
-
44
  rcParams["axes.unicode_minus"] = False
45
 
46
-
47
  # ===== 架空の場所データ =====
48
  PLACES = [
49
  {"place_id":"lib_silent", "name":"無音図書館", "tags":["静けさ","集中","屋内"], "emo_key":"calm", "image":"images/lib_silent.png"},
@@ -63,33 +60,43 @@ PLACES = [
63
  ]
64
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
65
 
66
- # ===== KUSHINADA 定義(HF の gated モデルのリポ)=====
67
  KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
68
 
69
- # ===== S3PRL 下流ヘッド(線形) =====
70
- class SimpleLinearHead(nn.Module):
71
  """
72
- チェックポイント中の線形分類器 (W, b) を復元する簡易ヘッド。
73
- 入力: [B, T, H] → mean-pool → [B, H] → Linear(H, C)
74
  """
75
- def __init__(self, in_dim: int, num_classes: int, W: torch.Tensor, b: torch.Tensor):
76
  super().__init__()
77
- self.pool = lambda x: x.mean(dim=1) # 時系列平均
78
- self.fc = nn.Linear(in_dim, num_classes)
 
 
 
 
 
 
 
79
  with torch.no_grad():
80
- self.fc.weight.copy_(W) # [C, H]
81
- self.fc.bias.copy_(b) # [C]
82
-
83
- def forward(self, reps): # reps: [B, T, H]
84
- x = self.pool(reps)
 
 
 
 
 
 
 
 
 
85
  return self.fc(x)
86
 
87
- # ===== KUSHINADA (S3PRL) ローダ =====
88
- from huggingface_hub import HfApi, hf_hub_download
89
- from s3prl.nn import S3PRLUpstream, Featurizer
90
-
91
- KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
92
-
93
  @st.cache_resource(show_spinner=False)
94
  def load_kushinada_s3prl():
95
  token = os.getenv("HF_TOKEN")
@@ -97,11 +104,11 @@ def load_kushinada_s3prl():
97
  raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
98
 
99
  revision = os.getenv("KUSHINADA_REVISION", "main")
100
- prefer_filename = os.getenv("KUSHINADA_FILENAME") # 例: "s3prl/result/downstream/.../dev-best.ckpt"
101
 
102
  device = "cuda" if torch.cuda.is_available() else "cpu"
103
 
104
- # 上流 + Featurizer(最終層)
105
  upstream = S3PRLUpstream("hubert_base").to(device).eval()
106
  try:
107
  featurizer = Featurizer(upstream)
@@ -112,110 +119,164 @@ def load_kushinada_s3prl():
112
  featurizer = Featurizer(upstream, feature_selection="last_hidden_state")
113
  featurizer = featurizer.to(device).eval()
114
 
115
- # ckpt 選定
116
  api = HfApi()
117
  info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
118
  all_files = [s.rfilename for s in info.siblings]
119
- exts = (".pt", ".ckpt", ".pth", ".bin")
120
- candidates = [f for f in all_files if f.lower().endswith(exts)]
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  filename = None
123
  if prefer_filename:
 
124
  if prefer_filename in all_files:
125
  filename = prefer_filename
126
  else:
127
  matches = [f for f in all_files if f.endswith(prefer_filename)]
128
- if matches: filename = matches[0]
 
129
  if filename is None and candidates:
130
- candidates = sorted(
131
- candidates,
132
- key=lambda f: (
133
- -int(any(k in f.lower() for k in ["downstream","classifier","jtes","kushinada","dev-best","best"])),
134
- len(f)
135
- )
136
- )
137
- filename = candidates[0] if candidates else None
 
 
 
138
  if filename is None:
139
- raise FileNotFoundError("下流チェックポイント(.pt/.ckpt/.pth/.bin)が見つかりません。KUSHINADA_FILENAME を Secrets に指定してください。")
140
 
141
  ckpt_path = hf_hub_download(
142
- repo_id=KUSHINADA_REPO, filename=filename, revision=revision,
143
- token=token, repo_type="model", local_dir_use_symlinks=False
 
 
 
 
144
  )
145
  ckpt = torch.load(ckpt_path, map_location="cpu")
146
 
147
- # state_dict 取得
148
  state = None
149
  if isinstance(ckpt, dict):
150
- for key in ["state_dict","Downstream","model","downstream","net","weights"]:
151
  if key in ckpt and isinstance(ckpt[key], dict):
152
  state = ckpt[key]; break
153
- if state is None: state = ckpt
 
154
  if not isinstance(state, dict):
155
  raise RuntimeError("チェックポイント形式を解釈できませんでした。")
156
 
157
- # --- 最終分類層 (bias 長さが小さい層) を探索し、Wの向きを揃える ---
158
- clf = []
159
  for k, v in state.items():
160
- if isinstance(v, torch.Tensor) and v.ndim == 1:
161
- C = v.numel()
162
- if 2 <= C <= 16: # JTES は通常 4 クラス
163
- # 重み候補キー
164
- prefix = k[:-5] if k.endswith(".bias") else k.rsplit(".", 1)[0]
165
- cand_keys = [prefix + ".weight", k.replace(".bias",".weight"), k.replace("bias","weight")]
166
- for wk in cand_keys:
167
- W = state.get(wk, None)
168
- if isinstance(W, torch.Tensor) and W.ndim == 2:
169
- # 形状を [C, H] に正規化
170
- if W.shape[0] == C:
171
- oriented_W, H = W.float(), W.shape[1]
172
- elif W.shape[1] == C:
173
- oriented_W, H = W.t().float(), W.shape[0]
174
- else:
175
- continue
176
- clf.append((k, wk, C, H, oriented_W, v.float()))
177
- break
178
-
179
- if not clf:
180
- raise RuntimeError("最終分類層の重みが見つかりません(bias サイズ 2〜16 の層が見当たりません)。")
181
-
182
- # H が 768 に近いものを優先
183
- clf.sort(key=lambda x: (abs(x[3]-768), -x[2])) # (|H-768| 小さい, クラス数 大きい)
184
- _, _, num_classes, hidden_dim, linear_W, linear_b = clf[0]
185
-
186
- head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes,
187
- W=linear_W, b=linear_b).to(device).eval()
188
-
189
- default_labels = ["angry","happy","neutral","sad"]
190
- id2label = {i: (default_labels[i] if num_classes == 4 and i < 4 else f"class_{i}") for i in range(num_classes)}
191
-
192
- st.info(f"✅ 使うチェックポイント: `{filename}`(revision: {revision})")
193
- st.info(f"✅ 分類器 in_features={hidden_dim}, classes={num_classes}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  return featurizer, head, id2label, device
195
 
196
  # ===== ユーティリティ =====
197
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
198
- """任意音声をWAV(16kHz/mono)へ変換"""
199
  if not any_bytes:
200
- st.error("音声が空です。録音やアップロードを確認してください。")
201
- st.stop()
202
  try:
203
  seg = AudioSegment.from_file(io.BytesIO(any_bytes))
204
  except Exception as e:
205
- st.error(f"音声読込エラー: {e}")
206
- st.stop()
207
- if mono:
208
- seg = seg.set_channels(1)
209
- if target_sr:
210
- seg = seg.set_frame_rate(target_sr)
211
- buf = io.BytesIO()
212
- seg.export(buf, format="wav")
213
  return buf.getvalue()
214
 
215
  def audio_player_bytes(b: bytes, mime="audio/wav"):
216
- """音声プレイヤーを表示"""
217
- if not b:
218
- return
219
  b64 = base64.b64encode(b).decode("utf-8")
220
  st.markdown(
221
  f"""
@@ -226,75 +287,50 @@ def audio_player_bytes(b: bytes, mime="audio/wav"):
226
  unsafe_allow_html=True,
227
  )
228
 
229
- # ===== フォールバック用:簡易特徴量ベース =====
230
  def extract_features(y, sr):
231
- """音声から簡易特徴量を抽出"""
232
  abs_y = np.abs(y)
233
  thr = 0.01 * (abs_y.max() + 1e-9)
234
  idx = np.where(abs_y > thr)[0]
235
- if idx.size >= 2:
236
- y = y[idx[0]:idx[-1]+1]
237
-
238
  energy_mean = float(np.sqrt(np.mean(y**2) + 1e-12))
239
-
240
  n = len(y)
241
  win = np.hanning(n) if n >= 512 else np.ones_like(y)
242
  y_win = y * win
243
- spec = np.fft.rfft(y_win)
244
- mag = np.abs(spec) + 1e-12
245
  freqs = np.fft.rfftfreq(len(y_win), d=1.0/sr)
246
  sc_mean = float((freqs * mag).sum() / mag.sum())
247
-
248
  zc = (y[:-1] * y[1:] < 0).astype(np.float32)
249
  zcr_mean = float(zc.mean()) if zc.size else 0.0
250
-
251
- # F0推定(非常に簡易)
252
  fmin, fmax = 80.0, 600.0
253
  if len(y) < int(sr / fmin) + 2:
254
  f0_est = 0.0
255
  else:
256
  corr = np.correlate(y, y, mode='full')[len(y)-1:]
257
- lmin = max(1, int(sr / fmax))
258
- lmax = min(len(corr) - 1, int(sr / fmin))
259
  seg = corr[lmin:lmax] if lmax > lmin else np.array([])
260
  if seg.size > 0:
261
- lag = lmin + int(np.argmax(seg))
262
- f0_est = float(sr / lag) if lag > 0 else 0.0
263
  else:
264
  f0_est = 0.0
265
-
266
- return {
267
- "f0_mean": float(f0_est),
268
- "energy_mean": energy_mean,
269
- "spec_centroid": sc_mean,
270
- "zcr_mean": zcr_mean,
271
- "duration": len(y)/sr
272
- }
273
 
274
  def predict_emotion_features(audio_bytes):
275
- """音声特徴量から感情を推定(フォールバック)"""
276
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
277
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
278
  feat = extract_features(y, sr)
279
  f0, en, z = feat["f0_mean"], feat["energy_mean"], feat["zcr_mean"]
280
-
281
  arousal = float(np.tanh(160*en + 4*z))
282
  valence = float(np.tanh(((f0-170)/120) + 15*en))
283
-
284
- if valence >= 0.22 and arousal >= 0.22:
285
- label = "happiness"
286
- elif valence >= 0.22 and arousal < 0.22:
287
- label = "neutral" # calm-ish
288
- elif valence < 0.10 and arousal >= 0.30:
289
- label = "anger"
290
- elif valence < 0.10 and arousal < 0.18:
291
- label = "sadness"
292
- else:
293
- label = "neutral"
294
-
295
  scores = {k: 0.0 for k in ["happiness","anger","sadness","neutral"]}
296
- scores[label] = 0.7
297
- scores["neutral"] += 0.3
298
  return label, scores, "Features"
299
 
300
  # ===== AI推定(S3PRL)=====
@@ -314,92 +350,69 @@ def predict_emotion_ai(audio_bytes):
314
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
315
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
316
 
 
317
  max_duration = 30
318
  max_samples = int(sr * max_duration)
319
  if len(y) > max_samples:
320
- y = y[:max_samples]
321
- st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
322
 
 
323
  wavs = [torch.tensor(y, dtype=torch.float32)]
324
  wavs_len = [int(len(y))]
325
 
326
  with torch.no_grad():
327
- reps, reps_len = featurizer(wavs, wavs_len) # Tensor
328
  if not isinstance(reps, torch.Tensor):
329
  raise RuntimeError(f"Unexpected reps type: {type(reps)}")
330
-
331
- # reps [B,T,H?] or [B,H?,T] に正規化(B 次元を作る)
332
- if reps.dim() == 1: # [H]
333
- reps = reps.unsqueeze(0).unsqueeze(0) # [1,1,H]
334
- elif reps.dim() == 2: # [T,H] or [H,T]
335
- # どちらでも batch 次元を付ける
336
- reps = reps.unsqueeze(0) # [1,*,*]
337
- elif reps.dim() == 3:
338
- pass
339
- else:
340
  raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
341
 
342
- B, D1, D2 = reps.shape
343
- expected_H = head.fc.in_features
344
-
345
- # reps_len 正規化
346
- if reps_len is None:
347
- reps_len_list = [max(D1, D2)] * B
348
- elif isinstance(reps_len, int):
349
- reps_len_list = [int(reps_len)] * B
350
- elif isinstance(reps_len, (list, tuple)):
351
- reps_len_list = [int(x) for x in reps_len]
352
- if len(reps_len_list) != B: reps_len_list = [max(D1, D2)] * B
353
- elif isinstance(reps_len, torch.Tensor):
354
- reps_len_list = reps_len.view(-1).tolist()
355
- if len(reps_len_list) != B: reps_len_list = [max(D1, D2)] * B
 
 
 
 
 
 
 
 
 
 
356
  else:
357
- reps_len_list = [max(D1, D2)] * B
358
-
359
- # --- 特徴次元の自動判定&時間平均 ---
360
- def pool_over_time(tensor3d, time_dim):
361
- # time_dim: 1 or 2
362
- pooled = []
363
- for i in range(tensor3d.shape[0]):
364
- Ti = min(max(1, reps_len_list[i]), tensor3d.shape[time_dim])
365
- if time_dim == 1:
366
- pooled.append(tensor3d[i, :Ti].mean(dim=0)) # [H]
367
- else:
368
- pooled.append(tensor3d[i, :, :Ti].mean(dim=1)) # [H]
369
- return torch.stack(pooled, dim=0) # [B,H]
370
-
371
- if D2 == expected_H:
372
- pooled = pool_over_time(reps, time_dim=1) # [B,H]
373
- elif D1 == expected_H:
374
- pooled = pool_over_time(reps, time_dim=2) # [B,H]
375
- else:
376
- # どの軸も一致しない: 割り切れれば再整形して H を作る
377
- if D2 % expected_H == 0:
378
- k = D2 // expected_H
379
- reps2 = reps.view(B, D1, expected_H, k).mean(dim=3) # [B,D1,H]
380
- pooled = pool_over_time(reps2, time_dim=1)
381
- elif D1 % expected_H == 0:
382
- k = D1 // expected_H
383
- reps2 = reps.view(B, expected_H, k, D2).mean(dim=2) # [B,H,D2]
384
- pooled = pool_over_time(reps2, time_dim=2)
385
- else:
386
- raise RuntimeError(f"特徴次元が一致しません: reps.shape={tuple(reps.shape)}, 期待H={expected_H}")
387
-
388
- logits = head.fc(pooled.to(device)) # [B,C]
389
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
390
 
391
  pred_id = int(np.argmax(probs))
392
  raw_label = id2label[pred_id]
393
-
394
- def _norm(lbl: str) -> str:
395
- m = {"happy":"happiness", "angry":"anger", "sad":"sadness", "neutral":"neutral"}
396
- return m.get(lbl.lower(), lbl)
397
-
398
- label = _norm(raw_label)
399
- scores = {_norm(id2label[i]): float(probs[i]) for i in range(len(probs))}
400
- for k in list(scores.keys()):
401
- scores[k] = max(0.0, min(1.0, scores[k]))
402
-
403
  return label, scores, "AI(S3PRL)"
404
 
405
  except Exception as e:
@@ -413,10 +426,8 @@ def score_places(emo_label, top_k=4, diversity=True):
413
  "anger": ["release", "calm"],
414
  "sadness": ["calm", "joy"],
415
  "neutral": ["calm", "surprise", "joy"],
416
- "joy": ["joy","surprise"],
417
- "calm": ["calm","joy"],
418
- "surprise": ["surprise","joy"],
419
- "release": ["release","calm"],
420
  }
421
  priors = EMO_MAP_PRIORS.get(emo_label, ["calm","joy","surprise"])
422
  scored = []
@@ -427,63 +438,38 @@ def score_places(emo_label, top_k=4, diversity=True):
427
  scored.append((base + random.uniform(-0.02, 0.02), p))
428
  scored.sort(key=lambda x: x[0], reverse=True)
429
  candidates = [p for _, p in scored[:max(top_k, 4)]]
430
-
431
- if not diversity:
432
- return candidates[:top_k]
433
-
434
  picked, seen = [], set()
435
  for p in candidates:
436
- k = p["emo_key"]
437
- if k not in seen:
438
- picked.append(p); seen.add(k)
439
- if len(picked) >= top_k:
440
- break
441
  if len(picked) < top_k:
442
  for p in candidates:
443
- if p not in picked:
444
- picked.append(p)
445
- if len(picked) >= top_k:
446
- break
447
  return picked
448
 
449
  # ===== 可視化 =====
450
  def plot_emotion_map(emotion_label, scores, method="AI"):
451
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=150)
452
-
453
  emotion_jp = {
454
- 'happiness': '😊 喜び',
455
- 'anger': '😠 怒り',
456
- 'sadness': '😢 悲しみ',
457
- 'neutral': '😐 中立',
458
- 'joy': '😊 喜び',
459
- 'calm': '😌 落ち着き',
460
- 'surprise': '😲 驚き',
461
- 'release': '💨 発散'
462
  }
463
  color_map = {
464
- 'happiness': '#FF6B6B',
465
- 'anger': '#FFA94D',
466
- 'sadness': '#868E96',
467
- 'neutral': '#51CF66',
468
- 'joy': '#FF6B6B',
469
- 'calm': '#51CF66',
470
- 'surprise': '#74C0FC',
471
- 'release': '#FFD43B'
472
  }
473
-
474
- labels = list(scores.keys())
475
- values = [scores[k] for k in labels]
476
  colors = [color_map.get(k, '#74C0FC') for k in labels]
477
  bars = ax1.bar([emotion_jp.get(k,k) for k in labels], values, color=colors, alpha=0.85)
478
- ax1.set_ylim(0, 1)
479
- ax1.set_ylabel('Score', fontsize=12)
480
  ax1.set_title(f'Emotion Scores ({method})', fontsize=14, fontweight='bold')
481
  ax1.grid(axis='y', alpha=0.3)
482
- for bar, value in zip(bars, values):
483
  ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
484
- f'{value:.2f}', ha='center', va='bottom', fontsize=10)
485
-
486
- # 円グラフ(0.05未満は非表示)
487
  pairs = [(k,v) for k,v in scores.items() if v > 0.05]
488
  sizes = [v for _,v in pairs]
489
  labels_pie = [emotion_jp.get(k,k) for k,_ in pairs]
@@ -492,9 +478,7 @@ def plot_emotion_map(emotion_label, scores, method="AI"):
492
  autopct='%1.0f%%', startangle=90, textprops={'fontsize': 11})
493
  ax2.set_title(f'Result: {emotion_jp.get(emotion_label, emotion_label)}',
494
  fontsize=14, fontweight='bold')
495
-
496
- plt.tight_layout()
497
- return fig
498
 
499
  # ===== メイン =====
500
  def main():
@@ -510,7 +494,7 @@ def main():
510
  if key not in st.session_state: st.session_state[key] = default
511
 
512
  st.subheader("1) 録音またはアップロード")
513
- with st.warning("⚠️ ファイルアップロードで403が出る場合��、録音機能をご利用ください。"):
514
  st.markdown("**🎤 録音** → 直接話す or 端末で音声再生しながら録音")
515
 
516
  tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
@@ -522,18 +506,14 @@ def main():
522
  st.session_state["wav_bytes"] = buf.getvalue()
523
  audio_player_bytes(st.session_state["wav_bytes"], mime="audio/wav")
524
  st.caption(f"録音サイズ: {len(st.session_state['wav_bytes']) / 1024:.1f} KB")
525
-
526
  if st.button("🧹 クリアして新しく録音", width="stretch"):
527
  for k in ["wav_bytes","recs","feat","emotion_label","scores","method"]:
528
  st.session_state[k] = None
529
- st.session_state["rec_key"] += 1
530
- st.rerun()
531
 
532
  with tab_upload:
533
  uploaded_file = st.file_uploader(
534
- "音声ファイルを選択(WAV推奨)",
535
- type=["wav", "mp3", "m4a"],
536
- accept_multiple_files=False
537
  )
538
  if uploaded_file is not None:
539
  try:
@@ -543,8 +523,7 @@ def main():
543
  st.caption(f"ファイルサイズ: {len(bytes_data) / 1024:.1f} KB")
544
  audio_player_bytes(bytes_data, mime="audio/wav")
545
  except Exception as e:
546
- st.error("❌ ファイル読み込みエラー")
547
- st.exception(e)
548
  st.info("💡 代わりに録音機能をお試しください。")
549
 
550
  st.subheader("2) 同意")
@@ -552,9 +531,7 @@ def main():
552
  ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
553
  save_audio = st.checkbox("音声ファイルも保存する(任意)", value=False)
554
 
555
- analysis_method = st.radio("分析方法",
556
- ["AIモデル(推奨)", "音声特徴量ベース"],
557
- horizontal=True)
558
 
559
  if st.button("🔍 推定 & レコメンド", type="primary", width="stretch",
560
  disabled=(st.session_state["wav_bytes"] is None)):
@@ -564,12 +541,10 @@ def main():
564
  emotion_label, scores, method = predict_emotion_ai(raw_bytes)
565
  else:
566
  emotion_label, scores, method = predict_emotion_features(raw_bytes)
567
-
568
  st.session_state["emotion_label"] = emotion_label
569
  st.session_state["scores"] = scores
570
  st.session_state["method"] = method
571
  st.session_state["recs"] = score_places(emotion_label, top_k=4, diversity=True)
572
-
573
  st.success("分析が完了しました!")
574
 
575
  if st.session_state["recs"] is not None:
@@ -577,7 +552,6 @@ def main():
577
  scores = st.session_state["scores"]
578
  method = st.session_state["method"]
579
  recs = st.session_state["recs"]
580
-
581
  emotion_japanese = {
582
  'happiness': '喜び', 'anger': '怒り', 'sadness': '悲しみ', 'neutral': '中立',
583
  'joy': '喜び', 'calm': '落ち着き', 'surprise': '驚き', 'release': '発散'
@@ -606,10 +580,8 @@ def main():
606
  cols = st.columns(4)
607
  for i, p in enumerate(recs[:4]):
608
  with cols[i % 4]:
609
- if "image" in p:
610
- st.image(p["image"], width="stretch")
611
- st.markdown(f"**{p['name']}**")
612
- st.caption(f"タグ: {', '.join(p['tags'])}")
613
 
614
  st.subheader("4) 評価")
615
  choice_name = st.selectbox("第一候補を選んでください", [p["name"] for p in recs[:4]])
@@ -617,20 +589,16 @@ def main():
617
  rating_vibe = st.slider("気分に合う度(🎯)", 1, 5, 4)
618
  reasons = st.multiselect("理由タグ(1–3個)", REASON_TAGS, max_selections=3)
619
  comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
620
-
621
  if st.button("💾 ログ保存", width="stretch"):
622
  consent_research = (consent == "匿名で保存する")
623
- if not consent_research:
624
- st.info("体験のみモードです。研究ログは保存しません。")
625
- else:
626
- st.success("保存機能は開発中です。")
627
 
628
  st.divider()
629
  if st.button("▶ 次の人を録音する(状態をクリ���)", width="stretch"):
630
  for k in ["wav_bytes","recs","emotion_label","scores","method"]:
631
  st.session_state[k] = None
632
- st.session_state["rec_key"] += 1
633
- st.rerun()
634
 
635
  if __name__ == "__main__":
636
  main()
 
1
  # app.py
2
  """
3
  Voice→Place Recommender (Streamlit / Hugging Face Spaces)
4
+ - 日本語音声感情認識:S3PRL(HuBERT base) + HFの下流(.ckpt)を用いてJTES(4感情)推定
5
+ - Spaces Settings → Secrets に HF_TOKEN(Read権限)を設定
6
+ - 可能なら KUSHINADA_FILENAME ckpt を明示指定(例: s3prl/result/downstream/.../dev-best.ckpt)
7
+ - apt.txt: ffmpeg, (任意で)fonts-ipaexfont, fonts-noto-cjk
8
+ - requirements.txt: streamlit-audiorecorder, s3prl==0.4.17, torch==2.0.1, torchaudio==2.0.2 など
9
  """
10
 
11
  # ===== 基本インポート =====
12
+ import io, base64, os, random
13
  import numpy as np
14
  import soundfile as sf
15
  from pydub import AudioSegment
 
23
  import matplotlib.pyplot as plt
24
  from matplotlib import rcParams
25
  import japanize_matplotlib
26
+ import matplotlib.font_manager as fm
27
 
28
+ # Torch / HF Hub / S3PRL
29
  import torch
30
  import torch.nn as nn
31
  from huggingface_hub import HfApi, hf_hub_download
32
  from s3prl.nn import S3PRLUpstream, Featurizer
33
 
34
+ # ===== フォント設定(日本語) =====
 
 
 
 
35
  jp_candidates = ["IPAexGothic", "IPAGothic", "Noto Sans CJK JP", "Noto Sans CJK"]
36
  for name in jp_candidates:
37
  if any(name in f.name for f in fm.fontManager.ttflist):
38
  rcParams["font.family"] = name
39
  break
40
  else:
41
+ rcParams["font.family"] = "DejaVu Sans"
 
42
  rcParams["axes.unicode_minus"] = False
43
 
 
44
  # ===== 架空の場所データ =====
45
  PLACES = [
46
  {"place_id":"lib_silent", "name":"無音図書館", "tags":["静けさ","集中","屋内"], "emo_key":"calm", "image":"images/lib_silent.png"},
 
60
  ]
61
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
62
 
63
+ # ===== モデル定義 =====
64
  KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
65
 
66
+ # ---- Downstream ヘッド(1層 or 2層MLP) ----
67
+ class DownstreamHead(nn.Module):
68
  """
69
+ in -> (optional proj Linear) -> (optional ReLU) -> final Linear -> logits
 
70
  """
71
+ def __init__(self, in_dim, out_dim, W_final, b_final, proj_W=None, proj_b=None):
72
  super().__init__()
73
+ self.proj = None
74
+ if proj_W is not None and proj_b is not None:
75
+ proj_out, proj_in = proj_W.shape # [out, in]
76
+ self.proj = nn.Linear(proj_in, proj_out)
77
+ with torch.no_grad():
78
+ self.proj.weight.copy_(proj_W)
79
+ self.proj.bias.copy_(proj_b)
80
+ in_dim = proj_out # 後段の入力次元
81
+ self.fc = nn.Linear(in_dim, out_dim)
82
  with torch.no_grad():
83
+ self.fc.weight.copy_(W_final)
84
+ self.fc.bias.copy_(b_final)
85
+
86
+ @property
87
+ def expected_in(self):
88
+ # 入力期待次元(Featurizerからのプール後に一致させたい次元)
89
+ if self.proj is not None:
90
+ return self.proj.in_features
91
+ return self.fc.in_features
92
+
93
+ def forward(self, x): # x: [B, expected_in]
94
+ if self.proj is not None:
95
+ x = self.proj(x)
96
+ # 学習時に非線形を挟んでいた可能性はあるが未知なので省略(必要ならnn.ReLU()等)
97
  return self.fc(x)
98
 
99
+ # ====== KUSHINADA ローダ(上流 + featurizer + 下流ヘッド構築) ======
 
 
 
 
 
100
  @st.cache_resource(show_spinner=False)
101
  def load_kushinada_s3prl():
102
  token = os.getenv("HF_TOKEN")
 
104
  raise RuntimeError("環境変数 HF_TOKEN が見つかりません。SpacesのSettings→Secretsで設定してください。")
105
 
106
  revision = os.getenv("KUSHINADA_REVISION", "main")
107
+ prefer_filename = os.getenv("KUSHINADA_FILENAME", "").strip()
108
 
109
  device = "cuda" if torch.cuda.is_available() else "cpu"
110
 
111
+ # 1) 上流 + Featurizer(最終層)
112
  upstream = S3PRLUpstream("hubert_base").to(device).eval()
113
  try:
114
  featurizer = Featurizer(upstream)
 
119
  featurizer = Featurizer(upstream, feature_selection="last_hidden_state")
120
  featurizer = featurizer.to(device).eval()
121
 
122
+ # 2) ckpt選定(下流のみ。upstream/converted系は除外)
123
  api = HfApi()
124
  info = api.model_info(KUSHINADA_REPO, token=token, revision=revision)
125
  all_files = [s.rfilename for s in info.siblings]
 
 
126
 
127
+ def is_ckpt(path):
128
+ p = path.lower()
129
+ if not (p.endswith(".pt") or p.endswith(".ckpt") or p.endswith(".pth") or p.endswith(".bin")):
130
+ return False
131
+ # 上流や変換済みの類は除外
132
+ bad = ["upstream", "converted", "hubert_base", "s3prl/converted", "wav2vec", "espnet"]
133
+ if any(b in p for b in bad):
134
+ return False
135
+ return True
136
+
137
+ candidates = [f for f in all_files if is_ckpt(f)]
138
+
139
+ # 優先順位: 明示指定 > downstream/dev-best > best > fold > others
140
  filename = None
141
  if prefer_filename:
142
+ # サブパス一致/末尾一致にも対応
143
  if prefer_filename in all_files:
144
  filename = prefer_filename
145
  else:
146
  matches = [f for f in all_files if f.endswith(prefer_filename)]
147
+ if matches:
148
+ filename = matches[0]
149
  if filename is None and candidates:
150
+ def rank_score(f):
151
+ f_lower = f.lower()
152
+ score = 0
153
+ if "result/downstream" in f_lower: score += 100
154
+ if "dev-best" in f_lower: score += 50
155
+ if "best" in f_lower: score += 20
156
+ if "fold" in f_lower: score += 10
157
+ if "kushinada" in f_lower: score += 5
158
+ return -score, len(f) # スコア高→優先、短すぎる名前は避けたいので長さも加味
159
+ candidates_sorted = sorted(candidates, key=rank_score)
160
+ filename = candidates_sorted[0]
161
  if filename is None:
162
+ raise FileNotFoundError("下流チェックポイントが見つかりません。KUSHINADA_FILENAME を Secrets に設定してください。")
163
 
164
  ckpt_path = hf_hub_download(
165
+ repo_id=KUSHINADA_REPO,
166
+ filename=filename,
167
+ revision=revision,
168
+ token=token,
169
+ repo_type="model",
170
+ local_dir_use_symlinks=False
171
  )
172
  ckpt = torch.load(ckpt_path, map_location="cpu")
173
 
174
+ # 3) state_dict 取得
175
  state = None
176
  if isinstance(ckpt, dict):
177
+ for key in ["state_dict", "Downstream", "model", "downstream", "net", "weights"]:
178
  if key in ckpt and isinstance(ckpt[key], dict):
179
  state = ckpt[key]; break
180
+ if state is None:
181
+ state = ckpt
182
  if not isinstance(state, dict):
183
  raise RuntimeError("チェックポイント形式を解釈できませんでした。")
184
 
185
+ # 4) すべての (weight,bias) の線形層候補を収集([out,in]に整形)
186
+ layers = []
187
  for k, v in state.items():
188
+ if isinstance(v, torch.Tensor) and v.ndim == 1: # bias
189
+ b = v.float()
190
+ base = k[:-5] if k.endswith(".bias") else k.rsplit(".", 1)[0]
191
+ w_key = base + ".weight"
192
+ if w_key in state and isinstance(state[w_key], torch.Tensor) and state[w_key].ndim == 2:
193
+ W = state[w_key].float()
194
+ # [out, in] に整形
195
+ if W.shape[0] >= 2 and W.shape[1] >= 2:
196
+ out, in_ = W.shape
197
+ layers.append({
198
+ "name": base,
199
+ "W": W, "b": b,
200
+ "out": out, "in": in_
201
+ })
202
+ else:
203
+ # 逆向きの可能性 [in,out] を考慮
204
+ Wt = W.t()
205
+ out, in_ = Wt.shape
206
+ layers.append({
207
+ "name": base,
208
+ "W": Wt, "b": b,
209
+ "out": out, "in": in_
210
+ })
211
+
212
+ if not layers:
213
+ raise RuntimeError("線形層の (weight, bias) が見つかりませんでした。")
214
+
215
+ # 5) 最終層候補(出力クラスが小さい層を優先)
216
+ finals = [L for L in layers if 2 <= L["out"] <= 16]
217
+ if not finals:
218
+ raise RuntimeError("最終分類層らしき小クラス数の線形層が見つかりませんでした。")
219
+
220
+ # 768や256がよく使われるので、それに近いinを優先。名前でclassifier等があればさらに加点
221
+ def final_rank(L):
222
+ score = 0
223
+ if "class" in L["name"].lower() or "out" in L["name"].lower() or "fc" in L["name"].lower():
224
+ score += 3
225
+ score -= abs(L["in"] - 256) / 256.0
226
+ score -= abs(L["in"] - 768) / 768.0
227
+ return -score
228
+ finals_sorted = sorted(finals, key=final_rank)
229
+ final = finals_sorted[0]
230
+
231
+ # 6) 前段の射影(final.in に一致する out を持つ層)を探索
232
+ proj = None
233
+ proj_candidates = [L for L in layers if L["out"] == final["in"]]
234
+ if proj_candidates:
235
+ def proj_rank(L):
236
+ score = 0
237
+ if "proj" in L["name"].lower() or "linear" in L["name"].lower() or "fc" in L["name"].lower():
238
+ score += 2
239
+ score -= abs(L["in"] - 768) / 768.0
240
+ return -score
241
+ proj = sorted(proj_candidates, key=proj_rank)[0]
242
+
243
+ # 7) DownstreamHead 構築
244
+ if proj is not None:
245
+ head = DownstreamHead(
246
+ in_dim=proj["in"], out_dim=final["out"],
247
+ W_final=final["W"], b_final=final["b"],
248
+ proj_W=proj["W"], proj_b=proj["b"]
249
+ )
250
+ else:
251
+ head = DownstreamHead(
252
+ in_dim=final["in"], out_dim=final["out"],
253
+ W_final=final["W"], b_final=final["b"]
254
+ )
255
+ head = head.to(device).eval()
256
+
257
+ # 8) ラベル(JTES想定)
258
+ default_labels = ["angry", "happy", "neutral", "sad"]
259
+ id2label = {i: (default_labels[i] if head.fc.out_features == 4 and i < 4 else f"class_{i}") for i in range(head.fc.out_features)}
260
+
261
+ st.info(f"✅ ckpt: `{filename}`(rev: {revision})")
262
+ st.info(f"✅ head.expected_in={head.expected_in}, final_out={head.fc.out_features}")
263
  return featurizer, head, id2label, device
264
 
265
  # ===== ユーティリティ =====
266
  def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
 
267
  if not any_bytes:
268
+ st.error("音声が空です。録音やアップロードを確認してください。"); st.stop()
 
269
  try:
270
  seg = AudioSegment.from_file(io.BytesIO(any_bytes))
271
  except Exception as e:
272
+ st.error(f"音声読込エラー: {e}"); st.stop()
273
+ if mono: seg = seg.set_channels(1)
274
+ if target_sr: seg = seg.set_frame_rate(target_sr)
275
+ buf = io.BytesIO(); seg.export(buf, format="wav")
 
 
 
 
276
  return buf.getvalue()
277
 
278
  def audio_player_bytes(b: bytes, mime="audio/wav"):
279
+ if not b: return
 
 
280
  b64 = base64.b64encode(b).decode("utf-8")
281
  st.markdown(
282
  f"""
 
287
  unsafe_allow_html=True,
288
  )
289
 
290
+ # ===== フォールバック(簡易特徴量) =====
291
  def extract_features(y, sr):
 
292
  abs_y = np.abs(y)
293
  thr = 0.01 * (abs_y.max() + 1e-9)
294
  idx = np.where(abs_y > thr)[0]
295
+ if idx.size >= 2: y = y[idx[0]:idx[-1]+1]
 
 
296
  energy_mean = float(np.sqrt(np.mean(y**2) + 1e-12))
 
297
  n = len(y)
298
  win = np.hanning(n) if n >= 512 else np.ones_like(y)
299
  y_win = y * win
300
+ spec = np.fft.rfft(y_win); mag = np.abs(spec) + 1e-12
 
301
  freqs = np.fft.rfftfreq(len(y_win), d=1.0/sr)
302
  sc_mean = float((freqs * mag).sum() / mag.sum())
 
303
  zc = (y[:-1] * y[1:] < 0).astype(np.float32)
304
  zcr_mean = float(zc.mean()) if zc.size else 0.0
305
+ # 超簡易F0
 
306
  fmin, fmax = 80.0, 600.0
307
  if len(y) < int(sr / fmin) + 2:
308
  f0_est = 0.0
309
  else:
310
  corr = np.correlate(y, y, mode='full')[len(y)-1:]
311
+ lmin = max(1, int(sr / fmax)); lmax = min(len(corr) - 1, int(sr / fmin))
 
312
  seg = corr[lmin:lmax] if lmax > lmin else np.array([])
313
  if seg.size > 0:
314
+ lag = lmin + int(np.argmax(seg)); f0_est = float(sr / lag) if lag > 0 else 0.0
 
315
  else:
316
  f0_est = 0.0
317
+ return {"f0_mean": float(f0_est), "energy_mean": energy_mean, "spec_centroid": sc_mean,
318
+ "zcr_mean": zcr_mean, "duration": len(y)/sr}
 
 
 
 
 
 
319
 
320
  def predict_emotion_features(audio_bytes):
 
321
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
322
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
323
  feat = extract_features(y, sr)
324
  f0, en, z = feat["f0_mean"], feat["energy_mean"], feat["zcr_mean"]
 
325
  arousal = float(np.tanh(160*en + 4*z))
326
  valence = float(np.tanh(((f0-170)/120) + 15*en))
327
+ if valence >= 0.22 and arousal >= 0.22: label = "happiness"
328
+ elif valence >= 0.22 and arousal < 0.22: label = "neutral"
329
+ elif valence < 0.10 and arousal >= 0.30: label = "anger"
330
+ elif valence < 0.10 and arousal < 0.18: label = "sadness"
331
+ else: label = "neutral"
 
 
 
 
 
 
 
332
  scores = {k: 0.0 for k in ["happiness","anger","sadness","neutral"]}
333
+ scores[label] = 0.7; scores["neutral"] += 0.3
 
334
  return label, scores, "Features"
335
 
336
  # ===== AI推定(S3PRL)=====
 
350
  wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
351
  y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
352
 
353
+ # 30秒でカット
354
  max_duration = 30
355
  max_samples = int(sr * max_duration)
356
  if len(y) > max_samples:
357
+ y = y[:max_samples]; st.warning("音声が30秒を超えたため、最初の30秒のみ分析します。")
 
358
 
359
+ # S3PRLは list[Tensor], list[int] を想定
360
  wavs = [torch.tensor(y, dtype=torch.float32)]
361
  wavs_len = [int(len(y))]
362
 
363
  with torch.no_grad():
364
+ reps, reps_len = featurizer(wavs, wavs_len) # 期待: reps [B,T,H], reps_len list[int]
365
  if not isinstance(reps, torch.Tensor):
366
  raise RuntimeError(f"Unexpected reps type: {type(reps)}")
367
+ # reps を [B,T,H] へ
368
+ if reps.dim() == 1: reps = reps.unsqueeze(0).unsqueeze(0)
369
+ elif reps.dim() == 2: reps = reps.unsqueeze(0)
370
+ elif reps.dim() != 3:
 
 
 
 
 
 
371
  raise RuntimeError(f"Unexpected reps.dim(): {reps.dim()}")
372
 
373
+ B, T, H = reps.shape
374
+
375
+ # reps_len を [B] リストに
376
+ if reps_len is None: reps_len_list = [T]*B
377
+ elif isinstance(reps_len, int): reps_len_list = [int(reps_len)]*B
378
+ elif isinstance(reps_len, (list, tuple)): reps_len_list = [int(x) for x in reps_len]
379
+ elif isinstance(reps_len, torch.Tensor): reps_len_list = reps_len.view(-1).tolist()
380
+ else: reps_len_list = [T]*B
381
+ if len(reps_len_list) != B: reps_len_list = [T]*B
382
+ reps_len_list = [max(1, min(int(li), T)) for li in reps_len_list]
383
+
384
+ # 有効長で時間平均 → [B,H_feat]
385
+ pooled = torch.stack([reps[i, :reps_len_list[i]].mean(dim=0) for i in range(B)], dim=0) # [B,H_feat]
386
+
387
+ # 次元整合:期待入力に合わせる
388
+ expected_in = head.expected_in
389
+ H_feat = pooled.shape[1]
390
+
391
+ if H_feat == expected_in:
392
+ pooled_in = pooled
393
+ elif H_feat % expected_in == 0:
394
+ g = H_feat // expected_in
395
+ pooled_in = pooled.view(B, expected_in, g).mean(dim=2) # グループ平均で縮約
396
+ st.info(f"ℹ️ 特徴次元を {H_feat}→{expected_in} にグループ平均で整合 (group={g})")
397
  else:
398
+ # どうしても合わない場合は線形射影(最小限の適合用)
399
+ proj = nn.Linear(H_feat, expected_in).to(pooled.device)
400
+ with torch.no_grad():
401
+ nn.init.eye_(proj.weight[:min(H_feat, expected_in), :min(H_feat, expected_in)])
402
+ if expected_in > H_feat:
403
+ nn.init.zeros_(proj.weight[min(H_feat, expected_in):])
404
+ nn.init.zeros_(proj.bias)
405
+ pooled_in = proj(pooled)
406
+ st.info(f"ℹ️ 線形射影で {H_feat}→{expected_in} に適合")
407
+
408
+ logits = head(pooled_in.to(device)) # [B,C]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
410
 
411
  pred_id = int(np.argmax(probs))
412
  raw_label = id2label[pred_id]
413
+ label = _normalize_label(raw_label)
414
+ scores = {_normalize_label(id2label[i]): float(probs[i]) for i in range(len(probs))}
415
+ for k in list(scores.keys()): scores[k] = max(0.0, min(1.0, scores[k]))
 
 
 
 
 
 
 
416
  return label, scores, "AI(S3PRL)"
417
 
418
  except Exception as e:
 
426
  "anger": ["release", "calm"],
427
  "sadness": ["calm", "joy"],
428
  "neutral": ["calm", "surprise", "joy"],
429
+ "joy": ["joy","surprise"], "calm": ["calm","joy"],
430
+ "surprise": ["surprise","joy"], "release": ["release","calm"],
 
 
431
  }
432
  priors = EMO_MAP_PRIORS.get(emo_label, ["calm","joy","surprise"])
433
  scored = []
 
438
  scored.append((base + random.uniform(-0.02, 0.02), p))
439
  scored.sort(key=lambda x: x[0], reverse=True)
440
  candidates = [p for _, p in scored[:max(top_k, 4)]]
441
+ if not diversity: return candidates[:top_k]
 
 
 
442
  picked, seen = [], set()
443
  for p in candidates:
444
+ if p["emo_key"] not in seen:
445
+ picked.append(p); seen.add(p["emo_key"])
446
+ if len(picked) >= top_k: break
 
 
447
  if len(picked) < top_k:
448
  for p in candidates:
449
+ if p not in picked: picked.append(p)
450
+ if len(picked) >= top_k: break
 
 
451
  return picked
452
 
453
  # ===== 可視化 =====
454
  def plot_emotion_map(emotion_label, scores, method="AI"):
455
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5), dpi=150)
 
456
  emotion_jp = {
457
+ 'happiness': '😊 喜び', 'anger': '😠 怒り', 'sadness': '😢 悲しみ', 'neutral': '😐 中立',
458
+ 'joy': '😊 喜び', 'calm': '😌 落ち着き', 'surprise': '😲 驚き', 'release': '💨 発散'
 
 
 
 
 
 
459
  }
460
  color_map = {
461
+ 'happiness': '#FF6B6B','anger': '#FFA94D','sadness': '#868E96','neutral': '#51CF66',
462
+ 'joy': '#FF6B6B','calm': '#51CF66','surprise': '#74C0FC','release': '#FFD43B'
 
 
 
 
 
 
463
  }
464
+ labels = list(scores.keys()); values = [scores[k] for k in labels]
 
 
465
  colors = [color_map.get(k, '#74C0FC') for k in labels]
466
  bars = ax1.bar([emotion_jp.get(k,k) for k in labels], values, color=colors, alpha=0.85)
467
+ ax1.set_ylim(0, 1); ax1.set_ylabel('Score', fontsize=12)
 
468
  ax1.set_title(f'Emotion Scores ({method})', fontsize=14, fontweight='bold')
469
  ax1.grid(axis='y', alpha=0.3)
470
+ for bar, v in zip(bars, values):
471
  ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
472
+ f'{v:.2f}', ha='center', va='bottom', fontsize=10)
 
 
473
  pairs = [(k,v) for k,v in scores.items() if v > 0.05]
474
  sizes = [v for _,v in pairs]
475
  labels_pie = [emotion_jp.get(k,k) for k,_ in pairs]
 
478
  autopct='%1.0f%%', startangle=90, textprops={'fontsize': 11})
479
  ax2.set_title(f'Result: {emotion_jp.get(emotion_label, emotion_label)}',
480
  fontsize=14, fontweight='bold')
481
+ plt.tight_layout(); return fig
 
 
482
 
483
  # ===== メイン =====
484
  def main():
 
494
  if key not in st.session_state: st.session_state[key] = default
495
 
496
  st.subheader("1) 録音またはアップロード")
497
+ with st.warning("⚠️ アップロードで403が出る場合は、録音機能をご利用ください。"):
498
  st.markdown("**🎤 録音** → 直接話す or 端末で音声再生しながら録音")
499
 
500
  tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
 
506
  st.session_state["wav_bytes"] = buf.getvalue()
507
  audio_player_bytes(st.session_state["wav_bytes"], mime="audio/wav")
508
  st.caption(f"録音サイズ: {len(st.session_state['wav_bytes']) / 1024:.1f} KB")
 
509
  if st.button("🧹 クリアして新しく録音", width="stretch"):
510
  for k in ["wav_bytes","recs","feat","emotion_label","scores","method"]:
511
  st.session_state[k] = None
512
+ st.session_state["rec_key"] += 1; st.rerun()
 
513
 
514
  with tab_upload:
515
  uploaded_file = st.file_uploader(
516
+ "音声ファイルを選択(WAV推奨)", type=["wav", "mp3", "m4a"], accept_multiple_files=False
 
 
517
  )
518
  if uploaded_file is not None:
519
  try:
 
523
  st.caption(f"ファイルサイズ: {len(bytes_data) / 1024:.1f} KB")
524
  audio_player_bytes(bytes_data, mime="audio/wav")
525
  except Exception as e:
526
+ st.error("❌ ファイル読み込みエラー"); st.exception(e)
 
527
  st.info("💡 代わりに録音機能をお試しください。")
528
 
529
  st.subheader("2) 同意")
 
531
  ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
532
  save_audio = st.checkbox("音声ファイルも保存する(任意)", value=False)
533
 
534
+ analysis_method = st.radio("分析方法", ["AIモデル(推奨)", "音声特徴量ベース"], horizontal=True)
 
 
535
 
536
  if st.button("🔍 推定 & レコメンド", type="primary", width="stretch",
537
  disabled=(st.session_state["wav_bytes"] is None)):
 
541
  emotion_label, scores, method = predict_emotion_ai(raw_bytes)
542
  else:
543
  emotion_label, scores, method = predict_emotion_features(raw_bytes)
 
544
  st.session_state["emotion_label"] = emotion_label
545
  st.session_state["scores"] = scores
546
  st.session_state["method"] = method
547
  st.session_state["recs"] = score_places(emotion_label, top_k=4, diversity=True)
 
548
  st.success("分析が完了しました!")
549
 
550
  if st.session_state["recs"] is not None:
 
552
  scores = st.session_state["scores"]
553
  method = st.session_state["method"]
554
  recs = st.session_state["recs"]
 
555
  emotion_japanese = {
556
  'happiness': '喜び', 'anger': '怒り', 'sadness': '悲しみ', 'neutral': '中立',
557
  'joy': '喜び', 'calm': '落ち着き', 'surprise': '驚き', 'release': '発散'
 
580
  cols = st.columns(4)
581
  for i, p in enumerate(recs[:4]):
582
  with cols[i % 4]:
583
+ if "image" in p: st.image(p["image"], width="stretch")
584
+ st.markdown(f"**{p['name']}**"); st.caption(f"タグ: {', '.join(p['tags'])}")
 
 
585
 
586
  st.subheader("4) 評価")
587
  choice_name = st.selectbox("第一候補を選んでください", [p["name"] for p in recs[:4]])
 
589
  rating_vibe = st.slider("気分に合う度(🎯)", 1, 5, 4)
590
  reasons = st.multiselect("理由タグ(1–3個)", REASON_TAGS, max_selections=3)
591
  comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
 
592
  if st.button("💾 ログ保存", width="stretch"):
593
  consent_research = (consent == "匿名で保存する")
594
+ if not consent_research: st.info("体験のみモードです。研究ログは保存しません。")
595
+ else: st.success("保存機能は開発中です。")
 
 
596
 
597
  st.divider()
598
  if st.button("▶ 次の人を録音する(状態をクリ���)", width="stretch"):
599
  for k in ["wav_bytes","recs","emotion_label","scores","method"]:
600
  st.session_state[k] = None
601
+ st.session_state["rec_key"] += 1; st.rerun()
 
602
 
603
  if __name__ == "__main__":
604
  main()