Ved Gupta commited on
Commit
5ece346
·
1 Parent(s): 022e710

model parameter added

Browse files
app/api/endpoints/transcribe.py CHANGED
@@ -6,10 +6,14 @@ from pydantic import BaseModel
6
 
7
  from app.core.database import SessionLocal
8
 
9
- from app.utils.utils import save_audio_file, transcribe_file, get_audio_duration
 
 
 
 
 
10
  from app.core.models import AuthTokenController, TranscribeController
11
 
12
-
13
  router = APIRouter()
14
  database = SessionLocal()
15
 
@@ -24,12 +28,13 @@ async def post_audio(
24
  background_tasks: BackgroundTasks,
25
  request: Request,
26
  file: UploadFile = File(...),
 
27
  Authentication: Annotated[Union[str, None], Header()] = None,
28
  ):
29
  try:
30
  userId = AuthTokenController(database).get_userid_from_token(Authentication)
31
  file_path = save_audio_file(file)
32
- [data, output_audio_path] = transcribe_file(file_path)
33
  background_tasks.add_task(
34
  create_transcribe_record, database, userId, data, output_audio_path
35
  )
 
6
 
7
  from app.core.database import SessionLocal
8
 
9
+ from app.utils.utils import (
10
+ save_audio_file,
11
+ transcribe_file,
12
+ get_audio_duration,
13
+ get_model_name,
14
+ )
15
  from app.core.models import AuthTokenController, TranscribeController
16
 
 
17
  router = APIRouter()
18
  database = SessionLocal()
19
 
 
28
  background_tasks: BackgroundTasks,
29
  request: Request,
30
  file: UploadFile = File(...),
31
+ model: str = "tiny.en.q5",
32
  Authentication: Annotated[Union[str, None], Header()] = None,
33
  ):
34
  try:
35
  userId = AuthTokenController(database).get_userid_from_token(Authentication)
36
  file_path = save_audio_file(file)
37
+ [data, output_audio_path] = transcribe_file(file_path, get_model_name(model))
38
  background_tasks.add_task(
39
  create_transcribe_record, database, userId, data, output_audio_path
40
  )
app/utils/constant.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ model_names = {
2
+ "tiny.en": "ggml-tiny.en.bin",
3
+ "tiny.en.q5": "ggml-model-whisper-tiny.en-q5_1.bin",
4
+ "base.en.q5": "ggml-model-whisper-base.en-q5_1.bin",
5
+ }
app/utils/utils.py CHANGED
@@ -4,6 +4,8 @@ import uuid
4
  import logging
5
  import wave
6
 
 
 
7
 
8
  def get_all_routes(app):
9
  routes = []
@@ -88,3 +90,13 @@ def get_audio_duration(audio_file):
88
  rounded_duration = int(round(duration, 0))
89
 
90
  return rounded_duration
 
 
 
 
 
 
 
 
 
 
 
4
  import logging
5
  import wave
6
 
7
+ from .constant import model_names
8
+
9
 
10
  def get_all_routes(app):
11
  routes = []
 
90
  rounded_duration = int(round(duration, 0))
91
 
92
  return rounded_duration
93
+
94
+
95
+ def get_model_name(model: str = None):
96
+ if model is None:
97
+ model_names["tiny.en.q5"]
98
+
99
+ if model in model_names.keys():
100
+ return model_names[model]
101
+
102
+ return model_names["tiny.en.q5"]