gilbert-stt-diarization / processing.py
mathisescriva
Initial commit: STT + Diarization pipeline unifié
e6e14b8
#!/usr/bin/env python3
"""
Module de traitement unifié pour STT + Diarization.
Utilisé par le Space Gradio.
"""
import os
import sys
from pathlib import Path
from typing import List, Dict, Any
import json
# Imports pour pyannote
try:
from pyannote.audio import Pipeline
HAS_PYANNOTE = True
except ImportError:
HAS_PYANNOTE = False
# Imports pour Whisper et Transformers
try:
import whisper
import torch
HAS_WHISPER = True
except ImportError:
HAS_WHISPER = False
try:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
HAS_TRANSFORMERS = True
except ImportError:
HAS_TRANSFORMERS = False
# Corriger le problème PyTorch 2.6 avec weights_only
if hasattr(torch.serialization, 'add_safe_globals'):
try:
torch.serialization.add_safe_globals([torch.torch_version.TorchVersion])
except:
pass
import numpy as np
import librosa
import soundfile as sf
def convert_audio_if_needed(audio_path: str) -> str:
"""
Convertit l'audio en WAV si nécessaire.
Returns:
Chemin vers le fichier audio (WAV si conversion nécessaire)
"""
ext = Path(audio_path).suffix.lower()
supported_formats = {'.wav', '.flac', '.ogg'}
if ext in supported_formats:
return audio_path
if ext in {'.m4a', '.mp3', '.mp4', '.aac'}:
wav_path = str(Path(audio_path).with_suffix('.wav'))
if os.path.exists(wav_path):
return wav_path
try:
y, sr = librosa.load(audio_path, sr=16000, mono=True)
sf.write(wav_path, y, sr)
return wav_path
except Exception as e:
return audio_path
return audio_path
def run_diarization(audio_path: str, hf_token: str, model_name: str = "pyannote/speaker-diarization-community-1") -> List[Dict[str, Any]]:
"""Exécute la diarisation avec pyannote."""
if not HAS_PYANNOTE:
raise ImportError("pyannote.audio n'est pas installé")
# Convertir l'audio en WAV si nécessaire
audio_path_converted = convert_audio_if_needed(audio_path)
# Configurer le token
if hf_token:
try:
from huggingface_hub import login
login(token=hf_token, add_to_git_credential=False)
except Exception:
pass
try:
pipeline = Pipeline.from_pretrained(model_name, token=hf_token)
except Exception as e:
if "plda" in str(e).lower() or "unexpected keyword" in str(e).lower():
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token)
else:
raise
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
diarization = pipeline(audio_path_converted)
# Convertir en segments
segments = []
speakers = sorted(diarization.labels())
speaker_mapping = {speaker: f"SPEAKER_{idx:02d}" for idx, speaker in enumerate(speakers)}
for segment, track, speaker in diarization.itertracks(yield_label=True):
normalized_speaker = speaker_mapping.get(speaker, speaker)
segments.append({
"speaker": normalized_speaker,
"start": segment.start,
"end": segment.end
})
segments.sort(key=lambda x: x["start"])
return segments
def run_transcription(audio_path: str, device: str = None, hf_token: str = None) -> List[Dict[str, Any]]:
"""Exécute la transcription avec le modèle Whisper Large V3 French."""
if not HAS_WHISPER:
raise ImportError("whisper n'est pas installé")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "bofenghuang/whisper-large-v3-french"
# Utiliser Transformers pour charger le modèle
try:
if HAS_TRANSFORMERS:
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
low_cpu_mem_usage=True,
token=hf_token
)
model.to(device)
model.eval()
# Charger l'audio
audio_path_converted = convert_audio_if_needed(audio_path)
waveform, sample_rate = librosa.load(audio_path_converted, sr=16000, mono=True)
# Préparer les inputs
inputs = processor(
waveform,
sampling_rate=sample_rate,
return_tensors="pt"
)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Transcription
with torch.no_grad():
generated_ids = model.generate(
inputs["input_features"],
language="fr",
task="transcribe",
return_timestamps=True
)
# Décoder avec timestamps
result = processor.batch_decode(
generated_ids,
skip_special_tokens=False,
output_word_timestamps=True
)[0]
# Extraire les segments avec timestamps depuis les tokens
tokens = generated_ids[0].cpu().numpy()
segments = []
current_segment = {"start": None, "end": None, "text": []}
# Parser les tokens pour extraire les timestamps
for token_id in tokens:
token_text = processor.tokenizer.decode([token_id], skip_special_tokens=False)
# Chercher les tokens de timestamp <|X.XX|>
if "<|" in token_text and "|>" in token_text:
try:
start_idx = token_text.find("<|") + 2
end_idx = token_text.find("|>")
if start_idx < end_idx:
timestamp_str = token_text[start_idx:end_idx]
timestamp = float(timestamp_str)
if current_segment["start"] is None:
current_segment["start"] = timestamp
else:
current_segment["end"] = timestamp
text = " ".join(current_segment["text"]).strip()
if text:
segments.append({
"start": current_segment["start"],
"end": current_segment["end"],
"text": text
})
current_segment = {"start": timestamp, "end": None, "text": []}
except (ValueError, IndexError):
pass
else:
if token_text.strip() and not any(x in token_text for x in ["<|", "|>", "<|startof", "<|endof", "<|notimestamps"]):
current_segment["text"].append(token_text)
# Ajouter le dernier segment
if current_segment["text"]:
text = " ".join(current_segment["text"]).strip()
if text:
duration = len(waveform) / sample_rate
segments.append({
"start": current_segment["start"] if current_segment["start"] is not None else 0.0,
"end": current_segment["end"] if current_segment["end"] is not None else duration,
"text": text
})
# Si on n'a pas réussi à extraire les timestamps, utiliser une approche de fallback
if not segments or all(seg.get("start") is None for seg in segments):
# Décoder le texte complet
result_text = processor.decode(generated_ids[0], skip_special_tokens=True)
# Diviser en phrases
sentences = []
for sent in result_text.split('. '):
if sent.strip():
sentences.append(sent.strip() + ('.' if not sent.strip().endswith('.') else ''))
if not sentences:
sentences = [result_text.strip()]
# Créer des segments temporels basés sur la durée
duration = len(waveform) / sample_rate
segments = []
time_per_sentence = duration / len(sentences)
for i, sentence in enumerate(sentences):
start_time = i * time_per_sentence
end_time = min((i + 1) * time_per_sentence, duration)
segments.append({
"start": start_time,
"end": end_time,
"text": sentence
})
return segments
except Exception as e:
# Fallback sur Whisper natif
model = whisper.load_model("large-v3", device=device)
audio_path_converted = convert_audio_if_needed(audio_path)
result = model.transcribe(
audio_path_converted,
language="fr",
task="transcribe",
verbose=False
)
segments = []
for seg in result["segments"]:
segments.append({
"start": seg["start"],
"end": seg["end"],
"text": seg["text"].strip()
})
return segments
def combine_diarization_transcription(
diarization_segments: List[Dict[str, Any]],
transcription_segments: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Combine diarisation et transcription."""
combined = []
# Créer une timeline de diarisation
diar_timeline = [
(seg["start"], seg["end"], seg["speaker"])
for seg in diarization_segments
]
diar_timeline.sort()
def get_speaker_for_segment(seg_start: float, seg_end: float) -> str:
"""Détermine le locuteur pour un segment."""
speaker_time = {}
for diar_start, diar_end, speaker in diar_timeline:
overlap_start = max(seg_start, diar_start)
overlap_end = min(seg_end, diar_end)
overlap_duration = max(0, overlap_end - overlap_start)
if overlap_duration > 0:
speaker_time[speaker] = speaker_time.get(speaker, 0) + overlap_duration
if speaker_time:
return max(speaker_time, key=speaker_time.get)
else:
# Trouver le locuteur le plus proche
center_time = (seg_start + seg_end) / 2.0
min_dist = float('inf')
closest_speaker = "SPEAKER_00"
for diar_start, diar_end, speaker in diar_timeline:
if center_time < diar_start:
dist = diar_start - center_time
elif center_time >= diar_end:
dist = center_time - diar_end
else:
return speaker
if dist < min_dist:
min_dist = dist
closest_speaker = speaker
return closest_speaker
# Combiner les segments
for trans_seg in transcription_segments:
speaker = get_speaker_for_segment(trans_seg["start"], trans_seg["end"])
combined.append({
"speaker": speaker,
"start": trans_seg["start"],
"end": trans_seg["end"],
"text": trans_seg["text"]
})
return combined
def format_output(combined_segments: List[Dict[str, Any]]) -> str:
"""Formate la sortie en texte lisible: "Speaker A : blabla"."""
output_lines = []
current_speaker = None
current_texts = []
for seg in combined_segments:
speaker = seg["speaker"]
text = seg["text"]
if speaker != current_speaker:
# Écrire le groupe précédent
if current_speaker and current_texts:
speaker_num = int(current_speaker.replace("SPEAKER_", ""))
speaker_name = f"Speaker {chr(65 + speaker_num)}"
output_lines.append(f"{speaker_name} : {' '.join(current_texts)}")
# Nouveau locuteur
current_speaker = speaker
current_texts = [text]
else:
# Même locuteur, ajouter le texte
current_texts.append(text)
# Écrire le dernier groupe
if current_speaker and current_texts:
speaker_num = int(current_speaker.replace("SPEAKER_", ""))
speaker_name = f"Speaker {chr(65 + speaker_num)}"
output_lines.append(f"{speaker_name} : {' '.join(current_texts)}")
return "\n\n".join(output_lines)