daeunn commited on
Commit
a341d0e
·
verified ·
1 Parent(s): 23ecab6

feat: cer_module.py 추가

Browse files
Files changed (1) hide show
  1. model/cer_module.py +178 -0
model/cer_module.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
+
8
+ def preprocess_text(text, remove_spaces=False, remove_punctuation=False):
9
+ """
10
+ 텍스트 전처리 함수
11
+
12
+ Args:
13
+ text (str): 전처리할 텍스트
14
+ remove_spaces (bool): 공백 제거 여부
15
+ remove_punctuation (bool): 문장부호 제거 여부
16
+
17
+ Returns:
18
+ str: 전처리된 텍스트
19
+ """
20
+ if remove_punctuation:
21
+ # 한글, 영문, 숫자를 제외한 문장부호 등 제거
22
+ text = re.sub(r'[^\w\s]', '', text)
23
+
24
+ if remove_spaces:
25
+ # 모든 공백 제거
26
+ text = text.replace(' ', '')
27
+
28
+ return text
29
+
30
+ def calculate_levenshtein(u, v):
31
+ """
32
+ 두 문자열 간의 레벤슈타인 거리와 작업 세부 정보(대체, 삭제, 삽입)를 계산
33
+
34
+ Args:
35
+ u (list): 첫 번째 문자열(문자 리스트)
36
+ v (list): 두 번째 문자열(문자 리스트)
37
+
38
+ Returns:
39
+ tuple: (편집 거리, (대체 수, 삭제 수, 삽입 수))
40
+ """
41
+ prev = None
42
+ curr = [0] + list(range(1, len(v) + 1))
43
+ # 작업: (대체, 삭제, 삽입)
44
+ prev_ops = None
45
+ curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
46
+
47
+ for x in range(1, len(u) + 1):
48
+ prev, curr = curr, [x] + ([None] * len(v))
49
+ prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
50
+
51
+ for y in range(1, len(v) + 1):
52
+ delcost = prev[y] + 1
53
+ addcost = curr[y - 1] + 1
54
+ subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
55
+
56
+ curr[y] = min(subcost, delcost, addcost)
57
+
58
+ if curr[y] == subcost:
59
+ (n_s, n_d, n_i) = prev_ops[y - 1]
60
+ curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
61
+ elif curr[y] == delcost:
62
+ (n_s, n_d, n_i) = prev_ops[y]
63
+ curr_ops[y] = (n_s, n_d + 1, n_i)
64
+ else:
65
+ (n_s, n_d, n_i) = curr_ops[y - 1]
66
+ curr_ops[y] = (n_s, n_d, n_i + 1)
67
+
68
+ return curr[len(v)], curr_ops[len(v)]
69
+
70
+ def calculate_korean_cer(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
71
+ """
72
+ 한국어 문장의 CER(Character Error Rate)을 계산
73
+
74
+ Args:
75
+ reference (str): 정답 문장
76
+ hypothesis (str): 예측 문장
77
+ remove_spaces (bool): 공백 제거 여부
78
+ remove_punctuation (bool): 문장부호 제거 여부
79
+
80
+ Returns:
81
+ dict: CER 값과 세부 정보 (대체, 삭제, 삽입 수)
82
+ """
83
+ # preprocessing
84
+ ref = preprocess_text(reference, remove_spaces, remove_punctuation)
85
+ hyp = preprocess_text(hypothesis, remove_spaces, remove_punctuation)
86
+
87
+ ref_chars = list(ref)
88
+ hyp_chars = list(hyp)
89
+
90
+ _, (substitutions, deletions, insertions) = calculate_levenshtein(hyp_chars, ref_chars)
91
+
92
+ hits = len(ref_chars) - (substitutions + deletions)
93
+ incorrect = substitutions + deletions + insertions
94
+ total = substitutions + deletions + hits + insertions
95
+
96
+ cer = round(incorrect / total, 4) if total > 0 else 0
97
+
98
+ result = {
99
+ 'cer': cer,
100
+ 'substitutions': substitutions,
101
+ 'deletions': deletions,
102
+ 'insertions': insertions
103
+ }
104
+
105
+ return result
106
+
107
+ def calculate_korean_crr(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
108
+ """
109
+ 한국어 문장의 CRR(정확도)을 계산
110
+ CRR = 1 - CER
111
+
112
+ Args:
113
+ reference (str): 정답 문장
114
+ hypothesis (str): 예측 문장
115
+ remove_spaces (bool): 공백 제거 여부
116
+ remove_punctuation (bool): 문장부호 제거 여부
117
+
118
+ Returns:
119
+ dict: CRR 값과 세부 정보 (대체, 삭제, 삽입 수)
120
+ """
121
+ cer_result = calculate_korean_cer(reference, hypothesis, remove_spaces, remove_punctuation)
122
+ crr = round(1 - cer_result['cer'], 4) # 이 부분에서 소수점 몇 번째 자리까지 나타낼지 설정 가능
123
+
124
+ result = {
125
+ 'crr': crr, # 정확도
126
+ 'substitutions': cer_result['substitutions'],
127
+ 'deletions': cer_result['deletions'],
128
+ 'insertions': cer_result['insertions']
129
+ }
130
+
131
+ return result
132
+
133
+ def transcribe_audio(file_path, model_name="daeunn/wav2vec2-korean-finetuned2"):
134
+ # 모델 및 프로세서 로드
135
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
136
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
137
+
138
+ # 오디오 파일 로드 및 16kHz 리샘플링
139
+ waveform, sample_rate = torchaudio.load(file_path)
140
+ if sample_rate != 16000:
141
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
143
+
144
+ # 추론
145
+ with torch.no_grad():
146
+ logits = model(**input_values).logits
147
+ predicted_ids = torch.argmax(logits, dim=-1)
148
+ transcription = processor.decode(predicted_ids[0])
149
+
150
+ return transcription
151
+
152
+
153
+ if __name__ == "__main__":
154
+ # 동일한 문장에 대한 두 번의 녹음
155
+ reference = "제가 스웨덴에서 왔고, 우리나라가 큰 나라이지만 인구가 좀 적어서 학생이라도 재밌게 할 수 있는게 많이 없고 카페나 술집이나 이런게 많이 없어서 그런 거 한국에 많이 있다고 들었고 그거 때문에 한국에 공부하러 왔어요."
156
+
157
+ audio_path1 = "../data/stt_test.wav"
158
+ audio_path2 = "../data/stt_test.wav" # 파일 변경 필요요
159
+
160
+ # 첫 번째 오디오 처리
161
+ print("\n[첫 번째 발화 STT 및 정확도 평가]")
162
+ hypothesis1 = transcribe_audio(audio_path1)
163
+ print("예측 1:", hypothesis1)
164
+
165
+ crr_result1 = calculate_korean_crr(reference, hypothesis1)
166
+ print(f"CRR #1: {crr_result1['crr']} (대체={crr_result1['substitutions']}, 삭제={crr_result1['deletions']}, 삽입={crr_result1['insertions']})")
167
+
168
+ # 두 번째 오디오 처리
169
+ print("\n[두 번째 발화 STT 및 정확도 평가]")
170
+ hypothesis2 = transcribe_audio(audio_path2)
171
+ print("예측 2:", hypothesis2)
172
+
173
+ crr_result2 = calculate_korean_crr(reference, hypothesis2)
174
+ print(f"CRR #2: {crr_result2['crr']} (대체={crr_result2['substitutions']}, 삭제={crr_result2['deletions']}, 삽입={crr_result2['insertions']})")
175
+
176
+ # 정확도 변화량 출력
177
+ diff = round(crr_result2['crr'] - crr_result1['crr'], 4)
178
+ print(f"\n 동일 문장 재발화에 따른 CRR 변화량: {diff:+.4f}")