Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import List, Dict | |
| import torch | |
| from transformers import pipeline | |
| from transformers import logging as transformers_logging | |
| import warnings | |
| import os | |
| from typing import Tuple | |
| from app.core.chunking import split_audio_to_chunks | |
| from app.core.audio_utils import get_audio_info | |
| logger = logging.getLogger(__name__) | |
| # =============================== | |
| # Global model cache | |
| # =============================== | |
| _ASR_MODEL = None | |
| def load_model(chunk_length_s: float = 30.0): | |
| """ | |
| Load ASR model once and reuse. | |
| Safe to call multiple times. | |
| """ | |
| global _ASR_MODEL | |
| if _ASR_MODEL is not None: | |
| return _ASR_MODEL | |
| logger.info("Loading ASR model PhoWhisper-base") | |
| device = 0 if torch.cuda.is_available() else -1 | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # Reduce noisy transformer logs and warnings about experimental chunking | |
| try: | |
| transformers_logging.set_verbosity_error() | |
| except Exception: | |
| pass | |
| # filter the noisy chunk_length_s warnings (regex) | |
| warnings.filterwarnings("ignore", message=r".*chunk_length_s.*") | |
| _ASR_MODEL = pipeline( | |
| task="automatic-speech-recognition", | |
| model="vinai/PhoWhisper-base", | |
| device=device, | |
| dtype=dtype, | |
| chunk_length_s=chunk_length_s, | |
| return_timestamps=True, | |
| ignore_warning=True, | |
| ) | |
| logger.info( | |
| "ASR model loaded (device=%s)", "cuda" if device >= 0 else "cpu" | |
| ) | |
| return _ASR_MODEL | |
| # =============================== | |
| # Transcribe full text | |
| # =============================== | |
| def transcribe_file( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| stride_s: float = 5.0, | |
| ) -> str: | |
| """ | |
| Return full transcript text. | |
| """ | |
| if not wav_path: | |
| return "" | |
| # If audio is long, prefer chunked inference to avoid memory/time issues | |
| info = get_audio_info(wav_path) or {} | |
| duration = info.get("duration", 0) | |
| if duration and duration > chunk_length_s: | |
| try: | |
| text, _chunks = transcribe_long_audio( | |
| model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s | |
| ) | |
| return text | |
| except Exception: | |
| logger.exception("transcribe_long_audio failed, falling back to pipeline") | |
| out = model( | |
| wav_path, | |
| chunk_length_s=chunk_length_s, | |
| stride_length_s=stride_s, | |
| # return_timestamps may be ignored for full-text outputs but safe to pass | |
| ) | |
| # Primary: pipeline may return 'text' | |
| text = (out.get("text") or "").strip() | |
| if text: | |
| return text | |
| # Fallback: some pipeline versions return detailed segments/chunks | |
| segs = out.get("chunks") or out.get("segments") or [] | |
| if segs: | |
| parts = [ (s.get("text") or "").strip() for s in segs ] | |
| joined = " ".join([p for p in parts if p]) | |
| return joined.strip() | |
| return "" | |
| def transcribe_long_audio( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| overlap_s: float = 5.0, | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Split `wav_path` into chunks and run inference on each chunk sequentially. | |
| Returns (full_text, chunks) where chunks have global start/end timestamps. | |
| """ | |
| if not wav_path: | |
| return "", [] | |
| # prefer VAD-based splitting if available | |
| try: | |
| from app.core.chunking import split_audio_with_vad | |
| chunk_paths = split_audio_with_vad(wav_path) | |
| except Exception: | |
| chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s) | |
| logger.debug("transcribe_long_audio: split into %d chunk_paths", len(chunk_paths)) | |
| combined_text_parts = [] | |
| combined_chunks: List[Dict] = [] | |
| step = chunk_length_s - overlap_s | |
| try: | |
| for i, cp in enumerate(chunk_paths): | |
| base_offset = i * step | |
| try: | |
| cinfo = get_audio_info(cp) or {} | |
| logger.debug( | |
| "chunk[%d]=%s duration=%.3fs samplerate=%s", i, cp, cinfo.get("duration"), cinfo.get("samplerate") | |
| ) | |
| except Exception: | |
| logger.debug("chunk[%d]=%s (info unavailable)", i, cp) | |
| try: | |
| out = model( | |
| cp, | |
| chunk_length_s=chunk_length_s, | |
| stride_length_s=overlap_s, | |
| return_timestamps=True, | |
| ) | |
| except Exception: | |
| logger.exception("model inference failed for chunk %s", cp) | |
| continue | |
| # debug: log output shape/keys (only first few chunks to avoid huge logs) | |
| try: | |
| if i < 5: | |
| logger.debug("model out keys for chunk[%d]: %s", i, list(out.keys()) if isinstance(out, dict) else type(out)) | |
| except Exception: | |
| logger.debug("failed to log model out keys for chunk %d", i) | |
| part_text = (out.get("text") or "").strip() | |
| if not part_text: | |
| segs = out.get("chunks") or out.get("segments") or [] | |
| parts = [ (s.get("text") or "").strip() for s in segs ] | |
| part_text = " ".join([p for p in parts if p]).strip() | |
| if part_text: | |
| combined_text_parts.append(part_text) | |
| raw_segs = out.get("chunks") or out.get("segments") or [] | |
| if raw_segs: | |
| for s in raw_segs: | |
| start = None | |
| end = None | |
| if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2: | |
| ts = s.get("timestamp") | |
| start, end = ts[0], ts[1] | |
| elif s.get("start") is not None and s.get("end") is not None: | |
| start, end = s.get("start"), s.get("end") | |
| text = (s.get("text") or "").strip() | |
| if not text or start is None or end is None: | |
| continue | |
| try: | |
| combined_chunks.append( | |
| {"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text} | |
| ) | |
| except Exception: | |
| continue | |
| else: | |
| # If model returned text but no timestamped segments for this chunk, | |
| # create a fallback chunk spanning the chunk file duration. | |
| if part_text: | |
| try: | |
| cinfo = get_audio_info(cp) or {} | |
| cdur = cinfo.get("duration") or chunk_length_s | |
| combined_chunks.append({ | |
| "start": float(base_offset), | |
| "end": float(base_offset) + float(cdur), | |
| "text": part_text, | |
| }) | |
| except Exception: | |
| logger.exception("failed to create fallback chunk for %s", cp) | |
| finally: | |
| for p in chunk_paths: | |
| try: | |
| if p and os.path.exists(p): | |
| os.remove(p) | |
| except Exception: | |
| logger.debug("Failed to remove chunk file %s", p) | |
| full_text = " ".join([p for p in combined_text_parts if p]).strip() | |
| return full_text, combined_chunks | |
| # =============================== | |
| # Transcribe chunks with timestamps | |
| # =============================== | |
| def transcribe_file_chunks( | |
| model, | |
| wav_path: str, | |
| chunk_length_s: float = 30.0, | |
| stride_s: float = 5.0, | |
| ) -> List[Dict]: | |
| """ | |
| Return list of chunks: | |
| [{ start, end, text }] | |
| """ | |
| if not wav_path: | |
| return [] | |
| # For long audio prefer explicit chunked inference (split + per-chunk inference) | |
| info = get_audio_info(wav_path) or {} | |
| duration = info.get("duration", 0) | |
| if duration and duration > chunk_length_s: | |
| try: | |
| _, combined = transcribe_long_audio( | |
| model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s | |
| ) | |
| return combined | |
| except Exception: | |
| logger.exception("transcribe_long_audio failed in transcribe_file_chunks, falling back to pipeline") | |
| out = model( | |
| wav_path, | |
| chunk_length_s=chunk_length_s, | |
| stride_length_s=stride_s, | |
| return_timestamps=True, | |
| ) | |
| # Pipeline output can vary across transformers versions/models: | |
| # - some return `chunks` (with `timestamp` list), | |
| # - others return `segments` (with `start`/end), | |
| # so be permissive and handle both shapes. | |
| raw_segments = out.get("chunks") or out.get("segments") or [] | |
| chunks = [] | |
| for c in raw_segments: | |
| # try multiple timestamp shapes | |
| start = None | |
| end = None | |
| if isinstance(c.get("timestamp"), (list, tuple)) and len(c.get("timestamp")) >= 2: | |
| ts = c.get("timestamp") | |
| start, end = ts[0], ts[1] | |
| elif c.get("start") is not None and c.get("end") is not None: | |
| start, end = c.get("start"), c.get("end") | |
| text = (c.get("text") or "").strip() | |
| if not text: | |
| continue | |
| # If timestamps are missing, skip (we don't want chunks without timing) | |
| if start is None or end is None: | |
| continue | |
| try: | |
| chunks.append({"start": float(start), "end": float(end), "text": text}) | |
| except Exception: | |
| # be robust against unexpected types | |
| continue | |
| return chunks | |