Isk5434's picture
feat: aggressive augmentation + moderate ridge + final retrain on all data
25695c2 verified
# -*- coding: utf-8 -*-
import numpy as np
import librosa
import gradio as gr
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
"""音声分類.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1_wgDxDY2B0aYe_zbc1yZ603NBXlRc6Xe
"""
# ESN
# tutorial1〜3 を完全に踏襲した実装
# - W_in, b_in: input_scaling で初期化(tutorial1/2/3 共通)
# - W: density マスク + 固有値分解でスペクトル半径調整(tutorial1/2/3 共通)
# - fit(): H = [1, x(t)] をフレーム毎に積んでリッジ回帰(tutorial3 準拠)
# - predict_proba_sequence(): H @ W_out のフレーム平均で発話レベル推論(tutorial3 準拠)
# - predict_step_proba(): ストリーミング用、現フレームの H @ W_out を返す
from dataclasses import dataclass
@dataclass
class ESNConfig:
n_res: int = 300
spectral_radius: float = 0.9
leaking_rate: float = 0.3
input_scale: float = 0.5
density: float = 0.1 # density で疎行列制御(tutorial3 準拠)
ridge_alpha: float = 1.0 # 正則化強め(過学習抑制)
seed: int = 42
class ESNClassifier:
def __init__(self, cfg: ESNConfig, n_in: int, n_classes: int):
self.cfg = cfg
self.n_in = n_in
self.n_classes = n_classes
rng = np.random.default_rng(cfg.seed)
# ── 入力重み + バイアス(tutorial1/2/3 共通) ──────────────────────
# W_in: (n_res, n_in), b_in: (n_res,) いずれも input_scaling でスケール
self.W_in = (rng.uniform(-1, 1, (cfg.n_res, n_in))
* cfg.input_scale).astype(np.float32)
self.b_in = (rng.uniform(-1, 1, cfg.n_res)
* cfg.input_scale).astype(np.float32)
# ── リザバー内部結合(tutorial1/2/3 共通) ──────────────────────────
# density マスク → 疎結合
W = rng.uniform(-1, 1, (cfg.n_res, cfg.n_res)).astype(np.float32)
if 0.0 < cfg.density < 1.0:
mask = rng.uniform(0, 1, (cfg.n_res, cfg.n_res)) < cfg.density
W = W * mask
# スペクトル半径調整: 固有値分解(tutorial1/2/3 と同じ方法)
eigvals = np.linalg.eigvals(W.astype(np.float64))
sr = np.max(np.abs(eigvals))
if sr > 0:
W = (W * (cfg.spectral_radius / sr)).astype(np.float32)
self.W = W
self.W_out = None # fit() で決まる
self.reset_state()
def reset_state(self):
self.x = np.zeros(self.cfg.n_res, dtype=np.float32)
def _step(self, u):
"""1フレーム更新(tutorial3 の _forward_states の 1ステップ)
pre_activation = W @ x + W_in @ u + b_in
x = (1 - leak) * x + leak * tanh(pre_activation)
"""
pre = (self.W @ self.x
+ self.W_in @ u.astype(np.float32)
+ self.b_in).astype(np.float32)
x_new = np.tanh(pre)
a = self.cfg.leaking_rate
self.x = ((1 - a) * self.x + a * x_new).astype(np.float32)
return self.x.copy()
def _forward_states(self, U):
"""
U: (T, n_in) 1発話分の MFCC 系列
→ states: (T, n_res) 各時刻のリザバー状態(tutorial3 準拠)
"""
U = np.asarray(U, dtype=np.float32)
T = U.shape[0]
self.reset_state()
states = np.zeros((T, self.cfg.n_res), dtype=np.float32)
for t in range(T):
states[t] = self._step(U[t])
return states
def fit(self, X_list, y_list):
"""
tutorial3 の fit() を完全再現:
H = [1, x(t)](バイアス列を先頭に付加)をフレーム毎に積み、
W_out = (H^T H + lambda*I)^{-1} H^T Y (リッジ回帰の正規方程式)
"""
N = len(X_list)
total_frames = sum(np.asarray(X_list[i]).shape[0] for i in range(N))
# H: (total_frames, 1+n_res), Y: (total_frames, n_classes)
H = np.zeros((total_frames, 1 + self.cfg.n_res), dtype=np.float32)
Y = np.zeros((total_frames, self.n_classes), dtype=np.float32)
row = 0
for i in range(N):
states = self._forward_states(X_list[i]) # (T_i, n_res)
T_i = states.shape[0]
H[row:row+T_i, 0] = 1.0 # バイアス列
H[row:row+T_i, 1:] = states
Y[row:row+T_i, y_list[i]] = 1.0 # フレーム毎にそのラベルを付与
row += T_i
# リッジ回帰(正規方程式)
Ht = H.T
I = np.eye(H.shape[1], dtype=np.float32)
A = Ht @ H + self.cfg.ridge_alpha * I
B = Ht @ Y
try:
self.W_out = np.linalg.solve(
A.astype(np.float64), B.astype(np.float64)
).astype(np.float32)
except np.linalg.LinAlgError:
# 特異行列のときは lstsq(擬似逆行列)でフォールバック
self.W_out = np.linalg.lstsq(
A.astype(np.float64), B.astype(np.float64), rcond=None
)[0].astype(np.float32)
def predict_proba_sequence(self, U):
"""
tutorial3 の predict_proba() を1発話に適用:
H = [1, x(t)] → H @ W_out のフレーム平均 → softmax
"""
states = self._forward_states(U) # (T, n_res)
T = states.shape[0]
H = np.zeros((T, 1 + self.cfg.n_res), dtype=np.float32)
H[:, 0] = 1.0
H[:, 1:] = states
Z = H @ self.W_out # (T, n_classes)
z = Z.mean(axis=0) # 発話内フレーム平均
z -= z.max()
return np.exp(z) / (np.exp(z).sum() + 1e-12)
def predict_step_proba(self, u):
"""
ストリーミング用: 現フレームの状態のみで即推論
h = [1, x(t)] → W_out → softmax
"""
x = self._step(u)
h = np.empty(1 + self.cfg.n_res, dtype=np.float32)
h[0] = 1.0
h[1:] = x
z = h @ self.W_out # (n_classes,)
z -= z.max()
return np.exp(z) / (np.exp(z).sum() + 1e-12)
# Audio -> MFCC
SR = 16000
N_MFCC = 13
HOP = int(0.01 * SR) # 10ms
WIN = int(0.025 * SR) # 25ms
SILENCE_RMS = 0.02 # 無音判定の RMS 閾値(学習時の無音除去用)
def _mono_float32(y):
y = y.astype(np.float32)
if y.ndim > 1:
y = y.mean(axis=1)
return y
def _rms(y: np.ndarray) -> float:
return float(np.sqrt(np.mean(y.astype(np.float32) ** 2)))
def _trim_silence(y: np.ndarray, sr: int) -> np.ndarray:
"""RMS ベースで無音フレームを除去した波形を返す(tutorial3 準拠)"""
rms = librosa.feature.rms(y=y, frame_length=WIN, hop_length=HOP)[0]
keep = rms > SILENCE_RMS
if not np.any(keep):
return y # 全部無音なら元波形を返す
idx = np.where(keep)[0]
start = idx[0] * HOP
end = min(len(y), idx[-1] * HOP + WIN)
return y[start:end]
def normalize_audio_tuple(audio):
sr, y = audio
y = _mono_float32(y)
if sr != SR:
y = librosa.resample(y, orig_sr=sr, target_sr=SR)
sr = SR
y /= (np.max(np.abs(y)) + 1e-9)
return (sr, y)
def _wav_to_mfcc(y, sr):
"""正規化済み波形 → 無音除去 → 再正規化 → MFCC(学習・推論共通)"""
y = _trim_silence(y, sr)
y /= (np.max(np.abs(y)) + 1e-9)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=N_MFCC, hop_length=HOP, n_fft=WIN)
return mfcc.T.astype(np.float32)
def audio_to_sequence(audio):
"""学習データ用: 正規化 → MFCC"""
sr, y = normalize_audio_tuple(audio)
return _wav_to_mfcc(y, sr)
# ── データ拡張(Augmentation) ─────────────────────────────────
# 少量データで過学習を防ぐため、学習時にノイズ・音量・速度を変えた
# コピーを自動生成し、「特定の録音の丸暗記」を防ぐ
N_AUGMENT = 6 # 1サンプルあたりの拡張コピー数
def _augment_wav(y, sr, rng):
"""波形レベルのデータ拡張(攻撃的)"""
aug = y.copy()
# 1) ガウスノイズ付加(SNR 5〜20 dB — 強めのノイズ)
snr_db = rng.uniform(5, 20)
sig_power = np.mean(aug ** 2) + 1e-12
noise_power = sig_power / (10 ** (snr_db / 10))
aug = aug + rng.normal(0, np.sqrt(noise_power), len(aug)).astype(np.float32)
# 2) 音量スケーリング(0.4〜1.6倍 — 広い範囲)
aug = aug * rng.uniform(0.4, 1.6)
# 3) 時間伸縮(0.8〜1.2倍)
rate = rng.uniform(0.8, 1.2)
aug = librosa.effects.time_stretch(aug, rate=rate)
# 4) ピッチシフト(±3半音 — 声質の違いを再現)
n_steps = rng.uniform(-3, 3)
aug = librosa.effects.pitch_shift(aug, sr=sr, n_steps=n_steps)
# 再正規化
aug /= (np.max(np.abs(aug)) + 1e-9)
return aug.astype(np.float32)
def augment_sequences(audio_tuple):
"""1つの音声から N_AUGMENT 個の拡張 MFCC 系列を生成"""
sr, y = normalize_audio_tuple(audio_tuple)
rng = np.random.default_rng()
seqs = []
for _ in range(N_AUGMENT):
aug = _augment_wav(y, sr, rng)
seq = _wav_to_mfcc(aug, sr)
if seq is not None and len(seq) >= 5:
seqs.append(seq)
return seqs
def chunk_to_seq(chunk):
if chunk is None:
return None
sr, y = chunk
if y is None:
return None
y = _mono_float32(y)
if len(y) < WIN:
return None
return audio_to_sequence((sr, y))
# Data store (replay + relabel)
# each item: {"audio": (sr,y), "U": mfcc_seq, "label": str}
DATA = []
LABELS = []
MODEL = None
# 起動時ウォームアップ: librosa の遅延ロードを解消し、初回保存の遅延を防ぐ
def _warmup():
try:
_dummy = np.zeros(SR, dtype=np.float32)
librosa.feature.mfcc(y=_dummy, sr=SR, n_mfcc=N_MFCC, hop_length=HOP, n_fft=WIN)
except Exception:
pass
_warmup()
def probs_dict_from_p(p):
if p is None or len(LABELS) == 0:
return {}
return {LABELS[i]: float(p[i]) for i in range(min(len(LABELS), len(p)))}
def dataset_table():
# compact rows: idx, label, sec
rows = []
for i, item in enumerate(DATA):
sr, y = item["audio"]
dur = (len(y) / sr) if (sr and y is not None) else 0.0
rows.append([i, item["label"], round(dur, 2)])
return rows
def dataset_stats_text():
counts = {l: 0 for l in LABELS}
for item in DATA:
lab = item["label"]
counts[lab] = counts.get(lab, 0) + 1
parts = [f"{l}:{counts.get(l,0)}" for l in LABELS] if LABELS else ["(no labels)"]
return f"n={len(DATA)} | " + " ".join(parts)
# Training(先に分割 → train のみ拡張 → ランダムサーチ)
def train_random(n_trials):
global MODEL
if len(LABELS) < 2:
return "ラベル不足(2種類以上)"
if len(DATA) < 6:
return "データ不足(目安: 合計6以上、各ラベル3以上)"
# ── 1) 元データで分割(リーク防止) ──────────────────────
X_orig = [d["U"] for d in DATA]
y_orig = [LABELS.index(d["label"]) for d in DATA]
audio_orig = [d["audio"] for d in DATA]
try:
split = train_test_split(
list(range(len(DATA))), y_orig,
test_size=0.3, random_state=0, stratify=y_orig
)
idx_tr, idx_val = split[0], split[1]
except Exception:
split = train_test_split(
list(range(len(DATA))), y_orig,
test_size=0.3, random_state=0
)
idx_tr, idx_val = split[0], split[1]
# ── 2) val は元データのみ(拡張しない) ──────────────────
X_val = [X_orig[i] for i in idx_val]
y_val = [y_orig[i] for i in idx_val]
# ── 3) train は元データ+拡張データ ──────────────────────
X_tr = []
y_tr = []
for i in idx_tr:
X_tr.append(X_orig[i])
y_tr.append(y_orig[i])
for aug_seq in augment_sequences(audio_orig[i]):
X_tr.append(aug_seq)
y_tr.append(y_orig[i])
best_acc = -1.0
best_cfg = None
try:
for _ in range(int(n_trials)):
cfg = ESNConfig(
n_res = int(np.random.choice([50, 100, 150])),
spectral_radius = float(np.random.uniform(0.5, 1.0)),
leaking_rate = float(np.random.uniform(0.1, 0.8)),
input_scale = float(np.random.uniform(0.1, 1.0)),
density = float(np.random.uniform(0.05, 0.3)),
ridge_alpha = float(10 ** np.random.uniform(-2, 1)),
seed = int(np.random.randint(0, 10_000_000)),
)
model = ESNClassifier(cfg, N_MFCC, len(LABELS))
model.fit(X_tr, y_tr)
preds = [int(np.argmax(model.predict_proba_sequence(U))) for U in X_val]
acc = accuracy_score(y_val, preds)
if acc > best_acc:
best_acc = acc
best_cfg = cfg
except Exception:
import traceback
return f"❌ エラー:\n{traceback.format_exc()}"
# ── 4) ベスト cfg で全データ(元+拡張)を使って最終学習 ───
X_final = []
y_final = []
for d in DATA:
lab_idx = LABELS.index(d["label"])
X_final.append(d["U"])
y_final.append(lab_idx)
for aug_seq in augment_sequences(d["audio"]):
X_final.append(aug_seq)
y_final.append(lab_idx)
final_model = ESNClassifier(best_cfg, N_MFCC, len(LABELS))
final_model.fit(X_final, y_final)
MODEL = final_model
return (f"val acc: {best_acc:.3f} | "
f"train: {len(idx_tr)}{len(X_tr)}(x{N_AUGMENT}aug) | "
f"val: {len(idx_val)}(元データのみ) | "
f"最終学習: 全{len(X_final)}サンプル | "
f"cfg: {best_cfg}")
# 推論: 録音→無音自動停止→学習と同じパイプラインで一括推論
# (収集タブと同じJS自動停止を共有)
def infer_on_stop(audio):
"""録音停止時に発話全体を一括推論(学習評価と同じ predict_proba_sequence)"""
global MODEL
if MODEL is None:
return "未学習", {}, gr.update(value=None)
if audio is None:
return "...", {}, gr.update(value=None)
sr, y = audio
if y is None or len(y) == 0:
return "...", {}, gr.update(value=None)
# 学習時と同じパイプライン: normalize → trim_silence → re-normalize → MFCC
y = _mono_float32(y)
if sr != SR:
y = librosa.resample(y, orig_sr=sr, target_sr=SR)
y /= (np.max(np.abs(y)) + 1e-9)
y = _trim_silence(y, SR)
if len(y) < WIN:
return "(音声が短すぎます)", {}, gr.update(value=None)
y /= (np.max(np.abs(y)) + 1e-9)
mfcc = librosa.feature.mfcc(
y=y, sr=SR, n_mfcc=N_MFCC, hop_length=HOP, n_fft=WIN
).T.astype(np.float32)
if len(mfcc) < 2:
return "(音声が短すぎます)", {}, gr.update(value=None)
p = MODEL.predict_proba_sequence(mfcc)
pred = LABELS[int(np.argmax(p))]
prob = probs_dict_from_p(p)
return pred, prob, gr.update(value=None)
# UI callbacks
def add_label_cb(label):
label = (label or "").strip()
if not label:
return gr.update(), dataset_table(), gr.update()
if label not in LABELS:
LABELS.append(label)
return gr.update(choices=LABELS, value=label), dataset_table(), gr.update(choices=LABELS)
def add_sample_cb(audio, label):
label = (label or "").strip()
if label not in LABELS:
return dataset_table(), gr.update(value=None)
if audio is None:
return dataset_table(), gr.update(value=None)
audio_n = normalize_audio_tuple(audio)
U = audio_to_sequence(audio_n)
if U is None or len(U) < 5:
return dataset_table(), gr.update(value=None)
DATA.append({"audio": audio_n, "U": U, "label": label})
return dataset_table(), gr.update(value=None)
def auto_add_sample_cb(audio, label):
"""録音完了(stop_recording)時に自動でDATAに追加"""
label = (label or "").strip()
if label not in LABELS:
return dataset_table(), gr.update(value=None), "⚠ ラベルを先に選択してください"
if audio is None:
return dataset_table(), gr.update(value=None), "待機中..."
audio_n = normalize_audio_tuple(audio)
U = audio_to_sequence(audio_n)
if U is None or len(U) < 5:
return dataset_table(), gr.update(value=None), "⚠ 音声が短すぎます(もう少し話してください)"
DATA.append({"audio": audio_n, "U": U, "label": label})
return dataset_table(), gr.update(value=None), f"✓ 保存完了! (idx={len(DATA)-1}, label={label})"
def undo_last_cb():
if len(DATA) == 0:
return dataset_table()
DATA.pop()
return dataset_table()
def reset_all_cb():
global MODEL
DATA.clear()
LABELS.clear()
MODEL = None
return (
dataset_table(),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
gr.update(value=None),
None,
)
def clear_rec_cb():
return gr.update(value=None)
def on_select_row(evt: gr.SelectData):
# evt.index: (row, col) or row index depending on component
# For Dataframe select, evt.index is (row, col)
if evt is None or evt.index is None:
return None, gr.update(value=None), "未選択"
row = evt.index[0] if isinstance(evt.index, (tuple, list)) else int(evt.index)
if row < 0 or row >= len(DATA):
return None, gr.update(value=None), "範囲外"
item = DATA[row]
audio = item["audio"]
lab = item["label"]
return audio, gr.update(value=lab), f"selected idx={row}"
def apply_relabel_cb(table, new_label):
# table is not trusted as source of truth; we use selected_idx in state ideally,
# but keep simple: infer currently selected by last "selected idx=..."
# -> we will pass selected index via State.
return "内部: use state idx"
def relabel_selected_cb(selected_idx, new_label):
new_label = (new_label or "").strip()
if selected_idx is None or selected_idx < 0 or selected_idx >= len(DATA):
return dataset_table()
if new_label not in LABELS:
return dataset_table()
DATA[selected_idx]["label"] = new_label
return dataset_table()
def delete_selected_cb(selected_idx):
if selected_idx is None or selected_idx < 0 or selected_idx >= len(DATA):
return dataset_table(), None, gr.update(value=None)
DATA.pop(selected_idx)
return dataset_table(), None, gr.update(value=None)
# Compact UI — Copy.py風フェミニン・パステルデザイン + 写真風ブルー→グリーングラデーション
HEAD = """
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, viewport-fit=cover, user-scalable=no">
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Cormorant+Garamond:wght@300;400;500;600&display=swap" rel="stylesheet">
<script>
/* ── 無音自動停止(Gradio公式 .stop-button/.record-button セレクタ使用) ── */
(function(){
const SILENCE_THRESH = 0.01;
const SPEECH_THRESH = 0.015;
const SILENCE_MS = 300;
let ctx, analyser, src, silStart, spoke, on;
function stat(m){ var e=document.getElementById('rec_status_js'); if(e) e.textContent=m; }
function stop(){
on=false;
try{src&&src.disconnect()}catch(e){}
try{ctx&&ctx.close()}catch(e){}
src=null;ctx=null;analyser=null;
}
function monitor(stream){
stop();
on=true; spoke=false; silStart=null;
ctx=new(window.AudioContext||window.webkitAudioContext)();
analyser=ctx.createAnalyser(); analyser.fftSize=512;
src=ctx.createMediaStreamSource(stream); src.connect(analyser);
var buf=new Float32Array(analyser.fftSize);
function tick(){
if(!on)return;
analyser.getFloatTimeDomainData(buf);
var s=0; for(var i=0;i<buf.length;i++) s+=buf[i]*buf[i];
var rms=Math.sqrt(s/buf.length);
if(rms>SPEECH_THRESH){
spoke=true; silStart=null;
stat('録音中... 🎙️');
} else if(spoke && rms<SILENCE_THRESH){
if(!silStart) silStart=Date.now();
var left=Math.max(0,SILENCE_MS-(Date.now()-silStart));
stat('無音検出中... あと'+Math.ceil(left/1000)+'秒');
if(left<=0){
stat('自動停止 & 保存中...');
/* Gradio公式: .stop-button で停止ボタンを取得 */
var btn=document.querySelector('.stop-button');
console.log('[auto-stop] .stop-button found:', btn);
if(btn){ btn.click(); }
stop();
return;
}
} else if(!spoke){
stat('待機中... 話してください 🎤');
}
requestAnimationFrame(tick);
}
tick();
}
var orig=navigator.mediaDevices.getUserMedia.bind(navigator.mediaDevices);
navigator.mediaDevices.getUserMedia=function(c){
return orig(c).then(function(stream){
if(c&&c.audio){
console.log('[auto-stop] getUserMedia intercepted, starting monitor');
monitor(stream);
stream.getAudioTracks().forEach(function(t){t.addEventListener('ended',stop);});
}
return stream;
});
};
})();
</script>
"""
CSS = """
/* ========================================
グローバル: フェミニン・パステルテーマ
写真風 淡いブルー→ミントグリーン グラデーション
======================================== */
html {
scroll-behavior: smooth !important;
-webkit-overflow-scrolling: touch !important;
}
.gradio-container {
background: linear-gradient(175deg,
#c5dff0 0%,
#b8e0e8 15%,
#b2e5d8 30%,
#c8ecd0 50%,
#d5f0cc 70%,
#e0f4d8 85%,
#e8f6e0 100%) !important;
color: #1a1a1a !important;
font-family: 'Inter', 'Hiragino Kaku Gothic ProN', 'Noto Sans JP', -apple-system, BlinkMacSystemFont, sans-serif !important;
font-weight: 400 !important;
max-width: 100% !important;
padding: 0 !important;
min-height: 100vh !important;
}
footer { display: none !important; }
/* ========================================
ヘッダーヒーローセクション
======================================== */
.hero-section {
text-align: center !important;
padding: 48px 24px 32px 24px !important;
}
.hero-section h1 {
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', serif !important;
font-size: clamp(24px, 7vw, 38px) !important;
font-weight: 400 !important;
letter-spacing: 0.1em !important;
color: #2a3a2a !important;
text-transform: uppercase !important;
margin: 0 0 8px 0 !important;
line-height: 1.2 !important;
}
.hero-section p {
font-size: clamp(12px, 3vw, 14px) !important;
color: #4a5a4a !important;
letter-spacing: 0.04em !important;
margin: 0 !important;
}
/* ========================================
タブナビゲーション: ピル型パステル
======================================== */
div.tab-nav {
background: rgba(255,255,255,0.75) !important;
border: 1px solid rgba(0,0,0,0.06) !important;
border-radius: 22px !important;
padding: 5px !important;
margin: 16px 16px 20px 16px !important;
display: flex !important;
justify-content: center !important;
gap: 3px !important;
box-shadow: 0 2px 12px rgba(100,140,120,0.08) !important;
backdrop-filter: blur(8px) !important;
-webkit-backdrop-filter: blur(8px) !important;
}
div.tab-nav button {
background: transparent !important;
color: #3a4a3a !important;
border: none !important;
border-radius: 18px !important;
padding: 10px 16px !important;
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', serif !important;
font-size: 13px !important;
font-weight: 500 !important;
transition: all 0.25s ease !important;
letter-spacing: 0.1em !important;
}
div.tab-nav button.selected {
background: rgba(178,216,210,0.4) !important;
color: #1a2a1a !important;
box-shadow: 0 1px 8px rgba(178,216,210,0.25) !important;
}
/* ========================================
タブコンテンツ: 二重フレーム(Copy.py風)
======================================== */
.tabitem {
background: transparent !important;
border: none !important;
}
.tabitem > div {
position: relative !important;
background: rgba(255,255,255,0.92) !important;
border-radius: 0 !important;
padding: 32px 20px !important;
margin: 16px 18px 28px 18px !important;
border: 1.25px solid #8aaa8a !important;
box-shadow: 8px 8px 0px 0px #c0d8c0 !important;
backdrop-filter: blur(4px) !important;
-webkit-backdrop-filter: blur(4px) !important;
animation: softFadeIn 0.5s ease forwards;
}
/* ========================================
ラベル: セリフ体
======================================== */
label span, .label-wrap span {
color: #2a3a2a !important;
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', serif !important;
font-weight: 500 !important;
font-size: 13px !important;
letter-spacing: 0.06em !important;
}
/* ========================================
テキスト入力
======================================== */
input[type="text"], textarea, select {
background: rgba(255,255,255,0.8) !important;
border: 1px solid rgba(138,170,138,0.35) !important;
border-radius: 0 !important;
color: #1a1a1a !important;
font-weight: 400 !important;
}
input[type="text"]:focus, textarea:focus {
border-color: rgba(138,170,138,0.6) !important;
box-shadow: 0 0 0 3px rgba(178,216,210,0.2) !important;
outline: none !important;
}
/* Number input: シャープスタイル */
input[type="number"] {
background: #ffffff !important;
border: 1.25px solid #8aaa8a !important;
border-radius: 0 !important;
color: #1a1a1a !important;
font-family: 'Cormorant Garamond', 'Georgia', serif !important;
font-weight: 500 !important;
font-size: 13px !important;
text-align: center !important;
padding: 4px 6px !important;
box-shadow: 3px 3px 0px 0px #c0d8c0 !important;
outline: none !important;
-moz-appearance: textfield !important;
}
input[type="number"]:focus {
border-color: #6a8a6a !important;
box-shadow: 4px 4px 0px 0px #a8c8a8 !important;
}
input[type="number"]::-webkit-inner-spin-button,
input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none !important;
margin: 0 !important;
}
/* ========================================
Dropdown: Copy.py風シャープスタイル
======================================== */
[data-testid="dropdown"] {
background: transparent !important;
border: none !important;
box-shadow: none !important;
border-radius: 0 !important;
}
[data-testid="dropdown"] > div,
[data-testid="dropdown"] .wrap,
[data-testid="dropdown"] .wrap-inner,
[data-testid="dropdown"] .secondary-wrap,
[data-testid="dropdown"] input,
[data-testid="dropdown"] .multiselect {
background: #ffffff !important;
border: none !important;
border-radius: 0 !important;
box-shadow: none !important;
outline: none !important;
}
[data-testid="dropdown"] .wrap,
[data-testid="dropdown"] .secondary-wrap {
border: 1px solid #5a7a5a !important;
box-shadow: 3px 3px 0px 0px #c0d8c0 !important;
padding: 8px 10px !important;
}
[data-testid="dropdown"] *:not(ul):not(ul *) {
background: #ffffff !important;
color: #1a1a1a !important;
border-radius: 0 !important;
}
ul.options, ul.options li, .options, .options .item,
.secondary-wrap .item, .secondary-wrap ul li {
background: #2a3a2a !important;
color: #ffffff !important;
border-radius: 0 !important;
}
ul.options li:hover, .options .item:hover, .secondary-wrap .item:hover {
background: #3a4a3a !important;
color: #ffffff !important;
}
/* ========================================
Slider: ラグジュアリー仕様
======================================== */
input[type="range"] {
-webkit-appearance: none !important;
appearance: none !important;
height: 3px !important;
background: linear-gradient(90deg,
#b2d8e8 0%,
#b2e0d4 50%,
#c8e8c0 100%) !important;
border-radius: 0 !important;
outline: none !important;
cursor: pointer !important;
overflow: visible !important;
margin: 12px 0 !important;
}
input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none !important;
appearance: none !important;
width: 14px !important;
height: 14px !important;
background: #3a5a3a !important;
border: 1.5px solid #8aaa8a !important;
border-radius: 0 !important;
transform: rotate(45deg) !important;
cursor: pointer !important;
box-shadow: 2px 2px 4px rgba(0,0,0,0.15) !important;
margin-top: -6px !important;
position: relative !important;
}
input[type="range"]::-moz-range-thumb {
width: 14px !important;
height: 14px !important;
background: #3a5a3a !important;
border: 1.5px solid #8aaa8a !important;
border-radius: 0 !important;
transform: rotate(45deg) !important;
cursor: pointer !important;
}
input[type="range"]::-webkit-slider-runnable-track {
height: 3px !important;
background: linear-gradient(90deg,
#b2d8e8 0%,
#b2e0d4 50%,
#c8e8c0 100%) !important;
border-radius: 0 !important;
overflow: visible !important;
}
[data-testid="slider"] {
background: rgba(255,255,255,0.6) !important;
border: 1.25px solid #a8c8a8 !important;
border-radius: 0 !important;
padding: 12px 14px 16px 14px !important;
box-shadow: 3px 3px 0px 0px #c0d8c0 !important;
overflow: visible !important;
}
[data-testid="slider"] > div,
[data-testid="slider"] .wrap,
[data-testid="slider"] .wrap-inner {
overflow: visible !important;
}
[data-testid="slider"] .label-wrap span,
[data-testid="slider"] label span {
font-family: 'Cormorant Garamond', 'Georgia', serif !important;
font-size: 11px !important;
letter-spacing: 0.12em !important;
text-transform: uppercase !important;
color: #4a6a4a !important;
}
[data-testid="slider"] button {
border-radius: 0 !important;
border: 1.25px solid #a8c8a8 !important;
background: #f0f5f0 !important;
padding: 4px 6px !important;
}
/* ========================================
Radioボタン: ダイヤモンド◆スタイル
======================================== */
.diamond-radio,
.diamond-radio *:not(input):not(span) {
background: transparent !important;
background-color: transparent !important;
box-shadow: none !important;
border-color: transparent !important;
}
.diamond-radio {
border: 2px solid transparent !important;
border-image: linear-gradient(135deg, #b2d8e8 0%, #b2e0d4 50%, #c8e8c0 100%) 1 !important;
border-radius: 0 !important;
padding: 12px 14px !important;
}
.diamond-radio label,
.diamond-radio .wrap label,
.diamond-radio label.svelte-1kcyvh9 {
display: flex !important;
align-items: center !important;
gap: 10px !important;
padding: 10px 12px !important;
cursor: pointer !important;
transition: all 0.2s ease !important;
border-bottom: 1px solid rgba(138,170,138,0.15) !important;
font-family: 'Cormorant Garamond', 'Georgia', serif !important;
font-size: 14px !important;
font-weight: 500 !important;
letter-spacing: 0.08em !important;
}
.diamond-radio label:last-child {
border-bottom: none !important;
}
.diamond-radio label:hover {
background: rgba(178,216,210,0.15) !important;
}
/* ラジオ丸→ダイヤモンド: あらゆるinput[type=radio]をカバー */
.diamond-radio input[type="radio"],
.diamond-radio input[type="radio"]::before,
.diamond-radio input[type="radio"]::after {
-webkit-appearance: none !important;
appearance: none !important;
border-radius: 0 !important;
}
.diamond-radio input[type="radio"] {
width: 14px !important;
height: 14px !important;
min-width: 14px !important;
border: 1.5px solid #8aaa8a !important;
background: #ffffff !important;
transform: rotate(45deg) !important;
cursor: pointer !important;
transition: all 0.25s ease !important;
box-shadow: 1px 1px 3px rgba(0,0,0,0.1) !important;
margin: 0 4px 0 0 !important;
padding: 0 !important;
flex-shrink: 0 !important;
}
.diamond-radio input[type="radio"]:checked {
background: linear-gradient(135deg, #a0cfe0 0%, #90d4c8 50%, #b0e0a0 100%) !important;
border-color: #7abaa0 !important;
box-shadow: 2px 2px 4px rgba(0,0,0,0.2) !important;
}
/* Gradio 6: 丸いsvg/spanインジケーターも上書き */
.diamond-radio .radio-circle,
.diamond-radio [class*="radio"] span:first-child,
.diamond-radio .item > div:first-child,
.diamond-radio .choice > div:first-child {
width: 14px !important;
height: 14px !important;
min-width: 14px !important;
border: 1.5px solid #8aaa8a !important;
border-radius: 0 !important;
background: #ffffff !important;
transform: rotate(45deg) !important;
box-shadow: 1px 1px 3px rgba(0,0,0,0.1) !important;
transition: all 0.25s ease !important;
}
.diamond-radio .selected .radio-circle,
.diamond-radio .selected [class*="radio"] span:first-child,
.diamond-radio .selected .item > div:first-child,
.diamond-radio .selected .choice > div:first-child {
background: linear-gradient(135deg, #a0cfe0 0%, #90d4c8 50%, #b0e0a0 100%) !important;
border-color: #7abaa0 !important;
box-shadow: 2px 2px 4px rgba(0,0,0,0.2) !important;
}
/* ========================================
Gradioボタン: VIEW MORE風 / セリフ体 / シャープ
======================================== */
button[class*="primary"], button[class*="secondary"],
button.lg {
border-radius: 0 !important;
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', 'YuMincho', serif !important;
font-weight: 500 !important;
font-size: 13px !important;
letter-spacing: 0.18em !important;
text-transform: uppercase !important;
padding: 15px 24px !important;
transition: all 0.3s ease !important;
touch-action: manipulation !important;
-webkit-tap-highlight-color: transparent !important;
}
/* Primary: ダーク背景 */
button[class*="primary"] {
background: #3a5a3a !important;
color: #d8e8d8 !important;
border: 1px solid #3a5a3a !important;
box-shadow: none !important;
}
button[class*="primary"]:hover {
background: #4a6a4a !important;
}
button[class*="primary"]:active {
background: #5a7a5a !important;
}
/* Secondary: 白背景 + 細線ボーダー */
button[class*="secondary"] {
background: #ffffff !important;
color: #4a5a4a !important;
border: 1px solid #8aaa8a !important;
box-shadow: none !important;
}
button[class*="secondary"]:hover {
background: #f5faf5 !important;
border-color: #6a8a6a !important;
}
button[class*="secondary"]:active {
background: #eef5ee !important;
}
/* ========================================
Textbox / Markdown
======================================== */
textarea {
background: rgba(255,255,255,0.7) !important;
color: #1a1a1a !important;
border-radius: 0 !important;
border: 1px solid #a8c8a8 !important;
font-family: 'SF Mono', 'Fira Code', ui-monospace, monospace !important;
font-size: 12px !important;
}
.prose, .markdown-text, .md {
color: #1a1a1a !important;
}
.prose h2, .prose h3 {
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', serif !important;
color: #2a3a2a !important;
font-weight: 400 !important;
letter-spacing: 0.08em !important;
}
/* ========================================
JSON表示
======================================== */
.json-holder, [data-testid="json"] {
background: rgba(255,255,255,0.5) !important;
border-radius: 0 !important;
border: 1px solid rgba(138,170,138,0.2) !important;
}
/* ========================================
Dataframe
======================================== */
.dataframe, table {
border-radius: 0 !important;
}
.dataframe th {
background: rgba(178,216,210,0.2) !important;
font-family: 'Cormorant Garamond', 'Georgia', serif !important;
font-weight: 500 !important;
letter-spacing: 0.06em !important;
}
/* ========================================
Audio コンポーネント
======================================== */
.audio-container, [data-testid="audio"] {
border-radius: 0 !important;
border: 1px solid #a8c8a8 !important;
}
/* ========================================
Label (確率表示)
======================================== */
[data-testid="label"] {
border-radius: 0 !important;
}
/* ========================================
セクションタイトル
======================================== */
.section-title {
text-align: center !important;
padding: 8px 16px !important;
}
.section-title h3 {
font-family: 'Cormorant Garamond', 'Georgia', 'Times New Roman', serif !important;
font-size: clamp(13px, 3.5vw, 16px) !important;
font-weight: 400 !important;
color: #4a6a4a !important;
letter-spacing: 0.15em !important;
text-transform: uppercase !important;
margin: 0 !important;
}
/* ========================================
アニメーション
======================================== */
@keyframes softFadeIn {
from { opacity: 0; transform: translateY(12px); }
to { opacity: 1; transform: translateY(0); }
}
/* ========================================
自動録音: ステータス表示
======================================== */
.rec-status {
text-align: center;
padding: 8px 12px;
font-family: 'Cormorant Garamond', 'Georgia', serif;
font-size: 14px;
color: #4a6a4a;
letter-spacing: 0.06em;
}
@keyframes recPulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.6; }
}
/* ========================================
スクロールバー
======================================== */
::-webkit-scrollbar { width: 4px; }
::-webkit-scrollbar-track { background: transparent; }
::-webkit-scrollbar-thumb {
background: rgba(138,170,138,0.3);
border-radius: 4px;
}
/* ========================================
Gradio内部padding/border補正
======================================== */
.block {
border: none !important;
background: transparent !important;
padding: 0 !important;
}
.form {
background: transparent !important;
border: none !important;
}
.container {
background: transparent !important;
}
.tabs {
background: transparent !important;
}
/* ========================================
テキスト色: 全体黒、Primaryボタンのみ白
======================================== */
* {
color: #1a1a1a;
}
button[class*="primary"],
button[class*="primary"] * {
color: #ffffff !important;
}
/* ========================================
レスポンシブ: PC = 中央固定幅
======================================== */
@media (min-width: 768px) {
.gradio-container > .main,
.gradio-container > div > .main {
max-width: 520px !important;
margin: 0 auto !important;
}
.tabitem > div {
margin: 16px auto 28px auto !important;
max-width: 480px !important;
}
div.tab-nav {
max-width: 480px !important;
margin: 16px auto 20px auto !important;
}
}
/* ========================================
スマホ特化
======================================== */
@media (max-width: 767px) {
div.tab-nav {
margin: 12px 12px 16px 12px !important;
padding: 4px !important;
}
div.tab-nav button {
padding: 9px 10px !important;
font-size: 12px !important;
}
.tabitem > div {
margin: 12px 12px 24px 12px !important;
padding: 24px 14px !important;
box-shadow: 6px 6px 0px 0px #c0d8c0 !important;
}
button[class*="primary"], button[class*="secondary"],
button.lg {
width: 100% !important;
padding: 14px 18px !important;
font-size: 14px !important;
}
label span, .label-wrap span {
font-size: 12px !important;
}
input[type="text"], input[type="number"], textarea, select {
font-size: 16px !important;
}
}
"""
with gr.Blocks() as demo:
# ヒーローセクション + リセットボタン
gr.HTML("""
<div class="hero-section">
<h1>Sound Classify</h1>
<p>Streaming ESN &#183; Record &#183; Learn &#183; Predict</p>
</div>
""")
reset_btn = gr.Button("Reset All", size="sm")
with gr.Tabs():
# ── 収集タブ(ラベル管理 + 録音 + データ一覧を統合) ──
with gr.Tab("収集"):
# ラベル追加
gr.Markdown("### ラベル追加")
with gr.Row():
label_box = gr.Textbox(label="新ラベル", placeholder="例: yes", scale=3)
add_btn = gr.Button("追加", size="lg", scale=1)
# 録音 & サンプル追加(自動停止)
gr.Markdown("### 録音")
label_dd = gr.Radio(choices=LABELS, label="ラベル選択", interactive=True, elem_classes=["diamond-radio"])
with gr.Column(elem_id="auto_rec_area"):
audio_rec = gr.Audio(sources=["microphone"], type="numpy", label="マイク(録音→自動停止→自動保存)")
gr.HTML('<div id="rec_status_js" class="rec-status">待機中... 録音ボタンを押してください</div>')
rec_status_md = gr.Markdown("", elem_classes=["rec-status"])
undo_btn = gr.Button("Undo", size="lg")
# データ一覧 & 編集
gr.Markdown("### データ一覧")
table = gr.Dataframe(
headers=["idx", "label", "sec"],
value=dataset_table(),
datatype=["number", "str", "number"],
row_count=(6, "dynamic"),
column_count=(3, "fixed"),
interactive=False,
elem_id="data_table"
)
selected_idx_state = gr.State(None)
replay_audio = gr.Audio(type="numpy", label="選択した音声(再生)", interactive=False)
relabel_dd = gr.Radio(choices=LABELS, label="ラベル修正", interactive=True, elem_classes=["diamond-radio"])
with gr.Row():
relabel_btn = gr.Button("ラベル更新", size="lg")
del_btn = gr.Button("削除", size="lg")
# ── 学習タブ ──
with gr.Tab("学習"):
trials = gr.Slider(3, 30, value=8, step=1, label="Trials")
train_btn = gr.Button("学習", variant="primary", size="lg")
train_log = gr.Textbox(label="学習ログ", interactive=False, lines=2)
# ── 推論タブ ──
with gr.Tab("推論"):
infer_audio = gr.Audio(sources=["microphone"], type="numpy",
label="マイク(録音→無音自動停止→判定)")
gr.HTML('<div class="rec-status">録音ボタンを押して話してください(無音で自動停止)</div>')
pred_box = gr.Textbox(label="推定", interactive=False)
prob_box = gr.Label(label="確率", num_top_classes=10)
# wiring
add_btn.click(add_label_cb, inputs=[label_box], outputs=[label_dd, table, relabel_dd])
undo_btn.click(undo_last_cb, inputs=[], outputs=[table])
reset_btn.click(reset_all_cb, inputs=[], outputs=[table, label_dd, relabel_dd, audio_rec, selected_idx_state])
# 録音完了時に自動保存(stop_recording = 停止ボタン押下)
audio_rec.stop_recording(auto_add_sample_cb, inputs=[audio_rec, label_dd], outputs=[table, audio_rec, rec_status_md])
# select row -> update state + replay + relabel dropdown value
def _select_and_store(evt: gr.SelectData):
if evt is None or evt.index is None:
return None, None, gr.update(value=None)
row = evt.index[0] if isinstance(evt.index, (tuple, list)) else int(evt.index)
if row < 0 or row >= len(DATA):
return None, None, gr.update(value=None)
item = DATA[row]
return row, item["audio"], gr.update(value=item["label"])
table.select(_select_and_store, inputs=None, outputs=[selected_idx_state, replay_audio, relabel_dd])
relabel_btn.click(relabel_selected_cb, inputs=[selected_idx_state, relabel_dd], outputs=[table])
del_btn.click(delete_selected_cb, inputs=[selected_idx_state], outputs=[table, selected_idx_state, relabel_dd])
train_btn.click(train_random, inputs=[trials], outputs=[train_log])
# 推論: 録音停止時に一括推論
infer_audio.stop_recording(
infer_on_stop,
inputs=[infer_audio],
outputs=[pred_box, prob_box, infer_audio],
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
css=CSS,
head=HEAD,
theme=gr.themes.Base(
text_size=gr.themes.sizes.text_md,
font=[gr.themes.GoogleFont("Inter"), gr.themes.GoogleFont("Noto Sans JP")],
).set(
body_text_color="#1a1a1a",
body_text_color_subdued="#4a5a4a",
block_label_text_color="#2a3a2a",
block_title_text_color="#1a1a1a",
checkbox_label_text_color="#1a1a1a",
table_text_color="#1a1a1a",
link_text_color="#3a5a3a",
color_accent_soft="#c8e0d0",
input_background_fill="#ffffff",
input_background_fill_dark="#ffffff",
input_border_color="#8aaa8a",
input_border_color_dark="#8aaa8a",
),
)