Spaces:
Running
Running
Ved Gupta
commited on
Commit
·
5ece346
1
Parent(s):
022e710
model parameter added
Browse files- app/api/endpoints/transcribe.py +8 -3
- app/utils/constant.py +5 -0
- app/utils/utils.py +12 -0
app/api/endpoints/transcribe.py
CHANGED
|
@@ -6,10 +6,14 @@ from pydantic import BaseModel
|
|
| 6 |
|
| 7 |
from app.core.database import SessionLocal
|
| 8 |
|
| 9 |
-
from app.utils.utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from app.core.models import AuthTokenController, TranscribeController
|
| 11 |
|
| 12 |
-
|
| 13 |
router = APIRouter()
|
| 14 |
database = SessionLocal()
|
| 15 |
|
|
@@ -24,12 +28,13 @@ async def post_audio(
|
|
| 24 |
background_tasks: BackgroundTasks,
|
| 25 |
request: Request,
|
| 26 |
file: UploadFile = File(...),
|
|
|
|
| 27 |
Authentication: Annotated[Union[str, None], Header()] = None,
|
| 28 |
):
|
| 29 |
try:
|
| 30 |
userId = AuthTokenController(database).get_userid_from_token(Authentication)
|
| 31 |
file_path = save_audio_file(file)
|
| 32 |
-
[data, output_audio_path] = transcribe_file(file_path)
|
| 33 |
background_tasks.add_task(
|
| 34 |
create_transcribe_record, database, userId, data, output_audio_path
|
| 35 |
)
|
|
|
|
| 6 |
|
| 7 |
from app.core.database import SessionLocal
|
| 8 |
|
| 9 |
+
from app.utils.utils import (
|
| 10 |
+
save_audio_file,
|
| 11 |
+
transcribe_file,
|
| 12 |
+
get_audio_duration,
|
| 13 |
+
get_model_name,
|
| 14 |
+
)
|
| 15 |
from app.core.models import AuthTokenController, TranscribeController
|
| 16 |
|
|
|
|
| 17 |
router = APIRouter()
|
| 18 |
database = SessionLocal()
|
| 19 |
|
|
|
|
| 28 |
background_tasks: BackgroundTasks,
|
| 29 |
request: Request,
|
| 30 |
file: UploadFile = File(...),
|
| 31 |
+
model: str = "tiny.en.q5",
|
| 32 |
Authentication: Annotated[Union[str, None], Header()] = None,
|
| 33 |
):
|
| 34 |
try:
|
| 35 |
userId = AuthTokenController(database).get_userid_from_token(Authentication)
|
| 36 |
file_path = save_audio_file(file)
|
| 37 |
+
[data, output_audio_path] = transcribe_file(file_path, get_model_name(model))
|
| 38 |
background_tasks.add_task(
|
| 39 |
create_transcribe_record, database, userId, data, output_audio_path
|
| 40 |
)
|
app/utils/constant.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_names = {
|
| 2 |
+
"tiny.en": "ggml-tiny.en.bin",
|
| 3 |
+
"tiny.en.q5": "ggml-model-whisper-tiny.en-q5_1.bin",
|
| 4 |
+
"base.en.q5": "ggml-model-whisper-base.en-q5_1.bin",
|
| 5 |
+
}
|
app/utils/utils.py
CHANGED
|
@@ -4,6 +4,8 @@ import uuid
|
|
| 4 |
import logging
|
| 5 |
import wave
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def get_all_routes(app):
|
| 9 |
routes = []
|
|
@@ -88,3 +90,13 @@ def get_audio_duration(audio_file):
|
|
| 88 |
rounded_duration = int(round(duration, 0))
|
| 89 |
|
| 90 |
return rounded_duration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import logging
|
| 5 |
import wave
|
| 6 |
|
| 7 |
+
from .constant import model_names
|
| 8 |
+
|
| 9 |
|
| 10 |
def get_all_routes(app):
|
| 11 |
routes = []
|
|
|
|
| 90 |
rounded_duration = int(round(duration, 0))
|
| 91 |
|
| 92 |
return rounded_duration
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_model_name(model: str = None):
|
| 96 |
+
if model is None:
|
| 97 |
+
model_names["tiny.en.q5"]
|
| 98 |
+
|
| 99 |
+
if model in model_names.keys():
|
| 100 |
+
return model_names[model]
|
| 101 |
+
|
| 102 |
+
return model_names["tiny.en.q5"]
|