ayaka68 commited on
Commit
f446a3d
·
verified ·
1 Parent(s): 27e5bec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -294
app.py CHANGED
@@ -1,92 +1,26 @@
1
  # =========================
2
- # streamlit_app.py 最終版
3
  # =========================
4
  import os
5
- import tempfile
6
- import warnings
7
- import logging
8
  import io
9
  import uuid
10
  import datetime as dt
11
  import csv
12
  import base64
13
- import json
14
  import random
 
15
 
16
- # --- ロギング・警告の抑制 ---
17
- logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
18
- logging.getLogger('matplotlib').setLevel(logging.ERROR)
19
  warnings.filterwarnings('ignore')
20
 
21
- # --- 権限/キャッシュ対策 ---
22
- os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
23
- os.environ["NUMBA_DISABLE_JIT"] = "1"
24
- os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
25
-
26
- # --- Matplotlibのバックエンド設定 ---
27
- mpl_config_dir = tempfile.mkdtemp()
28
- os.environ["MPLCONFIGDIR"] = mpl_config_dir
29
- matplotlibrc_path = os.path.join(mpl_config_dir, 'matplotlibrc')
30
- with open(matplotlibrc_path, 'w') as f:
31
- f.write("""
32
- backend: Agg
33
- font.family: sans-serif
34
- axes.unicode_minus: False
35
- """)
36
-
37
  # --- ライブラリのインポート ---
38
- import matplotlib
39
- matplotlib.use('Agg')
40
- import matplotlib.pyplot as plt
41
- import japanize_matplotlib # 日本語化ライブラリ
42
-
43
  import numpy as np
44
  import soundfile as sf
45
  import streamlit as st
46
  from audiorecorder import audiorecorder
47
  from pydub import AudioSegment
48
- from matplotlib import rcParams
49
-
50
- plt.ioff()
51
- try:
52
- os.makedirs("/tmp/numba_cache", exist_ok=True)
53
- except:
54
- pass
55
-
56
- # Matplotlibのマイナス記号の文字化け対策
57
- rcParams["axes.unicode_minus"] = False
58
-
59
- # =========================
60
- # アプリケーション設定値
61
- # =========================
62
- class AppConfig:
63
- """アプリ内の閾値や係数を管理するクラス"""
64
- # A/V推定の閾値
65
- VALENCE_DEAD_ZONE = 0.12
66
- AROUSAL_DEAD_ZONE = 0.12
67
- JOY_VALENCE_THRESHOLD = 0.22
68
- JOY_AROUSAL_THRESHOLD = 0.22
69
- HIGH_AROUSAL_NEG_VALENCE_THRESHOLD = 0.10
70
- HIGH_AROUSAL_THRESHOLD = 0.30
71
- SURPRISE_AROUSAL_THRESHOLD = 0.22
72
-
73
- # A/V計算の係数
74
- ENERGY_TO_AROUSAL_COEF = 160.0
75
- ZCR_TO_AROUSAL_COEF = 4.0
76
- F0_TO_VALENCE_OFFSET = 170.0
77
- F0_TO_VALENCE_SCALE = 120.0
78
- ENERGY_TO_VALENCE_COEF = 15.0
79
-
80
- # refine_labelの閾値
81
- ANGER_AROUSAL = 0.35
82
- ANGER_SC = 1500
83
- ANGER_ZCR = 0.05
84
- ANGER_ENERGY = 0.015
85
- SADNESS_AROUSAL = 0.18
86
- SADNESS_ENERGY = 0.012
87
- SADNESS_SC = 800
88
- TENSION_AROUSAL = 0.30
89
- TENSION_ZCR = 0.06
90
 
91
  # =========================
92
  # 架空の場所データ
@@ -108,91 +42,79 @@ PLACES = [
108
  {"place_id":"urban_track", "name":"アーバントラック", "tags":["身体活動","発散","屋外"], "emo_key":"release", "image":"images/urban_track.png"},
109
  ]
110
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
111
- EMO_MAP_PRIORS = {
112
- "joy": ["joy","surprise"], "calm": ["calm","joy"], "surprise": ["surprise","joy"],
113
- "arousal_high_neg": ["release","surprise"], "neutral": ["calm","joy","surprise"],
114
- "anger": ["release","surprise"], "sadness": ["calm","joy"], "tension": ["calm","surprise"],
115
- }
116
 
117
  # =========================
118
- # コアロジック (関数群)
119
  # =========================
120
 
121
- def audio_player_bytes(b: bytes, mime="audio/wav"):
122
- """音声データをUIに表示するためのHTMLを生成"""
123
- if not b: return
124
- b64 = base64.b64encode(b).decode("utf-8")
125
- st.markdown(
126
- f'<audio controls preload="metadata" style="width:100%"><source src="data:{mime};base64,{b64}" type="{mime}"></audio>',
127
- unsafe_allow_html=True,
128
- )
 
 
 
 
 
 
 
129
 
130
- def extract_features(y, sr):
131
- """音声波形データから特徴量を抽出"""
132
- abs_y = np.abs(y)
133
- thr = 0.01 * (abs_y.max() + 1e-9)
134
- idx = np.where(abs_y > thr)[0]
135
- if idx.size >= 2: y = y[idx[0]:idx[-1]+1]
136
 
137
- energy_mean = float(np.sqrt(np.mean(y**2) + 1e-12))
 
 
138
 
139
- n = len(y)
140
- win = np.hanning(n) if n >= 512 else np.ones_like(y)
141
- spec = np.fft.rfft(y * win)
142
- mag = np.abs(spec) + 1e-12
143
- freqs = np.fft.rfftfreq(len(y * win), d=1.0/sr)
144
- sc_mean = float((freqs * mag).sum() / mag.sum())
145
 
146
- zc = (y[:-1] * y[1:] < 0).astype(np.float32)
147
- zcr_mean = float(zc.mean()) if zc.size else 0.0
 
148
 
149
- fmin, fmax = 80.0, 600.0
150
- if len(y) < int(sr / fmin) + 2:
151
- f0_est = 0.0
152
- else:
153
- corr = np.correlate(y, y, mode='full')[len(y)-1:]
154
- lmin, lmax = max(1, int(sr / fmax)), min(len(corr) - 1, int(sr / fmin))
155
- seg = corr[lmin:lmax] if lmax > lmin else np.array([])
156
- f0_est = float(sr / (lmin + np.argmax(seg))) if seg.size > 0 and (lmin + np.argmax(seg)) > 0 else 0.0
157
-
158
- return {"f0_mean": f0_est, "energy_mean": energy_mean, "spec_centroid": sc_mean, "zcr_mean": zcr_mean, "duration": len(y)/sr}
159
-
160
- def av_from_features(feat):
161
- """特徴量からArousal/Valenceを推定"""
162
- f0, en, z = feat["f0_mean"], feat["energy_mean"], feat["zcr_mean"]
163
- arousal = float(np.tanh(AppConfig.ENERGY_TO_AROUSAL_COEF * en + AppConfig.ZCR_TO_AROUSAL_COEF * z))
164
- valence_term = ((f0 - AppConfig.F0_TO_VALENCE_OFFSET) / AppConfig.F0_TO_VALENCE_SCALE if AppConfig.F0_TO_VALENCE_SCALE != 0 else 0)
165
- valence = float(np.tanh(valence_term + AppConfig.ENERGY_TO_VALENCE_COEF * en))
166
- return arousal, valence
167
 
168
- def label_from_av(arousal, valence):
169
- """Arousal/Valenceから基本的な感情ラベルを推定"""
170
- v = 0.0 if abs(valence) < AppConfig.VALENCE_DEAD_ZONE else valence
171
- a = 0.0 if arousal < AppConfig.AROUSAL_DEAD_ZONE else arousal
172
- if v >= AppConfig.JOY_VALENCE_THRESHOLD and a >= AppConfig.JOY_AROUSAL_THRESHOLD: return "joy"
173
- if v >= AppConfig.JOY_VALENCE_THRESHOLD and a < AppConfig.JOY_AROUSAL_THRESHOLD: return "calm"
174
- if v < AppConfig.HIGH_AROUSAL_NEG_VALENCE_THRESHOLD and a >= AppConfig.HIGH_AROUSAL_THRESHOLD: return "arousal_high_neg"
175
- if a >= AppConfig.SURPRISE_AROUSAL_THRESHOLD: return "surprise"
176
- return "neutral"
177
 
178
- def refine_label(arousal, valence, feat):
179
- """基本的な感情ラベルをさらに細分化"""
180
- base = label_from_av(arousal, valence)
181
- if base not in ("arousal_high_neg", "surprise", "neutral"):
182
- return base
 
 
 
 
 
 
 
183
 
184
- e, z, sc = feat["energy_mean"], feat["zcr_mean"], feat["spec_centroid"]
185
- a = 0.0 if arousal < AppConfig.AROUSAL_DEAD_ZONE else arousal
186
- v = 0.0 if abs(valence) < AppConfig.VALENCE_DEAD_ZONE else valence
187
-
188
- if (a >= AppConfig.ANGER_AROUSAL and v <= 0.0 and sc >= AppConfig.ANGER_SC and z >= AppConfig.ANGER_ZCR and e >= AppConfig.ANGER_ENERGY): return "anger"
189
- if (a < AppConfig.SADNESS_AROUSAL and v <= 0.0 and e < AppConfig.SADNESS_ENERGY and sc < AppConfig.SADNESS_SC): return "sadness"
190
- if (a >= AppConfig.TENSION_AROUSAL and v <= 0.0 and z >= AppConfig.TENSION_ZCR): return "tension"
191
- return base
 
 
 
 
 
 
 
 
 
192
 
193
- def score_places(emo_label, top_k=8, show_k=4, diversity=True):
194
- """感情ラベルに基づいて場所をスコアリングし、推薦リストを生成"""
195
- priors = EMO_MAP_PRIORS.get(emo_label, ["calm", "joy", "surprise"])
196
  scored = []
197
  for p in PLACES:
198
  base = 0.5
@@ -201,112 +123,30 @@ def score_places(emo_label, top_k=8, show_k=4, diversity=True):
201
  scored.append((base + random.uniform(-0.02, 0.02), p))
202
 
203
  scored.sort(key=lambda x: x[0], reverse=True)
204
- candidates = [p for _, p in scored[:max(top_k, show_k)]]
205
- if not diversity: return candidates[:show_k]
206
 
 
 
207
  picked, seen = [], set()
208
  for p in candidates:
209
  if p["emo_key"] not in seen:
210
- picked.append(p); seen.add(p["emo_key"])
211
- if len(picked) >= show_k: break
212
-
213
- # 多様性を確保しつつ、指定された件数まで候補を追加
214
- if len(picked) < show_k:
215
  for p in candidates:
216
  if p not in picked: picked.append(p)
217
- if len(picked) >= show_k: break
218
  return picked
219
 
220
- def ensure_logs_path():
221
- """ログファイルのパスを返し、なければヘッダーを書き込む"""
222
- path_dir = "/tmp/logs"
223
- os.makedirs(path_dir, exist_ok=True)
224
- path = os.path.join(path_dir, "oc_sessions.csv")
225
- if not os.path.exists(path):
226
- with open(path, "w", newline="", encoding="utf-8") as f:
227
- csv.writer(f).writerow(["session_id","ts","consent_research","save_audio","f0_mean","energy_mean","spec_centroid","zcr_mean","duration","arousal","valence","emo_label","exposed_ids","choice_id","rating_like","rating_vibe","reason_tags","comment"])
228
- return path
229
-
230
- def append_log(row_dict):
231
- """ログファイルに1行追記"""
232
- path = ensure_logs_path()
233
- header = ["session_id","ts","consent_research","save_audio","f0_mean","energy_mean","spec_centroid","zcr_mean","duration","arousal","valence","emo_label","exposed_ids","choice_id","rating_like","rating_vibe","reason_tags","comment"]
234
- row_values = []
235
- for key in header:
236
- if key == "exposed_ids": row_values.append(",".join(row_dict.get(key, [])))
237
- elif key == "reason_tags": row_values.append("|".join(row_dict.get(key, [])))
238
- else: row_values.append(row_dict.get(key, ""))
239
- with open(path, "a", newline="", encoding="utf-8") as f:
240
- csv.writer(f).writerow(row_values)
241
-
242
- def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
243
- """様々な形式の音声をWAV形式のbytesに変換"""
244
- if not any_bytes: st.error("音声が空です。"); st.stop()
245
- try:
246
- seg = AudioSegment.from_file(io.BytesIO(any_bytes))
247
- seg = seg.set_channels(1) if mono else seg
248
- seg = seg.set_frame_rate(target_sr) if target_sr else seg
249
- buf = io.BytesIO()
250
- seg.export(buf, format="wav")
251
- return buf.getvalue()
252
- except Exception as e:
253
- st.error(f"音声ファイルを処理できませんでした: {e}"); st.stop()
254
-
255
- def plot_av_map(points, current=None, size=(6.5, 6.5), dpi=200):
256
- """Arousal/Valenceマップを描画"""
257
- fig, ax = plt.subplots(figsize=size, dpi=dpi)
258
- fig.patch.set_facecolor('#FFFFFF')
259
- ax.set_facecolor('#FAFBFC')
260
-
261
- quads = [
262
- ((0, 0), (1, 1), "#FFE4E6", "#FF6B6B", "喜び・興奮", "Joy/Excitement"),
263
- ((-1, 0), (0, 1), "#FFF3CD", "#FFA94D", "緊張・怒り", "Tension/Anger"),
264
- ((-1, -1), (0, 0), "#E8EAED", "#868E96", "悲しみ・低覚醒", "Sadness/Low"),
265
- ((0, -1), (1, 0), "#D3F9D8", "#51CF66", "落ち着き・満足", "Calm/Content"),
266
- ]
267
- for (x0, y0), (x1, y1), c_base, c_accent, label_jp, label_en in quads:
268
- ax.fill_between([x0, x1], y0, y1, color=c_base, alpha=0.15, zorder=0)
269
- cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
270
- bbox_props = dict(boxstyle="round,pad=0.3", facecolor='white', edgecolor='none', alpha=0.7)
271
- ax.text(cx, cy + 0.08, label_jp, fontsize=13, fontweight='bold', ha="center", va="center", color="#2D3436", bbox=bbox_props)
272
- ax.text(cx, cy - 0.08, label_en, fontsize=10, style='italic', ha="center", va="center", color="#636E72", alpha=0.8)
273
-
274
- ax.add_artist(plt.Circle((0, 0), 1.0, fill=False, lw=2, color="#2D3436", alpha=0.8))
275
- ax.grid(True, alpha=0.2, linestyle=':', color='#BDC3C7')
276
- ax.axhline(0, color="#495057", lw=1.5, alpha=0.8); ax.axvline(0, color="#495057", lw=1.5, alpha=0.8)
277
-
278
- ax.set_xlim(-1.15, 1.15); ax.set_ylim(-1.15, 1.15); ax.set_aspect("equal", adjustable="box")
279
- ax.set_xlabel("Valence (価値感情)\n← ネガティブ    ポジティブ →", fontsize=13, labelpad=10, color="#2D3436")
280
- ax.set_ylabel("Arousal (覚醒度)\n↑ 高い\n\n↓ 低い", fontsize=13, labelpad=10, color="#2D3436")
281
- ax.set_xticks([-1, -0.5, 0, 0.5, 1]); ax.set_yticks([-1, -0.5, 0, 0.5, 1])
282
- ax.tick_params(labelsize=10, colors="#495057")
283
-
284
- if points:
285
- xs, ys, n = [p["v"] for p in points], [p["a"] for p in points], len(points)
286
- for i in range(n):
287
- alpha = 0.2 + 0.4 * (i / max(n - 1, 1))
288
- ax.scatter(xs[i], ys[i], s=30, alpha=alpha, color="#3498DB", edgecolors="none", zorder=2)
289
-
290
- if current:
291
- v, a, lab = float(current["v"]), float(current["a"]), current.get("label", "現在")
292
- ax.scatter([v], [a], s=120, facecolors="white", edgecolors="#FF6348", linewidths=3, zorder=5)
293
- ax.scatter([v], [a], s=60, color="#FF6348", zorder=6)
294
- bbox_props = dict(boxstyle="round,pad=0.5", facecolor='#FF6348', edgecolor='none', alpha=0.9)
295
- ax.annotate(f" {lab} ", xy=(v, a), xytext=(v + 0.15, a + 0.15), fontsize=12, weight="bold", color="white", bbox=bbox_props, arrowprops=dict(arrowstyle="-|>", connectionstyle="arc3,rad=0.3", color="#FF6348", lw=2))
296
-
297
- ax.set_title("感情分析マップ", fontsize=16, fontweight='bold', pad=20, color="#2D3436")
298
- plt.tight_layout(pad=0.5)
299
- return fig
300
-
301
  # =========================
302
  # Streamlit UI
303
  # =========================
304
  st.set_page_config(page_title="Voice→Place Recommender", page_icon="🎙️", layout="centered")
305
- st.title("🎙️ 声の感情で『架空の場所』をレコメンド")
306
- st.caption("録音→感情推定(Arousal/Valence)→上位スポット→評価→CSV保存(匿名)")
307
 
308
  # ---- Session state 初期化 ----
309
- for key, default in [("wav_bytes", None), ("recs", None), ("feat", None), ("arousal", None), ("valence", None), ("emo_label", None), ("av_hist", []), ("rec_key", 0)]:
310
  if key not in st.session_state: st.session_state[key] = default
311
 
312
  # ---- 1) 録音 / アップロード ----
@@ -318,57 +158,51 @@ with tab_rec:
318
  st.session_state["wav_bytes"] = audio.export().read()
319
  audio_player_bytes(st.session_state["wav_bytes"])
320
  if st.button("🧹 クリアして新しく録音", use_container_width=True):
321
- for k in ["wav_bytes","recs","feat","arousal","valence","emo_label"]: st.session_state[k] = None
322
  st.session_state["rec_key"] += 1
323
  st.rerun()
324
 
325
  with tab_upload:
326
- up = st.file_uploader("WAV/MP3/M4A を選択", type=["wav","mp3","m4a"])
327
  if up:
328
  st.session_state["wav_bytes"] = up.read()
329
  audio_player_bytes(st.session_state["wav_bytes"])
330
 
331
  # ---- 2) 同意 ----
332
  st.subheader("2) 同意")
333
- consent = st.radio("研究利用の同意(匿名IDで特徴量と評価を保存します)", ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
334
- save_audio = st.checkbox("音声ファイルも保存する(任意)", value=False)
335
 
336
  # ---- 推定 & レコメンド実行 ----
337
- if st.button("🔍 推定 & レコメンド", type="primary", use_container_width=True, disabled=(st.session_state["wav_bytes"] is None)):
338
- with st.spinner('音声を分析中...🤖'):
339
- wav_bytes_fixed = to_wav_bytes(st.session_state["wav_bytes"])
340
- y, sr = sf.read(io.BytesIO(wav_bytes_fixed), dtype="float32", always_2d=False)
341
- y = y.mean(axis=1) if y.ndim == 2 else y
342
-
343
- feat = extract_features(y, sr)
344
- arousal, valence = av_from_features(feat)
345
- emo_label = refine_label(arousal, valence, feat)
346
-
347
  st.session_state.update({
348
- "feat": feat, "arousal": arousal, "valence": valence, "emo_label": emo_label,
349
- "recs": score_places(emo_label)
 
350
  })
351
- st.session_state["av_hist"].append({"a": arousal, "v": valence, "label": emo_label})
352
  st.success("分析が完了しました!")
353
 
354
  # ---- 結果表示 ----
355
- if st.session_state["recs"]:
356
- feat, arousal, valence, emo_label, recs = (st.session_state[k] for k in ["feat", "arousal", "valence", "emo_label", "recs"])
 
 
357
 
358
  st.subheader("分析結果")
359
- col_map, col_info = st.columns([0.65, 0.35])
360
-
361
- with col_map:
362
- current_pt = {"a": arousal, "v": valence, "label": emo_label}
363
- fig = plot_av_map(st.session_state["av_hist"], current=current_pt, size=(6, 6), dpi=150)
364
- st.pyplot(fig, clear_figure=True)
365
-
366
- with col_info:
367
- st.success(f"**推定感情: {emo_label}**")
368
- st.metric("Arousal (覚醒度)", f"{arousal:.2f}")
369
- st.metric("Valence (価値感情)", f"{valence:.2f}")
370
- with st.expander("詳細な特徴量"):
371
- st.json({k: f"{v:.3f}" if isinstance(v, float) else v for k, v in feat.items()})
372
 
373
  st.subheader("3) おすすめ(上位4件)")
374
  cols = st.columns(4)
@@ -388,33 +222,12 @@ if st.session_state["recs"]:
388
  comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
389
 
390
  if st.form_submit_button("💾 ログ保存", use_container_width=True):
391
- if consent != "匿名で保存する":
392
- st.info("体験のみモードのため、ログは保存しません。")
393
- else:
394
- choice_id = next((p["place_id"] for p in recs if p["name"] == choice_name), None)
395
- row = {
396
- "session_id": f"oc-{uuid.uuid4().hex[:8]}", "ts": dt.datetime.now().isoformat(timespec="seconds"),
397
- "consent_research": True, "save_audio": save_audio, **feat,
398
- "arousal": arousal, "valence": valence, "emo_label": emo_label,
399
- "exposed_ids": [p["place_id"] for p in recs[:4]], "choice_id": choice_id,
400
- "rating_like": rating_like, "rating_vibe": rating_vibe, "reason_tags": reasons, "comment": comment,
401
- }
402
- append_log(row)
403
- if save_audio:
404
- out_path = os.path.join("/tmp/logs", f"{row['session_id']}.wav")
405
- with open(out_path, "wb") as f: f.write(st.session_state["wav_bytes"])
406
- st.success("ログを保存しました!ご協力ありがとうございます。")
407
 
408
- # ---- フッターと管理機能 ----
409
  st.divider()
410
- if st.button("▶ 次の人を録音する(状態をクリア)", use_container_width=True):
411
- for k in list(st.session_state.keys()):
412
- if k not in ['av_hist', 'rec_key']: # 履歴は残す
413
- del st.session_state[k]
414
  st.session_state["rec_key"] += 1
415
- st.rerun()
416
-
417
- csv_path = ensure_logs_path()
418
- if os.path.exists(csv_path) and os.path.getsize(csv_path) > 0:
419
- with open(csv_path, "rb") as f:
420
- st.download_button("🔻 これまでの評価ログをダウンロード", f, file_name="oc_sessions.csv", mime="text/csv")
 
1
  # =========================
2
+ # app.py (AIモデル搭載版)
3
  # =========================
4
  import os
 
 
 
5
  import io
6
  import uuid
7
  import datetime as dt
8
  import csv
9
  import base64
 
10
  import random
11
+ import warnings
12
 
13
+ # --- 警告の抑制 ---
 
 
14
  warnings.filterwarnings('ignore')
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # --- ライブラリのインポート ---
 
 
 
 
 
17
  import numpy as np
18
  import soundfile as sf
19
  import streamlit as st
20
  from audiorecorder import audiorecorder
21
  from pydub import AudioSegment
22
+ import torch
23
+ from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # =========================
26
  # 架空の場所データ
 
42
  {"place_id":"urban_track", "name":"アーバントラック", "tags":["身体活動","発散","屋外"], "emo_key":"release", "image":"images/urban_track.png"},
43
  ]
44
  REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
 
 
 
 
 
45
 
46
  # =========================
47
+ # AIモデル関連の関数
48
  # =========================
49
 
50
+ @st.cache_resource
51
+ def load_model():
52
+ """AIモデルをロードしてStreamlitのキャッシュに保存"""
53
+ model_name = "Mizuiro-inc/emotion2vec-base-japanese"
54
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
55
+ model = AutoModelForAudioClassification.from_pretrained(model_name)
56
+ return feature_extractor, model
57
+
58
+ def predict_emotion(audio_bytes):
59
+ """音声データからAIが感情を予測する"""
60
+ feature_extractor, model = load_model()
61
+
62
+ # 音声データを16kHzのWAV形式に変換
63
+ wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
64
+ y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
65
 
66
+ # 特徴量を抽出し、PyTorchテンソルに変換
67
+ inputs = feature_extractor(y, sampling_rate=sr, return_tensors="pt", padding=True)
 
 
 
 
68
 
69
+ # AIモデルで予測を実行
70
+ with torch.no_grad():
71
+ logits = model(**inputs).logits
72
 
73
+ # 最も確率の高い感情ラベルを取得
74
+ predicted_id = torch.argmax(logits, dim=-1).item()
75
+ predicted_label = model.config.id2label[predicted_id]
 
 
 
76
 
77
+ # 各感情の確率も計算 (表示用)
78
+ probabilities = torch.softmax(logits, dim=-1)[0]
79
+ all_scores = {model.config.id2label[i]: prob.item() for i, prob in enumerate(probabilities)}
80
 
81
+ return predicted_label, all_scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # =========================
84
+ # 汎用関数
85
+ # =========================
 
 
 
 
 
 
86
 
87
+ def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
88
+ """様々な形式の音声をWAV形式のbytesに変換"""
89
+ if not any_bytes: st.error("音声が空です。"); st.stop()
90
+ try:
91
+ seg = AudioSegment.from_file(io.BytesIO(any_bytes))
92
+ if mono: seg = seg.set_channels(1)
93
+ if target_sr: seg = seg.set_frame_rate(target_sr)
94
+ buf = io.BytesIO()
95
+ seg.export(buf, format="wav")
96
+ return buf.getvalue()
97
+ except Exception as e:
98
+ st.error(f"音声ファイルを処理できませんでした: {e}"); st.stop()
99
 
100
+ def audio_player_bytes(b: bytes, mime="audio/wav"):
101
+ """音声データをUIに表示するためのHTMLを生成"""
102
+ if not b: return
103
+ b64 = base64.b64encode(b).decode("utf-8")
104
+ st.markdown(f'<audio controls preload="metadata" style="width:100%"><source src="data:{mime};base64,{b64}" type="{mime}"></audio>', unsafe_allow_html=True)
105
+
106
+ # AIの予測結果を場所の推薦に繋げるための新しい関数
107
+ def score_places_by_ai(emo_label, top_k=4):
108
+ """AIの感情ラベルに基づいて場所を推薦する"""
109
+ # AIのラベルと場所のカテゴリを対応付ける
110
+ label_to_emo_key = {
111
+ 'happy': ['joy', 'surprise'],
112
+ 'sad': ['calm', 'joy'],
113
+ 'angry': ['release', 'calm'],
114
+ 'neutral': ['calm', 'surprise', 'joy']
115
+ }
116
+ priors = label_to_emo_key.get(emo_label, ['calm', 'joy']) # 未知のラベルはcalm/joyに
117
 
 
 
 
118
  scored = []
119
  for p in PLACES:
120
  base = 0.5
 
123
  scored.append((base + random.uniform(-0.02, 0.02), p))
124
 
125
  scored.sort(key=lambda x: x[0], reverse=True)
 
 
126
 
127
+ # 多様性を確保するロジック
128
+ candidates = [p for _, p in scored]
129
  picked, seen = [], set()
130
  for p in candidates:
131
  if p["emo_key"] not in seen:
132
+ picked.append(p)
133
+ seen.add(p["emo_key"])
134
+ if len(picked) >= top_k: break
135
+ if len(picked) < top_k:
 
136
  for p in candidates:
137
  if p not in picked: picked.append(p)
138
+ if len(picked) >= top_k: break
139
  return picked
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  # =========================
142
  # Streamlit UI
143
  # =========================
144
  st.set_page_config(page_title="Voice→Place Recommender", page_icon="🎙️", layout="centered")
145
+ st.title("🎙️ 声の感情で『架空の場所』をレコメンド (AI版)")
146
+ st.caption("録音→AI感情推定→上位スポット→評価→CSV保存(匿名)")
147
 
148
  # ---- Session state 初期化 ----
149
+ for key, default in [("wav_bytes", None), ("recs", None), ("emo_label", None), ("scores", None), ("rec_key", 0)]:
150
  if key not in st.session_state: st.session_state[key] = default
151
 
152
  # ---- 1) 録音 / アップロード ----
 
158
  st.session_state["wav_bytes"] = audio.export().read()
159
  audio_player_bytes(st.session_state["wav_bytes"])
160
  if st.button("🧹 クリアして新しく録音", use_container_width=True):
161
+ for k in ["wav_bytes", "recs", "emo_label", "scores"]: st.session_state[k] = None
162
  st.session_state["rec_key"] += 1
163
  st.rerun()
164
 
165
  with tab_upload:
166
+ up = st.file_uploader("WAV/MP3/M4A を選択", type=["wav", "mp3", "m4a"])
167
  if up:
168
  st.session_state["wav_bytes"] = up.read()
169
  audio_player_bytes(st.session_state["wav_bytes"])
170
 
171
  # ---- 2) 同意 ----
172
  st.subheader("2) 同意")
173
+ consent = st.radio("研究利用の同意", ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
 
174
 
175
  # ---- 推定 & レコメンド実行 ----
176
+ if st.button("🔍 AIで推定 & レコメンド", type="primary", use_container_width=True, disabled=(st.session_state["wav_bytes"] is None)):
177
+ with st.spinner('AIが感情を分析中...🤖'):
178
+ raw_bytes = st.session_state["wav_bytes"]
179
+ emo_label, all_scores = predict_emotion(raw_bytes)
180
+
 
 
 
 
 
181
  st.session_state.update({
182
+ "emo_label": emo_label,
183
+ "scores": all_scores,
184
+ "recs": score_places_by_ai(emo_label)
185
  })
 
186
  st.success("分析が完了しました!")
187
 
188
  # ---- 結果表示 ----
189
+ if st.session_state.get("recs"):
190
+ emo_label = st.session_state["emo_label"]
191
+ scores = st.session_state["scores"]
192
+ recs = st.session_state["recs"]
193
 
194
  st.subheader("分析結果")
195
+ col1, col2 = st.columns([0.6, 0.4])
196
+ with col1:
197
+ st.success(f"**AIの推定感情: {emo_label}**")
198
+ st.write("感情スコアの詳細:")
199
+ st.bar_chart(scores)
200
+ with col2:
201
+ st.write("この感情におすすめの場所:")
202
+ if recs:
203
+ st.image(recs[0]["image"], use_container_width=True)
204
+ st.markdown(f"**{recs[0]['name']}**")
205
+ st.caption(f"タグ: {', '.join(recs[0]['tags'])}")
 
 
206
 
207
  st.subheader("3) おすすめ(上位4件)")
208
  cols = st.columns(4)
 
222
  comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
223
 
224
  if st.form_submit_button("💾 ログ保存", use_container_width=True):
225
+ st.info("ログ保存機能は現在開発中です。")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # ---- フッター ----
228
  st.divider()
229
+ if st.button("▶ 次の人を試す(状態をクリア)", use_container_width=True):
230
+ for k in ["wav_bytes", "recs", "emo_label", "scores"]:
231
+ if st.session_state.get(k): st.session_state[k] = None
 
232
  st.session_state["rec_key"] += 1
233
+ st.rerun()