| |
| from pathlib import Path |
| import os |
| import argparse |
|
|
| from huggingface_hub import snapshot_download |
| from funasr import AutoModel |
| from funasr.utils.postprocess_utils import rich_transcription_postprocess |
|
|
| HF_REPO_ID = "AeiROBOT/SenseVoice-Small-ko" |
| LOCAL_DIR = "/home/khw/.aeirobot_models/SenseVoice-Small-ko" |
|
|
| |
| LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"} |
| EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"} |
| EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"} |
| WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"} |
|
|
|
|
| def _consume(prefixes, text: str): |
| for p in prefixes: |
| if text.startswith(p): |
| return p, text[len(p):] |
| return None, text |
|
|
|
|
| def parse_sensevoice_text(raw: str): |
| """SenseVoice ์ถ๋ ฅ ๋ฌธ์์ด์์ (lang, emo, event, with_itn, text) ๋ถ๋ฆฌ. |
| |
| ์: |
| "<|ko|><|NEUTRAL|><|Speech|><|withitn|>์กฐ ๊ธ๋ง ์๊ฐ ์ ํ ๋ฉด์ ์ด ๋ฉด ํจ์ฌ ํธํ ๊ฑฐ์ผ." -> |
| { |
| "language": "<|ko|>", |
| "emo": "<|NEUTRAL|>", |
| "event": "<|Speech|>", |
| "with_itn": "<|withitn|>", |
| "text": "์กฐ ๊ธ๋ง ์๊ฐ ์ ํ ๋ฉด์ ์ด ๋ฉด ํจ์ฌ ํธํ ๊ฑฐ์ผ." |
| } |
| """ |
| if not raw: |
| return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""} |
|
|
| rest = raw.strip() |
| lang, rest = _consume(LANG_TOKENS, rest) |
| emo, rest = _consume(EMO_TOKENS, rest) |
| event, rest = _consume(EVENT_TOKENS, rest) |
| with_itn, rest = _consume(WITH_ITN_TOKENS, rest) |
|
|
| clean_text = rest.strip() |
| return { |
| "language": lang, |
| "emo": emo, |
| "event": event, |
| "with_itn": with_itn, |
| "text": clean_text, |
| } |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--wav_file", default="./test.wav", help="pretrained ๋ชจ๋ธ ์ด๋ฆ ๋๋ ๋ก์ปฌ ๋๋ ํฐ๋ฆฌ") |
| return p.parse_args() |
|
|
| def get_model(): |
| local_path = snapshot_download( |
| repo_id=HF_REPO_ID, |
| repo_type="model", |
| local_dir=LOCAL_DIR, |
| local_dir_use_symlinks=False, |
| token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), |
| ) |
| print("๋ค์ด๋ก๋ ๊ฒฝ๋ก:", local_path) |
|
|
| |
| model_dir = local_path |
|
|
| model = AutoModel( |
| model=model_dir, |
| trust_remote_code=True, |
| remote_code=str(Path(model_dir) / "model.py"), |
| vad_model=None, |
| |
| device="cuda:0", |
| ) |
| |
| return model |
|
|
| def main(): |
| args = parse_args() |
| wav_path = args.wav_file |
| |
| model = get_model() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| res = model.generate( |
| input=wav_path, |
| cache={}, |
| language="auto", |
| use_itn=True, |
| batch_size=1, |
| ) |
|
|
| raw_text = res[0]["text"] |
| parsed = parse_sensevoice_text(raw_text) |
|
|
| |
| pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else "" |
|
|
| print("=== Raw ===") |
| print(raw_text) |
| print("=== Parsed ===") |
| print("lang :", parsed["language"]) |
| print("emo :", parsed["emo"]) |
| print("event :", parsed["event"]) |
| print("withitn:", parsed["with_itn"]) |
| print("text :", pretty_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|