Spaces:
Running
Running
colab-user commited on
Commit ·
9a8a554
1
Parent(s): 65ee821
input stream
Browse files- app/services/processor.py +82 -64
app/services/processor.py
CHANGED
|
@@ -60,13 +60,9 @@ def convert_audio_to_wav(audio_path: Path) -> Path:
|
|
| 60 |
|
| 61 |
|
| 62 |
def format_timestamp(seconds: float) -> str:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
secs = seconds % 60
|
| 67 |
-
if hours > 0:
|
| 68 |
-
return f"{hours:02d}:{minutes:02d}:{secs:05.2f}"
|
| 69 |
-
return f"{minutes:02d}:{secs:05.2f}"
|
| 70 |
|
| 71 |
|
| 72 |
def pad_and_refine_tensor(
|
|
@@ -84,37 +80,34 @@ def pad_and_refine_tensor(
|
|
| 84 |
"""
|
| 85 |
total_len = waveform.shape[1]
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
start_idx = int(start_s * sr)
|
| 91 |
-
end_idx = int(end_s * sr)
|
| 92 |
|
| 93 |
if end_idx <= start_idx:
|
| 94 |
return None
|
| 95 |
|
| 96 |
-
|
| 97 |
-
if
|
| 98 |
return None
|
| 99 |
|
| 100 |
# RMS energy
|
| 101 |
-
rms = torch.sqrt(torch.mean(
|
| 102 |
-
|
| 103 |
-
return None
|
| 104 |
|
| 105 |
-
|
| 106 |
-
valid = torch.where(
|
| 107 |
|
| 108 |
if valid.numel() == 0:
|
| 109 |
return None
|
| 110 |
|
| 111 |
-
refined_start = start_idx + valid[0].item()
|
| 112 |
-
refined_end = start_idx + valid[-1].item()
|
| 113 |
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
return None
|
| 116 |
|
| 117 |
-
return
|
| 118 |
|
| 119 |
|
| 120 |
# =========================
|
|
@@ -143,7 +136,7 @@ class Processor:
|
|
| 143 |
|
| 144 |
import asyncio
|
| 145 |
|
| 146 |
-
|
| 147 |
|
| 148 |
# Step 1: Convert to WAV
|
| 149 |
logger.info("Step 1: Converting audio to WAV 16kHz...")
|
|
@@ -158,37 +151,35 @@ class Processor:
|
|
| 158 |
|
| 159 |
# Step 3: Diarization
|
| 160 |
logger.info("Step 3: Running diarization...")
|
|
|
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
except Exception as e:
|
| 164 |
logger.error(f"Diarization failed: {e}")
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
# Sort by start time
|
| 173 |
-
|
| 174 |
|
| 175 |
|
| 176 |
# Step 4: Refine segment boundaries
|
| 177 |
refined_segments: List[SpeakerSegment] = []
|
| 178 |
|
| 179 |
-
for seg in
|
| 180 |
-
|
|
|
|
| 181 |
|
| 182 |
if pad_refine:
|
| 183 |
-
refined = pad_and_refine_tensor(
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
else:
|
| 188 |
start_idx, end_idx = refined
|
| 189 |
-
else:
|
| 190 |
-
start_idx = int(start * sr)
|
| 191 |
-
end_idx = int(end * sr)
|
| 192 |
|
| 193 |
if end_idx <= start_idx:
|
| 194 |
continue
|
|
@@ -197,12 +188,17 @@ class Processor:
|
|
| 197 |
SpeakerSegment(
|
| 198 |
start=start_idx / sr,
|
| 199 |
end=end_idx / sr,
|
| 200 |
-
speaker=seg.speaker
|
| 201 |
)
|
| 202 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
# Step 5: Transcribe
|
| 205 |
-
logger.info(f"Step 5: Transcribing {len(refined_segments)} segments...")
|
| 206 |
vad_options = None
|
| 207 |
if vad_filter:
|
| 208 |
vad_options = {
|
|
@@ -213,49 +209,71 @@ class Processor:
|
|
| 213 |
}
|
| 214 |
|
| 215 |
processed_segments: List[TranscriptSegment] = []
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
start_sample = int(seg.start * sr)
|
| 221 |
-
end_sample = int(seg.end * sr)
|
| 222 |
-
if end_sample <= start_sample:
|
| 223 |
continue
|
| 224 |
-
|
| 225 |
try:
|
| 226 |
text = await TranscriptionService.transcribe_segment_async(
|
| 227 |
-
audio_array=
|
| 228 |
model_name=model_name,
|
| 229 |
language=language,
|
| 230 |
vad_options=vad_options,
|
| 231 |
beam_size=beam_size,
|
| 232 |
temperature=temperature,
|
| 233 |
best_of=best_of,
|
| 234 |
-
initial_prompt=initial_prompt
|
| 235 |
)
|
| 236 |
-
if text.strip():
|
| 237 |
-
unique_speakers.add(seg.speaker)
|
| 238 |
-
processed_segments.append(TranscriptSegment(start=seg.start, end=seg.end, speaker=seg.speaker, text=text.strip()))
|
| 239 |
except Exception as e:
|
| 240 |
-
logger.error(f"
|
|
|
|
|
|
|
|
|
|
| 241 |
continue
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
txt_content = cls._generate_txt(processed_segments, unique_speakers, processing_time, duration)
|
| 248 |
csv_content = cls._generate_csv(processed_segments)
|
| 249 |
|
| 250 |
return ProcessingResult(
|
| 251 |
segments=processed_segments,
|
| 252 |
-
speaker_count=len(
|
| 253 |
duration=duration,
|
| 254 |
processing_time=processing_time,
|
| 255 |
txt_content=txt_content,
|
| 256 |
-
csv_content=csv_content
|
| 257 |
)
|
| 258 |
-
|
| 259 |
@classmethod
|
| 260 |
def _generate_txt(cls, segments: List[TranscriptSegment], speakers: set, processing_time: float, duration: float) -> str:
|
| 261 |
lines = [
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
def format_timestamp(seconds: float) -> str:
|
| 63 |
+
m = int(seconds // 60)
|
| 64 |
+
s = seconds % 60
|
| 65 |
+
return f"{m:02d}:{s:06.3f}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
def pad_and_refine_tensor(
|
|
|
|
| 80 |
"""
|
| 81 |
total_len = waveform.shape[1]
|
| 82 |
|
| 83 |
+
start_idx = max(int((start_s - pad_ms / 1000) * sr), 0)
|
| 84 |
+
end_idx = min(int((end_s + pad_ms / 1000) * sr), total_len)
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
if end_idx <= start_idx:
|
| 87 |
return None
|
| 88 |
|
| 89 |
+
segment = waveform[:, start_idx:end_idx]
|
| 90 |
+
if segment .numel() == 0:
|
| 91 |
return None
|
| 92 |
|
| 93 |
# RMS energy
|
| 94 |
+
rms = torch.sqrt(torch.mean(segment ** 2) + 1e-9)
|
| 95 |
+
threshold = rms / silence_db_delta
|
|
|
|
| 96 |
|
| 97 |
+
energy = torch.abs(segment)
|
| 98 |
+
valid = torch.where(energy > threshold)[0]
|
| 99 |
|
| 100 |
if valid.numel() == 0:
|
| 101 |
return None
|
| 102 |
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
new_start = start_idx + valid[0].item()
|
| 105 |
+
new_end = start_idx + valid[-1].item()
|
| 106 |
+
|
| 107 |
+
if new_end - new_start < int(min_duration_ms / 1000 * sr):
|
| 108 |
return None
|
| 109 |
|
| 110 |
+
return new_start, new_end
|
| 111 |
|
| 112 |
|
| 113 |
# =========================
|
|
|
|
| 136 |
|
| 137 |
import asyncio
|
| 138 |
|
| 139 |
+
t0= time.time()
|
| 140 |
|
| 141 |
# Step 1: Convert to WAV
|
| 142 |
logger.info("Step 1: Converting audio to WAV 16kHz...")
|
|
|
|
| 151 |
|
| 152 |
# Step 3: Diarization
|
| 153 |
logger.info("Step 3: Running diarization...")
|
| 154 |
+
|
| 155 |
try:
|
| 156 |
+
diarization_segments = await DiarizationService.diarize_async(wav_path)
|
| 157 |
except Exception as e:
|
| 158 |
logger.error(f"Diarization failed: {e}")
|
| 159 |
+
diarization_segments = []
|
| 160 |
+
|
| 161 |
+
if not diarization_segments:
|
| 162 |
+
diarization_segments = [
|
| 163 |
+
SpeakerSegment(0.0, duration, "Speaker 1")
|
| 164 |
+
]
|
| 165 |
|
| 166 |
# Sort by start time
|
| 167 |
+
diarization_segments.sort(key=lambda x: x.start)
|
| 168 |
|
| 169 |
|
| 170 |
# Step 4: Refine segment boundaries
|
| 171 |
refined_segments: List[SpeakerSegment] = []
|
| 172 |
|
| 173 |
+
for seg in diarization_segments:
|
| 174 |
+
start_idx = int(seg.start * sr)
|
| 175 |
+
end_idx = int(seg.end * sr)
|
| 176 |
|
| 177 |
if pad_refine:
|
| 178 |
+
refined = pad_and_refine_tensor(
|
| 179 |
+
waveform, sr, seg.start, seg.end
|
| 180 |
+
)
|
| 181 |
+
if refined:
|
|
|
|
| 182 |
start_idx, end_idx = refined
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
if end_idx <= start_idx:
|
| 185 |
continue
|
|
|
|
| 188 |
SpeakerSegment(
|
| 189 |
start=start_idx / sr,
|
| 190 |
end=end_idx / sr,
|
| 191 |
+
speaker=seg.speaker or "Speaker 1"
|
| 192 |
)
|
| 193 |
)
|
| 194 |
+
if not refined_segments:
|
| 195 |
+
refined_segments = [
|
| 196 |
+
SpeakerSegment(0.0, duration, "Speaker 1")
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
logger.info(f"Refined segments: {len(refined_segments)}")
|
| 200 |
|
| 201 |
# Step 5: Transcribe
|
|
|
|
| 202 |
vad_options = None
|
| 203 |
if vad_filter:
|
| 204 |
vad_options = {
|
|
|
|
| 209 |
}
|
| 210 |
|
| 211 |
processed_segments: List[TranscriptSegment] = []
|
| 212 |
+
speakers = set()
|
| 213 |
+
|
| 214 |
+
for seg in refined_segments:
|
| 215 |
+
start = int(seg.start * sr)
|
| 216 |
+
end = int(seg.end * sr)
|
| 217 |
+
|
| 218 |
+
if end <= start:
|
| 219 |
+
continue
|
| 220 |
|
| 221 |
+
audio_slice = y[start:end]
|
| 222 |
+
if audio_slice.size < sr * 0.25:
|
|
|
|
|
|
|
|
|
|
| 223 |
continue
|
| 224 |
+
|
| 225 |
try:
|
| 226 |
text = await TranscriptionService.transcribe_segment_async(
|
| 227 |
+
audio_array=audio_slice,
|
| 228 |
model_name=model_name,
|
| 229 |
language=language,
|
| 230 |
vad_options=vad_options,
|
| 231 |
beam_size=beam_size,
|
| 232 |
temperature=temperature,
|
| 233 |
best_of=best_of,
|
| 234 |
+
initial_prompt=initial_prompt,
|
| 235 |
)
|
|
|
|
|
|
|
|
|
|
| 236 |
except Exception as e:
|
| 237 |
+
logger.error(f"Transcribe error: {e}")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
if not text or not text.strip():
|
| 241 |
continue
|
| 242 |
|
| 243 |
+
processed_segments.append(
|
| 244 |
+
TranscriptSegment(
|
| 245 |
+
start=seg.start,
|
| 246 |
+
end=seg.end,
|
| 247 |
+
speaker=seg.speaker,
|
| 248 |
+
text=text.strip(),
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
speakers.add(seg.speaker)
|
| 252 |
+
|
| 253 |
+
if not processed_segments:
|
| 254 |
+
processed_segments = [
|
| 255 |
+
TranscriptSegment(
|
| 256 |
+
start=0.0,
|
| 257 |
+
end=duration,
|
| 258 |
+
speaker="Speaker 1",
|
| 259 |
+
text="(No speech detected)"
|
| 260 |
+
)
|
| 261 |
+
]
|
| 262 |
+
speakers.add("Speaker 1")
|
| 263 |
+
|
| 264 |
+
processing_time = time.time() - t0
|
| 265 |
|
| 266 |
+
txt_content = cls._generate_txt(processed_segments, speakers, processing_time, duration)
|
|
|
|
| 267 |
csv_content = cls._generate_csv(processed_segments)
|
| 268 |
|
| 269 |
return ProcessingResult(
|
| 270 |
segments=processed_segments,
|
| 271 |
+
speaker_count=len(speakers),
|
| 272 |
duration=duration,
|
| 273 |
processing_time=processing_time,
|
| 274 |
txt_content=txt_content,
|
| 275 |
+
csv_content=csv_content,
|
| 276 |
)
|
|
|
|
| 277 |
@classmethod
|
| 278 |
def _generate_txt(cls, segments: List[TranscriptSegment], speakers: set, processing_time: float, duration: float) -> str:
|
| 279 |
lines = [
|