Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,92 +1,26 @@
|
|
| 1 |
# =========================
|
| 2 |
-
#
|
| 3 |
# =========================
|
| 4 |
import os
|
| 5 |
-
import tempfile
|
| 6 |
-
import warnings
|
| 7 |
-
import logging
|
| 8 |
import io
|
| 9 |
import uuid
|
| 10 |
import datetime as dt
|
| 11 |
import csv
|
| 12 |
import base64
|
| 13 |
-
import json
|
| 14 |
import random
|
|
|
|
| 15 |
|
| 16 |
-
# ---
|
| 17 |
-
logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR)
|
| 18 |
-
logging.getLogger('matplotlib').setLevel(logging.ERROR)
|
| 19 |
warnings.filterwarnings('ignore')
|
| 20 |
|
| 21 |
-
# --- 権限/キャッシュ対策 ---
|
| 22 |
-
os.environ["STREAMLIT_BROWSER_GATHERUSAGESTATS"] = "false"
|
| 23 |
-
os.environ["NUMBA_DISABLE_JIT"] = "1"
|
| 24 |
-
os.environ["NUMBA_CACHE_DIR"] = "/tmp/numba_cache"
|
| 25 |
-
|
| 26 |
-
# --- Matplotlibのバックエンド設定 ---
|
| 27 |
-
mpl_config_dir = tempfile.mkdtemp()
|
| 28 |
-
os.environ["MPLCONFIGDIR"] = mpl_config_dir
|
| 29 |
-
matplotlibrc_path = os.path.join(mpl_config_dir, 'matplotlibrc')
|
| 30 |
-
with open(matplotlibrc_path, 'w') as f:
|
| 31 |
-
f.write("""
|
| 32 |
-
backend: Agg
|
| 33 |
-
font.family: sans-serif
|
| 34 |
-
axes.unicode_minus: False
|
| 35 |
-
""")
|
| 36 |
-
|
| 37 |
# --- ライブラリのインポート ---
|
| 38 |
-
import matplotlib
|
| 39 |
-
matplotlib.use('Agg')
|
| 40 |
-
import matplotlib.pyplot as plt
|
| 41 |
-
import japanize_matplotlib # 日本語化ライブラリ
|
| 42 |
-
|
| 43 |
import numpy as np
|
| 44 |
import soundfile as sf
|
| 45 |
import streamlit as st
|
| 46 |
from audiorecorder import audiorecorder
|
| 47 |
from pydub import AudioSegment
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
plt.ioff()
|
| 51 |
-
try:
|
| 52 |
-
os.makedirs("/tmp/numba_cache", exist_ok=True)
|
| 53 |
-
except:
|
| 54 |
-
pass
|
| 55 |
-
|
| 56 |
-
# Matplotlibのマイナス記号の文字化け対策
|
| 57 |
-
rcParams["axes.unicode_minus"] = False
|
| 58 |
-
|
| 59 |
-
# =========================
|
| 60 |
-
# アプリケーション設定値
|
| 61 |
-
# =========================
|
| 62 |
-
class AppConfig:
|
| 63 |
-
"""アプリ内の閾値や係数を管理するクラス"""
|
| 64 |
-
# A/V推定の閾値
|
| 65 |
-
VALENCE_DEAD_ZONE = 0.12
|
| 66 |
-
AROUSAL_DEAD_ZONE = 0.12
|
| 67 |
-
JOY_VALENCE_THRESHOLD = 0.22
|
| 68 |
-
JOY_AROUSAL_THRESHOLD = 0.22
|
| 69 |
-
HIGH_AROUSAL_NEG_VALENCE_THRESHOLD = 0.10
|
| 70 |
-
HIGH_AROUSAL_THRESHOLD = 0.30
|
| 71 |
-
SURPRISE_AROUSAL_THRESHOLD = 0.22
|
| 72 |
-
|
| 73 |
-
# A/V計算の係数
|
| 74 |
-
ENERGY_TO_AROUSAL_COEF = 160.0
|
| 75 |
-
ZCR_TO_AROUSAL_COEF = 4.0
|
| 76 |
-
F0_TO_VALENCE_OFFSET = 170.0
|
| 77 |
-
F0_TO_VALENCE_SCALE = 120.0
|
| 78 |
-
ENERGY_TO_VALENCE_COEF = 15.0
|
| 79 |
-
|
| 80 |
-
# refine_labelの閾値
|
| 81 |
-
ANGER_AROUSAL = 0.35
|
| 82 |
-
ANGER_SC = 1500
|
| 83 |
-
ANGER_ZCR = 0.05
|
| 84 |
-
ANGER_ENERGY = 0.015
|
| 85 |
-
SADNESS_AROUSAL = 0.18
|
| 86 |
-
SADNESS_ENERGY = 0.012
|
| 87 |
-
SADNESS_SC = 800
|
| 88 |
-
TENSION_AROUSAL = 0.30
|
| 89 |
-
TENSION_ZCR = 0.06
|
| 90 |
|
| 91 |
# =========================
|
| 92 |
# 架空の場所データ
|
|
@@ -108,91 +42,79 @@ PLACES = [
|
|
| 108 |
{"place_id":"urban_track", "name":"アーバントラック", "tags":["身体活動","発散","屋外"], "emo_key":"release", "image":"images/urban_track.png"},
|
| 109 |
]
|
| 110 |
REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
|
| 111 |
-
EMO_MAP_PRIORS = {
|
| 112 |
-
"joy": ["joy","surprise"], "calm": ["calm","joy"], "surprise": ["surprise","joy"],
|
| 113 |
-
"arousal_high_neg": ["release","surprise"], "neutral": ["calm","joy","surprise"],
|
| 114 |
-
"anger": ["release","surprise"], "sadness": ["calm","joy"], "tension": ["calm","surprise"],
|
| 115 |
-
}
|
| 116 |
|
| 117 |
# =========================
|
| 118 |
-
#
|
| 119 |
# =========================
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
""
|
| 132 |
-
abs_y = np.abs(y)
|
| 133 |
-
thr = 0.01 * (abs_y.max() + 1e-9)
|
| 134 |
-
idx = np.where(abs_y > thr)[0]
|
| 135 |
-
if idx.size >= 2: y = y[idx[0]:idx[-1]+1]
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
mag = np.abs(spec) + 1e-12
|
| 143 |
-
freqs = np.fft.rfftfreq(len(y * win), d=1.0/sr)
|
| 144 |
-
sc_mean = float((freqs * mag).sum() / mag.sum())
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
if len(y) < int(sr / fmin) + 2:
|
| 151 |
-
f0_est = 0.0
|
| 152 |
-
else:
|
| 153 |
-
corr = np.correlate(y, y, mode='full')[len(y)-1:]
|
| 154 |
-
lmin, lmax = max(1, int(sr / fmax)), min(len(corr) - 1, int(sr / fmin))
|
| 155 |
-
seg = corr[lmin:lmax] if lmax > lmin else np.array([])
|
| 156 |
-
f0_est = float(sr / (lmin + np.argmax(seg))) if seg.size > 0 and (lmin + np.argmax(seg)) > 0 else 0.0
|
| 157 |
-
|
| 158 |
-
return {"f0_mean": f0_est, "energy_mean": energy_mean, "spec_centroid": sc_mean, "zcr_mean": zcr_mean, "duration": len(y)/sr}
|
| 159 |
-
|
| 160 |
-
def av_from_features(feat):
|
| 161 |
-
"""特徴量からArousal/Valenceを推定"""
|
| 162 |
-
f0, en, z = feat["f0_mean"], feat["energy_mean"], feat["zcr_mean"]
|
| 163 |
-
arousal = float(np.tanh(AppConfig.ENERGY_TO_AROUSAL_COEF * en + AppConfig.ZCR_TO_AROUSAL_COEF * z))
|
| 164 |
-
valence_term = ((f0 - AppConfig.F0_TO_VALENCE_OFFSET) / AppConfig.F0_TO_VALENCE_SCALE if AppConfig.F0_TO_VALENCE_SCALE != 0 else 0)
|
| 165 |
-
valence = float(np.tanh(valence_term + AppConfig.ENERGY_TO_VALENCE_COEF * en))
|
| 166 |
-
return arousal, valence
|
| 167 |
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
a = 0.0 if arousal < AppConfig.AROUSAL_DEAD_ZONE else arousal
|
| 172 |
-
if v >= AppConfig.JOY_VALENCE_THRESHOLD and a >= AppConfig.JOY_AROUSAL_THRESHOLD: return "joy"
|
| 173 |
-
if v >= AppConfig.JOY_VALENCE_THRESHOLD and a < AppConfig.JOY_AROUSAL_THRESHOLD: return "calm"
|
| 174 |
-
if v < AppConfig.HIGH_AROUSAL_NEG_VALENCE_THRESHOLD and a >= AppConfig.HIGH_AROUSAL_THRESHOLD: return "arousal_high_neg"
|
| 175 |
-
if a >= AppConfig.SURPRISE_AROUSAL_THRESHOLD: return "surprise"
|
| 176 |
-
return "neutral"
|
| 177 |
|
| 178 |
-
def
|
| 179 |
-
"""
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
-
def score_places(emo_label, top_k=8, show_k=4, diversity=True):
|
| 194 |
-
"""感情ラベルに基づいて場所をスコアリングし、推薦リストを生成"""
|
| 195 |
-
priors = EMO_MAP_PRIORS.get(emo_label, ["calm", "joy", "surprise"])
|
| 196 |
scored = []
|
| 197 |
for p in PLACES:
|
| 198 |
base = 0.5
|
|
@@ -201,112 +123,30 @@ def score_places(emo_label, top_k=8, show_k=4, diversity=True):
|
|
| 201 |
scored.append((base + random.uniform(-0.02, 0.02), p))
|
| 202 |
|
| 203 |
scored.sort(key=lambda x: x[0], reverse=True)
|
| 204 |
-
candidates = [p for _, p in scored[:max(top_k, show_k)]]
|
| 205 |
-
if not diversity: return candidates[:show_k]
|
| 206 |
|
|
|
|
|
|
|
| 207 |
picked, seen = [], set()
|
| 208 |
for p in candidates:
|
| 209 |
if p["emo_key"] not in seen:
|
| 210 |
-
picked.append(p)
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
if len(picked) < show_k:
|
| 215 |
for p in candidates:
|
| 216 |
if p not in picked: picked.append(p)
|
| 217 |
-
if len(picked) >=
|
| 218 |
return picked
|
| 219 |
|
| 220 |
-
def ensure_logs_path():
|
| 221 |
-
"""ログファイルのパスを返し、なければヘッダーを書き込む"""
|
| 222 |
-
path_dir = "/tmp/logs"
|
| 223 |
-
os.makedirs(path_dir, exist_ok=True)
|
| 224 |
-
path = os.path.join(path_dir, "oc_sessions.csv")
|
| 225 |
-
if not os.path.exists(path):
|
| 226 |
-
with open(path, "w", newline="", encoding="utf-8") as f:
|
| 227 |
-
csv.writer(f).writerow(["session_id","ts","consent_research","save_audio","f0_mean","energy_mean","spec_centroid","zcr_mean","duration","arousal","valence","emo_label","exposed_ids","choice_id","rating_like","rating_vibe","reason_tags","comment"])
|
| 228 |
-
return path
|
| 229 |
-
|
| 230 |
-
def append_log(row_dict):
|
| 231 |
-
"""ログファイルに1行追記"""
|
| 232 |
-
path = ensure_logs_path()
|
| 233 |
-
header = ["session_id","ts","consent_research","save_audio","f0_mean","energy_mean","spec_centroid","zcr_mean","duration","arousal","valence","emo_label","exposed_ids","choice_id","rating_like","rating_vibe","reason_tags","comment"]
|
| 234 |
-
row_values = []
|
| 235 |
-
for key in header:
|
| 236 |
-
if key == "exposed_ids": row_values.append(",".join(row_dict.get(key, [])))
|
| 237 |
-
elif key == "reason_tags": row_values.append("|".join(row_dict.get(key, [])))
|
| 238 |
-
else: row_values.append(row_dict.get(key, ""))
|
| 239 |
-
with open(path, "a", newline="", encoding="utf-8") as f:
|
| 240 |
-
csv.writer(f).writerow(row_values)
|
| 241 |
-
|
| 242 |
-
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
| 243 |
-
"""様々な形式の音声をWAV形式のbytesに変換"""
|
| 244 |
-
if not any_bytes: st.error("音声が空です。"); st.stop()
|
| 245 |
-
try:
|
| 246 |
-
seg = AudioSegment.from_file(io.BytesIO(any_bytes))
|
| 247 |
-
seg = seg.set_channels(1) if mono else seg
|
| 248 |
-
seg = seg.set_frame_rate(target_sr) if target_sr else seg
|
| 249 |
-
buf = io.BytesIO()
|
| 250 |
-
seg.export(buf, format="wav")
|
| 251 |
-
return buf.getvalue()
|
| 252 |
-
except Exception as e:
|
| 253 |
-
st.error(f"音声ファイルを処理できませんでした: {e}"); st.stop()
|
| 254 |
-
|
| 255 |
-
def plot_av_map(points, current=None, size=(6.5, 6.5), dpi=200):
|
| 256 |
-
"""Arousal/Valenceマップを描画"""
|
| 257 |
-
fig, ax = plt.subplots(figsize=size, dpi=dpi)
|
| 258 |
-
fig.patch.set_facecolor('#FFFFFF')
|
| 259 |
-
ax.set_facecolor('#FAFBFC')
|
| 260 |
-
|
| 261 |
-
quads = [
|
| 262 |
-
((0, 0), (1, 1), "#FFE4E6", "#FF6B6B", "喜び・興奮", "Joy/Excitement"),
|
| 263 |
-
((-1, 0), (0, 1), "#FFF3CD", "#FFA94D", "緊張・怒り", "Tension/Anger"),
|
| 264 |
-
((-1, -1), (0, 0), "#E8EAED", "#868E96", "悲しみ・低覚醒", "Sadness/Low"),
|
| 265 |
-
((0, -1), (1, 0), "#D3F9D8", "#51CF66", "落ち着き・満足", "Calm/Content"),
|
| 266 |
-
]
|
| 267 |
-
for (x0, y0), (x1, y1), c_base, c_accent, label_jp, label_en in quads:
|
| 268 |
-
ax.fill_between([x0, x1], y0, y1, color=c_base, alpha=0.15, zorder=0)
|
| 269 |
-
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
|
| 270 |
-
bbox_props = dict(boxstyle="round,pad=0.3", facecolor='white', edgecolor='none', alpha=0.7)
|
| 271 |
-
ax.text(cx, cy + 0.08, label_jp, fontsize=13, fontweight='bold', ha="center", va="center", color="#2D3436", bbox=bbox_props)
|
| 272 |
-
ax.text(cx, cy - 0.08, label_en, fontsize=10, style='italic', ha="center", va="center", color="#636E72", alpha=0.8)
|
| 273 |
-
|
| 274 |
-
ax.add_artist(plt.Circle((0, 0), 1.0, fill=False, lw=2, color="#2D3436", alpha=0.8))
|
| 275 |
-
ax.grid(True, alpha=0.2, linestyle=':', color='#BDC3C7')
|
| 276 |
-
ax.axhline(0, color="#495057", lw=1.5, alpha=0.8); ax.axvline(0, color="#495057", lw=1.5, alpha=0.8)
|
| 277 |
-
|
| 278 |
-
ax.set_xlim(-1.15, 1.15); ax.set_ylim(-1.15, 1.15); ax.set_aspect("equal", adjustable="box")
|
| 279 |
-
ax.set_xlabel("Valence (価値感情)\n← ネガティブ ポジティブ →", fontsize=13, labelpad=10, color="#2D3436")
|
| 280 |
-
ax.set_ylabel("Arousal (覚醒度)\n↑ 高い\n\n↓ 低い", fontsize=13, labelpad=10, color="#2D3436")
|
| 281 |
-
ax.set_xticks([-1, -0.5, 0, 0.5, 1]); ax.set_yticks([-1, -0.5, 0, 0.5, 1])
|
| 282 |
-
ax.tick_params(labelsize=10, colors="#495057")
|
| 283 |
-
|
| 284 |
-
if points:
|
| 285 |
-
xs, ys, n = [p["v"] for p in points], [p["a"] for p in points], len(points)
|
| 286 |
-
for i in range(n):
|
| 287 |
-
alpha = 0.2 + 0.4 * (i / max(n - 1, 1))
|
| 288 |
-
ax.scatter(xs[i], ys[i], s=30, alpha=alpha, color="#3498DB", edgecolors="none", zorder=2)
|
| 289 |
-
|
| 290 |
-
if current:
|
| 291 |
-
v, a, lab = float(current["v"]), float(current["a"]), current.get("label", "現在")
|
| 292 |
-
ax.scatter([v], [a], s=120, facecolors="white", edgecolors="#FF6348", linewidths=3, zorder=5)
|
| 293 |
-
ax.scatter([v], [a], s=60, color="#FF6348", zorder=6)
|
| 294 |
-
bbox_props = dict(boxstyle="round,pad=0.5", facecolor='#FF6348', edgecolor='none', alpha=0.9)
|
| 295 |
-
ax.annotate(f" {lab} ", xy=(v, a), xytext=(v + 0.15, a + 0.15), fontsize=12, weight="bold", color="white", bbox=bbox_props, arrowprops=dict(arrowstyle="-|>", connectionstyle="arc3,rad=0.3", color="#FF6348", lw=2))
|
| 296 |
-
|
| 297 |
-
ax.set_title("感情分析マップ", fontsize=16, fontweight='bold', pad=20, color="#2D3436")
|
| 298 |
-
plt.tight_layout(pad=0.5)
|
| 299 |
-
return fig
|
| 300 |
-
|
| 301 |
# =========================
|
| 302 |
# Streamlit UI
|
| 303 |
# =========================
|
| 304 |
st.set_page_config(page_title="Voice→Place Recommender", page_icon="🎙️", layout="centered")
|
| 305 |
-
st.title("🎙️ 声の感情で『架空の場所』をレコメンド")
|
| 306 |
-
st.caption("
|
| 307 |
|
| 308 |
# ---- Session state 初期化 ----
|
| 309 |
-
for key, default in [("wav_bytes", None), ("recs", None), ("
|
| 310 |
if key not in st.session_state: st.session_state[key] = default
|
| 311 |
|
| 312 |
# ---- 1) 録音 / アップロード ----
|
|
@@ -318,57 +158,51 @@ with tab_rec:
|
|
| 318 |
st.session_state["wav_bytes"] = audio.export().read()
|
| 319 |
audio_player_bytes(st.session_state["wav_bytes"])
|
| 320 |
if st.button("🧹 クリアして新しく録音", use_container_width=True):
|
| 321 |
-
for k in ["wav_bytes","recs","
|
| 322 |
st.session_state["rec_key"] += 1
|
| 323 |
st.rerun()
|
| 324 |
|
| 325 |
with tab_upload:
|
| 326 |
-
up = st.file_uploader("WAV/MP3/M4A を選択", type=["wav","mp3","m4a"])
|
| 327 |
if up:
|
| 328 |
st.session_state["wav_bytes"] = up.read()
|
| 329 |
audio_player_bytes(st.session_state["wav_bytes"])
|
| 330 |
|
| 331 |
# ---- 2) 同意 ----
|
| 332 |
st.subheader("2) 同意")
|
| 333 |
-
consent = st.radio("
|
| 334 |
-
save_audio = st.checkbox("音声ファイルも保存する(任意)", value=False)
|
| 335 |
|
| 336 |
# ---- 推定 & レコメンド実行 ----
|
| 337 |
-
if st.button("🔍
|
| 338 |
-
with st.spinner('
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
feat = extract_features(y, sr)
|
| 344 |
-
arousal, valence = av_from_features(feat)
|
| 345 |
-
emo_label = refine_label(arousal, valence, feat)
|
| 346 |
-
|
| 347 |
st.session_state.update({
|
| 348 |
-
"
|
| 349 |
-
"
|
|
|
|
| 350 |
})
|
| 351 |
-
st.session_state["av_hist"].append({"a": arousal, "v": valence, "label": emo_label})
|
| 352 |
st.success("分析が完了しました!")
|
| 353 |
|
| 354 |
# ---- 結果表示 ----
|
| 355 |
-
if st.session_state
|
| 356 |
-
|
|
|
|
|
|
|
| 357 |
|
| 358 |
st.subheader("分析結果")
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
with st.expander("詳細な特徴量"):
|
| 371 |
-
st.json({k: f"{v:.3f}" if isinstance(v, float) else v for k, v in feat.items()})
|
| 372 |
|
| 373 |
st.subheader("3) おすすめ(上位4件)")
|
| 374 |
cols = st.columns(4)
|
|
@@ -388,33 +222,12 @@ if st.session_state["recs"]:
|
|
| 388 |
comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
|
| 389 |
|
| 390 |
if st.form_submit_button("💾 ログ保存", use_container_width=True):
|
| 391 |
-
|
| 392 |
-
st.info("体験のみモードのため、ログは保存しません。")
|
| 393 |
-
else:
|
| 394 |
-
choice_id = next((p["place_id"] for p in recs if p["name"] == choice_name), None)
|
| 395 |
-
row = {
|
| 396 |
-
"session_id": f"oc-{uuid.uuid4().hex[:8]}", "ts": dt.datetime.now().isoformat(timespec="seconds"),
|
| 397 |
-
"consent_research": True, "save_audio": save_audio, **feat,
|
| 398 |
-
"arousal": arousal, "valence": valence, "emo_label": emo_label,
|
| 399 |
-
"exposed_ids": [p["place_id"] for p in recs[:4]], "choice_id": choice_id,
|
| 400 |
-
"rating_like": rating_like, "rating_vibe": rating_vibe, "reason_tags": reasons, "comment": comment,
|
| 401 |
-
}
|
| 402 |
-
append_log(row)
|
| 403 |
-
if save_audio:
|
| 404 |
-
out_path = os.path.join("/tmp/logs", f"{row['session_id']}.wav")
|
| 405 |
-
with open(out_path, "wb") as f: f.write(st.session_state["wav_bytes"])
|
| 406 |
-
st.success("ログを保存しました!ご協力ありがとうございます。")
|
| 407 |
|
| 408 |
-
# ----
|
| 409 |
st.divider()
|
| 410 |
-
if st.button("▶
|
| 411 |
-
for k in
|
| 412 |
-
if k
|
| 413 |
-
del st.session_state[k]
|
| 414 |
st.session_state["rec_key"] += 1
|
| 415 |
-
st.rerun()
|
| 416 |
-
|
| 417 |
-
csv_path = ensure_logs_path()
|
| 418 |
-
if os.path.exists(csv_path) and os.path.getsize(csv_path) > 0:
|
| 419 |
-
with open(csv_path, "rb") as f:
|
| 420 |
-
st.download_button("🔻 これまでの評価ログをダウンロード", f, file_name="oc_sessions.csv", mime="text/csv")
|
|
|
|
| 1 |
# =========================
|
| 2 |
+
# app.py (AIモデル搭載版)
|
| 3 |
# =========================
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
import io
|
| 6 |
import uuid
|
| 7 |
import datetime as dt
|
| 8 |
import csv
|
| 9 |
import base64
|
|
|
|
| 10 |
import random
|
| 11 |
+
import warnings
|
| 12 |
|
| 13 |
+
# --- 警告の抑制 ---
|
|
|
|
|
|
|
| 14 |
warnings.filterwarnings('ignore')
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# --- ライブラリのインポート ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import numpy as np
|
| 18 |
import soundfile as sf
|
| 19 |
import streamlit as st
|
| 20 |
from audiorecorder import audiorecorder
|
| 21 |
from pydub import AudioSegment
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# =========================
|
| 26 |
# 架空の場所データ
|
|
|
|
| 42 |
{"place_id":"urban_track", "name":"アーバントラック", "tags":["身体活動","発散","屋外"], "emo_key":"release", "image":"images/urban_track.png"},
|
| 43 |
]
|
| 44 |
REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
# =========================
|
| 47 |
+
# AIモデル関連の関数
|
| 48 |
# =========================
|
| 49 |
|
| 50 |
+
@st.cache_resource
|
| 51 |
+
def load_model():
|
| 52 |
+
"""AIモデルをロードしてStreamlitのキャッシュに保存"""
|
| 53 |
+
model_name = "Mizuiro-inc/emotion2vec-base-japanese"
|
| 54 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
| 55 |
+
model = AutoModelForAudioClassification.from_pretrained(model_name)
|
| 56 |
+
return feature_extractor, model
|
| 57 |
+
|
| 58 |
+
def predict_emotion(audio_bytes):
|
| 59 |
+
"""音声データからAIが感情を予測する"""
|
| 60 |
+
feature_extractor, model = load_model()
|
| 61 |
+
|
| 62 |
+
# 音声データを16kHzのWAV形式に変換
|
| 63 |
+
wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
|
| 64 |
+
y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
|
| 65 |
|
| 66 |
+
# 特徴量を抽出し、PyTorchテンソルに変換
|
| 67 |
+
inputs = feature_extractor(y, sampling_rate=sr, return_tensors="pt", padding=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
# AIモデルで予測を実行
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
logits = model(**inputs).logits
|
| 72 |
|
| 73 |
+
# 最も確率の高い感情ラベルを取得
|
| 74 |
+
predicted_id = torch.argmax(logits, dim=-1).item()
|
| 75 |
+
predicted_label = model.config.id2label[predicted_id]
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
# 各感情の確率も計算 (表示用)
|
| 78 |
+
probabilities = torch.softmax(logits, dim=-1)[0]
|
| 79 |
+
all_scores = {model.config.id2label[i]: prob.item() for i, prob in enumerate(probabilities)}
|
| 80 |
|
| 81 |
+
return predicted_label, all_scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
# =========================
|
| 84 |
+
# 汎用関数
|
| 85 |
+
# =========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
| 88 |
+
"""様々な形式の音声をWAV形式のbytesに変換"""
|
| 89 |
+
if not any_bytes: st.error("音声が空です。"); st.stop()
|
| 90 |
+
try:
|
| 91 |
+
seg = AudioSegment.from_file(io.BytesIO(any_bytes))
|
| 92 |
+
if mono: seg = seg.set_channels(1)
|
| 93 |
+
if target_sr: seg = seg.set_frame_rate(target_sr)
|
| 94 |
+
buf = io.BytesIO()
|
| 95 |
+
seg.export(buf, format="wav")
|
| 96 |
+
return buf.getvalue()
|
| 97 |
+
except Exception as e:
|
| 98 |
+
st.error(f"音声ファイルを処理できませんでした: {e}"); st.stop()
|
| 99 |
|
| 100 |
+
def audio_player_bytes(b: bytes, mime="audio/wav"):
|
| 101 |
+
"""音声データをUIに表示するためのHTMLを生成"""
|
| 102 |
+
if not b: return
|
| 103 |
+
b64 = base64.b64encode(b).decode("utf-8")
|
| 104 |
+
st.markdown(f'<audio controls preload="metadata" style="width:100%"><source src="data:{mime};base64,{b64}" type="{mime}"></audio>', unsafe_allow_html=True)
|
| 105 |
+
|
| 106 |
+
# AIの予測結果を場所の推薦に繋げるための新しい関数
|
| 107 |
+
def score_places_by_ai(emo_label, top_k=4):
|
| 108 |
+
"""AIの感情ラベルに基づいて場所を推薦する"""
|
| 109 |
+
# AIのラベルと場所のカテゴリを対応付ける
|
| 110 |
+
label_to_emo_key = {
|
| 111 |
+
'happy': ['joy', 'surprise'],
|
| 112 |
+
'sad': ['calm', 'joy'],
|
| 113 |
+
'angry': ['release', 'calm'],
|
| 114 |
+
'neutral': ['calm', 'surprise', 'joy']
|
| 115 |
+
}
|
| 116 |
+
priors = label_to_emo_key.get(emo_label, ['calm', 'joy']) # 未知のラベルはcalm/joyに
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
scored = []
|
| 119 |
for p in PLACES:
|
| 120 |
base = 0.5
|
|
|
|
| 123 |
scored.append((base + random.uniform(-0.02, 0.02), p))
|
| 124 |
|
| 125 |
scored.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
# 多様性を確保するロジック
|
| 128 |
+
candidates = [p for _, p in scored]
|
| 129 |
picked, seen = [], set()
|
| 130 |
for p in candidates:
|
| 131 |
if p["emo_key"] not in seen:
|
| 132 |
+
picked.append(p)
|
| 133 |
+
seen.add(p["emo_key"])
|
| 134 |
+
if len(picked) >= top_k: break
|
| 135 |
+
if len(picked) < top_k:
|
|
|
|
| 136 |
for p in candidates:
|
| 137 |
if p not in picked: picked.append(p)
|
| 138 |
+
if len(picked) >= top_k: break
|
| 139 |
return picked
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
# =========================
|
| 142 |
# Streamlit UI
|
| 143 |
# =========================
|
| 144 |
st.set_page_config(page_title="Voice→Place Recommender", page_icon="🎙️", layout="centered")
|
| 145 |
+
st.title("🎙️ 声の感情で『架空の場所』をレコメンド (AI版)")
|
| 146 |
+
st.caption("録音→AI感情推定→上位スポット→評価→CSV保存(匿名)")
|
| 147 |
|
| 148 |
# ---- Session state 初期化 ----
|
| 149 |
+
for key, default in [("wav_bytes", None), ("recs", None), ("emo_label", None), ("scores", None), ("rec_key", 0)]:
|
| 150 |
if key not in st.session_state: st.session_state[key] = default
|
| 151 |
|
| 152 |
# ---- 1) 録音 / アップロード ----
|
|
|
|
| 158 |
st.session_state["wav_bytes"] = audio.export().read()
|
| 159 |
audio_player_bytes(st.session_state["wav_bytes"])
|
| 160 |
if st.button("🧹 クリアして新しく録音", use_container_width=True):
|
| 161 |
+
for k in ["wav_bytes", "recs", "emo_label", "scores"]: st.session_state[k] = None
|
| 162 |
st.session_state["rec_key"] += 1
|
| 163 |
st.rerun()
|
| 164 |
|
| 165 |
with tab_upload:
|
| 166 |
+
up = st.file_uploader("WAV/MP3/M4A を選択", type=["wav", "mp3", "m4a"])
|
| 167 |
if up:
|
| 168 |
st.session_state["wav_bytes"] = up.read()
|
| 169 |
audio_player_bytes(st.session_state["wav_bytes"])
|
| 170 |
|
| 171 |
# ---- 2) 同意 ----
|
| 172 |
st.subheader("2) 同意")
|
| 173 |
+
consent = st.radio("研究利用の同意", ["保存しない(体験のみ)", "匿名で保存する"], horizontal=True)
|
|
|
|
| 174 |
|
| 175 |
# ---- 推定 & レコメンド実行 ----
|
| 176 |
+
if st.button("🔍 AIで推定 & レコメンド", type="primary", use_container_width=True, disabled=(st.session_state["wav_bytes"] is None)):
|
| 177 |
+
with st.spinner('AIが感情を分析中...🤖'):
|
| 178 |
+
raw_bytes = st.session_state["wav_bytes"]
|
| 179 |
+
emo_label, all_scores = predict_emotion(raw_bytes)
|
| 180 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
st.session_state.update({
|
| 182 |
+
"emo_label": emo_label,
|
| 183 |
+
"scores": all_scores,
|
| 184 |
+
"recs": score_places_by_ai(emo_label)
|
| 185 |
})
|
|
|
|
| 186 |
st.success("分析が完了しました!")
|
| 187 |
|
| 188 |
# ---- 結果表示 ----
|
| 189 |
+
if st.session_state.get("recs"):
|
| 190 |
+
emo_label = st.session_state["emo_label"]
|
| 191 |
+
scores = st.session_state["scores"]
|
| 192 |
+
recs = st.session_state["recs"]
|
| 193 |
|
| 194 |
st.subheader("分析結果")
|
| 195 |
+
col1, col2 = st.columns([0.6, 0.4])
|
| 196 |
+
with col1:
|
| 197 |
+
st.success(f"**AIの推定感情: {emo_label}**")
|
| 198 |
+
st.write("感情スコアの詳細:")
|
| 199 |
+
st.bar_chart(scores)
|
| 200 |
+
with col2:
|
| 201 |
+
st.write("この感情におすすめの場所:")
|
| 202 |
+
if recs:
|
| 203 |
+
st.image(recs[0]["image"], use_container_width=True)
|
| 204 |
+
st.markdown(f"**{recs[0]['name']}**")
|
| 205 |
+
st.caption(f"タグ: {', '.join(recs[0]['tags'])}")
|
|
|
|
|
|
|
| 206 |
|
| 207 |
st.subheader("3) おすすめ(上位4件)")
|
| 208 |
cols = st.columns(4)
|
|
|
|
| 222 |
comment = st.text_input("ひとことコメント(任意・20字)", max_chars=20)
|
| 223 |
|
| 224 |
if st.form_submit_button("💾 ログ保存", use_container_width=True):
|
| 225 |
+
st.info("ログ保存機能は現在開発中です。")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
+
# ---- フッター ----
|
| 228 |
st.divider()
|
| 229 |
+
if st.button("▶ 次の人を試す(状態をクリア)", use_container_width=True):
|
| 230 |
+
for k in ["wav_bytes", "recs", "emo_label", "scores"]:
|
| 231 |
+
if st.session_state.get(k): st.session_state[k] = None
|
|
|
|
| 232 |
st.session_state["rec_key"] += 1
|
| 233 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|