wav2vec2-server / app.py
bigeco's picture
feat: 기쑴의 정확도 λŒ€μ‹  CRR둜 계산 (#1)
23ecab6 verified
raw
history blame
9.14 kB
from fastapi import FastAPI, Request, File, UploadFile, Form
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import yaml
import tempfile
import os
import traceback
from model.cer_module import calculate_korean_crr
from model.wav2vec2 import Wav2Vec2
# ---------------- μ„€μ • λ‘œλ“œ ----------------
with open("config/wav2vec2.yaml", "r") as f:
config = yaml.safe_load(f)
# ---------------- λͺ¨λΈ μ΄ˆκΈ°ν™” ----------------
wav2vec2_model = Wav2Vec2(config)
# ---------------- FastAPI μ•± ----------------
app = FastAPI(
title="Korean Speech Recognition API",
description="FastAPI + Wav2Vec2 기반 ν•œκ΅­μ–΄ μŒμ„± 인식 μ„œλ²„",
version="1.0.0"
)
# ---------------- μž…λ ₯ λͺ¨λΈ ----------------
class TranscriptionResponse(BaseModel):
transcription: str
status: str
crr: float = None # CRR κ°’, 선택적 ν•„λ“œ
# ---------------- API: 파일 μ—…λ‘œλ“œ POST ----------------
@app.post("/transcribe", response_model=TranscriptionResponse)
async def transcribe_audio(file: UploadFile = File(...), reference: str = None):
# 파일 ν˜•μ‹ 검증
if not file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
return TranscriptionResponse(
transcription="",
status="error: μ§€μ›λ˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹μž…λ‹ˆλ‹€. wav, mp3, flac, m4a 파일만 μ§€μ›λ©λ‹ˆλ‹€.",
crr=None
)
try:
audio_bytes = await file.read()
result = wav2vec2_model.transcribe_from_bytes(audio_bytes, file.filename)
# referenceκ°€ μ „λ‹¬λœ 경우 CRR 계산
crr = None
if reference:
crr_result = calculate_korean_crr(reference, result)
crr = crr_result['crr']
return TranscriptionResponse(
transcription=result,
status="success",
crr=crr
)
except Exception as e:
return TranscriptionResponse(
transcription="",
status=f"error: {str(e)}",
crr=None
)
class CRRRequest(BaseModel):
original: str
corrected: str
class CRRResponse(BaseModel):
crr: float
@app.post("/calculate-crr", response_model=CRRResponse)
async def calculate_crr_api(data: CRRRequest):
"""
두 λ¬Έμž₯(original, corrected)을 λ°›μ•„ CRR(정확도)만 κ³„μ‚°ν•΄μ„œ λ°˜ν™˜
"""
result = calculate_korean_crr(data.original, data.corrected)
return CRRResponse(crr=result['crr'])
# ---------------- HTML UI ----------------
@app.get("/", response_class=HTMLResponse)
async def main_ui():
return """
<html>
<head>
<title>Korean Speech Recognition</title>
<meta charset="UTF-8">
<style>
body {
font-family: Arial, sans-serif;
max-width: 800px;
margin: auto;
padding: 2rem;
background-color: #f5f5f5;
}
.container {
background-color: white;
padding: 2rem;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.form-group {
margin-bottom: 1.5rem;
}
label {
display: block;
margin-bottom: 0.5rem;
font-weight: bold;
color: #333;
}
input[type="file"] {
padding: 0.5rem;
border: 2px dashed #ccc;
border-radius: 5px;
width: 100%;
box-sizing: border-box;
}
input[type="submit"] {
background-color: #007bff;
color: white;
padding: 1rem 2rem;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 1rem;
}
input[type="submit"]:hover {
background-color: #0056b3;
}
.info {
background-color: #e7f3ff;
padding: 1rem;
border-radius: 5px;
margin-bottom: 1rem;
border-left: 4px solid #007bff;
}
</style>
</head>
<body>
<div class="container">
<h1>🎀 ν•œκ΅­μ–΄ μŒμ„± 인식</h1>
<div class="info">
<strong>지원 ν˜•μ‹:</strong> WAV, MP3, FLAC, M4A<br>
<strong>λͺ¨λΈ:</strong> Wav2Vec2 Korean Fine-tuned
</div>
<form action="/submit" method="post" enctype="multipart/form-data">
<div class="form-group">
<label for="audio_file">🎡 μ˜€λ””μ˜€ 파일 선택:</label>
<input type="file" id="audio_file" name="audio_file" accept=".wav,.mp3,.flac,.m4a" required>
</div>
<input type="submit" value="μŒμ„± 인식 μ‹€ν–‰">
</form>
</div>
</body>
</html>
"""
# ---------------- κ²°κ³Ό λ Œλ”λ§ ----------------
@app.post("/submit", response_class=HTMLResponse)
async def handle_form(request: Request, audio_file: UploadFile = File(...)):
try:
# 파일 ν˜•μ‹ 검증
if not audio_file.filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a')):
return f"""
<html>
<head><title>μ—λŸ¬</title><meta charset="UTF-8"></head>
<body style="font-family: Arial, sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>❌ 파일 ν˜•μ‹ 였λ₯˜</h1>
<p>μ§€μ›λ˜μ§€ μ•ŠλŠ” 파일 ν˜•μ‹μž…λ‹ˆλ‹€.</p>
<p><strong>지원 ν˜•μ‹:</strong> WAV, MP3, FLAC, M4A</p>
<br>
<a href="/" style="color: #007bff; text-decoration: none;">← λŒμ•„κ°€κΈ°</a>
</body>
</html>
"""
# 파일 λ‚΄μš© 읽기
audio_bytes = await audio_file.read()
# μŒμ„± 인식 μˆ˜ν–‰
result = wav2vec2_model.transcribe_from_bytes(audio_bytes, audio_file.filename)
except Exception as e:
error_details = traceback.format_exc()
return f"""
<html>
<head><title>μ—λŸ¬</title><meta charset="UTF-8"></head>
<body style="font-family: Arial, sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>❌ μ„œλ²„ 였λ₯˜ λ°œμƒ</h1>
<p><strong>였λ₯˜ λ©”μ‹œμ§€:</strong></p>
<pre style="background-color: #f8f9fa; padding: 1rem; border-radius: 5px; overflow-x: auto;">{str(e)}</pre>
<hr>
<details>
<summary><strong>μ—λŸ¬ 상세 (ν΄λ¦­ν•˜μ—¬ 펼치기)</strong></summary>
<pre style="background-color: #f8f9fa; padding: 1rem; border-radius: 5px; overflow-x: auto;">{error_details}</pre>
</details>
<br>
<a href="/" style="color: #007bff; text-decoration: none;">← λŒμ•„κ°€κΈ°</a>
</body>
</html>
"""
return f"""
<html>
<head><title>κ²°κ³Ό</title><meta charset="UTF-8"></head>
<body style="font-family: Arial, sans-serif; max-width: 600px; margin: auto; padding: 2rem;">
<h1>βœ… μŒμ„± 인식 κ²°κ³Ό</h1>
<div style="background-color: #f8f9fa; padding: 1rem; border-radius: 5px; margin: 1rem 0;">
<p><strong>μ—…λ‘œλ“œλœ 파일:</strong> {audio_file.filename}</p>
<p><strong>파일 크기:</strong> {len(audio_bytes):,} bytes</p>
</div>
<hr>
<h2>🎯 μΈμ‹λœ ν…μŠ€νŠΈ:</h2>
<div style="background-color: #e7f3ff; padding: 1.5rem; border-radius: 5px; border-left: 4px solid #007bff;">
<pre style="font-size: 1.1rem; margin: 0; white-space: pre-wrap; word-wrap: break-word;">{result}</pre>
</div>
<br>
<a href="/" style="color: #007bff; text-decoration: none;">← λ‹€μ‹œ μ‹œλ„ν•˜κΈ°</a>
</body>
</html>
"""
# ---------------- ν—¬μŠ€ 체크 ----------------
@app.get("/health")
async def health_check():
return {
"status": "ok",
"model": config["model"]["id"],
"device": config["model"]["device"],
"sampling_rate": config["model"]["sampling_rate"]
}
# ---------------- λͺ¨λΈ 정보 ----------------
@app.get("/info")
async def model_info():
return {
"model_id": config["model"]["id"],
"device": config["model"]["device"],
"sampling_rate": config["model"]["sampling_rate"],
"supported_formats": ["wav", "mp3", "flac", "m4a"],
"description": "Korean Speech Recognition using Wav2Vec2"
}