Commit ·
6bffdd8
1
Parent(s): 8931f77
separate transcription and diarization for longer records
Browse files- handler.py +64 -61
handler.py
CHANGED
|
@@ -8,7 +8,7 @@ import torch
|
|
| 8 |
# stdout, stderr = process.communicate()
|
| 9 |
|
| 10 |
import whisperx
|
| 11 |
-
import os
|
| 12 |
|
| 13 |
import time
|
| 14 |
import json
|
|
@@ -188,7 +188,7 @@ class EndpointHandler:
|
|
| 188 |
def __init__(self, path=""):
|
| 189 |
# load the model
|
| 190 |
device, batch_size, compute_type, whisper_model = whisper_config()
|
| 191 |
-
|
| 192 |
# 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 ERIC
|
| 193 |
# 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 OLD
|
| 194 |
# logger.info(f"Model {whisper_model} initialized")
|
|
@@ -218,17 +218,23 @@ class EndpointHandler:
|
|
| 218 |
logger.info(display_gpu_infos())
|
| 219 |
|
| 220 |
# 1. process input
|
| 221 |
-
# for diarization without transcription, the transcription is given as input, so data is now a tuple (inputs, transcription)
|
| 222 |
-
inputs_encoded, transcription = data.pop("inputs", data)
|
| 223 |
-
# inputs_encoded = data.pop("inputs", data)
|
| 224 |
parameters = data.pop("parameters", None)
|
| 225 |
options = data.pop("options", None)
|
| 226 |
|
| 227 |
# OPTIONS are given as parameters
|
| 228 |
-
info =
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
inputs = base64.b64decode(inputs_encoded)
|
| 234 |
logger.info(f"inputs decoded.")
|
|
@@ -237,82 +243,79 @@ class EndpointHandler:
|
|
| 237 |
w.write(inputs)
|
| 238 |
logger.info(f"inputs saved.")
|
| 239 |
|
| 240 |
-
# audio_nparray = ffmpeg_load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE, mono=True, out_type=np.float32)
|
| 241 |
audio_nparray = load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE)
|
| 242 |
logger.info(f"inputs loaded as mono 16kHz.")
|
| 243 |
# clean up
|
| 244 |
os.remove("/tmp/myfile.tmp")
|
| 245 |
logger.info(f"temp file removed.")
|
| 246 |
|
| 247 |
-
# audio_nparray = ffmpeg_read(inputs, SAMPLE_RATE)
|
| 248 |
-
# audio_tensor = torch.from_numpy(audio_nparray)
|
| 249 |
-
# logger.info(f"inputs loaded as mono 16kHz.")
|
| 250 |
-
|
| 251 |
-
# get the end time
|
| 252 |
et = time.time()
|
| 253 |
-
|
| 254 |
-
# get the execution time
|
| 255 |
elapsed_time = et - st
|
|
|
|
| 256 |
logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds")
|
| 257 |
if info:
|
| 258 |
print(f"TIME for audio processing : {elapsed_time:.2f} seconds")
|
| 259 |
|
| 260 |
# 2. transcribe
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
#
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
# 4. Assign speaker labels
|
| 301 |
if diarization:
|
|
|
|
|
|
|
| 302 |
logger.info("--------------- STARTING DIARIZATION ------------------------")
|
|
|
|
|
|
|
| 303 |
# add min/max number of speakers if known
|
| 304 |
-
diarize_segments = self.diarize_model(audio_nparray)
|
| 305 |
if info:
|
| 306 |
print(diarize_segments)
|
| 307 |
logger.info(diarize_segments)
|
| 308 |
-
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
| 309 |
|
| 310 |
transcription = whisperx.assign_word_speakers(diarize_segments, transcription)
|
| 311 |
-
# if info:
|
| 312 |
-
# print(transcription["segments"][0:10000])
|
| 313 |
-
# logger.info(transcription["segments"][0:10000]) # segments are now assigned speaker IDs
|
| 314 |
|
| 315 |
-
# get the execution time
|
| 316 |
et = time.time()
|
| 317 |
elapsed_time = et - st
|
| 318 |
st = time.time()
|
|
|
|
| 8 |
# stdout, stderr = process.communicate()
|
| 9 |
|
| 10 |
import whisperx
|
| 11 |
+
import os, gc
|
| 12 |
|
| 13 |
import time
|
| 14 |
import json
|
|
|
|
| 188 |
def __init__(self, path=""):
|
| 189 |
# load the model
|
| 190 |
device, batch_size, compute_type, whisper_model = whisper_config()
|
| 191 |
+
self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type, language="fr")
|
| 192 |
# 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 ERIC
|
| 193 |
# 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 OLD
|
| 194 |
# logger.info(f"Model {whisper_model} initialized")
|
|
|
|
| 218 |
logger.info(display_gpu_infos())
|
| 219 |
|
| 220 |
# 1. process input
|
|
|
|
|
|
|
|
|
|
| 221 |
parameters = data.pop("parameters", None)
|
| 222 |
options = data.pop("options", None)
|
| 223 |
|
| 224 |
# OPTIONS are given as parameters
|
| 225 |
+
info = options.get("info", False)
|
| 226 |
+
transcribe = options.get("transcription", False)
|
| 227 |
+
alignment = options.get("alignment", False)
|
| 228 |
+
diarization = options.get("diarization", False)
|
| 229 |
+
language = parameters.get("language", "fr")
|
| 230 |
+
min_speakers = parameters.get("min_speakers", 2)
|
| 231 |
+
max_speakers = parameters.get("max_speakers", 25)
|
| 232 |
+
|
| 233 |
+
# for diarization without transcription, the transcription is given as input, so data is now a tuple (inputs, transcription)
|
| 234 |
+
if transcribe:
|
| 235 |
+
(inputs_encoded,) = data.pop("inputs", data)
|
| 236 |
+
elif diarization:
|
| 237 |
+
inputs_encoded, transcription = data.pop("inputs", data)
|
| 238 |
|
| 239 |
inputs = base64.b64decode(inputs_encoded)
|
| 240 |
logger.info(f"inputs decoded.")
|
|
|
|
| 243 |
w.write(inputs)
|
| 244 |
logger.info(f"inputs saved.")
|
| 245 |
|
|
|
|
| 246 |
audio_nparray = load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE)
|
| 247 |
logger.info(f"inputs loaded as mono 16kHz.")
|
| 248 |
# clean up
|
| 249 |
os.remove("/tmp/myfile.tmp")
|
| 250 |
logger.info(f"temp file removed.")
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
et = time.time()
|
|
|
|
|
|
|
| 253 |
elapsed_time = et - st
|
| 254 |
+
|
| 255 |
logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds")
|
| 256 |
if info:
|
| 257 |
print(f"TIME for audio processing : {elapsed_time:.2f} seconds")
|
| 258 |
|
| 259 |
# 2. transcribe
|
| 260 |
+
if transcribe:
|
| 261 |
+
gc.collect()
|
| 262 |
+
torch.cuda.empty_cache()
|
| 263 |
+
logger.info("--------------- STARTING TRANSCRIPTION ------------------------")
|
| 264 |
+
transcription = self.model.transcribe(audio_nparray, batch_size=batch_size, language=language)
|
| 265 |
+
if info:
|
| 266 |
+
print(transcription["segments"][0:10_000]) # before alignment
|
| 267 |
+
else:
|
| 268 |
+
logger.info(transcription["segments"][0:1_000])
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
first_text = transcription["segments"][0]["text"]
|
| 272 |
+
except:
|
| 273 |
+
logger.warning("No transcription")
|
| 274 |
+
return {"transcription": transcription["segments"]}
|
| 275 |
+
|
| 276 |
+
et = time.time()
|
| 277 |
+
elapsed_time = et - st
|
| 278 |
+
st = time.time()
|
| 279 |
+
logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
|
| 280 |
+
if info:
|
| 281 |
+
print(f"TIME for audio transcription : {elapsed_time:.2f} seconds")
|
| 282 |
+
|
| 283 |
+
# 3. align
|
| 284 |
+
if alignment:
|
| 285 |
+
gc.collect()
|
| 286 |
+
torch.cuda.empty_cache()
|
| 287 |
+
logger.info("--------------- STARTING ALIGNMENT ------------------------")
|
| 288 |
+
model_a, metadata = whisperx.load_align_model(language_code=transcription["language"], device=device)
|
| 289 |
+
transcription = whisperx.align(
|
| 290 |
+
transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False
|
| 291 |
+
)
|
| 292 |
+
del model_a
|
| 293 |
+
if info:
|
| 294 |
+
print(transcription["segments"][0:10000])
|
| 295 |
+
logger.info(transcription["segments"][0:10000])
|
| 296 |
+
|
| 297 |
+
et = time.time()
|
| 298 |
+
elapsed_time = et - st
|
| 299 |
+
st = time.time()
|
| 300 |
+
logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds")
|
| 301 |
+
if info:
|
| 302 |
+
print(f"TIME for alignment : {elapsed_time:.2f} seconds")
|
| 303 |
|
| 304 |
# 4. Assign speaker labels
|
| 305 |
if diarization:
|
| 306 |
+
gc.collect()
|
| 307 |
+
torch.cuda.empty_cache()
|
| 308 |
logger.info("--------------- STARTING DIARIZATION ------------------------")
|
| 309 |
+
if not transcription:
|
| 310 |
+
logger.warning("No transcription to diarize")
|
| 311 |
# add min/max number of speakers if known
|
| 312 |
+
diarize_segments = self.diarize_model(audio_nparray, min_speakers=min_speakers, max_speakers=max_speakers)
|
| 313 |
if info:
|
| 314 |
print(diarize_segments)
|
| 315 |
logger.info(diarize_segments)
|
|
|
|
| 316 |
|
| 317 |
transcription = whisperx.assign_word_speakers(diarize_segments, transcription)
|
|
|
|
|
|
|
|
|
|
| 318 |
|
|
|
|
| 319 |
et = time.time()
|
| 320 |
elapsed_time = et - st
|
| 321 |
st = time.time()
|