bichnhan2701 commited on
Commit
580c470
·
1 Parent(s): 7d9eaeb
Files changed (1) hide show
  1. app/core/asr_engine.py +105 -0
app/core/asr_engine.py CHANGED
@@ -3,6 +3,11 @@ from typing import List, Dict
3
 
4
  import torch
5
  from transformers import pipeline
 
 
 
 
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -56,6 +61,18 @@ def transcribe_file(
56
  if not wav_path:
57
  return ""
58
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  out = model(
60
  wav_path,
61
  chunk_length_s=chunk_length_s,
@@ -78,6 +95,81 @@ def transcribe_file(
78
  return ""
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ===============================
82
  # Transcribe chunks with timestamps
83
  # ===============================
@@ -132,4 +224,17 @@ def transcribe_file_chunks(
132
  # be robust against unexpected types
133
  continue
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return chunks
 
3
 
4
  import torch
5
  from transformers import pipeline
6
+ import os
7
+ from typing import Tuple
8
+
9
+ from app.core.chunking import split_audio_to_chunks
10
+ from app.core.audio_utils import get_audio_info
11
 
12
  logger = logging.getLogger(__name__)
13
 
 
61
  if not wav_path:
62
  return ""
63
 
64
+ # If audio is long, prefer chunked inference to avoid memory/time issues
65
+ info = get_audio_info(wav_path) or {}
66
+ duration = info.get("duration", 0)
67
+ if duration and duration > chunk_length_s:
68
+ try:
69
+ text, _chunks = transcribe_long_audio(
70
+ model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
71
+ )
72
+ return text
73
+ except Exception:
74
+ logger.exception("transcribe_long_audio failed, falling back to pipeline")
75
+
76
  out = model(
77
  wav_path,
78
  chunk_length_s=chunk_length_s,
 
95
  return ""
96
 
97
 
98
+ def transcribe_long_audio(
99
+ model,
100
+ wav_path: str,
101
+ chunk_length_s: float = 30.0,
102
+ overlap_s: float = 5.0,
103
+ ) -> Tuple[str, List[Dict]]:
104
+ """
105
+ Split `wav_path` into chunks and run inference on each chunk sequentially.
106
+ Returns (full_text, chunks) where chunks have global start/end timestamps.
107
+ """
108
+ if not wav_path:
109
+ return "", []
110
+
111
+ chunk_paths = split_audio_to_chunks(wav_path, chunk_length_s=chunk_length_s, overlap_s=overlap_s)
112
+ combined_text_parts = []
113
+ combined_chunks: List[Dict] = []
114
+
115
+ step = chunk_length_s - overlap_s
116
+ try:
117
+ for i, cp in enumerate(chunk_paths):
118
+ base_offset = i * step
119
+
120
+ try:
121
+ out = model(
122
+ cp,
123
+ chunk_length_s=chunk_length_s,
124
+ stride_length_s=overlap_s,
125
+ return_timestamps=True,
126
+ )
127
+ except Exception:
128
+ logger.exception("model inference failed for chunk %s", cp)
129
+ continue
130
+
131
+ part_text = (out.get("text") or "").strip()
132
+ if not part_text:
133
+ segs = out.get("chunks") or out.get("segments") or []
134
+ parts = [ (s.get("text") or "").strip() for s in segs ]
135
+ part_text = " ".join([p for p in parts if p]).strip()
136
+
137
+ if part_text:
138
+ combined_text_parts.append(part_text)
139
+
140
+ raw_segs = out.get("chunks") or out.get("segments") or []
141
+ for s in raw_segs:
142
+ start = None
143
+ end = None
144
+ if isinstance(s.get("timestamp"), (list, tuple)) and len(s.get("timestamp")) >= 2:
145
+ ts = s.get("timestamp")
146
+ start, end = ts[0], ts[1]
147
+ elif s.get("start") is not None and s.get("end") is not None:
148
+ start, end = s.get("start"), s.get("end")
149
+
150
+ text = (s.get("text") or "").strip()
151
+ if not text or start is None or end is None:
152
+ continue
153
+
154
+ try:
155
+ combined_chunks.append(
156
+ {"start": float(start) + base_offset, "end": float(end) + base_offset, "text": text}
157
+ )
158
+ except Exception:
159
+ continue
160
+
161
+ finally:
162
+ for p in chunk_paths:
163
+ try:
164
+ if p and os.path.exists(p):
165
+ os.remove(p)
166
+ except Exception:
167
+ logger.debug("Failed to remove chunk file %s", p)
168
+
169
+ full_text = " ".join([p for p in combined_text_parts if p]).strip()
170
+ return full_text, combined_chunks
171
+
172
+
173
  # ===============================
174
  # Transcribe chunks with timestamps
175
  # ===============================
 
224
  # be robust against unexpected types
225
  continue
226
 
227
+ # If no timestamped chunks found and file is long, try chunked inference
228
+ if not chunks:
229
+ info = get_audio_info(wav_path) or {}
230
+ duration = info.get("duration", 0)
231
+ if duration and duration > chunk_length_s:
232
+ try:
233
+ _, combined = transcribe_long_audio(
234
+ model, wav_path, chunk_length_s=chunk_length_s, overlap_s=stride_s
235
+ )
236
+ return combined
237
+ except Exception:
238
+ logger.exception("transcribe_long_audio fallback failed for %s", wav_path)
239
+
240
  return chunks