bigeco commited on
Commit
c4d95b5
·
1 Parent(s): a341d0e

Revert "feat: cer_module.py 추가"

Browse files

This reverts commit a341d0e357d63602fd6f4dadbce22bf0fbff96b4.

Files changed (1) hide show
  1. model/cer_module.py +0 -178
model/cer_module.py DELETED
@@ -1,178 +0,0 @@
1
- import re
2
-
3
- import numpy as np
4
- import torch
5
- import torchaudio
6
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
-
8
- def preprocess_text(text, remove_spaces=False, remove_punctuation=False):
9
- """
10
- 텍스트 전처리 함수
11
-
12
- Args:
13
- text (str): 전처리할 텍스트
14
- remove_spaces (bool): 공백 제거 여부
15
- remove_punctuation (bool): 문장부호 제거 여부
16
-
17
- Returns:
18
- str: 전처리된 텍스트
19
- """
20
- if remove_punctuation:
21
- # 한글, 영문, 숫자를 제외한 문장부호 등 제거
22
- text = re.sub(r'[^\w\s]', '', text)
23
-
24
- if remove_spaces:
25
- # 모든 공백 제거
26
- text = text.replace(' ', '')
27
-
28
- return text
29
-
30
- def calculate_levenshtein(u, v):
31
- """
32
- 두 문자열 간의 레벤슈타인 거리와 작업 세부 정보(대체, 삭제, 삽입)를 계산
33
-
34
- Args:
35
- u (list): 첫 번째 문자열(문자 리스트)
36
- v (list): 두 번째 문자열(문자 리스트)
37
-
38
- Returns:
39
- tuple: (편집 거리, (대체 수, 삭제 수, 삽입 수))
40
- """
41
- prev = None
42
- curr = [0] + list(range(1, len(v) + 1))
43
- # 작업: (대체, 삭제, 삽입)
44
- prev_ops = None
45
- curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
46
-
47
- for x in range(1, len(u) + 1):
48
- prev, curr = curr, [x] + ([None] * len(v))
49
- prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
50
-
51
- for y in range(1, len(v) + 1):
52
- delcost = prev[y] + 1
53
- addcost = curr[y - 1] + 1
54
- subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
55
-
56
- curr[y] = min(subcost, delcost, addcost)
57
-
58
- if curr[y] == subcost:
59
- (n_s, n_d, n_i) = prev_ops[y - 1]
60
- curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
61
- elif curr[y] == delcost:
62
- (n_s, n_d, n_i) = prev_ops[y]
63
- curr_ops[y] = (n_s, n_d + 1, n_i)
64
- else:
65
- (n_s, n_d, n_i) = curr_ops[y - 1]
66
- curr_ops[y] = (n_s, n_d, n_i + 1)
67
-
68
- return curr[len(v)], curr_ops[len(v)]
69
-
70
- def calculate_korean_cer(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
71
- """
72
- 한국어 문장의 CER(Character Error Rate)을 계산
73
-
74
- Args:
75
- reference (str): 정답 문장
76
- hypothesis (str): 예측 문장
77
- remove_spaces (bool): 공백 제거 여부
78
- remove_punctuation (bool): 문장부호 제거 여부
79
-
80
- Returns:
81
- dict: CER 값과 세부 정보 (대체, 삭제, 삽입 수)
82
- """
83
- # preprocessing
84
- ref = preprocess_text(reference, remove_spaces, remove_punctuation)
85
- hyp = preprocess_text(hypothesis, remove_spaces, remove_punctuation)
86
-
87
- ref_chars = list(ref)
88
- hyp_chars = list(hyp)
89
-
90
- _, (substitutions, deletions, insertions) = calculate_levenshtein(hyp_chars, ref_chars)
91
-
92
- hits = len(ref_chars) - (substitutions + deletions)
93
- incorrect = substitutions + deletions + insertions
94
- total = substitutions + deletions + hits + insertions
95
-
96
- cer = round(incorrect / total, 4) if total > 0 else 0
97
-
98
- result = {
99
- 'cer': cer,
100
- 'substitutions': substitutions,
101
- 'deletions': deletions,
102
- 'insertions': insertions
103
- }
104
-
105
- return result
106
-
107
- def calculate_korean_crr(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
108
- """
109
- 한국어 문장의 CRR(정확도)을 계산
110
- CRR = 1 - CER
111
-
112
- Args:
113
- reference (str): 정답 문장
114
- hypothesis (str): 예측 문장
115
- remove_spaces (bool): 공백 제거 여부
116
- remove_punctuation (bool): 문장부호 제거 여부
117
-
118
- Returns:
119
- dict: CRR 값과 세부 정보 (대체, 삭제, 삽입 수)
120
- """
121
- cer_result = calculate_korean_cer(reference, hypothesis, remove_spaces, remove_punctuation)
122
- crr = round(1 - cer_result['cer'], 4) # 이 부분에서 소수점 몇 번째 자리까지 나타낼지 설정 가능
123
-
124
- result = {
125
- 'crr': crr, # 정확도
126
- 'substitutions': cer_result['substitutions'],
127
- 'deletions': cer_result['deletions'],
128
- 'insertions': cer_result['insertions']
129
- }
130
-
131
- return result
132
-
133
- def transcribe_audio(file_path, model_name="daeunn/wav2vec2-korean-finetuned2"):
134
- # 모델 및 프로세서 로드
135
- processor = Wav2Vec2Processor.from_pretrained(model_name)
136
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
137
-
138
- # 오디오 파일 로드 및 16kHz 리샘플링
139
- waveform, sample_rate = torchaudio.load(file_path)
140
- if sample_rate != 16000:
141
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
- input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
143
-
144
- # 추론
145
- with torch.no_grad():
146
- logits = model(**input_values).logits
147
- predicted_ids = torch.argmax(logits, dim=-1)
148
- transcription = processor.decode(predicted_ids[0])
149
-
150
- return transcription
151
-
152
-
153
- if __name__ == "__main__":
154
- # 동일한 문장에 대한 두 번의 녹음
155
- reference = "제가 스웨덴에서 왔고, 우리나라가 큰 나라이지만 인구가 좀 적어서 학생이라도 재밌게 할 수 있는게 많이 없고 카페나 술집이나 이런게 많이 없어서 그런 거 한국에 많이 있다고 들었고 그거 때문에 한국에 공부하러 왔어요."
156
-
157
- audio_path1 = "../data/stt_test.wav"
158
- audio_path2 = "../data/stt_test.wav" # 파일 변경 필요요
159
-
160
- # 첫 번째 오디오 처리
161
- print("\n[첫 번째 발화 STT 및 정확도 평가]")
162
- hypothesis1 = transcribe_audio(audio_path1)
163
- print("예측 1:", hypothesis1)
164
-
165
- crr_result1 = calculate_korean_crr(reference, hypothesis1)
166
- print(f"CRR #1: {crr_result1['crr']} (대체={crr_result1['substitutions']}, 삭제={crr_result1['deletions']}, 삽입={crr_result1['insertions']})")
167
-
168
- # 두 번째 오디오 처리
169
- print("\n[두 번째 발화 STT 및 정확도 평가]")
170
- hypothesis2 = transcribe_audio(audio_path2)
171
- print("예측 2:", hypothesis2)
172
-
173
- crr_result2 = calculate_korean_crr(reference, hypothesis2)
174
- print(f"CRR #2: {crr_result2['crr']} (대체={crr_result2['substitutions']}, 삭제={crr_result2['deletions']}, 삽입={crr_result2['insertions']})")
175
-
176
- # 정확도 변화량 출력
177
- diff = round(crr_result2['crr'] - crr_result1['crr'], 4)
178
- print(f"\n 동일 문장 재발화에 따른 CRR 변화량: {diff:+.4f}")