Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Module de traitement unifié pour STT + Diarization. | |
| Utilisé par le Space Gradio. | |
| """ | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Dict, Any | |
| import json | |
| # Imports pour pyannote | |
| try: | |
| from pyannote.audio import Pipeline | |
| HAS_PYANNOTE = True | |
| except ImportError: | |
| HAS_PYANNOTE = False | |
| # Imports pour Whisper et Transformers | |
| try: | |
| import whisper | |
| import torch | |
| HAS_WHISPER = True | |
| except ImportError: | |
| HAS_WHISPER = False | |
| try: | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor | |
| HAS_TRANSFORMERS = True | |
| except ImportError: | |
| HAS_TRANSFORMERS = False | |
| # Corriger le problème PyTorch 2.6 avec weights_only | |
| if hasattr(torch.serialization, 'add_safe_globals'): | |
| try: | |
| torch.serialization.add_safe_globals([torch.torch_version.TorchVersion]) | |
| except: | |
| pass | |
| import numpy as np | |
| import librosa | |
| import soundfile as sf | |
| def convert_audio_if_needed(audio_path: str) -> str: | |
| """ | |
| Convertit l'audio en WAV si nécessaire. | |
| Returns: | |
| Chemin vers le fichier audio (WAV si conversion nécessaire) | |
| """ | |
| ext = Path(audio_path).suffix.lower() | |
| supported_formats = {'.wav', '.flac', '.ogg'} | |
| if ext in supported_formats: | |
| return audio_path | |
| if ext in {'.m4a', '.mp3', '.mp4', '.aac'}: | |
| wav_path = str(Path(audio_path).with_suffix('.wav')) | |
| if os.path.exists(wav_path): | |
| return wav_path | |
| try: | |
| y, sr = librosa.load(audio_path, sr=16000, mono=True) | |
| sf.write(wav_path, y, sr) | |
| return wav_path | |
| except Exception as e: | |
| return audio_path | |
| return audio_path | |
| def run_diarization(audio_path: str, hf_token: str, model_name: str = "pyannote/speaker-diarization-community-1") -> List[Dict[str, Any]]: | |
| """Exécute la diarisation avec pyannote.""" | |
| if not HAS_PYANNOTE: | |
| raise ImportError("pyannote.audio n'est pas installé") | |
| # Convertir l'audio en WAV si nécessaire | |
| audio_path_converted = convert_audio_if_needed(audio_path) | |
| # Configurer le token | |
| if hf_token: | |
| try: | |
| from huggingface_hub import login | |
| login(token=hf_token, add_to_git_credential=False) | |
| except Exception: | |
| pass | |
| try: | |
| pipeline = Pipeline.from_pretrained(model_name, token=hf_token) | |
| except Exception as e: | |
| if "plda" in str(e).lower() or "unexpected keyword" in str(e).lower(): | |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", token=hf_token) | |
| else: | |
| raise | |
| if torch.cuda.is_available(): | |
| pipeline = pipeline.to(torch.device("cuda")) | |
| diarization = pipeline(audio_path_converted) | |
| # Convertir en segments | |
| segments = [] | |
| speakers = sorted(diarization.labels()) | |
| speaker_mapping = {speaker: f"SPEAKER_{idx:02d}" for idx, speaker in enumerate(speakers)} | |
| for segment, track, speaker in diarization.itertracks(yield_label=True): | |
| normalized_speaker = speaker_mapping.get(speaker, speaker) | |
| segments.append({ | |
| "speaker": normalized_speaker, | |
| "start": segment.start, | |
| "end": segment.end | |
| }) | |
| segments.sort(key=lambda x: x["start"]) | |
| return segments | |
| def run_transcription(audio_path: str, device: str = None, hf_token: str = None) -> List[Dict[str, Any]]: | |
| """Exécute la transcription avec le modèle Whisper Large V3 French.""" | |
| if not HAS_WHISPER: | |
| raise ImportError("whisper n'est pas installé") | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_id = "bofenghuang/whisper-large-v3-french" | |
| # Utiliser Transformers pour charger le modèle | |
| try: | |
| if HAS_TRANSFORMERS: | |
| processor = AutoProcessor.from_pretrained(model_id, token=hf_token) | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| low_cpu_mem_usage=True, | |
| token=hf_token | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # Charger l'audio | |
| audio_path_converted = convert_audio_if_needed(audio_path) | |
| waveform, sample_rate = librosa.load(audio_path_converted, sr=16000, mono=True) | |
| # Préparer les inputs | |
| inputs = processor( | |
| waveform, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt" | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Transcription | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| inputs["input_features"], | |
| language="fr", | |
| task="transcribe", | |
| return_timestamps=True | |
| ) | |
| # Décoder avec timestamps | |
| result = processor.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=False, | |
| output_word_timestamps=True | |
| )[0] | |
| # Extraire les segments avec timestamps depuis les tokens | |
| tokens = generated_ids[0].cpu().numpy() | |
| segments = [] | |
| current_segment = {"start": None, "end": None, "text": []} | |
| # Parser les tokens pour extraire les timestamps | |
| for token_id in tokens: | |
| token_text = processor.tokenizer.decode([token_id], skip_special_tokens=False) | |
| # Chercher les tokens de timestamp <|X.XX|> | |
| if "<|" in token_text and "|>" in token_text: | |
| try: | |
| start_idx = token_text.find("<|") + 2 | |
| end_idx = token_text.find("|>") | |
| if start_idx < end_idx: | |
| timestamp_str = token_text[start_idx:end_idx] | |
| timestamp = float(timestamp_str) | |
| if current_segment["start"] is None: | |
| current_segment["start"] = timestamp | |
| else: | |
| current_segment["end"] = timestamp | |
| text = " ".join(current_segment["text"]).strip() | |
| if text: | |
| segments.append({ | |
| "start": current_segment["start"], | |
| "end": current_segment["end"], | |
| "text": text | |
| }) | |
| current_segment = {"start": timestamp, "end": None, "text": []} | |
| except (ValueError, IndexError): | |
| pass | |
| else: | |
| if token_text.strip() and not any(x in token_text for x in ["<|", "|>", "<|startof", "<|endof", "<|notimestamps"]): | |
| current_segment["text"].append(token_text) | |
| # Ajouter le dernier segment | |
| if current_segment["text"]: | |
| text = " ".join(current_segment["text"]).strip() | |
| if text: | |
| duration = len(waveform) / sample_rate | |
| segments.append({ | |
| "start": current_segment["start"] if current_segment["start"] is not None else 0.0, | |
| "end": current_segment["end"] if current_segment["end"] is not None else duration, | |
| "text": text | |
| }) | |
| # Si on n'a pas réussi à extraire les timestamps, utiliser une approche de fallback | |
| if not segments or all(seg.get("start") is None for seg in segments): | |
| # Décoder le texte complet | |
| result_text = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| # Diviser en phrases | |
| sentences = [] | |
| for sent in result_text.split('. '): | |
| if sent.strip(): | |
| sentences.append(sent.strip() + ('.' if not sent.strip().endswith('.') else '')) | |
| if not sentences: | |
| sentences = [result_text.strip()] | |
| # Créer des segments temporels basés sur la durée | |
| duration = len(waveform) / sample_rate | |
| segments = [] | |
| time_per_sentence = duration / len(sentences) | |
| for i, sentence in enumerate(sentences): | |
| start_time = i * time_per_sentence | |
| end_time = min((i + 1) * time_per_sentence, duration) | |
| segments.append({ | |
| "start": start_time, | |
| "end": end_time, | |
| "text": sentence | |
| }) | |
| return segments | |
| except Exception as e: | |
| # Fallback sur Whisper natif | |
| model = whisper.load_model("large-v3", device=device) | |
| audio_path_converted = convert_audio_if_needed(audio_path) | |
| result = model.transcribe( | |
| audio_path_converted, | |
| language="fr", | |
| task="transcribe", | |
| verbose=False | |
| ) | |
| segments = [] | |
| for seg in result["segments"]: | |
| segments.append({ | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "text": seg["text"].strip() | |
| }) | |
| return segments | |
| def combine_diarization_transcription( | |
| diarization_segments: List[Dict[str, Any]], | |
| transcription_segments: List[Dict[str, Any]] | |
| ) -> List[Dict[str, Any]]: | |
| """Combine diarisation et transcription.""" | |
| combined = [] | |
| # Créer une timeline de diarisation | |
| diar_timeline = [ | |
| (seg["start"], seg["end"], seg["speaker"]) | |
| for seg in diarization_segments | |
| ] | |
| diar_timeline.sort() | |
| def get_speaker_for_segment(seg_start: float, seg_end: float) -> str: | |
| """Détermine le locuteur pour un segment.""" | |
| speaker_time = {} | |
| for diar_start, diar_end, speaker in diar_timeline: | |
| overlap_start = max(seg_start, diar_start) | |
| overlap_end = min(seg_end, diar_end) | |
| overlap_duration = max(0, overlap_end - overlap_start) | |
| if overlap_duration > 0: | |
| speaker_time[speaker] = speaker_time.get(speaker, 0) + overlap_duration | |
| if speaker_time: | |
| return max(speaker_time, key=speaker_time.get) | |
| else: | |
| # Trouver le locuteur le plus proche | |
| center_time = (seg_start + seg_end) / 2.0 | |
| min_dist = float('inf') | |
| closest_speaker = "SPEAKER_00" | |
| for diar_start, diar_end, speaker in diar_timeline: | |
| if center_time < diar_start: | |
| dist = diar_start - center_time | |
| elif center_time >= diar_end: | |
| dist = center_time - diar_end | |
| else: | |
| return speaker | |
| if dist < min_dist: | |
| min_dist = dist | |
| closest_speaker = speaker | |
| return closest_speaker | |
| # Combiner les segments | |
| for trans_seg in transcription_segments: | |
| speaker = get_speaker_for_segment(trans_seg["start"], trans_seg["end"]) | |
| combined.append({ | |
| "speaker": speaker, | |
| "start": trans_seg["start"], | |
| "end": trans_seg["end"], | |
| "text": trans_seg["text"] | |
| }) | |
| return combined | |
| def format_output(combined_segments: List[Dict[str, Any]]) -> str: | |
| """Formate la sortie en texte lisible: "Speaker A : blabla".""" | |
| output_lines = [] | |
| current_speaker = None | |
| current_texts = [] | |
| for seg in combined_segments: | |
| speaker = seg["speaker"] | |
| text = seg["text"] | |
| if speaker != current_speaker: | |
| # Écrire le groupe précédent | |
| if current_speaker and current_texts: | |
| speaker_num = int(current_speaker.replace("SPEAKER_", "")) | |
| speaker_name = f"Speaker {chr(65 + speaker_num)}" | |
| output_lines.append(f"{speaker_name} : {' '.join(current_texts)}") | |
| # Nouveau locuteur | |
| current_speaker = speaker | |
| current_texts = [text] | |
| else: | |
| # Même locuteur, ajouter le texte | |
| current_texts.append(text) | |
| # Écrire le dernier groupe | |
| if current_speaker and current_texts: | |
| speaker_num = int(current_speaker.replace("SPEAKER_", "")) | |
| speaker_name = f"Speaker {chr(65 + speaker_num)}" | |
| output_lines.append(f"{speaker_name} : {' '.join(current_texts)}") | |
| return "\n\n".join(output_lines) | |