PhoWhisperBaseAPI / app /core /asr_engine.py
bichnhan2701's picture
ignore warning
be40b87
import logging
from typing import List, Dict
import torch
from transformers import pipeline
from transformers import logging as transformers_logging
import warnings
import os
from typing import Tuple
from app.core.chunking import split_audio_to_chunks
from app.core.audio_utils import get_audio_info
logger = logging.getLogger(__name__)
# ===============================
# Global model cache
# ===============================
_ASR_MODEL = None
def load_model(chunk_length_s: float = 30.0):
"""
Load ASR model once and reuse.
Safe to call multiple times.
"""
global _ASR_MODEL
if _ASR_MODEL is not None:
return _ASR_MODEL
logger.info("Loading ASR model PhoWhisper-base")
device = 0 if torch.cuda.is_available() else -1
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Reduce noisy transformer logs and warnings about experimental chunking
try:
transformers_logging.set_verbosity_error()
except Exception:
pass
# filter the noisy chunk_length_s warnings (regex)
warnings.filterwarnings("ignore", message=r".*chunk_length_s.*")
_ASR_MODEL = pipeline(
task="automatic-speech-recognition",
model="vinai/PhoWhisper-base",
device=device,
dtype=dtype,
chunk_length_s=chunk_length_s,
return_timestamps=True,
ignore_warning=True,
)
logger.info(
"ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu"
)
return _ASR_MODEL
# ===============================
# Transcribe full text
# ===============================
def transcribe_file(
model,
wav_path: str,
chunk_length_s: float = 30.0,
stride_s: float = 5.0,
) -> str:
"""
Return full transcript text.
"""
if not wav_path:
return ""
# If audio is long, prefer chunked inference to avoid memory/time issues
info = get_audio_info(wav_path) or {}
duration = info.get("duration", 0)
if duration and duration > chunk_length_s:
try:
text, _chunks = transcribe_long_audio(
model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
)
return text
except Exception:
logger.exception("transcribe_long_audio failed, falling back to pipeline")
out = model(
wav_path,
chunk_length_s=chunk_length_s,
stride_length_s=stride_s,
# return_timestamps may be ignored for full-text outputs but safe to pass
)
# Primary: pipeline may return 'text'
text = (out.get("text") or "").strip()
if text:
return text
# Fallback: some pipeline versions return detailed segments/chunks
segs = out.get("chunks") or out.get("segments") or []
if segs:
parts = [ (s.get("text") or "").strip() for s in segs ]
joined = " ".join([p for p in parts if p])
return joined.strip()
return ""
def transcribe_long_audio(
model,
wav_path: str,
chunk_length_s: float = 30.0,
overlap_s: float = 5.0,
) -> Tuple[str, List[Dict]]:
"""
Split `wav_path` into chunks and run inference on each chunk sequentially.
Returns (full_text, chunks) where chunks have global start/end timestamps.
"""
if not wav_path:
return "", []
# prefer VAD-based splitting if available
try:
from app.core.chunking import split_audio_with_vad
chunk_paths = split_audio_with_vad(wav_path)
except Exception:
chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
logger.debug("transcribe_long_audio: split into %d chunk_paths", len(chunk_paths))
combined_text_parts = []
combined_chunks: List[Dict] = []
step = chunk_length_s - overlap_s
try:
for i, cp in enumerate(chunk_paths):
base_offset = i * step
try:
cinfo = get_audio_info(cp) or {}
logger.debug(
"chunk[%d]=%s duration=%.3fs samplerate=%s", i, cp, cinfo.get("duration"), cinfo.get("samplerate")
)
except Exception:
logger.debug("chunk[%d]=%s (info unavailable)", i, cp)
try:
out = model(
cp,
chunk_length_s=chunk_length_s,
stride_length_s=overlap_s,
return_timestamps=True,
)
except Exception:
logger.exception("model inference failed for chunk %s", cp)
continue
# debug: log output shape/keys (only first few chunks to avoid huge logs)
try:
if i < 5:
logger.debug("model out keys for chunk[%d]: %s", i, list(out.keys()) if isinstance(out, dict) else type(out))
except Exception:
logger.debug("failed to log model out keys for chunk %d", i)
part_text = (out.get("text") or "").strip()
if not part_text:
segs = out.get("chunks") or out.get("segments") or []
parts = [ (s.get("text") or "").strip() for s in segs ]
part_text = " ".join([p for p in parts if p]).strip()
if part_text:
combined_text_parts.append(part_text)
raw_segs = out.get("chunks") or out.get("segments") or []
if raw_segs:
for s in raw_segs:
start = None
end = None
if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2:
ts = s.get("timestamp")
start, end = ts[0], ts[1]
elif s.get("start") is not None and s.get("end") is not None:
start, end = s.get("start"), s.get("end")
text = (s.get("text") or "").strip()
if not text or start is None or end is None:
continue
try:
combined_chunks.append(
{"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text}
)
except Exception:
continue
else:
# If model returned text but no timestamped segments for this chunk,
# create a fallback chunk spanning the chunk file duration.
if part_text:
try:
cinfo = get_audio_info(cp) or {}
cdur = cinfo.get("duration") or chunk_length_s
combined_chunks.append({
"start": float(base_offset),
"end": float(base_offset) + float(cdur),
"text": part_text,
})
except Exception:
logger.exception("failed to create fallback chunk for %s", cp)
finally:
for p in chunk_paths:
try:
if p and os.path.exists(p):
os.remove(p)
except Exception:
logger.debug("Failed to remove chunk file %s", p)
full_text = " ".join([p for p in combined_text_parts if p]).strip()
return full_text, combined_chunks
# ===============================
# Transcribe chunks with timestamps
# ===============================
def transcribe_file_chunks(
model,
wav_path: str,
chunk_length_s: float = 30.0,
stride_s: float = 5.0,
) -> List[Dict]:
"""
Return list of chunks:
[{ start, end, text }]
"""
if not wav_path:
return []
# For long audio prefer explicit chunked inference (split + per-chunk inference)
info = get_audio_info(wav_path) or {}
duration = info.get("duration", 0)
if duration and duration > chunk_length_s:
try:
_, combined = transcribe_long_audio(
model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
)
return combined
except Exception:
logger.exception("transcribe_long_audio failed in transcribe_file_chunks, falling back to pipeline")
out = model(
wav_path,
chunk_length_s=chunk_length_s,
stride_length_s=stride_s,
return_timestamps=True,
)
# Pipeline output can vary across transformers versions/models:
# - some return `chunks` (with `timestamp` list),
# - others return `segments` (with `start`/end),
# so be permissive and handle both shapes.
raw_segments = out.get("chunks") or out.get("segments") or []
chunks = []
for c in raw_segments:
# try multiple timestamp shapes
start = None
end = None
if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2:
ts = c.get("timestamp")
start, end = ts[0], ts[1]
elif c.get("start") is not None and c.get("end") is not None:
start, end = c.get("start"), c.get("end")
text = (c.get("text") or "").strip()
if not text:
continue
# If timestamps are missing, skip (we don't want chunks without timing)
if start is None or end is None:
continue
try:
chunks.append({"start": float(start), "end": float(end), "text": text})
except Exception:
# be robust against unexpected types
continue
return chunks