feat: 기존의 정확도 대신 CRR로 계산

#1
by daeunn - opened
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -5,6 +5,7 @@ import yaml
5
  import tempfile
6
  import os
7
  import traceback
 
8
  from model.wav2vec2 import Wav2Vec2
9
 
10
  # ---------------- 설정 로드 ----------------
@@ -25,36 +26,53 @@ app = FastAPI(
25
  class TranscriptionResponse(BaseModel):
26
  transcription: str
27
  status: str
 
28
 
29
  # ---------------- API: 파일 업로드 POST ----------------
30
  @app.post("/transcribe", response_model=TranscriptionResponse)
31
- async def transcribe_audio(file: UploadFile = File(...)):
32
- """오디오 파일을 업로드하여 음성 인식 수행"""
33
-
34
  # 파일 형식 검증
35
  if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
36
  return TranscriptionResponse(
37
  transcription="",
38
- status="error: 지원되지 않는 파일 형식입니다. wav, mp3, flac, m4a 파일만 지원됩니다."
 
39
  )
40
-
41
  try:
42
- # 파일 내용 읽기
43
  audio_bytes = await file.read()
44
-
45
- # 음성 인식 수행
46
  result = wav2vec2_model.transcribe_from_bytes(audio_bytes, file.filename)
47
 
 
 
 
 
 
 
48
  return TranscriptionResponse(
49
  transcription=result,
50
- status="success"
 
51
  )
52
-
53
  except Exception as e:
54
  return TranscriptionResponse(
55
  transcription="",
56
- status=f"error: {str(e)}"
 
57
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # ---------------- HTML UI ----------------
60
  @app.get("/", response_class=HTMLResponse)
 
5
  import tempfile
6
  import os
7
  import traceback
8
+ from model.cer_module import calculate_korean_crr
9
  from model.wav2vec2 import Wav2Vec2
10
 
11
  # ---------------- 설정 로드 ----------------
 
26
  class TranscriptionResponse(BaseModel):
27
  transcription: str
28
  status: str
29
+ crr: float = None # CRR 값, 선택적 필드
30
 
31
  # ---------------- API: 파일 업로드 POST ----------------
32
  @app.post("/transcribe", response_model=TranscriptionResponse)
33
+ async def transcribe_audio(file: UploadFile = File(...), reference: str = None):
 
 
34
  # 파일 형식 검증
35
  if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
36
  return TranscriptionResponse(
37
  transcription="",
38
+ status="error: 지원되지 않는 파일 형식입니다. wav, mp3, flac, m4a 파일만 지원됩니다.",
39
+ crr=None
40
  )
 
41
  try:
 
42
  audio_bytes = await file.read()
 
 
43
  result = wav2vec2_model.transcribe_from_bytes(audio_bytes, file.filename)
44
 
45
+ # reference가 전달된 경우 CRR 계산
46
+ crr = None
47
+ if reference:
48
+ crr_result = calculate_korean_crr(reference, result)
49
+ crr = crr_result['crr']
50
+
51
  return TranscriptionResponse(
52
  transcription=result,
53
+ status="success",
54
+ crr=crr
55
  )
 
56
  except Exception as e:
57
  return TranscriptionResponse(
58
  transcription="",
59
+ status=f"error: {str(e)}",
60
+ crr=None
61
  )
62
+ class CRRRequest(BaseModel):
63
+ original: str
64
+ corrected: str
65
+
66
+ class CRRResponse(BaseModel):
67
+ crr: float
68
+
69
+ @app.post("/calculate-crr", response_model=CRRResponse)
70
+ async def calculate_crr_api(data: CRRRequest):
71
+ """
72
+ 두 문장(original, corrected)을 받아 CRR(정확도)만 계산해서 반환
73
+ """
74
+ result = calculate_korean_crr(data.original, data.corrected)
75
+ return CRRResponse(crr=result['crr'])
76
 
77
  # ---------------- HTML UI ----------------
78
  @app.get("/", response_class=HTMLResponse)