marcosremar2 commited on
Commit
6df750f
·
1 Parent(s): a8b6268

Add Whisper transcription to speaker diarization

Browse files
Files changed (2) hide show
  1. app.py +65 -26
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from pyannote.audio import Pipeline
3
  import torch
4
- import torchaudio
5
  from huggingface_hub import login
6
  import os
7
  import traceback
@@ -10,38 +10,50 @@ import traceback
10
  hf_token = os.environ.get("HF_TOKEN")
11
  if not hf_token:
12
  print("WARNING: HF_TOKEN environment variable not found. Please set it in the Space settings.")
13
- pipeline = None
14
  else:
15
  try:
16
  login(token=hf_token)
17
  print("Successfully logged in to Hugging Face")
18
 
19
- # Initialize the pipeline
20
  print("Loading pyannote/speaker-diarization-3.1 pipeline...")
21
- pipeline = Pipeline.from_pretrained(
22
  "pyannote/speaker-diarization-3.1",
23
  use_auth_token=hf_token
24
  )
25
- print("Pipeline loaded successfully!")
26
 
27
  # Send pipeline to GPU if available
28
  if torch.cuda.is_available():
29
  print("GPU detected, moving pipeline to GPU")
30
- pipeline.to(torch.device("cuda"))
31
  else:
32
  print("No GPU detected, using CPU")
33
 
34
  except Exception as e:
35
- print(f"Error loading pipeline: {e}")
36
  print(f"Error type: {type(e).__name__}")
37
  print("Traceback:")
38
  traceback.print_exc()
39
- pipeline = None
40
 
41
- def diarize_audio(audio_file):
42
- """Process audio file and return diarization results"""
43
- if pipeline is None:
44
- return "❌ Pipeline not loaded. Please ensure HF_TOKEN is set and you have access to pyannote/speaker-diarization-3.1. Check the logs for more details."
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  if audio_file is None:
47
  return "Please upload an audio file."
@@ -49,21 +61,48 @@ def diarize_audio(audio_file):
49
  try:
50
  print(f"Processing audio file: {audio_file}")
51
 
52
- # Apply pretrained pipeline
53
- diarization = pipeline(audio_file)
 
 
 
 
 
 
 
 
54
 
55
- # Format results
56
  results = []
57
- for turn, _, speaker in diarization.itertracks(yield_label=True):
58
- results.append(
59
- f"Speaker {speaker}: {turn.start:.1f}s - {turn.end:.1f}s"
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if not results:
63
- return "No speakers detected in the audio."
 
 
 
 
 
64
 
65
- print(f"Successfully processed audio. Found {len(set([r.split(':')[0] for r in results]))} speakers")
66
- return "\n".join(results)
67
 
68
  except Exception as e:
69
  error_msg = f"Error processing audio: {str(e)}"
@@ -73,11 +112,11 @@ def diarize_audio(audio_file):
73
 
74
  # Create Gradio interface
75
  demo = gr.Interface(
76
- fn=diarize_audio,
77
  inputs=gr.Audio(type="filepath", label="Upload Audio File"),
78
- outputs=gr.Textbox(label="Diarization Results", lines=10),
79
- title="Speaker Diarization with Pyannote 3.1",
80
- description="Upload an audio file to identify different speakers and their speaking times.",
81
  examples=[],
82
  cache_examples=False
83
  )
 
1
  import gradio as gr
2
  from pyannote.audio import Pipeline
3
  import torch
4
+ import whisper
5
  from huggingface_hub import login
6
  import os
7
  import traceback
 
10
  hf_token = os.environ.get("HF_TOKEN")
11
  if not hf_token:
12
  print("WARNING: HF_TOKEN environment variable not found. Please set it in the Space settings.")
13
+ diarization_pipeline = None
14
  else:
15
  try:
16
  login(token=hf_token)
17
  print("Successfully logged in to Hugging Face")
18
 
19
+ # Initialize the diarization pipeline
20
  print("Loading pyannote/speaker-diarization-3.1 pipeline...")
21
+ diarization_pipeline = Pipeline.from_pretrained(
22
  "pyannote/speaker-diarization-3.1",
23
  use_auth_token=hf_token
24
  )
25
+ print("Diarization pipeline loaded successfully!")
26
 
27
  # Send pipeline to GPU if available
28
  if torch.cuda.is_available():
29
  print("GPU detected, moving pipeline to GPU")
30
+ diarization_pipeline.to(torch.device("cuda"))
31
  else:
32
  print("No GPU detected, using CPU")
33
 
34
  except Exception as e:
35
+ print(f"Error loading diarization pipeline: {e}")
36
  print(f"Error type: {type(e).__name__}")
37
  print("Traceback:")
38
  traceback.print_exc()
39
+ diarization_pipeline = None
40
 
41
+ # Load Whisper model
42
+ try:
43
+ print("Loading Whisper model...")
44
+ whisper_model = whisper.load_model("base")
45
+ print("Whisper model loaded successfully!")
46
+ except Exception as e:
47
+ print(f"Error loading Whisper model: {e}")
48
+ whisper_model = None
49
+
50
+ def transcribe_with_diarization(audio_file):
51
+ """Process audio file for both diarization and transcription"""
52
+ if diarization_pipeline is None:
53
+ return "❌ Diarization pipeline not loaded. Please ensure HF_TOKEN is set and you have access to pyannote/speaker-diarization-3.1."
54
+
55
+ if whisper_model is None:
56
+ return "❌ Whisper model not loaded."
57
 
58
  if audio_file is None:
59
  return "Please upload an audio file."
 
61
  try:
62
  print(f"Processing audio file: {audio_file}")
63
 
64
+ # Step 1: Transcribe with Whisper
65
+ print("Transcribing audio with Whisper...")
66
+ transcription_result = whisper_model.transcribe(audio_file, language="pt")
67
+ segments = transcription_result["segments"]
68
+ print(f"Transcription complete. Found {len(segments)} segments")
69
+
70
+ # Step 2: Diarize with pyannote
71
+ print("Performing speaker diarization...")
72
+ diarization = diarization_pipeline(audio_file)
73
+ print("Diarization complete")
74
 
75
+ # Step 3: Match transcription segments with speaker labels
76
  results = []
77
+
78
+ for segment in segments:
79
+ start_time = segment['start']
80
+ end_time = segment['end']
81
+ text = segment['text'].strip()
82
+
83
+ # Find the speaker at this timestamp
84
+ speaker = None
85
+ for turn, _, label in diarization.itertracks(yield_label=True):
86
+ # Check if this segment overlaps with the speaker turn
87
+ if turn.start <= start_time <= turn.end or turn.start <= end_time <= turn.end:
88
+ speaker = label
89
+ break
90
+
91
+ if speaker:
92
+ results.append(f"[{speaker}] ({start_time:.1f}s - {end_time:.1f}s): {text}")
93
+ else:
94
+ results.append(f"[Unknown] ({start_time:.1f}s - {end_time:.1f}s): {text}")
95
 
96
  if not results:
97
+ return "No transcription available."
98
+
99
+ # Add summary
100
+ speakers = set()
101
+ for turn, _, speaker in diarization.itertracks(yield_label=True):
102
+ speakers.add(speaker)
103
 
104
+ summary = f"Found {len(speakers)} speakers in the conversation.\n\n"
105
+ return summary + "\n".join(results)
106
 
107
  except Exception as e:
108
  error_msg = f"Error processing audio: {str(e)}"
 
112
 
113
  # Create Gradio interface
114
  demo = gr.Interface(
115
+ fn=transcribe_with_diarization,
116
  inputs=gr.Audio(type="filepath", label="Upload Audio File"),
117
+ outputs=gr.Textbox(label="Transcription with Speaker Identification", lines=20),
118
+ title="Speaker Diarization + Transcription",
119
+ description="Upload an audio file to identify different speakers and transcribe what they said. Uses pyannote for speaker identification and Whisper for transcription.",
120
  examples=[],
121
  cache_examples=False
122
  )
requirements.txt CHANGED
@@ -3,4 +3,6 @@ torch>=2.0.0
3
  torchaudio>=2.0.0
4
  gradio>=4.0.0
5
  huggingface_hub
6
- speechbrain>=0.5.16
 
 
 
3
  torchaudio>=2.0.0
4
  gradio>=4.0.0
5
  huggingface_hub
6
+ speechbrain>=0.5.16
7
+ openai-whisper
8
+ ffmpeg-python