Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
|
|
| 1 |
"""
|
| 2 |
Voice→Place Recommender (Streamlit / Hugging Face Spaces)
|
| 3 |
-
-
|
| 4 |
-
-
|
|
|
|
| 5 |
"""
|
| 6 |
|
| 7 |
# ===== 基本インポート =====
|
| 8 |
-
import io,
|
| 9 |
import numpy as np
|
| 10 |
import soundfile as sf
|
| 11 |
from pydub import AudioSegment
|
|
@@ -13,14 +15,18 @@ from pydub import AudioSegment
|
|
| 13 |
import streamlit as st
|
| 14 |
from audiorecorder import audiorecorder
|
| 15 |
|
|
|
|
| 16 |
import matplotlib
|
| 17 |
matplotlib.use('Agg')
|
| 18 |
import matplotlib.pyplot as plt
|
| 19 |
from matplotlib import rcParams
|
| 20 |
import japanize_matplotlib
|
| 21 |
|
|
|
|
| 22 |
import torch
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# ===== フォント設定 =====
|
| 26 |
rcParams["font.family"] = "DejaVu Sans"
|
|
@@ -45,14 +51,33 @@ PLACES = [
|
|
| 45 |
]
|
| 46 |
REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
|
| 47 |
|
| 48 |
-
# =====
|
| 49 |
-
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@st.cache_resource(show_spinner=False)
|
| 52 |
-
def
|
| 53 |
"""
|
| 54 |
-
|
| 55 |
-
|
| 56 |
"""
|
| 57 |
token = os.getenv("HF_TOKEN")
|
| 58 |
if not token:
|
|
@@ -60,17 +85,71 @@ def load_model():
|
|
| 60 |
|
| 61 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# ===== ユーティリティ =====
|
| 76 |
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
|
@@ -92,6 +171,7 @@ def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
|
| 92 |
return buf.getvalue()
|
| 93 |
|
| 94 |
def audio_player_bytes(b: bytes, mime="audio/wav"):
|
|
|
|
| 95 |
if not b:
|
| 96 |
return
|
| 97 |
b64 = base64.b64encode(b).decode("utf-8")
|
|
@@ -104,8 +184,9 @@ def audio_player_bytes(b: bytes, mime="audio/wav"):
|
|
| 104 |
unsafe_allow_html=True,
|
| 105 |
)
|
| 106 |
|
| 107 |
-
# ===== フォールバック用簡易特徴量 =====
|
| 108 |
def extract_features(y, sr):
|
|
|
|
| 109 |
abs_y = np.abs(y)
|
| 110 |
thr = 0.01 * (abs_y.max() + 1e-9)
|
| 111 |
idx = np.where(abs_y > thr)[0]
|
|
@@ -125,6 +206,7 @@ def extract_features(y, sr):
|
|
| 125 |
zc = (y[:-1] * y[1:] < 0).astype(np.float32)
|
| 126 |
zcr_mean = float(zc.mean()) if zc.size else 0.0
|
| 127 |
|
|
|
|
| 128 |
fmin, fmax = 80.0, 600.0
|
| 129 |
if len(y) < int(sr / fmin) + 2:
|
| 130 |
f0_est = 0.0
|
|
@@ -148,6 +230,7 @@ def extract_features(y, sr):
|
|
| 148 |
}
|
| 149 |
|
| 150 |
def predict_emotion_features(audio_bytes):
|
|
|
|
| 151 |
wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
|
| 152 |
y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
|
| 153 |
feat = extract_features(y, sr)
|
|
@@ -172,23 +255,17 @@ def predict_emotion_features(audio_bytes):
|
|
| 172 |
scores["neutral"] += 0.3
|
| 173 |
return label, scores, "Features"
|
| 174 |
|
| 175 |
-
# ===== AI推定
|
| 176 |
-
def
|
| 177 |
-
"""
|
| 178 |
-
m = {
|
| 179 |
-
"happy": "happiness",
|
| 180 |
-
"happiness": "happiness",
|
| 181 |
-
"angry": "anger",
|
| 182 |
-
"anger": "anger",
|
| 183 |
-
"sad": "sadness",
|
| 184 |
-
"sadness": "sadness",
|
| 185 |
-
"neutral": "neutral"
|
| 186 |
-
}
|
| 187 |
return m.get(lbl.lower(), lbl)
|
| 188 |
|
| 189 |
def predict_emotion_ai(audio_bytes):
|
|
|
|
|
|
|
|
|
|
| 190 |
try:
|
| 191 |
-
|
| 192 |
except Exception as e:
|
| 193 |
st.error(f"モデルのロードに失敗しました: {e}")
|
| 194 |
st.info("音声特徴量ベースの分析に切り替えます。")
|
|
@@ -204,23 +281,32 @@ def predict_emotion_ai(audio_bytes):
|
|
| 204 |
y = y[:max_samples]
|
| 205 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 206 |
|
| 207 |
-
|
| 208 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 209 |
|
| 210 |
with torch.no_grad():
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
raw_label = model.config.id2label[pred_id]
|
| 216 |
-
label = normalize_label(raw_label)
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
for k in list(scores.keys()):
|
| 221 |
scores[k] = max(0.0, min(1.0, scores[k]))
|
| 222 |
-
|
| 223 |
-
return label, scores, "AI"
|
| 224 |
|
| 225 |
except Exception as e:
|
| 226 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
|
@@ -330,9 +416,8 @@ def main():
|
|
| 330 |
if key not in st.session_state: st.session_state[key] = default
|
| 331 |
|
| 332 |
st.subheader("1) 録音またはアップロード")
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
st.markdown("**🎤 録音** → PC/スマホで直接話す or 端末で音声再生しながら録音")
|
| 336 |
|
| 337 |
tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
|
| 338 |
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
"""
|
| 3 |
Voice→Place Recommender (Streamlit / Hugging Face Spaces)
|
| 4 |
+
- 日本語音声感情認識:S3PRL(HuBERT base) + HFの下流(.pt) チェックポイントを用いた推論
|
| 5 |
+
- Spaces の Settings → Secrets に HF_TOKEN を設定してください
|
| 6 |
+
- ffmpeg が必要(apt.txtに ffmpeg を記載)
|
| 7 |
"""
|
| 8 |
|
| 9 |
# ===== 基本インポート =====
|
| 10 |
+
import io, json, base64, random, os
|
| 11 |
import numpy as np
|
| 12 |
import soundfile as sf
|
| 13 |
from pydub import AudioSegment
|
|
|
|
| 15 |
import streamlit as st
|
| 16 |
from audiorecorder import audiorecorder
|
| 17 |
|
| 18 |
+
# Matplotlib
|
| 19 |
import matplotlib
|
| 20 |
matplotlib.use('Agg')
|
| 21 |
import matplotlib.pyplot as plt
|
| 22 |
from matplotlib import rcParams
|
| 23 |
import japanize_matplotlib
|
| 24 |
|
| 25 |
+
# Torch / Hugging Face Hub / S3PRL
|
| 26 |
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from huggingface_hub import list_repo_files, hf_hub_download
|
| 29 |
+
from s3prl.nn import S3PRLUpstream
|
| 30 |
|
| 31 |
# ===== フォント設定 =====
|
| 32 |
rcParams["font.family"] = "DejaVu Sans"
|
|
|
|
| 51 |
]
|
| 52 |
REASON_TAGS = ["静けさ","緑","水辺","発散","創作","交流","体験","学習","屋内","屋外","没入","回復"]
|
| 53 |
|
| 54 |
+
# ===== KUSHINADA 定義(HF の gated モデルのリポ)=====
|
| 55 |
+
KUSHINADA_REPO = "imprt/kushinada-hubert-base-jtes-er"
|
| 56 |
|
| 57 |
+
# ===== S3PRL 下流ヘッド(線形) =====
|
| 58 |
+
class SimpleLinearHead(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
チェックポイント中の線形分類器 (W, b) を復元する簡易ヘッド。
|
| 61 |
+
入力: [B, T, H] → mean-pool → [B, H] → Linear(H, C)
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, in_dim: int, num_classes: int, W: torch.Tensor, b: torch.Tensor):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.pool = lambda x: x.mean(dim=1) # 時系列平均
|
| 66 |
+
self.fc = nn.Linear(in_dim, num_classes)
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
self.fc.weight.copy_(W) # [C, H]
|
| 69 |
+
self.fc.bias.copy_(b) # [C]
|
| 70 |
+
|
| 71 |
+
def forward(self, reps): # reps: [B, T, H]
|
| 72 |
+
x = self.pool(reps)
|
| 73 |
+
return self.fc(x)
|
| 74 |
+
|
| 75 |
+
# ===== KUSHINADA (S3PRL) ローダ =====
|
| 76 |
@st.cache_resource(show_spinner=False)
|
| 77 |
+
def load_kushinada_s3prl():
|
| 78 |
"""
|
| 79 |
+
S3PRL上流(HuBERT base) + HFの下流(.pt)を自動取得して復元。
|
| 80 |
+
チェックポイント中から (weight,bias) を推定して線形ヘッドを構築。
|
| 81 |
"""
|
| 82 |
token = os.getenv("HF_TOKEN")
|
| 83 |
if not token:
|
|
|
|
| 85 |
|
| 86 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 87 |
|
| 88 |
+
# 1) S3PRL 上流:HuBERT base(kushinada はHuBERT系想定)
|
| 89 |
+
upstream = S3PRLUpstream("hubert_base").to(device).eval()
|
| 90 |
+
|
| 91 |
+
# 2) HFから .pt を探してダウンロード
|
| 92 |
+
files = list_repo_files(KUSHINADA_REPO, token=token)
|
| 93 |
+
pt_files = [f for f in files if f.endswith(".pt")]
|
| 94 |
+
if not pt_files:
|
| 95 |
+
raise FileNotFoundError("下流チェックポイント(.pt)が見つかりません。モデルページの Files を確認してください。")
|
| 96 |
+
|
| 97 |
+
# 最初の .pt を採用(必要なら固定のファイル名に変更)
|
| 98 |
+
ckpt_path = hf_hub_download(repo_id=KUSHINADA_REPO, filename=pt_files[0], token=token)
|
| 99 |
+
|
| 100 |
+
# 3) チェックポイント読込
|
| 101 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 102 |
+
|
| 103 |
+
# 4) state_dict から線形層の W, b を推定
|
| 104 |
+
state = None
|
| 105 |
+
if isinstance(ckpt, dict):
|
| 106 |
+
for key in ["state_dict", "Downstream", "model", "downstream", "net", "weights"]:
|
| 107 |
+
if key in ckpt and isinstance(ckpt[key], dict):
|
| 108 |
+
state = ckpt[key]
|
| 109 |
+
break
|
| 110 |
+
if state is None:
|
| 111 |
+
# そのままstate dictの可能性
|
| 112 |
+
# S3PRLのスクリプトにより出力形式は複数パターンありうる
|
| 113 |
+
state = ckpt
|
| 114 |
+
|
| 115 |
+
if not isinstance(state, dict):
|
| 116 |
+
raise RuntimeError("チェックポイント形式を解釈できませんでした。")
|
| 117 |
+
|
| 118 |
+
# W, b らしきテンソルを探索([C,H], [C] っぽい組を探す)
|
| 119 |
+
linear_W, linear_b = None, None
|
| 120 |
+
for k, v in state.items():
|
| 121 |
+
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
| 122 |
+
base = k.rsplit(".", 1)[0] # 例: "classifier.fc.weight" → "classifier.fc"
|
| 123 |
+
bias_key = base + ".bias"
|
| 124 |
+
if bias_key in state and isinstance(state[bias_key], torch.Tensor) and state[bias_key].ndim == 1:
|
| 125 |
+
linear_W = v
|
| 126 |
+
linear_b = state[bias_key]
|
| 127 |
+
break
|
| 128 |
+
|
| 129 |
+
if linear_W is None:
|
| 130 |
+
# 次善策: "weight"と"bias"という名前のペアを総当たり
|
| 131 |
+
twos = [(k,v) for k,v in state.items() if isinstance(v, torch.Tensor) and v.ndim==2 and k.endswith("weight")]
|
| 132 |
+
for wk, w in twos:
|
| 133 |
+
bk = wk.replace("weight", "bias")
|
| 134 |
+
if bk in state and isinstance(state[bk], torch.Tensor) and state[bk].ndim == 1:
|
| 135 |
+
linear_W, linear_b = w, state[bk]
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
if linear_W is None:
|
| 139 |
+
raise RuntimeError("線形分類器の重みが見つかりません。S3PRLの公式手順に沿ったDownstream再現が必要です。")
|
| 140 |
|
| 141 |
+
num_classes, hidden_dim = linear_W.shape # [C, H]
|
| 142 |
+
head = SimpleLinearHead(in_dim=hidden_dim, num_classes=num_classes,
|
| 143 |
+
W=linear_W, b=linear_b).to(device).eval()
|
| 144 |
+
|
| 145 |
+
# JTES想定:4クラス(angry/happy/neutral/sad)※順序は環境/学習で異なる可能性あり
|
| 146 |
+
default_labels = ["angry", "happy", "neutral", "sad"]
|
| 147 |
+
if num_classes == 4:
|
| 148 |
+
id2label = {i: default_labels[i] for i in range(4)}
|
| 149 |
+
else:
|
| 150 |
+
id2label = {i: f"class_{i}" for i in range(num_classes)}
|
| 151 |
+
|
| 152 |
+
return upstream, head, id2label, device
|
| 153 |
|
| 154 |
# ===== ユーティリティ =====
|
| 155 |
def to_wav_bytes(any_bytes: bytes, target_sr=16000, mono=True) -> bytes:
|
|
|
|
| 171 |
return buf.getvalue()
|
| 172 |
|
| 173 |
def audio_player_bytes(b: bytes, mime="audio/wav"):
|
| 174 |
+
"""音声プレイヤーを表示"""
|
| 175 |
if not b:
|
| 176 |
return
|
| 177 |
b64 = base64.b64encode(b).decode("utf-8")
|
|
|
|
| 184 |
unsafe_allow_html=True,
|
| 185 |
)
|
| 186 |
|
| 187 |
+
# ===== フォールバック用:簡易特徴量ベース =====
|
| 188 |
def extract_features(y, sr):
|
| 189 |
+
"""音声から簡易特徴量を抽出"""
|
| 190 |
abs_y = np.abs(y)
|
| 191 |
thr = 0.01 * (abs_y.max() + 1e-9)
|
| 192 |
idx = np.where(abs_y > thr)[0]
|
|
|
|
| 206 |
zc = (y[:-1] * y[1:] < 0).astype(np.float32)
|
| 207 |
zcr_mean = float(zc.mean()) if zc.size else 0.0
|
| 208 |
|
| 209 |
+
# F0推定(非常に簡易)
|
| 210 |
fmin, fmax = 80.0, 600.0
|
| 211 |
if len(y) < int(sr / fmin) + 2:
|
| 212 |
f0_est = 0.0
|
|
|
|
| 230 |
}
|
| 231 |
|
| 232 |
def predict_emotion_features(audio_bytes):
|
| 233 |
+
"""音声特徴量から感情を推定(フォールバック)"""
|
| 234 |
wav_bytes_16k = to_wav_bytes(audio_bytes, target_sr=16000)
|
| 235 |
y, sr = sf.read(io.BytesIO(wav_bytes_16k), dtype="float32")
|
| 236 |
feat = extract_features(y, sr)
|
|
|
|
| 255 |
scores["neutral"] += 0.3
|
| 256 |
return label, scores, "Features"
|
| 257 |
|
| 258 |
+
# ===== AI推定(S3PRL)=====
|
| 259 |
+
def _normalize_label(lbl: str) -> str:
|
| 260 |
+
m = {"happy": "happiness", "angry": "anger", "sad": "sadness", "neutral": "neutral"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
return m.get(lbl.lower(), lbl)
|
| 262 |
|
| 263 |
def predict_emotion_ai(audio_bytes):
|
| 264 |
+
"""
|
| 265 |
+
S3PRL上流 + HF下流(.pt) で推論。
|
| 266 |
+
"""
|
| 267 |
try:
|
| 268 |
+
upstream, head, id2label, device = load_kushinada_s3prl()
|
| 269 |
except Exception as e:
|
| 270 |
st.error(f"モデルのロードに失敗しました: {e}")
|
| 271 |
st.info("音声特徴量ベースの分析に切り替えます。")
|
|
|
|
| 281 |
y = y[:max_samples]
|
| 282 |
st.warning("音声が30秒を超えたため、最初の30秒のみを分析します。")
|
| 283 |
|
| 284 |
+
wav = torch.tensor(y, dtype=torch.float32, device=device).unsqueeze(0) # [1, T]
|
|
|
|
| 285 |
|
| 286 |
with torch.no_grad():
|
| 287 |
+
reps_dict = upstream(wav) # S3PRL Upstream の出力
|
| 288 |
+
if isinstance(reps_dict, dict):
|
| 289 |
+
reps = reps_dict.get("last_hidden_state", None)
|
| 290 |
+
if reps is None:
|
| 291 |
+
# 代替:最終層の hidden_states など
|
| 292 |
+
if "hidden_states" in reps_dict and isinstance(reps_dict["hidden_states"], (list, tuple)):
|
| 293 |
+
reps = reps_dict["hidden_states"][-1]
|
| 294 |
+
else:
|
| 295 |
+
# 直接テンソルが来る実装もある
|
| 296 |
+
reps = list(reps_dict.values())[-1]
|
| 297 |
+
else:
|
| 298 |
+
reps = reps_dict # テンソル想定 [B, T, H]
|
| 299 |
|
| 300 |
+
logits = head(reps) # [B, C]
|
| 301 |
+
probs = torch.softmax(logits, dim=-1)[0].detach().cpu().numpy()
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
pred_id = int(np.argmax(probs))
|
| 304 |
+
raw_label = id2label[pred_id]
|
| 305 |
+
label = _normalize_label(raw_label)
|
| 306 |
+
scores = {_normalize_label(id2label[i]): float(probs[i]) for i in range(len(probs))}
|
| 307 |
for k in list(scores.keys()):
|
| 308 |
scores[k] = max(0.0, min(1.0, scores[k]))
|
| 309 |
+
return label, scores, "AI(S3PRL)"
|
|
|
|
| 310 |
|
| 311 |
except Exception as e:
|
| 312 |
st.warning(f"AI予測中にエラーが発生: {e}")
|
|
|
|
| 416 |
if key not in st.session_state: st.session_state[key] = default
|
| 417 |
|
| 418 |
st.subheader("1) 録音またはアップロード")
|
| 419 |
+
with st.warning("⚠️ ファイルアップロードで403が出る場合は、録音機能をご利用ください。"):
|
| 420 |
+
st.markdown("**🎤 録音** → 直接話す or 端末で音声再生しながら録音")
|
|
|
|
| 421 |
|
| 422 |
tab_rec, tab_upload = st.tabs(["🎤 録音する(推奨)", "📁 ファイルを使う"])
|
| 423 |
|