#!/usr/bin/env python3 """ Module de traitement unifié pour STT + Diarization. Utilisé par le Space Gradio. """ import os import sys from pathlib import Path from typing import List, Dict, Any import json # Imports pour pyannote try: from pyannote.audio import Pipeline HAS_PYANNOTE = True except ImportError: HAS_PYANNOTE = False # Imports pour Whisper et Transformers try: import whisper import torch HAS_WHISPER = True except ImportError: HAS_WHISPER = False try: from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor HAS_TRANSFORMERS = True except ImportError: HAS_TRANSFORMERS = False # Corriger le problème PyTorch 2.6 avec weights_only if hasattr(torch.serialization, 'add_safe_globals'): try: torch.serialization.add_safe_globals([torch.torch_version.TorchVersion]) except: pass import numpy as np import librosa import soundfile as sf def convert_audio_if_needed(audio_path: str) -> str: """ Convertit l'audio en WAV si nécessaire. Returns: Chemin vers le fichier audio (WAV si conversion nécessaire) """ ext = Path(audio_path).suffix.lower() supported_formats = {'.wav', '.flac', '.ogg'} if ext in supported_formats: return audio_path if ext in {'.m4a', '.mp3', '.mp4', '.aac'}: wav_path = str(Path(audio_path).with_suffix('.wav')) if os.path.exists(wav_path): return wav_path try: y, sr = librosa.load(audio_path, sr=16000, mono=True) sf.write(wav_path, y, sr) return wav_path except Exception as e: return audio_path return audio_path def run_diarization(audio_path: str, hf_token: str, model_name: str = "pyannote/speaker-diarization-community-1") -> List[Dict[str, Any]]: """Exécute la diarisation avec pyannote.""" if not HAS_PYANNOTE: raise ImportError("pyannote.audio n'est pas installé") # Convertir l'audio en WAV si nécessaire audio_path_converted = convert_audio_if_needed(audio_path) # Configurer le token if hf_token: try: from huggingface_hub import login login(token=hf_token, add_to_git_credential=False) except Exception: pass try: pipeline = Pipeline.from_pretrained(model_name, token=hf_token) except Exception as e: if "plda" in str(e).lower() or "unexpected keyword" in str(e).lower(): pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) else: raise if torch.cuda.is_available(): pipeline = pipeline.to(torch.device("cuda")) diarization = pipeline(audio_path_converted) # Convertir en segments segments = [] speakers = sorted(diarization.labels()) speaker_mapping = {speaker: f"SPEAKER_{idx:02d}" for idx, speaker in enumerate(speakers)} for segment, track, speaker in diarization.itertracks(yield_label=True): normalized_speaker = speaker_mapping.get(speaker, speaker) segments.append({ "speaker": normalized_speaker, "start": segment.start, "end": segment.end }) segments.sort(key=lambda x: x["start"]) return segments def run_transcription(audio_path: str, device: str = None, hf_token: str = None) -> List[Dict[str, Any]]: """Exécute la transcription avec le modèle Whisper Large V3 French.""" if not HAS_WHISPER: raise ImportError("whisper n'est pas installé") if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" model_id = "bofenghuang/whisper-large-v3-french" # Utiliser Transformers pour charger le modèle try: if HAS_TRANSFORMERS: processor = AutoProcessor.from_pretrained(model_id, token=hf_token) model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, low_cpu_mem_usage=True, token=hf_token ) model.to(device) model.eval() # Charger l'audio audio_path_converted = convert_audio_if_needed(audio_path) waveform, sample_rate = librosa.load(audio_path_converted, sr=16000, mono=True) # Préparer les inputs inputs = processor( waveform, sampling_rate=sample_rate, return_tensors="pt" ) inputs = {k: v.to(device) for k, v in inputs.items()} # Transcription with torch.no_grad(): generated_ids = model.generate( inputs["input_features"], language="fr", task="transcribe", return_timestamps=True ) # Décoder avec timestamps result = processor.batch_decode( generated_ids, skip_special_tokens=False, output_word_timestamps=True )[0] # Extraire les segments avec timestamps depuis les tokens tokens = generated_ids[0].cpu().numpy() segments = [] current_segment = {"start": None, "end": None, "text": []} # Parser les tokens pour extraire les timestamps for token_id in tokens: token_text = processor.tokenizer.decode([token_id], skip_special_tokens=False) # Chercher les tokens de timestamp <|X.XX|> if "<|" in token_text and "|>" in token_text: try: start_idx = token_text.find("<|") + 2 end_idx = token_text.find("|>") if start_idx < end_idx: timestamp_str = token_text[start_idx:end_idx] timestamp = float(timestamp_str) if current_segment["start"] is None: current_segment["start"] = timestamp else: current_segment["end"] = timestamp text = " ".join(current_segment["text"]).strip() if text: segments.append({ "start": current_segment["start"], "end": current_segment["end"], "text": text }) current_segment = {"start": timestamp, "end": None, "text": []} except (ValueError, IndexError): pass else: if token_text.strip() and not any(x in token_text for x in ["<|", "|>", "<|startof", "<|endof", "<|notimestamps"]): current_segment["text"].append(token_text) # Ajouter le dernier segment if current_segment["text"]: text = " ".join(current_segment["text"]).strip() if text: duration = len(waveform) / sample_rate segments.append({ "start": current_segment["start"] if current_segment["start"] is not None else 0.0, "end": current_segment["end"] if current_segment["end"] is not None else duration, "text": text }) # Si on n'a pas réussi à extraire les timestamps, utiliser une approche de fallback if not segments or all(seg.get("start") is None for seg in segments): # Décoder le texte complet result_text = processor.decode(generated_ids[0], skip_special_tokens=True) # Diviser en phrases sentences = [] for sent in result_text.split('. '): if sent.strip(): sentences.append(sent.strip() + ('.' if not sent.strip().endswith('.') else '')) if not sentences: sentences = [result_text.strip()] # Créer des segments temporels basés sur la durée duration = len(waveform) / sample_rate segments = [] time_per_sentence = duration / len(sentences) for i, sentence in enumerate(sentences): start_time = i * time_per_sentence end_time = min((i + 1) * time_per_sentence, duration) segments.append({ "start": start_time, "end": end_time, "text": sentence }) return segments except Exception as e: # Fallback sur Whisper natif model = whisper.load_model("large-v3", device=device) audio_path_converted = convert_audio_if_needed(audio_path) result = model.transcribe( audio_path_converted, language="fr", task="transcribe", verbose=False ) segments = [] for seg in result["segments"]: segments.append({ "start": seg["start"], "end": seg["end"], "text": seg["text"].strip() }) return segments def combine_diarization_transcription( diarization_segments: List[Dict[str, Any]], transcription_segments: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """Combine diarisation et transcription.""" combined = [] # Créer une timeline de diarisation diar_timeline = [ (seg["start"], seg["end"], seg["speaker"]) for seg in diarization_segments ] diar_timeline.sort() def get_speaker_for_segment(seg_start: float, seg_end: float) -> str: """Détermine le locuteur pour un segment.""" speaker_time = {} for diar_start, diar_end, speaker in diar_timeline: overlap_start = max(seg_start, diar_start) overlap_end = min(seg_end, diar_end) overlap_duration = max(0, overlap_end - overlap_start) if overlap_duration > 0: speaker_time[speaker] = speaker_time.get(speaker, 0) + overlap_duration if speaker_time: return max(speaker_time, key=speaker_time.get) else: # Trouver le locuteur le plus proche center_time = (seg_start + seg_end) / 2.0 min_dist = float('inf') closest_speaker = "SPEAKER_00" for diar_start, diar_end, speaker in diar_timeline: if center_time < diar_start: dist = diar_start - center_time elif center_time >= diar_end: dist = center_time - diar_end else: return speaker if dist < min_dist: min_dist = dist closest_speaker = speaker return closest_speaker # Combiner les segments for trans_seg in transcription_segments: speaker = get_speaker_for_segment(trans_seg["start"], trans_seg["end"]) combined.append({ "speaker": speaker, "start": trans_seg["start"], "end": trans_seg["end"], "text": trans_seg["text"] }) return combined def format_output(combined_segments: List[Dict[str, Any]]) -> str: """Formate la sortie en texte lisible: "Speaker A : blabla".""" output_lines = [] current_speaker = None current_texts = [] for seg in combined_segments: speaker = seg["speaker"] text = seg["text"] if speaker != current_speaker: # Écrire le groupe précédent if current_speaker and current_texts: speaker_num = int(current_speaker.replace("SPEAKER_", "")) speaker_name = f"Speaker {chr(65 + speaker_num)}" output_lines.append(f"{speaker_name} : {' '.join(current_texts)}") # Nouveau locuteur current_speaker = speaker current_texts = [text] else: # Même locuteur, ajouter le texte current_texts.append(text) # Écrire le dernier groupe if current_speaker and current_texts: speaker_num = int(current_speaker.replace("SPEAKER_", "")) speaker_name = f"Speaker {chr(65 + speaker_num)}" output_lines.append(f"{speaker_name} : {' '.join(current_texts)}") return "\n\n".join(output_lines)