erik-svensson-cm commited on
Commit
bb2e16d
·
verified ·
1 Parent(s): 443b8d6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -5
handler.py CHANGED
@@ -109,12 +109,13 @@ class EndpointHandler():
109
  stream = torch.cuda.Stream()
110
  with torch.cuda.stream(stream):
111
  try:
112
- diarize_segments = self.diarize_model(
113
  _audio,
114
  min_speakers=parameters.min_speakers,
115
- max_speakers=parameters.max_speakers
 
116
  )
117
- return diarize_segments
118
  except RuntimeError as e:
119
  logger.error(f"Diarization inference error: {str(e)}")
120
  raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
@@ -134,14 +135,13 @@ class EndpointHandler():
134
  diarization_future = executor.submit(run_diarization, audio)
135
 
136
  _result = align_future.result()
137
- diarization_output = diarization_future.result()
138
  result = []
139
  if diarization_output is not None and _result:
140
  result = assign_word_speakers(
141
  diarization_output,
142
  _result,
143
  )
144
- embeddings = diarization_output.speaker_embeddings
145
  # Final cleanup
146
  del diarization_output, segments, audio
147
  gc.collect()
 
109
  stream = torch.cuda.Stream()
110
  with torch.cuda.stream(stream):
111
  try:
112
+ diarize_segments, _embeddings = self.diarize_model(
113
  _audio,
114
  min_speakers=parameters.min_speakers,
115
+ max_speakers=parameters.max_speakers,
116
+ return_embeddings=True
117
  )
118
+ return diarize_segments, _embeddings
119
  except RuntimeError as e:
120
  logger.error(f"Diarization inference error: {str(e)}")
121
  raise HTTPException(status_code=400, detail=f"Diarization inference error: {str(e)}")
 
135
  diarization_future = executor.submit(run_diarization, audio)
136
 
137
  _result = align_future.result()
138
+ diarization_output, embeddings = diarization_future.result()
139
  result = []
140
  if diarization_output is not None and _result:
141
  result = assign_word_speakers(
142
  diarization_output,
143
  _result,
144
  )
 
145
  # Final cleanup
146
  del diarization_output, segments, audio
147
  gc.collect()