SenseVoice-Small-ko / experiments.py
HueyWoo's picture
Upload folder using huggingface_hub
7857ca0 verified
#!/usr/bin/env python3
import os
import json
import argparse
import unicodedata
import time
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import torch
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess
# =======================
# SenseVoice ํ† ํฐ ํŒŒ์„œ
# =======================
LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"}
EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"}
EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"}
WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"}
def _consume(prefixes, text: str):
for p in prefixes:
if text.startswith(p):
return p, text[len(p):]
return None, text
def parse_sensevoice_text(raw: str) -> Dict[str, Optional[str]]:
if not raw:
return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""}
rest = raw.strip()
lang, rest = _consume(LANG_TOKENS, rest)
emo, rest = _consume(EMO_TOKENS, rest)
event, rest = _consume(EVENT_TOKENS, rest)
with_itn, rest = _consume(WITH_ITN_TOKENS, rest)
clean_text = rest.strip()
return {
"language": lang,
"emo": emo,
"event": event,
"with_itn": with_itn,
"text": clean_text,
}
# =======================
# ํ…์ŠคํŠธ ์ •๊ทœํ™” & ์ง€ํ‘œ
# =======================
def normalize_text(s: str, lower: bool, strip_punct: bool, strip_spaces: bool) -> str:
if s is None:
return ""
t = s
if lower:
t = t.lower()
if strip_punct:
t = "".join(ch for ch in t if not unicodedata.category(ch).startswith("P"))
if strip_spaces:
t = "".join(t.split())
return t
def _levenshtein(a: List[str], b: List[str]) -> int:
n, m = len(a), len(b)
if n == 0:
return m
if m == 0:
return n
prev = list(range(m + 1))
for i in range(1, n + 1):
curr = [i] + [0] * m
ai = a[i - 1]
for j in range(1, m + 1):
cost = 0 if ai == b[j - 1] else 1
curr[j] = min(
prev[j] + 1,
curr[j - 1] + 1,
prev[j - 1] + cost,
)
prev = curr
return prev[m]
def cer(ref: str, hyp: str) -> float:
r = list(ref)
h = list(hyp)
dist = _levenshtein(r, h)
return dist / max(1, len(r))
def wer(ref: str, hyp: str) -> float:
r = ref.split()
h = hyp.split()
dist = _levenshtein(r, h)
return dist / max(1, len(r))
def norm_emo(label: Optional[str]) -> str:
if not label:
return ""
t = label.strip()
if t.startswith("<|") and t.endswith("|>"):
t = t[2:-2]
return t.upper()
# =======================
# IO & argparse
# =======================
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--model-dir", default="SenseVoice-Small-ko", help="๋ชจ๋ธ ๋””๋ ‰ํ„ฐ๋ฆฌ")
p.add_argument("--model-file", default="model.pt", help="๋ชจ๋ธ ํŒŒ์ผ๋ช… (์˜ˆ: model.pt, model.pt.best)")
p.add_argument("--wav-scp", default="dataset/dataset/train_wav.scp", help="wav scp ํŒŒ์ผ ๊ฒฝ๋กœ")
p.add_argument("--text-file", default="dataset/dataset/train_text.txt", help="text ํŒŒ์ผ ๊ฒฝ๋กœ")
p.add_argument("--emo-file", default="dataset/dataset/train_emo.txt", help="emotion ํŒŒ์ผ ๊ฒฝ๋กœ")
p.add_argument("--base-audio-dir", default=".", help="wav ํŒŒ์ผ ๊ธฐ์ค€ ๋””๋ ‰ํ„ฐ๋ฆฌ")
p.add_argument("--remote-code", default="SenseVoice-Small-ko/model.py", help="SenseVoice ๋ชจ๋ธ ๊ตฌํ˜„ ๊ฒฝ๋กœ")
p.add_argument("--device", default=None, help="cuda:0 / cpu (๋ฏธ์ง€์ • ์‹œ ์ž๋™ ๊ฒฐ์ •)")
p.add_argument("--lang", default="ko", choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], help="์–ธ์–ด ๊ฐ•์ œ ์„ค์ •. ๊ธฐ๋ณธ ko")
p.add_argument("--lower", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ์†Œ๋ฌธ์žํ™”")
p.add_argument("--strip-punct", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ")
p.add_argument("--strip-spaces", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ๋ชจ๋“  ๊ณต๋ฐฑ ์ œ๊ฑฐ")
p.add_argument("--out", default="results.jsonl", help="์ถ”๋ก  ๊ฒฐ๊ณผ JSONL")
return p.parse_args()
def load_dataset(wav_scp: str, text_file: str, emo_file: str, base_audio_dir: str) -> List[Dict]:
"""wav-scp, text-file, emo-file์—์„œ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ"""
data = {}
# Read WAV SCP
with open(wav_scp, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split(maxsplit=1)
if len(parts) == 2:
key, path = parts
# Resolve path
if base_audio_dir:
full_path = os.path.join(base_audio_dir, path)
else:
full_path = path
data[key] = {'wav_path': full_path}
# Read Text
with open(text_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split(maxsplit=1)
if len(parts) >= 1:
key = parts[0]
text = parts[1] if len(parts) > 1 else ""
if key in data:
data[key]['text'] = text
# Read Emo
with open(emo_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split(maxsplit=1)
if len(parts) >= 2:
key, emo = parts
if key in data:
data[key]['emo'] = emo
# Filter complete items
valid_data = []
for key, val in data.items():
if 'text' in val and 'emo' in val:
if os.path.exists(val['wav_path']):
valid_data.append({
'key': key,
'abs_source': val['wav_path'],
'target': val['text'],
'emo_target': val['emo']
})
else:
print(f"[warn] File not found: {val['wav_path']}")
return valid_data
def prepare_model_file(model_dir: Path, model_file: str) -> Path:
"""์ง€์ •๋œ ๋ชจ๋ธ ํŒŒ์ผ์ด ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๊ณ  model.pt๋กœ ์‹ฌ๋ณผ๋ฆญ ๋งํฌ ์ƒ์„ฑ"""
source = model_dir / model_file
target = model_dir / "model.pt"
if not source.exists():
raise SystemExit(
f"[fatal] Model file not found: {source}. Program will exit."
)
# ๊ฐ™์€ ํŒŒ์ผ์ด๋ฉด ์•„๋ฌด๊ฒƒ๋„ ์•ˆ ํ•จ
if source == target:
print(f"[info] using model file: {model_file}")
return source
# ๋‹ค๋ฅธ ํŒŒ์ผ์ด๋ฉด model.pt๋ฅผ ํ•ด๋‹น ํŒŒ์ผ๋กœ ๋งํฌ
if target.exists() or target.is_symlink():
try:
target.unlink()
except Exception as e:
print(f"[warn] failed to remove existing {target}: {e}")
try:
# ์ƒ๋Œ€ ์ด๋ฆ„์œผ๋กœ ์‹ฌ๋ณผ๋ฆญ ๋งํฌ ์ƒ์„ฑ
target.symlink_to(model_file)
print(f"[info] using model file: {model_file} (linked as model.pt)")
except Exception as e:
# ์ผ๋ถ€ ํŒŒ์ผ์‹œ์Šคํ…œ/๊ถŒํ•œ ํ™˜๊ฒฝ์—์„œ symlink๊ฐ€ ์•ˆ ๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๋ณต์‚ฌ๋กœ ํด๋ฐฑ
print(f"[warn] symlink failed ({e}), will try to copy instead.")
import shutil
try:
shutil.copy2(str(source), str(target))
print(f"[info] using model file: {model_file} (copied to model.pt)")
except Exception as e2:
raise SystemExit(
f"[fatal] failed to prepare model file at {target}: {e2}. Program will exit."
)
return source
# =======================
# main
# =======================
def main():
args = parse_args()
model_dir = Path(args.model_dir)
# ๋ชจ๋ธ ํŒŒ์ผ ์ค€๋น„
if not model_dir.exists():
raise SystemExit(f"[fatal] Model directory not found: {model_dir}. Program will exit.")
model_file = prepare_model_file(model_dir, args.model_file)
print(f"[info] final model file: {model_file}")
device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")
# model.py(remote_code)๋Š” ๋ฐ˜๋“œ์‹œ ์กด์žฌํ•ด์•ผ ํ•œ๋‹ค. ์—†์œผ๋ฉด ๋ฐ”๋กœ ์ข…๋ฃŒ.
remote_code_path = Path(args.remote_code)
if not remote_code_path.exists():
raise SystemExit(
f"[fatal] remote_code not found at {remote_code_path}. "
f"Expected model.py for SenseVoice. Program will exit."
)
trust_remote = True
print("๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
model = AutoModel(
model=str(model_dir), # ๋กœ์ปฌ ๋””๋ ‰ํ„ฐ๋ฆฌ๋งŒ ์‚ฌ์šฉ
trust_remote_code=trust_remote,
remote_code=str(remote_code_path),
device=device,
vad_model=None,
)
print("๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ ์ค‘...")
items = load_dataset(args.wav_scp, args.text_file, args.emo_file, args.base_audio_dir)
total = len(items)
print(f"[info] total inputs used: {total}, device: {device}, language: {args.lang}")
if total == 0:
print("[exit] No valid data found. Check file paths.")
return
out_path = Path(args.out)
if out_path.parent != Path("."):
out_path.parent.mkdir(parents=True, exist_ok=True)
# ์ง€ํ‘œ ๋ˆ„์ 
exact_matches = 0
cer_sum = 0.0
wer_sum = 0.0
text_pairs = 0
emo_correct = 0
emo_total = 0
written = 0
total_inference_time = 0.0
with out_path.open("w", encoding="utf-8") as wf:
for it in items:
wav_path = it["abs_source"]
try:
# ์ถ”๋ก  ์‹œ๊ฐ„ ์ธก์ •
t0 = time.time()
res = model.generate(
input=wav_path,
cache={},
language=args.lang,
use_itn=True,
batch_size=1,
)
dur = time.time() - t0
total_inference_time += dur
except Exception as e:
print(f"[error] inference failed on key={it.get('key')}: {e}")
continue
# res๋Š” ๋ฆฌ์ŠคํŠธ์ด๋ฏ€๋กœ ์ฒซ ๋ฒˆ์งธ ์š”์†Œ ์‚ฌ์šฉ
r = res[0] if res else {}
raw_text = r.get("text", "") or ""
parsed = parse_sensevoice_text(raw_text)
pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else ""
ref_text = it.get("target") or ""
# ํ…์ŠคํŠธ ์ง€ํ‘œ
if ref_text:
nt_ref = normalize_text(ref_text, args.lower, args.strip_punct, args.strip_spaces)
nt_hyp = normalize_text(pretty_text, args.lower, args.strip_punct, args.strip_spaces)
if nt_ref == nt_hyp:
exact_matches += 1
cer_sum += cer(nt_ref, nt_hyp)
wer_sum += wer(nt_ref, nt_hyp)
text_pairs += 1
# ๊ฐ์ • ์ง€ํ‘œ
tgt_emo_n = norm_emo(it.get("emo_target"))
pred_emo_n = norm_emo(parsed["emo"])
if tgt_emo_n:
emo_total += 1
if pred_emo_n == tgt_emo_n:
emo_correct += 1
out_obj = {
"key": it.get("key"),
"audio": it.get("abs_source"),
"pred_raw": raw_text,
"pred_text": pretty_text,
"ref_text": ref_text,
"pred_language": parsed["language"],
"pred_emo": pred_emo_n or parsed["emo"] or "",
"ref_emo": tgt_emo_n or it.get("emo_target") or "",
"pred_event": parsed["event"] or "",
"with_itn": parsed["with_itn"] or "",
"inference_time": round(dur, 4),
}
wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
# ===== ์‚ฌ๋žŒ์ด ๋ณด๊ธฐ ์ข‹์€ per-sample ์ถœ๋ ฅ =====
idx = written + 1
print("\n[{}] key={}".format(idx, it.get("key")))
print("REF_TEXT :", ref_text)
print("REF_EMO :", tgt_emo_n or it.get("emo_target"))
print("PRED_TEXT:", pretty_text)
print("PRED_EMO :", pred_emo_n or parsed["emo"]) # ํ† ํฐ ๊ทธ๋Œ€๋กœ ๋ณด์—ฌ์ค˜๋„ ๋จ
print("PRED_EVT :", parsed["event"]) # ์ด๋ฒคํŠธ๋„ ๊ฐ™์ด ํ™•์ธ
print("PRED_time:", round(dur, 4))
print("-" * 80)
written += 1
# ์š”์•ฝ ์ถœ๋ ฅ
print("\n===== Summary =====")
print(f"Samples inferred: {written}")
if written > 0:
print(f"Total inference time: {total_inference_time:.2f}s")
print(f"Avg inference time: {total_inference_time / written:.4f}s")
if text_pairs > 0:
exact_acc = exact_matches / text_pairs * 100.0
avg_cer = cer_sum / text_pairs
avg_wer = wer_sum / text_pairs
print(f"Text pairs (with ref): {text_pairs}")
print(f"- Exact match accuracy: {exact_acc:.2f}%")
print(f"- Avg CER: {avg_cer:.4f}")
print(f"- Avg WER: {avg_wer:.4f}")
else:
print("No text references found; text metrics skipped.")
if emo_total > 0:
emo_acc = emo_correct / emo_total * 100.0
print(f"Emotion pairs: {emo_total}")
print(f"- Emotion accuracy: {emo_acc:.2f}%")
else:
print("No emotion references found; emotion metrics skipped.")
print(f"Results saved to: {out_path}")
if __name__ == "__main__":
main()