emergency_any_test / train_model.py
dusen0528's picture
Upload folder using huggingface_hub
77c8117 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
==============================================================================
ํ•œ๊ตญ์–ด ํ˜ธ์ถœ์–ด ๋ชจ๋ธ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ
==============================================================================
openWakeWord์˜ Transfer Learning ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜์—ฌ ์ปค์Šคํ…€ ํ˜ธ์ถœ์–ด ๋ชจ๋ธ์„ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.
ํ•™์Šต ํ”„๋กœ์„ธ์Šค:
1. positive ๋ฐ์ดํ„ฐ(ํ˜ธ์ถœ์–ด ์Œ์„ฑ)์—์„œ embedding ์ถ”์ถœ
2. negative ๋ฐ์ดํ„ฐ(๋น„ํ˜ธ์ถœ์–ด ์Œ์„ฑ)์—์„œ embedding ์ถ”์ถœ
3. ๊ฐ„๋‹จํ•œ DNN ๋ถ„๋ฅ˜๊ธฐ ํ•™์Šต
4. ONNX ๋ชจ๋ธ๋กœ ๋‚ด๋ณด๋‚ด๊ธฐ
์‚ฌ์šฉ๋ฒ•:
python train_model.py --positive_dir ./positive --model_name my_model
"""
import os
import sys
from pathlib import Path
from typing import List, Tuple, Optional
import glob
import importlib.util
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import soundfile as sf
from scipy import signal as scipy_signal
# openWakeWord ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๊ฒฝ๋กœ ์ถ”๊ฐ€ (๋กœ์ปฌ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ์‚ฌ์šฉ)
sys.path.insert(0, str(Path(__file__).parent.parent / "openWakeWord"))
try:
from openwakeword.utils import AudioFeatures
print("โœ… openwakeword.utils.AudioFeatures ์ž„ํฌํŠธ ์„ฑ๊ณต")
except ImportError as e:
print(f"โŒ openwakeword ์ž„ํฌํŠธ ์‹คํŒจ: {e}")
print(" -> 'pip install openwakeword' ์‹คํ–‰ ํ•„์š”")
sys.exit(1)
# ============================================
# ์„ค์ •
# ============================================
POSITIVE_DIR = "./positive" # positive ๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ
MODEL_NAME = "my_model" # ์ถœ๋ ฅ ๋ชจ๋ธ ์ด๋ฆ„
OUTPUT_DIR = "./" # ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ
SAMPLE_RATE = 16000 # openWakeWord ์š”๊ตฌ์‚ฌํ•ญ
CLIP_DURATION_SAMPLES = 32000 # 2์ดˆ ํด๋ฆฝ (16000 * 2)
# ์ผ๋ฐ˜ ๋Œ€ํ™” ์Œ์„ฑ negative ๊ธฐ๋ณธ ๊ฒฝ๋กœ (Keyword-Spotting ๋ฐ์ดํ„ฐ)
DEFAULT_NEGATIVE_DIR = "/home/dusen0528/Keyword-Spotting/data"
# ํ•™์Šต ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ (GPU ์‚ฌ์šฉ ์‹œ BATCH_SIZE ์ž๋™ ํ™•๋Œ€)
EPOCHS = 50
BATCH_SIZE = 32 # CPU ๊ธฐ๋ณธ๊ฐ’, GPU ์‹œ train_model()์—์„œ 128๋กœ ํ™•๋Œ€
LEARNING_RATE = 0.001
LAYER_DIM = 128 # DNN ํžˆ๋“  ๋ ˆ์ด์–ด ํฌ๊ธฐ
N_BLOCKS = 1 # DNN ๋ธ”๋ก ์ˆ˜
class WakeWordModel(nn.Module):
"""
๊ฐ„๋‹จํ•œ Wake Word ๋ถ„๋ฅ˜ ๋ชจ๋ธ
openWakeWord embedding (96์ฐจ์›)์„ ์ž…๋ ฅ๋ฐ›์•„ ์ด์ง„ ๋ถ„๋ฅ˜๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
def __init__(self, input_shape: Tuple[int, int], layer_dim: int = 128, n_blocks: int = 1):
"""
Args:
input_shape: (timesteps, features) - ์˜ˆ: (16, 96)
layer_dim: ํžˆ๋“  ๋ ˆ์ด์–ด ์ฐจ์›
n_blocks: ์ถ”๊ฐ€ ๋ธ”๋ก ์ˆ˜
"""
super().__init__()
self.input_shape = input_shape
flat_size = input_shape[0] * input_shape[1]
# ์ฒซ ๋ฒˆ์งธ ๋ ˆ์ด์–ด
layers = [
nn.Flatten(),
nn.Linear(flat_size, layer_dim),
nn.LayerNorm(layer_dim),
nn.ReLU()
]
# ์ถ”๊ฐ€ ๋ธ”๋ก
for _ in range(n_blocks):
layers.extend([
nn.Linear(layer_dim, layer_dim),
nn.LayerNorm(layer_dim),
nn.ReLU()
])
# ์ถœ๋ ฅ ๋ ˆ์ด์–ด
layers.extend([
nn.Linear(layer_dim, 1),
nn.Sigmoid()
])
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
def load_audio_files(directory: str, max_files: Optional[int] = None) -> List[np.ndarray]:
"""
๋””๋ ‰ํ† ๋ฆฌ์—์„œ WAV ํŒŒ์ผ๋“ค์„ ๋กœ๋“œ
Args:
directory: WAV ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ
max_files: ์ตœ๋Œ€ ๋กœ๋“œํ•  ํŒŒ์ผ ์ˆ˜ (None์ด๋ฉด ์ „์ฒด)
Returns:
์˜ค๋””์˜ค ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ (๊ฐ ์š”์†Œ๋Š” 16kHz, 16-bit PCM numpy array)
"""
wav_files = sorted(Path(directory).glob("*.wav"))
if max_files:
wav_files = wav_files[:max_files]
audio_data = []
for wav_file in tqdm(wav_files, desc=f"{directory} ๋กœ๋“œ ์ค‘"):
try:
# ์˜ค๋””์˜ค ํŒŒ์ผ ์ฝ๊ธฐ (24kHz ๋“ฑ์ด๋ฉด float๋กœ ์ฝ์–ด์„œ ๋ฆฌ์ƒ˜ํ”Œ)
data, sr = sf.read(str(wav_file), dtype='float64')
if len(data.shape) > 1:
data = data[:, 0]
if sr != SAMPLE_RATE:
data = _resample_to_16k(data, sr)
else:
data = (np.clip(data, -1.0, 1.0) * 32767).astype(np.int16)
# ํด๋ฆฝ ๊ธธ์ด ์กฐ์ • (ํŒจ๋”ฉ ๋˜๋Š” ์ž๋ฅด๊ธฐ)
if len(data) < CLIP_DURATION_SAMPLES:
# ์•ž์ชฝ์— ํŒจ๋”ฉ ์ถ”๊ฐ€ (ํ˜ธ์ถœ์–ด๊ฐ€ ๋์— ์˜ค๋„๋ก)
padded = np.zeros(CLIP_DURATION_SAMPLES, dtype=np.int16)
start_idx = CLIP_DURATION_SAMPLES - len(data)
padded[start_idx:] = data
data = padded
elif len(data) > CLIP_DURATION_SAMPLES:
# ๋๋ถ€๋ถ„ ์œ ์ง€ (ํ˜ธ์ถœ์–ด๊ฐ€ ๋์— ์žˆ๋‹ค๊ณ  ๊ฐ€์ •)
data = data[-CLIP_DURATION_SAMPLES:]
audio_data.append(data)
except Exception as e:
print(f"โŒ {wav_file.name} ๋กœ๋“œ ์‹คํŒจ: {e}")
return audio_data
def _resample_to_16k(audio: np.ndarray, orig_sr: int) -> np.ndarray:
"""์˜ค๋””์˜ค๋ฅผ 16kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ๋ง ํ›„ int16 ๋ฐ˜ํ™˜."""
if orig_sr == SAMPLE_RATE:
return (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16) if audio.dtype == np.float64 or audio.dtype == np.float32 else audio.astype(np.int16)
n = int(len(audio) * SAMPLE_RATE / orig_sr)
out = scipy_signal.resample(audio, n)
out = np.clip(out, -1.0, 1.0)
return (out * 32767).astype(np.int16)
def load_negative_from_dir(
negative_dir: str,
num_samples: int,
clip_length: int = CLIP_DURATION_SAMPLES,
) -> List[np.ndarray]:
"""
์ผ๋ฐ˜ ๋Œ€ํ™” ์Œ์„ฑ ๋””๋ ‰ํ† ๋ฆฌ์—์„œ negative ํด๋ฆฝ ๋กœ๋“œ.
ํŒŒ์ผ์ด 44.1kHz ๋“ฑ์ด๋ฉด 16kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ๋ง ํ›„, 2์ดˆ(32000์ƒ˜ํ”Œ) ๊ตฌ๊ฐ„์„ ๋žœ๋ค ์‹œ์ž‘์œผ๋กœ ์ž˜๋ผ ์‚ฌ์šฉ.
Args:
negative_dir: WAV ํŒŒ์ผ๋“ค์ด ์žˆ๋Š” ๋””๋ ‰ํ† ๋ฆฌ
num_samples: ํ•„์š”ํ•œ negative ํด๋ฆฝ ์ˆ˜
clip_length: ํด๋ฆฝ ๊ธธ์ด(์ƒ˜ํ”Œ ์ˆ˜)
Returns:
16kHz, int16, ๊ธธ์ด clip_length์ธ ์˜ค๋””์˜ค ๋ฐฐ์—ด ๋ฆฌ์ŠคํŠธ
"""
wav_files = sorted(Path(negative_dir).glob("*.wav"))
if not wav_files:
return []
negative_data = []
rng = np.random.default_rng()
for _ in tqdm(range(num_samples), desc="Negative(์ผ๋ฐ˜๋Œ€ํ™”) ๋กœ๋“œ ์ค‘"):
# ๋ฌด์ž‘์œ„ ํŒŒ์ผ ์„ ํƒ
fpath = str(rng.choice(wav_files))
try:
data, sr = sf.read(fpath, dtype="float64")
except Exception:
continue
if len(data) == 0:
continue
if data.ndim > 1:
data = data[:, 0]
# 16kHz๋กœ ๋ฆฌ์ƒ˜ํ”Œ
data_16k = _resample_to_16k(data, sr)
if len(data_16k) < clip_length:
# ํŒจ๋”ฉ
pad = np.zeros(clip_length - len(data_16k), dtype=np.int16)
data_16k = np.concatenate([data_16k, pad])
else:
# ๋žœ๋ค ๊ตฌ๊ฐ„ ์Šฌ๋ผ์ด์Šค
start = rng.integers(0, len(data_16k) - clip_length + 1)
data_16k = data_16k[start : start + clip_length]
negative_data.append(data_16k)
return negative_data
def generate_negative_data(num_samples: int) -> List[np.ndarray]:
"""
Negative ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (๋ฌด์ž‘์œ„ ๋…ธ์ด์ฆˆ + ๋ฌด์Œ).
negative_dir์„ ์“ฐ์ง€ ์•Š์„ ๋•Œ๋งŒ ์‚ฌ์šฉ.
"""
negative_data = []
for i in tqdm(range(num_samples), desc="Negative(๋…ธ์ด์ฆˆ) ์ƒ์„ฑ ์ค‘"):
if np.random.random() < 0.8:
noise_level = np.random.uniform(100, 1000)
data = np.random.normal(0, noise_level, CLIP_DURATION_SAMPLES).astype(np.int16)
else:
data = np.zeros(CLIP_DURATION_SAMPLES, dtype=np.int16)
negative_data.append(data)
return negative_data
def extract_embeddings(
audio_clips: List[np.ndarray],
feature_extractor: AudioFeatures,
use_gpu: bool = True,
) -> np.ndarray:
"""
์˜ค๋””์˜ค ํด๋ฆฝ๋“ค์—์„œ openWakeWord embedding ์ถ”์ถœ
Args:
audio_clips: ์˜ค๋””์˜ค ๋ฐ์ดํ„ฐ ๋ฆฌ์ŠคํŠธ
feature_extractor: AudioFeatures ์ธ์Šคํ„ด์Šค
use_gpu: GPU ์‚ฌ์šฉ ์‹œ True (๋ฐฐ์น˜ ํฌ๊ธฐ ํ™•๋Œ€)
Returns:
(N, timesteps, 96) ํ˜•ํƒœ์˜ embedding array
"""
# ๋ฐฐ์—ด๋กœ ๋ณ€ํ™˜
clips_array = np.vstack([c[None, :] for c in audio_clips])
# GPU ์‚ฌ์šฉ ์‹œ ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€๋กœ ์†๋„ ํ–ฅ์ƒ
batch_size = 128 if (use_gpu and torch.cuda.is_available()) else 32
print(f"\n๐Ÿ“Š Embedding ์ถ”์ถœ ์ค‘... (์ด {len(clips_array)}๊ฐœ ํด๋ฆฝ, batch_size={batch_size})")
# embed_clips ๋ฉ”์„œ๋“œ ์‚ฌ์šฉ
embeddings = feature_extractor.embed_clips(clips_array, batch_size=batch_size)
print(f" Embedding ํ˜•ํƒœ: {embeddings.shape}")
return embeddings
def _resolve_openwakeword_resource_paths() -> Tuple[Optional[str], Optional[str]]:
"""
openwakeword ๊ธฐ๋ณธ ๋ฆฌ์†Œ์Šค ๊ฒฝ๋กœ๋ฅผ ์ฐพ๋Š”๋‹ค.
1) ๋กœ์ปฌ ๋ฆฌํฌ(openWakeWord/openwakeword/resources/models)
2) ํ˜„์žฌ venv site-packages/openwakeword/resources/models
"""
local_models_dir = Path(__file__).parent.parent / "openWakeWord" / "openwakeword" / "resources" / "models"
mel_local = local_models_dir / "melspectrogram.onnx"
emb_local = local_models_dir / "embedding_model.onnx"
if mel_local.exists() and emb_local.exists():
return str(mel_local), str(emb_local)
# venv site-packages fallback
venv_base = Path(__file__).parent / ".venv" / "lib"
candidates = sorted(glob.glob(str(venv_base / "python*" / "site-packages" / "openwakeword" / "resources" / "models")))
for c in candidates:
cdir = Path(c)
mel = cdir / "melspectrogram.onnx"
emb = cdir / "embedding_model.onnx"
if mel.exists() and emb.exists():
return str(mel), str(emb)
return None, None
def _check_onnx_export_dependencies() -> None:
"""
ONNX ๋‚ด๋ณด๋‚ด๊ธฐ ํ•„์ˆ˜ ๋ชจ๋“ˆ ์‚ฌ์ „ ์ ๊ฒ€.
ํ•™์Šต 50 epoch ํ›„ export ๋‹จ๊ณ„์—์„œ ์‹คํŒจํ•˜๋Š” ์‹œ๊ฐ„ ๋‚ญ๋น„๋ฅผ ๋ฐฉ์ง€ํ•œ๋‹ค.
"""
missing = []
for mod in ("onnx", "onnxscript"):
if importlib.util.find_spec(mod) is None:
missing.append(mod)
if missing:
mods = ", ".join(missing)
print("\nโŒ ONNX export ์˜์กด์„ฑ ๋ˆ„๋ฝ:", mods)
print(" ํ˜„์žฌ Python:", sys.executable)
print(" ์•„๋ž˜๋ฅผ ๋จผ์ € ์‹คํ–‰ํ•˜์„ธ์š”:")
print(" python -m ensurepip --upgrade")
print(" python -m pip install -U onnx onnxscript")
sys.exit(1)
def train_model(
positive_embeddings: np.ndarray,
negative_embeddings: np.ndarray,
epochs: int = 50,
batch_size: int = 32,
learning_rate: float = 0.001,
use_gpu: bool = True,
) -> Tuple[WakeWordModel, dict]:
"""
Wake Word ๋ชจ๋ธ ํ•™์Šต
Args:
positive_embeddings: Positive ์ƒ˜ํ”Œ embeddings
negative_embeddings: Negative ์ƒ˜ํ”Œ embeddings
epochs: ํ•™์Šต ์—ํญ ์ˆ˜
batch_size: ๋ฐฐ์น˜ ํฌ๊ธฐ
learning_rate: ํ•™์Šต๋ฅ 
Returns:
ํ•™์Šต๋œ ๋ชจ๋ธ๊ณผ ํ•™์Šต ํžˆ์Šคํ† ๋ฆฌ
"""
# ์ž…๋ ฅ ํ˜•ํƒœ ๊ฒฐ์ • (timesteps, features)
# openWakeWord ๋ชจ๋ธ์€ ๋ณดํ†ต 16 timesteps๋ฅผ ์‚ฌ์šฉ
n_timesteps = 16 # ๊ณ ์ •๊ฐ’
n_features = positive_embeddings.shape[-1] # 96
input_shape = (n_timesteps, n_features)
print(f"\n{'='*60}")
print(f"๋ชจ๋ธ ํ•™์Šต ์‹œ์ž‘")
print(f"Input shape: {input_shape}")
print(f"Positive ์ƒ˜ํ”Œ: {len(positive_embeddings)}๊ฐœ")
print(f"Negative ์ƒ˜ํ”Œ: {len(negative_embeddings)}๊ฐœ")
print(f"{'='*60}\n")
# ๋ฐ์ดํ„ฐ ์ค€๋น„ - ๋งˆ์ง€๋ง‰ 16 timesteps๋งŒ ์‚ฌ์šฉ
def prepare_features(embeddings: np.ndarray, n_timesteps: int = 16) -> np.ndarray:
"""Embedding์—์„œ ๋งˆ์ง€๋ง‰ n_timesteps ํ”„๋ ˆ์ž„ ์ถ”์ถœ"""
prepared = []
for emb in embeddings:
if emb.shape[0] >= n_timesteps:
prepared.append(emb[-n_timesteps:])
else:
# ํŒจ๋”ฉ ํ•„์š”ํ•œ ๊ฒฝ์šฐ
padded = np.zeros((n_timesteps, n_features))
padded[-emb.shape[0]:] = emb
prepared.append(padded)
return np.array(prepared)
X_pos = prepare_features(positive_embeddings, n_timesteps)
X_neg = prepare_features(negative_embeddings, n_timesteps)
# ๋ฐ์ดํ„ฐ์…‹ ์ƒ์„ฑ
X = np.vstack([X_pos, X_neg]).astype(np.float32)
y = np.hstack([
np.ones(len(X_pos)),
np.zeros(len(X_neg))
]).astype(np.float32)
# ์…”ํ”Œ
indices = np.random.permutation(len(X))
X, y = X[indices], y[indices]
# Train/Validation ๋ถ„ํ•  (80/20)
split_idx = int(len(X) * 0.8)
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
# ๋ชจ๋ธ ์ƒ์„ฑ (use_gpu=True์ด๊ณ  CUDA ๊ฐ€๋Šฅ ์‹œ GPU ์‚ฌ์šฉ)
device = torch.device('cuda' if (use_gpu and torch.cuda.is_available()) else 'cpu')
if device.type == 'cuda':
print(f"๐Ÿ–ฅ๏ธ ํ•™์Šต ๋””๋ฐ”์ด์Šค: GPU ({torch.cuda.get_device_name(0)})")
batch_size = min(128, max(BATCH_SIZE, len(X_train) // 16)) # GPU ์‹œ ๋ฐฐ์น˜ ํ™•๋Œ€
else:
print(f"๐Ÿ–ฅ๏ธ ํ•™์Šต ๋””๋ฐ”์ด์Šค: CPU")
batch_size = BATCH_SIZE
# PyTorch DataLoader ์ƒ์„ฑ
train_dataset = TensorDataset(
torch.from_numpy(X_train),
torch.from_numpy(y_train)
)
val_dataset = TensorDataset(
torch.from_numpy(X_val),
torch.from_numpy(y_val)
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == 'cuda'))
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=(device.type == 'cuda'))
model = WakeWordModel(input_shape=input_shape, layer_dim=LAYER_DIM, n_blocks=N_BLOCKS)
model = model.to(device)
# ์†์‹ค ํ•จ์ˆ˜์™€ ์˜ตํ‹ฐ๋งˆ์ด์ €
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# ํ•™์Šต ํžˆ์Šคํ† ๋ฆฌ
history = {
'train_loss': [],
'val_loss': [],
'val_accuracy': []
}
best_val_loss = float('inf')
best_model_state = None
# ํ•™์Šต ๋ฃจํ”„
for epoch in range(epochs):
# Training
model.train()
train_losses = []
for X_batch, y_batch in train_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device).unsqueeze(1)
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
# Validation
model.eval()
val_losses = []
correct = 0
total = 0
with torch.no_grad():
for X_batch, y_batch in val_loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device).unsqueeze(1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
val_losses.append(loss.item())
predicted = (outputs > 0.5).float()
total += y_batch.size(0)
correct += (predicted == y_batch).sum().item()
avg_train_loss = np.mean(train_losses)
avg_val_loss = np.mean(val_losses)
val_accuracy = correct / total
history['train_loss'].append(avg_train_loss)
history['val_loss'].append(avg_val_loss)
history['val_accuracy'].append(val_accuracy)
# ์ตœ๊ณ  ๋ชจ๋ธ ์ €์žฅ
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_state = model.state_dict().copy()
# ์ง„ํ–‰ ์ƒํ™ฉ ์ถœ๋ ฅ
if (epoch + 1) % 10 == 0 or epoch == 0:
print(f"Epoch [{epoch+1}/{epochs}] "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {avg_val_loss:.4f} | "
f"Val Acc: {val_accuracy:.4f}")
# ์ตœ๊ณ  ์„ฑ๋Šฅ ๋ชจ๋ธ๋กœ ๋ณต์›
if best_model_state:
model.load_state_dict(best_model_state)
return model, history
def export_to_onnx(
model: WakeWordModel,
model_name: str,
output_dir: str
) -> str:
"""
ํ•™์Šต๋œ ๋ชจ๋ธ์„ ONNX ํ˜•์‹์œผ๋กœ ๋‚ด๋ณด๋‚ด๊ธฐ
Args:
model: ํ•™์Šต๋œ PyTorch ๋ชจ๋ธ
model_name: ๋ชจ๋ธ ์ด๋ฆ„
output_dir: ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ
Returns:
์ €์žฅ๋œ ONNX ํŒŒ์ผ ๊ฒฝ๋กœ
"""
model.eval()
model = model.to('cpu')
# ๋”๋ฏธ ์ž…๋ ฅ ์ƒ์„ฑ
dummy_input = torch.randn(1, *model.input_shape)
# ์ถœ๋ ฅ ๊ฒฝ๋กœ
onnx_path = os.path.join(output_dir, f"{model_name}.onnx")
# ONNX ๋‚ด๋ณด๋‚ด๊ธฐ
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True,
opset_version=13,
do_constant_folding=True,
input_names=['input'],
output_names=[model_name],
dynamic_axes={
'input': {0: 'batch_size'},
model_name: {0: 'batch_size'}
}
)
print(f"\nโœ… ONNX ๋ชจ๋ธ ์ €์žฅ ์™„๋ฃŒ: {onnx_path}")
return onnx_path
def main():
"""๋ฉ”์ธ ํ•จ์ˆ˜"""
import argparse
parser = argparse.ArgumentParser(
description="ํ•œ๊ตญ์–ด ํ˜ธ์ถœ์–ด ๋ชจ๋ธ ํ•™์Šต๊ธฐ"
)
parser.add_argument(
"--positive_dir", "-p",
type=str,
default=POSITIVE_DIR,
help=f"Positive ๋ฐ์ดํ„ฐ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: {POSITIVE_DIR})"
)
parser.add_argument(
"--model_name", "-m",
type=str,
default=MODEL_NAME,
help=f"์ถœ๋ ฅ ๋ชจ๋ธ ์ด๋ฆ„ (๊ธฐ๋ณธ๊ฐ’: {MODEL_NAME})"
)
parser.add_argument(
"--output_dir", "-o",
type=str,
default=OUTPUT_DIR,
help=f"์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ (๊ธฐ๋ณธ๊ฐ’: {OUTPUT_DIR})"
)
parser.add_argument(
"--epochs", "-e",
type=int,
default=EPOCHS,
help=f"ํ•™์Šต ์—ํญ ์ˆ˜ (๊ธฐ๋ณธ๊ฐ’: {EPOCHS})"
)
parser.add_argument(
"--negative_ratio",
type=float,
default=3.0,
help="Positive ๋Œ€๋น„ Negative ์ƒ˜ํ”Œ ๋น„์œจ (๊ธฐ๋ณธ๊ฐ’: 3.0)"
)
parser.add_argument(
"--negative_dir", "-n",
type=str,
default=DEFAULT_NEGATIVE_DIR,
help=f"์ผ๋ฐ˜ ๋Œ€ํ™” ์Œ์„ฑ(negative) WAV ๋””๋ ‰ํ† ๋ฆฌ. ๋น„์šฐ๋ฉด ๋…ธ์ด์ฆˆ๋งŒ ์‚ฌ์šฉ (๊ธฐ๋ณธ: {DEFAULT_NEGATIVE_DIR})"
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="GPU ๋น„ํ™œ์„ฑํ™” (CPU๋งŒ ์‚ฌ์šฉ)"
)
args = parser.parse_args()
# ONNX export ์˜์กด์„ฑ ์‚ฌ์ „ ์ ๊ฒ€ (ํ•™์Šต ์‹œ์ž‘ ์ „์— ์‹คํŒจ ์ฒ˜๋ฆฌ)
_check_onnx_export_dependencies()
# GPU ์‚ฌ์šฉ ์—ฌ๋ถ€ (--no-gpu๋ฉด CPU ๊ฐ•์ œ)
use_gpu_flag = torch.cuda.is_available() and not getattr(args, 'no_gpu', False)
# ์ถœ๋ ฅ ๋””๋ ‰ํ† ๋ฆฌ ์ƒ์„ฑ
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
print("\n" + "="*60)
print("๐ŸŽค ํ•œ๊ตญ์–ด ํ˜ธ์ถœ์–ด ๋ชจ๋ธ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ")
print("="*60)
# Step 1: Positive ๋ฐ์ดํ„ฐ ๋กœ๋“œ
print("\n[Step 1/5] Positive ๋ฐ์ดํ„ฐ ๋กœ๋“œ")
positive_clips = load_audio_files(args.positive_dir)
if len(positive_clips) == 0:
print("โŒ Positive ๋ฐ์ดํ„ฐ๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค! ๋จผ์ € generate_data.py๋ฅผ ์‹คํ–‰ํ•˜์„ธ์š”.")
sys.exit(1)
print(f" โœ… {len(positive_clips)}๊ฐœ Positive ํด๋ฆฝ ๋กœ๋“œ ์™„๋ฃŒ")
# Step 2: Negative ๋ฐ์ดํ„ฐ (์ผ๋ฐ˜ ๋Œ€ํ™” ์Œ์„ฑ ์šฐ์„ , ์—†์œผ๋ฉด ๋…ธ์ด์ฆˆ ์ƒ์„ฑ)
print("\n[Step 2/5] Negative ๋ฐ์ดํ„ฐ ์ค€๋น„")
num_negative = int(len(positive_clips) * args.negative_ratio)
negative_dir = (args.negative_dir or "").strip()
if negative_dir and Path(negative_dir).is_dir() and list(Path(negative_dir).glob("*.wav")):
negative_clips = load_negative_from_dir(negative_dir, num_negative)
print(f" โœ… {len(negative_clips)}๊ฐœ Negative ํด๋ฆฝ ๋กœ๋“œ (์ผ๋ฐ˜ ๋Œ€ํ™” ์Œ์„ฑ: {negative_dir})")
else:
negative_clips = generate_negative_data(num_negative)
print(f" โœ… {len(negative_clips)}๊ฐœ Negative ํด๋ฆฝ ์ƒ์„ฑ (๋…ธ์ด์ฆˆ)")
# Step 3: Feature ์ถ”์ถœ (GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ ์ž๋™ ์‚ฌ์šฉ)
print("\n[Step 3/5] openWakeWord Embedding ์ถ”์ถœ")
feat_device = 'gpu' if use_gpu_flag else 'cpu'
if use_gpu_flag:
print(f" ๐Ÿ–ฅ๏ธ GPU ์‚ฌ์šฉ: {torch.cuda.get_device_name(0)}")
else:
print(" ๐Ÿ–ฅ๏ธ CPU ์‚ฌ์šฉ")
melspec_model_path, embedding_model_path = _resolve_openwakeword_resource_paths()
if melspec_model_path and embedding_model_path:
feature_extractor = AudioFeatures(
inference_framework='onnx',
device=feat_device,
melspec_model_path=melspec_model_path,
embedding_model_path=embedding_model_path,
)
else:
feature_extractor = AudioFeatures(inference_framework='onnx', device=feat_device)
positive_embeddings = extract_embeddings(positive_clips, feature_extractor, use_gpu=use_gpu_flag)
negative_embeddings = extract_embeddings(negative_clips, feature_extractor, use_gpu=use_gpu_flag)
print(f" โœ… Positive embeddings: {positive_embeddings.shape}")
print(f" โœ… Negative embeddings: {negative_embeddings.shape}")
# Step 4: ๋ชจ๋ธ ํ•™์Šต
print("\n[Step 4/5] ๋ชจ๋ธ ํ•™์Šต")
model, history = train_model(
positive_embeddings,
negative_embeddings,
epochs=args.epochs,
batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE,
use_gpu=use_gpu_flag,
)
# Step 5: ONNX ๋‚ด๋ณด๋‚ด๊ธฐ
print("\n[Step 5/5] ONNX ๋ชจ๋ธ ๋‚ด๋ณด๋‚ด๊ธฐ")
onnx_path = export_to_onnx(model, args.model_name, args.output_dir)
# ์ตœ์ข… ๊ฒฐ๊ณผ ์ถœ๋ ฅ
print("\n" + "="*60)
print("๐ŸŽ‰ ํ•™์Šต ์™„๋ฃŒ!")
print("="*60)
print(f"๐Ÿ“ ONNX ๋ชจ๋ธ: {onnx_path}")
print(f"๐Ÿ“Š ์ตœ์ข… Validation Accuracy: {history['val_accuracy'][-1]:.4f}")
print(f"\n๐Ÿ’ก ์‹ค์‹œ๊ฐ„ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด:")
print(f" python run_live.py --model {onnx_path}")
if __name__ == "__main__":
main()