| # SenseVoice-Small-ko (Fine-tuned SenseVoiceSmall on EDIE dataset) |
|
|
| ์ด ๋ฆฌํฌ์งํฐ๋ฆฌ๋ SenseVoiceSmall๋ฅผ ํ๊ตญ์ด ์์ฑ/๊ฐ์ /์ด๋ฒคํธ ์ธ์์ฉ EDIE ๋ฐ์ดํฐ์
์ผ๋ก ํ์ธํ๋ํ ๋ชจ๋ธ์
๋๋ค. |
|
|
| - ๋ฒ ์ด์ค ๋ชจ๋ธ: iic/SenseVoiceSmall |
| - ํ
์คํฌ: STT (ASR) + Emotion (SER) + Event (AED) |
| - ์ฃผ์ ๋ผ๋ฒจ: |
| - ํ
์คํธ ๋ผ๋ฒจ |
| - ๊ฐ์ ๋ผ๋ฒจ: <|HAPPY|>, <|SAD|>, <|ANGRY|>, <|NEUTRAL|>, <|FEARFUL|>, <|DISGUSTED|>, <|SURPRISED|> |
|
|
|
|
| ## 0. ๋ชจ๋ธ ์
์ถ๋ ฅ ํฌ๋ฉง |
|
|
| **์
๋ ฅ** |
| - input: ๋จ์ผ wav ๊ฒฝ๋ก ๋๋ ๊ฒฝ๋ก ๋ฆฌ์คํธ |
|
|
| **์ถ๋ ฅ** |
| ์ถ๋ ฅ ์์ (AutoModel) |
| - text: ์ธ์๋ ํ
์คํธ |
| - language: ์ธ์ด ID (<|ko|> ๋ฑ) |
| - emo: ๊ฐ์ ๋ผ๋ฒจ (<|HAPPY|>, <|SAD|> ๋ฑ) |
| - event: ์ด๋ฒคํธ ๋ผ๋ฒจ (<|Speech|>, <|BGM|> ๋ฑ) |
|
|
|
|
| ## 1. ์ค์น |
|
|
| ```bash |
| pip install -U "funasr>=1.2.7" torch |
| ``` |
|
|
| GPU๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ ์ฌ์ ์ CUDA ํธํ PyTorch๋ฅผ ์ค์นํด ์ฃผ์ธ์ |
|
|
| ## 2. ๊ฐ๋จํ๊ฒ ๋ชจ๋ธ ์ฌ์ฉํ๊ธฐ |
|
|
| FunASR์ AutoModel์ ์ด์ฉํ์ฌ ํ๊น
ํ์ด์ค ๋ชจ๋ธ ํ๋ธ์์ ๋ชจ๋ธ ๋ ํ์งํ ๋ฆฌ์ ๋ชจ๋ธ์ ๋ฐ๋ก ๋ก๋ํด์ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
| ```python |
| #!/usr/bin/env python3 |
| from pathlib import Path |
| import os |
| import argparse |
| |
| from huggingface_hub import snapshot_download |
| from funasr import AutoModel |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess |
| |
| HF_REPO_ID = "AeiROBOT/SenseVoice-Small-ko" # ์
๋ก๋ํ HF ๋ฆฌํฌ ID |
| LOCAL_DIR = "/home/khw/Workspace/SenseVoice/hf_models/SenseVoice-Small-ko" |
| |
| # ----- SenseVoice ํ ํฐ ํ์ ----- |
| LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"} |
| EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"} |
| EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"} |
| WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"} |
| |
| |
| def _consume(prefixes, text: str): |
| for p in prefixes: |
| if text.startswith(p): |
| return p, text[len(p):] |
| return None, text |
| |
| |
| def parse_sensevoice_text(raw: str): |
| """SenseVoice ์ถ๋ ฅ ๋ฌธ์์ด์์ (lang, emo, event, with_itn, text) ๋ถ๋ฆฌ. |
| |
| ์: |
| "<|ko|><|NEUTRAL|><|Speech|><|withitn|>์กฐ ๊ธ๋ง ์๊ฐ ์ ํ ๋ฉด์ ์ด ๋ฉด ํจ์ฌ ํธํ ๊ฑฐ์ผ." -> |
| { |
| "language": "<|ko|>", |
| "emo": "<|NEUTRAL|>", |
| "event": "<|Speech|>", |
| "with_itn": "<|withitn|>", |
| "text": "์กฐ ๊ธ๋ง ์๊ฐ ์ ํ ๋ฉด์ ์ด ๋ฉด ํจ์ฌ ํธํ ๊ฑฐ์ผ." |
| } |
| """ |
| if not raw: |
| return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""} |
| |
| rest = raw.strip() |
| lang, rest = _consume(LANG_TOKENS, rest) |
| emo, rest = _consume(EMO_TOKENS, rest) |
| event, rest = _consume(EVENT_TOKENS, rest) |
| with_itn, rest = _consume(WITH_ITN_TOKENS, rest) |
| |
| clean_text = rest.strip() |
| return { |
| "language": lang, |
| "emo": emo, |
| "event": event, |
| "with_itn": with_itn, |
| "text": clean_text, |
| } |
| |
| |
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--wav_file", default="dataset/wav_dataset/DISGUSTED/test_2025_12_12_040201.wav", help="pretrained ๋ชจ๋ธ ์ด๋ฆ ๋๋ ๋ก์ปฌ ๋๋ ํฐ๋ฆฌ") |
| return p.parse_args() |
| |
| def get_model(): |
| local_path = snapshot_download( |
| repo_id=HF_REPO_ID, |
| repo_type="model", |
| local_dir=LOCAL_DIR, |
| local_dir_use_symlinks=False, |
| token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), # private ์ด๋ฏ๋ก ํ์ |
| ) |
| print("๋ค์ด๋ก๋ ๊ฒฝ๋ก:", local_path) |
| |
| # 2) AutoModel์ ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ ๋๊ฒจ์ ์ฌ์ฉ |
| model_dir = local_path # ๋๋ LOCAL_DIR |
| |
| model = AutoModel( |
| model=model_dir, |
| trust_remote_code=True, |
| remote_code=str(Path(model_dir) / "model.py"), # HF ๋ฆฌํฌ์ ์๋ model.py ์ฌ์ฉ |
| vad_model="fsmn-vad", |
| vad_kwargs={"max_single_segment_time": 30000}, |
| device="cuda:0", |
| ) |
| |
| return model |
| |
| def main(): |
| args = parse_args() |
| wav_path = args.wav_file |
| |
| model = get_model() |
| |
| res = model.generate( |
| input=wav_path, |
| cache={}, |
| language="auto", # ๋๋ "ko" |
| use_itn=True, |
| batch_size_s=60, |
| merge_vad=True, |
| merge_length_s=15, |
| ) |
| |
| raw_text = res[0]["text"] |
| parsed = parse_sensevoice_text(raw_text) |
| |
| # ITN ํ์ฒ๋ฆฌ |
| pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else "" |
| |
| print("=== Raw ===") |
| print(raw_text) |
| print("=== Parsed ===") |
| print("lang :", parsed["language"]) |
| print("emo :", parsed["emo"]) |
| print("event :", parsed["event"]) |
| print("withitn:", parsed["with_itn"]) |
| print("text :", pretty_text) |
| |
| |
| if __name__ == "__main__": |
| main() |
| ``` |
|
|
|
|
| ## 3. ํ์ต ๋ฐ์ดํฐ์
์ผ๋ก ํ๊ฐํ๊ธฐ |
|
|
| ```python |
| #!/usr/bin/env python3 |
| import os |
| import json |
| import argparse |
| import unicodedata |
| from pathlib import Path |
| from typing import List, Dict, Tuple, Optional |
| |
| import torch |
| from funasr import AutoModel |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess |
| |
| |
| # ======================= |
| # SenseVoice ํ ํฐ ํ์ |
| # ======================= |
| LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"} |
| EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"} |
| EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"} |
| WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"} |
| |
| |
| def _consume(prefixes, text: str): |
| for p in prefixes: |
| if text.startswith(p): |
| return p, text[len(p):] |
| return None, text |
| |
| |
| def parse_sensevoice_text(raw: str) -> Dict[str, Optional[str]]: |
| if not raw: |
| return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""} |
| |
| rest = raw.strip() |
| lang, rest = _consume(LANG_TOKENS, rest) |
| emo, rest = _consume(EMO_TOKENS, rest) |
| event, rest = _consume(EVENT_TOKENS, rest) |
| with_itn, rest = _consume(WITH_ITN_TOKENS, rest) |
| |
| clean_text = rest.strip() |
| return { |
| "language": lang, |
| "emo": emo, |
| "event": event, |
| "with_itn": with_itn, |
| "text": clean_text, |
| } |
| |
| |
| # ======================= |
| # ํ
์คํธ ์ ๊ทํ & ์งํ |
| # ======================= |
| |
| def normalize_text(s: str, lower: bool, strip_punct: bool, strip_spaces: bool) -> str: |
| if s is None: |
| return "" |
| t = s |
| if lower: |
| t = t.lower() |
| if strip_punct: |
| t = "".join(ch for ch in t if not unicodedata.category(ch).startswith("P")) |
| if strip_spaces: |
| t = "".join(t.split()) |
| return t |
| |
| |
| def _levenshtein(a: List[str], b: List[str]) -> int: |
| n, m = len(a), len(b) |
| if n == 0: |
| return m |
| if m == 0: |
| return n |
| prev = list(range(m + 1)) |
| for i in range(1, n + 1): |
| curr = [i] + [0] * m |
| ai = a[i - 1] |
| for j in range(1, m + 1): |
| cost = 0 if ai == b[j - 1] else 1 |
| curr[j] = min( |
| prev[j] + 1, |
| curr[j - 1] + 1, |
| prev[j - 1] + cost, |
| ) |
| prev = curr |
| return prev[m] |
| |
| |
| def cer(ref: str, hyp: str) -> float: |
| r = list(ref) |
| h = list(hyp) |
| dist = _levenshtein(r, h) |
| return dist / max(1, len(r)) |
| |
| |
| def wer(ref: str, hyp: str) -> float: |
| r = ref.split() |
| h = hyp.split() |
| dist = _levenshtein(r, h) |
| return dist / max(1, len(r)) |
| |
| |
| def norm_emo(label: Optional[str]) -> str: |
| if not label: |
| return "" |
| t = label.strip() |
| if t.startswith("<|") and t.endswith("|>"): |
| t = t[2:-2] |
| return t.upper() |
| |
| |
| # ======================= |
| # IO & argparse |
| # ======================= |
| |
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model-dir", default="/home/khw/Workspace/SenseVoice/outputs", help="finetune ์ฐ์ถ๋ฌผ ๋๋ ํฐ๋ฆฌ") |
| p.add_argument("--jsonl", default="/home/khw/Workspace/SenseVoice/data/train.jsonl", help="์
๋ ฅ JSONL ๊ฒฝ๋ก") |
| p.add_argument("--base-audio-dir", default="/home/khw/Workspace/SenseVoice", help="source ์๋๊ฒฝ๋ก์ ๊ธฐ์ค ๋๋ ํฐ๋ฆฌ") |
| p.add_argument("--remote-code", default="/home/khw/Workspace/SenseVoice/model.py", help="SenseVoice ๋ชจ๋ธ ๊ตฌํ ๊ฒฝ๋ก") |
| p.add_argument("--device", default=None, help="cuda:0 / cpu (๋ฏธ์ง์ ์ ์๋ ๊ฒฐ์ )") |
| p.add_argument("--batch-size", type=int, default=64, help="๋ฐฐ์น ํฌ๊ธฐ(์งง์ ์์ ๋ค์ ๊ฐ์ )") |
| p.add_argument("--use-best-ckpt", action="store_true", help="model.pt.best๋ฅผ model.pt๋ก ์ฌ๋ณผ๋ฆญ ๋งํฌ ์์ฑ") |
| p.add_argument("--lang", default="ko", choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], help="์ธ์ด ๊ฐ์ ์ค์ . ๊ธฐ๋ณธ ko") |
| p.add_argument("--lower", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ์๋ฌธ์ํ") |
| p.add_argument("--strip-punct", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ๋ฌธ์ฅ๋ถํธ ์ ๊ฑฐ") |
| p.add_argument("--strip-spaces", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ๋ชจ๋ ๊ณต๋ฐฑ ์ ๊ฑฐ") |
| p.add_argument("--out", default="/home/khw/Workspace/SenseVoice/results/preds_train.jsonl", help="์ถ๋ก ๊ฒฐ๊ณผ JSONL") |
| return p.parse_args() |
| |
| |
| def _find_latest_epoch_ckpt(model_dir: Path) -> Optional[Path]: |
| """model.pt.ep* ์ค์์ ๊ฐ์ฅ ํฐ epoch ๋ฒํธ๋ฅผ ๊ฐ์ง ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ๋๋ค.""" |
| candidates = [] |
| for p in model_dir.glob("model.pt.ep*"): |
| name = p.name |
| try: |
| # ์ด๋ฆ์์ ์ซ์ ๋ถ๋ถ๋ง ํ์ฑ: model.pt.ep50 -> 50 |
| ep_str = name.split("model.pt.ep", 1)[1] |
| ep = int(ep_str) |
| candidates.append((ep, p)) |
| except (IndexError, ValueError): |
| # ํจํด์ด ์ ๋ง์ผ๋ฉด ๋ฌด์ |
| continue |
| |
| if not candidates: |
| return None |
| |
| candidates.sort(key=lambda x: x[0]) # epoch ์ค๋ฆ์ฐจ์ ์ ๋ ฌ |
| return candidates[-1][1] # ๊ฐ์ฅ ํฐ epoch |
| |
| |
| def prepare_checkpoint(model_dir: Path) -> Path: |
| """์ฃผ์ด์ง model_dir ์์์ ์ฌ์ฉํ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ํํ๊ณ , model.pt๋ฅผ ์ค๋นํ๋ค. |
| |
| ์ฐ์ ์์: |
| 1) model.pt.best |
| 2) model.pt.ep* ์ค ๊ฐ์ฅ ํฐ epoch |
| 3) model.pt (๊ธฐ์กด ํ์ผ) |
| |
| ์
๋ค ์์ผ๋ฉด SystemExit์ผ๋ก ์ข
๋ฃ. |
| |
| ์ ํ๋ ํ์ผ์ด model.pt๊ฐ ์๋๋ผ๋ฉด, model.pt๋ฅผ ํด๋น ํ์ผ์ ๊ฐ๋ฆฌํค๋ |
| ์ฌ๋ณผ๋ฆญ ๋งํฌ(๋๋ ๋ณต์ฌ๋ณธ)์ผ๋ก ๋ง๋ ๋ค. |
| """ |
| best = model_dir / "model.pt.best" |
| target = model_dir / "model.pt" # AutoModel์ด ์ต์ข
์ ์ผ๋ก ๋ณด๊ฒ ๋ ํ์ผ |
| |
| chosen: Optional[Path] = None |
| |
| # 1) model.pt.best ์ต์ฐ์ |
| if best.exists(): |
| chosen = best |
| reason = "model.pt.best" |
| else: |
| # 2) ๊ฐ์ฅ ๋ง์ง๋ง epoch์ model.pt.ep* |
| latest_ep = _find_latest_epoch_ckpt(model_dir) |
| if latest_ep is not None: |
| chosen = latest_ep |
| reason = latest_ep.name |
| # 3) ๊ธฐ์กด model.pt |
| elif target.exists(): |
| chosen = target |
| reason = "existing model.pt" |
| else: |
| reason = "(none)" |
| |
| if chosen is None: |
| raise SystemExit( |
| f"[fatal] No checkpoint found in {model_dir}. " |
| f"Expected one of: model.pt.best, model.pt.ep*, model.pt. Program will exit." |
| ) |
| |
| # ์ ํ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ model.pt๋ก ๋ง์ถฐ์ค๋ค (๋งํฌ ๋๋ ๋ณต์ฌ) |
| if chosen != target: |
| if target.exists() or target.is_symlink(): |
| try: |
| target.unlink() |
| except Exception as e: |
| print(f"[warn] failed to remove existing {target}: {e}") |
| |
| try: |
| # ์๋ ์ด๋ฆ์ผ๋ก ์ฌ๋ณผ๋ฆญ ๋งํฌ ์์ฑ |
| target.symlink_to(chosen.name) |
| print(f"[info] using checkpoint: {chosen.name} (linked as model.pt)") |
| except Exception as e: |
| # ์ผ๋ถ ํ์ผ์์คํ
/๊ถํ ํ๊ฒฝ์์ symlink๊ฐ ์ ๋ ์ ์์ผ๋ฏ๋ก, ๋ณต์ฌ๋ก ํด๋ฐฑ |
| print(f"[warn] symlink failed ({e}), will try to copy instead.") |
| import shutil |
| try: |
| shutil.copy2(str(chosen), str(target)) |
| print(f"[info] using checkpoint: {chosen.name} (copied to model.pt)") |
| except Exception as e2: |
| raise SystemExit( |
| f"[fatal] failed to prepare checkpoint at {target}: {e2}. Program will exit." |
| ) |
| else: |
| print(f"[info] using checkpoint: {reason}") |
| |
| return chosen |
| |
| |
| def load_items(jsonl_path: Path) -> List[Dict]: |
| items = [] |
| with jsonl_path.open("r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| obj = json.loads(line) |
| items.append(obj) |
| except Exception as e: |
| print(f"[warn] skip bad line: {e}") |
| return items |
| |
| |
| def to_abs_paths(items: List[Dict], base_audio_dir: Path) -> Tuple[List[Dict], int]: |
| missing = 0 |
| for it in items: |
| src = it.get("source") |
| if src: |
| p = (base_audio_dir / src).resolve() |
| if not p.exists(): |
| missing += 1 |
| it["abs_source"] = str(p) |
| else: |
| it["abs_source"] = None |
| missing += 1 |
| return items, missing |
| |
| |
| def batched(iterable, n: int): |
| batch = [] |
| for x in iterable: |
| batch.append(x) |
| if len(batch) == n: |
| yield batch |
| batch = [] |
| if batch: |
| yield batch |
| |
| |
| # ======================= |
| # main |
| # ======================= |
| |
| def main(): |
| args = parse_args() |
| |
| model_dir = Path(args.model_dir) |
| jsonl_path = Path(args.jsonl) |
| base_audio_dir = Path(args.base_audio_dir) |
| |
| # ์ฒดํฌํฌ์ธํธ ์ฐ์ ์์ ์ ์ฉ: model.pt.best > model.pt.ep* (์ต๋ epoch) > model.pt |
| ckpt = prepare_checkpoint(model_dir) |
| print(f"[info] final checkpoint file: {ckpt}") |
| |
| device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") |
| |
| # model.py(remote_code)๋ ๋ฐ๋์ ์กด์ฌํด์ผ ํ๋ค. ์์ผ๋ฉด ๋ฐ๋ก ์ข
๋ฃ. |
| remote_code_path = Path(args.remote_code) |
| if not remote_code_path.exists(): |
| raise SystemExit( |
| f"[fatal] remote_code not found at {remote_code_path}. " |
| f"Expected model.py for SenseVoice. Program will exit." |
| ) |
| |
| trust_remote = True |
| |
| model = AutoModel( |
| model=str(model_dir), # ๋ก์ปฌ ๋๋ ํฐ๋ฆฌ๋ง ์ฌ์ฉ |
| trust_remote_code=trust_remote, |
| remote_code=str(remote_code_path), |
| device=device, |
| vad_model=None, |
| ) |
| |
| items = load_items(jsonl_path) |
| items, _ = to_abs_paths(items, base_audio_dir) |
| |
| valid_items = [it for it in items if it.get("abs_source") and Path(it["abs_source"]).exists()] |
| missing = len(items) - len(valid_items) |
| if missing: |
| print(f"[warn] {missing} items skipped due to missing files") |
| |
| out_path = Path(args.out) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| total = len(valid_items) |
| print(f"[info] total inputs used: {total}, device: {device}, language: {args.lang}") |
| if total == 0: |
| print("[exit] No valid audio found. Check --base-audio-dir or 'source' paths.") |
| with out_path.open("w", encoding="utf-8") as wf: |
| pass |
| return |
| |
| # ์งํ ๋์ |
| exact_matches = 0 |
| cer_sum = 0.0 |
| wer_sum = 0.0 |
| text_pairs = 0 |
| |
| emo_correct = 0 |
| emo_total = 0 |
| |
| written = 0 |
| with out_path.open("w", encoding="utf-8") as wf: |
| for batch in batched(valid_items, args.batch_size): |
| wav_list = [b["abs_source"] for b in batch] |
| |
| try: |
| res = model.generate( |
| input=wav_list, |
| cache={}, |
| language=args.lang, |
| use_itn=True, |
| batch_size=len(wav_list), |
| ) |
| except Exception as e: |
| print(f"[error] inference failed on batch starting key={batch[0].get('key')}: {e}") |
| continue |
| |
| for it, r in zip(batch, res): |
| raw_text = r.get("text", "") or "" |
| parsed = parse_sensevoice_text(raw_text) |
| pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else "" |
| |
| ref_text = it.get("target") or "" |
| |
| # ํ
์คํธ ์งํ |
| if ref_text: |
| nt_ref = normalize_text(ref_text, args.lower, args.strip_punct, args.strip_spaces) |
| nt_hyp = normalize_text(pretty_text, args.lower, args.strip_punct, args.strip_spaces) |
| |
| if nt_ref == nt_hyp: |
| exact_matches += 1 |
| cer_sum += cer(nt_ref, nt_hyp) |
| wer_sum += wer(nt_ref, nt_hyp) |
| text_pairs += 1 |
| |
| # ๊ฐ์ ์งํ |
| tgt_emo_n = norm_emo(it.get("emo_target")) |
| pred_emo_n = norm_emo(parsed["emo"]) |
| if tgt_emo_n: |
| emo_total += 1 |
| if pred_emo_n == tgt_emo_n: |
| emo_correct += 1 |
| |
| out_obj = { |
| "key": it.get("key"), |
| "audio": it.get("abs_source"), |
| "pred_raw": raw_text, |
| "pred_text": pretty_text, |
| "ref_text": ref_text, |
| "pred_language": parsed["language"], |
| "pred_emo": pred_emo_n or parsed["emo"] or "", |
| "ref_emo": tgt_emo_n or it.get("emo_target") or "", |
| "pred_event": parsed["event"] or "", |
| "with_itn": parsed["with_itn"] or "", |
| } |
| wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n") |
| |
| # ===== ์ฌ๋์ด ๋ณด๊ธฐ ์ข์ per-sample ์ถ๋ ฅ ===== |
| idx = written + 1 |
| print("\n[{}] key={}".format(idx, it.get("key"))) |
| print("REF_TEXT :", ref_text) |
| print("REF_EMO :", tgt_emo_n or it.get("emo_target")) |
| print("PRED_TEXT:", pretty_text) |
| print("PRED_EMO :", pred_emo_n or parsed["emo"]) # ํ ํฐ ๊ทธ๋๋ก ๋ณด์ฌ์ค๋ ๋จ |
| print("PRED_EVT :", parsed["event"]) # ์ด๋ฒคํธ๋ ๊ฐ์ด ํ์ธ |
| print("-" * 80) |
| |
| written += 1 |
| |
| # ์์ฝ ์ถ๋ ฅ |
| print("\n===== Summary =====") |
| print(f"Samples inferred: {written}") |
| if text_pairs > 0: |
| exact_acc = exact_matches / text_pairs * 100.0 |
| avg_cer = cer_sum / text_pairs |
| avg_wer = wer_sum / text_pairs |
| print(f"Text pairs (with ref): {text_pairs}") |
| print(f"- Exact match accuracy: {exact_acc:.2f}%") |
| print(f"- Avg CER: {avg_cer:.4f}") |
| print(f"- Avg WER: {avg_wer:.4f}") |
| else: |
| print("No text references found; text metrics skipped.") |
| |
| if emo_total > 0: |
| emo_acc = emo_correct / emo_total * 100.0 |
| print(f"Emotion pairs: {emo_total}") |
| print(f"- Emotion accuracy: {emo_acc:.2f}%") |
| else: |
| print("No emotion references found; emotion metrics skipped.") |
| |
| print(f"Results saved to: {out_path}") |
| |
| |
| if __name__ == "__main__": |
| main() |
| |
| |
| ``` |
|
|
|
|
| ## 4. ํ์ต ํ ํ๊น
ํ์ด์ค์ ๋ชจ๋ธ ์
๋ก๋ |
|
|
| upload_model_to_huggingface.py |
| |
| ```python |
| |
| #!/usr/bin/env python3 |
| import os |
| from pathlib import Path |
| |
| from huggingface_hub import HfApi, create_repo, upload_folder |
|
|
| # ===== ์ฌ์ฉ์ ์ค์ ===== |
| # ์ค์ ๋ก ๋ง๋ค Hugging Face ๋ชจ๋ธ repo ID (์์) |
| REPO_ID = "AeiROBOT/SenseVoice-Small-ko" # <-- ์ํ๋ ์ด๋ฆ์ผ๋ก ์์ |
| |
| # ์
๋ก๋ํ ๋ก์ปฌ ํด๋ (ํ์ต ๊ฒฐ๊ณผ) |
| MODEL_DIR = Path("/home/khw/Workspace/SenseVoice/outputs") |
|
|
| # ๋ก์ปฌ์ ์๋ model.py๋ฅผ ํจ๊ป ์ฌ๋ฆฌ๊ณ ์ถ์ผ๋ฉด (FunASR/SenseVoice์ฉ) |
| # outputs ์์ ์ด๋ฏธ ๋ณต์ฌํด ๋์์ผ๋ฉด ์๋ต ๊ฐ๋ฅ |
| EXTRA_FILES = [ |
| Path("/home/khw/Workspace/SenseVoice/model.py"), # ์์ผ๋ฉด ์ฃผ์ ์ฒ๋ฆฌ |
| ] |
| |
| |
| def main(): |
| # 1) ํ ํฐ ๊ฐ์ ธ์ค๊ธฐ (ํ๊ฒฝ๋ณ์ ์ฌ์ฉ ๊ถ์ฅ) |
| # ๋ฏธ๋ฆฌ export HUGGINGFACE_HUB_TOKEN=hf_xxx ํ๊ธฐ |
| token = os.environ.get("HUGGINGFACE_HUB_TOKEN") |
| if token is None: |
| raise RuntimeError( |
| "HUGGINGFACE_HUB_TOKEN ํ๊ฒฝ๋ณ์๊ฐ ์ค์ ๋์ด ์์ง ์์ต๋๋ค. " |
| "https://huggingface.co/settings/tokens ์์ ํ ํฐ์ ๋ง๋ค๊ณ ,\n" |
| "export HUGGINGFACE_HUB_TOKEN=hf_xxx ๋ก ์ค์ ํ ๋ค ๋ค์ ์คํํ์ธ์." |
| ) |
| |
| api = HfApi() |
| |
| # 2) ๋ฆฌํฌ์งํฐ๋ฆฌ ์์ฑ (์ด๋ฏธ ์์ผ๋ฉด exist_ok=True ๋ก ๊ทธ๋ฅ ํต๊ณผ) |
| create_repo( |
| repo_id=REPO_ID, |
| token=token, |
| private=True, # ๋น๊ณต๊ฐ๋ก ์ฌ๋ฆฌ๋ ค๋ฉด True |
| exist_ok=True, |
| repo_type="model", |
| ) |
| |
| # 3) ์ถ๊ฐ๋ก ์ฌ๋ฆด ํ์ผ(model.py ๋ฑ)์ outputs ์์ผ๋ก ๋ณต์ฌ (์ ํ) |
| # -> HF ๋ฆฌํฌ root์ README.md, model.pt, config.yaml, configuration.json, model.py ๋ฑ์ด ๊ฐ์ด ์๋๋ก ์ถ์ฒ |
| for extra in EXTRA_FILES: |
| if extra.is_file(): |
| target = MODEL_DIR / extra.name |
| if not target.exists(): |
| print(f"[info] copy {extra} -> {target}") |
| target.write_bytes(extra.read_bytes()) |
| else: |
| print(f"[warn] extra file not found: {extra}") |
| |
| # 3-1) ๋ชจ๋ธ ์นด๋(README) ์
๋ก๋: ์คํ ์์น(CWD)์ README_huggingface.md๋ฅผ outputs/README.md๋ก ๋ณต์ฌ |
| # - HF ๋ชจ๋ธ ํ๋ธ๋ repo ๋ฃจํธ์ README.md๋ฅผ ๋ชจ๋ธ ์นด๋๋ก ์ธ์ํฉ๋๋ค. |
| readme_src = Path.cwd() / "README_huggingface.md" |
| readme_dst = MODEL_DIR / "README.md" |
| if readme_src.is_file(): |
| print(f"[info] copy {readme_src} -> {readme_dst}") |
| readme_dst.write_text(readme_src.read_text(encoding="utf-8"), encoding="utf-8") |
| else: |
| print(f"[warn] README_huggingface.md not found in CWD: {Path.cwd()}") |
| |
| # 4) ํด๋ ํต์งธ๋ก ์
๋ก๋ |
| print(f"[info] uploading folder: {MODEL_DIR} -> {REPO_ID}") |
| upload_folder( |
| repo_id=REPO_ID, |
| folder_path=str(MODEL_DIR), |
| path_in_repo=".", # ๋ฆฌํฌ ๋ฃจํธ์ ๊ทธ๋๋ก ์ฌ๋ฆฌ๊ธฐ |
| token=token, |
| repo_type="model", |
| ignore_patterns=[ |
| "model.pt.ep*", # ์ฒดํฌํฌ์ธํธ๋ค ์ ์ธ |
| "*.pt.ep*", # ํน์ ๋ค๋ฅธ ํ์ผ๋ช
๋ ๋น์ทํ๊ฒ ์ฐํ๋ฉด ๊ฐ์ด ์ ์ธ |
| ], |
| ) |
| |
| print("[done] uploaded to:", f"https://huggingface.co/{REPO_ID}") |
| |
|
|
| if __name__ == "__main__": |
| main() |
| |
|
|
| ``` |