#!/usr/bin/env python3 import os import json import argparse import unicodedata import time from pathlib import Path from typing import List, Dict, Tuple, Optional import torch from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess # ======================= # SenseVoice 토큰 파서 # ======================= LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"} EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"} EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"} WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"} def _consume(prefixes, text: str): for p in prefixes: if text.startswith(p): return p, text[len(p):] return None, text def parse_sensevoice_text(raw: str) -> Dict[str, Optional[str]]: if not raw: return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""} rest = raw.strip() lang, rest = _consume(LANG_TOKENS, rest) emo, rest = _consume(EMO_TOKENS, rest) event, rest = _consume(EVENT_TOKENS, rest) with_itn, rest = _consume(WITH_ITN_TOKENS, rest) clean_text = rest.strip() return { "language": lang, "emo": emo, "event": event, "with_itn": with_itn, "text": clean_text, } # ======================= # 텍스트 정규화 & 지표 # ======================= def normalize_text(s: str, lower: bool, strip_punct: bool, strip_spaces: bool) -> str: if s is None: return "" t = s if lower: t = t.lower() if strip_punct: t = "".join(ch for ch in t if not unicodedata.category(ch).startswith("P")) if strip_spaces: t = "".join(t.split()) return t def _levenshtein(a: List[str], b: List[str]) -> int: n, m = len(a), len(b) if n == 0: return m if m == 0: return n prev = list(range(m + 1)) for i in range(1, n + 1): curr = [i] + [0] * m ai = a[i - 1] for j in range(1, m + 1): cost = 0 if ai == b[j - 1] else 1 curr[j] = min( prev[j] + 1, curr[j - 1] + 1, prev[j - 1] + cost, ) prev = curr return prev[m] def cer(ref: str, hyp: str) -> float: r = list(ref) h = list(hyp) dist = _levenshtein(r, h) return dist / max(1, len(r)) def wer(ref: str, hyp: str) -> float: r = ref.split() h = hyp.split() dist = _levenshtein(r, h) return dist / max(1, len(r)) def norm_emo(label: Optional[str]) -> str: if not label: return "" t = label.strip() if t.startswith("<|") and t.endswith("|>"): t = t[2:-2] return t.upper() # ======================= # IO & argparse # ======================= def parse_args(): p = argparse.ArgumentParser() p.add_argument("--model-dir", default="SenseVoice-Small-ko", help="모델 디렉터리") p.add_argument("--model-file", default="model.pt", help="모델 파일명 (예: model.pt, model.pt.best)") p.add_argument("--wav-scp", default="dataset/dataset/train_wav.scp", help="wav scp 파일 경로") p.add_argument("--text-file", default="dataset/dataset/train_text.txt", help="text 파일 경로") p.add_argument("--emo-file", default="dataset/dataset/train_emo.txt", help="emotion 파일 경로") p.add_argument("--base-audio-dir", default=".", help="wav 파일 기준 디렉터리") p.add_argument("--remote-code", default="SenseVoice-Small-ko/model.py", help="SenseVoice 모델 구현 경로") p.add_argument("--device", default=None, help="cuda:0 / cpu (미지정 시 자동 결정)") p.add_argument("--lang", default="ko", choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], help="언어 강제 설정. 기본 ko") p.add_argument("--lower", action="store_true", help="정밀도 계산 시 소문자화") p.add_argument("--strip-punct", action="store_true", help="정밀도 계산 시 문장부호 제거") p.add_argument("--strip-spaces", action="store_true", help="정밀도 계산 시 모든 공백 제거") p.add_argument("--out", default="results.jsonl", help="추론 결과 JSONL") return p.parse_args() def load_dataset(wav_scp: str, text_file: str, emo_file: str, base_audio_dir: str) -> List[Dict]: """wav-scp, text-file, emo-file에서 데이터셋 로드""" data = {} # Read WAV SCP with open(wav_scp, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split(maxsplit=1) if len(parts) == 2: key, path = parts # Resolve path if base_audio_dir: full_path = os.path.join(base_audio_dir, path) else: full_path = path data[key] = {'wav_path': full_path} # Read Text with open(text_file, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split(maxsplit=1) if len(parts) >= 1: key = parts[0] text = parts[1] if len(parts) > 1 else "" if key in data: data[key]['text'] = text # Read Emo with open(emo_file, 'r', encoding='utf-8') as f: for line in f: parts = line.strip().split(maxsplit=1) if len(parts) >= 2: key, emo = parts if key in data: data[key]['emo'] = emo # Filter complete items valid_data = [] for key, val in data.items(): if 'text' in val and 'emo' in val: if os.path.exists(val['wav_path']): valid_data.append({ 'key': key, 'abs_source': val['wav_path'], 'target': val['text'], 'emo_target': val['emo'] }) else: print(f"[warn] File not found: {val['wav_path']}") return valid_data def prepare_model_file(model_dir: Path, model_file: str) -> Path: """지정된 모델 파일이 존재하는지 확인하고 model.pt로 심볼릭 링크 생성""" source = model_dir / model_file target = model_dir / "model.pt" if not source.exists(): raise SystemExit( f"[fatal] Model file not found: {source}. Program will exit." ) # 같은 파일이면 아무것도 안 함 if source == target: print(f"[info] using model file: {model_file}") return source # 다른 파일이면 model.pt를 해당 파일로 링크 if target.exists() or target.is_symlink(): try: target.unlink() except Exception as e: print(f"[warn] failed to remove existing {target}: {e}") try: # 상대 이름으로 심볼릭 링크 생성 target.symlink_to(model_file) print(f"[info] using model file: {model_file} (linked as model.pt)") except Exception as e: # 일부 파일시스템/권한 환경에서 symlink가 안 될 수 있으므로, 복사로 폴백 print(f"[warn] symlink failed ({e}), will try to copy instead.") import shutil try: shutil.copy2(str(source), str(target)) print(f"[info] using model file: {model_file} (copied to model.pt)") except Exception as e2: raise SystemExit( f"[fatal] failed to prepare model file at {target}: {e2}. Program will exit." ) return source # ======================= # main # ======================= def main(): args = parse_args() model_dir = Path(args.model_dir) # 모델 파일 준비 if not model_dir.exists(): raise SystemExit(f"[fatal] Model directory not found: {model_dir}. Program will exit.") model_file = prepare_model_file(model_dir, args.model_file) print(f"[info] final model file: {model_file}") device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu") # model.py(remote_code)는 반드시 존재해야 한다. 없으면 바로 종료. remote_code_path = Path(args.remote_code) if not remote_code_path.exists(): raise SystemExit( f"[fatal] remote_code not found at {remote_code_path}. " f"Expected model.py for SenseVoice. Program will exit." ) trust_remote = True print("모델 로딩 중...") model = AutoModel( model=str(model_dir), # 로컬 디렉터리만 사용 trust_remote_code=trust_remote, remote_code=str(remote_code_path), device=device, vad_model=None, ) print("데이터셋 로딩 중...") items = load_dataset(args.wav_scp, args.text_file, args.emo_file, args.base_audio_dir) total = len(items) print(f"[info] total inputs used: {total}, device: {device}, language: {args.lang}") if total == 0: print("[exit] No valid data found. Check file paths.") return out_path = Path(args.out) if out_path.parent != Path("."): out_path.parent.mkdir(parents=True, exist_ok=True) # 지표 누적 exact_matches = 0 cer_sum = 0.0 wer_sum = 0.0 text_pairs = 0 emo_correct = 0 emo_total = 0 written = 0 total_inference_time = 0.0 with out_path.open("w", encoding="utf-8") as wf: for it in items: wav_path = it["abs_source"] try: # 추론 시간 측정 t0 = time.time() res = model.generate( input=wav_path, cache={}, language=args.lang, use_itn=True, batch_size=1, ) dur = time.time() - t0 total_inference_time += dur except Exception as e: print(f"[error] inference failed on key={it.get('key')}: {e}") continue # res는 리스트이므로 첫 번째 요소 사용 r = res[0] if res else {} raw_text = r.get("text", "") or "" parsed = parse_sensevoice_text(raw_text) pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else "" ref_text = it.get("target") or "" # 텍스트 지표 if ref_text: nt_ref = normalize_text(ref_text, args.lower, args.strip_punct, args.strip_spaces) nt_hyp = normalize_text(pretty_text, args.lower, args.strip_punct, args.strip_spaces) if nt_ref == nt_hyp: exact_matches += 1 cer_sum += cer(nt_ref, nt_hyp) wer_sum += wer(nt_ref, nt_hyp) text_pairs += 1 # 감정 지표 tgt_emo_n = norm_emo(it.get("emo_target")) pred_emo_n = norm_emo(parsed["emo"]) if tgt_emo_n: emo_total += 1 if pred_emo_n == tgt_emo_n: emo_correct += 1 out_obj = { "key": it.get("key"), "audio": it.get("abs_source"), "pred_raw": raw_text, "pred_text": pretty_text, "ref_text": ref_text, "pred_language": parsed["language"], "pred_emo": pred_emo_n or parsed["emo"] or "", "ref_emo": tgt_emo_n or it.get("emo_target") or "", "pred_event": parsed["event"] or "", "with_itn": parsed["with_itn"] or "", "inference_time": round(dur, 4), } wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n") # ===== 사람이 보기 좋은 per-sample 출력 ===== idx = written + 1 print("\n[{}] key={}".format(idx, it.get("key"))) print("REF_TEXT :", ref_text) print("REF_EMO :", tgt_emo_n or it.get("emo_target")) print("PRED_TEXT:", pretty_text) print("PRED_EMO :", pred_emo_n or parsed["emo"]) # 토큰 그대로 보여줘도 됨 print("PRED_EVT :", parsed["event"]) # 이벤트도 같이 확인 print("PRED_time:", round(dur, 4)) print("-" * 80) written += 1 # 요약 출력 print("\n===== Summary =====") print(f"Samples inferred: {written}") if written > 0: print(f"Total inference time: {total_inference_time:.2f}s") print(f"Avg inference time: {total_inference_time / written:.4f}s") if text_pairs > 0: exact_acc = exact_matches / text_pairs * 100.0 avg_cer = cer_sum / text_pairs avg_wer = wer_sum / text_pairs print(f"Text pairs (with ref): {text_pairs}") print(f"- Exact match accuracy: {exact_acc:.2f}%") print(f"- Avg CER: {avg_cer:.4f}") print(f"- Avg WER: {avg_wer:.4f}") else: print("No text references found; text metrics skipped.") if emo_total > 0: emo_acc = emo_correct / emo_total * 100.0 print(f"Emotion pairs: {emo_total}") print(f"- Emotion accuracy: {emo_acc:.2f}%") else: print("No emotion references found; emotion metrics skipped.") print(f"Results saved to: {out_path}") if __name__ == "__main__": main()