bichnhan2701 commited on
Commit
b55d9e8
·
1 Parent(s): bd471fc

Update model version

Browse files
app/api/transcribe.py CHANGED
@@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse
10
  from rq import Queue, Retry
11
 
12
  from app.config import settings
 
13
  from app.infra.redis_client import redis_client
14
  from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
15
  from app.schemas.transcribe import TranscribeResponse
@@ -126,7 +127,7 @@ async def _run_sync_pipeline(tmp_wav: str, note_id: str, audio_url: str | None =
126
  "duration": info.get("duration"),
127
  "sample_rate": info.get("samplerate"),
128
  "chunks": chunks,
129
- "asr_model": "PhoWhisper-base",
130
  }
131
  },
132
  "generate": ["normalize", "keywords", "summary", "mindmap"],
@@ -156,7 +157,7 @@ async def _create_placeholder_note(note_id: str, duration: float, audio_url: str
156
  "audio": {
157
  "duration": duration,
158
  "chunks": [],
159
- "asr_model": "PhoWhisper-base",
160
  }
161
  },
162
  # ❌ KHÔNG generate ở đây
 
10
  from rq import Queue, Retry
11
 
12
  from app.config import settings
13
+ from app.config.settings import MODEL_NAME
14
  from app.infra.redis_client import redis_client
15
  from app.infra.metrics import REQUEST_COUNT, REQUEST_LATENCY, ASR_DURATION
16
  from app.schemas.transcribe import TranscribeResponse
 
127
  "duration": info.get("duration"),
128
  "sample_rate": info.get("samplerate"),
129
  "chunks": chunks,
130
+ "asr_model": MODEL_NAME,
131
  }
132
  },
133
  "generate": ["normalize", "keywords", "summary", "mindmap"],
 
157
  "audio": {
158
  "duration": duration,
159
  "chunks": [],
160
+ "asr_model": MODEL_NAME,
161
  }
162
  },
163
  # ❌ KHÔNG generate ở đây
app/config/settings.py CHANGED
@@ -2,7 +2,8 @@ import os
2
 
3
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
4
  MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
5
- MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-base")
 
6
 
7
  TMP_DIR = os.getenv("TMP_DIR", "/tmp/uploads")
8
  os.makedirs(TMP_DIR, exist_ok=True)
 
2
 
3
  MAX_UPLOAD_BYTES = int(os.getenv("MAX_UPLOAD_BYTES", 100 * 1024 * 1024))
4
  MAX_DURATION_SECS = int(os.getenv("MAX_DURATION_SECS", 60 * 60))
5
+ # MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-base")
6
+ MODEL_NAME = os.getenv("MODEL_NAME", "vinai/PhoWhisper-small")
7
 
8
  TMP_DIR = os.getenv("TMP_DIR", "/tmp/uploads")
9
  os.makedirs(TMP_DIR, exist_ok=True)
app/core/asr_engine.py CHANGED
@@ -10,6 +10,7 @@ import os
10
 
11
  from app.core.chunking import split_audio_to_chunks
12
  from app.core.audio_utils import get_audio_info
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -28,7 +29,7 @@ def load_model(chunk_length_s: float = 30.0):
28
  if _ASR_MODEL is not None:
29
  return _ASR_MODEL
30
 
31
- logger.info("Loading ASR model PhoWhisper-base")
32
 
33
  device = 0 if torch.cuda.is_available() else -1
34
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -43,7 +44,7 @@ def load_model(chunk_length_s: float = 30.0):
43
 
44
  _ASR_MODEL = pipeline(
45
  task="automatic-speech-recognition",
46
- model="vinai/PhoWhisper-base",
47
  device=device,
48
  dtype=dtype,
49
  chunk_length_s=chunk_length_s,
 
10
 
11
  from app.core.chunking import split_audio_to_chunks
12
  from app.core.audio_utils import get_audio_info
13
+ from app.config.settings import MODEL_NAME
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
29
  if _ASR_MODEL is not None:
30
  return _ASR_MODEL
31
 
32
+ logger.info("Loading ASR model %s", MODEL_NAME)
33
 
34
  device = 0 if torch.cuda.is_available() else -1
35
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
44
 
45
  _ASR_MODEL = pipeline(
46
  task="automatic-speech-recognition",
47
+ model=MODEL_NAME,
48
  device=device,
49
  dtype=dtype,
50
  chunk_length_s=chunk_length_s,
app/jobs/transcribe_job.py CHANGED
@@ -7,6 +7,7 @@ import httpx
7
  import time
8
 
9
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks, transcribe_file_unified
 
10
  from app.services.note_client import NoteServiceClient
11
  from app.core.audio_utils import get_audio_info
12
  from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
@@ -103,7 +104,7 @@ def transcribe_job(audio_url: str, note_id: str, user_id: str | None = None):
103
  "duration": info.get("duration"),
104
  "sample_rate": info.get("samplerate"),
105
  "chunks": chunks,
106
- "asr_model": "PhoWhisper-base",
107
  },
108
  "client": {"user_id": user_id},
109
  },
 
7
  import time
8
 
9
  from app.core.asr_engine import load_model, transcribe_file, transcribe_file_chunks, transcribe_file_unified
10
+ from app.config.settings import MODEL_NAME
11
  from app.services.note_client import NoteServiceClient
12
  from app.core.audio_utils import get_audio_info
13
  from app.core.audio_utils import ensure_wav_16k_mono, make_temp_path
 
104
  "duration": info.get("duration"),
105
  "sample_rate": info.get("samplerate"),
106
  "chunks": chunks,
107
+ "asr_model": MODEL_NAME,
108
  },
109
  "client": {"user_id": user_id},
110
  },