Spaces:
Runtime error
Runtime error
liuyang
commited on
Commit
Β·
99ff812
1
Parent(s):
77abe68
fast whisper
Browse files- app.py +118 -73
- requirements.txt +6 -3
app.py
CHANGED
|
@@ -28,64 +28,38 @@ import subprocess
|
|
| 28 |
import os
|
| 29 |
import tempfile
|
| 30 |
import spaces
|
| 31 |
-
from
|
|
|
|
| 32 |
from pyannote.audio import Pipeline
|
| 33 |
import requests
|
| 34 |
import base64
|
| 35 |
|
| 36 |
-
#
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
"pip install flash-attn --no-build-isolation",
|
| 41 |
-
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
| 42 |
-
shell=True,
|
| 43 |
-
check=True
|
| 44 |
-
)
|
| 45 |
-
except subprocess.CalledProcessError:
|
| 46 |
-
print("Warning: Could not install flash-attn, falling back to default attention")
|
| 47 |
-
'''
|
| 48 |
-
|
| 49 |
-
# Create global Whisper pipeline
|
| 50 |
-
pipe = pipeline(
|
| 51 |
-
"automatic-speech-recognition",
|
| 52 |
-
model="openai/whisper-large-v3-turbo",
|
| 53 |
-
torch_dtype=torch.float16,
|
| 54 |
device="cuda",
|
| 55 |
-
|
| 56 |
-
return_timestamps=True,
|
| 57 |
)
|
|
|
|
| 58 |
|
| 59 |
# Create global diarization pipeline
|
| 60 |
diarization_pipe = None
|
| 61 |
try:
|
| 62 |
print("Loading diarization model...")
|
| 63 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 64 |
-
torch.backends.cudnn.allow_tf32 = True
|
| 65 |
-
torch.set_float32_matmul_precision('high')
|
| 66 |
-
|
| 67 |
diarization_pipe = Pipeline.from_pretrained(
|
| 68 |
"pyannote/speaker-diarization-3.1",
|
| 69 |
use_auth_token=os.getenv("HF_TOKEN"),
|
| 70 |
torch_dtype=torch.float16,
|
| 71 |
).to(torch.device("cuda"))
|
| 72 |
-
pipe.model.half() # FP16
|
| 73 |
-
|
| 74 |
-
for m in pipe.model.modules(): # compact LSTM weights
|
| 75 |
-
if isinstance(m, torch.nn.LSTM):
|
| 76 |
-
m.flatten_parameters()
|
| 77 |
-
|
| 78 |
-
pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
|
| 79 |
print("Diarization model loaded successfully")
|
| 80 |
except Exception as e:
|
| 81 |
-
import traceback
|
| 82 |
-
traceback.print_exc()
|
| 83 |
print(f"Could not load diarization model: {e}")
|
| 84 |
diarization_pipe = None
|
| 85 |
|
| 86 |
class WhisperTranscriber:
|
| 87 |
def __init__(self):
|
| 88 |
-
self.
|
| 89 |
self.diarization_model = diarization_pipe # Use global diarization pipeline
|
| 90 |
|
| 91 |
def convert_audio_format(self, audio_path):
|
|
@@ -137,42 +111,65 @@ class WhisperTranscriber:
|
|
| 137 |
|
| 138 |
@spaces.GPU
|
| 139 |
def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
|
| 140 |
-
"""Transcribe multiple audio segments"""
|
| 141 |
print(f"Transcribing {len(audio_segments)} audio segments...")
|
| 142 |
start_time = time.time()
|
| 143 |
|
| 144 |
-
# Prepare
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
results = []
|
|
|
|
|
|
|
| 154 |
for i, segment in enumerate(audio_segments):
|
| 155 |
print(f"Processing segment {i+1}/{len(audio_segments)}")
|
| 156 |
|
| 157 |
# Transcribe this segment
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
return_timestamps=True,
|
| 161 |
-
generate_kwargs=generate_kwargs,
|
| 162 |
-
chunk_length_s=30,
|
| 163 |
-
batch_size=128,
|
| 164 |
-
)
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
# Clean up temporary files
|
| 178 |
for segment in audio_segments:
|
|
@@ -182,7 +179,7 @@ class WhisperTranscriber:
|
|
| 182 |
transcription_time = time.time() - start_time
|
| 183 |
print(f"All segments transcribed in {transcription_time:.2f} seconds")
|
| 184 |
|
| 185 |
-
return results
|
| 186 |
|
| 187 |
def perform_diarization(self, audio_path, num_speakers=None):
|
| 188 |
"""Perform speaker diarization"""
|
|
@@ -228,6 +225,47 @@ class WhisperTranscriber:
|
|
| 228 |
|
| 229 |
return diarize_segments, detected_num_speakers
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
@spaces.GPU
|
| 232 |
def process_audio(self, audio_file, num_speakers=None, language=None,
|
| 233 |
translate=False, prompt=None, group_segments=True):
|
|
@@ -252,14 +290,19 @@ class WhisperTranscriber:
|
|
| 252 |
audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
|
| 253 |
|
| 254 |
# Step 4: Transcribe each segment
|
| 255 |
-
transcription_results = self.transcribe_audio_segments(
|
| 256 |
audio_segments, language, translate, prompt
|
| 257 |
)
|
| 258 |
|
| 259 |
-
# Step 5:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
return {
|
| 261 |
-
"
|
| 262 |
-
"
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
except Exception as e:
|
|
@@ -280,19 +323,21 @@ def format_segments_for_display(result):
|
|
| 280 |
if "error" in result:
|
| 281 |
return f"β Error: {result['error']}"
|
| 282 |
|
| 283 |
-
|
| 284 |
-
|
|
|
|
| 285 |
|
| 286 |
output = f"π― **Detection Results:**\n"
|
| 287 |
-
output += f"-
|
| 288 |
-
output += f"-
|
|
|
|
| 289 |
|
| 290 |
output += "π **Transcription:**\n\n"
|
| 291 |
|
| 292 |
-
for i, segment in enumerate(
|
| 293 |
-
start_time = str(datetime.timedelta(seconds=int(segment["
|
| 294 |
-
end_time = str(datetime.timedelta(seconds=int(segment["
|
| 295 |
-
speaker = segment.get("
|
| 296 |
text = segment["text"]
|
| 297 |
|
| 298 |
output += f"**{speaker}** ({start_time} β {end_time})\n"
|
|
|
|
| 28 |
import os
|
| 29 |
import tempfile
|
| 30 |
import spaces
|
| 31 |
+
from faster_whisper import WhisperModel
|
| 32 |
+
from faster_whisper.vad import VadOptions
|
| 33 |
from pyannote.audio import Pipeline
|
| 34 |
import requests
|
| 35 |
import base64
|
| 36 |
|
| 37 |
+
# Create global Whisper model
|
| 38 |
+
print("Loading Whisper model...")
|
| 39 |
+
model = WhisperModel(
|
| 40 |
+
"large-v3-turbo",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
device="cuda",
|
| 42 |
+
compute_type="float16",
|
|
|
|
| 43 |
)
|
| 44 |
+
print("Whisper model loaded successfully")
|
| 45 |
|
| 46 |
# Create global diarization pipeline
|
| 47 |
diarization_pipe = None
|
| 48 |
try:
|
| 49 |
print("Loading diarization model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
diarization_pipe = Pipeline.from_pretrained(
|
| 51 |
"pyannote/speaker-diarization-3.1",
|
| 52 |
use_auth_token=os.getenv("HF_TOKEN"),
|
| 53 |
torch_dtype=torch.float16,
|
| 54 |
).to(torch.device("cuda"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
print("Diarization model loaded successfully")
|
| 56 |
except Exception as e:
|
|
|
|
|
|
|
| 57 |
print(f"Could not load diarization model: {e}")
|
| 58 |
diarization_pipe = None
|
| 59 |
|
| 60 |
class WhisperTranscriber:
|
| 61 |
def __init__(self):
|
| 62 |
+
self.model = model # Use global Whisper model
|
| 63 |
self.diarization_model = diarization_pipe # Use global diarization pipeline
|
| 64 |
|
| 65 |
def convert_audio_format(self, audio_path):
|
|
|
|
| 111 |
|
| 112 |
@spaces.GPU
|
| 113 |
def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
|
| 114 |
+
"""Transcribe multiple audio segments using faster_whisper"""
|
| 115 |
print(f"Transcribing {len(audio_segments)} audio segments...")
|
| 116 |
start_time = time.time()
|
| 117 |
|
| 118 |
+
# Prepare options similar to replicate.py
|
| 119 |
+
options = dict(
|
| 120 |
+
language=language,
|
| 121 |
+
beam_size=5,
|
| 122 |
+
vad_filter=True,
|
| 123 |
+
vad_parameters=VadOptions(
|
| 124 |
+
max_speech_duration_s=self.model.feature_extractor.chunk_length,
|
| 125 |
+
min_speech_duration_ms=100,
|
| 126 |
+
speech_pad_ms=100,
|
| 127 |
+
threshold=0.25,
|
| 128 |
+
neg_threshold=0.2,
|
| 129 |
+
),
|
| 130 |
+
word_timestamps=True,
|
| 131 |
+
initial_prompt=prompt,
|
| 132 |
+
language_detection_segments=1,
|
| 133 |
+
task="translate" if translate else "transcribe",
|
| 134 |
+
)
|
| 135 |
|
| 136 |
results = []
|
| 137 |
+
detected_language = None
|
| 138 |
+
|
| 139 |
for i, segment in enumerate(audio_segments):
|
| 140 |
print(f"Processing segment {i+1}/{len(audio_segments)}")
|
| 141 |
|
| 142 |
# Transcribe this segment
|
| 143 |
+
segments, transcript_info = self.model.transcribe(segment["audio_path"], **options)
|
| 144 |
+
segments = list(segments)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Get detected language from first segment
|
| 147 |
+
if detected_language is None:
|
| 148 |
+
detected_language = transcript_info.language
|
| 149 |
|
| 150 |
+
# Process each transcribed segment
|
| 151 |
+
for seg in segments:
|
| 152 |
+
# Create result entry with detailed format like replicate.py
|
| 153 |
+
words_list = []
|
| 154 |
+
if seg.words:
|
| 155 |
+
for word in seg.words:
|
| 156 |
+
words_list.append({
|
| 157 |
+
"start": float(word.start) + segment["start"],
|
| 158 |
+
"end": float(word.end) + segment["start"],
|
| 159 |
+
"word": word.word,
|
| 160 |
+
"probability": word.probability,
|
| 161 |
+
"speaker": segment["speaker"]
|
| 162 |
+
})
|
| 163 |
+
|
| 164 |
+
results.append({
|
| 165 |
+
"start": float(seg.start) + segment["start"],
|
| 166 |
+
"end": float(seg.end) + segment["start"],
|
| 167 |
+
"text": seg.text,
|
| 168 |
+
"speaker": segment["speaker"],
|
| 169 |
+
"avg_logprob": seg.avg_logprob,
|
| 170 |
+
"words": words_list,
|
| 171 |
+
"duration": float(seg.end - seg.start)
|
| 172 |
+
})
|
| 173 |
|
| 174 |
# Clean up temporary files
|
| 175 |
for segment in audio_segments:
|
|
|
|
| 179 |
transcription_time = time.time() - start_time
|
| 180 |
print(f"All segments transcribed in {transcription_time:.2f} seconds")
|
| 181 |
|
| 182 |
+
return results, detected_language
|
| 183 |
|
| 184 |
def perform_diarization(self, audio_path, num_speakers=None):
|
| 185 |
"""Perform speaker diarization"""
|
|
|
|
| 225 |
|
| 226 |
return diarize_segments, detected_num_speakers
|
| 227 |
|
| 228 |
+
def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
|
| 229 |
+
"""Group consecutive segments from the same speaker"""
|
| 230 |
+
if not segments:
|
| 231 |
+
return segments
|
| 232 |
+
|
| 233 |
+
grouped_segments = []
|
| 234 |
+
current_group = segments[0].copy()
|
| 235 |
+
sentence_end_pattern = r"[.!?]+"
|
| 236 |
+
|
| 237 |
+
for segment in segments[1:]:
|
| 238 |
+
time_gap = segment["start"] - current_group["end"]
|
| 239 |
+
current_duration = current_group["end"] - current_group["start"]
|
| 240 |
+
|
| 241 |
+
# Conditions for combining segments
|
| 242 |
+
can_combine = (
|
| 243 |
+
segment["speaker"] == current_group["speaker"] and
|
| 244 |
+
time_gap <= max_gap and
|
| 245 |
+
current_duration < max_duration and
|
| 246 |
+
not re.search(sentence_end_pattern, current_group["text"][-1:])
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if can_combine:
|
| 250 |
+
# Merge segments
|
| 251 |
+
current_group["end"] = segment["end"]
|
| 252 |
+
current_group["text"] += " " + segment["text"]
|
| 253 |
+
current_group["words"].extend(segment["words"])
|
| 254 |
+
current_group["duration"] = current_group["end"] - current_group["start"]
|
| 255 |
+
else:
|
| 256 |
+
# Start new group
|
| 257 |
+
grouped_segments.append(current_group)
|
| 258 |
+
current_group = segment.copy()
|
| 259 |
+
|
| 260 |
+
grouped_segments.append(current_group)
|
| 261 |
+
|
| 262 |
+
# Clean up text
|
| 263 |
+
for segment in grouped_segments:
|
| 264 |
+
segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip()
|
| 265 |
+
segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"])
|
| 266 |
+
|
| 267 |
+
return grouped_segments
|
| 268 |
+
|
| 269 |
@spaces.GPU
|
| 270 |
def process_audio(self, audio_file, num_speakers=None, language=None,
|
| 271 |
translate=False, prompt=None, group_segments=True):
|
|
|
|
| 290 |
audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
|
| 291 |
|
| 292 |
# Step 4: Transcribe each segment
|
| 293 |
+
transcription_results, detected_language = self.transcribe_audio_segments(
|
| 294 |
audio_segments, language, translate, prompt
|
| 295 |
)
|
| 296 |
|
| 297 |
+
# Step 5: Group segments if requested
|
| 298 |
+
if group_segments:
|
| 299 |
+
transcription_results = self.group_segments_by_speaker(transcription_results)
|
| 300 |
+
|
| 301 |
+
# Step 6: Return in replicate.py format
|
| 302 |
return {
|
| 303 |
+
"segments": transcription_results,
|
| 304 |
+
"language": detected_language,
|
| 305 |
+
"num_speakers": detected_num_speakers
|
| 306 |
}
|
| 307 |
|
| 308 |
except Exception as e:
|
|
|
|
| 323 |
if "error" in result:
|
| 324 |
return f"β Error: {result['error']}"
|
| 325 |
|
| 326 |
+
segments = result.get("segments", [])
|
| 327 |
+
language = result.get("language", "unknown")
|
| 328 |
+
num_speakers = result.get("num_speakers", 1)
|
| 329 |
|
| 330 |
output = f"π― **Detection Results:**\n"
|
| 331 |
+
output += f"- Language: {language}\n"
|
| 332 |
+
output += f"- Speakers: {num_speakers}\n"
|
| 333 |
+
output += f"- Segments: {len(segments)}\n\n"
|
| 334 |
|
| 335 |
output += "π **Transcription:**\n\n"
|
| 336 |
|
| 337 |
+
for i, segment in enumerate(segments, 1):
|
| 338 |
+
start_time = str(datetime.timedelta(seconds=int(segment["start"])))
|
| 339 |
+
end_time = str(datetime.timedelta(seconds=int(segment["end"])))
|
| 340 |
+
speaker = segment.get("speaker", "SPEAKER_00")
|
| 341 |
text = segment["text"]
|
| 342 |
|
| 343 |
output += f"**{speaker}** ({start_time} β {end_time})\n"
|
requirements.txt
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
# 1. Do NOT pin torch/torchaudio here β keep the CUDA builds that come with the image
|
| 2 |
torch==2.4.0
|
| 3 |
transformers==4.48.0
|
| 4 |
-
#
|
| 5 |
-
https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
|
| 6 |
pydantic==2.10.6
|
| 7 |
|
| 8 |
-
# 2.
|
|
|
|
|
|
|
|
|
|
| 9 |
gradio==5.0.1
|
| 10 |
spaces>=0.19.0
|
| 11 |
pyannote.audio>=3.1.0
|
|
|
|
| 1 |
# 1. Do NOT pin torch/torchaudio here β keep the CUDA builds that come with the image
|
| 2 |
torch==2.4.0
|
| 3 |
transformers==4.48.0
|
| 4 |
+
# Removed flash-attention since faster-whisper handles this internally
|
| 5 |
+
# https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
|
| 6 |
pydantic==2.10.6
|
| 7 |
|
| 8 |
+
# 2. Main whisper model
|
| 9 |
+
faster-whisper>=1.0.0
|
| 10 |
+
|
| 11 |
+
# 3. Extra libs your app really needs
|
| 12 |
gradio==5.0.1
|
| 13 |
spaces>=0.19.0
|
| 14 |
pyannote.audio>=3.1.0
|