Capstone04 commited on
Commit
24ed202
·
verified ·
1 Parent(s): 83ddb5c

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. asr_diarization/pipeline.py +15 -3
asr_diarization/pipeline.py CHANGED
@@ -406,16 +406,28 @@ class ASR_Diarization:
406
  results["DER"] = round(der_score * 100, 2)
407
 
408
  if ref_json and os.path.exists(hyp_json):
409
- def load_words(path):
 
410
  data = json.load(open(path))
411
  # Filter out NSE events for WER calculation (only use speech)
412
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
413
  # NEW: Directly use segment text instead of tokens
414
  return " ".join([seg["text"] for seg in speech_segments])
415
 
416
- ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
 
 
 
 
 
 
 
 
 
 
 
417
  transform = Compose([ToLowerCase(), RemovePunctuation(),
418
- RemoveMultipleSpaces(), Strip()])
419
  results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
420
  results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)
421
 
 
406
  results["DER"] = round(der_score * 100, 2)
407
 
408
  if ref_json and os.path.exists(hyp_json):
409
+ def load_words_from_hypothesis(path):
410
+ """Load text from YOUR pipeline output (has 'text' field)"""
411
  data = json.load(open(path))
412
  # Filter out NSE events for WER calculation (only use speech)
413
  speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
414
  # NEW: Directly use segment text instead of tokens
415
  return " ".join([seg["text"] for seg in speech_segments])
416
 
417
+ def load_words_from_reference(path):
418
+ """Load text from REFERENCE file (has 'tokens' field)"""
419
+ data = json.load(open(path))
420
+ # Filter out NSE events for WER calculation (only use speech)
421
+ speech_segments = [seg for seg in data if seg.get("speaker") != "NSE"]
422
+ # Reference format has tokens, not direct text
423
+ return " ".join([tok["text"] for seg in speech_segments for tok in seg["tokens"]])
424
+
425
+ # Use appropriate loader for each file
426
+ ref_text = load_words_from_reference(ref_json)
427
+ hyp_text = load_words_from_hypothesis(hyp_json)
428
+
429
  transform = Compose([ToLowerCase(), RemovePunctuation(),
430
+ RemoveMultipleSpaces(), Strip()])
431
  results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
432
  results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)
433