dusen0528 commited on
Commit
02117d5
ยท
verified ยท
1 Parent(s): fc2e8fe

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Kws Fp Test
3
- emoji: โšก
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.7.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: KWS FP Test
3
+ emoji: ๐ŸŽค
4
+ colorFrom: green
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 6.5.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ KWS FP Test Space
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from gradio_kws_test import build_demo
2
+
3
+ demo = build_demo()
4
+
5
+ if __name__ == "__main__":
6
+ demo.launch()
checkpoint_io.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # checkpoint_io.py: KWS ์ฒดํฌํฌ์ธํŠธ Safetensors ์ €์žฅ/๋กœ๋“œ + ๊ธฐ์กด .pth ํ˜ธํ™˜.
2
+ # Single Source of Truth: ํ•™์Šต์€ .safetensors + training_config.json, ๋กœ๋“œ๋Š” .safetensors ์šฐ์„  ํ›„ .pth ํด๋ฐฑ.
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import os
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import torch
10
+
11
+
12
+ def resolve_checkpoint_path(path: str) -> Tuple[str, str]:
13
+ """
14
+ ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ์—์„œ (state ํŒŒ์ผ ๊ฒฝ๋กœ, config ํŒŒ์ผ ๊ฒฝ๋กœ) ๋ฐ˜ํ™˜.
15
+ .safetensors ์šฐ์„ , ์—†์œผ๋ฉด .pth. config๋Š” ํ•ญ์ƒ ๊ฐ™์€ ๋””๋ ‰ํ„ฐ๋ฆฌ์˜ training_config.json.
16
+ """
17
+ path = os.path.normpath(path)
18
+ if os.path.isdir(path):
19
+ base_dir = path
20
+ base_name = "best_kws_model"
21
+ else:
22
+ base_dir = os.path.dirname(path) or "."
23
+ base_name = os.path.splitext(os.path.basename(path))[0] or "best_kws_model"
24
+ config_path = os.path.join(base_dir, "training_config.json")
25
+ for ext in (".safetensors", ".pth"):
26
+ candidate = os.path.join(base_dir, base_name + ext)
27
+ if os.path.isfile(candidate):
28
+ return candidate, config_path
29
+ return os.path.join(base_dir, base_name + ".pth"), config_path
30
+
31
+
32
+ def load_state_dict(path: str, map_location: Optional[Any] = None) -> Dict[str, torch.Tensor]:
33
+ """
34
+ path(๋˜๋Š” ๋™์ผ ๋””๋ ‰ํ„ฐ๋ฆฌ ๋‚ด best_kws_model.*)์—์„œ state_dict๋งŒ ๋กœ๋“œ.
35
+ .safetensors ์šฐ์„ , ์—†์œผ๋ฉด .pth (์ „์ฒด ์ฒดํฌํฌ์ธํŠธ๋ฉด 'model' ํ‚ค ์‚ฌ์šฉ).
36
+ """
37
+ state_path, _ = resolve_checkpoint_path(path)
38
+ if not os.path.isfile(state_path):
39
+ raise FileNotFoundError(f"์ฒดํฌํฌ์ธํŠธ ์—†์Œ: {state_path}")
40
+ if state_path.endswith(".safetensors"):
41
+ from safetensors.torch import load_file
42
+ device: str
43
+ if map_location is None:
44
+ device = "cpu"
45
+ elif isinstance(map_location, torch.device):
46
+ if map_location.type == "cuda":
47
+ device = f"cuda:{map_location.index}" if map_location.index is not None else "cuda"
48
+ else:
49
+ device = map_location.type
50
+ elif isinstance(map_location, str):
51
+ device = map_location
52
+ else:
53
+ device = "cpu"
54
+ return load_file(state_path, device=device)
55
+ ckpt = torch.load(state_path, map_location=map_location)
56
+ if isinstance(ckpt, dict) and "model" in ckpt:
57
+ return ckpt["model"]
58
+ return ckpt
59
+
60
+
61
+ def load_label_map(path: str) -> Optional[Dict[str, int]]:
62
+ """๊ฐ™์€ ๋””๋ ‰ํ„ฐ๋ฆฌ์˜ training_config.json์—์„œ label_map ๋ฐ˜ํ™˜. ์—†์œผ๋ฉด None."""
63
+ _, config_path = resolve_checkpoint_path(path)
64
+ if not os.path.isfile(config_path):
65
+ return None
66
+ try:
67
+ with open(config_path, "r", encoding="utf-8") as f:
68
+ cfg = json.load(f)
69
+ return cfg.get("label_map")
70
+ except Exception:
71
+ return None
72
+
73
+
74
+ def load_checkpoint(
75
+ path: str,
76
+ map_location: Optional[Any] = None,
77
+ default_label_map: Optional[Dict[str, int]] = None,
78
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], List[str]]:
79
+ """
80
+ state_dict, label_map, class_names ๋ฐ˜ํ™˜.
81
+ label_map์€ training_config.json ์šฐ์„ , ์—†์œผ๋ฉด .pth ๋‚ด๋ถ€, ์—†์œผ๋ฉด default_label_map.
82
+ """
83
+ state = load_state_dict(path, map_location=map_location)
84
+ label_map = load_label_map(path)
85
+ if label_map is None:
86
+ state_path, _ = resolve_checkpoint_path(path)
87
+ if state_path.endswith(".pth") and os.path.isfile(state_path):
88
+ ckpt = torch.load(state_path, map_location=map_location)
89
+ if isinstance(ckpt, dict) and "label_map" in ckpt:
90
+ label_map = ckpt["label_map"]
91
+ if label_map is None:
92
+ label_map = default_label_map or {"normal": 0, "help_me": 1, "save_me": 2}
93
+ num_classes = len(label_map)
94
+ id_to_name = {v: k for k, v in label_map.items()}
95
+ class_names: List[str] = [id_to_name[i] for i in range(num_classes)]
96
+ return state, label_map, class_names
97
+
98
+
99
+ def save_checkpoint(
100
+ output_dir: str,
101
+ state_dict: Dict[str, torch.Tensor],
102
+ config_dict: Dict[str, Any],
103
+ *,
104
+ also_save_pth: bool = False,
105
+ ) -> None:
106
+ """
107
+ state_dict๋ฅผ best_kws_model.safetensors๋กœ, config_dict๋ฅผ training_config.json์œผ๋กœ ์ €์žฅ.
108
+ also_save_pth=True๋ฉด ๊ธฐ์กด ํ˜ธํ™˜์šฉ best_kws_model.pth๋„ ์ €์žฅ (state_dict๋งŒ).
109
+ """
110
+ os.makedirs(output_dir, exist_ok=True)
111
+ sf_path = os.path.join(output_dir, "best_kws_model.safetensors")
112
+ from safetensors.torch import save_file
113
+ save_file(state_dict, sf_path)
114
+ config_path = os.path.join(output_dir, "training_config.json")
115
+ with open(config_path, "w", encoding="utf-8") as f:
116
+ json.dump(config_dict, f, indent=2, ensure_ascii=False)
117
+ if also_save_pth:
118
+ pth_path = os.path.join(output_dir, "best_kws_model.pth")
119
+ torch.save(state_dict, pth_path)
gradio_kws_test.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import gradio as gr
9
+ import librosa
10
+ import numpy as np
11
+
12
+ from kws_inference import KWSLongInference, SR_MODEL
13
+ from train_kws import WINDOW_SEC
14
+
15
+
16
+ def _default_model_path() -> str:
17
+ candidates = [
18
+ Path("./kws_models_fpfix/best_kws_model.safetensors"),
19
+ Path("./kws_models/best_kws_model.safetensors"),
20
+ Path("./kws_models_fixed/best_kws_model.safetensors"),
21
+ Path("./kws_models/best_kws_model.pth"),
22
+ ]
23
+ for p in candidates:
24
+ if p.exists():
25
+ return str(p)
26
+ return str(candidates[0])
27
+
28
+
29
+ _ENGINE_CACHE: dict[tuple[str, float, float], KWSLongInference] = {}
30
+
31
+
32
+ def _get_engine(model_path: str, step_sec: float, rms_threshold: float) -> KWSLongInference:
33
+ key = (str(Path(model_path).resolve()), float(step_sec), float(rms_threshold))
34
+ if key not in _ENGINE_CACHE:
35
+ _ENGINE_CACHE[key] = KWSLongInference(
36
+ checkpoint_path=model_path,
37
+ step_sec=step_sec,
38
+ rms_threshold=rms_threshold,
39
+ )
40
+ return _ENGINE_CACHE[key]
41
+
42
+
43
+ def _simulate_alerts(
44
+ window_results: list[Any],
45
+ threshold: float,
46
+ n_consecutive: int,
47
+ cooldown_sec: float,
48
+ ) -> list[dict[str, Any]]:
49
+ alerts: list[dict[str, Any]] = []
50
+ consecutive = 0
51
+ last_alert_sec = -1e9
52
+ for w in window_results:
53
+ help_p = float(w.probs.get("help_me", 0.0))
54
+ save_p = float(w.probs.get("save_me", 0.0))
55
+ kw_p = max(help_p, save_p)
56
+ kw_label = "save_me" if save_p >= help_p else "help_me"
57
+ in_cooldown = (w.start_sec - last_alert_sec) < cooldown_sec
58
+ if in_cooldown:
59
+ consecutive = 0
60
+ continue
61
+ if kw_p >= threshold:
62
+ consecutive += 1
63
+ else:
64
+ consecutive = 0
65
+ if consecutive >= n_consecutive:
66
+ alerts.append(
67
+ {
68
+ "t_sec": round(float(w.start_sec), 2),
69
+ "label": kw_label,
70
+ "score": round(float(kw_p), 4),
71
+ "rule": f"{n_consecutive}x >= {threshold:.2f}",
72
+ }
73
+ )
74
+ last_alert_sec = float(w.start_sec)
75
+ consecutive = 0
76
+ return alerts
77
+
78
+
79
+ def run_test(
80
+ audio_path: str,
81
+ model_path: str,
82
+ step_sec: float,
83
+ rms_threshold: float,
84
+ alert_threshold: float,
85
+ n_consecutive: int,
86
+ cooldown_sec: float,
87
+ ):
88
+ if not audio_path:
89
+ raise gr.Error("์˜ค๋””์˜ค ํŒŒ์ผ(๋˜๋Š” ๋งˆ์ดํฌ ์ž…๋ ฅ)์„ ๋„ฃ์–ด์ฃผ์„ธ์š”.")
90
+ if not model_path or not Path(model_path).exists():
91
+ raise gr.Error(f"๋ชจ๋ธ ํŒŒ์ผ์ด ์—†์Šต๋‹ˆ๋‹ค: {model_path}")
92
+
93
+ engine = _get_engine(model_path=model_path, step_sec=step_sec, rms_threshold=rms_threshold)
94
+ result = engine.predict_long(audio_path)
95
+
96
+ rows: list[list[Any]] = []
97
+ for i, w in enumerate(result.window_results, start=1):
98
+ rows.append(
99
+ [
100
+ i,
101
+ round(float(w.start_sec), 3),
102
+ round(float(w.end_sec), 3),
103
+ w.class_name,
104
+ round(float(w.confidence), 4),
105
+ round(float(w.probs.get("normal", 0.0)), 4),
106
+ round(float(w.probs.get("help_me", 0.0)), 4),
107
+ round(float(w.probs.get("save_me", 0.0)), 4),
108
+ ]
109
+ )
110
+
111
+ alerts = _simulate_alerts(
112
+ window_results=result.window_results,
113
+ threshold=alert_threshold,
114
+ n_consecutive=n_consecutive,
115
+ cooldown_sec=cooldown_sec,
116
+ )
117
+
118
+ agg = result.aggregated_probs
119
+ summary = (
120
+ f"### ์ถ”๋ก  ์š”์•ฝ\n"
121
+ f"- Aggregated label: `{result.aggregated_label}`\n"
122
+ f"- Duration: `{result.duration_sec:.2f}s`\n"
123
+ f"- Aggregated probs: normal `{agg.get('normal', 0.0):.4f}` | "
124
+ f"help_me `{agg.get('help_me', 0.0):.4f}` | "
125
+ f"save_me `{agg.get('save_me', 0.0):.4f}`\n"
126
+ f"- Windows: `{len(result.window_results)}` (step `{step_sec:.2f}s`, rms_threshold `{rms_threshold:.4f}`)"
127
+ )
128
+
129
+ if alerts:
130
+ alert_lines = [
131
+ f"- t={a['t_sec']:.2f}s | `{a['label']}` | score={a['score']:.4f} | {a['rule']}"
132
+ for a in alerts
133
+ ]
134
+ alerts_md = "### ๊ฒฝ๋ณด ์‹œ๋ฎฌ๋ ˆ์ด์…˜\n" + "\n".join(alert_lines)
135
+ else:
136
+ alerts_md = "### ๊ฒฝ๋ณด ์‹œ๋ฎฌ๋ ˆ์ด์…˜\n- ์กฐ๊ฑด์„ ๋งŒ์กฑํ•œ ๊ฒฝ๋ณด ์—†์Œ"
137
+
138
+ details = {
139
+ "aggregated_label": result.aggregated_label,
140
+ "aggregated_probs": result.aggregated_probs,
141
+ "duration_sec": result.duration_sec,
142
+ "num_windows": len(result.window_results),
143
+ "alerts": alerts,
144
+ }
145
+ return summary, alerts_md, rows, details
146
+
147
+
148
+ def _init_stream_state() -> dict[str, Any]:
149
+ return {
150
+ "buffer": np.array([], dtype=np.float32),
151
+ "consecutive": 0,
152
+ "last_alert_time": -1e9,
153
+ "events": [],
154
+ "last_probs": {"normal": 1.0, "help_me": 0.0, "save_me": 0.0},
155
+ "last_decision": "NORMAL",
156
+ }
157
+
158
+
159
+ def _normalize_chunk(chunk: Any) -> tuple[int, np.ndarray] | None:
160
+ if chunk is None:
161
+ return None
162
+ if isinstance(chunk, np.ndarray):
163
+ y = np.asarray(chunk, dtype=np.float32)
164
+ if y.ndim == 2:
165
+ y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
166
+ return SR_MODEL, y.astype(np.float32)
167
+ if isinstance(chunk, (tuple, list)) and len(chunk) == 2:
168
+ sr, y = chunk
169
+ y = np.asarray(y, dtype=np.float32)
170
+ if y.ndim == 2:
171
+ y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
172
+ if y.dtype == np.int16:
173
+ y = y.astype(np.float32) / 32768.0
174
+ elif y.dtype == np.int32:
175
+ y = y.astype(np.float32) / 2147483648.0
176
+ return int(sr), y.astype(np.float32)
177
+ if isinstance(chunk, dict):
178
+ # gradio ๋ฒ„์ „์— ๋”ฐ๋ผ {"sample_rate": ..., "data": ...} ํ˜น์€ {"path": ...} ํ˜•ํƒœ๊ฐ€ ์˜ฌ ์ˆ˜ ์žˆ์Œ
179
+ if "sample_rate" in chunk and ("data" in chunk or "array" in chunk):
180
+ sr = int(chunk.get("sample_rate", SR_MODEL))
181
+ y = np.asarray(chunk.get("data", chunk.get("array")), dtype=np.float32)
182
+ if y.ndim == 2:
183
+ y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
184
+ return sr, y.astype(np.float32)
185
+ if "path" in chunk and chunk["path"]:
186
+ try:
187
+ y, sr = librosa.load(str(chunk["path"]), sr=None, mono=True)
188
+ return int(sr), y.astype(np.float32)
189
+ except Exception:
190
+ return None
191
+ return None
192
+
193
+
194
+ def stream_infer(
195
+ state: dict[str, Any] | None,
196
+ chunk: Any,
197
+ model_path: str,
198
+ step_sec: float,
199
+ rms_threshold: float,
200
+ alert_threshold: float,
201
+ margin_threshold: float,
202
+ n_consecutive: int,
203
+ cooldown_sec: float,
204
+ ):
205
+ st = state if isinstance(state, dict) else _init_stream_state()
206
+ parsed = _normalize_chunk(chunk)
207
+ if parsed is None:
208
+ probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
209
+ return st, "๋Œ€๊ธฐ ์ค‘...", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
210
+
211
+ sr, y = parsed
212
+ if len(y) == 0:
213
+ probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
214
+ return st, "์ž…๋ ฅ ์—†์Œ", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
215
+
216
+ if sr != SR_MODEL:
217
+ y = librosa.resample(y, orig_sr=sr, target_sr=SR_MODEL)
218
+
219
+ buffer = np.concatenate((st["buffer"], y))
220
+ max_len = int(SR_MODEL * WINDOW_SEC)
221
+ if len(buffer) > max_len:
222
+ buffer = buffer[-max_len:]
223
+ st["buffer"] = buffer
224
+
225
+ if len(buffer) < int(SR_MODEL * 0.6):
226
+ probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
227
+ return st, "๋ฐ์ดํ„ฐ ์ˆ˜์ง‘ ์ค‘...", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
228
+
229
+ engine = _get_engine(model_path=model_path, step_sec=step_sec, rms_threshold=rms_threshold)
230
+ result = engine.predict_long((SR_MODEL, buffer))
231
+ probs = {
232
+ "normal": float(result.aggregated_probs.get("normal", 0.0)),
233
+ "help_me": float(result.aggregated_probs.get("help_me", 0.0)),
234
+ "save_me": float(result.aggregated_probs.get("save_me", 0.0)),
235
+ }
236
+ st["last_probs"] = probs
237
+
238
+ now = time.time()
239
+ help_p, save_p = probs["help_me"], probs["save_me"]
240
+ normal_p = probs["normal"]
241
+ kw_score = max(help_p, save_p)
242
+ kw_label = "save_me" if save_p >= help_p else "help_me"
243
+ margin = kw_score - normal_p
244
+ in_cooldown = (now - float(st["last_alert_time"])) < cooldown_sec
245
+
246
+ meets_score = kw_score >= alert_threshold
247
+ meets_margin = margin >= margin_threshold
248
+
249
+ if not in_cooldown and meets_score and meets_margin:
250
+ st["consecutive"] = int(st["consecutive"]) + 1
251
+ else:
252
+ st["consecutive"] = 0
253
+
254
+ status = (
255
+ f"normal={probs['normal']:.3f} | help_me={probs['help_me']:.3f} | "
256
+ f"save_me={probs['save_me']:.3f} | margin={margin:.3f} | "
257
+ f"consec={st['consecutive']}/{n_consecutive}"
258
+ )
259
+ if in_cooldown:
260
+ remain = max(0.0, cooldown_sec - (now - float(st["last_alert_time"])))
261
+ status = f"์ฟจ๋‹ค์šด {remain:.1f}s | " + status
262
+
263
+ decision = "NORMAL"
264
+ probs_txt = f"n={normal_p:.3f} h={help_p:.3f} s={save_p:.3f} m={margin:.3f}"
265
+
266
+ if not in_cooldown and meets_score and not meets_margin:
267
+ decision = "HOLD"
268
+ event = (
269
+ f"{time.strftime('%H:%M:%S')} HOLD {kw_label} "
270
+ f"score={kw_score:.3f} {probs_txt}"
271
+ )
272
+ st["events"] = ([event] + list(st.get("events", [])))[:30]
273
+ elif in_cooldown:
274
+ decision = "COOLDOWN"
275
+ elif meets_score and meets_margin:
276
+ decision = "CANDIDATE"
277
+
278
+ if not in_cooldown and int(st["consecutive"]) >= int(n_consecutive):
279
+ st["last_alert_time"] = now
280
+ st["consecutive"] = 0
281
+ event = f"{time.strftime('%H:%M:%S')} ALERT {kw_label} score={kw_score:.3f} {probs_txt}"
282
+ st["events"] = ([event] + list(st.get("events", [])))[:30]
283
+ status = "ALERT ๋ฐœ์ƒ | " + status
284
+ decision = "ALERT"
285
+
286
+ st["last_decision"] = decision
287
+ return st, status, decision, probs, list(st.get("events", []))
288
+
289
+
290
+ def build_demo() -> gr.Blocks:
291
+ with gr.Blocks(title="KWS FP Test Page") as demo:
292
+ gr.Markdown("# KWS FP Test Page")
293
+ gr.Markdown("ํŒŒ์ผ ํ…Œ์ŠคํŠธ + ๋ธŒ๋ผ์šฐ์ € ๋งˆ์ดํฌ ์‹ค์‹œ๊ฐ„ ๋ถ„์„")
294
+
295
+ model_path = gr.Textbox(label="๋ชจ๋ธ ๊ฒฝ๋กœ", value=_default_model_path())
296
+
297
+ with gr.Row():
298
+ step_sec = gr.Slider(label="์Šฌ๋ผ์ด๋”ฉ ์Šคํ…(์ดˆ)", minimum=0.1, maximum=1.0, value=0.25, step=0.05)
299
+ rms_threshold = gr.Slider(label="RMS ์ž„๊ณ„๊ฐ’", minimum=0.0, maximum=0.05, value=0.005, step=0.001)
300
+ alert_threshold = gr.Slider(label="๊ฒฝ๋ณด ์ž„๊ณ„๊ฐ’", minimum=0.5, maximum=0.99, value=0.9, step=0.01)
301
+ margin_threshold = gr.Slider(label="๋งˆ์ง„ ์ž„๊ณ„๊ฐ’(ํ‚ค์›Œ๋“œ-normal)", minimum=0.0, maximum=0.6, value=0.2, step=0.01)
302
+ n_consecutive = gr.Slider(label="์—ฐ์† ํšŸ์ˆ˜", minimum=1, maximum=5, value=3, step=1)
303
+ cooldown_sec = gr.Slider(label="์ฟจ๋‹ค์šด(์ดˆ)", minimum=0, maximum=10, value=5, step=0.5)
304
+
305
+ with gr.Tab("ํŒŒ์ผ ํ…Œ์ŠคํŠธ"):
306
+ audio = gr.Audio(label="ํ…Œ์ŠคํŠธ ์˜ค๋””์˜ค", type="filepath", sources=["upload", "microphone"])
307
+ run_btn = gr.Button("ํ…Œ์ŠคํŠธ ์‹คํ–‰", variant="primary")
308
+ summary = gr.Markdown()
309
+ alerts_md = gr.Markdown()
310
+ table = gr.Dataframe(
311
+ headers=["idx", "start_sec", "end_sec", "pred", "conf", "normal", "help_me", "save_me"],
312
+ datatype=["number", "number", "number", "str", "number", "number", "number", "number"],
313
+ label="์œˆ๋„์šฐ๋ณ„ ๊ฒฐ๊ณผ",
314
+ )
315
+ details = gr.JSON(label="์ƒ์„ธ JSON")
316
+ run_btn.click(
317
+ fn=run_test,
318
+ inputs=[audio, model_path, step_sec, rms_threshold, alert_threshold, n_consecutive, cooldown_sec],
319
+ outputs=[summary, alerts_md, table, details],
320
+ )
321
+
322
+ with gr.Tab("์‹ค์‹œ๊ฐ„ ๋งˆ์ดํฌ"):
323
+ gr.Markdown("๋งˆ์ดํฌ๋ฅผ ์ผ  ์ƒํƒœ์—์„œ ์‹ค์‹œ๊ฐ„์œผ๋กœ ํ™•๋ฅ /๊ฒฝ๋ณด ์ด๋ ฅ์„ ๋ด…๋‹ˆ๋‹ค.")
324
+ st = gr.State(value=_init_stream_state())
325
+ live_audio = gr.Audio(label="๋งˆ์ดํฌ ์‹ค์‹œ๊ฐ„ ์ž…๋ ฅ", type="numpy", sources=["microphone"], streaming=True)
326
+ live_status = gr.Textbox(label="์‹ค์‹œ๊ฐ„ ์ƒํƒœ")
327
+ live_decision = gr.Textbox(label="์ตœ์ข… ํŒ๋‹จ")
328
+ live_probs = gr.JSON(label="ํ˜„์žฌ ํ™•๋ฅ ")
329
+ live_events = gr.JSON(label="๊ฒฝ๋ณด ์ด๋ ฅ(์ตœ์‹ ์ˆœ)")
330
+ live_audio.stream(
331
+ fn=stream_infer,
332
+ inputs=[
333
+ st,
334
+ live_audio,
335
+ model_path,
336
+ step_sec,
337
+ rms_threshold,
338
+ alert_threshold,
339
+ margin_threshold,
340
+ n_consecutive,
341
+ cooldown_sec,
342
+ ],
343
+ outputs=[st, live_status, live_decision, live_probs, live_events],
344
+ show_progress=False,
345
+ time_limit=300,
346
+ )
347
+ return demo
348
+
349
+
350
+ def main() -> None:
351
+ parser = argparse.ArgumentParser(description="KWS FP ํ…Œ์ŠคํŠธ์šฉ Gradio ํŽ˜์ด์ง€")
352
+ parser.add_argument("--host", type=str, default="0.0.0.0")
353
+ parser.add_argument("--port", type=int, default=7861)
354
+ parser.add_argument("--share", action="store_true")
355
+ args = parser.parse_args()
356
+
357
+ demo = build_demo()
358
+ demo.launch(server_name=args.host, server_port=args.port, share=args.share)
359
+
360
+
361
+ if __name__ == "__main__":
362
+ main()
kws_inference.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KWS ๊ธด ์˜ค๋””์˜ค ์ถ”๋ก : ์Šฌ๋ผ์ด๋”ฉ ์œˆ๋„์šฐ + Max-over-windows ์ง‘๊ณ„.
3
+ [์—…๋ฐ์ดํŠธ] ์†Œ๋ฆฌ ํฌ๊ธฐ(RMS) ํ•„ํ„ฐ๋ง ์ถ”๊ฐ€๋กœ ๋ฌด์Œ/๋…ธ์ด์ฆˆ ๊ตฌ๊ฐ„ ์˜คํƒ ๋ฐฉ์ง€.
4
+ """
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ from dataclasses import dataclass
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import librosa
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ # train_kws์™€ ๋™์ผ ํŒŒ๋ผ๋ฏธํ„ฐ๋กœ melยท๋ชจ๋ธ ํ˜ธํ™˜ ์œ ์ง€
17
+ from train_kws import (
18
+ HOP_LENGTH,
19
+ KWSModel,
20
+ MAX_TIME_FRAMES,
21
+ N_FFT,
22
+ N_MELS,
23
+ SR_MODEL,
24
+ WINDOW_SEC,
25
+ )
26
+
27
+ # ํ•™์Šต ์‹œ KWSDataset._preprocess์™€ ๋™์ผ (์ผ๋ฐ˜์Œ์„ฑโ†’์œ„๊ธ‰ ์˜คํƒ ๋ฐฉ์ง€)
28
+ RMS_NORM_TARGET = 0.05
29
+ PRE_EMPHASIS = 0.97
30
+
31
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ @dataclass
34
+ class WindowResult:
35
+ """ํ•œ ์œˆ๋„์šฐ(๊ตฌ๊ฐ„) ์ถ”๋ก  ๊ฒฐ๊ณผ."""
36
+ start_sec: float
37
+ end_sec: float
38
+ pred_id: int
39
+ class_name: str
40
+ probs: Dict[str, float]
41
+ confidence: float
42
+
43
+ @dataclass
44
+ class LongAudioResult:
45
+ """๊ธด ์˜ค๋””์˜ค ์ „์ฒด ์ถ”๋ก  ๊ฒฐ๊ณผ."""
46
+ aggregated_label: str
47
+ aggregated_probs: Dict[str, float]
48
+ aggregated_pred_id: int
49
+ window_results: List[WindowResult]
50
+ duration_sec: float
51
+
52
+ def _load_checkpoint(path: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], List[str]]:
53
+ from checkpoint_io import load_checkpoint
54
+ state, label_map, class_names = load_checkpoint(path, map_location=DEVICE)
55
+ return state, label_map, class_names
56
+
57
+ def _preprocess_wav(y: np.ndarray) -> np.ndarray:
58
+ """ํ•™์Šต KWSDataset._preprocess์™€ ๋™์ผ: RMS norm + pre-emphasis (์„œ๋ฒ„ ์ถ”๋ก  ์˜คํƒ ๊ฐ์†Œ)."""
59
+ y = np.asarray(y, dtype=np.float32)
60
+ if len(y) == 0:
61
+ return y
62
+ rms = np.sqrt(np.mean(y ** 2)) + 1e-8
63
+ y = y * (RMS_NORM_TARGET / rms)
64
+ y = np.append(y[0], y[1:] - PRE_EMPHASIS * y[:-1]).astype(np.float32)
65
+ return y
66
+
67
+
68
+ def _audio_to_mel(y: np.ndarray, sr: int) -> np.ndarray:
69
+ if sr != SR_MODEL:
70
+ y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=SR_MODEL)
71
+ y = _preprocess_wav(y)
72
+ S = librosa.feature.melspectrogram(
73
+ y=y, sr=SR_MODEL, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
74
+ )
75
+ S_db = np.clip(librosa.power_to_db(S, ref=1.0), -80.0, 0.0)
76
+ norm_mel = (S_db + 80.0) / 80.0
77
+ return norm_mel.astype(np.float32)
78
+
79
+ def _mel_window_to_tensor(mel: np.ndarray, start_f: int, end_f: int) -> torch.Tensor:
80
+ chunk = mel[:, start_f:end_f]
81
+ n_f = chunk.shape[1]
82
+ if n_f >= MAX_TIME_FRAMES:
83
+ chunk = chunk[:, :MAX_TIME_FRAMES]
84
+ else:
85
+ chunk = np.pad(chunk, ((0, 0), (0, MAX_TIME_FRAMES - n_f)), mode="constant", constant_values=0)
86
+ t = torch.from_numpy(chunk).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
87
+ return t
88
+
89
+ def _frame_to_sec(frame_index: int) -> float:
90
+ return float(frame_index * HOP_LENGTH) / SR_MODEL
91
+
92
+ class KWSLongInference:
93
+ def __init__(
94
+ self,
95
+ checkpoint_path: str,
96
+ window_sec: float = WINDOW_SEC,
97
+ step_sec: float = 0.25,
98
+ rms_threshold: float = 0.005, # ์†Œ๋ฆฌ ํฌ๊ธฐ ์ž„๊ณ„๊ฐ’ ํ•˜ํ–ฅ ์กฐ์ •
99
+ ) -> None:
100
+ self.checkpoint_path = checkpoint_path
101
+ self.window_sec = window_sec
102
+ self.step_sec = step_sec
103
+ self.rms_threshold = rms_threshold
104
+
105
+ state, label_map, class_names = _load_checkpoint(checkpoint_path)
106
+ self.label_map = label_map
107
+ self.class_names = class_names
108
+ self.num_classes = len(class_names)
109
+ self.model = KWSModel(num_classes=self.num_classes).to(DEVICE)
110
+ self.model.load_state_dict(state, strict=True)
111
+ self.model.eval()
112
+
113
+ self.window_frames = MAX_TIME_FRAMES
114
+ self.step_frames = max(1, int(round(step_sec * SR_MODEL / HOP_LENGTH)))
115
+
116
+ def _get_windows(self, total_frames: int) -> List[Tuple[int, int]]:
117
+ windows: List[Tuple[int, int]] = []
118
+ start = 0
119
+ while start + self.window_frames <= total_frames:
120
+ windows.append((start, start + self.window_frames))
121
+ start += self.step_frames
122
+ return windows
123
+
124
+ def predict_long(self, audio: Union[str, Tuple[int, np.ndarray]]) -> LongAudioResult:
125
+ if isinstance(audio, str):
126
+ y, sr = librosa.load(audio, sr=SR_MODEL)
127
+ else:
128
+ sr, y = audio
129
+ y = np.asarray(y, dtype=np.float32)
130
+ if sr != SR_MODEL:
131
+ y = librosa.resample(y, orig_sr=sr, target_sr=SR_MODEL)
132
+
133
+ duration_sec = len(y) / SR_MODEL
134
+ mel = _audio_to_mel(y, SR_MODEL)
135
+ total_frames = mel.shape[1]
136
+ windows = self._get_windows(total_frames)
137
+
138
+ all_probs: List[np.ndarray] = []
139
+ window_results: List[WindowResult] = []
140
+
141
+ with torch.no_grad():
142
+ for start_f, end_f in windows:
143
+ # ์œˆ๋„์šฐ ๊ตฌ๊ฐ„์˜ ์˜ค๋””์˜ค ์ƒ˜ํ”Œ ์ถ”์ถœํ•˜์—ฌ RMS ๊ณ„์‚ฐ
144
+ start_sample = int(start_f * HOP_LENGTH)
145
+ end_sample = int(end_f * HOP_LENGTH)
146
+ y_chunk = y[start_sample:end_sample]
147
+ rms = np.sqrt(np.mean(y_chunk**2)) if len(y_chunk) > 0 else 0
148
+
149
+ start_sec = _frame_to_sec(start_f)
150
+ end_sec = _frame_to_sec(end_f)
151
+
152
+ if rms < self.rms_threshold:
153
+ # [RMS ํ•„ํ„ฐ๋ง] ์†Œ๋ฆฌ๊ฐ€ ๋„ˆ๋ฌด ์ž‘์œผ๋ฉด ๋ฌด์กฐ๊ฑด Normal๋กœ ์ฒ˜๋ฆฌ
154
+ probs = np.zeros(self.num_classes)
155
+ probs[self.label_map.get("normal", 0)] = 1.0
156
+ pred_id = self.label_map.get("normal", 0)
157
+ else:
158
+ # ๋ชจ๋ธ ์ถ”๋ก 
159
+ t = _mel_window_to_tensor(mel, start_f, end_f)
160
+ logits = self.model(t)
161
+ probs = F.softmax(logits, dim=1).cpu().numpy()[0]
162
+ pred_id = int(np.argmax(probs))
163
+
164
+ all_probs.append(probs)
165
+ conf = float(probs[pred_id])
166
+ probs_dict = {name: float(probs[i]) for i, name in enumerate(self.class_names)}
167
+
168
+ window_results.append(
169
+ WindowResult(
170
+ start_sec=start_sec,
171
+ end_sec=end_sec,
172
+ pred_id=pred_id,
173
+ class_name=self.class_names[pred_id],
174
+ probs=probs_dict,
175
+ confidence=conf,
176
+ )
177
+ )
178
+
179
+ if not all_probs:
180
+ return LongAudioResult("normal", {}, 0, [], duration_sec)
181
+
182
+ # Max over windows ์ง‘๊ณ„
183
+ stacked = np.array(all_probs)
184
+ max_per_class = np.max(stacked, axis=0)
185
+ agg_pred_id = int(np.argmax(max_per_class))
186
+ agg_label = self.class_names[agg_pred_id]
187
+ agg_probs = {name: float(max_per_class[i]) for i, name in enumerate(self.class_names)}
188
+
189
+ return LongAudioResult(agg_label, agg_probs, agg_pred_id, window_results, duration_sec)
190
+
191
+ def run_long_inference(
192
+ wav_path: str,
193
+ checkpoint_path: str = "kws_final_model/best_kws_model.pth",
194
+ step_sec: float = 0.25,
195
+ ) -> LongAudioResult:
196
+ engine = KWSLongInference(checkpoint_path=checkpoint_path, step_sec=step_sec)
197
+ return engine.predict_long(wav_path)
kws_models_fpfix/best_kws_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77212523c6e434385d36c1bd2a972f8a49706af0125357b4f6070d221f5d3283
3
+ size 1693516
kws_models_fpfix/training_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "manifest": "./dataset_hf/manifests/kws.jsonl",
3
+ "dataset_root": "./dataset_hf",
4
+ "data_dir": "./kws_data",
5
+ "output_dir": "./kws_models_fpfix",
6
+ "epochs": 80,
7
+ "lr": 0.0005,
8
+ "batch_size": 64,
9
+ "keywords": [
10
+ "help_me",
11
+ "save_me"
12
+ ],
13
+ "include_normal": true,
14
+ "normal_dir": "/home/dusen0528/LastResNet/data",
15
+ "trim_silence": false,
16
+ "rms_norm": 0.05,
17
+ "augment_noise_scale": 0.005,
18
+ "augment_shift": 0.1,
19
+ "augment_pitch": "0,2",
20
+ "augment_time_stretch": "0.9,1.1",
21
+ "spec_augment_on": true,
22
+ "fast_augment": true,
23
+ "mel_gpu": true,
24
+ "class_weight": "inverse",
25
+ "dropout": 0.3,
26
+ "weight_decay": 0.0001,
27
+ "early_stop_patience": 20,
28
+ "lr_scheduler_patience": 5,
29
+ "hf_repo": "",
30
+ "hf_token": "",
31
+ "wandb_project": "kws-fpfix",
32
+ "wandb_entity": "",
33
+ "amp": true,
34
+ "workers": 4,
35
+ "best_epoch": 51,
36
+ "val_recall_positive_min": 0.926829268292683,
37
+ "class_names": [
38
+ "normal",
39
+ "help_me",
40
+ "save_me"
41
+ ],
42
+ "label_map": {
43
+ "normal": 0,
44
+ "help_me": 1,
45
+ "save_me": 2
46
+ },
47
+ "normal_dir_added": 1000
48
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=6.5.1
2
+ numpy
3
+ librosa
4
+ torch
5
+ safetensors
6
+ scikit-learn
7
+ wandb
8
+ tqdm
train_kws.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Keyword Spotting (KWS) ๋ชจ๋ธ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ - Production Optimized
3
+ - WandB Sweep ์—ฐ๋™ (augment_shift, noise, class_weight ๋“ฑ)
4
+ - False Positive Rate (์˜คํƒ๋ฅ ) ๋ฉ”ํŠธ๋ฆญ ์ถ”๊ฐ€
5
+ - Audio Shift Augmentation ๊ตฌํ˜„
6
+ """
7
+ import os
8
+ import json
9
+ import random
10
+ from datetime import datetime
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import numpy as np
15
+ import librosa
16
+ import argparse
17
+ from pathlib import Path
18
+ from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score, confusion_matrix
19
+ from sklearn.model_selection import train_test_split
20
+ import wandb
21
+ from tqdm import tqdm
22
+
23
+ # --- Global Config ---
24
+ SR_MODEL = 16000
25
+ N_MELS = 128
26
+ N_FFT = 1024
27
+ HOP_LENGTH = 512
28
+ WINDOW_SEC = 1.5
29
+ MAX_TIME_FRAMES = int(round((WINDOW_SEC * SR_MODEL - N_FFT) / HOP_LENGTH + 1))
30
+ # ๊ณ ์ • ์ž…๋ ฅ ํ”„๋ ˆ์ž„์— ํ•„์š”ํ•œ ์ƒ˜ํ”Œ ์ˆ˜ (GPU ๋ฉœ์šฉ ๊ณ ์ • ๊ธธ์ด)
31
+ MAX_AUDIO_SAMPLES = (MAX_TIME_FRAMES - 1) * HOP_LENGTH + N_FFT
32
+
33
+ def _compute_mel_fast(y: np.ndarray, sr: int) -> np.ndarray:
34
+ """๋ฉœ ์ŠคํŽ™ํŠธ๋กœ๊ทธ๋žจ ๊ณ„์‚ฐ. torchaudio ์‚ฌ์šฉ ์‹œ librosa๋ณด๋‹ค ๋น ๋ฆ„, ๋™์ผ (n_mels, time) 0~1."""
35
+ try:
36
+ import torchaudio
37
+ t = torch.from_numpy(y).float().unsqueeze(0)
38
+ mel = torchaudio.functional.mel_spectrogram(
39
+ t, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=N_FFT,
40
+ f_min=0.0, f_max=float(sr) / 2, n_mels=N_MELS, power=2.0,
41
+ )
42
+ mel = mel.squeeze(0).numpy()
43
+ mel_db = np.clip(10.0 * np.log10(np.maximum(mel, 1e-10)), -80.0, 0.0)
44
+ return ((mel_db + 80.0) / 80.0).astype(np.float32)
45
+ except Exception:
46
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH)
47
+ S_db = np.clip(librosa.power_to_db(S, ref=1.0), -80.0, 0.0)
48
+ return ((S_db + 80.0) / 80.0).astype(np.float32)
49
+
50
+ def _load_audio_fast(path: str) -> tuple[np.ndarray, int]:
51
+ """torchaudio ์šฐ์„  ์‚ฌ์šฉ (librosa๋ณด๋‹ค ๋กœ๋”ฉ ๋น ๋ฆ„), ์‹คํŒจ ์‹œ librosa ํด๋ฐฑ."""
52
+ try:
53
+ import torchaudio
54
+ wav, sr = torchaudio.load(path)
55
+ if wav.shape[0] > 1:
56
+ wav = wav.mean(dim=0, keepdim=True)
57
+ if sr != SR_MODEL:
58
+ wav = torchaudio.functional.resample(wav, sr, SR_MODEL)
59
+ return wav.squeeze(0).numpy().astype(np.float32), SR_MODEL
60
+ except Exception:
61
+ y, sr = librosa.load(path, sr=SR_MODEL)
62
+ return y.astype(np.float32), sr
63
+
64
+ def str2bool(v):
65
+ if isinstance(v, bool): return v
66
+ return v.lower() in ('yes', 'true', 't', 'y', '1')
67
+
68
+ def _parse_range(s):
69
+ """'0.8,1.2' ๊ฐ™์€ ๋ฌธ์ž์—ด์„ (0.8, 1.2) ํŠœํ”Œ๋กœ ๋ณ€ํ™˜"""
70
+ if isinstance(s, list): return tuple(s)
71
+ try:
72
+ s = str(s).strip("[]").split(",")
73
+ return (float(s[0]), float(s[1]))
74
+ except: return (0, 0)
75
+
76
+ class KWSDataset(Dataset):
77
+ def __init__(self, items, label_map=None, trim_silence=False, trim_top_db=30.0,
78
+ rms_norm_target=0.05, pre_emphasis=0.97, augment_noise_scale=0.0,
79
+ augment_time_stretch=(0.0, 0.0), augment_pitch=(0.0, 0.0),
80
+ augment_shift=0.0, spec_augment=(0, 0), preprocess_cache=None, mel_cache=None,
81
+ skip_heavy_augment=False, mel_gpu=False):
82
+ self.label_map = label_map
83
+ self.items = items
84
+ self.trim_silence = trim_silence
85
+ self.trim_top_db = trim_top_db
86
+ self.rms_norm_target = rms_norm_target
87
+ self.pre_emphasis = pre_emphasis
88
+ self.preprocess_cache = preprocess_cache
89
+ self.mel_cache = mel_cache if not mel_gpu else None # mel_gpu๋ฉด Dataset์—์„œ ๋ฉœ ์•ˆ ํ•จ
90
+ self.skip_heavy_augment = skip_heavy_augment
91
+ self.mel_gpu = mel_gpu
92
+
93
+ # Augmentations
94
+ self.augment_noise_scale = augment_noise_scale
95
+ self.augment_time_stretch = augment_time_stretch
96
+ self.augment_pitch = augment_pitch
97
+ self.augment_shift = augment_shift
98
+ self.spec_augment = spec_augment
99
+
100
+ def __len__(self): return len(self.items)
101
+
102
+ def _preprocess(self, y):
103
+ # 1. Silence Trim (์˜ต์…˜: ์˜คํƒ ๋ฐฉ์ง€๋ฅผ ์œ„ํ•ด ๋ณดํ†ต False ๊ถŒ์žฅ)
104
+ if self.trim_silence and len(y) > 0:
105
+ y, _ = librosa.effects.trim(y, top_db=self.trim_top_db)
106
+ if len(y) == 0: y = np.zeros(1024, dtype=np.float32)
107
+
108
+ # 2. RMS Norm (๋ณผ๋ฅจ ์ •๊ทœํ™”)
109
+ if self.rms_norm_target > 0:
110
+ rms = np.sqrt(np.mean(y ** 2)) + 1e-8
111
+ y = y * (self.rms_norm_target / rms)
112
+
113
+ # 3. Pre-emphasis (๊ณ ์ฃผํŒŒ ๊ฐ•์กฐ)
114
+ if self.pre_emphasis != 0:
115
+ y = np.append(y[0], y[1:] - self.pre_emphasis * y[:-1]).astype(np.float32)
116
+ return y
117
+
118
+ def _apply_augment(self, y, sr):
119
+ # A. Time Stretch (๋ฌด๊ฑฐ์›€, --fast-augment ์‹œ ์ƒ๋žต)
120
+ if not self.skip_heavy_augment and self.augment_time_stretch[0] != self.augment_time_stretch[1]:
121
+ rate = random.uniform(self.augment_time_stretch[0], self.augment_time_stretch[1])
122
+ if abs(rate - 1.0) > 0.01:
123
+ y = librosa.effects.time_stretch(y, rate=rate)
124
+
125
+ # B. Pitch Shift (๋ฌด๊ฑฐ์›€, --fast-augment ์‹œ ์ƒ๋žต)
126
+ if not self.skip_heavy_augment and self.augment_pitch[0] != self.augment_pitch[1]:
127
+ n_steps = random.uniform(self.augment_pitch[0], self.augment_pitch[1])
128
+ if abs(n_steps) > 0.01:
129
+ y = librosa.effects.pitch_shift(y, sr=sr, n_steps=n_steps)
130
+
131
+ # C. Time Shift (์œ„์น˜ ์ด๋™) - [New] YAML ๋Œ€์‘
132
+ # ์˜ค๋””์˜ค๋ฅผ ์ขŒ์šฐ๋กœ ๋ฐ€๊ณ  ๋นˆ ๊ณต๊ฐ„์€ 0(์นจ๋ฌต)์œผ๋กœ ์ฑ„์›€ (np.roll์€ ์ˆœํ™˜์ด๋ผ ๋น„์ถ”์ฒœ)
133
+ if self.augment_shift > 0:
134
+ shift_sec = random.uniform(-self.augment_shift, self.augment_shift)
135
+ shift_samples = int(shift_sec * sr)
136
+ if abs(shift_samples) > 0:
137
+ y_shifted = np.zeros_like(y)
138
+ if shift_samples > 0: # ์˜ค๋ฅธ์ชฝ์œผ๋กœ ๋ฐ€๊ธฐ (์•ž์— ์นจ๋ฌต)
139
+ if shift_samples < len(y):
140
+ y_shifted[shift_samples:] = y[:-shift_samples]
141
+ else: # ์™ผ์ชฝ์œผ๋กœ ๋ฐ€๊ธฐ (๋’ค์— ์นจ๋ฌต)
142
+ shift_samples = abs(shift_samples)
143
+ if shift_samples < len(y):
144
+ y_shifted[:-shift_samples] = y[shift_samples:]
145
+ y = y_shifted
146
+
147
+ # D. Noise Injection (์†Œ์Œ) - [Updated] ์ •๊ทœํ™” ์ดํ›„์— ๋„ฃ์–ด์•ผ ์ผ๊ด€๋จ
148
+ if self.augment_noise_scale > 0:
149
+ rms = np.sqrt(np.mean(y ** 2)) + 1e-8
150
+ # White Noise ์ƒ์„ฑ (์‹ค์ œ ํ™˜๊ฒฝ์Œ Mix๊ฐ€ ๋” ์ข‹์ง€๋งŒ, ํ˜„์žฌ๋Š” Gaussian์œผ๋กœ ๋Œ€์ฒด)
151
+ noise = np.random.randn(len(y)).astype(np.float32)
152
+ y = y + self.augment_noise_scale * rms * noise
153
+
154
+ return y
155
+
156
+ def _apply_spec_augment(self, mel):
157
+ t_max, f_max = self.spec_augment
158
+ if t_max <= 0 and f_max <= 0: return mel
159
+ n_mels, T = mel.shape
160
+ # Time Masking
161
+ if t_max > 0 and T > t_max:
162
+ t0 = random.randint(0, T - t_max)
163
+ mel[:, t0 : t0 + t_max] = mel.mean() # ํ‰๊ท ๊ฐ’์œผ๋กœ ๋งˆ์Šคํ‚น
164
+ # Freq Masking
165
+ if f_max > 0 and n_mels > f_max:
166
+ f0 = random.randint(0, n_mels - f_max)
167
+ mel[f0 : f0 + f_max, :] = mel.mean()
168
+ return mel
169
+
170
+ def __getitem__(self, idx):
171
+ item = self.items[idx]
172
+ path = os.path.normpath(item['wav_path'])
173
+ try:
174
+ # mel_gpu๊ฐ€ ์•„๋‹ˆ๊ณ  ๊ฒ€์ฆ์šฉ ๋ฉœ ์บ์‹œ ํžˆํŠธ๋ฉด ๊ทธ๋Œ€๋กœ ๋ฐ˜ํ™˜
175
+ if not self.mel_gpu and self.mel_cache is not None and path in self.mel_cache:
176
+ return self.mel_cache[path].clone(), item['label_id']
177
+
178
+ # 1. Load + ์ „์ฒ˜๋ฆฌ
179
+ if self.preprocess_cache is not None:
180
+ if path in self.preprocess_cache:
181
+ y, sr = self.preprocess_cache[path]
182
+ else:
183
+ y, sr = _load_audio_fast(item['wav_path'])
184
+ y = self._preprocess(y)
185
+ self.preprocess_cache[path] = (y.copy(), sr)
186
+ else:
187
+ y, sr = _load_audio_fast(item['wav_path'])
188
+ y = self._preprocess(y)
189
+
190
+ # 2. Augment
191
+ y = self._apply_augment(y, sr)
192
+
193
+ if self.mel_gpu:
194
+ # 3a. GPU ๋ฉœ ๋ชจ๋“œ: ํŒŒํ˜•๋งŒ ํŒจ๋”ฉ/์ž๋ฅด๊ธฐ ํ›„ ๋ฐ˜ํ™˜ (๋ฉœ์€ ํ•™์Šต ๋ฃจํ”„์—์„œ GPU ๋ฐฐ์น˜ ๊ณ„์‚ฐ)
195
+ if len(y) > MAX_AUDIO_SAMPLES:
196
+ y = y[:MAX_AUDIO_SAMPLES]
197
+ else:
198
+ y = np.pad(y, (0, MAX_AUDIO_SAMPLES - len(y)), mode="constant", constant_values=0)
199
+ return torch.from_numpy(y.astype(np.float32)), item['label_id']
200
+
201
+ # 3b. CPU ๋ฉœ ๋ชจ๋“œ: ๊ธฐ์กด์ฒ˜๋Ÿผ ๋ฉœ ๊ณ„์‚ฐ ํ›„ ๋ฐ˜ํ™˜
202
+ norm_mel = _compute_mel_fast(y, sr)
203
+ norm_mel = self._apply_spec_augment(norm_mel.copy())
204
+ if norm_mel.shape[1] > MAX_TIME_FRAMES:
205
+ norm_mel = norm_mel[:, :MAX_TIME_FRAMES]
206
+ else:
207
+ norm_mel = np.pad(norm_mel, ((0, 0), (0, MAX_TIME_FRAMES - norm_mel.shape[1])))
208
+ out = torch.from_numpy(norm_mel).float().unsqueeze(0)
209
+ if self.mel_cache is not None:
210
+ self.mel_cache[path] = out.clone()
211
+ return out, item['label_id']
212
+ except Exception as e:
213
+ print(f"โš ๏ธ Load Error ({item['wav_path']}): {e}")
214
+ if self.mel_gpu:
215
+ return torch.zeros(MAX_AUDIO_SAMPLES, dtype=torch.float32), item['label_id']
216
+ return torch.zeros(1, N_MELS, MAX_TIME_FRAMES), item['label_id']
217
+
218
+ class KWSModel(nn.Module):
219
+ def __init__(self, num_classes=3, n_mels=128, dropout=0.0):
220
+ super(KWSModel, self).__init__()
221
+ self.conv1 = nn.Conv1d(n_mels, 64, kernel_size=3, padding=1)
222
+ self.bn1 = nn.BatchNorm1d(64)
223
+ self.relu = nn.ReLU()
224
+ # Stride๋ฅผ ํ™œ์šฉํ•ด ์‹œ๊ฐ„ ์ฐจ์› ์ถ•์†Œ
225
+ self.layer1 = self._make_layer(64, 64, 1)
226
+ self.layer2 = self._make_layer(64, 128, 2)
227
+ self.layer3 = self._make_layer(128, 256, 2)
228
+ self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
229
+ self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
230
+ self.fc = nn.Linear(256, num_classes)
231
+
232
+ def _make_layer(self, in_c, out_c, stride):
233
+ return nn.Sequential(
234
+ nn.Conv1d(in_c, out_c, 3, stride, 1, bias=False), nn.BatchNorm1d(out_c), nn.ReLU(),
235
+ nn.Conv1d(out_c, out_c, 3, 1, 1, bias=False), nn.BatchNorm1d(out_c), nn.ReLU()
236
+ )
237
+
238
+ def forward(self, x):
239
+ x = x.squeeze(1) # [B, 1, F, T] -> [B, F, T]
240
+ x = self.relu(self.bn1(self.conv1(x)))
241
+ x = self.layer3(self.layer2(self.layer1(x)))
242
+ x = self.adaptive_pool(x).flatten(1)
243
+ return self.fc(self.dropout(x))
244
+
245
+
246
+ class MelExtractorGPU(nn.Module):
247
+ """GPU์—์„œ ๋ฐฐ์น˜ ๋‹จ์œ„๋กœ ๋ฉœ ์ŠคํŽ™ํŠธ๋กœ๊ทธ๋žจ + (์˜ต์…˜) Spec Augment. Dataset์€ ํŒŒํ˜•๋งŒ ๋ฐ˜ํ™˜ํ•˜๋ฉด ๋จ."""
248
+ def __init__(self, spec_augment: tuple[int, int] = (0, 0)):
249
+ super().__init__()
250
+ self.spec_augment = spec_augment
251
+ try:
252
+ import torchaudio
253
+ self._mel_fn = torchaudio.transforms.MelSpectrogram(
254
+ sample_rate=SR_MODEL,
255
+ n_fft=N_FFT,
256
+ win_length=N_FFT,
257
+ hop_length=HOP_LENGTH,
258
+ f_min=0.0,
259
+ f_max=float(SR_MODEL) / 2,
260
+ n_mels=N_MELS,
261
+ power=2.0,
262
+ )
263
+ except Exception:
264
+ self._mel_fn = None
265
+
266
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
267
+ # waveform: (B, samples)
268
+ if self._mel_fn is None:
269
+ raise RuntimeError("MelExtractorGPU requires torchaudio")
270
+ mel = self._mel_fn(waveform)
271
+ mel_db = 10.0 * torch.log10(torch.clamp(mel, min=1e-10))
272
+ mel_db = torch.clamp(mel_db, -80.0, 0.0)
273
+ norm_mel = (mel_db + 80.0) / 80.0
274
+ # (B, n_mels, time) -> pad/crop to MAX_TIME_FRAMES
275
+ T = norm_mel.shape[2]
276
+ if T > MAX_TIME_FRAMES:
277
+ norm_mel = norm_mel[:, :, :MAX_TIME_FRAMES]
278
+ elif T < MAX_TIME_FRAMES:
279
+ norm_mel = torch.nn.functional.pad(norm_mel, (0, MAX_TIME_FRAMES - T))
280
+ if self.training and self.spec_augment[0] > 0 and self.spec_augment[1] > 0:
281
+ norm_mel = self._spec_augment(norm_mel)
282
+ return norm_mel.unsqueeze(1) # (B, 1, N_MELS, MAX_TIME_FRAMES)
283
+
284
+ def _spec_augment(self, mel: torch.Tensor) -> torch.Tensor:
285
+ B, F, T = mel.shape
286
+ t_max, f_max = self.spec_augment
287
+ if t_max > 0 and T > t_max:
288
+ t0 = torch.randint(0, T - t_max + 1, (B,), device=mel.device)
289
+ for i in range(B):
290
+ mel[i, :, t0[i] : t0[i] + t_max] = mel[i].mean()
291
+ if f_max > 0 and F > f_max:
292
+ f0 = torch.randint(0, F - f_max + 1, (B,), device=mel.device)
293
+ for i in range(B):
294
+ mel[i, f0[i] : f0[i] + f_max, :] = mel[i].mean()
295
+ return mel
296
+
297
+
298
+ def _serializable_config(args, best_epoch, best_metric, class_names, label_map, normal_dir_added=0):
299
+ out = {}
300
+ for k, v in vars(args).items():
301
+ if v is None or (isinstance(v, float) and (v != v)): continue
302
+ if isinstance(v, (str, int, float, bool)) or (isinstance(v, (list, tuple)) and all(isinstance(x, (str, int, float, bool)) for x in (v or []))):
303
+ out[k] = v
304
+ else: out[k] = str(v)
305
+ out.update({
306
+ "best_epoch": best_epoch,
307
+ "val_recall_positive_min": best_metric,
308
+ "class_names": class_names,
309
+ "label_map": label_map,
310
+ "normal_dir_added": normal_dir_added
311
+ })
312
+ return out
313
+
314
+ def main():
315
+ parser = argparse.ArgumentParser()
316
+ # Path
317
+ parser.add_argument('--manifest', type=str, default='')
318
+ parser.add_argument('--dataset-root', type=str, default='')
319
+ parser.add_argument('--data-dir', type=str, default='./kws_data')
320
+ parser.add_argument('--output-dir', type=str, default='./kws_models')
321
+
322
+ # Train
323
+ parser.add_argument('--epochs', type=int, default=50)
324
+ parser.add_argument('--lr', type=float, default=0.001)
325
+ parser.add_argument('--batch-size', type=int, default=32)
326
+ parser.add_argument('--keywords', nargs='+', default=['help_me', 'save_me'])
327
+ parser.add_argument('--include-normal', action='store_true')
328
+ parser.add_argument('--normal-dir', type=str, default='', help='์ผ๋ฐ˜(Normal) ๋ฐ์ดํ„ฐ ์ถ”๊ฐ€ ๊ฒฝ๋กœ')
329
+
330
+ # Augment
331
+ parser.add_argument('--trim_silence', type=str2bool, default=False)
332
+ parser.add_argument('--rms_norm', type=float, default=0.05)
333
+ parser.add_argument('--augment_noise_scale', type=float, default=0.0)
334
+ parser.add_argument('--augment_shift', type=float, default=0.0, help='[New] Time Shift Augmentation (sec)')
335
+ parser.add_argument('--augment_pitch', type=str, default='0,0')
336
+ parser.add_argument('--augment_time_stretch', type=str, default='0,0')
337
+ parser.add_argument('--spec_augment_on', type=str2bool, default=False)
338
+ parser.add_argument('--fast-augment', action='store_true', help='time_stretch/pitch_shift ์ƒ๋žต โ†’ ์›Œ์ปค CPU ๋ถ€ํ•˜ ๊ฐ์†Œ, GPU ํ™œ์šฉ๋„ ์ƒ์Šน (sweep ๋ณ‘๋ชฉ ์‹œ ์‚ฌ์šฉ)')
339
+ parser.add_argument('--mel-gpu', action='store_true', help='๋ฉœ ์ŠคํŽ™ํŠธ๋กœ๊ทธ๋žจ์„ GPU์—์„œ ๋ฐฐ์น˜ ๊ณ„์‚ฐ โ†’ ์›Œ์ปค CPU ๋ถ€ํ•˜ ๊ฐ์†Œ, ํ•™์Šต ๊ฐ€์†')
340
+
341
+ # Regularization (๊ณผ์ ํ•ฉ ์™„ํ™”)
342
+ parser.add_argument('--class_weight', type=str, default='inverse')
343
+ parser.add_argument('--dropout', type=float, default=0.0)
344
+ parser.add_argument('--weight_decay', type=float, default=1e-4, help='L2 ์ •๊ทœํ™” (Adam). 0์ด๋ฉด ๋น„ํ™œ์„ฑํ™”')
345
+ parser.add_argument('--early_stop_patience', type=int, default=10)
346
+ parser.add_argument('--lr_scheduler_patience', type=int, default=0, help='val_loss ๊ฐœ์„  ์—†์„ ๋•Œ LR ๊ฐ์†Œ ๋Œ€๊ธฐ ์—ํญ. 0์ด๋ฉด ๋น„ํ™œ์„ฑํ™”')
347
+
348
+ # HF / WandB (HF ๊ธฐ๋ณธ: dusen0528/kws)
349
+ parser.add_argument('--hf-repo', type=str, default='', help='Hugging Face ๋ชจ๋ธ repo (๊ธฐ๋ณธ: dusen0528/kws)')
350
+ parser.add_argument('--hf-token', type=str, default='')
351
+ parser.add_argument('--wandb-project', type=str, default='kws')
352
+ parser.add_argument('--wandb-entity', type=str, default='')
353
+ parser.add_argument('--amp', type=str2bool, default=True, help='GPU์—์„œ ํ˜ผํ•ฉ์ •๋ฐ€(AMP) ์‚ฌ์šฉ โ†’ ํ•™์Šต ์†๋„ ํ–ฅ์ƒ')
354
+ parser.add_argument('--workers', type=int, default=4, help='DataLoader worker ์ˆ˜ (0์ด๋ฉด ๋ฉ”์ธ๋งŒ)')
355
+
356
+ args, _ = parser.parse_known_args()
357
+
358
+ # --- 1. Label Map & Data Loading ---
359
+ label_map = {}
360
+ # Normal์ด ์žˆ์œผ๋ฉด 0๋ฒˆ์œผ๋กœ ๊ณ ์ • (๊ด€๋ก€)
361
+ if args.include_normal or (args.normal_dir and os.path.isdir(args.normal_dir)):
362
+ label_map['normal'] = 0
363
+ for kw in args.keywords: label_map[kw] = len(label_map)
364
+ num_classes = len(label_map)
365
+ class_names = [k for k, v in sorted(label_map.items(), key=lambda x: x[1])]
366
+
367
+ records = []
368
+ # A. Manifest Load
369
+ if args.manifest and os.path.exists(args.manifest):
370
+ root = Path(args.dataset_root)
371
+ with open(args.manifest, "r", encoding="utf-8") as f:
372
+ for line in f:
373
+ try:
374
+ obj = json.loads(line)
375
+ lbl = obj.get("label")
376
+ if lbl in label_map:
377
+ ap = obj["audio_path"]
378
+ full_p = ap if os.path.isabs(ap) else str(root / ap)
379
+ if os.path.exists(full_p):
380
+ records.append({"wav_path": full_p, "label": lbl, "label_id": label_map[lbl]})
381
+ except: pass
382
+
383
+ # B. Folder Load (Fallback)
384
+ if not records:
385
+ for ln, lid in label_map.items():
386
+ dp = os.path.join(args.data_dir, ln)
387
+ if os.path.isdir(dp):
388
+ for f in os.listdir(dp):
389
+ if f.lower().endswith(('.wav', '.mp3')):
390
+ records.append({"wav_path": os.path.join(dp, f), "label": ln, "label_id": lid})
391
+
392
+ # C. Normal Dir Augmentation
393
+ normal_dir_added = 0
394
+ if args.normal_dir and os.path.isdir(args.normal_dir) and 'normal' in label_map:
395
+ n_before = len(records)
396
+ for f in os.listdir(args.normal_dir):
397
+ if f.lower().endswith(('.wav', '.mp3')):
398
+ full_p = os.path.join(args.normal_dir, f)
399
+ if os.path.isfile(full_p):
400
+ records.append({"wav_path": full_p, "label": "normal", "label_id": label_map["normal"]})
401
+ normal_dir_added = len(records) - n_before
402
+
403
+ if not records: print("โŒ ๋ฐ์ดํ„ฐ ์—†์Œ"); return
404
+
405
+ os.makedirs(args.output_dir, exist_ok=True)
406
+ hf_repo = (args.hf_repo or os.environ.get("HF_REPO", "")).strip() or "dusen0528/kws"
407
+ hf_token = args.hf_token or os.environ.get("HF_TOKEN", "")
408
+
409
+ # --- 2. WandB Init ---
410
+ run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
411
+ if os.environ.get("WANDB_SWEEP_ID"):
412
+ wandb.init(config=vars(args), name=run_name)
413
+ # Sweep ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ args๋ฅผ ๋ฎ์–ด์”€
414
+ for k, v in wandb.config.items():
415
+ setattr(args, k, v)
416
+ else:
417
+ wandb.init(config=vars(args), project=args.wandb_project, entity=args.wandb_entity or None, name=run_name)
418
+
419
+ # --- 3. Dataset & Loader ---
420
+ stratify_ids = [r['label_id'] for r in records]
421
+ try:
422
+ train_items, val_items = train_test_split(records, test_size=0.2, random_state=42, stratify=stratify_ids)
423
+ except ValueError:
424
+ train_items, val_items = train_test_split(records, test_size=0.2, random_state=42)
425
+
426
+ preprocess_cache: dict = {}
427
+ fast_augment = getattr(args, 'fast_augment', False)
428
+ mel_gpu = getattr(args, "mel_gpu", False)
429
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
430
+ if mel_gpu and device.type != "cuda":
431
+ mel_gpu = False
432
+ print(" โš ๏ธ --mel-gpu๋Š” CUDA์—์„œ๋งŒ ๋™์ž‘, ๋น„ํ™œ์„ฑํ™”ํ•จ")
433
+ train_ds = KWSDataset(train_items, label_map,
434
+ trim_silence=args.trim_silence,
435
+ rms_norm_target=args.rms_norm,
436
+ augment_noise_scale=args.augment_noise_scale,
437
+ augment_pitch=_parse_range(args.augment_pitch),
438
+ augment_time_stretch=_parse_range(args.augment_time_stretch),
439
+ augment_shift=args.augment_shift,
440
+ spec_augment=(15, 10) if args.spec_augment_on else (0, 0),
441
+ preprocess_cache=preprocess_cache,
442
+ skip_heavy_augment=fast_augment,
443
+ mel_gpu=mel_gpu)
444
+
445
+ val_mel_cache: dict = {} if not mel_gpu else {}
446
+ val_ds = KWSDataset(val_items, label_map,
447
+ trim_silence=args.trim_silence,
448
+ rms_norm_target=args.rms_norm,
449
+ preprocess_cache=preprocess_cache,
450
+ mel_cache=val_mel_cache,
451
+ mel_gpu=mel_gpu)
452
+
453
+ use_amp = args.amp and device.type == 'cuda'
454
+ num_workers = args.workers if device.type == 'cuda' else 0
455
+ pin = (device.type == 'cuda')
456
+ loader_kw = dict(batch_size=args.batch_size, pin_memory=pin, num_workers=num_workers)
457
+ if num_workers > 0:
458
+ loader_kw["persistent_workers"] = True
459
+ loader_kw["prefetch_factor"] = 8
460
+ train_loader = DataLoader(train_ds, shuffle=True, **loader_kw)
461
+ val_loader = DataLoader(val_ds, **loader_kw)
462
+ if use_amp:
463
+ print("โšก AMP(ํ˜ผํ•ฉ์ •๋ฐ€) ์‚ฌ์šฉ")
464
+
465
+ # --- 4. Model & Loss ---
466
+ model = KWSModel(num_classes=num_classes, dropout=args.dropout).to(device)
467
+ mel_extractor = None
468
+ if mel_gpu and device.type == "cuda":
469
+ mel_extractor = MelExtractorGPU(
470
+ spec_augment=(15, 10) if args.spec_augment_on else (0, 0)
471
+ ).to(device)
472
+ optimizer = torch.optim.Adam(
473
+ model.parameters(),
474
+ lr=args.lr,
475
+ weight_decay=getattr(args, 'weight_decay', 0.0),
476
+ )
477
+ scheduler = None
478
+ if getattr(args, 'lr_scheduler_patience', 0) > 0:
479
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
480
+ optimizer, mode='min', factor=0.5, patience=args.lr_scheduler_patience, min_lr=1e-6
481
+ )
482
+ print(f" ๐Ÿ“‰ LR Scheduler: val_loss {args.lr_scheduler_patience}ep ๊ฐœ์„  ์—†์œผ๋ฉด LRร—0.5")
483
+
484
+ class_weights = None
485
+ if args.class_weight == 'inverse':
486
+ from collections import Counter
487
+ counts = Counter([r['label_id'] for r in train_items])
488
+ # N_samples / (N_classes * count)
489
+ weights = [len(train_items) / (num_classes * (counts.get(i, 0) + 1)) for i in range(num_classes)]
490
+ class_weights = torch.tensor(weights, dtype=torch.float32).to(device)
491
+ print(f"โš–๏ธ Class Weights: {class_weights.tolist()}")
492
+
493
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
494
+ scaler = torch.amp.GradScaler("cuda") if use_amp else None
495
+
496
+ # --- 5. Training Loop ---
497
+ best_rec_min = 0.0
498
+ best_fpr = 1.0
499
+ best_epoch = 0
500
+ patience = 0
501
+
502
+ print(f"๐Ÿš€ ํ•™์Šต ์‹œ์ž‘ (๋ฐ์ดํ„ฐ: {len(records)}๊ฐœ, ํด๋ž˜์Šค: {class_names})")
503
+ print(f" workers={num_workers} batch={args.batch_size} | GPU ๋‚ฎ๊ณ  pt_data_worker 100%๋ฉด: --batch-size 64 --workers 8, --fast-augment, --mel-gpu")
504
+ if fast_augment:
505
+ print(f" โšก --fast-augment: time_stretch/pitch_shift ์ƒ๋žต โ†’ ์›Œ์ปค ๋ถ€ํ•˜ ๊ฐ์†Œ")
506
+ if mel_gpu and mel_extractor is not None:
507
+ print(f" โšก --mel-gpu: ๋ฉœ ์ŠคํŽ™ํŠธ๋กœ๊ทธ๋žจ์„ GPU์—์„œ ๋ฐฐ์น˜ ๊ณ„์‚ฐ โ†’ ์›Œ์ปค CPU ๋ถ€ํ•˜ ๊ฐ์†Œ")
508
+
509
+ pbar_epoch = tqdm(range(1, args.epochs + 1), desc="Epoch", unit="ep")
510
+ for e in pbar_epoch:
511
+ # A. Train
512
+ model.train()
513
+ t_loss = 0; t_corr = 0; t_total = 0
514
+ train_pbar = tqdm(train_loader, desc=f"Ep {e}", leave=False, unit="batch")
515
+ for m, l in train_pbar:
516
+ m, l = m.to(device, non_blocking=True), l.to(device, non_blocking=True)
517
+ if mel_extractor is not None:
518
+ mel_extractor.train()
519
+ m = mel_extractor(m)
520
+ optimizer.zero_grad()
521
+ if use_amp and scaler is not None:
522
+ with torch.amp.autocast("cuda"):
523
+ out = model(m)
524
+ loss = criterion(out, l)
525
+ scaler.scale(loss).backward()
526
+ scaler.step(optimizer)
527
+ scaler.update()
528
+ else:
529
+ out = model(m)
530
+ loss = criterion(out, l)
531
+ loss.backward()
532
+ optimizer.step()
533
+ t_loss += loss.item()
534
+ _, p = torch.max(out, 1)
535
+ t_corr += (p == l).sum().item()
536
+ t_total += l.size(0)
537
+ train_pbar.set_postfix(loss=f"{t_loss / max(1, train_pbar.n):.3f}")
538
+
539
+ train_loss = t_loss / len(train_loader)
540
+ train_acc = 100 * t_corr / t_total if t_total else 0
541
+
542
+ # B. Validation
543
+ model.eval()
544
+ v_loss = 0; v_preds, v_labels = [], []
545
+ with torch.no_grad():
546
+ for m, l in tqdm(val_loader, desc="Val", leave=False, unit="batch"):
547
+ m, l = m.to(device, non_blocking=True), l.to(device, non_blocking=True)
548
+ if mel_extractor is not None:
549
+ mel_extractor.eval()
550
+ with torch.no_grad():
551
+ m = mel_extractor(m)
552
+ if use_amp:
553
+ with torch.amp.autocast("cuda"):
554
+ out = model(m)
555
+ loss = criterion(out, l)
556
+ else:
557
+ out = model(m)
558
+ loss = criterion(out, l)
559
+ v_loss += loss.item()
560
+ _, p = torch.max(out, 1)
561
+ v_preds.extend(p.cpu().numpy())
562
+ v_labels.extend(l.cpu().numpy())
563
+
564
+ val_loss = v_loss / len(val_loader) if len(val_loader) else 0
565
+ val_acc = 100 * accuracy_score(v_labels, v_preds) if v_labels else 0
566
+
567
+ # C. Metrics (Recall per class & False Positive Rate)
568
+ labels_idx = list(range(num_classes))
569
+ prec, rec, f1, _ = precision_recall_fscore_support(
570
+ v_labels, v_preds, labels=labels_idx, average=None, zero_division=0
571
+ )
572
+ rec = np.atleast_1d(rec).astype(float)
573
+ prec = np.atleast_1d(prec).astype(float)
574
+ f1 = np.atleast_1d(f1).astype(float)
575
+ if len(rec) < num_classes:
576
+ rec = np.pad(rec, (0, num_classes - len(rec)), constant_values=0.0)
577
+ prec = np.pad(prec, (0, num_classes - len(prec)), constant_values=0.0)
578
+ f1 = np.pad(f1, (0, num_classes - len(f1)), constant_values=0.0)
579
+
580
+ log_dict = {
581
+ "epoch": e,
582
+ "train/loss": train_loss, "train/acc": train_acc,
583
+ "val/loss": val_loss, "val/acc": val_acc,
584
+ "val/f1_macro": f1_score(v_labels, v_preds, average='macro', zero_division=0),
585
+ "val/train_loss_gap": val_loss - train_loss, # >0 ์ด๋ฉด val์ด train๋ณด๋‹ค ๋‚˜์จ โ†’ ์˜ค๋ฒ„ํ”ผํŒ… ์˜์‹ฌ
586
+ "val/train_acc_gap": train_acc - val_acc, # >0 ์ด๋ฉด train์ด val๋ณด๋‹ค ์ข‹์Œ โ†’ ์˜ค๋ฒ„ํ”ผํŒ… ์˜์‹ฌ
587
+ }
588
+ if scheduler is not None:
589
+ log_dict["train/lr"] = optimizer.param_groups[0]["lr"]
590
+
591
+ pos_rec = []
592
+ for i, name in enumerate(class_names):
593
+ safe_name = name.replace("/", "_")
594
+ log_dict[f"val/recall_{safe_name}"] = rec[i]
595
+ log_dict[f"val/f1_{safe_name}"] = f1[i]
596
+ if name in ['help_me', 'save_me']:
597
+ pos_rec.append(rec[i])
598
+
599
+ # [Metric 1] Recall Positive Min (๋ชฉํ‘œ: ๋ฏธํƒ ๋ฐฉ์ง€)
600
+ rec_min = min(pos_rec) if pos_rec else 0.0
601
+ log_dict["val/recall_positive_min"] = rec_min
602
+
603
+ # [Metric 2] False Positive Rate (๋ชฉํ‘œ: ์˜คํƒ ๋ฐฉ์ง€)
604
+ # Normal ๋ฐ์ดํ„ฐ(์‹ค์ œ 0๋ฒˆ)๊ฐ€ ๋“ค์–ด์™”๋Š”๋ฐ -> 0๋ฒˆ์ด ์•„๋‹ˆ๋ผ๊ณ (Positive๋ผ๊ณ ) ์˜ˆ์ธกํ•œ ๋น„์œจ
605
+ if 'normal' in label_map:
606
+ norm_idx = label_map['normal']
607
+ norm_mask = [i for i, x in enumerate(v_labels) if x == norm_idx]
608
+ if norm_mask:
609
+ norm_preds = [v_preds[i] for i in norm_mask]
610
+ false_alarms = sum(1 for p in norm_preds if p != norm_idx)
611
+ fpr = false_alarms / len(norm_mask)
612
+ else:
613
+ fpr = 0.0
614
+ log_dict["val/false_positive_rate"] = fpr
615
+
616
+ wandb.log(log_dict)
617
+
618
+ current_fpr = log_dict.get("val/false_positive_rate", 1.0)
619
+ loss_gap = val_loss - train_loss
620
+ acc_gap = train_acc - val_acc
621
+ if scheduler is not None:
622
+ scheduler.step(val_loss)
623
+ pbar_epoch.set_postfix(loss=f"{train_loss:.3f}", rec_min=f"{rec_min:.2f}", fpr=f"{current_fpr:.2f}")
624
+ print(f"[Ep {e}] Loss: {train_loss:.4f} | RecMin: {rec_min:.2f} | FPR: {current_fpr:.2f}")
625
+ # ๊ณผ์ ํ•ฉ ๊ฒฝ๊ณ : ๊ธฐ์ค€์„ ๊ฐ•ํ•˜๊ฒŒ ํ•ด์„œ ๋งค ์—ํญ ๋œจ์ง€ ์•Š๊ฒŒ ํ•จ (loss_gap 0.5 ์ด์ƒ ๋˜๋Š” acc_gap 15% ์ด์ƒ์ผ ๋•Œ๋งŒ)
626
+ if e > 5 and (loss_gap > 0.5 or acc_gap > 15.0):
627
+ print(f" โš ๏ธ ์˜ค๋ฒ„ํ”ผํŒ… ๊ฐ€๋Šฅ์„ฑ (val_loss - train_loss = {loss_gap:.3f}, train_acc - val_acc = {acc_gap:.1f}%)")
628
+ if e == 10 and train_loss > 0.5 and val_loss > 0.5:
629
+ print(f" ๐Ÿ’ก ์–ธ๋”ํ”ผํŒ… ๊ฐ€๋Šฅ์„ฑ (Ep 10์ธ๋ฐ loss ๋‘˜ ๋‹ค ๋†’์Œ โ†’ epoch/๋ชจ๋ธ/ํ•™์Šต๋ฅ  ๊ฒ€ํ† )")
630
+
631
+ # D. Checkpoint
632
+ # ์ €์žฅ ์กฐ๊ฑด 1: Recall Min์ด ๋” ๋†’์œผ๋ฉด ๋ฌด์กฐ๊ฑด ์ €์žฅ (๋ฏธํƒ ๋ฐฉ์ง€ ์ตœ์šฐ์„ )
633
+ # ์ €์žฅ ์กฐ๊ฑด 2: Recall Min์ด ๊ฐ™์œผ๋ฉด, ์˜คํƒ๋ฅ (FPR)์ด ๋” ๋‚ฎ์€ ๋ชจ๋ธ ์ €์žฅ
634
+ is_best = False
635
+ if rec_min > best_rec_min:
636
+ is_best = True
637
+ elif rec_min == best_rec_min and rec_min > 0:
638
+ if current_fpr < best_fpr:
639
+ is_best = True
640
+ print(f"โœจ Recall ๋™์ ({rec_min:.2f})์ด๋‚˜ FPR ๊ฐœ์„ ๋จ ({best_fpr:.2f} -> {current_fpr:.2f})")
641
+
642
+ if is_best:
643
+ best_rec_min = rec_min
644
+ best_fpr = current_fpr
645
+ best_epoch = e
646
+ patience = 0
647
+ from checkpoint_io import save_checkpoint
648
+ save_checkpoint(
649
+ args.output_dir,
650
+ model.state_dict(),
651
+ _serializable_config(args, best_epoch, best_rec_min, class_names, label_map, normal_dir_added),
652
+ also_save_pth=False,
653
+ )
654
+ print(f"๐Ÿ’พ Best ์ €์žฅ | Ep {e} | RecMin {best_rec_min:.2f} | FPR {best_fpr:.2f} | โ†’ best_kws_model.safetensors")
655
+ else:
656
+ patience += 1
657
+ if args.early_stop_patience > 0 and patience >= args.early_stop_patience:
658
+ print("๐Ÿ›‘ Early Stopping")
659
+ break
660
+
661
+ # ํ•™์Šต ์ข…๋ฃŒ
662
+ if getattr(wandb, "run", None) is not None:
663
+ wandb.run.summary["val/recall_positive_min"] = best_rec_min
664
+ wandb.run.summary["best_epoch"] = best_epoch
665
+ wandb.run.summary["best_fpr"] = best_fpr
666
+
667
+ # --- HF: sweep/ํ•™์Šต ์ข…๋ฃŒ ํ›„ ์—…๋กœ๋“œ (๊ธฐ์ค€ ๋‚ ์งœ ํด๋” + ํŒŒ๋ผ๋ฏธํ„ฐยท์„ฑ๋Šฅ) ---
668
+ if hf_repo:
669
+ try:
670
+ from huggingface_hub import HfApi
671
+ api = HfApi(token=hf_token or None)
672
+ api.create_repo(hf_repo, repo_type="model", exist_ok=True)
673
+ ts = datetime.now().strftime("%Y%m%d_%H%M%S")
674
+ folder = f"run_{ts}"
675
+ for f in ["best_kws_model.safetensors", "best_kws_model.pth", "training_config.json"]:
676
+ p = os.path.join(args.output_dir, f)
677
+ if os.path.exists(p):
678
+ api.upload_file(path_or_fileobj=p, path_in_repo=f"{folder}/{f}", repo_id=hf_repo, repo_type="model")
679
+ metrics = {
680
+ "uploaded_at": datetime.now().isoformat(),
681
+ "type": "run_post_training",
682
+ "best_epoch": best_epoch,
683
+ "val_recall_positive_min": best_rec_min,
684
+ "best_fpr": best_fpr,
685
+ }
686
+ if getattr(wandb, "run", None) is not None:
687
+ run = wandb.run
688
+ metrics["wandb_run_id"] = getattr(run, "id", "") or ""
689
+ metrics["wandb_run_name"] = getattr(run, "name", "") or ""
690
+ metrics["wandb_run_url"] = getattr(run, "url", "") or ""
691
+ metrics_path = os.path.join(args.output_dir, "metrics.json")
692
+ with open(metrics_path, "w", encoding="utf-8") as mf:
693
+ json.dump(metrics, mf, indent=2, ensure_ascii=False)
694
+ api.upload_file(path_or_fileobj=metrics_path, path_in_repo=f"{folder}/metrics.json", repo_id=hf_repo, repo_type="model")
695
+ # best/ = ๋ชจ๋“  run ์ค‘ ์ง„์งœ ์ตœ๊ณ ๋งŒ ์œ ์ง€. ๊ธฐ์กด best ์ง€ํ‘œ์™€ ๋น„๊ตํ•ด ๋” ์ข‹์„ ๋•Œ๋งŒ ๋ฎ์–ด์“ฐ๊ธฐ
696
+ should_update_best = True
697
+ existing_rec, existing_fpr = 0.0, 1.0
698
+ try:
699
+ from huggingface_hub import hf_hub_download
700
+ path = hf_hub_download(
701
+ repo_id=hf_repo,
702
+ filename="best/metrics.json",
703
+ repo_type="model",
704
+ token=hf_token or None,
705
+ )
706
+ with open(path, "r", encoding="utf-8") as f:
707
+ existing = json.load(f)
708
+ existing_rec = float(existing.get("val_recall_positive_min", 0.0))
709
+ existing_fpr = float(existing.get("best_fpr", 1.0))
710
+ # ๋” ์ข‹์Œ = rec_min ๋” ๋†’์Œ, ๋˜๋Š” rec_min ๋™์ ์ด๋ฉด FPR ๋” ๋‚ฎ์Œ
711
+ if best_rec_min > existing_rec or (best_rec_min == existing_rec and best_fpr < existing_fpr):
712
+ should_update_best = True
713
+ else:
714
+ should_update_best = False
715
+ except Exception:
716
+ # best/metrics.json ์—†์Œ (์ฒซ run) โ†’ best ๊ฐฑ์‹ 
717
+ should_update_best = True
718
+ if should_update_best:
719
+ for f in ["best_kws_model.safetensors", "best_kws_model.pth", "training_config.json"]:
720
+ p = os.path.join(args.output_dir, f)
721
+ if os.path.exists(p):
722
+ api.upload_file(path_or_fileobj=p, path_in_repo=f"best/{f}", repo_id=hf_repo, repo_type="model")
723
+ api.upload_file(path_or_fileobj=metrics_path, path_in_repo="best/metrics.json", repo_id=hf_repo, repo_type="model")
724
+ print(f"[KWS] HF ์—…๋กœ๋“œ ์™„๋ฃŒ: {hf_repo} -> {folder}/ (ํžˆ์Šคํ† ๋ฆฌ) + best/ (์ „์ฒด run ์ค‘ ์ตœ๊ณ ๋กœ ๊ฐฑ์‹ )")
725
+ else:
726
+ print(f"[KWS] HF ์—…๋กœ๋“œ ์™„๋ฃŒ: {hf_repo} -> {folder}/ (ํžˆ์Šคํ† ๋ฆฌ). best/ ๋ฏธ๊ฐฑ์‹ : ์ด๋ฒˆ rec_min={best_rec_min:.2f} fpr={best_fpr:.2f} vs ๊ธฐ์กด rec_min={existing_rec:.2f} fpr={existing_fpr:.2f}")
727
+ except Exception as e:
728
+ print(f"[KWS] HF Upload Error: {e}")
729
+
730
+ if __name__ == "__main__": main()