ericmattmann commited on
Commit
6bffdd8
·
1 Parent(s): 8931f77

separate transcription and diarization for longer records

Browse files
Files changed (1) hide show
  1. handler.py +64 -61
handler.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  # stdout, stderr = process.communicate()
9
 
10
  import whisperx
11
- import os
12
 
13
  import time
14
  import json
@@ -188,7 +188,7 @@ class EndpointHandler:
188
  def __init__(self, path=""):
189
  # load the model
190
  device, batch_size, compute_type, whisper_model = whisper_config()
191
- # self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type, language="fr")
192
  # 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 ERIC
193
  # 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 OLD
194
  # logger.info(f"Model {whisper_model} initialized")
@@ -218,17 +218,23 @@ class EndpointHandler:
218
  logger.info(display_gpu_infos())
219
 
220
  # 1. process input
221
- # for diarization without transcription, the transcription is given as input, so data is now a tuple (inputs, transcription)
222
- inputs_encoded, transcription = data.pop("inputs", data)
223
- # inputs_encoded = data.pop("inputs", data)
224
  parameters = data.pop("parameters", None)
225
  options = data.pop("options", None)
226
 
227
  # OPTIONS are given as parameters
228
- info = True if options and "info" in options.keys() and options["info"] else False
229
- alignment = True if options and "alignment" in options.keys() and options["alignment"] else False
230
- diarization = False if options and "diarization" in options.keys() and not options["diarization"] else True
231
- language = parameters["language"] if parameters and "language" in parameters.keys() else "fr"
 
 
 
 
 
 
 
 
 
232
 
233
  inputs = base64.b64decode(inputs_encoded)
234
  logger.info(f"inputs decoded.")
@@ -237,82 +243,79 @@ class EndpointHandler:
237
  w.write(inputs)
238
  logger.info(f"inputs saved.")
239
 
240
- # audio_nparray = ffmpeg_load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE, mono=True, out_type=np.float32)
241
  audio_nparray = load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE)
242
  logger.info(f"inputs loaded as mono 16kHz.")
243
  # clean up
244
  os.remove("/tmp/myfile.tmp")
245
  logger.info(f"temp file removed.")
246
 
247
- # audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)
248
- # audio_tensor = torch.from_numpy(audio_nparray)
249
- # logger.info(f"inputs loaded as mono 16kHz.")
250
-
251
- # get the end time
252
  et = time.time()
253
-
254
- # get the execution time
255
  elapsed_time = et - st
 
256
  logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds")
257
  if info:
258
  print(f"TIME for audio processing : {elapsed_time:.2f} seconds")
259
 
260
  # 2. transcribe
261
- # logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
262
- # transcription = self.model.transcribe(audio_nparray, batch_size=batch_size, language=language)
263
- # if info:
264
- # print(transcription["segments"][0:10000]) # before alignment
265
- # logger.info(transcription["segments"][0:10000])
266
-
267
- # try:
268
- # first_text = transcription["segments"][0]["text"]
269
- # except:
270
- # logger.warning("No transcription")
271
- # return {"transcription": transcription["segments"]}
272
-
273
- # # get the execution time
274
- # et = time.time()
275
- # elapsed_time = et - st
276
- # st = time.time()
277
- # logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
278
- # if info:
279
- # print(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
280
-
281
- # # 3. align
282
- # if alignment:
283
- # logger.info("--------------- STARTING ALIGNMENT ------------------------")
284
- # model_a, metadata = whisperx.load_align_model(language_code=transcription["language"], device=device)
285
- # transcription = whisperx.align(
286
- # transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False
287
- # )
288
- # if info:
289
- # print(transcription["segments"][0:10000])
290
- # logger.info(transcription["segments"][0:10000])
291
-
292
- # # get the execution time
293
- # et = time.time()
294
- # elapsed_time = et - st
295
- # st = time.time()
296
- # logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds")
297
- # if info:
298
- # print(f"TIME for alignment : {elapsed_time:.2f} seconds")
 
 
 
 
 
299
 
300
  # 4. Assign speaker labels
301
  if diarization:
 
 
302
  logger.info("--------------- STARTING DIARIZATION ------------------------")
 
 
303
  # add min/max number of speakers if known
304
- diarize_segments = self.diarize_model(audio_nparray)
305
  if info:
306
  print(diarize_segments)
307
  logger.info(diarize_segments)
308
- # diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
309
 
310
  transcription = whisperx.assign_word_speakers(diarize_segments, transcription)
311
- # if info:
312
- # print(transcription["segments"][0:10000])
313
- # logger.info(transcription["segments"][0:10000]) # segments are now assigned speaker IDs
314
 
315
- # get the execution time
316
  et = time.time()
317
  elapsed_time = et - st
318
  st = time.time()
 
8
  # stdout, stderr = process.communicate()
9
 
10
  import whisperx
11
+ import os, gc
12
 
13
  import time
14
  import json
 
188
  def __init__(self, path=""):
189
  # load the model
190
  device, batch_size, compute_type, whisper_model = whisper_config()
191
+ self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type, language="fr")
192
  # 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 ERIC
193
  # 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 OLD
194
  # logger.info(f"Model {whisper_model} initialized")
 
218
  logger.info(display_gpu_infos())
219
 
220
  # 1. process input
 
 
 
221
  parameters = data.pop("parameters", None)
222
  options = data.pop("options", None)
223
 
224
  # OPTIONS are given as parameters
225
+ info = options.get("info", False)
226
+ transcribe = options.get("transcription", False)
227
+ alignment = options.get("alignment", False)
228
+ diarization = options.get("diarization", False)
229
+ language = parameters.get("language", "fr")
230
+ min_speakers = parameters.get("min_speakers", 2)
231
+ max_speakers = parameters.get("max_speakers", 25)
232
+
233
+ # for diarization without transcription, the transcription is given as input, so data is now a tuple (inputs, transcription)
234
+ if transcribe:
235
+ (inputs_encoded,) = data.pop("inputs", data)
236
+ elif diarization:
237
+ inputs_encoded, transcription = data.pop("inputs", data)
238
 
239
  inputs = base64.b64decode(inputs_encoded)
240
  logger.info(f"inputs decoded.")
 
243
  w.write(inputs)
244
  logger.info(f"inputs saved.")
245
 
 
246
  audio_nparray = load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE)
247
  logger.info(f"inputs loaded as mono 16kHz.")
248
  # clean up
249
  os.remove("/tmp/myfile.tmp")
250
  logger.info(f"temp file removed.")
251
 
 
 
 
 
 
252
  et = time.time()
 
 
253
  elapsed_time = et - st
254
+
255
  logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds")
256
  if info:
257
  print(f"TIME for audio processing : {elapsed_time:.2f} seconds")
258
 
259
  # 2. transcribe
260
+ if transcribe:
261
+ gc.collect()
262
+ torch.cuda.empty_cache()
263
+ logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
264
+ transcription = self.model.transcribe(audio_nparray, batch_size=batch_size, language=language)
265
+ if info:
266
+ print(transcription["segments"][0:10_000]) # before alignment
267
+ else:
268
+ logger.info(transcription["segments"][0:1_000])
269
+
270
+ try:
271
+ first_text = transcription["segments"][0]["text"]
272
+ except:
273
+ logger.warning("No transcription")
274
+ return {"transcription": transcription["segments"]}
275
+
276
+ et = time.time()
277
+ elapsed_time = et - st
278
+ st = time.time()
279
+ logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
280
+ if info:
281
+ print(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
282
+
283
+ # 3. align
284
+ if alignment:
285
+ gc.collect()
286
+ torch.cuda.empty_cache()
287
+ logger.info("--------------- STARTING ALIGNMENT ------------------------")
288
+ model_a, metadata = whisperx.load_align_model(language_code=transcription["language"], device=device)
289
+ transcription = whisperx.align(
290
+ transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False
291
+ )
292
+ del model_a
293
+ if info:
294
+ print(transcription["segments"][0:10000])
295
+ logger.info(transcription["segments"][0:10000])
296
+
297
+ et = time.time()
298
+ elapsed_time = et - st
299
+ st = time.time()
300
+ logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds")
301
+ if info:
302
+ print(f"TIME for alignment : {elapsed_time:.2f} seconds")
303
 
304
  # 4. Assign speaker labels
305
  if diarization:
306
+ gc.collect()
307
+ torch.cuda.empty_cache()
308
  logger.info("--------------- STARTING DIARIZATION ------------------------")
309
+ if not transcription:
310
+ logger.warning("No transcription to diarize")
311
  # add min/max number of speakers if known
312
+ diarize_segments = self.diarize_model(audio_nparray, min_speakers=min_speakers, max_speakers=max_speakers)
313
  if info:
314
  print(diarize_segments)
315
  logger.info(diarize_segments)
 
316
 
317
  transcription = whisperx.assign_word_speakers(diarize_segments, transcription)
 
 
 
318
 
 
319
  et = time.time()
320
  elapsed_time = et - st
321
  st = time.time()