colab-user commited on
Commit
9a8a554
·
1 Parent(s): 65ee821

input stream

Browse files
Files changed (1) hide show
  1. app/services/processor.py +82 -64
app/services/processor.py CHANGED
@@ -60,13 +60,9 @@ def convert_audio_to_wav(audio_path: Path) -> Path:
60
 
61
 
62
  def format_timestamp(seconds: float) -> str:
63
- """Format seconds to MM:SS.ms or HH:MM:SS.ms"""
64
- hours = int(seconds // 3600)
65
- minutes = int((seconds % 3600) // 60)
66
- secs = seconds % 60
67
- if hours > 0:
68
- return f"{hours:02d}:{minutes:02d}:{secs:05.2f}"
69
- return f"{minutes:02d}:{secs:05.2f}"
70
 
71
 
72
  def pad_and_refine_tensor(
@@ -84,37 +80,34 @@ def pad_and_refine_tensor(
84
  """
85
  total_len = waveform.shape[1]
86
 
87
- start_s = max(start_s - pad_ms / 1000, 0)
88
- end_s = min(end_s + pad_ms / 1000, total_len / sr)
89
-
90
- start_idx = int(start_s * sr)
91
- end_idx = int(end_s * sr)
92
 
93
  if end_idx <= start_idx:
94
  return None
95
 
96
- seg = waveform[:, start_idx:end_idx]
97
- if seg.numel() == 0:
98
  return None
99
 
100
  # RMS energy
101
- rms = torch.sqrt(torch.mean(seg ** 2, dim=0))
102
- if rms.numel() == 0:
103
- return None
104
 
105
- threshold = torch.quantile(rms, 0.2)
106
- valid = torch.where(rms > threshold)[0]
107
 
108
  if valid.numel() == 0:
109
  return None
110
 
111
- refined_start = start_idx + valid[0].item()
112
- refined_end = start_idx + valid[-1].item()
113
 
114
- if refined_end - refined_start < (min_duration_ms / 1000) * sr:
 
 
 
115
  return None
116
 
117
- return refined_start, refined_end
118
 
119
 
120
  # =========================
@@ -143,7 +136,7 @@ class Processor:
143
 
144
  import asyncio
145
 
146
- total_start = time.time()
147
 
148
  # Step 1: Convert to WAV
149
  logger.info("Step 1: Converting audio to WAV 16kHz...")
@@ -158,37 +151,35 @@ class Processor:
158
 
159
  # Step 3: Diarization
160
  logger.info("Step 3: Running diarization...")
 
161
  try:
162
- diar_segments = await DiarizationService.diarize_async(wav_path)
163
  except Exception as e:
164
  logger.error(f"Diarization failed: {e}")
165
- # Fallback: create single segment for whole audio
166
- diar_segments = [SpeakerSegment(
167
- start=0.0,
168
- end=duration,
169
- speaker="Speaker 1"
170
- )]
171
 
172
  # Sort by start time
173
- diar_segments.sort(key=lambda x: x.start)
174
 
175
 
176
  # Step 4: Refine segment boundaries
177
  refined_segments: List[SpeakerSegment] = []
178
 
179
- for seg in diar_segments:
180
- start, end = seg.start, seg.end
 
181
 
182
  if pad_refine:
183
- refined = pad_and_refine_tensor(waveform, sr, start, end)
184
- if refined is None:
185
- start_idx = int(start * sr)
186
- end_idx = int(end * sr)
187
- else:
188
  start_idx, end_idx = refined
189
- else:
190
- start_idx = int(start * sr)
191
- end_idx = int(end * sr)
192
 
193
  if end_idx <= start_idx:
194
  continue
@@ -197,12 +188,17 @@ class Processor:
197
  SpeakerSegment(
198
  start=start_idx / sr,
199
  end=end_idx / sr,
200
- speaker=seg.speaker
201
  )
202
  )
 
 
 
 
 
 
203
 
204
  # Step 5: Transcribe
205
- logger.info(f"Step 5: Transcribing {len(refined_segments)} segments...")
206
  vad_options = None
207
  if vad_filter:
208
  vad_options = {
@@ -213,49 +209,71 @@ class Processor:
213
  }
214
 
215
  processed_segments: List[TranscriptSegment] = []
216
- unique_speakers = set()
 
 
 
 
 
 
 
217
 
218
- for idx, seg in enumerate(refined_segments):
219
- logger.info(f"Transcribing segment {idx+1}/{len(refined_segments)} ({seg.speaker})...")
220
- start_sample = int(seg.start * sr)
221
- end_sample = int(seg.end * sr)
222
- if end_sample <= start_sample:
223
  continue
224
- y_seg = waveform[:, start_sample:end_sample]
225
  try:
226
  text = await TranscriptionService.transcribe_segment_async(
227
- audio_array=y_seg,
228
  model_name=model_name,
229
  language=language,
230
  vad_options=vad_options,
231
  beam_size=beam_size,
232
  temperature=temperature,
233
  best_of=best_of,
234
- initial_prompt=initial_prompt
235
  )
236
- if text.strip():
237
- unique_speakers.add(seg.speaker)
238
- processed_segments.append(TranscriptSegment(start=seg.start, end=seg.end, speaker=seg.speaker, text=text.strip()))
239
  except Exception as e:
240
- logger.error(f"Error transcribing segment {idx}: {e}")
 
 
 
241
  continue
242
 
243
- processing_time = time.time() - total_start
244
- logger.info(f"Processing complete: {len(processed_segments)} segments, {len(unique_speakers)} speakers in {processing_time:.1f}s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
- # Step 6: Generate outputs
247
- txt_content = cls._generate_txt(processed_segments, unique_speakers, processing_time, duration)
248
  csv_content = cls._generate_csv(processed_segments)
249
 
250
  return ProcessingResult(
251
  segments=processed_segments,
252
- speaker_count=len(unique_speakers),
253
  duration=duration,
254
  processing_time=processing_time,
255
  txt_content=txt_content,
256
- csv_content=csv_content
257
  )
258
-
259
  @classmethod
260
  def _generate_txt(cls, segments: List[TranscriptSegment], speakers: set, processing_time: float, duration: float) -> str:
261
  lines = [
 
60
 
61
 
62
  def format_timestamp(seconds: float) -> str:
63
+ m = int(seconds // 60)
64
+ s = seconds % 60
65
+ return f"{m:02d}:{s:06.3f}"
 
 
 
 
66
 
67
 
68
  def pad_and_refine_tensor(
 
80
  """
81
  total_len = waveform.shape[1]
82
 
83
+ start_idx = max(int((start_s - pad_ms / 1000) * sr), 0)
84
+ end_idx = min(int((end_s + pad_ms / 1000) * sr), total_len)
 
 
 
85
 
86
  if end_idx <= start_idx:
87
  return None
88
 
89
+ segment = waveform[:, start_idx:end_idx]
90
+ if segment .numel() == 0:
91
  return None
92
 
93
  # RMS energy
94
+ rms = torch.sqrt(torch.mean(segment ** 2) + 1e-9)
95
+ threshold = rms / silence_db_delta
 
96
 
97
+ energy = torch.abs(segment)
98
+ valid = torch.where(energy > threshold)[0]
99
 
100
  if valid.numel() == 0:
101
  return None
102
 
 
 
103
 
104
+ new_start = start_idx + valid[0].item()
105
+ new_end = start_idx + valid[-1].item()
106
+
107
+ if new_end - new_start < int(min_duration_ms / 1000 * sr):
108
  return None
109
 
110
+ return new_start, new_end
111
 
112
 
113
  # =========================
 
136
 
137
  import asyncio
138
 
139
+ t0= time.time()
140
 
141
  # Step 1: Convert to WAV
142
  logger.info("Step 1: Converting audio to WAV 16kHz...")
 
151
 
152
  # Step 3: Diarization
153
  logger.info("Step 3: Running diarization...")
154
+
155
  try:
156
+ diarization_segments = await DiarizationService.diarize_async(wav_path)
157
  except Exception as e:
158
  logger.error(f"Diarization failed: {e}")
159
+ diarization_segments = []
160
+
161
+ if not diarization_segments:
162
+ diarization_segments = [
163
+ SpeakerSegment(0.0, duration, "Speaker 1")
164
+ ]
165
 
166
  # Sort by start time
167
+ diarization_segments.sort(key=lambda x: x.start)
168
 
169
 
170
  # Step 4: Refine segment boundaries
171
  refined_segments: List[SpeakerSegment] = []
172
 
173
+ for seg in diarization_segments:
174
+ start_idx = int(seg.start * sr)
175
+ end_idx = int(seg.end * sr)
176
 
177
  if pad_refine:
178
+ refined = pad_and_refine_tensor(
179
+ waveform, sr, seg.start, seg.end
180
+ )
181
+ if refined:
 
182
  start_idx, end_idx = refined
 
 
 
183
 
184
  if end_idx <= start_idx:
185
  continue
 
188
  SpeakerSegment(
189
  start=start_idx / sr,
190
  end=end_idx / sr,
191
+ speaker=seg.speaker or "Speaker 1"
192
  )
193
  )
194
+ if not refined_segments:
195
+ refined_segments = [
196
+ SpeakerSegment(0.0, duration, "Speaker 1")
197
+ ]
198
+
199
+ logger.info(f"Refined segments: {len(refined_segments)}")
200
 
201
  # Step 5: Transcribe
 
202
  vad_options = None
203
  if vad_filter:
204
  vad_options = {
 
209
  }
210
 
211
  processed_segments: List[TranscriptSegment] = []
212
+ speakers = set()
213
+
214
+ for seg in refined_segments:
215
+ start = int(seg.start * sr)
216
+ end = int(seg.end * sr)
217
+
218
+ if end <= start:
219
+ continue
220
 
221
+ audio_slice = y[start:end]
222
+ if audio_slice.size < sr * 0.25:
 
 
 
223
  continue
224
+
225
  try:
226
  text = await TranscriptionService.transcribe_segment_async(
227
+ audio_array=audio_slice,
228
  model_name=model_name,
229
  language=language,
230
  vad_options=vad_options,
231
  beam_size=beam_size,
232
  temperature=temperature,
233
  best_of=best_of,
234
+ initial_prompt=initial_prompt,
235
  )
 
 
 
236
  except Exception as e:
237
+ logger.error(f"Transcribe error: {e}")
238
+ continue
239
+
240
+ if not text or not text.strip():
241
  continue
242
 
243
+ processed_segments.append(
244
+ TranscriptSegment(
245
+ start=seg.start,
246
+ end=seg.end,
247
+ speaker=seg.speaker,
248
+ text=text.strip(),
249
+ )
250
+ )
251
+ speakers.add(seg.speaker)
252
+
253
+ if not processed_segments:
254
+ processed_segments = [
255
+ TranscriptSegment(
256
+ start=0.0,
257
+ end=duration,
258
+ speaker="Speaker 1",
259
+ text="(No speech detected)"
260
+ )
261
+ ]
262
+ speakers.add("Speaker 1")
263
+
264
+ processing_time = time.time() - t0
265
 
266
+ txt_content = cls._generate_txt(processed_segments, speakers, processing_time, duration)
 
267
  csv_content = cls._generate_csv(processed_segments)
268
 
269
  return ProcessingResult(
270
  segments=processed_segments,
271
+ speaker_count=len(speakers),
272
  duration=duration,
273
  processing_time=processing_time,
274
  txt_content=txt_content,
275
+ csv_content=csv_content,
276
  )
 
277
  @classmethod
278
  def _generate_txt(cls, segments: List[TranscriptSegment], speakers: set, processing_time: float, duration: float) -> str:
279
  lines = [