LiamKhoaLe commited on
Commit
4c24458
·
1 Parent(s): 4727f1c

Upd abort time and smart chunk-batcher #3

Browse files
Files changed (1) hide show
  1. app.py +72 -16
app.py CHANGED
@@ -36,6 +36,20 @@ def _concat_text(chunks):
36
  return " ".join([c.strip() for c in chunks if c and c.strip()])
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def _robust_transcribe_array(audio_array: np.ndarray, sr: int, task: str) -> str:
40
  """Transcribe long/large audio by chunking sequentially to minimize GPU memory.
41
 
@@ -49,22 +63,44 @@ def _robust_transcribe_array(audio_array: np.ndarray, sr: int, task: str) -> str
49
  win = int(chunk_s * sr)
50
  texts = []
51
  if len(audio_array) <= win:
52
- inputs = {"array": audio_array, "sampling_rate": sr}
53
- out = pipe(inputs, batch_size=1, generate_kwargs={"task": task})
54
- return out["text"]
55
  start = 0
56
  while start < len(audio_array):
57
  end = min(start + win, len(audio_array))
58
  chunk = audio_array[start:end]
59
- inputs = {"array": chunk, "sampling_rate": sr}
60
- out = pipe(inputs, batch_size=1, generate_kwargs={"task": task})
61
- texts.append(out["text"])
62
  if end == len(audio_array):
63
  break
64
  start += step
65
  return _concat_text(texts)
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def _robust_transcribe_path(path: str, task: str) -> str:
69
  sr = pipe.feature_extractor.sampling_rate
70
  # ffmpeg_read expects raw bytes, not a file path
@@ -97,22 +133,42 @@ def _robust_transcribe_path(path: str, task: str) -> str:
97
  def transcribe(inputs, task, summarize=False):
98
  if inputs is None:
99
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
  if isinstance(inputs, str):
102
- text = _robust_transcribe_path(inputs, task)
 
 
 
 
 
103
  elif isinstance(inputs, dict) and "array" in inputs:
104
- text = _robust_transcribe_array(inputs["array"], inputs.get("sampling_rate", pipe.feature_extractor.sampling_rate), task)
 
105
  else:
106
- text = pipe(inputs, batch_size=1, generate_kwargs={"task": task})["text"]
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
  raise gr.Error(f"Transcription failed: {e}")
109
- if summarize:
110
- try:
111
- summary = summarize_with_gemini(text)
112
- except Exception as e:
113
- summary = f"Summary error: {e}"
114
- return text, summary
115
- return text, ""
116
 
117
 
118
  def _return_yt_html_embed(yt_url):
 
36
  return " ".join([c.strip() for c in chunks if c and c.strip()])
37
 
38
 
39
+ def _transcribe_chunk(chunk: np.ndarray, sr: int, task: str, max_retries: int = 3) -> str:
40
+ """Transcribe a single chunk with retries and simple backoff."""
41
+ delay = 2.0
42
+ for attempt in range(max_retries):
43
+ try:
44
+ out = pipe({"array": chunk, "sampling_rate": sr}, batch_size=1, generate_kwargs={"task": task})
45
+ return out["text"]
46
+ except Exception:
47
+ if attempt == max_retries - 1:
48
+ raise
49
+ time.sleep(delay)
50
+ delay *= 1.8
51
+
52
+
53
  def _robust_transcribe_array(audio_array: np.ndarray, sr: int, task: str) -> str:
54
  """Transcribe long/large audio by chunking sequentially to minimize GPU memory.
55
 
 
63
  win = int(chunk_s * sr)
64
  texts = []
65
  if len(audio_array) <= win:
66
+ return _transcribe_chunk(audio_array, sr, task)
 
 
67
  start = 0
68
  while start < len(audio_array):
69
  end = min(start + win, len(audio_array))
70
  chunk = audio_array[start:end]
71
+ txt = _transcribe_chunk(chunk, sr, task)
72
+ texts.append(txt)
 
73
  if end == len(audio_array):
74
  break
75
  start += step
76
  return _concat_text(texts)
77
 
78
 
79
+ def _robust_transcribe_array_stream(audio_array: np.ndarray, sr: int, task: str):
80
+ """Generator: yields cumulative transcription after each chunk."""
81
+ if audio_array.ndim > 1:
82
+ audio_array = np.mean(audio_array, axis=1)
83
+ chunk_s = 20
84
+ overlap_s = 2
85
+ step = int((chunk_s - overlap_s) * sr)
86
+ win = int(chunk_s * sr)
87
+ texts = []
88
+ if len(audio_array) <= win:
89
+ texts.append(_transcribe_chunk(audio_array, sr, task))
90
+ yield _concat_text(texts)
91
+ return
92
+ start = 0
93
+ while start < len(audio_array):
94
+ end = min(start + win, len(audio_array))
95
+ chunk = audio_array[start:end]
96
+ txt = _transcribe_chunk(chunk, sr, task)
97
+ texts.append(txt)
98
+ yield _concat_text(texts)
99
+ if end == len(audio_array):
100
+ break
101
+ start += step
102
+
103
+
104
  def _robust_transcribe_path(path: str, task: str) -> str:
105
  sr = pipe.feature_extractor.sampling_rate
106
  # ffmpeg_read expects raw bytes, not a file path
 
133
  def transcribe(inputs, task, summarize=False):
134
  if inputs is None:
135
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
136
+ # Stream outputs incrementally: yield tuples (transcription_so_far, summary_so_far)
137
+ def _stream(gen):
138
+ running_text = ""
139
+ running_summary = ""
140
+ for partial in gen:
141
+ running_text = partial
142
+ if summarize and partial.strip():
143
+ try:
144
+ running_summary += ("\n\n" if running_summary else "") + summarize_with_gemini(partial)
145
+ except Exception:
146
+ pass
147
+ yield running_text, (running_summary if summarize else "")
148
+
149
  try:
150
  if isinstance(inputs, str):
151
+ # File path handed by Gradio
152
+ sr = pipe.feature_extractor.sampling_rate
153
+ with open(inputs, "rb") as _f:
154
+ payload = _f.read()
155
+ audio = ffmpeg_read(payload, sr)
156
+ return _stream(_robust_transcribe_array_stream(audio, sr, task))
157
  elif isinstance(inputs, dict) and "array" in inputs:
158
+ sr = inputs.get("sampling_rate", pipe.feature_extractor.sampling_rate)
159
+ return _stream(_robust_transcribe_array_stream(inputs["array"], sr, task))
160
  else:
161
+ # Fallback single shot
162
+ out = pipe(inputs, batch_size=1, generate_kwargs={"task": task})["text"]
163
+ if summarize:
164
+ try:
165
+ summ = summarize_with_gemini(out)
166
+ except Exception as e:
167
+ summ = f"Summary error: {e}"
168
+ return out, summ
169
+ return out, ""
170
  except Exception as e:
171
  raise gr.Error(f"Transcription failed: {e}")
 
 
 
 
 
 
 
172
 
173
 
174
  def _return_yt_html_embed(yt_url):