| |
| import os |
| import json |
| import argparse |
| import unicodedata |
| import time |
| from pathlib import Path |
| from typing import List, Dict, Tuple, Optional |
|
|
| import torch |
| from funasr import AutoModel |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess |
|
|
|
|
| |
| |
| |
| LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"} |
| EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"} |
| EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"} |
| WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"} |
|
|
|
|
| def _consume(prefixes, text: str): |
| for p in prefixes: |
| if text.startswith(p): |
| return p, text[len(p):] |
| return None, text |
|
|
|
|
| def parse_sensevoice_text(raw: str) -> Dict[str, Optional[str]]: |
| if not raw: |
| return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""} |
|
|
| rest = raw.strip() |
| lang, rest = _consume(LANG_TOKENS, rest) |
| emo, rest = _consume(EMO_TOKENS, rest) |
| event, rest = _consume(EVENT_TOKENS, rest) |
| with_itn, rest = _consume(WITH_ITN_TOKENS, rest) |
|
|
| clean_text = rest.strip() |
| return { |
| "language": lang, |
| "emo": emo, |
| "event": event, |
| "with_itn": with_itn, |
| "text": clean_text, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def normalize_text(s: str, lower: bool, strip_punct: bool, strip_spaces: bool) -> str: |
| if s is None: |
| return "" |
| t = s |
| if lower: |
| t = t.lower() |
| if strip_punct: |
| t = "".join(ch for ch in t if not unicodedata.category(ch).startswith("P")) |
| if strip_spaces: |
| t = "".join(t.split()) |
| return t |
|
|
|
|
| def _levenshtein(a: List[str], b: List[str]) -> int: |
| n, m = len(a), len(b) |
| if n == 0: |
| return m |
| if m == 0: |
| return n |
| prev = list(range(m + 1)) |
| for i in range(1, n + 1): |
| curr = [i] + [0] * m |
| ai = a[i - 1] |
| for j in range(1, m + 1): |
| cost = 0 if ai == b[j - 1] else 1 |
| curr[j] = min( |
| prev[j] + 1, |
| curr[j - 1] + 1, |
| prev[j - 1] + cost, |
| ) |
| prev = curr |
| return prev[m] |
|
|
|
|
| def cer(ref: str, hyp: str) -> float: |
| r = list(ref) |
| h = list(hyp) |
| dist = _levenshtein(r, h) |
| return dist / max(1, len(r)) |
|
|
|
|
| def wer(ref: str, hyp: str) -> float: |
| r = ref.split() |
| h = hyp.split() |
| dist = _levenshtein(r, h) |
| return dist / max(1, len(r)) |
|
|
|
|
| def norm_emo(label: Optional[str]) -> str: |
| if not label: |
| return "" |
| t = label.strip() |
| if t.startswith("<|") and t.endswith("|>"): |
| t = t[2:-2] |
| return t.upper() |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model-dir", default="SenseVoice-Small-ko", help="๋ชจ๋ธ ๋๋ ํฐ๋ฆฌ") |
| p.add_argument("--model-file", default="model.pt", help="๋ชจ๋ธ ํ์ผ๋ช
(์: model.pt, model.pt.best)") |
| p.add_argument("--wav-scp", default="dataset/dataset/train_wav.scp", help="wav scp ํ์ผ ๊ฒฝ๋ก") |
| p.add_argument("--text-file", default="dataset/dataset/train_text.txt", help="text ํ์ผ ๊ฒฝ๋ก") |
| p.add_argument("--emo-file", default="dataset/dataset/train_emo.txt", help="emotion ํ์ผ ๊ฒฝ๋ก") |
| p.add_argument("--base-audio-dir", default=".", help="wav ํ์ผ ๊ธฐ์ค ๋๋ ํฐ๋ฆฌ") |
| p.add_argument("--remote-code", default="SenseVoice-Small-ko/model.py", help="SenseVoice ๋ชจ๋ธ ๊ตฌํ ๊ฒฝ๋ก") |
| p.add_argument("--device", default=None, help="cuda:0 / cpu (๋ฏธ์ง์ ์ ์๋ ๊ฒฐ์ )") |
| p.add_argument("--lang", default="ko", choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], help="์ธ์ด ๊ฐ์ ์ค์ . ๊ธฐ๋ณธ ko") |
| p.add_argument("--lower", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ์๋ฌธ์ํ") |
| p.add_argument("--strip-punct", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ๋ฌธ์ฅ๋ถํธ ์ ๊ฑฐ") |
| p.add_argument("--strip-spaces", action="store_true", help="์ ๋ฐ๋ ๊ณ์ฐ ์ ๋ชจ๋ ๊ณต๋ฐฑ ์ ๊ฑฐ") |
| p.add_argument("--out", default="results.jsonl", help="์ถ๋ก ๊ฒฐ๊ณผ JSONL") |
| return p.parse_args() |
|
|
|
|
| def load_dataset(wav_scp: str, text_file: str, emo_file: str, base_audio_dir: str) -> List[Dict]: |
| """wav-scp, text-file, emo-file์์ ๋ฐ์ดํฐ์
๋ก๋""" |
| data = {} |
| |
| |
| with open(wav_scp, 'r', encoding='utf-8') as f: |
| for line in f: |
| parts = line.strip().split(maxsplit=1) |
| if len(parts) == 2: |
| key, path = parts |
| |
| if base_audio_dir: |
| full_path = os.path.join(base_audio_dir, path) |
| else: |
| full_path = path |
| data[key] = {'wav_path': full_path} |
| |
| |
| with open(text_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| parts = line.strip().split(maxsplit=1) |
| if len(parts) >= 1: |
| key = parts[0] |
| text = parts[1] if len(parts) > 1 else "" |
| if key in data: |
| data[key]['text'] = text |
|
|
| |
| with open(emo_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| parts = line.strip().split(maxsplit=1) |
| if len(parts) >= 2: |
| key, emo = parts |
| if key in data: |
| data[key]['emo'] = emo |
| |
| |
| valid_data = [] |
| for key, val in data.items(): |
| if 'text' in val and 'emo' in val: |
| if os.path.exists(val['wav_path']): |
| valid_data.append({ |
| 'key': key, |
| 'abs_source': val['wav_path'], |
| 'target': val['text'], |
| 'emo_target': val['emo'] |
| }) |
| else: |
| print(f"[warn] File not found: {val['wav_path']}") |
| |
| return valid_data |
|
|
|
|
| def prepare_model_file(model_dir: Path, model_file: str) -> Path: |
| """์ง์ ๋ ๋ชจ๋ธ ํ์ผ์ด ์กด์ฌํ๋์ง ํ์ธํ๊ณ model.pt๋ก ์ฌ๋ณผ๋ฆญ ๋งํฌ ์์ฑ""" |
| source = model_dir / model_file |
| target = model_dir / "model.pt" |
| |
| if not source.exists(): |
| raise SystemExit( |
| f"[fatal] Model file not found: {source}. Program will exit." |
| ) |
| |
| |
| if source == target: |
| print(f"[info] using model file: {model_file}") |
| return source |
| |
| |
| if target.exists() or target.is_symlink(): |
| try: |
| target.unlink() |
| except Exception as e: |
| print(f"[warn] failed to remove existing {target}: {e}") |
| |
| try: |
| |
| target.symlink_to(model_file) |
| print(f"[info] using model file: {model_file} (linked as model.pt)") |
| except Exception as e: |
| |
| print(f"[warn] symlink failed ({e}), will try to copy instead.") |
| import shutil |
| try: |
| shutil.copy2(str(source), str(target)) |
| print(f"[info] using model file: {model_file} (copied to model.pt)") |
| except Exception as e2: |
| raise SystemExit( |
| f"[fatal] failed to prepare model file at {target}: {e2}. Program will exit." |
| ) |
| |
| return source |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
|
|
| model_dir = Path(args.model_dir) |
|
|
| |
| if not model_dir.exists(): |
| raise SystemExit(f"[fatal] Model directory not found: {model_dir}. Program will exit.") |
| |
| model_file = prepare_model_file(model_dir, args.model_file) |
| print(f"[info] final model file: {model_file}") |
|
|
| device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
| |
| remote_code_path = Path(args.remote_code) |
| if not remote_code_path.exists(): |
| raise SystemExit( |
| f"[fatal] remote_code not found at {remote_code_path}. " |
| f"Expected model.py for SenseVoice. Program will exit." |
| ) |
|
|
| trust_remote = True |
|
|
| print("๋ชจ๋ธ ๋ก๋ฉ ์ค...") |
| model = AutoModel( |
| model=str(model_dir), |
| trust_remote_code=trust_remote, |
| remote_code=str(remote_code_path), |
| device=device, |
| vad_model=None, |
| ) |
|
|
| print("๋ฐ์ดํฐ์
๋ก๋ฉ ์ค...") |
| items = load_dataset(args.wav_scp, args.text_file, args.emo_file, args.base_audio_dir) |
|
|
| total = len(items) |
| print(f"[info] total inputs used: {total}, device: {device}, language: {args.lang}") |
| if total == 0: |
| print("[exit] No valid data found. Check file paths.") |
| return |
|
|
| out_path = Path(args.out) |
| if out_path.parent != Path("."): |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| exact_matches = 0 |
| cer_sum = 0.0 |
| wer_sum = 0.0 |
| text_pairs = 0 |
|
|
| emo_correct = 0 |
| emo_total = 0 |
|
|
| written = 0 |
| total_inference_time = 0.0 |
|
|
| with out_path.open("w", encoding="utf-8") as wf: |
| for it in items: |
| wav_path = it["abs_source"] |
|
|
| try: |
| |
| t0 = time.time() |
| res = model.generate( |
| input=wav_path, |
| cache={}, |
| language=args.lang, |
| use_itn=True, |
| batch_size=1, |
| ) |
| dur = time.time() - t0 |
| total_inference_time += dur |
| except Exception as e: |
| print(f"[error] inference failed on key={it.get('key')}: {e}") |
| continue |
|
|
| |
| r = res[0] if res else {} |
| raw_text = r.get("text", "") or "" |
| parsed = parse_sensevoice_text(raw_text) |
| pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else "" |
|
|
| ref_text = it.get("target") or "" |
|
|
| |
| if ref_text: |
| nt_ref = normalize_text(ref_text, args.lower, args.strip_punct, args.strip_spaces) |
| nt_hyp = normalize_text(pretty_text, args.lower, args.strip_punct, args.strip_spaces) |
|
|
| if nt_ref == nt_hyp: |
| exact_matches += 1 |
| cer_sum += cer(nt_ref, nt_hyp) |
| wer_sum += wer(nt_ref, nt_hyp) |
| text_pairs += 1 |
|
|
| |
| tgt_emo_n = norm_emo(it.get("emo_target")) |
| pred_emo_n = norm_emo(parsed["emo"]) |
| if tgt_emo_n: |
| emo_total += 1 |
| if pred_emo_n == tgt_emo_n: |
| emo_correct += 1 |
|
|
| out_obj = { |
| "key": it.get("key"), |
| "audio": it.get("abs_source"), |
| "pred_raw": raw_text, |
| "pred_text": pretty_text, |
| "ref_text": ref_text, |
| "pred_language": parsed["language"], |
| "pred_emo": pred_emo_n or parsed["emo"] or "", |
| "ref_emo": tgt_emo_n or it.get("emo_target") or "", |
| "pred_event": parsed["event"] or "", |
| "with_itn": parsed["with_itn"] or "", |
| "inference_time": round(dur, 4), |
| } |
| wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n") |
|
|
| |
| idx = written + 1 |
| print("\n[{}] key={}".format(idx, it.get("key"))) |
| print("REF_TEXT :", ref_text) |
| print("REF_EMO :", tgt_emo_n or it.get("emo_target")) |
| print("PRED_TEXT:", pretty_text) |
| print("PRED_EMO :", pred_emo_n or parsed["emo"]) |
| print("PRED_EVT :", parsed["event"]) |
| print("PRED_time:", round(dur, 4)) |
| print("-" * 80) |
|
|
| written += 1 |
|
|
| |
| print("\n===== Summary =====") |
| print(f"Samples inferred: {written}") |
| if written > 0: |
| print(f"Total inference time: {total_inference_time:.2f}s") |
| print(f"Avg inference time: {total_inference_time / written:.4f}s") |
| if text_pairs > 0: |
| exact_acc = exact_matches / text_pairs * 100.0 |
| avg_cer = cer_sum / text_pairs |
| avg_wer = wer_sum / text_pairs |
| print(f"Text pairs (with ref): {text_pairs}") |
| print(f"- Exact match accuracy: {exact_acc:.2f}%") |
| print(f"- Avg CER: {avg_cer:.4f}") |
| print(f"- Avg WER: {avg_wer:.4f}") |
| else: |
| print("No text references found; text metrics skipped.") |
|
|
| if emo_total > 0: |
| emo_acc = emo_correct / emo_total * 100.0 |
| print(f"Emotion pairs: {emo_total}") |
| print(f"- Emotion accuracy: {emo_acc:.2f}%") |
| else: |
| print("No emotion references found; emotion metrics skipped.") |
|
|
| print(f"Results saved to: {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|