Spaces:
Sleeping
Sleeping
| import re | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| def preprocess_text(text, remove_spaces=False, remove_punctuation=False): | |
| """ | |
| ํ ์คํธ ์ ์ฒ๋ฆฌ ํจ์ | |
| Args: | |
| text (str): ์ ์ฒ๋ฆฌํ ํ ์คํธ | |
| remove_spaces (bool): ๊ณต๋ฐฑ ์ ๊ฑฐ ์ฌ๋ถ | |
| remove_punctuation (bool): ๋ฌธ์ฅ๋ถํธ ์ ๊ฑฐ ์ฌ๋ถ | |
| Returns: | |
| str: ์ ์ฒ๋ฆฌ๋ ํ ์คํธ | |
| """ | |
| if remove_punctuation: | |
| # ํ๊ธ, ์๋ฌธ, ์ซ์๋ฅผ ์ ์ธํ ๋ฌธ์ฅ๋ถํธ ๋ฑ ์ ๊ฑฐ | |
| text = re.sub(r'[^\w\s]', '', text) | |
| if remove_spaces: | |
| # ๋ชจ๋ ๊ณต๋ฐฑ ์ ๊ฑฐ | |
| text = text.replace(' ', '') | |
| return text | |
| def calculate_levenshtein(u, v): | |
| """ | |
| ๋ ๋ฌธ์์ด ๊ฐ์ ๋ ๋ฒค์ํ์ธ ๊ฑฐ๋ฆฌ์ ์์ ์ธ๋ถ ์ ๋ณด(๋์ฒด, ์ญ์ , ์ฝ์ )๋ฅผ ๊ณ์ฐ | |
| Args: | |
| u (list): ์ฒซ ๋ฒ์งธ ๋ฌธ์์ด(๋ฌธ์ ๋ฆฌ์คํธ) | |
| v (list): ๋ ๋ฒ์งธ ๋ฌธ์์ด(๋ฌธ์ ๋ฆฌ์คํธ) | |
| Returns: | |
| tuple: (ํธ์ง ๊ฑฐ๋ฆฌ, (๋์ฒด ์, ์ญ์ ์, ์ฝ์ ์)) | |
| """ | |
| prev = None | |
| curr = [0] + list(range(1, len(v) + 1)) | |
| # ์์ : (๋์ฒด, ์ญ์ , ์ฝ์ ) | |
| prev_ops = None | |
| curr_ops = [(0, 0, i) for i in range(len(v) + 1)] | |
| for x in range(1, len(u) + 1): | |
| prev, curr = curr, [x] + ([None] * len(v)) | |
| prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v)) | |
| for y in range(1, len(v) + 1): | |
| delcost = prev[y] + 1 | |
| addcost = curr[y - 1] + 1 | |
| subcost = prev[y - 1] + int(u[x - 1] != v[y - 1]) | |
| curr[y] = min(subcost, delcost, addcost) | |
| if curr[y] == subcost: | |
| (n_s, n_d, n_i) = prev_ops[y - 1] | |
| curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i) | |
| elif curr[y] == delcost: | |
| (n_s, n_d, n_i) = prev_ops[y] | |
| curr_ops[y] = (n_s, n_d + 1, n_i) | |
| else: | |
| (n_s, n_d, n_i) = curr_ops[y - 1] | |
| curr_ops[y] = (n_s, n_d, n_i + 1) | |
| return curr[len(v)], curr_ops[len(v)] | |
| def calculate_korean_cer(reference, hypothesis, remove_spaces=True, remove_punctuation=True): | |
| """ | |
| ํ๊ตญ์ด ๋ฌธ์ฅ์ CER(Character Error Rate)์ ๊ณ์ฐ | |
| Args: | |
| reference (str): ์ ๋ต ๋ฌธ์ฅ | |
| hypothesis (str): ์์ธก ๋ฌธ์ฅ | |
| remove_spaces (bool): ๊ณต๋ฐฑ ์ ๊ฑฐ ์ฌ๋ถ | |
| remove_punctuation (bool): ๋ฌธ์ฅ๋ถํธ ์ ๊ฑฐ ์ฌ๋ถ | |
| Returns: | |
| dict: CER ๊ฐ๊ณผ ์ธ๋ถ ์ ๋ณด (๋์ฒด, ์ญ์ , ์ฝ์ ์) | |
| """ | |
| # preprocessing | |
| ref = preprocess_text(reference, remove_spaces, remove_punctuation) | |
| hyp = preprocess_text(hypothesis, remove_spaces, remove_punctuation) | |
| ref_chars = list(ref) | |
| hyp_chars = list(hyp) | |
| _, (substitutions, deletions, insertions) = calculate_levenshtein(hyp_chars, ref_chars) | |
| hits = len(ref_chars) - (substitutions + deletions) | |
| incorrect = substitutions + deletions + insertions | |
| total = substitutions + deletions + hits + insertions | |
| cer = round(incorrect / total, 4) if total > 0 else 0 | |
| result = { | |
| 'cer': cer, | |
| 'substitutions': substitutions, | |
| 'deletions': deletions, | |
| 'insertions': insertions | |
| } | |
| return result | |
| def calculate_korean_crr(reference, hypothesis, remove_spaces=True, remove_punctuation=True): | |
| """ | |
| ํ๊ตญ์ด ๋ฌธ์ฅ์ CRR(์ ํ๋)์ ๊ณ์ฐ | |
| CRR = 1 - CER | |
| Args: | |
| reference (str): ์ ๋ต ๋ฌธ์ฅ | |
| hypothesis (str): ์์ธก ๋ฌธ์ฅ | |
| remove_spaces (bool): ๊ณต๋ฐฑ ์ ๊ฑฐ ์ฌ๋ถ | |
| remove_punctuation (bool): ๋ฌธ์ฅ๋ถํธ ์ ๊ฑฐ ์ฌ๋ถ | |
| Returns: | |
| dict: CRR ๊ฐ๊ณผ ์ธ๋ถ ์ ๋ณด (๋์ฒด, ์ญ์ , ์ฝ์ ์) | |
| """ | |
| cer_result = calculate_korean_cer(reference, hypothesis, remove_spaces, remove_punctuation) | |
| crr = round(1 - cer_result['cer'], 4) # ์ด ๋ถ๋ถ์์ ์์์ ๋ช ๋ฒ์งธ ์๋ฆฌ๊น์ง ๋ํ๋ผ์ง ์ค์ ๊ฐ๋ฅ | |
| result = { | |
| 'crr': crr, # ์ ํ๋ | |
| 'substitutions': cer_result['substitutions'], | |
| 'deletions': cer_result['deletions'], | |
| 'insertions': cer_result['insertions'] | |
| } | |
| return result | |
| def transcribe_audio(file_path, model_name="daeunn/wav2vec2-korean-finetuned2"): | |
| # ๋ชจ๋ธ ๋ฐ ํ๋ก์ธ์ ๋ก๋ | |
| processor = Wav2Vec2Processor.from_pretrained(model_name) | |
| model = Wav2Vec2ForCTC.from_pretrained(model_name) | |
| # ์ค๋์ค ํ์ผ ๋ก๋ ๋ฐ 16kHz ๋ฆฌ์ํ๋ง | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| if sample_rate != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
| input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True) | |
| # ์ถ๋ก | |
| with torch.no_grad(): | |
| logits = model(**input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode(predicted_ids[0]) | |
| return transcription | |
| if __name__ == "__main__": | |
| # ๋์ผํ ๋ฌธ์ฅ์ ๋ํ ๋ ๋ฒ์ ๋ น์ | |
| reference = "์ ๊ฐ ์ค์จ๋ด์์ ์๊ณ , ์ฐ๋ฆฌ๋๋ผ๊ฐ ํฐ ๋๋ผ์ด์ง๋ง ์ธ๊ตฌ๊ฐ ์ข ์ ์ด์ ํ์์ด๋ผ๋ ์ฌ๋ฐ๊ฒ ํ ์ ์๋๊ฒ ๋ง์ด ์๊ณ ์นดํ๋ ์ ์ง์ด๋ ์ด๋ฐ๊ฒ ๋ง์ด ์์ด์ ๊ทธ๋ฐ ๊ฑฐ ํ๊ตญ์ ๋ง์ด ์๋ค๊ณ ๋ค์๊ณ ๊ทธ๊ฑฐ ๋๋ฌธ์ ํ๊ตญ์ ๊ณต๋ถํ๋ฌ ์์ด์." | |
| audio_path1 = "../data/stt_test.wav" | |
| audio_path2 = "../data/stt_test.wav" # ํ์ผ ๋ณ๊ฒฝ ํ์์ | |
| # ์ฒซ ๋ฒ์งธ ์ค๋์ค ์ฒ๋ฆฌ | |
| print("\n[์ฒซ ๋ฒ์งธ ๋ฐํ STT ๋ฐ ์ ํ๋ ํ๊ฐ]") | |
| hypothesis1 = transcribe_audio(audio_path1) | |
| print("์์ธก 1:", hypothesis1) | |
| crr_result1 = calculate_korean_crr(reference, hypothesis1) | |
| print(f"CRR #1: {crr_result1['crr']} (๋์ฒด={crr_result1['substitutions']}, ์ญ์ ={crr_result1['deletions']}, ์ฝ์ ={crr_result1['insertions']})") | |
| # ๋ ๋ฒ์งธ ์ค๋์ค ์ฒ๋ฆฌ | |
| print("\n[๋ ๋ฒ์งธ ๋ฐํ STT ๋ฐ ์ ํ๋ ํ๊ฐ]") | |
| hypothesis2 = transcribe_audio(audio_path2) | |
| print("์์ธก 2:", hypothesis2) | |
| crr_result2 = calculate_korean_crr(reference, hypothesis2) | |
| print(f"CRR #2: {crr_result2['crr']} (๋์ฒด={crr_result2['substitutions']}, ์ญ์ ={crr_result2['deletions']}, ์ฝ์ ={crr_result2['insertions']})") | |
| # ์ ํ๋ ๋ณํ๋ ์ถ๋ ฅ | |
| diff = round(crr_result2['crr'] - crr_result1['crr'], 4) | |
| print(f"\n ๋์ผ ๋ฌธ์ฅ ์ฌ๋ฐํ์ ๋ฐ๋ฅธ CRR ๋ณํ๋: {diff:+.4f}") | |