import logging from typing import List, Dict import torch from transformers import pipeline from transformers import logging as transformers_logging import warnings import os from typing import Tuple from app.core.chunking import split_audio_to_chunks from app.core.audio_utils import get_audio_info logger = logging.getLogger(__name__) # =============================== # Global model cache # =============================== _ASR_MODEL = None def load_model(chunk_length_s: float = 30.0): """ Load ASR model once and reuse. Safe to call multiple times. """ global _ASR_MODEL if _ASR_MODEL is not None: return _ASR_MODEL logger.info("Loading ASR model PhoWhisper-base") device = 0 if torch.cuda.is_available() else -1 dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Reduce noisy transformer logs and warnings about experimental chunking try: transformers_logging.set_verbosity_error() except Exception: pass # filter the noisy chunk_length_s warnings (regex) warnings.filterwarnings("ignore", message=r".*chunk_length_s.*") _ASR_MODEL = pipeline( task="automatic-speech-recognition", model="vinai/PhoWhisper-base", device=device, dtype=dtype, chunk_length_s=chunk_length_s, return_timestamps=True, ignore_warning=True, ) logger.info( "ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu" ) return _ASR_MODEL # =============================== # Transcribe full text # =============================== def transcribe_file( model, wav_path: str, chunk_length_s: float = 30.0, stride_s: float = 5.0, ) -> str: """ Return full transcript text. """ if not wav_path: return "" # If audio is long, prefer chunked inference to avoid memory/time issues info = get_audio_info(wav_path) or {} duration = info.get("duration", 0) if duration and duration > chunk_length_s: try: text, _chunks = transcribe_long_audio( model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s ) return text except Exception: logger.exception("transcribe_long_audio failed, falling back to pipeline") out = model( wav_path, chunk_length_s=chunk_length_s, stride_length_s=stride_s, # return_timestamps may be ignored for full-text outputs but safe to pass ) # Primary: pipeline may return 'text' text = (out.get("text") or "").strip() if text: return text # Fallback: some pipeline versions return detailed segments/chunks segs = out.get("chunks") or out.get("segments") or [] if segs: parts = [ (s.get("text") or "").strip() for s in segs ] joined = " ".join([p for p in parts if p]) return joined.strip() return "" def transcribe_long_audio( model, wav_path: str, chunk_length_s: float = 30.0, overlap_s: float = 5.0, ) -> Tuple[str, List[Dict]]: """ Split `wav_path` into chunks and run inference on each chunk sequentially. Returns (full_text, chunks) where chunks have global start/end timestamps. """ if not wav_path: return "", [] # prefer VAD-based splitting if available try: from app.core.chunking import split_audio_with_vad chunk_paths = split_audio_with_vad(wav_path) except Exception: chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s) logger.debug("transcribe_long_audio: split into %d chunk_paths", len(chunk_paths)) combined_text_parts = [] combined_chunks: List[Dict] = [] step = chunk_length_s - overlap_s try: for i, cp in enumerate(chunk_paths): base_offset = i * step try: cinfo = get_audio_info(cp) or {} logger.debug( "chunk[%d]=%s duration=%.3fs samplerate=%s", i, cp, cinfo.get("duration"), cinfo.get("samplerate") ) except Exception: logger.debug("chunk[%d]=%s (info unavailable)", i, cp) try: out = model( cp, chunk_length_s=chunk_length_s, stride_length_s=overlap_s, return_timestamps=True, ) except Exception: logger.exception("model inference failed for chunk %s", cp) continue # debug: log output shape/keys (only first few chunks to avoid huge logs) try: if i < 5: logger.debug("model out keys for chunk[%d]: %s", i, list(out.keys()) if isinstance(out, dict) else type(out)) except Exception: logger.debug("failed to log model out keys for chunk %d", i) part_text = (out.get("text") or "").strip() if not part_text: segs = out.get("chunks") or out.get("segments") or [] parts = [ (s.get("text") or "").strip() for s in segs ] part_text = " ".join([p for p in parts if p]).strip() if part_text: combined_text_parts.append(part_text) raw_segs = out.get("chunks") or out.get("segments") or [] if raw_segs: for s in raw_segs: start = None end = None if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2: ts = s.get("timestamp") start, end = ts[0], ts[1] elif s.get("start") is not None and s.get("end") is not None: start, end = s.get("start"), s.get("end") text = (s.get("text") or "").strip() if not text or start is None or end is None: continue try: combined_chunks.append( {"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text} ) except Exception: continue else: # If model returned text but no timestamped segments for this chunk, # create a fallback chunk spanning the chunk file duration. if part_text: try: cinfo = get_audio_info(cp) or {} cdur = cinfo.get("duration") or chunk_length_s combined_chunks.append({ "start": float(base_offset), "end": float(base_offset) + float(cdur), "text": part_text, }) except Exception: logger.exception("failed to create fallback chunk for %s", cp) finally: for p in chunk_paths: try: if p and os.path.exists(p): os.remove(p) except Exception: logger.debug("Failed to remove chunk file %s", p) full_text = " ".join([p for p in combined_text_parts if p]).strip() return full_text, combined_chunks # =============================== # Transcribe chunks with timestamps # =============================== def transcribe_file_chunks( model, wav_path: str, chunk_length_s: float = 30.0, stride_s: float = 5.0, ) -> List[Dict]: """ Return list of chunks: [{ start, end, text }] """ if not wav_path: return [] # For long audio prefer explicit chunked inference (split + per-chunk inference) info = get_audio_info(wav_path) or {} duration = info.get("duration", 0) if duration and duration > chunk_length_s: try: _, combined = transcribe_long_audio( model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s ) return combined except Exception: logger.exception("transcribe_long_audio failed in transcribe_file_chunks, falling back to pipeline") out = model( wav_path, chunk_length_s=chunk_length_s, stride_length_s=stride_s, return_timestamps=True, ) # Pipeline output can vary across transformers versions/models: # - some return `chunks` (with `timestamp` list), # - others return `segments` (with `start`/end), # so be permissive and handle both shapes. raw_segments = out.get("chunks") or out.get("segments") or [] chunks = [] for c in raw_segments: # try multiple timestamp shapes start = None end = None if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2: ts = c.get("timestamp") start, end = ts[0], ts[1] elif c.get("start") is not None and c.get("end") is not None: start, end = c.get("start"), c.get("end") text = (c.get("text") or "").strip() if not text: continue # If timestamps are missing, skip (we don't want chunks without timing) if start is None or end is None: continue try: chunks.append({"start": float(start), "end": float(end), "text": text}) except Exception: # be robust against unexpected types continue return chunks