revert: undo CRR calculation & cer_module.py addition

#2
by bigeco - opened
Files changed (2) hide show
  1. app.py +29 -11
  2. model/cer_module.py +178 -0
app.py CHANGED
@@ -5,6 +5,7 @@ import yaml
5
  import tempfile
6
  import os
7
  import traceback
 
8
  from model.wav2vec2 import Wav2Vec2
9
 
10
  # ---------------- ์„ค์ • ๋กœ๋“œ ----------------
@@ -25,36 +26,53 @@ app = FastAPI(
25
  class TranscriptionResponse(BaseModel):
26
  transcription: str
27
  status: str
 
28
 
29
  # ---------------- API: ํŒŒ์ผ ์—…๋กœ๋“œ POST ----------------
30
  @app.post("/transcribe", response_model=TranscriptionResponse)
31
- async def transcribe_audio(file: UploadFile = File(...)):
32
- """์˜ค๋””์˜ค ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜์—ฌ ์Œ์„ฑ ์ธ์‹ ์ˆ˜ํ–‰"""
33
-
34
  # ํŒŒ์ผ ํ˜•์‹ ๊ฒ€์ฆ
35
  if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
36
  return TranscriptionResponse(
37
  transcription="",
38
- status="error: ์ง€์›๋˜์ง€ ์•Š๋Š” ํŒŒ์ผ ํ˜•์‹์ž…๋‹ˆ๋‹ค. wav, mp3, flac, m4a ํŒŒ์ผ๋งŒ ์ง€์›๋ฉ๋‹ˆ๋‹ค."
 
39
  )
40
-
41
  try:
42
- # ํŒŒ์ผ ๋‚ด์šฉ ์ฝ๊ธฐ
43
  audio_bytes = await file.read()
44
-
45
- # ์Œ์„ฑ ์ธ์‹ ์ˆ˜ํ–‰
46
  result = wav2vec2_model.transcribe_from_bytes(audio_bytes, file.filename)
47
 
 
 
 
 
 
 
48
  return TranscriptionResponse(
49
  transcription=result,
50
- status="success"
 
51
  )
52
-
53
  except Exception as e:
54
  return TranscriptionResponse(
55
  transcription="",
56
- status=f"error: {str(e)}"
 
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # ---------------- HTML UI ----------------
60
  @app.get("/", response_class=HTMLResponse)
 
5
  import tempfile
6
  import os
7
  import traceback
8
+ from model.cer_module import calculate_korean_crr
9
  from model.wav2vec2 import Wav2Vec2
10
 
11
  # ---------------- ์„ค์ • ๋กœ๋“œ ----------------
 
26
  class TranscriptionResponse(BaseModel):
27
  transcription: str
28
  status: str
29
+ crr: float = None # CRR ๊ฐ’, ์„ ํƒ์  ํ•„๋“œ
30
 
31
  # ---------------- API: ํŒŒ์ผ ์—…๋กœ๋“œ POST ----------------
32
  @app.post("/transcribe", response_model=TranscriptionResponse)
33
+ async def transcribe_audio(file: UploadFile = File(...), reference: str = None):
 
 
34
  # ํŒŒ์ผ ํ˜•์‹ ๊ฒ€์ฆ
35
  if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
36
  return TranscriptionResponse(
37
  transcription="",
38
+ status="error: ์ง€์›๋˜์ง€ ์•Š๋Š” ํŒŒ์ผ ํ˜•์‹์ž…๋‹ˆ๋‹ค. wav, mp3, flac, m4a ํŒŒ์ผ๋งŒ ์ง€์›๋ฉ๋‹ˆ๋‹ค.",
39
+ crr=None
40
  )
 
41
  try:
 
42
  audio_bytes = await file.read()
 
 
43
  result = wav2vec2_model.transcribe_from_bytes(audio_bytes, file.filename)
44
 
45
+ # reference๊ฐ€ ์ „๋‹ฌ๋œ ๊ฒฝ์šฐ CRR ๊ณ„์‚ฐ
46
+ crr = None
47
+ if reference:
48
+ crr_result = calculate_korean_crr(reference, result)
49
+ crr = crr_result['crr']
50
+
51
  return TranscriptionResponse(
52
  transcription=result,
53
+ status="success",
54
+ crr=crr
55
  )
 
56
  except Exception as e:
57
  return TranscriptionResponse(
58
  transcription="",
59
+ status=f"error: {str(e)}",
60
+ crr=None
61
  )
62
+ class CRRRequest(BaseModel):
63
+ original: str
64
+ corrected: str
65
+
66
+ class CRRResponse(BaseModel):
67
+ crr: float
68
+
69
+ @app.post("/calculate-crr", response_model=CRRResponse)
70
+ async def calculate_crr_api(data: CRRRequest):
71
+ """
72
+ ๋‘ ๋ฌธ์žฅ(original, corrected)์„ ๋ฐ›์•„ CRR(์ •ํ™•๋„)๋งŒ ๊ณ„์‚ฐํ•ด์„œ ๋ฐ˜ํ™˜
73
+ """
74
+ result = calculate_korean_crr(data.original, data.corrected)
75
+ return CRRResponse(crr=result['crr'])
76
 
77
  # ---------------- HTML UI ----------------
78
  @app.get("/", response_class=HTMLResponse)
model/cer_module.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchaudio
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
7
+
8
+ def preprocess_text(text, remove_spaces=False, remove_punctuation=False):
9
+ """
10
+ ํ…์ŠคํŠธ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
11
+
12
+ Args:
13
+ text (str): ์ „์ฒ˜๋ฆฌํ•  ํ…์ŠคํŠธ
14
+ remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
15
+ remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
16
+
17
+ Returns:
18
+ str: ์ „์ฒ˜๋ฆฌ๋œ ํ…์ŠคํŠธ
19
+ """
20
+ if remove_punctuation:
21
+ # ํ•œ๊ธ€, ์˜๋ฌธ, ์ˆซ์ž๋ฅผ ์ œ์™ธํ•œ ๋ฌธ์žฅ๋ถ€ํ˜ธ ๋“ฑ ์ œ๊ฑฐ
22
+ text = re.sub(r'[^\w\s]', '', text)
23
+
24
+ if remove_spaces:
25
+ # ๋ชจ๋“  ๊ณต๋ฐฑ ์ œ๊ฑฐ
26
+ text = text.replace(' ', '')
27
+
28
+ return text
29
+
30
+ def calculate_levenshtein(u, v):
31
+ """
32
+ ๋‘ ๋ฌธ์ž์—ด ๊ฐ„์˜ ๋ ˆ๋ฒค์Šˆํƒ€์ธ ๊ฑฐ๋ฆฌ์™€ ์ž‘์—… ์„ธ๋ถ€ ์ •๋ณด(๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž…)๋ฅผ ๊ณ„์‚ฐ
33
+
34
+ Args:
35
+ u (list): ์ฒซ ๋ฒˆ์งธ ๋ฌธ์ž์—ด(๋ฌธ์ž ๋ฆฌ์ŠคํŠธ)
36
+ v (list): ๋‘ ๋ฒˆ์งธ ๋ฌธ์ž์—ด(๋ฌธ์ž ๋ฆฌ์ŠคํŠธ)
37
+
38
+ Returns:
39
+ tuple: (ํŽธ์ง‘ ๊ฑฐ๋ฆฌ, (๋Œ€์ฒด ์ˆ˜, ์‚ญ์ œ ์ˆ˜, ์‚ฝ์ž… ์ˆ˜))
40
+ """
41
+ prev = None
42
+ curr = [0] + list(range(1, len(v) + 1))
43
+ # ์ž‘์—…: (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž…)
44
+ prev_ops = None
45
+ curr_ops = [(0, 0, i) for i in range(len(v) + 1)]
46
+
47
+ for x in range(1, len(u) + 1):
48
+ prev, curr = curr, [x] + ([None] * len(v))
49
+ prev_ops, curr_ops = curr_ops, [(0, x, 0)] + ([None] * len(v))
50
+
51
+ for y in range(1, len(v) + 1):
52
+ delcost = prev[y] + 1
53
+ addcost = curr[y - 1] + 1
54
+ subcost = prev[y - 1] + int(u[x - 1] != v[y - 1])
55
+
56
+ curr[y] = min(subcost, delcost, addcost)
57
+
58
+ if curr[y] == subcost:
59
+ (n_s, n_d, n_i) = prev_ops[y - 1]
60
+ curr_ops[y] = (n_s + int(u[x - 1] != v[y - 1]), n_d, n_i)
61
+ elif curr[y] == delcost:
62
+ (n_s, n_d, n_i) = prev_ops[y]
63
+ curr_ops[y] = (n_s, n_d + 1, n_i)
64
+ else:
65
+ (n_s, n_d, n_i) = curr_ops[y - 1]
66
+ curr_ops[y] = (n_s, n_d, n_i + 1)
67
+
68
+ return curr[len(v)], curr_ops[len(v)]
69
+
70
+ def calculate_korean_cer(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
71
+ """
72
+ ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์˜ CER(Character Error Rate)์„ ๊ณ„์‚ฐ
73
+
74
+ Args:
75
+ reference (str): ์ •๋‹ต ๋ฌธ์žฅ
76
+ hypothesis (str): ์˜ˆ์ธก ๋ฌธ์žฅ
77
+ remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
78
+ remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
79
+
80
+ Returns:
81
+ dict: CER ๊ฐ’๊ณผ ์„ธ๋ถ€ ์ •๋ณด (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž… ์ˆ˜)
82
+ """
83
+ # preprocessing
84
+ ref = preprocess_text(reference, remove_spaces, remove_punctuation)
85
+ hyp = preprocess_text(hypothesis, remove_spaces, remove_punctuation)
86
+
87
+ ref_chars = list(ref)
88
+ hyp_chars = list(hyp)
89
+
90
+ _, (substitutions, deletions, insertions) = calculate_levenshtein(hyp_chars, ref_chars)
91
+
92
+ hits = len(ref_chars) - (substitutions + deletions)
93
+ incorrect = substitutions + deletions + insertions
94
+ total = substitutions + deletions + hits + insertions
95
+
96
+ cer = round(incorrect / total, 4) if total > 0 else 0
97
+
98
+ result = {
99
+ 'cer': cer,
100
+ 'substitutions': substitutions,
101
+ 'deletions': deletions,
102
+ 'insertions': insertions
103
+ }
104
+
105
+ return result
106
+
107
+ def calculate_korean_crr(reference, hypothesis, remove_spaces=True, remove_punctuation=True):
108
+ """
109
+ ํ•œ๊ตญ์–ด ๋ฌธ์žฅ์˜ CRR(์ •ํ™•๋„)์„ ๊ณ„์‚ฐ
110
+ CRR = 1 - CER
111
+
112
+ Args:
113
+ reference (str): ์ •๋‹ต ๋ฌธ์žฅ
114
+ hypothesis (str): ์˜ˆ์ธก ๋ฌธ์žฅ
115
+ remove_spaces (bool): ๊ณต๋ฐฑ ์ œ๊ฑฐ ์—ฌ๋ถ€
116
+ remove_punctuation (bool): ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ ์—ฌ๋ถ€
117
+
118
+ Returns:
119
+ dict: CRR ๊ฐ’๊ณผ ์„ธ๋ถ€ ์ •๋ณด (๋Œ€์ฒด, ์‚ญ์ œ, ์‚ฝ์ž… ์ˆ˜)
120
+ """
121
+ cer_result = calculate_korean_cer(reference, hypothesis, remove_spaces, remove_punctuation)
122
+ crr = round(1 - cer_result['cer'], 4) # ์ด ๋ถ€๋ถ„์—์„œ ์†Œ์ˆ˜์  ๋ช‡ ๋ฒˆ์งธ ์ž๋ฆฌ๊นŒ์ง€ ๋‚˜ํƒ€๋‚ผ์ง€ ์„ค์ • ๊ฐ€๋Šฅ
123
+
124
+ result = {
125
+ 'crr': crr, # ์ •ํ™•๋„
126
+ 'substitutions': cer_result['substitutions'],
127
+ 'deletions': cer_result['deletions'],
128
+ 'insertions': cer_result['insertions']
129
+ }
130
+
131
+ return result
132
+
133
+ def transcribe_audio(file_path, model_name="daeunn/wav2vec2-korean-finetuned2"):
134
+ # ๋ชจ๋ธ ๋ฐ ํ”„๋กœ์„ธ์„œ ๋กœ๋“œ
135
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
136
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
137
+
138
+ # ์˜ค๋””์˜ค ํŒŒ์ผ ๋กœ๋“œ ๋ฐ 16kHz ๋ฆฌ์ƒ˜ํ”Œ๋ง
139
+ waveform, sample_rate = torchaudio.load(file_path)
140
+ if sample_rate != 16000:
141
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
142
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
143
+
144
+ # ์ถ”๋ก 
145
+ with torch.no_grad():
146
+ logits = model(**input_values).logits
147
+ predicted_ids = torch.argmax(logits, dim=-1)
148
+ transcription = processor.decode(predicted_ids[0])
149
+
150
+ return transcription
151
+
152
+
153
+ if __name__ == "__main__":
154
+ # ๋™์ผํ•œ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ๋‘ ๋ฒˆ์˜ ๋…น์Œ
155
+ reference = "์ œ๊ฐ€ ์Šค์›จ๋ด์—์„œ ์™”๊ณ , ์šฐ๋ฆฌ๋‚˜๋ผ๊ฐ€ ํฐ ๋‚˜๋ผ์ด์ง€๋งŒ ์ธ๊ตฌ๊ฐ€ ์ข€ ์ ์–ด์„œ ํ•™์ƒ์ด๋ผ๋„ ์žฌ๋ฐŒ๊ฒŒ ํ•  ์ˆ˜ ์žˆ๋Š”๊ฒŒ ๋งŽ์ด ์—†๊ณ  ์นดํŽ˜๋‚˜ ์ˆ ์ง‘์ด๋‚˜ ์ด๋Ÿฐ๊ฒŒ ๋งŽ์ด ์—†์–ด์„œ ๊ทธ๋Ÿฐ ๊ฑฐ ํ•œ๊ตญ์— ๋งŽ์ด ์žˆ๋‹ค๊ณ  ๋“ค์—ˆ๊ณ  ๊ทธ๊ฑฐ ๋•Œ๋ฌธ์— ํ•œ๊ตญ์— ๊ณต๋ถ€ํ•˜๋Ÿฌ ์™”์–ด์š”."
156
+
157
+ audio_path1 = "../data/stt_test.wav"
158
+ audio_path2 = "../data/stt_test.wav" # ํŒŒ์ผ ๋ณ€๊ฒฝ ํ•„์š”์š”
159
+
160
+ # ์ฒซ ๋ฒˆ์งธ ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
161
+ print("\n[์ฒซ ๋ฒˆ์งธ ๋ฐœํ™” STT ๋ฐ ์ •ํ™•๋„ ํ‰๊ฐ€]")
162
+ hypothesis1 = transcribe_audio(audio_path1)
163
+ print("์˜ˆ์ธก 1:", hypothesis1)
164
+
165
+ crr_result1 = calculate_korean_crr(reference, hypothesis1)
166
+ print(f"CRR #1: {crr_result1['crr']} (๋Œ€์ฒด={crr_result1['substitutions']}, ์‚ญ์ œ={crr_result1['deletions']}, ์‚ฝ์ž…={crr_result1['insertions']})")
167
+
168
+ # ๋‘ ๋ฒˆ์งธ ์˜ค๋””์˜ค ์ฒ˜๋ฆฌ
169
+ print("\n[๋‘ ๋ฒˆ์งธ ๋ฐœํ™” STT ๋ฐ ์ •ํ™•๋„ ํ‰๊ฐ€]")
170
+ hypothesis2 = transcribe_audio(audio_path2)
171
+ print("์˜ˆ์ธก 2:", hypothesis2)
172
+
173
+ crr_result2 = calculate_korean_crr(reference, hypothesis2)
174
+ print(f"CRR #2: {crr_result2['crr']} (๋Œ€์ฒด={crr_result2['substitutions']}, ์‚ญ์ œ={crr_result2['deletions']}, ์‚ฝ์ž…={crr_result2['insertions']})")
175
+
176
+ # ์ •ํ™•๋„ ๋ณ€ํ™”๋Ÿ‰ ์ถœ๋ ฅ
177
+ diff = round(crr_result2['crr'] - crr_result1['crr'], 4)
178
+ print(f"\n ๋™์ผ ๋ฌธ์žฅ ์žฌ๋ฐœํ™”์— ๋”ฐ๋ฅธ CRR ๋ณ€ํ™”๋Ÿ‰: {diff:+.4f}")