ADGIT commited on
Commit
3156fde
·
1 Parent(s): 301fd30

praise.global extensions: speaker embedding extraction + matching

Browse files

- Extract per-speaker embeddings from pyannote internal wespeaker model
- Cosine similarity matching against known speaker profiles
- Confidence tiers: HIGH (>=0.55), MEDIUM (>=0.35), LOW
- return_embeddings + known_speakers parameters in InferenceConfig
- Backward compatible: original API unchanged without new params
- Bumped pyannote-audio>=3.3.0 for community-1 support

Files changed (5) hide show
  1. README.md +56 -17
  2. config.py +17 -5
  3. diarization_utils.py +226 -20
  4. handler.py +34 -12
  5. requirements.txt +8 -8
README.md CHANGED
@@ -1,23 +1,62 @@
1
- ASR+Diarization handler that works natively with Inference Endpoints.
 
 
 
2
 
3
- Example payload:
4
- ```python
5
- import base64
6
- import requests
7
 
8
- API_URL = "<your endpoint URL>"
9
- filepath = "/path/to/audio"
10
 
11
- with open(filepath, 'rb') as f:
12
- audio_encoded = base64.b64encode(f.read()).decode("utf-8")
13
 
14
- data = {
15
- "inputs": audio_encoded,
16
- "parameters": {
17
- "batch_size": 24
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- resp = requests.post(API_URL, json=data, headers={"Authorization": "Bearer <your token>"})
22
- print(resp.json())
23
- ```
 
 
 
1
+ ---
2
+ tags:
3
+ - endpoints-compatible
4
+ ---
5
 
6
+ # praise-ml-handler
 
 
 
7
 
8
+ Unified ASR + Diarization + Speaker Embedding + Speaker Matching handler for praise.global.
 
9
 
10
+ Forked from [sergeipetrov/asrdiarization-handler](https://huggingface.co/sergeipetrov/asrdiarization-handler).
 
11
 
12
+ ## Extensions over upstream
13
+
14
+ - **Speaker embedding extraction** — extracts per-speaker embeddings from pyannote's internal wespeaker model as a byproduct of diarization
15
+ - **Speaker matching** — matches diarized speakers against known voice profiles using cosine similarity
16
+ - **Confidence tiers** — HIGH (≥0.55), MEDIUM (≥0.35), LOW (<0.35) calibrated for pyannote embeddings
17
+
18
+ ## API
19
+
20
+ Standard Inference Endpoint `POST /` with `inputs` (base64 audio) and `parameters`:
21
+
22
+ ```json
23
+ {
24
+ "inputs": "<base64_audio>",
25
+ "parameters": {
26
+ "task": "transcribe",
27
+ "language": "en",
28
+ "batch_size": 24,
29
+ "chunk_length_s": 30,
30
+ "min_speakers": 2,
31
+ "max_speakers": 12,
32
+ "return_embeddings": true,
33
+ "known_speakers": [
34
+ {"slug": "bob-ryan", "name": "Bob Ryan", "centroid_b64": "..."}
35
+ ]
36
+ }
37
  }
38
+ ```
39
+
40
+ ## Response
41
+
42
+ ```json
43
+ {
44
+ "text": "full transcript...",
45
+ "chunks": [...],
46
+ "speakers": [...],
47
+ "speaker_embeddings": {
48
+ "SPEAKER_00": {"embedding_b64": "...", "embedding_dim": 512, "total_seconds": 45.2, "num_segments": 12}
49
+ },
50
+ "speaker_matches": {
51
+ "SPEAKER_00": {"matched_slug": "bob-ryan", "matched_name": "Bob Ryan", "confidence": "HIGH", "score": 0.72}
52
+ }
53
+ }
54
+ ```
55
+
56
+ ## Deployment
57
 
58
+ Create via HF Inference Endpoints API with env vars:
59
+ - `ASR_MODEL=openai/whisper-large-v3`
60
+ - `DIARIZATION_MODEL=pyannote/speaker-diarization-3.1`
61
+ - `HF_TOKEN=<your_token>`
62
+ - `ASSISTANT_MODEL=distil-whisper/distil-large-v3` (optional, for speculative decoding)
config.py CHANGED
@@ -2,16 +2,25 @@ import logging
2
 
3
  from pydantic import BaseModel
4
  from pydantic_settings import BaseSettings
5
- from typing import Optional, Literal
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
 
10
  class ModelSettings(BaseSettings):
11
  asr_model: str
12
- assistant_model: Optional[str]
13
- diarization_model: Optional[str]
14
- hf_token: Optional[str]
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class InferenceConfig(BaseModel):
@@ -24,10 +33,13 @@ class InferenceConfig(BaseModel):
24
  num_speakers: Optional[int] = None
25
  min_speakers: Optional[int] = None
26
  max_speakers: Optional[int] = None
 
 
 
27
 
28
 
29
  model_settings = ModelSettings()
30
 
31
  logger.info(f"asr model: {model_settings.asr_model}")
32
  logger.info(f"assist model: {model_settings.assistant_model}")
33
- logger.info(f"diar model: {model_settings.diarization_model}")
 
2
 
3
  from pydantic import BaseModel
4
  from pydantic_settings import BaseSettings
5
+ from typing import Optional, Literal, List
6
 
7
  logger = logging.getLogger(__name__)
8
 
9
 
10
  class ModelSettings(BaseSettings):
11
  asr_model: str
12
+ assistant_model: Optional[str] = None
13
+ diarization_model: Optional[str] = None
14
+ hf_token: Optional[str] = None
15
+
16
+
17
+ class KnownSpeaker(BaseModel):
18
+ """A known speaker profile for matching."""
19
+ slug: str
20
+ name: str
21
+ centroid_b64: str # base64-encoded float32 embedding
22
+ # Optional additional sample embeddings for best-of-N matching
23
+ samples: Optional[List[dict]] = None
24
 
25
 
26
  class InferenceConfig(BaseModel):
 
33
  num_speakers: Optional[int] = None
34
  min_speakers: Optional[int] = None
35
  max_speakers: Optional[int] = None
36
+ # praise.global extensions
37
+ return_embeddings: bool = False
38
+ known_speakers: Optional[List[dict]] = None # List of KnownSpeaker dicts
39
 
40
 
41
  model_settings = ModelSettings()
42
 
43
  logger.info(f"asr model: {model_settings.asr_model}")
44
  logger.info(f"assist model: {model_settings.assistant_model}")
45
+ logger.info(f"diar model: {model_settings.diarization_model}")
diarization_utils.py CHANGED
@@ -1,16 +1,15 @@
1
  import torch
2
  import numpy as np
 
3
  from torchaudio import functional as F
4
  from transformers.pipelines.audio_utils import ffmpeg_read
5
  from starlette.exceptions import HTTPException
6
  import sys
7
 
8
- # Code from insanely-fast-whisper:
9
- # https://github.com/Vaibhavs10/insanely-fast-whisper
10
-
11
  import logging
12
  logger = logging.getLogger(__name__)
13
 
 
14
  def preprocess_inputs(inputs, sampling_rate):
15
  inputs = ffmpeg_read(inputs, sampling_rate)
16
 
@@ -20,10 +19,10 @@ def preprocess_inputs(inputs, sampling_rate):
20
  ).numpy()
21
 
22
  if len(inputs.shape) != 1:
23
- logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
  raise HTTPException(
25
  status_code=400,
26
- detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
  )
28
 
29
  # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
@@ -51,17 +50,14 @@ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
51
  }
52
  )
53
 
54
- # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
- # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
  new_segments = []
57
  prev_segment = cur_segment = segments[0]
58
 
59
  for i in range(1, len(segments)):
60
  cur_segment = segments[i]
61
 
62
- # check if we have changed speaker ("label")
63
  if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
- # add the start/end times for the super-segment to the new list
65
  new_segments.append(
66
  {
67
  "segment": {
@@ -73,7 +69,6 @@ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
73
  )
74
  prev_segment = segments[i]
75
 
76
- # add the last segment(s) if there was no speaker change
77
  new_segments.append(
78
  {
79
  "segment": {
@@ -84,20 +79,196 @@ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
84
  }
85
  )
86
 
87
- return new_segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
- # get the end timestamps for each chunk from the ASR output
92
  end_timestamps = np.array(
93
  [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
  segmented_preds = []
95
 
96
- # align the diarizer timestamps and the ASR timestamps
97
  for segment in new_segments:
98
- # get the diarizer end timestamp
99
  end_time = segment["segment"]["end"]
100
- # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
  upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
 
103
  if group_by_speaker:
@@ -117,7 +288,6 @@ def post_process_segments_and_transcripts(new_segments, transcript, group_by_spe
117
  for i in range(upto_idx + 1):
118
  segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
 
120
- # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
  transcript = transcript[upto_idx + 1:]
122
  end_timestamps = end_timestamps[upto_idx + 1:]
123
 
@@ -128,14 +298,50 @@ def post_process_segments_and_transcripts(new_segments, transcript, group_by_spe
128
 
129
 
130
  def diarize(diarization_pipeline, file, parameters, asr_outputs):
 
131
  _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
 
133
- segments = diarize_audio(
134
- diarizer_inputs,
135
- diarization_pipeline,
136
  parameters
137
  )
138
 
139
  return post_process_segments_and_transcripts(
140
  segments, asr_outputs["chunks"], group_by_speaker=False
141
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ import base64
4
  from torchaudio import functional as F
5
  from transformers.pipelines.audio_utils import ffmpeg_read
6
  from starlette.exceptions import HTTPException
7
  import sys
8
 
 
 
 
9
  import logging
10
  logger = logging.getLogger(__name__)
11
 
12
+
13
  def preprocess_inputs(inputs, sampling_rate):
14
  inputs = ffmpeg_read(inputs, sampling_rate)
15
 
 
19
  ).numpy()
20
 
21
  if len(inputs.shape) != 1:
22
+ logger.error(f"Diarization pipeline expects single channel audio, received {inputs.shape}")
23
  raise HTTPException(
24
  status_code=400,
25
+ detail=f"Diarization pipeline expects single channel audio, received {inputs.shape}"
26
  )
27
 
28
  # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
 
50
  }
51
  )
52
 
53
+ # Combine consecutive segments from the same speaker
 
54
  new_segments = []
55
  prev_segment = cur_segment = segments[0]
56
 
57
  for i in range(1, len(segments)):
58
  cur_segment = segments[i]
59
 
 
60
  if cur_segment["label"] != prev_segment["label"] and i < len(segments):
 
61
  new_segments.append(
62
  {
63
  "segment": {
 
69
  )
70
  prev_segment = segments[i]
71
 
 
72
  new_segments.append(
73
  {
74
  "segment": {
 
79
  }
80
  )
81
 
82
+ return new_segments, diarization
83
+
84
+
85
+ def extract_speaker_embeddings(diarization_pipeline, diarizer_inputs, diarization_result, sampling_rate=16000):
86
+ """
87
+ Extract per-speaker embeddings from pyannote's internal embedding model.
88
+
89
+ pyannote's SpeakerDiarization pipeline has an internal embedding model
90
+ (wespeaker-based, 512-dim) that we can access directly. We use the
91
+ diarization result to identify which audio regions belong to each speaker,
92
+ then extract embeddings for those regions.
93
+ """
94
+ try:
95
+ # Access pyannote's internal embedding model
96
+ embedding_model = diarization_pipeline._embedding
97
+ device = next(embedding_model.parameters()).device
98
+
99
+ # Collect per-speaker audio segments
100
+ speaker_labels = set()
101
+ for segment, _, label in diarization_result.itertracks(yield_label=True):
102
+ speaker_labels.add(label)
103
+
104
+ speaker_embeddings = {}
105
+
106
+ for speaker in speaker_labels:
107
+ # Get all segments for this speaker
108
+ speaker_segments = []
109
+ total_seconds = 0.0
110
+ for segment, _, label in diarization_result.itertracks(yield_label=True):
111
+ if label == speaker:
112
+ speaker_segments.append(segment)
113
+ total_seconds += segment.duration
114
+
115
+ if total_seconds < 0.5:
116
+ logger.warning(f"Speaker {speaker} has only {total_seconds:.1f}s of audio, skipping embedding")
117
+ continue
118
+
119
+ # Extract audio for each segment and compute embeddings
120
+ segment_embeddings = []
121
+ waveform = diarizer_inputs # shape: (1, seq_len)
122
+
123
+ for seg in speaker_segments:
124
+ start_sample = int(seg.start * sampling_rate)
125
+ end_sample = int(seg.end * sampling_rate)
126
+
127
+ if end_sample > waveform.shape[1]:
128
+ end_sample = waveform.shape[1]
129
+ if end_sample - start_sample < sampling_rate * 0.3: # skip < 0.3s
130
+ continue
131
+
132
+ chunk = waveform[:, start_sample:end_sample].to(device)
133
+
134
+ with torch.no_grad():
135
+ emb = embedding_model(chunk)
136
+
137
+ # Normalize
138
+ if emb.dim() > 1:
139
+ emb = emb.squeeze()
140
+ emb = emb / (torch.norm(emb) + 1e-8)
141
+ segment_embeddings.append(emb.cpu().numpy())
142
+
143
+ if len(segment_embeddings) == 0:
144
+ continue
145
+
146
+ # Compute centroid (mean of all segment embeddings)
147
+ centroid = np.mean(segment_embeddings, axis=0).astype(np.float32)
148
+ centroid = centroid / (np.linalg.norm(centroid) + 1e-8)
149
+
150
+ # Encode as base64
151
+ centroid_b64 = base64.b64encode(centroid.tobytes()).decode("utf-8")
152
+
153
+ speaker_embeddings[speaker] = {
154
+ "embedding_b64": centroid_b64,
155
+ "embedding_dim": int(centroid.shape[0]),
156
+ "total_seconds": round(total_seconds, 2),
157
+ "num_segments": len(segment_embeddings),
158
+ }
159
+
160
+ logger.info(f"Speaker {speaker}: {total_seconds:.1f}s, {len(segment_embeddings)} segments, dim={centroid.shape[0]}")
161
+
162
+ return speaker_embeddings
163
+
164
+ except Exception as e:
165
+ logger.error(f"Error extracting speaker embeddings: {str(e)}")
166
+ import traceback
167
+ logger.error(traceback.format_exc())
168
+ return {}
169
+
170
+
171
+ def match_speakers(speaker_embeddings, known_speakers):
172
+ """
173
+ Match diarized speakers against known speaker profiles using cosine similarity.
174
+
175
+ known_speakers: list of dicts with {slug, name, centroid_b64, samples?}
176
+ speaker_embeddings: dict from extract_speaker_embeddings
177
+
178
+ Returns dict mapping SPEAKER_XX -> {matched_slug, matched_name, confidence, score}
179
+ """
180
+ if not known_speakers or not speaker_embeddings:
181
+ return {}
182
+
183
+ # Decode known speaker centroids
184
+ known_profiles = []
185
+ for ks in known_speakers:
186
+ try:
187
+ centroid_bytes = base64.b64decode(ks["centroid_b64"])
188
+ centroid = np.frombuffer(centroid_bytes, dtype=np.float32)
189
+
190
+ # Also decode sample embeddings if present
191
+ samples = []
192
+ if ks.get("samples"):
193
+ for s in ks["samples"]:
194
+ if s.get("embedding_b64"):
195
+ s_bytes = base64.b64decode(s["embedding_b64"])
196
+ samples.append(np.frombuffer(s_bytes, dtype=np.float32))
197
+
198
+ known_profiles.append({
199
+ "slug": ks["slug"],
200
+ "name": ks["name"],
201
+ "centroid": centroid,
202
+ "samples": samples,
203
+ })
204
+ except Exception as e:
205
+ logger.warning(f"Could not decode profile for {ks.get('slug', '?')}: {e}")
206
+ continue
207
+
208
+ if not known_profiles:
209
+ return {}
210
+
211
+ matches = {}
212
+
213
+ for spk_label, spk_data in speaker_embeddings.items():
214
+ try:
215
+ query_bytes = base64.b64decode(spk_data["embedding_b64"])
216
+ query = np.frombuffer(query_bytes, dtype=np.float32)
217
+ except Exception:
218
+ continue
219
+
220
+ best_score = -1.0
221
+ best_profile = None
222
+
223
+ for profile in known_profiles:
224
+ # Cosine similarity with centroid
225
+ centroid_score = float(np.dot(query, profile["centroid"]) /
226
+ (np.linalg.norm(query) * np.linalg.norm(profile["centroid"]) + 1e-8))
227
+
228
+ # Best-of-N: also check individual samples
229
+ best_sample_score = centroid_score
230
+ for sample in profile["samples"]:
231
+ s_score = float(np.dot(query, sample) /
232
+ (np.linalg.norm(query) * np.linalg.norm(sample) + 1e-8))
233
+ best_sample_score = max(best_sample_score, s_score)
234
+
235
+ # Final score = max of centroid and best sample
236
+ final_score = max(centroid_score, best_sample_score)
237
+
238
+ if final_score > best_score:
239
+ best_score = final_score
240
+ best_profile = profile
241
+
242
+ if best_profile is None:
243
+ continue
244
+
245
+ # Confidence tiers (calibrated for pyannote wespeaker embeddings)
246
+ if best_score >= 0.55:
247
+ confidence = "HIGH"
248
+ elif best_score >= 0.35:
249
+ confidence = "MEDIUM"
250
+ else:
251
+ confidence = "LOW"
252
+
253
+ matches[spk_label] = {
254
+ "matched_slug": best_profile["slug"],
255
+ "matched_name": best_profile["name"],
256
+ "confidence": confidence,
257
+ "score": round(best_score, 4),
258
+ }
259
+
260
+ logger.info(f"Speaker {spk_label} -> {best_profile['name']} ({confidence}, {best_score:.4f})")
261
+
262
+ return matches
263
 
264
 
265
  def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
 
266
  end_timestamps = np.array(
267
  [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
268
  segmented_preds = []
269
 
 
270
  for segment in new_segments:
 
271
  end_time = segment["segment"]["end"]
 
272
  upto_idx = np.argmin(np.abs(end_timestamps - end_time))
273
 
274
  if group_by_speaker:
 
288
  for i in range(upto_idx + 1):
289
  segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
290
 
 
291
  transcript = transcript[upto_idx + 1:]
292
  end_timestamps = end_timestamps[upto_idx + 1:]
293
 
 
298
 
299
 
300
  def diarize(diarization_pipeline, file, parameters, asr_outputs):
301
+ """Original diarize function — backward compatible."""
302
  _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
303
 
304
+ segments, _ = diarize_audio(
305
+ diarizer_inputs,
306
+ diarization_pipeline,
307
  parameters
308
  )
309
 
310
  return post_process_segments_and_transcripts(
311
  segments, asr_outputs["chunks"], group_by_speaker=False
312
+ )
313
+
314
+
315
+ def diarize_with_embeddings(diarization_pipeline, file, parameters, asr_outputs):
316
+ """
317
+ Extended diarize that also extracts per-speaker embeddings and optionally
318
+ matches against known speaker profiles.
319
+
320
+ Returns: (transcript, speaker_embeddings, speaker_matches)
321
+ """
322
+ _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
323
+
324
+ segments, diarization_result = diarize_audio(
325
+ diarizer_inputs,
326
+ diarization_pipeline,
327
+ parameters
328
+ )
329
+
330
+ transcript = post_process_segments_and_transcripts(
331
+ segments, asr_outputs["chunks"], group_by_speaker=False
332
+ )
333
+
334
+ # Extract embeddings
335
+ speaker_embeddings = {}
336
+ if parameters.return_embeddings:
337
+ speaker_embeddings = extract_speaker_embeddings(
338
+ diarization_pipeline, diarizer_inputs, diarization_result,
339
+ sampling_rate=parameters.sampling_rate
340
+ )
341
+
342
+ # Match against known speakers
343
+ speaker_matches = {}
344
+ if parameters.known_speakers and speaker_embeddings:
345
+ speaker_matches = match_speakers(speaker_embeddings, parameters.known_speakers)
346
+
347
+ return transcript, speaker_embeddings, speaker_matches
handler.py CHANGED
@@ -5,7 +5,7 @@ import base64
5
 
6
  from pyannote.audio import Pipeline
7
  from transformers import pipeline, AutoModelForCausalLM
8
- from diarization_utils import diarize
9
  from huggingface_hub import HfApi
10
  from pydantic import ValidationError
11
  from starlette.exceptions import HTTPException
@@ -49,8 +49,8 @@ class EndpointHandler():
49
  self.diarization_pipeline.to(device)
50
  else:
51
  self.diarization_pipeline = None
52
-
53
-
54
  def __call__(self, inputs):
55
  file = inputs.pop("inputs")
56
  file = base64.b64decode(file)
@@ -60,15 +60,16 @@ class EndpointHandler():
60
  except ValidationError as e:
61
  logger.error(f"Error validating parameters: {e}")
62
  raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
63
-
64
  logger.info(f"inference parameters: {parameters}")
65
 
66
  generate_kwargs = {
67
- "task": parameters.task,
68
  "language": parameters.language,
69
  "assistant_model": self.assistant_model if parameters.assisted else None
70
  }
71
 
 
72
  try:
73
  asr_outputs = self.asr_pipeline(
74
  file,
@@ -81,23 +82,44 @@ class EndpointHandler():
81
  logger.error(f"ASR inference error: {str(e)}")
82
  raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
83
  except Exception as e:
84
- logger.error(f"Unknown error diring ASR inference: {str(e)}")
85
- raise HTTPException(status_code=500, detail=f"Unknown error diring ASR inference: {str(e)}")
 
 
 
 
 
86
 
87
  if self.diarization_pipeline:
 
 
88
  try:
89
- transcript = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
 
 
 
 
 
90
  except RuntimeError as e:
91
  logger.error(f"Diarization inference error: {str(e)}")
92
  raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
93
  except Exception as e:
94
  logger.error(f"Unknown error during diarization: {str(e)}")
95
  raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
96
- else:
97
- transcript = []
98
 
99
- return {
 
100
  "speakers": transcript,
101
  "chunks": asr_outputs["chunks"],
102
  "text": asr_outputs["text"],
103
- }
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from pyannote.audio import Pipeline
7
  from transformers import pipeline, AutoModelForCausalLM
8
+ from diarization_utils import diarize, diarize_with_embeddings
9
  from huggingface_hub import HfApi
10
  from pydantic import ValidationError
11
  from starlette.exceptions import HTTPException
 
49
  self.diarization_pipeline.to(device)
50
  else:
51
  self.diarization_pipeline = None
52
+
53
+
54
  def __call__(self, inputs):
55
  file = inputs.pop("inputs")
56
  file = base64.b64decode(file)
 
60
  except ValidationError as e:
61
  logger.error(f"Error validating parameters: {e}")
62
  raise HTTPException(status_code=400, detail=f"Error validating parameters: {e}")
63
+
64
  logger.info(f"inference parameters: {parameters}")
65
 
66
  generate_kwargs = {
67
+ "task": parameters.task,
68
  "language": parameters.language,
69
  "assistant_model": self.assistant_model if parameters.assisted else None
70
  }
71
 
72
+ # --- ASR ---
73
  try:
74
  asr_outputs = self.asr_pipeline(
75
  file,
 
82
  logger.error(f"ASR inference error: {str(e)}")
83
  raise HTTPException(status_code=400, detail=f"ASR inference error: {str(e)}")
84
  except Exception as e:
85
+ logger.error(f"Unknown error during ASR inference: {str(e)}")
86
+ raise HTTPException(status_code=500, detail=f"Unknown error during ASR inference: {str(e)}")
87
+
88
+ # --- Diarization ---
89
+ speaker_embeddings = {}
90
+ speaker_matches = {}
91
+ transcript = []
92
 
93
  if self.diarization_pipeline:
94
+ use_extended = parameters.return_embeddings or parameters.known_speakers
95
+
96
  try:
97
+ if use_extended:
98
+ transcript, speaker_embeddings, speaker_matches = diarize_with_embeddings(
99
+ self.diarization_pipeline, file, parameters, asr_outputs
100
+ )
101
+ else:
102
+ transcript = diarize(self.diarization_pipeline, file, parameters, asr_outputs)
103
  except RuntimeError as e:
104
  logger.error(f"Diarization inference error: {str(e)}")
105
  raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
106
  except Exception as e:
107
  logger.error(f"Unknown error during diarization: {str(e)}")
108
  raise HTTPException(status_code=500, detail=f"Unknown error during diarization: {str(e)}")
 
 
109
 
110
+ # --- Response ---
111
+ response = {
112
  "speakers": transcript,
113
  "chunks": asr_outputs["chunks"],
114
  "text": asr_outputs["text"],
115
+ }
116
+
117
+ # Include embeddings if requested
118
+ if speaker_embeddings:
119
+ response["speaker_embeddings"] = speaker_embeddings
120
+
121
+ # Include matches if known speakers were provided
122
+ if speaker_matches:
123
+ response["speaker_matches"] = speaker_matches
124
+
125
+ return response
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
- accelerate==0.27.2
2
- torch==2.2.1
3
- pyannote-audio==3.1.1
4
- transformers==4.38.2
5
- numpy==1.26.4
6
- torchaudio==2.2.1
7
- pydantic==2.6.3
8
- pydantic-settings==2.2.1
 
1
+ accelerate>=0.27.2
2
+ torch>=2.2.1
3
+ pyannote-audio>=3.3.0
4
+ transformers>=4.38.2
5
+ numpy>=1.26.4
6
+ torchaudio>=2.2.1
7
+ pydantic>=2.6.3
8
+ pydantic-settings>=2.2.1