Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +6 -6
- app.py +6 -0
- checkpoint_io.py +119 -0
- gradio_kws_test.py +362 -0
- kws_inference.py +197 -0
- kws_models_fpfix/best_kws_model.safetensors +3 -0
- kws_models_fpfix/training_config.json +48 -0
- requirements.txt +8 -0
- train_kws.py +730 -0
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
---
|
| 2 |
+
title: KWS FP Test
|
| 3 |
+
emoji: ๐ค
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.5.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
KWS FP Test Space
|
app.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gradio_kws_test import build_demo
|
| 2 |
+
|
| 3 |
+
demo = build_demo()
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
demo.launch()
|
checkpoint_io.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# checkpoint_io.py: KWS ์ฒดํฌํฌ์ธํธ Safetensors ์ ์ฅ/๋ก๋ + ๊ธฐ์กด .pth ํธํ.
|
| 2 |
+
# Single Source of Truth: ํ์ต์ .safetensors + training_config.json, ๋ก๋๋ .safetensors ์ฐ์ ํ .pth ํด๋ฐฑ.
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def resolve_checkpoint_path(path: str) -> Tuple[str, str]:
|
| 13 |
+
"""
|
| 14 |
+
์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก์์ (state ํ์ผ ๊ฒฝ๋ก, config ํ์ผ ๊ฒฝ๋ก) ๋ฐํ.
|
| 15 |
+
.safetensors ์ฐ์ , ์์ผ๋ฉด .pth. config๋ ํญ์ ๊ฐ์ ๋๋ ํฐ๋ฆฌ์ training_config.json.
|
| 16 |
+
"""
|
| 17 |
+
path = os.path.normpath(path)
|
| 18 |
+
if os.path.isdir(path):
|
| 19 |
+
base_dir = path
|
| 20 |
+
base_name = "best_kws_model"
|
| 21 |
+
else:
|
| 22 |
+
base_dir = os.path.dirname(path) or "."
|
| 23 |
+
base_name = os.path.splitext(os.path.basename(path))[0] or "best_kws_model"
|
| 24 |
+
config_path = os.path.join(base_dir, "training_config.json")
|
| 25 |
+
for ext in (".safetensors", ".pth"):
|
| 26 |
+
candidate = os.path.join(base_dir, base_name + ext)
|
| 27 |
+
if os.path.isfile(candidate):
|
| 28 |
+
return candidate, config_path
|
| 29 |
+
return os.path.join(base_dir, base_name + ".pth"), config_path
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_state_dict(path: str, map_location: Optional[Any] = None) -> Dict[str, torch.Tensor]:
|
| 33 |
+
"""
|
| 34 |
+
path(๋๋ ๋์ผ ๋๋ ํฐ๋ฆฌ ๋ด best_kws_model.*)์์ state_dict๋ง ๋ก๋.
|
| 35 |
+
.safetensors ์ฐ์ , ์์ผ๋ฉด .pth (์ ์ฒด ์ฒดํฌํฌ์ธํธ๋ฉด 'model' ํค ์ฌ์ฉ).
|
| 36 |
+
"""
|
| 37 |
+
state_path, _ = resolve_checkpoint_path(path)
|
| 38 |
+
if not os.path.isfile(state_path):
|
| 39 |
+
raise FileNotFoundError(f"์ฒดํฌํฌ์ธํธ ์์: {state_path}")
|
| 40 |
+
if state_path.endswith(".safetensors"):
|
| 41 |
+
from safetensors.torch import load_file
|
| 42 |
+
device: str
|
| 43 |
+
if map_location is None:
|
| 44 |
+
device = "cpu"
|
| 45 |
+
elif isinstance(map_location, torch.device):
|
| 46 |
+
if map_location.type == "cuda":
|
| 47 |
+
device = f"cuda:{map_location.index}" if map_location.index is not None else "cuda"
|
| 48 |
+
else:
|
| 49 |
+
device = map_location.type
|
| 50 |
+
elif isinstance(map_location, str):
|
| 51 |
+
device = map_location
|
| 52 |
+
else:
|
| 53 |
+
device = "cpu"
|
| 54 |
+
return load_file(state_path, device=device)
|
| 55 |
+
ckpt = torch.load(state_path, map_location=map_location)
|
| 56 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 57 |
+
return ckpt["model"]
|
| 58 |
+
return ckpt
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_label_map(path: str) -> Optional[Dict[str, int]]:
|
| 62 |
+
"""๊ฐ์ ๋๋ ํฐ๋ฆฌ์ training_config.json์์ label_map ๋ฐํ. ์์ผ๋ฉด None."""
|
| 63 |
+
_, config_path = resolve_checkpoint_path(path)
|
| 64 |
+
if not os.path.isfile(config_path):
|
| 65 |
+
return None
|
| 66 |
+
try:
|
| 67 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
| 68 |
+
cfg = json.load(f)
|
| 69 |
+
return cfg.get("label_map")
|
| 70 |
+
except Exception:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_checkpoint(
|
| 75 |
+
path: str,
|
| 76 |
+
map_location: Optional[Any] = None,
|
| 77 |
+
default_label_map: Optional[Dict[str, int]] = None,
|
| 78 |
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], List[str]]:
|
| 79 |
+
"""
|
| 80 |
+
state_dict, label_map, class_names ๋ฐํ.
|
| 81 |
+
label_map์ training_config.json ์ฐ์ , ์์ผ๋ฉด .pth ๋ด๋ถ, ์์ผ๋ฉด default_label_map.
|
| 82 |
+
"""
|
| 83 |
+
state = load_state_dict(path, map_location=map_location)
|
| 84 |
+
label_map = load_label_map(path)
|
| 85 |
+
if label_map is None:
|
| 86 |
+
state_path, _ = resolve_checkpoint_path(path)
|
| 87 |
+
if state_path.endswith(".pth") and os.path.isfile(state_path):
|
| 88 |
+
ckpt = torch.load(state_path, map_location=map_location)
|
| 89 |
+
if isinstance(ckpt, dict) and "label_map" in ckpt:
|
| 90 |
+
label_map = ckpt["label_map"]
|
| 91 |
+
if label_map is None:
|
| 92 |
+
label_map = default_label_map or {"normal": 0, "help_me": 1, "save_me": 2}
|
| 93 |
+
num_classes = len(label_map)
|
| 94 |
+
id_to_name = {v: k for k, v in label_map.items()}
|
| 95 |
+
class_names: List[str] = [id_to_name[i] for i in range(num_classes)]
|
| 96 |
+
return state, label_map, class_names
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def save_checkpoint(
|
| 100 |
+
output_dir: str,
|
| 101 |
+
state_dict: Dict[str, torch.Tensor],
|
| 102 |
+
config_dict: Dict[str, Any],
|
| 103 |
+
*,
|
| 104 |
+
also_save_pth: bool = False,
|
| 105 |
+
) -> None:
|
| 106 |
+
"""
|
| 107 |
+
state_dict๋ฅผ best_kws_model.safetensors๋ก, config_dict๋ฅผ training_config.json์ผ๋ก ์ ์ฅ.
|
| 108 |
+
also_save_pth=True๋ฉด ๊ธฐ์กด ํธํ์ฉ best_kws_model.pth๋ ์ ์ฅ (state_dict๋ง).
|
| 109 |
+
"""
|
| 110 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 111 |
+
sf_path = os.path.join(output_dir, "best_kws_model.safetensors")
|
| 112 |
+
from safetensors.torch import save_file
|
| 113 |
+
save_file(state_dict, sf_path)
|
| 114 |
+
config_path = os.path.join(output_dir, "training_config.json")
|
| 115 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
| 116 |
+
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
| 117 |
+
if also_save_pth:
|
| 118 |
+
pth_path = os.path.join(output_dir, "best_kws_model.pth")
|
| 119 |
+
torch.save(state_dict, pth_path)
|
gradio_kws_test.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import librosa
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from kws_inference import KWSLongInference, SR_MODEL
|
| 13 |
+
from train_kws import WINDOW_SEC
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _default_model_path() -> str:
|
| 17 |
+
candidates = [
|
| 18 |
+
Path("./kws_models_fpfix/best_kws_model.safetensors"),
|
| 19 |
+
Path("./kws_models/best_kws_model.safetensors"),
|
| 20 |
+
Path("./kws_models_fixed/best_kws_model.safetensors"),
|
| 21 |
+
Path("./kws_models/best_kws_model.pth"),
|
| 22 |
+
]
|
| 23 |
+
for p in candidates:
|
| 24 |
+
if p.exists():
|
| 25 |
+
return str(p)
|
| 26 |
+
return str(candidates[0])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_ENGINE_CACHE: dict[tuple[str, float, float], KWSLongInference] = {}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _get_engine(model_path: str, step_sec: float, rms_threshold: float) -> KWSLongInference:
|
| 33 |
+
key = (str(Path(model_path).resolve()), float(step_sec), float(rms_threshold))
|
| 34 |
+
if key not in _ENGINE_CACHE:
|
| 35 |
+
_ENGINE_CACHE[key] = KWSLongInference(
|
| 36 |
+
checkpoint_path=model_path,
|
| 37 |
+
step_sec=step_sec,
|
| 38 |
+
rms_threshold=rms_threshold,
|
| 39 |
+
)
|
| 40 |
+
return _ENGINE_CACHE[key]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _simulate_alerts(
|
| 44 |
+
window_results: list[Any],
|
| 45 |
+
threshold: float,
|
| 46 |
+
n_consecutive: int,
|
| 47 |
+
cooldown_sec: float,
|
| 48 |
+
) -> list[dict[str, Any]]:
|
| 49 |
+
alerts: list[dict[str, Any]] = []
|
| 50 |
+
consecutive = 0
|
| 51 |
+
last_alert_sec = -1e9
|
| 52 |
+
for w in window_results:
|
| 53 |
+
help_p = float(w.probs.get("help_me", 0.0))
|
| 54 |
+
save_p = float(w.probs.get("save_me", 0.0))
|
| 55 |
+
kw_p = max(help_p, save_p)
|
| 56 |
+
kw_label = "save_me" if save_p >= help_p else "help_me"
|
| 57 |
+
in_cooldown = (w.start_sec - last_alert_sec) < cooldown_sec
|
| 58 |
+
if in_cooldown:
|
| 59 |
+
consecutive = 0
|
| 60 |
+
continue
|
| 61 |
+
if kw_p >= threshold:
|
| 62 |
+
consecutive += 1
|
| 63 |
+
else:
|
| 64 |
+
consecutive = 0
|
| 65 |
+
if consecutive >= n_consecutive:
|
| 66 |
+
alerts.append(
|
| 67 |
+
{
|
| 68 |
+
"t_sec": round(float(w.start_sec), 2),
|
| 69 |
+
"label": kw_label,
|
| 70 |
+
"score": round(float(kw_p), 4),
|
| 71 |
+
"rule": f"{n_consecutive}x >= {threshold:.2f}",
|
| 72 |
+
}
|
| 73 |
+
)
|
| 74 |
+
last_alert_sec = float(w.start_sec)
|
| 75 |
+
consecutive = 0
|
| 76 |
+
return alerts
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def run_test(
|
| 80 |
+
audio_path: str,
|
| 81 |
+
model_path: str,
|
| 82 |
+
step_sec: float,
|
| 83 |
+
rms_threshold: float,
|
| 84 |
+
alert_threshold: float,
|
| 85 |
+
n_consecutive: int,
|
| 86 |
+
cooldown_sec: float,
|
| 87 |
+
):
|
| 88 |
+
if not audio_path:
|
| 89 |
+
raise gr.Error("์ค๋์ค ํ์ผ(๋๋ ๋ง์ดํฌ ์
๋ ฅ)์ ๋ฃ์ด์ฃผ์ธ์.")
|
| 90 |
+
if not model_path or not Path(model_path).exists():
|
| 91 |
+
raise gr.Error(f"๋ชจ๋ธ ํ์ผ์ด ์์ต๋๋ค: {model_path}")
|
| 92 |
+
|
| 93 |
+
engine = _get_engine(model_path=model_path, step_sec=step_sec, rms_threshold=rms_threshold)
|
| 94 |
+
result = engine.predict_long(audio_path)
|
| 95 |
+
|
| 96 |
+
rows: list[list[Any]] = []
|
| 97 |
+
for i, w in enumerate(result.window_results, start=1):
|
| 98 |
+
rows.append(
|
| 99 |
+
[
|
| 100 |
+
i,
|
| 101 |
+
round(float(w.start_sec), 3),
|
| 102 |
+
round(float(w.end_sec), 3),
|
| 103 |
+
w.class_name,
|
| 104 |
+
round(float(w.confidence), 4),
|
| 105 |
+
round(float(w.probs.get("normal", 0.0)), 4),
|
| 106 |
+
round(float(w.probs.get("help_me", 0.0)), 4),
|
| 107 |
+
round(float(w.probs.get("save_me", 0.0)), 4),
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
alerts = _simulate_alerts(
|
| 112 |
+
window_results=result.window_results,
|
| 113 |
+
threshold=alert_threshold,
|
| 114 |
+
n_consecutive=n_consecutive,
|
| 115 |
+
cooldown_sec=cooldown_sec,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
agg = result.aggregated_probs
|
| 119 |
+
summary = (
|
| 120 |
+
f"### ์ถ๋ก ์์ฝ\n"
|
| 121 |
+
f"- Aggregated label: `{result.aggregated_label}`\n"
|
| 122 |
+
f"- Duration: `{result.duration_sec:.2f}s`\n"
|
| 123 |
+
f"- Aggregated probs: normal `{agg.get('normal', 0.0):.4f}` | "
|
| 124 |
+
f"help_me `{agg.get('help_me', 0.0):.4f}` | "
|
| 125 |
+
f"save_me `{agg.get('save_me', 0.0):.4f}`\n"
|
| 126 |
+
f"- Windows: `{len(result.window_results)}` (step `{step_sec:.2f}s`, rms_threshold `{rms_threshold:.4f}`)"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if alerts:
|
| 130 |
+
alert_lines = [
|
| 131 |
+
f"- t={a['t_sec']:.2f}s | `{a['label']}` | score={a['score']:.4f} | {a['rule']}"
|
| 132 |
+
for a in alerts
|
| 133 |
+
]
|
| 134 |
+
alerts_md = "### ๊ฒฝ๋ณด ์๋ฎฌ๋ ์ด์
\n" + "\n".join(alert_lines)
|
| 135 |
+
else:
|
| 136 |
+
alerts_md = "### ๊ฒฝ๋ณด ์๋ฎฌ๋ ์ด์
\n- ์กฐ๊ฑด์ ๋ง์กฑํ ๊ฒฝ๋ณด ์์"
|
| 137 |
+
|
| 138 |
+
details = {
|
| 139 |
+
"aggregated_label": result.aggregated_label,
|
| 140 |
+
"aggregated_probs": result.aggregated_probs,
|
| 141 |
+
"duration_sec": result.duration_sec,
|
| 142 |
+
"num_windows": len(result.window_results),
|
| 143 |
+
"alerts": alerts,
|
| 144 |
+
}
|
| 145 |
+
return summary, alerts_md, rows, details
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _init_stream_state() -> dict[str, Any]:
|
| 149 |
+
return {
|
| 150 |
+
"buffer": np.array([], dtype=np.float32),
|
| 151 |
+
"consecutive": 0,
|
| 152 |
+
"last_alert_time": -1e9,
|
| 153 |
+
"events": [],
|
| 154 |
+
"last_probs": {"normal": 1.0, "help_me": 0.0, "save_me": 0.0},
|
| 155 |
+
"last_decision": "NORMAL",
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _normalize_chunk(chunk: Any) -> tuple[int, np.ndarray] | None:
|
| 160 |
+
if chunk is None:
|
| 161 |
+
return None
|
| 162 |
+
if isinstance(chunk, np.ndarray):
|
| 163 |
+
y = np.asarray(chunk, dtype=np.float32)
|
| 164 |
+
if y.ndim == 2:
|
| 165 |
+
y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
|
| 166 |
+
return SR_MODEL, y.astype(np.float32)
|
| 167 |
+
if isinstance(chunk, (tuple, list)) and len(chunk) == 2:
|
| 168 |
+
sr, y = chunk
|
| 169 |
+
y = np.asarray(y, dtype=np.float32)
|
| 170 |
+
if y.ndim == 2:
|
| 171 |
+
y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
|
| 172 |
+
if y.dtype == np.int16:
|
| 173 |
+
y = y.astype(np.float32) / 32768.0
|
| 174 |
+
elif y.dtype == np.int32:
|
| 175 |
+
y = y.astype(np.float32) / 2147483648.0
|
| 176 |
+
return int(sr), y.astype(np.float32)
|
| 177 |
+
if isinstance(chunk, dict):
|
| 178 |
+
# gradio ๋ฒ์ ์ ๋ฐ๋ผ {"sample_rate": ..., "data": ...} ํน์ {"path": ...} ํํ๊ฐ ์ฌ ์ ์์
|
| 179 |
+
if "sample_rate" in chunk and ("data" in chunk or "array" in chunk):
|
| 180 |
+
sr = int(chunk.get("sample_rate", SR_MODEL))
|
| 181 |
+
y = np.asarray(chunk.get("data", chunk.get("array")), dtype=np.float32)
|
| 182 |
+
if y.ndim == 2:
|
| 183 |
+
y = y.mean(axis=1 if y.shape[1] <= 8 else 0)
|
| 184 |
+
return sr, y.astype(np.float32)
|
| 185 |
+
if "path" in chunk and chunk["path"]:
|
| 186 |
+
try:
|
| 187 |
+
y, sr = librosa.load(str(chunk["path"]), sr=None, mono=True)
|
| 188 |
+
return int(sr), y.astype(np.float32)
|
| 189 |
+
except Exception:
|
| 190 |
+
return None
|
| 191 |
+
return None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def stream_infer(
|
| 195 |
+
state: dict[str, Any] | None,
|
| 196 |
+
chunk: Any,
|
| 197 |
+
model_path: str,
|
| 198 |
+
step_sec: float,
|
| 199 |
+
rms_threshold: float,
|
| 200 |
+
alert_threshold: float,
|
| 201 |
+
margin_threshold: float,
|
| 202 |
+
n_consecutive: int,
|
| 203 |
+
cooldown_sec: float,
|
| 204 |
+
):
|
| 205 |
+
st = state if isinstance(state, dict) else _init_stream_state()
|
| 206 |
+
parsed = _normalize_chunk(chunk)
|
| 207 |
+
if parsed is None:
|
| 208 |
+
probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
|
| 209 |
+
return st, "๋๊ธฐ ์ค...", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
|
| 210 |
+
|
| 211 |
+
sr, y = parsed
|
| 212 |
+
if len(y) == 0:
|
| 213 |
+
probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
|
| 214 |
+
return st, "์
๋ ฅ ์์", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
|
| 215 |
+
|
| 216 |
+
if sr != SR_MODEL:
|
| 217 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=SR_MODEL)
|
| 218 |
+
|
| 219 |
+
buffer = np.concatenate((st["buffer"], y))
|
| 220 |
+
max_len = int(SR_MODEL * WINDOW_SEC)
|
| 221 |
+
if len(buffer) > max_len:
|
| 222 |
+
buffer = buffer[-max_len:]
|
| 223 |
+
st["buffer"] = buffer
|
| 224 |
+
|
| 225 |
+
if len(buffer) < int(SR_MODEL * 0.6):
|
| 226 |
+
probs = st.get("last_probs", {"normal": 1.0, "help_me": 0.0, "save_me": 0.0})
|
| 227 |
+
return st, "๋ฐ์ดํฐ ์์ง ์ค...", st.get("last_decision", "NORMAL"), probs, list(st.get("events", []))
|
| 228 |
+
|
| 229 |
+
engine = _get_engine(model_path=model_path, step_sec=step_sec, rms_threshold=rms_threshold)
|
| 230 |
+
result = engine.predict_long((SR_MODEL, buffer))
|
| 231 |
+
probs = {
|
| 232 |
+
"normal": float(result.aggregated_probs.get("normal", 0.0)),
|
| 233 |
+
"help_me": float(result.aggregated_probs.get("help_me", 0.0)),
|
| 234 |
+
"save_me": float(result.aggregated_probs.get("save_me", 0.0)),
|
| 235 |
+
}
|
| 236 |
+
st["last_probs"] = probs
|
| 237 |
+
|
| 238 |
+
now = time.time()
|
| 239 |
+
help_p, save_p = probs["help_me"], probs["save_me"]
|
| 240 |
+
normal_p = probs["normal"]
|
| 241 |
+
kw_score = max(help_p, save_p)
|
| 242 |
+
kw_label = "save_me" if save_p >= help_p else "help_me"
|
| 243 |
+
margin = kw_score - normal_p
|
| 244 |
+
in_cooldown = (now - float(st["last_alert_time"])) < cooldown_sec
|
| 245 |
+
|
| 246 |
+
meets_score = kw_score >= alert_threshold
|
| 247 |
+
meets_margin = margin >= margin_threshold
|
| 248 |
+
|
| 249 |
+
if not in_cooldown and meets_score and meets_margin:
|
| 250 |
+
st["consecutive"] = int(st["consecutive"]) + 1
|
| 251 |
+
else:
|
| 252 |
+
st["consecutive"] = 0
|
| 253 |
+
|
| 254 |
+
status = (
|
| 255 |
+
f"normal={probs['normal']:.3f} | help_me={probs['help_me']:.3f} | "
|
| 256 |
+
f"save_me={probs['save_me']:.3f} | margin={margin:.3f} | "
|
| 257 |
+
f"consec={st['consecutive']}/{n_consecutive}"
|
| 258 |
+
)
|
| 259 |
+
if in_cooldown:
|
| 260 |
+
remain = max(0.0, cooldown_sec - (now - float(st["last_alert_time"])))
|
| 261 |
+
status = f"์ฟจ๋ค์ด {remain:.1f}s | " + status
|
| 262 |
+
|
| 263 |
+
decision = "NORMAL"
|
| 264 |
+
probs_txt = f"n={normal_p:.3f} h={help_p:.3f} s={save_p:.3f} m={margin:.3f}"
|
| 265 |
+
|
| 266 |
+
if not in_cooldown and meets_score and not meets_margin:
|
| 267 |
+
decision = "HOLD"
|
| 268 |
+
event = (
|
| 269 |
+
f"{time.strftime('%H:%M:%S')} HOLD {kw_label} "
|
| 270 |
+
f"score={kw_score:.3f} {probs_txt}"
|
| 271 |
+
)
|
| 272 |
+
st["events"] = ([event] + list(st.get("events", [])))[:30]
|
| 273 |
+
elif in_cooldown:
|
| 274 |
+
decision = "COOLDOWN"
|
| 275 |
+
elif meets_score and meets_margin:
|
| 276 |
+
decision = "CANDIDATE"
|
| 277 |
+
|
| 278 |
+
if not in_cooldown and int(st["consecutive"]) >= int(n_consecutive):
|
| 279 |
+
st["last_alert_time"] = now
|
| 280 |
+
st["consecutive"] = 0
|
| 281 |
+
event = f"{time.strftime('%H:%M:%S')} ALERT {kw_label} score={kw_score:.3f} {probs_txt}"
|
| 282 |
+
st["events"] = ([event] + list(st.get("events", [])))[:30]
|
| 283 |
+
status = "ALERT ๋ฐ์ | " + status
|
| 284 |
+
decision = "ALERT"
|
| 285 |
+
|
| 286 |
+
st["last_decision"] = decision
|
| 287 |
+
return st, status, decision, probs, list(st.get("events", []))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def build_demo() -> gr.Blocks:
|
| 291 |
+
with gr.Blocks(title="KWS FP Test Page") as demo:
|
| 292 |
+
gr.Markdown("# KWS FP Test Page")
|
| 293 |
+
gr.Markdown("ํ์ผ ํ
์คํธ + ๋ธ๋ผ์ฐ์ ๋ง์ดํฌ ์ค์๊ฐ ๋ถ์")
|
| 294 |
+
|
| 295 |
+
model_path = gr.Textbox(label="๋ชจ๋ธ ๊ฒฝ๋ก", value=_default_model_path())
|
| 296 |
+
|
| 297 |
+
with gr.Row():
|
| 298 |
+
step_sec = gr.Slider(label="์ฌ๋ผ์ด๋ฉ ์คํ
(์ด)", minimum=0.1, maximum=1.0, value=0.25, step=0.05)
|
| 299 |
+
rms_threshold = gr.Slider(label="RMS ์๊ณ๊ฐ", minimum=0.0, maximum=0.05, value=0.005, step=0.001)
|
| 300 |
+
alert_threshold = gr.Slider(label="๊ฒฝ๋ณด ์๊ณ๊ฐ", minimum=0.5, maximum=0.99, value=0.9, step=0.01)
|
| 301 |
+
margin_threshold = gr.Slider(label="๋ง์ง ์๊ณ๊ฐ(ํค์๋-normal)", minimum=0.0, maximum=0.6, value=0.2, step=0.01)
|
| 302 |
+
n_consecutive = gr.Slider(label="์ฐ์ ํ์", minimum=1, maximum=5, value=3, step=1)
|
| 303 |
+
cooldown_sec = gr.Slider(label="์ฟจ๋ค์ด(์ด)", minimum=0, maximum=10, value=5, step=0.5)
|
| 304 |
+
|
| 305 |
+
with gr.Tab("ํ์ผ ํ
์คํธ"):
|
| 306 |
+
audio = gr.Audio(label="ํ
์คํธ ์ค๋์ค", type="filepath", sources=["upload", "microphone"])
|
| 307 |
+
run_btn = gr.Button("ํ
์คํธ ์คํ", variant="primary")
|
| 308 |
+
summary = gr.Markdown()
|
| 309 |
+
alerts_md = gr.Markdown()
|
| 310 |
+
table = gr.Dataframe(
|
| 311 |
+
headers=["idx", "start_sec", "end_sec", "pred", "conf", "normal", "help_me", "save_me"],
|
| 312 |
+
datatype=["number", "number", "number", "str", "number", "number", "number", "number"],
|
| 313 |
+
label="์๋์ฐ๋ณ ๊ฒฐ๊ณผ",
|
| 314 |
+
)
|
| 315 |
+
details = gr.JSON(label="์์ธ JSON")
|
| 316 |
+
run_btn.click(
|
| 317 |
+
fn=run_test,
|
| 318 |
+
inputs=[audio, model_path, step_sec, rms_threshold, alert_threshold, n_consecutive, cooldown_sec],
|
| 319 |
+
outputs=[summary, alerts_md, table, details],
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
with gr.Tab("์ค์๊ฐ ๋ง์ดํฌ"):
|
| 323 |
+
gr.Markdown("๋ง์ดํฌ๋ฅผ ์ผ ์ํ์์ ์ค์๊ฐ์ผ๋ก ํ๋ฅ /๊ฒฝ๋ณด ์ด๋ ฅ์ ๋ด
๋๋ค.")
|
| 324 |
+
st = gr.State(value=_init_stream_state())
|
| 325 |
+
live_audio = gr.Audio(label="๋ง์ดํฌ ์ค์๊ฐ ์
๋ ฅ", type="numpy", sources=["microphone"], streaming=True)
|
| 326 |
+
live_status = gr.Textbox(label="์ค์๊ฐ ์ํ")
|
| 327 |
+
live_decision = gr.Textbox(label="์ต์ข
ํ๋จ")
|
| 328 |
+
live_probs = gr.JSON(label="ํ์ฌ ํ๋ฅ ")
|
| 329 |
+
live_events = gr.JSON(label="๊ฒฝ๋ณด ์ด๋ ฅ(์ต์ ์)")
|
| 330 |
+
live_audio.stream(
|
| 331 |
+
fn=stream_infer,
|
| 332 |
+
inputs=[
|
| 333 |
+
st,
|
| 334 |
+
live_audio,
|
| 335 |
+
model_path,
|
| 336 |
+
step_sec,
|
| 337 |
+
rms_threshold,
|
| 338 |
+
alert_threshold,
|
| 339 |
+
margin_threshold,
|
| 340 |
+
n_consecutive,
|
| 341 |
+
cooldown_sec,
|
| 342 |
+
],
|
| 343 |
+
outputs=[st, live_status, live_decision, live_probs, live_events],
|
| 344 |
+
show_progress=False,
|
| 345 |
+
time_limit=300,
|
| 346 |
+
)
|
| 347 |
+
return demo
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def main() -> None:
|
| 351 |
+
parser = argparse.ArgumentParser(description="KWS FP ํ
์คํธ์ฉ Gradio ํ์ด์ง")
|
| 352 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 353 |
+
parser.add_argument("--port", type=int, default=7861)
|
| 354 |
+
parser.add_argument("--share", action="store_true")
|
| 355 |
+
args = parser.parse_args()
|
| 356 |
+
|
| 357 |
+
demo = build_demo()
|
| 358 |
+
demo.launch(server_name=args.host, server_port=args.port, share=args.share)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
main()
|
kws_inference.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
KWS ๊ธด ์ค๋์ค ์ถ๋ก : ์ฌ๋ผ์ด๋ฉ ์๋์ฐ + Max-over-windows ์ง๊ณ.
|
| 3 |
+
[์
๋ฐ์ดํธ] ์๋ฆฌ ํฌ๊ธฐ(RMS) ํํฐ๋ง ์ถ๊ฐ๋ก ๋ฌด์/๋
ธ์ด์ฆ ๊ตฌ๊ฐ ์คํ ๋ฐฉ์ง.
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import librosa
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
# train_kws์ ๋์ผ ํ๋ผ๋ฏธํฐ๋ก melยท๋ชจ๋ธ ํธํ ์ ์ง
|
| 17 |
+
from train_kws import (
|
| 18 |
+
HOP_LENGTH,
|
| 19 |
+
KWSModel,
|
| 20 |
+
MAX_TIME_FRAMES,
|
| 21 |
+
N_FFT,
|
| 22 |
+
N_MELS,
|
| 23 |
+
SR_MODEL,
|
| 24 |
+
WINDOW_SEC,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# ํ์ต ์ KWSDataset._preprocess์ ๋์ผ (์ผ๋ฐ์์ฑโ์๊ธ ์คํ ๋ฐฉ์ง)
|
| 28 |
+
RMS_NORM_TARGET = 0.05
|
| 29 |
+
PRE_EMPHASIS = 0.97
|
| 30 |
+
|
| 31 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class WindowResult:
|
| 35 |
+
"""ํ ์๋์ฐ(๊ตฌ๊ฐ) ์ถ๋ก ๊ฒฐ๊ณผ."""
|
| 36 |
+
start_sec: float
|
| 37 |
+
end_sec: float
|
| 38 |
+
pred_id: int
|
| 39 |
+
class_name: str
|
| 40 |
+
probs: Dict[str, float]
|
| 41 |
+
confidence: float
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class LongAudioResult:
|
| 45 |
+
"""๊ธด ์ค๋์ค ์ ์ฒด ์ถ๋ก ๊ฒฐ๊ณผ."""
|
| 46 |
+
aggregated_label: str
|
| 47 |
+
aggregated_probs: Dict[str, float]
|
| 48 |
+
aggregated_pred_id: int
|
| 49 |
+
window_results: List[WindowResult]
|
| 50 |
+
duration_sec: float
|
| 51 |
+
|
| 52 |
+
def _load_checkpoint(path: str) -> Tuple[Dict[str, torch.Tensor], Dict[str, int], List[str]]:
|
| 53 |
+
from checkpoint_io import load_checkpoint
|
| 54 |
+
state, label_map, class_names = load_checkpoint(path, map_location=DEVICE)
|
| 55 |
+
return state, label_map, class_names
|
| 56 |
+
|
| 57 |
+
def _preprocess_wav(y: np.ndarray) -> np.ndarray:
|
| 58 |
+
"""ํ์ต KWSDataset._preprocess์ ๋์ผ: RMS norm + pre-emphasis (์๋ฒ ์ถ๋ก ์คํ ๊ฐ์)."""
|
| 59 |
+
y = np.asarray(y, dtype=np.float32)
|
| 60 |
+
if len(y) == 0:
|
| 61 |
+
return y
|
| 62 |
+
rms = np.sqrt(np.mean(y ** 2)) + 1e-8
|
| 63 |
+
y = y * (RMS_NORM_TARGET / rms)
|
| 64 |
+
y = np.append(y[0], y[1:] - PRE_EMPHASIS * y[:-1]).astype(np.float32)
|
| 65 |
+
return y
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _audio_to_mel(y: np.ndarray, sr: int) -> np.ndarray:
|
| 69 |
+
if sr != SR_MODEL:
|
| 70 |
+
y = librosa.resample(y.astype(np.float32), orig_sr=sr, target_sr=SR_MODEL)
|
| 71 |
+
y = _preprocess_wav(y)
|
| 72 |
+
S = librosa.feature.melspectrogram(
|
| 73 |
+
y=y, sr=SR_MODEL, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH
|
| 74 |
+
)
|
| 75 |
+
S_db = np.clip(librosa.power_to_db(S, ref=1.0), -80.0, 0.0)
|
| 76 |
+
norm_mel = (S_db + 80.0) / 80.0
|
| 77 |
+
return norm_mel.astype(np.float32)
|
| 78 |
+
|
| 79 |
+
def _mel_window_to_tensor(mel: np.ndarray, start_f: int, end_f: int) -> torch.Tensor:
|
| 80 |
+
chunk = mel[:, start_f:end_f]
|
| 81 |
+
n_f = chunk.shape[1]
|
| 82 |
+
if n_f >= MAX_TIME_FRAMES:
|
| 83 |
+
chunk = chunk[:, :MAX_TIME_FRAMES]
|
| 84 |
+
else:
|
| 85 |
+
chunk = np.pad(chunk, ((0, 0), (0, MAX_TIME_FRAMES - n_f)), mode="constant", constant_values=0)
|
| 86 |
+
t = torch.from_numpy(chunk).float().unsqueeze(0).unsqueeze(0).to(DEVICE)
|
| 87 |
+
return t
|
| 88 |
+
|
| 89 |
+
def _frame_to_sec(frame_index: int) -> float:
|
| 90 |
+
return float(frame_index * HOP_LENGTH) / SR_MODEL
|
| 91 |
+
|
| 92 |
+
class KWSLongInference:
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
checkpoint_path: str,
|
| 96 |
+
window_sec: float = WINDOW_SEC,
|
| 97 |
+
step_sec: float = 0.25,
|
| 98 |
+
rms_threshold: float = 0.005, # ์๋ฆฌ ํฌ๊ธฐ ์๊ณ๊ฐ ํํฅ ์กฐ์
|
| 99 |
+
) -> None:
|
| 100 |
+
self.checkpoint_path = checkpoint_path
|
| 101 |
+
self.window_sec = window_sec
|
| 102 |
+
self.step_sec = step_sec
|
| 103 |
+
self.rms_threshold = rms_threshold
|
| 104 |
+
|
| 105 |
+
state, label_map, class_names = _load_checkpoint(checkpoint_path)
|
| 106 |
+
self.label_map = label_map
|
| 107 |
+
self.class_names = class_names
|
| 108 |
+
self.num_classes = len(class_names)
|
| 109 |
+
self.model = KWSModel(num_classes=self.num_classes).to(DEVICE)
|
| 110 |
+
self.model.load_state_dict(state, strict=True)
|
| 111 |
+
self.model.eval()
|
| 112 |
+
|
| 113 |
+
self.window_frames = MAX_TIME_FRAMES
|
| 114 |
+
self.step_frames = max(1, int(round(step_sec * SR_MODEL / HOP_LENGTH)))
|
| 115 |
+
|
| 116 |
+
def _get_windows(self, total_frames: int) -> List[Tuple[int, int]]:
|
| 117 |
+
windows: List[Tuple[int, int]] = []
|
| 118 |
+
start = 0
|
| 119 |
+
while start + self.window_frames <= total_frames:
|
| 120 |
+
windows.append((start, start + self.window_frames))
|
| 121 |
+
start += self.step_frames
|
| 122 |
+
return windows
|
| 123 |
+
|
| 124 |
+
def predict_long(self, audio: Union[str, Tuple[int, np.ndarray]]) -> LongAudioResult:
|
| 125 |
+
if isinstance(audio, str):
|
| 126 |
+
y, sr = librosa.load(audio, sr=SR_MODEL)
|
| 127 |
+
else:
|
| 128 |
+
sr, y = audio
|
| 129 |
+
y = np.asarray(y, dtype=np.float32)
|
| 130 |
+
if sr != SR_MODEL:
|
| 131 |
+
y = librosa.resample(y, orig_sr=sr, target_sr=SR_MODEL)
|
| 132 |
+
|
| 133 |
+
duration_sec = len(y) / SR_MODEL
|
| 134 |
+
mel = _audio_to_mel(y, SR_MODEL)
|
| 135 |
+
total_frames = mel.shape[1]
|
| 136 |
+
windows = self._get_windows(total_frames)
|
| 137 |
+
|
| 138 |
+
all_probs: List[np.ndarray] = []
|
| 139 |
+
window_results: List[WindowResult] = []
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for start_f, end_f in windows:
|
| 143 |
+
# ์๋์ฐ ๊ตฌ๊ฐ์ ์ค๋์ค ์ํ ์ถ์ถํ์ฌ RMS ๊ณ์ฐ
|
| 144 |
+
start_sample = int(start_f * HOP_LENGTH)
|
| 145 |
+
end_sample = int(end_f * HOP_LENGTH)
|
| 146 |
+
y_chunk = y[start_sample:end_sample]
|
| 147 |
+
rms = np.sqrt(np.mean(y_chunk**2)) if len(y_chunk) > 0 else 0
|
| 148 |
+
|
| 149 |
+
start_sec = _frame_to_sec(start_f)
|
| 150 |
+
end_sec = _frame_to_sec(end_f)
|
| 151 |
+
|
| 152 |
+
if rms < self.rms_threshold:
|
| 153 |
+
# [RMS ํํฐ๋ง] ์๋ฆฌ๊ฐ ๋๋ฌด ์์ผ๋ฉด ๋ฌด์กฐ๊ฑด Normal๋ก ์ฒ๋ฆฌ
|
| 154 |
+
probs = np.zeros(self.num_classes)
|
| 155 |
+
probs[self.label_map.get("normal", 0)] = 1.0
|
| 156 |
+
pred_id = self.label_map.get("normal", 0)
|
| 157 |
+
else:
|
| 158 |
+
# ๋ชจ๋ธ ์ถ๋ก
|
| 159 |
+
t = _mel_window_to_tensor(mel, start_f, end_f)
|
| 160 |
+
logits = self.model(t)
|
| 161 |
+
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
|
| 162 |
+
pred_id = int(np.argmax(probs))
|
| 163 |
+
|
| 164 |
+
all_probs.append(probs)
|
| 165 |
+
conf = float(probs[pred_id])
|
| 166 |
+
probs_dict = {name: float(probs[i]) for i, name in enumerate(self.class_names)}
|
| 167 |
+
|
| 168 |
+
window_results.append(
|
| 169 |
+
WindowResult(
|
| 170 |
+
start_sec=start_sec,
|
| 171 |
+
end_sec=end_sec,
|
| 172 |
+
pred_id=pred_id,
|
| 173 |
+
class_name=self.class_names[pred_id],
|
| 174 |
+
probs=probs_dict,
|
| 175 |
+
confidence=conf,
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if not all_probs:
|
| 180 |
+
return LongAudioResult("normal", {}, 0, [], duration_sec)
|
| 181 |
+
|
| 182 |
+
# Max over windows ์ง๊ณ
|
| 183 |
+
stacked = np.array(all_probs)
|
| 184 |
+
max_per_class = np.max(stacked, axis=0)
|
| 185 |
+
agg_pred_id = int(np.argmax(max_per_class))
|
| 186 |
+
agg_label = self.class_names[agg_pred_id]
|
| 187 |
+
agg_probs = {name: float(max_per_class[i]) for i, name in enumerate(self.class_names)}
|
| 188 |
+
|
| 189 |
+
return LongAudioResult(agg_label, agg_probs, agg_pred_id, window_results, duration_sec)
|
| 190 |
+
|
| 191 |
+
def run_long_inference(
|
| 192 |
+
wav_path: str,
|
| 193 |
+
checkpoint_path: str = "kws_final_model/best_kws_model.pth",
|
| 194 |
+
step_sec: float = 0.25,
|
| 195 |
+
) -> LongAudioResult:
|
| 196 |
+
engine = KWSLongInference(checkpoint_path=checkpoint_path, step_sec=step_sec)
|
| 197 |
+
return engine.predict_long(wav_path)
|
kws_models_fpfix/best_kws_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:77212523c6e434385d36c1bd2a972f8a49706af0125357b4f6070d221f5d3283
|
| 3 |
+
size 1693516
|
kws_models_fpfix/training_config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"manifest": "./dataset_hf/manifests/kws.jsonl",
|
| 3 |
+
"dataset_root": "./dataset_hf",
|
| 4 |
+
"data_dir": "./kws_data",
|
| 5 |
+
"output_dir": "./kws_models_fpfix",
|
| 6 |
+
"epochs": 80,
|
| 7 |
+
"lr": 0.0005,
|
| 8 |
+
"batch_size": 64,
|
| 9 |
+
"keywords": [
|
| 10 |
+
"help_me",
|
| 11 |
+
"save_me"
|
| 12 |
+
],
|
| 13 |
+
"include_normal": true,
|
| 14 |
+
"normal_dir": "/home/dusen0528/LastResNet/data",
|
| 15 |
+
"trim_silence": false,
|
| 16 |
+
"rms_norm": 0.05,
|
| 17 |
+
"augment_noise_scale": 0.005,
|
| 18 |
+
"augment_shift": 0.1,
|
| 19 |
+
"augment_pitch": "0,2",
|
| 20 |
+
"augment_time_stretch": "0.9,1.1",
|
| 21 |
+
"spec_augment_on": true,
|
| 22 |
+
"fast_augment": true,
|
| 23 |
+
"mel_gpu": true,
|
| 24 |
+
"class_weight": "inverse",
|
| 25 |
+
"dropout": 0.3,
|
| 26 |
+
"weight_decay": 0.0001,
|
| 27 |
+
"early_stop_patience": 20,
|
| 28 |
+
"lr_scheduler_patience": 5,
|
| 29 |
+
"hf_repo": "",
|
| 30 |
+
"hf_token": "",
|
| 31 |
+
"wandb_project": "kws-fpfix",
|
| 32 |
+
"wandb_entity": "",
|
| 33 |
+
"amp": true,
|
| 34 |
+
"workers": 4,
|
| 35 |
+
"best_epoch": 51,
|
| 36 |
+
"val_recall_positive_min": 0.926829268292683,
|
| 37 |
+
"class_names": [
|
| 38 |
+
"normal",
|
| 39 |
+
"help_me",
|
| 40 |
+
"save_me"
|
| 41 |
+
],
|
| 42 |
+
"label_map": {
|
| 43 |
+
"normal": 0,
|
| 44 |
+
"help_me": 1,
|
| 45 |
+
"save_me": 2
|
| 46 |
+
},
|
| 47 |
+
"normal_dir_added": 1000
|
| 48 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=6.5.1
|
| 2 |
+
numpy
|
| 3 |
+
librosa
|
| 4 |
+
torch
|
| 5 |
+
safetensors
|
| 6 |
+
scikit-learn
|
| 7 |
+
wandb
|
| 8 |
+
tqdm
|
train_kws.py
ADDED
|
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Keyword Spotting (KWS) ๋ชจ๋ธ ํ์ต ์คํฌ๋ฆฝํธ - Production Optimized
|
| 3 |
+
- WandB Sweep ์ฐ๋ (augment_shift, noise, class_weight ๋ฑ)
|
| 4 |
+
- False Positive Rate (์คํ๋ฅ ) ๋ฉํธ๋ฆญ ์ถ๊ฐ
|
| 5 |
+
- Audio Shift Augmentation ๊ตฌํ
|
| 6 |
+
"""
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.utils.data import Dataset, DataLoader
|
| 14 |
+
import numpy as np
|
| 15 |
+
import librosa
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from sklearn.metrics import f1_score, precision_recall_fscore_support, accuracy_score, confusion_matrix
|
| 19 |
+
from sklearn.model_selection import train_test_split
|
| 20 |
+
import wandb
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
# --- Global Config ---
|
| 24 |
+
SR_MODEL = 16000
|
| 25 |
+
N_MELS = 128
|
| 26 |
+
N_FFT = 1024
|
| 27 |
+
HOP_LENGTH = 512
|
| 28 |
+
WINDOW_SEC = 1.5
|
| 29 |
+
MAX_TIME_FRAMES = int(round((WINDOW_SEC * SR_MODEL - N_FFT) / HOP_LENGTH + 1))
|
| 30 |
+
# ๊ณ ์ ์
๋ ฅ ํ๋ ์์ ํ์ํ ์ํ ์ (GPU ๋ฉ์ฉ ๊ณ ์ ๊ธธ์ด)
|
| 31 |
+
MAX_AUDIO_SAMPLES = (MAX_TIME_FRAMES - 1) * HOP_LENGTH + N_FFT
|
| 32 |
+
|
| 33 |
+
def _compute_mel_fast(y: np.ndarray, sr: int) -> np.ndarray:
|
| 34 |
+
"""๋ฉ ์คํํธ๋ก๊ทธ๋จ ๊ณ์ฐ. torchaudio ์ฌ์ฉ ์ librosa๋ณด๋ค ๋น ๋ฆ, ๋์ผ (n_mels, time) 0~1."""
|
| 35 |
+
try:
|
| 36 |
+
import torchaudio
|
| 37 |
+
t = torch.from_numpy(y).float().unsqueeze(0)
|
| 38 |
+
mel = torchaudio.functional.mel_spectrogram(
|
| 39 |
+
t, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=N_FFT,
|
| 40 |
+
f_min=0.0, f_max=float(sr) / 2, n_mels=N_MELS, power=2.0,
|
| 41 |
+
)
|
| 42 |
+
mel = mel.squeeze(0).numpy()
|
| 43 |
+
mel_db = np.clip(10.0 * np.log10(np.maximum(mel, 1e-10)), -80.0, 0.0)
|
| 44 |
+
return ((mel_db + 80.0) / 80.0).astype(np.float32)
|
| 45 |
+
except Exception:
|
| 46 |
+
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH)
|
| 47 |
+
S_db = np.clip(librosa.power_to_db(S, ref=1.0), -80.0, 0.0)
|
| 48 |
+
return ((S_db + 80.0) / 80.0).astype(np.float32)
|
| 49 |
+
|
| 50 |
+
def _load_audio_fast(path: str) -> tuple[np.ndarray, int]:
|
| 51 |
+
"""torchaudio ์ฐ์ ์ฌ์ฉ (librosa๋ณด๋ค ๋ก๋ฉ ๋น ๋ฆ), ์คํจ ์ librosa ํด๋ฐฑ."""
|
| 52 |
+
try:
|
| 53 |
+
import torchaudio
|
| 54 |
+
wav, sr = torchaudio.load(path)
|
| 55 |
+
if wav.shape[0] > 1:
|
| 56 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 57 |
+
if sr != SR_MODEL:
|
| 58 |
+
wav = torchaudio.functional.resample(wav, sr, SR_MODEL)
|
| 59 |
+
return wav.squeeze(0).numpy().astype(np.float32), SR_MODEL
|
| 60 |
+
except Exception:
|
| 61 |
+
y, sr = librosa.load(path, sr=SR_MODEL)
|
| 62 |
+
return y.astype(np.float32), sr
|
| 63 |
+
|
| 64 |
+
def str2bool(v):
|
| 65 |
+
if isinstance(v, bool): return v
|
| 66 |
+
return v.lower() in ('yes', 'true', 't', 'y', '1')
|
| 67 |
+
|
| 68 |
+
def _parse_range(s):
|
| 69 |
+
"""'0.8,1.2' ๊ฐ์ ๋ฌธ์์ด์ (0.8, 1.2) ํํ๋ก ๋ณํ"""
|
| 70 |
+
if isinstance(s, list): return tuple(s)
|
| 71 |
+
try:
|
| 72 |
+
s = str(s).strip("[]").split(",")
|
| 73 |
+
return (float(s[0]), float(s[1]))
|
| 74 |
+
except: return (0, 0)
|
| 75 |
+
|
| 76 |
+
class KWSDataset(Dataset):
|
| 77 |
+
def __init__(self, items, label_map=None, trim_silence=False, trim_top_db=30.0,
|
| 78 |
+
rms_norm_target=0.05, pre_emphasis=0.97, augment_noise_scale=0.0,
|
| 79 |
+
augment_time_stretch=(0.0, 0.0), augment_pitch=(0.0, 0.0),
|
| 80 |
+
augment_shift=0.0, spec_augment=(0, 0), preprocess_cache=None, mel_cache=None,
|
| 81 |
+
skip_heavy_augment=False, mel_gpu=False):
|
| 82 |
+
self.label_map = label_map
|
| 83 |
+
self.items = items
|
| 84 |
+
self.trim_silence = trim_silence
|
| 85 |
+
self.trim_top_db = trim_top_db
|
| 86 |
+
self.rms_norm_target = rms_norm_target
|
| 87 |
+
self.pre_emphasis = pre_emphasis
|
| 88 |
+
self.preprocess_cache = preprocess_cache
|
| 89 |
+
self.mel_cache = mel_cache if not mel_gpu else None # mel_gpu๋ฉด Dataset์์ ๋ฉ ์ ํจ
|
| 90 |
+
self.skip_heavy_augment = skip_heavy_augment
|
| 91 |
+
self.mel_gpu = mel_gpu
|
| 92 |
+
|
| 93 |
+
# Augmentations
|
| 94 |
+
self.augment_noise_scale = augment_noise_scale
|
| 95 |
+
self.augment_time_stretch = augment_time_stretch
|
| 96 |
+
self.augment_pitch = augment_pitch
|
| 97 |
+
self.augment_shift = augment_shift
|
| 98 |
+
self.spec_augment = spec_augment
|
| 99 |
+
|
| 100 |
+
def __len__(self): return len(self.items)
|
| 101 |
+
|
| 102 |
+
def _preprocess(self, y):
|
| 103 |
+
# 1. Silence Trim (์ต์
: ์คํ ๋ฐฉ์ง๋ฅผ ์ํด ๋ณดํต False ๊ถ์ฅ)
|
| 104 |
+
if self.trim_silence and len(y) > 0:
|
| 105 |
+
y, _ = librosa.effects.trim(y, top_db=self.trim_top_db)
|
| 106 |
+
if len(y) == 0: y = np.zeros(1024, dtype=np.float32)
|
| 107 |
+
|
| 108 |
+
# 2. RMS Norm (๋ณผ๋ฅจ ์ ๊ทํ)
|
| 109 |
+
if self.rms_norm_target > 0:
|
| 110 |
+
rms = np.sqrt(np.mean(y ** 2)) + 1e-8
|
| 111 |
+
y = y * (self.rms_norm_target / rms)
|
| 112 |
+
|
| 113 |
+
# 3. Pre-emphasis (๊ณ ์ฃผํ ๊ฐ์กฐ)
|
| 114 |
+
if self.pre_emphasis != 0:
|
| 115 |
+
y = np.append(y[0], y[1:] - self.pre_emphasis * y[:-1]).astype(np.float32)
|
| 116 |
+
return y
|
| 117 |
+
|
| 118 |
+
def _apply_augment(self, y, sr):
|
| 119 |
+
# A. Time Stretch (๋ฌด๊ฑฐ์, --fast-augment ์ ์๋ต)
|
| 120 |
+
if not self.skip_heavy_augment and self.augment_time_stretch[0] != self.augment_time_stretch[1]:
|
| 121 |
+
rate = random.uniform(self.augment_time_stretch[0], self.augment_time_stretch[1])
|
| 122 |
+
if abs(rate - 1.0) > 0.01:
|
| 123 |
+
y = librosa.effects.time_stretch(y, rate=rate)
|
| 124 |
+
|
| 125 |
+
# B. Pitch Shift (๋ฌด๊ฑฐ์, --fast-augment ์ ์๋ต)
|
| 126 |
+
if not self.skip_heavy_augment and self.augment_pitch[0] != self.augment_pitch[1]:
|
| 127 |
+
n_steps = random.uniform(self.augment_pitch[0], self.augment_pitch[1])
|
| 128 |
+
if abs(n_steps) > 0.01:
|
| 129 |
+
y = librosa.effects.pitch_shift(y, sr=sr, n_steps=n_steps)
|
| 130 |
+
|
| 131 |
+
# C. Time Shift (์์น ์ด๋) - [New] YAML ๋์
|
| 132 |
+
# ์ค๋์ค๋ฅผ ์ข์ฐ๋ก ๋ฐ๊ณ ๋น ๊ณต๊ฐ์ 0(์นจ๋ฌต)์ผ๋ก ์ฑ์ (np.roll์ ์ํ์ด๋ผ ๋น์ถ์ฒ)
|
| 133 |
+
if self.augment_shift > 0:
|
| 134 |
+
shift_sec = random.uniform(-self.augment_shift, self.augment_shift)
|
| 135 |
+
shift_samples = int(shift_sec * sr)
|
| 136 |
+
if abs(shift_samples) > 0:
|
| 137 |
+
y_shifted = np.zeros_like(y)
|
| 138 |
+
if shift_samples > 0: # ์ค๋ฅธ์ชฝ์ผ๋ก ๋ฐ๊ธฐ (์์ ์นจ๋ฌต)
|
| 139 |
+
if shift_samples < len(y):
|
| 140 |
+
y_shifted[shift_samples:] = y[:-shift_samples]
|
| 141 |
+
else: # ์ผ์ชฝ์ผ๋ก ๋ฐ๊ธฐ (๋ค์ ์นจ๋ฌต)
|
| 142 |
+
shift_samples = abs(shift_samples)
|
| 143 |
+
if shift_samples < len(y):
|
| 144 |
+
y_shifted[:-shift_samples] = y[shift_samples:]
|
| 145 |
+
y = y_shifted
|
| 146 |
+
|
| 147 |
+
# D. Noise Injection (์์) - [Updated] ์ ๊ทํ ์ดํ์ ๋ฃ์ด์ผ ์ผ๊ด๋จ
|
| 148 |
+
if self.augment_noise_scale > 0:
|
| 149 |
+
rms = np.sqrt(np.mean(y ** 2)) + 1e-8
|
| 150 |
+
# White Noise ์์ฑ (์ค์ ํ๊ฒฝ์ Mix๊ฐ ๋ ์ข์ง๋ง, ํ์ฌ๋ Gaussian์ผ๋ก ๋์ฒด)
|
| 151 |
+
noise = np.random.randn(len(y)).astype(np.float32)
|
| 152 |
+
y = y + self.augment_noise_scale * rms * noise
|
| 153 |
+
|
| 154 |
+
return y
|
| 155 |
+
|
| 156 |
+
def _apply_spec_augment(self, mel):
|
| 157 |
+
t_max, f_max = self.spec_augment
|
| 158 |
+
if t_max <= 0 and f_max <= 0: return mel
|
| 159 |
+
n_mels, T = mel.shape
|
| 160 |
+
# Time Masking
|
| 161 |
+
if t_max > 0 and T > t_max:
|
| 162 |
+
t0 = random.randint(0, T - t_max)
|
| 163 |
+
mel[:, t0 : t0 + t_max] = mel.mean() # ํ๊ท ๊ฐ์ผ๋ก ๋ง์คํน
|
| 164 |
+
# Freq Masking
|
| 165 |
+
if f_max > 0 and n_mels > f_max:
|
| 166 |
+
f0 = random.randint(0, n_mels - f_max)
|
| 167 |
+
mel[f0 : f0 + f_max, :] = mel.mean()
|
| 168 |
+
return mel
|
| 169 |
+
|
| 170 |
+
def __getitem__(self, idx):
|
| 171 |
+
item = self.items[idx]
|
| 172 |
+
path = os.path.normpath(item['wav_path'])
|
| 173 |
+
try:
|
| 174 |
+
# mel_gpu๊ฐ ์๋๊ณ ๊ฒ์ฆ์ฉ ๋ฉ ์บ์ ํํธ๋ฉด ๊ทธ๋๋ก ๋ฐํ
|
| 175 |
+
if not self.mel_gpu and self.mel_cache is not None and path in self.mel_cache:
|
| 176 |
+
return self.mel_cache[path].clone(), item['label_id']
|
| 177 |
+
|
| 178 |
+
# 1. Load + ์ ์ฒ๋ฆฌ
|
| 179 |
+
if self.preprocess_cache is not None:
|
| 180 |
+
if path in self.preprocess_cache:
|
| 181 |
+
y, sr = self.preprocess_cache[path]
|
| 182 |
+
else:
|
| 183 |
+
y, sr = _load_audio_fast(item['wav_path'])
|
| 184 |
+
y = self._preprocess(y)
|
| 185 |
+
self.preprocess_cache[path] = (y.copy(), sr)
|
| 186 |
+
else:
|
| 187 |
+
y, sr = _load_audio_fast(item['wav_path'])
|
| 188 |
+
y = self._preprocess(y)
|
| 189 |
+
|
| 190 |
+
# 2. Augment
|
| 191 |
+
y = self._apply_augment(y, sr)
|
| 192 |
+
|
| 193 |
+
if self.mel_gpu:
|
| 194 |
+
# 3a. GPU ๋ฉ ๋ชจ๋: ํํ๋ง ํจ๋ฉ/์๋ฅด๊ธฐ ํ ๋ฐํ (๋ฉ์ ํ์ต ๋ฃจํ์์ GPU ๋ฐฐ์น ๊ณ์ฐ)
|
| 195 |
+
if len(y) > MAX_AUDIO_SAMPLES:
|
| 196 |
+
y = y[:MAX_AUDIO_SAMPLES]
|
| 197 |
+
else:
|
| 198 |
+
y = np.pad(y, (0, MAX_AUDIO_SAMPLES - len(y)), mode="constant", constant_values=0)
|
| 199 |
+
return torch.from_numpy(y.astype(np.float32)), item['label_id']
|
| 200 |
+
|
| 201 |
+
# 3b. CPU ๋ฉ ๋ชจ๋: ๊ธฐ์กด์ฒ๋ผ ๋ฉ ๊ณ์ฐ ํ ๋ฐํ
|
| 202 |
+
norm_mel = _compute_mel_fast(y, sr)
|
| 203 |
+
norm_mel = self._apply_spec_augment(norm_mel.copy())
|
| 204 |
+
if norm_mel.shape[1] > MAX_TIME_FRAMES:
|
| 205 |
+
norm_mel = norm_mel[:, :MAX_TIME_FRAMES]
|
| 206 |
+
else:
|
| 207 |
+
norm_mel = np.pad(norm_mel, ((0, 0), (0, MAX_TIME_FRAMES - norm_mel.shape[1])))
|
| 208 |
+
out = torch.from_numpy(norm_mel).float().unsqueeze(0)
|
| 209 |
+
if self.mel_cache is not None:
|
| 210 |
+
self.mel_cache[path] = out.clone()
|
| 211 |
+
return out, item['label_id']
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(f"โ ๏ธ Load Error ({item['wav_path']}): {e}")
|
| 214 |
+
if self.mel_gpu:
|
| 215 |
+
return torch.zeros(MAX_AUDIO_SAMPLES, dtype=torch.float32), item['label_id']
|
| 216 |
+
return torch.zeros(1, N_MELS, MAX_TIME_FRAMES), item['label_id']
|
| 217 |
+
|
| 218 |
+
class KWSModel(nn.Module):
|
| 219 |
+
def __init__(self, num_classes=3, n_mels=128, dropout=0.0):
|
| 220 |
+
super(KWSModel, self).__init__()
|
| 221 |
+
self.conv1 = nn.Conv1d(n_mels, 64, kernel_size=3, padding=1)
|
| 222 |
+
self.bn1 = nn.BatchNorm1d(64)
|
| 223 |
+
self.relu = nn.ReLU()
|
| 224 |
+
# Stride๋ฅผ ํ์ฉํด ์๊ฐ ์ฐจ์ ์ถ์
|
| 225 |
+
self.layer1 = self._make_layer(64, 64, 1)
|
| 226 |
+
self.layer2 = self._make_layer(64, 128, 2)
|
| 227 |
+
self.layer3 = self._make_layer(128, 256, 2)
|
| 228 |
+
self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
|
| 229 |
+
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()
|
| 230 |
+
self.fc = nn.Linear(256, num_classes)
|
| 231 |
+
|
| 232 |
+
def _make_layer(self, in_c, out_c, stride):
|
| 233 |
+
return nn.Sequential(
|
| 234 |
+
nn.Conv1d(in_c, out_c, 3, stride, 1, bias=False), nn.BatchNorm1d(out_c), nn.ReLU(),
|
| 235 |
+
nn.Conv1d(out_c, out_c, 3, 1, 1, bias=False), nn.BatchNorm1d(out_c), nn.ReLU()
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
x = x.squeeze(1) # [B, 1, F, T] -> [B, F, T]
|
| 240 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 241 |
+
x = self.layer3(self.layer2(self.layer1(x)))
|
| 242 |
+
x = self.adaptive_pool(x).flatten(1)
|
| 243 |
+
return self.fc(self.dropout(x))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class MelExtractorGPU(nn.Module):
|
| 247 |
+
"""GPU์์ ๋ฐฐ์น ๋จ์๋ก ๋ฉ ์คํํธ๋ก๊ทธ๋จ + (์ต์
) Spec Augment. Dataset์ ํํ๋ง ๋ฐํํ๋ฉด ๋จ."""
|
| 248 |
+
def __init__(self, spec_augment: tuple[int, int] = (0, 0)):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.spec_augment = spec_augment
|
| 251 |
+
try:
|
| 252 |
+
import torchaudio
|
| 253 |
+
self._mel_fn = torchaudio.transforms.MelSpectrogram(
|
| 254 |
+
sample_rate=SR_MODEL,
|
| 255 |
+
n_fft=N_FFT,
|
| 256 |
+
win_length=N_FFT,
|
| 257 |
+
hop_length=HOP_LENGTH,
|
| 258 |
+
f_min=0.0,
|
| 259 |
+
f_max=float(SR_MODEL) / 2,
|
| 260 |
+
n_mels=N_MELS,
|
| 261 |
+
power=2.0,
|
| 262 |
+
)
|
| 263 |
+
except Exception:
|
| 264 |
+
self._mel_fn = None
|
| 265 |
+
|
| 266 |
+
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
| 267 |
+
# waveform: (B, samples)
|
| 268 |
+
if self._mel_fn is None:
|
| 269 |
+
raise RuntimeError("MelExtractorGPU requires torchaudio")
|
| 270 |
+
mel = self._mel_fn(waveform)
|
| 271 |
+
mel_db = 10.0 * torch.log10(torch.clamp(mel, min=1e-10))
|
| 272 |
+
mel_db = torch.clamp(mel_db, -80.0, 0.0)
|
| 273 |
+
norm_mel = (mel_db + 80.0) / 80.0
|
| 274 |
+
# (B, n_mels, time) -> pad/crop to MAX_TIME_FRAMES
|
| 275 |
+
T = norm_mel.shape[2]
|
| 276 |
+
if T > MAX_TIME_FRAMES:
|
| 277 |
+
norm_mel = norm_mel[:, :, :MAX_TIME_FRAMES]
|
| 278 |
+
elif T < MAX_TIME_FRAMES:
|
| 279 |
+
norm_mel = torch.nn.functional.pad(norm_mel, (0, MAX_TIME_FRAMES - T))
|
| 280 |
+
if self.training and self.spec_augment[0] > 0 and self.spec_augment[1] > 0:
|
| 281 |
+
norm_mel = self._spec_augment(norm_mel)
|
| 282 |
+
return norm_mel.unsqueeze(1) # (B, 1, N_MELS, MAX_TIME_FRAMES)
|
| 283 |
+
|
| 284 |
+
def _spec_augment(self, mel: torch.Tensor) -> torch.Tensor:
|
| 285 |
+
B, F, T = mel.shape
|
| 286 |
+
t_max, f_max = self.spec_augment
|
| 287 |
+
if t_max > 0 and T > t_max:
|
| 288 |
+
t0 = torch.randint(0, T - t_max + 1, (B,), device=mel.device)
|
| 289 |
+
for i in range(B):
|
| 290 |
+
mel[i, :, t0[i] : t0[i] + t_max] = mel[i].mean()
|
| 291 |
+
if f_max > 0 and F > f_max:
|
| 292 |
+
f0 = torch.randint(0, F - f_max + 1, (B,), device=mel.device)
|
| 293 |
+
for i in range(B):
|
| 294 |
+
mel[i, f0[i] : f0[i] + f_max, :] = mel[i].mean()
|
| 295 |
+
return mel
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def _serializable_config(args, best_epoch, best_metric, class_names, label_map, normal_dir_added=0):
|
| 299 |
+
out = {}
|
| 300 |
+
for k, v in vars(args).items():
|
| 301 |
+
if v is None or (isinstance(v, float) and (v != v)): continue
|
| 302 |
+
if isinstance(v, (str, int, float, bool)) or (isinstance(v, (list, tuple)) and all(isinstance(x, (str, int, float, bool)) for x in (v or []))):
|
| 303 |
+
out[k] = v
|
| 304 |
+
else: out[k] = str(v)
|
| 305 |
+
out.update({
|
| 306 |
+
"best_epoch": best_epoch,
|
| 307 |
+
"val_recall_positive_min": best_metric,
|
| 308 |
+
"class_names": class_names,
|
| 309 |
+
"label_map": label_map,
|
| 310 |
+
"normal_dir_added": normal_dir_added
|
| 311 |
+
})
|
| 312 |
+
return out
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
parser = argparse.ArgumentParser()
|
| 316 |
+
# Path
|
| 317 |
+
parser.add_argument('--manifest', type=str, default='')
|
| 318 |
+
parser.add_argument('--dataset-root', type=str, default='')
|
| 319 |
+
parser.add_argument('--data-dir', type=str, default='./kws_data')
|
| 320 |
+
parser.add_argument('--output-dir', type=str, default='./kws_models')
|
| 321 |
+
|
| 322 |
+
# Train
|
| 323 |
+
parser.add_argument('--epochs', type=int, default=50)
|
| 324 |
+
parser.add_argument('--lr', type=float, default=0.001)
|
| 325 |
+
parser.add_argument('--batch-size', type=int, default=32)
|
| 326 |
+
parser.add_argument('--keywords', nargs='+', default=['help_me', 'save_me'])
|
| 327 |
+
parser.add_argument('--include-normal', action='store_true')
|
| 328 |
+
parser.add_argument('--normal-dir', type=str, default='', help='์ผ๋ฐ(Normal) ๋ฐ์ดํฐ ์ถ๊ฐ ๊ฒฝ๋ก')
|
| 329 |
+
|
| 330 |
+
# Augment
|
| 331 |
+
parser.add_argument('--trim_silence', type=str2bool, default=False)
|
| 332 |
+
parser.add_argument('--rms_norm', type=float, default=0.05)
|
| 333 |
+
parser.add_argument('--augment_noise_scale', type=float, default=0.0)
|
| 334 |
+
parser.add_argument('--augment_shift', type=float, default=0.0, help='[New] Time Shift Augmentation (sec)')
|
| 335 |
+
parser.add_argument('--augment_pitch', type=str, default='0,0')
|
| 336 |
+
parser.add_argument('--augment_time_stretch', type=str, default='0,0')
|
| 337 |
+
parser.add_argument('--spec_augment_on', type=str2bool, default=False)
|
| 338 |
+
parser.add_argument('--fast-augment', action='store_true', help='time_stretch/pitch_shift ์๋ต โ ์์ปค CPU ๋ถํ ๊ฐ์, GPU ํ์ฉ๋ ์์น (sweep ๋ณ๋ชฉ ์ ์ฌ์ฉ)')
|
| 339 |
+
parser.add_argument('--mel-gpu', action='store_true', help='๋ฉ ์คํํธ๋ก๊ทธ๋จ์ GPU์์ ๋ฐฐ์น ๊ณ์ฐ โ ์์ปค CPU ๋ถํ ๊ฐ์, ํ์ต ๊ฐ์')
|
| 340 |
+
|
| 341 |
+
# Regularization (๊ณผ์ ํฉ ์ํ)
|
| 342 |
+
parser.add_argument('--class_weight', type=str, default='inverse')
|
| 343 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 344 |
+
parser.add_argument('--weight_decay', type=float, default=1e-4, help='L2 ์ ๊ทํ (Adam). 0์ด๋ฉด ๋นํ์ฑํ')
|
| 345 |
+
parser.add_argument('--early_stop_patience', type=int, default=10)
|
| 346 |
+
parser.add_argument('--lr_scheduler_patience', type=int, default=0, help='val_loss ๊ฐ์ ์์ ๋ LR ๊ฐ์ ๋๊ธฐ ์ํญ. 0์ด๋ฉด ๋นํ์ฑํ')
|
| 347 |
+
|
| 348 |
+
# HF / WandB (HF ๊ธฐ๋ณธ: dusen0528/kws)
|
| 349 |
+
parser.add_argument('--hf-repo', type=str, default='', help='Hugging Face ๋ชจ๋ธ repo (๊ธฐ๋ณธ: dusen0528/kws)')
|
| 350 |
+
parser.add_argument('--hf-token', type=str, default='')
|
| 351 |
+
parser.add_argument('--wandb-project', type=str, default='kws')
|
| 352 |
+
parser.add_argument('--wandb-entity', type=str, default='')
|
| 353 |
+
parser.add_argument('--amp', type=str2bool, default=True, help='GPU์์ ํผํฉ์ ๋ฐ(AMP) ์ฌ์ฉ โ ํ์ต ์๋ ํฅ์')
|
| 354 |
+
parser.add_argument('--workers', type=int, default=4, help='DataLoader worker ์ (0์ด๋ฉด ๋ฉ์ธ๋ง)')
|
| 355 |
+
|
| 356 |
+
args, _ = parser.parse_known_args()
|
| 357 |
+
|
| 358 |
+
# --- 1. Label Map & Data Loading ---
|
| 359 |
+
label_map = {}
|
| 360 |
+
# Normal์ด ์์ผ๋ฉด 0๋ฒ์ผ๋ก ๊ณ ์ (๊ด๋ก)
|
| 361 |
+
if args.include_normal or (args.normal_dir and os.path.isdir(args.normal_dir)):
|
| 362 |
+
label_map['normal'] = 0
|
| 363 |
+
for kw in args.keywords: label_map[kw] = len(label_map)
|
| 364 |
+
num_classes = len(label_map)
|
| 365 |
+
class_names = [k for k, v in sorted(label_map.items(), key=lambda x: x[1])]
|
| 366 |
+
|
| 367 |
+
records = []
|
| 368 |
+
# A. Manifest Load
|
| 369 |
+
if args.manifest and os.path.exists(args.manifest):
|
| 370 |
+
root = Path(args.dataset_root)
|
| 371 |
+
with open(args.manifest, "r", encoding="utf-8") as f:
|
| 372 |
+
for line in f:
|
| 373 |
+
try:
|
| 374 |
+
obj = json.loads(line)
|
| 375 |
+
lbl = obj.get("label")
|
| 376 |
+
if lbl in label_map:
|
| 377 |
+
ap = obj["audio_path"]
|
| 378 |
+
full_p = ap if os.path.isabs(ap) else str(root / ap)
|
| 379 |
+
if os.path.exists(full_p):
|
| 380 |
+
records.append({"wav_path": full_p, "label": lbl, "label_id": label_map[lbl]})
|
| 381 |
+
except: pass
|
| 382 |
+
|
| 383 |
+
# B. Folder Load (Fallback)
|
| 384 |
+
if not records:
|
| 385 |
+
for ln, lid in label_map.items():
|
| 386 |
+
dp = os.path.join(args.data_dir, ln)
|
| 387 |
+
if os.path.isdir(dp):
|
| 388 |
+
for f in os.listdir(dp):
|
| 389 |
+
if f.lower().endswith(('.wav', '.mp3')):
|
| 390 |
+
records.append({"wav_path": os.path.join(dp, f), "label": ln, "label_id": lid})
|
| 391 |
+
|
| 392 |
+
# C. Normal Dir Augmentation
|
| 393 |
+
normal_dir_added = 0
|
| 394 |
+
if args.normal_dir and os.path.isdir(args.normal_dir) and 'normal' in label_map:
|
| 395 |
+
n_before = len(records)
|
| 396 |
+
for f in os.listdir(args.normal_dir):
|
| 397 |
+
if f.lower().endswith(('.wav', '.mp3')):
|
| 398 |
+
full_p = os.path.join(args.normal_dir, f)
|
| 399 |
+
if os.path.isfile(full_p):
|
| 400 |
+
records.append({"wav_path": full_p, "label": "normal", "label_id": label_map["normal"]})
|
| 401 |
+
normal_dir_added = len(records) - n_before
|
| 402 |
+
|
| 403 |
+
if not records: print("โ ๋ฐ์ดํฐ ์์"); return
|
| 404 |
+
|
| 405 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 406 |
+
hf_repo = (args.hf_repo or os.environ.get("HF_REPO", "")).strip() or "dusen0528/kws"
|
| 407 |
+
hf_token = args.hf_token or os.environ.get("HF_TOKEN", "")
|
| 408 |
+
|
| 409 |
+
# --- 2. WandB Init ---
|
| 410 |
+
run_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 411 |
+
if os.environ.get("WANDB_SWEEP_ID"):
|
| 412 |
+
wandb.init(config=vars(args), name=run_name)
|
| 413 |
+
# Sweep ํ๋ผ๋ฏธํฐ๊ฐ args๋ฅผ ๋ฎ์ด์
|
| 414 |
+
for k, v in wandb.config.items():
|
| 415 |
+
setattr(args, k, v)
|
| 416 |
+
else:
|
| 417 |
+
wandb.init(config=vars(args), project=args.wandb_project, entity=args.wandb_entity or None, name=run_name)
|
| 418 |
+
|
| 419 |
+
# --- 3. Dataset & Loader ---
|
| 420 |
+
stratify_ids = [r['label_id'] for r in records]
|
| 421 |
+
try:
|
| 422 |
+
train_items, val_items = train_test_split(records, test_size=0.2, random_state=42, stratify=stratify_ids)
|
| 423 |
+
except ValueError:
|
| 424 |
+
train_items, val_items = train_test_split(records, test_size=0.2, random_state=42)
|
| 425 |
+
|
| 426 |
+
preprocess_cache: dict = {}
|
| 427 |
+
fast_augment = getattr(args, 'fast_augment', False)
|
| 428 |
+
mel_gpu = getattr(args, "mel_gpu", False)
|
| 429 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 430 |
+
if mel_gpu and device.type != "cuda":
|
| 431 |
+
mel_gpu = False
|
| 432 |
+
print(" โ ๏ธ --mel-gpu๋ CUDA์์๋ง ๋์, ๋นํ์ฑํํจ")
|
| 433 |
+
train_ds = KWSDataset(train_items, label_map,
|
| 434 |
+
trim_silence=args.trim_silence,
|
| 435 |
+
rms_norm_target=args.rms_norm,
|
| 436 |
+
augment_noise_scale=args.augment_noise_scale,
|
| 437 |
+
augment_pitch=_parse_range(args.augment_pitch),
|
| 438 |
+
augment_time_stretch=_parse_range(args.augment_time_stretch),
|
| 439 |
+
augment_shift=args.augment_shift,
|
| 440 |
+
spec_augment=(15, 10) if args.spec_augment_on else (0, 0),
|
| 441 |
+
preprocess_cache=preprocess_cache,
|
| 442 |
+
skip_heavy_augment=fast_augment,
|
| 443 |
+
mel_gpu=mel_gpu)
|
| 444 |
+
|
| 445 |
+
val_mel_cache: dict = {} if not mel_gpu else {}
|
| 446 |
+
val_ds = KWSDataset(val_items, label_map,
|
| 447 |
+
trim_silence=args.trim_silence,
|
| 448 |
+
rms_norm_target=args.rms_norm,
|
| 449 |
+
preprocess_cache=preprocess_cache,
|
| 450 |
+
mel_cache=val_mel_cache,
|
| 451 |
+
mel_gpu=mel_gpu)
|
| 452 |
+
|
| 453 |
+
use_amp = args.amp and device.type == 'cuda'
|
| 454 |
+
num_workers = args.workers if device.type == 'cuda' else 0
|
| 455 |
+
pin = (device.type == 'cuda')
|
| 456 |
+
loader_kw = dict(batch_size=args.batch_size, pin_memory=pin, num_workers=num_workers)
|
| 457 |
+
if num_workers > 0:
|
| 458 |
+
loader_kw["persistent_workers"] = True
|
| 459 |
+
loader_kw["prefetch_factor"] = 8
|
| 460 |
+
train_loader = DataLoader(train_ds, shuffle=True, **loader_kw)
|
| 461 |
+
val_loader = DataLoader(val_ds, **loader_kw)
|
| 462 |
+
if use_amp:
|
| 463 |
+
print("โก AMP(ํผํฉ์ ๋ฐ) ์ฌ์ฉ")
|
| 464 |
+
|
| 465 |
+
# --- 4. Model & Loss ---
|
| 466 |
+
model = KWSModel(num_classes=num_classes, dropout=args.dropout).to(device)
|
| 467 |
+
mel_extractor = None
|
| 468 |
+
if mel_gpu and device.type == "cuda":
|
| 469 |
+
mel_extractor = MelExtractorGPU(
|
| 470 |
+
spec_augment=(15, 10) if args.spec_augment_on else (0, 0)
|
| 471 |
+
).to(device)
|
| 472 |
+
optimizer = torch.optim.Adam(
|
| 473 |
+
model.parameters(),
|
| 474 |
+
lr=args.lr,
|
| 475 |
+
weight_decay=getattr(args, 'weight_decay', 0.0),
|
| 476 |
+
)
|
| 477 |
+
scheduler = None
|
| 478 |
+
if getattr(args, 'lr_scheduler_patience', 0) > 0:
|
| 479 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 480 |
+
optimizer, mode='min', factor=0.5, patience=args.lr_scheduler_patience, min_lr=1e-6
|
| 481 |
+
)
|
| 482 |
+
print(f" ๐ LR Scheduler: val_loss {args.lr_scheduler_patience}ep ๊ฐ์ ์์ผ๋ฉด LRร0.5")
|
| 483 |
+
|
| 484 |
+
class_weights = None
|
| 485 |
+
if args.class_weight == 'inverse':
|
| 486 |
+
from collections import Counter
|
| 487 |
+
counts = Counter([r['label_id'] for r in train_items])
|
| 488 |
+
# N_samples / (N_classes * count)
|
| 489 |
+
weights = [len(train_items) / (num_classes * (counts.get(i, 0) + 1)) for i in range(num_classes)]
|
| 490 |
+
class_weights = torch.tensor(weights, dtype=torch.float32).to(device)
|
| 491 |
+
print(f"โ๏ธ Class Weights: {class_weights.tolist()}")
|
| 492 |
+
|
| 493 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 494 |
+
scaler = torch.amp.GradScaler("cuda") if use_amp else None
|
| 495 |
+
|
| 496 |
+
# --- 5. Training Loop ---
|
| 497 |
+
best_rec_min = 0.0
|
| 498 |
+
best_fpr = 1.0
|
| 499 |
+
best_epoch = 0
|
| 500 |
+
patience = 0
|
| 501 |
+
|
| 502 |
+
print(f"๐ ํ์ต ์์ (๋ฐ์ดํฐ: {len(records)}๊ฐ, ํด๋์ค: {class_names})")
|
| 503 |
+
print(f" workers={num_workers} batch={args.batch_size} | GPU ๋ฎ๊ณ pt_data_worker 100%๋ฉด: --batch-size 64 --workers 8, --fast-augment, --mel-gpu")
|
| 504 |
+
if fast_augment:
|
| 505 |
+
print(f" โก --fast-augment: time_stretch/pitch_shift ์๋ต โ ์์ปค ๋ถํ ๊ฐ์")
|
| 506 |
+
if mel_gpu and mel_extractor is not None:
|
| 507 |
+
print(f" โก --mel-gpu: ๋ฉ ์คํํธ๋ก๊ทธ๋จ์ GPU์์ ๋ฐฐ์น ๊ณ์ฐ โ ์์ปค CPU ๋ถํ ๊ฐ์")
|
| 508 |
+
|
| 509 |
+
pbar_epoch = tqdm(range(1, args.epochs + 1), desc="Epoch", unit="ep")
|
| 510 |
+
for e in pbar_epoch:
|
| 511 |
+
# A. Train
|
| 512 |
+
model.train()
|
| 513 |
+
t_loss = 0; t_corr = 0; t_total = 0
|
| 514 |
+
train_pbar = tqdm(train_loader, desc=f"Ep {e}", leave=False, unit="batch")
|
| 515 |
+
for m, l in train_pbar:
|
| 516 |
+
m, l = m.to(device, non_blocking=True), l.to(device, non_blocking=True)
|
| 517 |
+
if mel_extractor is not None:
|
| 518 |
+
mel_extractor.train()
|
| 519 |
+
m = mel_extractor(m)
|
| 520 |
+
optimizer.zero_grad()
|
| 521 |
+
if use_amp and scaler is not None:
|
| 522 |
+
with torch.amp.autocast("cuda"):
|
| 523 |
+
out = model(m)
|
| 524 |
+
loss = criterion(out, l)
|
| 525 |
+
scaler.scale(loss).backward()
|
| 526 |
+
scaler.step(optimizer)
|
| 527 |
+
scaler.update()
|
| 528 |
+
else:
|
| 529 |
+
out = model(m)
|
| 530 |
+
loss = criterion(out, l)
|
| 531 |
+
loss.backward()
|
| 532 |
+
optimizer.step()
|
| 533 |
+
t_loss += loss.item()
|
| 534 |
+
_, p = torch.max(out, 1)
|
| 535 |
+
t_corr += (p == l).sum().item()
|
| 536 |
+
t_total += l.size(0)
|
| 537 |
+
train_pbar.set_postfix(loss=f"{t_loss / max(1, train_pbar.n):.3f}")
|
| 538 |
+
|
| 539 |
+
train_loss = t_loss / len(train_loader)
|
| 540 |
+
train_acc = 100 * t_corr / t_total if t_total else 0
|
| 541 |
+
|
| 542 |
+
# B. Validation
|
| 543 |
+
model.eval()
|
| 544 |
+
v_loss = 0; v_preds, v_labels = [], []
|
| 545 |
+
with torch.no_grad():
|
| 546 |
+
for m, l in tqdm(val_loader, desc="Val", leave=False, unit="batch"):
|
| 547 |
+
m, l = m.to(device, non_blocking=True), l.to(device, non_blocking=True)
|
| 548 |
+
if mel_extractor is not None:
|
| 549 |
+
mel_extractor.eval()
|
| 550 |
+
with torch.no_grad():
|
| 551 |
+
m = mel_extractor(m)
|
| 552 |
+
if use_amp:
|
| 553 |
+
with torch.amp.autocast("cuda"):
|
| 554 |
+
out = model(m)
|
| 555 |
+
loss = criterion(out, l)
|
| 556 |
+
else:
|
| 557 |
+
out = model(m)
|
| 558 |
+
loss = criterion(out, l)
|
| 559 |
+
v_loss += loss.item()
|
| 560 |
+
_, p = torch.max(out, 1)
|
| 561 |
+
v_preds.extend(p.cpu().numpy())
|
| 562 |
+
v_labels.extend(l.cpu().numpy())
|
| 563 |
+
|
| 564 |
+
val_loss = v_loss / len(val_loader) if len(val_loader) else 0
|
| 565 |
+
val_acc = 100 * accuracy_score(v_labels, v_preds) if v_labels else 0
|
| 566 |
+
|
| 567 |
+
# C. Metrics (Recall per class & False Positive Rate)
|
| 568 |
+
labels_idx = list(range(num_classes))
|
| 569 |
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
| 570 |
+
v_labels, v_preds, labels=labels_idx, average=None, zero_division=0
|
| 571 |
+
)
|
| 572 |
+
rec = np.atleast_1d(rec).astype(float)
|
| 573 |
+
prec = np.atleast_1d(prec).astype(float)
|
| 574 |
+
f1 = np.atleast_1d(f1).astype(float)
|
| 575 |
+
if len(rec) < num_classes:
|
| 576 |
+
rec = np.pad(rec, (0, num_classes - len(rec)), constant_values=0.0)
|
| 577 |
+
prec = np.pad(prec, (0, num_classes - len(prec)), constant_values=0.0)
|
| 578 |
+
f1 = np.pad(f1, (0, num_classes - len(f1)), constant_values=0.0)
|
| 579 |
+
|
| 580 |
+
log_dict = {
|
| 581 |
+
"epoch": e,
|
| 582 |
+
"train/loss": train_loss, "train/acc": train_acc,
|
| 583 |
+
"val/loss": val_loss, "val/acc": val_acc,
|
| 584 |
+
"val/f1_macro": f1_score(v_labels, v_preds, average='macro', zero_division=0),
|
| 585 |
+
"val/train_loss_gap": val_loss - train_loss, # >0 ์ด๋ฉด val์ด train๋ณด๋ค ๋์จ โ ์ค๋ฒํผํ
์์ฌ
|
| 586 |
+
"val/train_acc_gap": train_acc - val_acc, # >0 ์ด๋ฉด train์ด val๋ณด๋ค ์ข์ โ ์ค๋ฒํผํ
์์ฌ
|
| 587 |
+
}
|
| 588 |
+
if scheduler is not None:
|
| 589 |
+
log_dict["train/lr"] = optimizer.param_groups[0]["lr"]
|
| 590 |
+
|
| 591 |
+
pos_rec = []
|
| 592 |
+
for i, name in enumerate(class_names):
|
| 593 |
+
safe_name = name.replace("/", "_")
|
| 594 |
+
log_dict[f"val/recall_{safe_name}"] = rec[i]
|
| 595 |
+
log_dict[f"val/f1_{safe_name}"] = f1[i]
|
| 596 |
+
if name in ['help_me', 'save_me']:
|
| 597 |
+
pos_rec.append(rec[i])
|
| 598 |
+
|
| 599 |
+
# [Metric 1] Recall Positive Min (๋ชฉํ: ๋ฏธํ ๋ฐฉ์ง)
|
| 600 |
+
rec_min = min(pos_rec) if pos_rec else 0.0
|
| 601 |
+
log_dict["val/recall_positive_min"] = rec_min
|
| 602 |
+
|
| 603 |
+
# [Metric 2] False Positive Rate (๋ชฉํ: ์คํ ๋ฐฉ์ง)
|
| 604 |
+
# Normal ๋ฐ์ดํฐ(์ค์ 0๋ฒ)๊ฐ ๋ค์ด์๋๋ฐ -> 0๋ฒ์ด ์๋๋ผ๊ณ (Positive๋ผ๊ณ ) ์์ธกํ ๋น์จ
|
| 605 |
+
if 'normal' in label_map:
|
| 606 |
+
norm_idx = label_map['normal']
|
| 607 |
+
norm_mask = [i for i, x in enumerate(v_labels) if x == norm_idx]
|
| 608 |
+
if norm_mask:
|
| 609 |
+
norm_preds = [v_preds[i] for i in norm_mask]
|
| 610 |
+
false_alarms = sum(1 for p in norm_preds if p != norm_idx)
|
| 611 |
+
fpr = false_alarms / len(norm_mask)
|
| 612 |
+
else:
|
| 613 |
+
fpr = 0.0
|
| 614 |
+
log_dict["val/false_positive_rate"] = fpr
|
| 615 |
+
|
| 616 |
+
wandb.log(log_dict)
|
| 617 |
+
|
| 618 |
+
current_fpr = log_dict.get("val/false_positive_rate", 1.0)
|
| 619 |
+
loss_gap = val_loss - train_loss
|
| 620 |
+
acc_gap = train_acc - val_acc
|
| 621 |
+
if scheduler is not None:
|
| 622 |
+
scheduler.step(val_loss)
|
| 623 |
+
pbar_epoch.set_postfix(loss=f"{train_loss:.3f}", rec_min=f"{rec_min:.2f}", fpr=f"{current_fpr:.2f}")
|
| 624 |
+
print(f"[Ep {e}] Loss: {train_loss:.4f} | RecMin: {rec_min:.2f} | FPR: {current_fpr:.2f}")
|
| 625 |
+
# ๊ณผ์ ํฉ ๊ฒฝ๊ณ : ๊ธฐ์ค์ ๊ฐํ๊ฒ ํด์ ๋งค ์ํญ ๋จ์ง ์๊ฒ ํจ (loss_gap 0.5 ์ด์ ๋๋ acc_gap 15% ์ด์์ผ ๋๋ง)
|
| 626 |
+
if e > 5 and (loss_gap > 0.5 or acc_gap > 15.0):
|
| 627 |
+
print(f" โ ๏ธ ์ค๋ฒํผํ
๊ฐ๋ฅ์ฑ (val_loss - train_loss = {loss_gap:.3f}, train_acc - val_acc = {acc_gap:.1f}%)")
|
| 628 |
+
if e == 10 and train_loss > 0.5 and val_loss > 0.5:
|
| 629 |
+
print(f" ๐ก ์ธ๋ํผํ
๊ฐ๋ฅ์ฑ (Ep 10์ธ๋ฐ loss ๋ ๋ค ๋์ โ epoch/๋ชจ๋ธ/ํ์ต๋ฅ ๊ฒํ )")
|
| 630 |
+
|
| 631 |
+
# D. Checkpoint
|
| 632 |
+
# ์ ์ฅ ์กฐ๊ฑด 1: Recall Min์ด ๋ ๋์ผ๋ฉด ๋ฌด์กฐ๊ฑด ์ ์ฅ (๋ฏธํ ๋ฐฉ์ง ์ต์ฐ์ )
|
| 633 |
+
# ์ ์ฅ ์กฐ๊ฑด 2: Recall Min์ด ๊ฐ์ผ๋ฉด, ์คํ๋ฅ (FPR)์ด ๋ ๋ฎ์ ๋ชจ๋ธ ์ ์ฅ
|
| 634 |
+
is_best = False
|
| 635 |
+
if rec_min > best_rec_min:
|
| 636 |
+
is_best = True
|
| 637 |
+
elif rec_min == best_rec_min and rec_min > 0:
|
| 638 |
+
if current_fpr < best_fpr:
|
| 639 |
+
is_best = True
|
| 640 |
+
print(f"โจ Recall ๋์ ({rec_min:.2f})์ด๋ FPR ๊ฐ์ ๋จ ({best_fpr:.2f} -> {current_fpr:.2f})")
|
| 641 |
+
|
| 642 |
+
if is_best:
|
| 643 |
+
best_rec_min = rec_min
|
| 644 |
+
best_fpr = current_fpr
|
| 645 |
+
best_epoch = e
|
| 646 |
+
patience = 0
|
| 647 |
+
from checkpoint_io import save_checkpoint
|
| 648 |
+
save_checkpoint(
|
| 649 |
+
args.output_dir,
|
| 650 |
+
model.state_dict(),
|
| 651 |
+
_serializable_config(args, best_epoch, best_rec_min, class_names, label_map, normal_dir_added),
|
| 652 |
+
also_save_pth=False,
|
| 653 |
+
)
|
| 654 |
+
print(f"๐พ Best ์ ์ฅ | Ep {e} | RecMin {best_rec_min:.2f} | FPR {best_fpr:.2f} | โ best_kws_model.safetensors")
|
| 655 |
+
else:
|
| 656 |
+
patience += 1
|
| 657 |
+
if args.early_stop_patience > 0 and patience >= args.early_stop_patience:
|
| 658 |
+
print("๐ Early Stopping")
|
| 659 |
+
break
|
| 660 |
+
|
| 661 |
+
# ํ์ต ์ข
๋ฃ
|
| 662 |
+
if getattr(wandb, "run", None) is not None:
|
| 663 |
+
wandb.run.summary["val/recall_positive_min"] = best_rec_min
|
| 664 |
+
wandb.run.summary["best_epoch"] = best_epoch
|
| 665 |
+
wandb.run.summary["best_fpr"] = best_fpr
|
| 666 |
+
|
| 667 |
+
# --- HF: sweep/ํ์ต ์ข
๋ฃ ํ ์
๋ก๋ (๊ธฐ์ค ๋ ์ง ํด๋ + ํ๋ผ๋ฏธํฐยท์ฑ๋ฅ) ---
|
| 668 |
+
if hf_repo:
|
| 669 |
+
try:
|
| 670 |
+
from huggingface_hub import HfApi
|
| 671 |
+
api = HfApi(token=hf_token or None)
|
| 672 |
+
api.create_repo(hf_repo, repo_type="model", exist_ok=True)
|
| 673 |
+
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 674 |
+
folder = f"run_{ts}"
|
| 675 |
+
for f in ["best_kws_model.safetensors", "best_kws_model.pth", "training_config.json"]:
|
| 676 |
+
p = os.path.join(args.output_dir, f)
|
| 677 |
+
if os.path.exists(p):
|
| 678 |
+
api.upload_file(path_or_fileobj=p, path_in_repo=f"{folder}/{f}", repo_id=hf_repo, repo_type="model")
|
| 679 |
+
metrics = {
|
| 680 |
+
"uploaded_at": datetime.now().isoformat(),
|
| 681 |
+
"type": "run_post_training",
|
| 682 |
+
"best_epoch": best_epoch,
|
| 683 |
+
"val_recall_positive_min": best_rec_min,
|
| 684 |
+
"best_fpr": best_fpr,
|
| 685 |
+
}
|
| 686 |
+
if getattr(wandb, "run", None) is not None:
|
| 687 |
+
run = wandb.run
|
| 688 |
+
metrics["wandb_run_id"] = getattr(run, "id", "") or ""
|
| 689 |
+
metrics["wandb_run_name"] = getattr(run, "name", "") or ""
|
| 690 |
+
metrics["wandb_run_url"] = getattr(run, "url", "") or ""
|
| 691 |
+
metrics_path = os.path.join(args.output_dir, "metrics.json")
|
| 692 |
+
with open(metrics_path, "w", encoding="utf-8") as mf:
|
| 693 |
+
json.dump(metrics, mf, indent=2, ensure_ascii=False)
|
| 694 |
+
api.upload_file(path_or_fileobj=metrics_path, path_in_repo=f"{folder}/metrics.json", repo_id=hf_repo, repo_type="model")
|
| 695 |
+
# best/ = ๋ชจ๋ run ์ค ์ง์ง ์ต๊ณ ๋ง ์ ์ง. ๊ธฐ์กด best ์งํ์ ๋น๊ตํด ๋ ์ข์ ๋๋ง ๋ฎ์ด์ฐ๊ธฐ
|
| 696 |
+
should_update_best = True
|
| 697 |
+
existing_rec, existing_fpr = 0.0, 1.0
|
| 698 |
+
try:
|
| 699 |
+
from huggingface_hub import hf_hub_download
|
| 700 |
+
path = hf_hub_download(
|
| 701 |
+
repo_id=hf_repo,
|
| 702 |
+
filename="best/metrics.json",
|
| 703 |
+
repo_type="model",
|
| 704 |
+
token=hf_token or None,
|
| 705 |
+
)
|
| 706 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 707 |
+
existing = json.load(f)
|
| 708 |
+
existing_rec = float(existing.get("val_recall_positive_min", 0.0))
|
| 709 |
+
existing_fpr = float(existing.get("best_fpr", 1.0))
|
| 710 |
+
# ๋ ์ข์ = rec_min ๋ ๋์, ๋๋ rec_min ๋์ ์ด๋ฉด FPR ๋ ๋ฎ์
|
| 711 |
+
if best_rec_min > existing_rec or (best_rec_min == existing_rec and best_fpr < existing_fpr):
|
| 712 |
+
should_update_best = True
|
| 713 |
+
else:
|
| 714 |
+
should_update_best = False
|
| 715 |
+
except Exception:
|
| 716 |
+
# best/metrics.json ์์ (์ฒซ run) โ best ๊ฐฑ์
|
| 717 |
+
should_update_best = True
|
| 718 |
+
if should_update_best:
|
| 719 |
+
for f in ["best_kws_model.safetensors", "best_kws_model.pth", "training_config.json"]:
|
| 720 |
+
p = os.path.join(args.output_dir, f)
|
| 721 |
+
if os.path.exists(p):
|
| 722 |
+
api.upload_file(path_or_fileobj=p, path_in_repo=f"best/{f}", repo_id=hf_repo, repo_type="model")
|
| 723 |
+
api.upload_file(path_or_fileobj=metrics_path, path_in_repo="best/metrics.json", repo_id=hf_repo, repo_type="model")
|
| 724 |
+
print(f"[KWS] HF ์
๋ก๋ ์๋ฃ: {hf_repo} -> {folder}/ (ํ์คํ ๋ฆฌ) + best/ (์ ์ฒด run ์ค ์ต๊ณ ๋ก ๊ฐฑ์ )")
|
| 725 |
+
else:
|
| 726 |
+
print(f"[KWS] HF ์
๋ก๋ ์๋ฃ: {hf_repo} -> {folder}/ (ํ์คํ ๋ฆฌ). best/ ๋ฏธ๊ฐฑ์ : ์ด๋ฒ rec_min={best_rec_min:.2f} fpr={best_fpr:.2f} vs ๊ธฐ์กด rec_min={existing_rec:.2f} fpr={existing_fpr:.2f}")
|
| 727 |
+
except Exception as e:
|
| 728 |
+
print(f"[KWS] HF Upload Error: {e}")
|
| 729 |
+
|
| 730 |
+
if __name__ == "__main__": main()
|