evannh commited on
Commit
0d92d48
·
verified ·
1 Parent(s): 54a7679

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -83
app.py CHANGED
@@ -1,88 +1,89 @@
1
- import gradio as gr
2
- import torch
3
  import os
4
- from whisperx import load_model, load_align_model, align
5
- from resemblyzer import preprocess_wav, VoiceEncoder
6
- from sklearn.cluster import AgglomerativeClustering
7
- import librosa
8
- import numpy as np
9
-
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- compute_type = "float16" if device == "cuda" else "int8"
12
-
13
- whisper_model = load_model("medium", device=device, compute_type=compute_type)
14
- align_model, metadata = load_align_model(language_code="fr", device=device)
15
- voice_encoder = VoiceEncoder()
16
-
17
- def get_speaker_segments(audio_path, window_size=1.0, step_size=0.5, num_speakers=2):
18
- wav, sr = librosa.load(audio_path, sr=16000, mono=True)
19
- wav = librosa.util.normalize(wav)
20
- duration = librosa.get_duration(y=wav, sr=sr)
21
-
22
- segments = []
23
- embeddings = []
24
-
25
- for start in np.arange(0, duration - window_size, step_size):
26
- end = start + window_size
27
- clip = wav[int(start * sr):int(end * sr)]
28
- if len(clip) == 0:
29
- continue
30
- try:
31
- embed = voice_encoder.embed_utterance(clip)
32
- embeddings.append(embed)
33
- segments.append((start, end))
34
- except Exception as e:
35
- print(f"⚠️ Skipped segment {start}-{end}s: {e}")
36
-
37
- if len(embeddings) < 2:
38
- print("⚠️ Pas assez de segments pour la diarisation. Diarisation annulée.")
39
- return [{"start": 0, "end": duration, "speaker": "speaker_00"}]
40
-
41
- clustering = AgglomerativeClustering(n_clusters=num_speakers)
42
- labels = clustering.fit_predict(embeddings)
43
-
44
- speaker_segments = []
45
- for (start, end), label in zip(segments, labels):
46
- speaker_segments.append({"start": start, "end": end, "speaker": f"speaker_{label:02d}"})
47
-
48
- return speaker_segments
49
-
50
- def process_audio(audio_file):
51
- tmp_path = audio_file
52
-
53
- # Step 1: Transcription
54
- result = whisper_model.transcribe(tmp_path, language="fr", word_timestamps=False, verbose=False)
55
-
56
- # Step 2: Diarisation via resemblyzer
57
- speaker_segments = get_speaker_segments(tmp_path)
58
-
59
- # Step 3: Alignement mot à mot
60
- result_aligned = align(result["segments"], align_model, metadata, tmp_path, return_char_alignments=False)
61
-
62
- # Attribution speaker
63
- for segment in result_aligned["segments"]:
64
- segment_start = segment["start"]
65
- speaker_found = next((sp["speaker"] for sp in speaker_segments if sp["start"] <= segment_start <= sp["end"]), "speaker_??")
66
- segment["speaker"] = speaker_found
67
-
68
- # Format final
69
- final_output = ""
70
- for seg in result_aligned["segments"]:
71
- speaker = seg["speaker"]
72
- start = f"{seg['start']:.2f}s"
73
- end = f"{seg['end']:.2f}s"
74
- text = seg['text'].strip()
75
- final_output += f"[{start} - {end}] {speaker}: {text}\n"
76
 
77
- return final_output
 
78
 
79
- iface = gr.Interface(
80
- fn=process_audio,
81
- inputs=gr.Audio(type="filepath", label="Audio (.wav, .mp3...)"),
82
- outputs=gr.Textbox(label="Transcription + Diarisation + Alignement"),
83
- title="🎙️ Transcription enrichie avec WhisperX + Resemblyzer",
84
- description="Transcription française, diarisation légère (sans token), alignement mot à mot."
85
  )
86
 
87
- if __name__ == "__main__":
88
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import tempfile
3
+ import pandas as pd
4
+ import gradio as gr
5
+ from pydub import AudioSegment
6
+ from faster_whisper import WhisperModel
7
+ from pyannote.audio import Pipeline as DiarizationPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Initialisation des modèles
10
+ whisper_model = WhisperModel("large-v2", device="cpu", compute_type="int8")
11
 
12
+ diari_pipeline = DiarizationPipeline.from_pretrained(
13
+ "pyannote/speaker-diarization-3.1",
14
+ use_auth_token="hf_YOUR_TOKEN_HERE" # Remplace par ton token Hugging Face perso
 
 
 
15
  )
16
 
17
+ def convert_mp3_to_wav(mp3_path):
18
+ wav_path = tempfile.mktemp(suffix=".wav")
19
+ audio = AudioSegment.from_file(mp3_path, format="mp3")
20
+ audio = audio.set_channels(1).set_frame_rate(16000)
21
+ audio.export(wav_path, format="wav")
22
+ return wav_path
23
+
24
+ def transcribe_and_diarize(audio_file):
25
+ wav_path = convert_mp3_to_wav(audio_file)
26
+
27
+ # Transcription avec Whisper
28
+ segments, _ = whisper_model.transcribe(wav_path, language="fr", beam_size=5)
29
+
30
+ transcript = []
31
+ for seg in segments:
32
+ transcript.append({
33
+ "start": seg.start,
34
+ "end": seg.end,
35
+ "text": seg.text.strip()
36
+ })
37
+
38
+ # Diarisation avec pyannote
39
+ diarization = diari_pipeline(wav_path)
40
+ speakers = []
41
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
42
+ speakers.append({
43
+ "start": turn.start,
44
+ "end": turn.end,
45
+ "speaker": speaker
46
+ })
47
+
48
+ # Fusion transcription + speaker
49
+ final_output = []
50
+ for t in transcript:
51
+ speaker = "Inconnu"
52
+ for d in speakers:
53
+ if d["start"] <= t["start"] <= d["end"]:
54
+ speaker = d["speaker"]
55
+ break
56
+ final_output.append({
57
+ "start": t["start"],
58
+ "end": t["end"],
59
+ "speaker": speaker,
60
+ "text": t["text"]
61
+ })
62
+
63
+ df = pd.DataFrame(final_output)
64
+
65
+ # Export TXT format
66
+ txt_lines = [f"[{row['start']:.2f}s - {row['end']:.2f}s] {row['speaker']} : {row['text']}" for _, row in df.iterrows()]
67
+ txt_output = "\n".join(txt_lines)
68
+ txt_path = tempfile.mktemp(suffix=".txt")
69
+ with open(txt_path, "w", encoding="utf-8") as f:
70
+ f.write(txt_output)
71
+
72
+ # Export CSV format
73
+ csv_path = tempfile.mktemp(suffix=".csv")
74
+ df.to_csv(csv_path, index=False)
75
+
76
+ return txt_output, csv_path, txt_path
77
+
78
+ # Interface Gradio
79
+ gr.Interface(
80
+ fn=transcribe_and_diarize,
81
+ inputs=gr.Audio(type="filepath", label="Fichier audio MP3"),
82
+ outputs=[
83
+ gr.Textbox(label="Transcription avec locuteurs"),
84
+ gr.File(label="Télécharger le CSV"),
85
+ gr.File(label="Télécharger le TXT")
86
+ ],
87
+ title="Transcription + Diarisation (FR)",
88
+ description="Charge un fichier MP3. Transcription FR + séparation des locuteurs + export CSV et TXT."
89
+ ).launch()