Spaces:
Sleeping
Sleeping
Commit ·
b55d9e8
1
Parent(s): bd471fc
Update model version
Browse files- app/api/transcribe.py +3 -2
- app/config/settings.py +2 -1
- app/core/asr_engine.py +3 -2
- app/jobs/transcribe_job.py +2 -1
app/api/transcribe.py
CHANGED
|
@@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse
|
|
| 10 |
from rq import Queue, Retry
|
| 11 |
|
| 12 |
from app.config import settings
|
|
|
|
| 13 |
from app.infra.redis_client import redis_client
|
| 14 |
from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
|
| 15 |
from app.schemas.transcribe import TranscribeResponse
|
|
@@ -126,7 +127,7 @@ async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None =
|
|
| 126 |
"duration": info.get("duration"),
|
| 127 |
"sample_rate": info.get("samplerate"),
|
| 128 |
"chunks": chunks,
|
| 129 |
-
"asr_model":
|
| 130 |
}
|
| 131 |
},
|
| 132 |
"generate": ["normalize", "keywords", "summary", "mindmap"],
|
|
@@ -156,7 +157,7 @@ async def _create_placeholder_note(note_id: str, duration: float, audio_url: str
|
|
| 156 |
"audio": {
|
| 157 |
"duration": duration,
|
| 158 |
"chunks": [],
|
| 159 |
-
"asr_model":
|
| 160 |
}
|
| 161 |
},
|
| 162 |
# ❌ KHÔNG generate ở đây
|
|
|
|
| 10 |
from rq import Queue, Retry
|
| 11 |
|
| 12 |
from app.config import settings
|
| 13 |
+
from app.config.settings import MODEL_NAME
|
| 14 |
from app.infra.redis_client import redis_client
|
| 15 |
from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
|
| 16 |
from app.schemas.transcribe import TranscribeResponse
|
|
|
|
| 127 |
"duration": info.get("duration"),
|
| 128 |
"sample_rate": info.get("samplerate"),
|
| 129 |
"chunks": chunks,
|
| 130 |
+
"asr_model": MODEL_NAME,
|
| 131 |
}
|
| 132 |
},
|
| 133 |
"generate": ["normalize", "keywords", "summary", "mindmap"],
|
|
|
|
| 157 |
"audio": {
|
| 158 |
"duration": duration,
|
| 159 |
"chunks": [],
|
| 160 |
+
"asr_model": MODEL_NAME,
|
| 161 |
}
|
| 162 |
},
|
| 163 |
# ❌ KHÔNG generate ở đây
|
app/config/settings.py
CHANGED
|
@@ -2,7 +2,8 @@ import os
|
|
| 2 |
|
| 3 |
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
|
| 4 |
MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
|
| 5 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-base")
|
|
|
|
| 6 |
|
| 7 |
TMP_DIR = os.getenv("TMP_DIR", "/tmp/uploads")
|
| 8 |
os.makedirs(TMP_DIR, exist_ok=True)
|
|
|
|
| 2 |
|
| 3 |
MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
|
| 4 |
MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
|
| 5 |
+
# MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-base")
|
| 6 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-small")
|
| 7 |
|
| 8 |
TMP_DIR = os.getenv("TMP_DIR", "/tmp/uploads")
|
| 9 |
os.makedirs(TMP_DIR, exist_ok=True)
|
app/core/asr_engine.py
CHANGED
|
@@ -10,6 +10,7 @@ import os
|
|
| 10 |
|
| 11 |
from app.core.chunking import split_audio_to_chunks
|
| 12 |
from app.core.audio_utils import get_audio_info
|
|
|
|
| 13 |
|
| 14 |
logger = logging.getLogger(__name__)
|
| 15 |
|
|
@@ -28,7 +29,7 @@ def load_model(chunk_length_s: float = 30.0):
|
|
| 28 |
if _ASR_MODEL is not None:
|
| 29 |
return _ASR_MODEL
|
| 30 |
|
| 31 |
-
logger.info("Loading ASR model
|
| 32 |
|
| 33 |
device = 0 if torch.cuda.is_available() else -1
|
| 34 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
@@ -43,7 +44,7 @@ def load_model(chunk_length_s: float = 30.0):
|
|
| 43 |
|
| 44 |
_ASR_MODEL = pipeline(
|
| 45 |
task="automatic-speech-recognition",
|
| 46 |
-
model=
|
| 47 |
device=device,
|
| 48 |
dtype=dtype,
|
| 49 |
chunk_length_s=chunk_length_s,
|
|
|
|
| 10 |
|
| 11 |
from app.core.chunking import split_audio_to_chunks
|
| 12 |
from app.core.audio_utils import get_audio_info
|
| 13 |
+
from app.config.settings import MODEL_NAME
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
|
|
|
|
| 29 |
if _ASR_MODEL is not None:
|
| 30 |
return _ASR_MODEL
|
| 31 |
|
| 32 |
+
logger.info("Loading ASR model %s", MODEL_NAME)
|
| 33 |
|
| 34 |
device = 0 if torch.cuda.is_available() else -1
|
| 35 |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
|
| 44 |
|
| 45 |
_ASR_MODEL = pipeline(
|
| 46 |
task="automatic-speech-recognition",
|
| 47 |
+
model=MODEL_NAME,
|
| 48 |
device=device,
|
| 49 |
dtype=dtype,
|
| 50 |
chunk_length_s=chunk_length_s,
|
app/jobs/transcribe_job.py
CHANGED
|
@@ -7,6 +7,7 @@ import httpx
|
|
| 7 |
import time
|
| 8 |
|
| 9 |
from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks, transcribe_file_unified
|
|
|
|
| 10 |
from app.services.note_client import NoteServiceClient
|
| 11 |
from app.core.audio_utils import get_audio_info
|
| 12 |
from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
|
|
@@ -103,7 +104,7 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
|
|
| 103 |
"duration": info.get("duration"),
|
| 104 |
"sample_rate": info.get("samplerate"),
|
| 105 |
"chunks": chunks,
|
| 106 |
-
"asr_model":
|
| 107 |
},
|
| 108 |
"client": {"user_id": user_id},
|
| 109 |
},
|
|
|
|
| 7 |
import time
|
| 8 |
|
| 9 |
from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks, transcribe_file_unified
|
| 10 |
+
from app.config.settings import MODEL_NAME
|
| 11 |
from app.services.note_client import NoteServiceClient
|
| 12 |
from app.core.audio_utils import get_audio_info
|
| 13 |
from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
|
|
|
|
| 104 |
"duration": info.get("duration"),
|
| 105 |
"sample_rate": info.get("samplerate"),
|
| 106 |
"chunks": chunks,
|
| 107 |
+
"asr_model": MODEL_NAME,
|
| 108 |
},
|
| 109 |
"client": {"user_id": user_id},
|
| 110 |
},
|