wav2vec2-server / model /cer_module.py
daeunn's picture
feat: cer_module.py ์ถ”๊ฐ€
a341d0e verified
raw
history blame
6.53 kB
import re
import numpy as np
import torch
import torchaudio
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
def preprocess_text(text, remove_spaces=False, remove_punctuation=False):
"""
ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
Args:
text (str): ์ „์ฒ˜๋ฆฌํ•  ํ…์ŠคํŠธ
remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
Returns:
str: ์ „์ฒ˜๋ฆฌ๋œ ํ…์ŠคํŠธ
"""
if remove_punctuation:
# ํ•œ๊ธ€, ์˜๋ฌธ, ์ˆซ์ž๋ฅผ ์ œ์™ธํ•œ ๋ฌธ์žฅ๋ถ€ํ˜ธ ๋“ฑ ์ œ๊ฑฐ
text = re.sub(r'[^\w\s]', '', text)
if remove_spaces:
# ๋ชจ๋“  ๊ณต๋ฐฑ ์ œ๊ฑฐ
text = text.replace(' ', '')
return text
def calculate_levenshtein(u, v):
"""
๋‘ ๋ฌธ์ž์—ด ๊ฐ„์˜ ๋ ˆ๋ฒค์Šˆํƒ€์ธ ๊ฑฐ๋ฆฌ์™€ ์ž‘์—… ์„ธ๋ถ€ ์ •๋ณด(๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž…)๋ฅผ ๊ณ„์‚ฐ
Args:
u (list): ์ฒซ ๋ฒˆ์งธ ๋ฌธ์ž์—ด(๋ฌธ์ž ๋ฆฌ์ŠคํŠธ)
v (list): ๋‘ ๋ฒˆ์งธ ๋ฌธ์ž์—ด(๋ฌธ์ž ๋ฆฌ์ŠคํŠธ)
Returns:
tuple: (ํŽธ์ง‘ ๊ฑฐ๋ฆฌ, (๋Œ€์ฒด ์ˆ˜, ์‚ญ์ œ ์ˆ˜, ์‚ฝ์ž… ์ˆ˜))
"""
prev = None
curr = [0] + list(range(1, len(v) + 1))
# ์ž‘์—…: (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž…)
prev_ops = None
curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
for x in range(1, len(u) + 1):
prev, curr = curr, [x] + ([None] * len(v))
prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
for y in range(1, len(v) + 1):
delcost = prev[y] + 1
addcost = curr[y - 1] + 1
subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
curr[y] = min(subcost, delcost, addcost)
if curr[y] == subcost:
(n_s, n_d, n_i) = prev_ops[y - 1]
curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
elif curr[y] == delcost:
(n_s, n_d, n_i) = prev_ops[y]
curr_ops[y] = (n_s, n_d + 1, n_i)
else:
(n_s, n_d, n_i) = curr_ops[y - 1]
curr_ops[y] = (n_s, n_d, n_i + 1)
return curr[len(v)], curr_ops[len(v)]
def calculate_korean_cer(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
"""
ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์˜ CER(Character Error Rate)์„ ๊ณ„์‚ฐ
Args:
reference (str): ์ •๋‹ต ๋ฌธ์žฅ
hypothesis (str): ์˜ˆ์ธก ๋ฌธ์žฅ
remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
Returns:
dict: CER ๊ฐ’๊ณผ ์„ธ๋ถ€ ์ •๋ณด (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž… ์ˆ˜)
"""
# preprocessing
ref = preprocess_text(reference, remove_spaces, remove_punctuation)
hyp = preprocess_text(hypothesis, remove_spaces, remove_punctuation)
ref_chars = list(ref)
hyp_chars = list(hyp)
_, (substitutions, deletions, insertions) = calculate_levenshtein(hyp_chars, ref_chars)
hits = len(ref_chars) - (substitutions + deletions)
incorrect = substitutions + deletions + insertions
total = substitutions + deletions + hits + insertions
cer = round(incorrect / total, 4) if total > 0 else 0
result = {
'cer': cer,
'substitutions': substitutions,
'deletions': deletions,
'insertions': insertions
}
return result
def calculate_korean_crr(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
"""
ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์˜ CRR(์ •ํ™•๋„)์„ ๊ณ„์‚ฐ
CRR = 1 - CER
Args:
reference (str): ์ •๋‹ต ๋ฌธ์žฅ
hypothesis (str): ์˜ˆ์ธก ๋ฌธ์žฅ
remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
Returns:
dict: CRR ๊ฐ’๊ณผ ์„ธ๋ถ€ ์ •๋ณด (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž… ์ˆ˜)
"""
cer_result = calculate_korean_cer(reference, hypothesis, remove_spaces, remove_punctuation)
crr = round(1 - cer_result['cer'], 4) # ์ด ๋ถ€๋ถ„์—์„œ ์†Œ์ˆ˜์  ๋ช‡ ๋ฒˆ์งธ ์ž๋ฆฌ๊นŒ์ง€ ๋‚˜ํƒ€๋‚ผ์ง€ ์„ค์ • ๊ฐ€๋Šฅ
result = {
'crr': crr, # ์ •ํ™•๋„
'substitutions': cer_result['substitutions'],
'deletions': cer_result['deletions'],
'insertions': cer_result['insertions']
}
return result
def transcribe_audio(file_path, model_name="daeunn/wav2vec2-korean-finetuned2"):
# ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
# ์˜ค๋””์˜ค ํŒŒ์ผ ๋กœ๋“œ ๋ฐ 16kHz ๋ฆฌ์ƒ˜ํ”Œ๋ง
waveform, sample_rate = torchaudio.load(file_path)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
# ์ถ”๋ก 
with torch.no_grad():
logits = model(**input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
return transcription
if __name__ == "__main__":
# ๋™์ผํ•œ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ๋‘ ๋ฒˆ์˜ ๋…น์Œ
reference = "์ œ๊ฐ€ ์Šค์›จ๋ด์—์„œ ์™”๊ณ , ์šฐ๋ฆฌ๋‚˜๋ผ๊ฐ€ ํฐ ๋‚˜๋ผ์ด์ง€๋งŒ ์ธ๊ตฌ๊ฐ€ ์ข€ ์ ์–ด์„œ ํ•™์ƒ์ด๋ผ๋„ ์žฌ๋ฐŒ๊ฒŒ ํ•  ์ˆ˜ ์žˆ๋Š”๊ฒŒ ๋งŽ์ด ์—†๊ณ  ์นดํŽ˜๋‚˜ ์ˆ ์ง‘์ด๋‚˜ ์ด๋Ÿฐ๊ฒŒ ๋งŽ์ด ์—†์–ด์„œ ๊ทธ๋Ÿฐ ๊ฑฐ ํ•œ๊ตญ์— ๋งŽ์ด ์žˆ๋‹ค๊ณ  ๋“ค์—ˆ๊ณ  ๊ทธ๊ฑฐ ๋•Œ๋ฌธ์— ํ•œ๊ตญ์— ๊ณต๋ถ€ํ•˜๋Ÿฌ ์™”์–ด์š”."
audio_path1 = "../data/stt_test.wav"
audio_path2 = "../data/stt_test.wav" # ํŒŒ์ผ ๋ณ€๊ฒฝ ํ•„์š”์š”
# ์ฒซ ๋ฒˆ์งธ ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
print("\n[์ฒซ ๋ฒˆ์งธ ๋ฐœํ™” STT ๋ฐ ์ •ํ™•๋„ ํ‰๊ฐ€]")
hypothesis1 = transcribe_audio(audio_path1)
print("์˜ˆ์ธก 1:", hypothesis1)
crr_result1 = calculate_korean_crr(reference, hypothesis1)
print(f"CRR #1: {crr_result1['crr']} (๋Œ€์ฒด={crr_result1['substitutions']}, ์‚ญ์ œ={crr_result1['deletions']}, ์‚ฝ์ž…={crr_result1['insertions']})")
# ๋‘ ๋ฒˆ์งธ ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
print("\n[๋‘ ๋ฒˆ์งธ ๋ฐœํ™” STT ๋ฐ ์ •ํ™•๋„ ํ‰๊ฐ€]")
hypothesis2 = transcribe_audio(audio_path2)
print("์˜ˆ์ธก 2:", hypothesis2)
crr_result2 = calculate_korean_crr(reference, hypothesis2)
print(f"CRR #2: {crr_result2['crr']} (๋Œ€์ฒด={crr_result2['substitutions']}, ์‚ญ์ œ={crr_result2['deletions']}, ์‚ฝ์ž…={crr_result2['insertions']})")
# ์ •ํ™•๋„ ๋ณ€ํ™”๋Ÿ‰ ์ถœ๋ ฅ
diff = round(crr_result2['crr'] - crr_result1['crr'], 4)
print(f"\n ๋™์ผ ๋ฌธ์žฅ ์žฌ๋ฐœํ™”์— ๋”ฐ๋ฅธ CRR ๋ณ€ํ™”๋Ÿ‰: {diff:+.4f}")