Spaces:
Runtime error
Runtime error
File size: 13,095 Bytes
e6e14b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
#!/usr/bin/env python3
"""
Module de traitement unifié pour STT + Diarization.
Utilisé par le Space Gradio.
"""
import os
import sys
from pathlib import Path
from typing import List, Dict, Any
import json
# Imports pour pyannote
try:
from pyannote.audio import Pipeline
HAS_PYANNOTE = True
except ImportError:
HAS_PYANNOTE = False
# Imports pour Whisper et Transformers
try:
import whisper
import torch
HAS_WHISPER = True
except ImportError:
HAS_WHISPER = False
try:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
# Corriger le problème PyTorch 2.6 avec weights_only
if hasattr(torch.serialization, 'add_safe_globals'):
try:
torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
except:
pass
import numpy as np
import librosa
import soundfile as sf
def convert_audio_if_needed(audio_path: str) -> str:
"""
Convertit l'audio en WAV si nécessaire.
Returns:
Chemin vers le fichier audio (WAV si conversion nécessaire)
"""
ext = Path(audio_path).suffix.lower()
supported_formats = {'.wav', '.flac', '.ogg'}
if ext in supported_formats:
return audio_path
if ext in {'.m4a', '.mp3', '.mp4', '.aac'}:
wav_path = str(Path(audio_path).with_suffix('.wav'))
if os.path.exists(wav_path):
return wav_path
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
sf.write(wav_path, y, sr)
return wav_path
except Exception as e:
return audio_path
return audio_path
def run_diarization(audio_path: str, hf_token: str, model_name: str = "pyannote/speaker-diarization-community-1") -> List[Dict[str, Any]]:
"""Exécute la diarisation avec pyannote."""
if not HAS_PYANNOTE:
raise ImportError("pyannote.audio n'est pas installé")
# Convertir l'audio en WAV si nécessaire
audio_path_converted = convert_audio_if_needed(audio_path)
# Configurer le token
if hf_token:
try:
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
except Exception:
pass
try:
pipeline = Pipeline.from_pretrained(model_name, token=hf_token)
except Exception as e:
if "plda" in str(e).lower() or "unexpected keyword" in str(e).lower():
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
else:
raise
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
diarization = pipeline(audio_path_converted)
# Convertir en segments
segments = []
speakers = sorted(diarization.labels())
speaker_mapping = {speaker: f"SPEAKER_{idx:02d}" for idx, speaker in enumerate(speakers)}
for segment, track, speaker in diarization.itertracks(yield_label=True):
normalized_speaker = speaker_mapping.get(speaker, speaker)
segments.append({
"speaker": normalized_speaker,
"start": segment.start,
"end": segment.end
})
segments.sort(key=lambda x: x["start"])
return segments
def run_transcription(audio_path: str, device: str = None, hf_token: str = None) -> List[Dict[str, Any]]:
"""Exécute la transcription avec le modèle Whisper Large V3 French."""
if not HAS_WHISPER:
raise ImportError("whisper n'est pas installé")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "bofenghuang/whisper-large-v3-french"
# Utiliser Transformers pour charger le modèle
try:
if HAS_TRANSFORMERS:
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
token=hf_token
)
model.to(device)
model.eval()
# Charger l'audio
audio_path_converted = convert_audio_if_needed(audio_path)
waveform, sample_rate = librosa.load(audio_path_converted, sr=16000, mono=True)
# Préparer les inputs
inputs = processor(
waveform,
sampling_rate=sample_rate,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Transcription
with torch.no_grad():
generated_ids = model.generate(
inputs["input_features"],
language="fr",
task="transcribe",
return_timestamps=True
)
# Décoder avec timestamps
result = processor.batch_decode(
generated_ids,
skip_special_tokens=False,
output_word_timestamps=True
)[0]
# Extraire les segments avec timestamps depuis les tokens
tokens = generated_ids[0].cpu().numpy()
segments = []
current_segment = {"start": None, "end": None, "text": []}
# Parser les tokens pour extraire les timestamps
for token_id in tokens:
token_text = processor.tokenizer.decode([token_id], skip_special_tokens=False)
# Chercher les tokens de timestamp <|X.XX|>
if "<|" in token_text and "|>" in token_text:
try:
start_idx = token_text.find("<|") + 2
end_idx = token_text.find("|>")
if start_idx < end_idx:
timestamp_str = token_text[start_idx:end_idx]
timestamp = float(timestamp_str)
if current_segment["start"] is None:
current_segment["start"] = timestamp
else:
current_segment["end"] = timestamp
text = " ".join(current_segment["text"]).strip()
if text:
segments.append({
"start": current_segment["start"],
"end": current_segment["end"],
"text": text
})
current_segment = {"start": timestamp, "end": None, "text": []}
except (ValueError, IndexError):
pass
else:
if token_text.strip() and not any(x in token_text for x in ["<|", "|>", "<|startof", "<|endof", "<|notimestamps"]):
current_segment["text"].append(token_text)
# Ajouter le dernier segment
if current_segment["text"]:
text = " ".join(current_segment["text"]).strip()
if text:
duration = len(waveform) / sample_rate
segments.append({
"start": current_segment["start"] if current_segment["start"] is not None else 0.0,
"end": current_segment["end"] if current_segment["end"] is not None else duration,
"text": text
})
# Si on n'a pas réussi à extraire les timestamps, utiliser une approche de fallback
if not segments or all(seg.get("start") is None for seg in segments):
# Décoder le texte complet
result_text = processor.decode(generated_ids[0], skip_special_tokens=True)
# Diviser en phrases
sentences = []
for sent in result_text.split('. '):
if sent.strip():
sentences.append(sent.strip() + ('.' if not sent.strip().endswith('.') else ''))
if not sentences:
sentences = [result_text.strip()]
# Créer des segments temporels basés sur la durée
duration = len(waveform) / sample_rate
segments = []
time_per_sentence = duration / len(sentences)
for i, sentence in enumerate(sentences):
start_time = i * time_per_sentence
end_time = min((i + 1) * time_per_sentence, duration)
segments.append({
"start": start_time,
"end": end_time,
"text": sentence
})
return segments
except Exception as e:
# Fallback sur Whisper natif
model = whisper.load_model("large-v3", device=device)
audio_path_converted = convert_audio_if_needed(audio_path)
result = model.transcribe(
audio_path_converted,
language="fr",
task="transcribe",
verbose=False
)
segments = []
for seg in result["segments"]:
segments.append({
"start": seg["start"],
"end": seg["end"],
"text": seg["text"].strip()
})
return segments
def combine_diarization_transcription(
diarization_segments: List[Dict[str, Any]],
transcription_segments: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Combine diarisation et transcription."""
combined = []
# Créer une timeline de diarisation
diar_timeline = [
(seg["start"], seg["end"], seg["speaker"])
for seg in diarization_segments
]
diar_timeline.sort()
def get_speaker_for_segment(seg_start: float, seg_end: float) -> str:
"""Détermine le locuteur pour un segment."""
speaker_time = {}
for diar_start, diar_end, speaker in diar_timeline:
overlap_start = max(seg_start, diar_start)
overlap_end = min(seg_end, diar_end)
overlap_duration = max(0, overlap_end - overlap_start)
if overlap_duration > 0:
speaker_time[speaker] = speaker_time.get(speaker, 0) + overlap_duration
if speaker_time:
return max(speaker_time, key=speaker_time.get)
else:
# Trouver le locuteur le plus proche
center_time = (seg_start + seg_end) / 2.0
min_dist = float('inf')
closest_speaker = "SPEAKER_00"
for diar_start, diar_end, speaker in diar_timeline:
if center_time < diar_start:
dist = diar_start - center_time
elif center_time >= diar_end:
dist = center_time - diar_end
else:
return speaker
if dist < min_dist:
min_dist = dist
closest_speaker = speaker
return closest_speaker
# Combiner les segments
for trans_seg in transcription_segments:
speaker = get_speaker_for_segment(trans_seg["start"], trans_seg["end"])
combined.append({
"speaker": speaker,
"start": trans_seg["start"],
"end": trans_seg["end"],
"text": trans_seg["text"]
})
return combined
def format_output(combined_segments: List[Dict[str, Any]]) -> str:
"""Formate la sortie en texte lisible: "Speaker A : blabla"."""
output_lines = []
current_speaker = None
current_texts = []
for seg in combined_segments:
speaker = seg["speaker"]
text = seg["text"]
if speaker != current_speaker:
# Écrire le groupe précédent
if current_speaker and current_texts:
speaker_num = int(current_speaker.replace("SPEAKER_", ""))
speaker_name = f"Speaker {chr(65 + speaker_num)}"
output_lines.append(f"{speaker_name} : {' '.join(current_texts)}")
# Nouveau locuteur
current_speaker = speaker
current_texts = [text]
else:
# Même locuteur, ajouter le texte
current_texts.append(text)
# Écrire le dernier groupe
if current_speaker and current_texts:
speaker_num = int(current_speaker.replace("SPEAKER_", ""))
speaker_name = f"Speaker {chr(65 + speaker_num)}"
output_lines.append(f"{speaker_name} : {' '.join(current_texts)}")
return "\n\n".join(output_lines)
|