calettippo commited on
Commit
9bd97fd
·
1 Parent(s): e1d57cb

Add postprocessing commented

Browse files
Files changed (1) hide show
  1. app.py +341 -5
app.py CHANGED
@@ -20,11 +20,28 @@ from pydub.silence import split_on_silence
20
  import soundfile as sf
21
  import noisereduce
22
  from huggingface_hub import snapshot_download
 
23
 
24
  load_dotenv()
25
 
26
  # Audio preprocessing available with required dependencies
27
  PREPROCESSING_AVAILABLE = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  # Shared caches to keep models/pipelines in memory across requests
@@ -33,6 +50,10 @@ PIPELINE_CACHE_LOCK = threading.Lock()
33
  MODEL_PATH_CACHE: Dict[str, str] = {}
34
  MODEL_PATH_CACHE_LOCK = threading.Lock()
35
 
 
 
 
 
36
 
37
  def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
38
  """Get environment variable or default."""
@@ -161,6 +182,16 @@ def warm_model_cache() -> None:
161
  if model_id and model_id != base_model_id:
162
  models_to_check.append((model_id, "fine-tuned"))
163
 
 
 
 
 
 
 
 
 
 
 
164
  for model_name, label in models_to_check:
165
  try:
166
  logger.info("Verifying %s model cache for %s", label, model_name)
@@ -481,6 +512,298 @@ def load_asr_pipeline(
481
  return asr, final_device, final_dtype_name
482
 
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  @contextmanager
485
  def memory_monitor():
486
  """Context manager to monitor memory usage during inference."""
@@ -709,7 +1032,8 @@ def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") ->
709
  def transcribe_comparison(audio_file):
710
  """Main function for Gradio interface."""
711
  if audio_file is None:
712
- return "❌ Nessun file audio fornito", "❌ Nessun file audio fornito"
 
713
 
714
  # Model configuration
715
  model_id = get_env_or_secret("HF_MODEL_ID")
@@ -720,7 +1044,7 @@ def transcribe_comparison(audio_file):
720
 
721
  if not model_id or not base_model_id:
722
  error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
723
- return error_msg, error_msg
724
 
725
  # Preprocessing sempre attivo: normalizzazione formato, volume, riduzione rumore, rimozione silenzi
726
  # Viene applicato automaticamente prima della trascrizione con entrambi i modelli
@@ -742,6 +1066,7 @@ def transcribe_comparison(audio_file):
742
  finetuned_result = None
743
  original_text = ""
744
  finetuned_text = ""
 
745
 
746
  try:
747
  # Transcribe with original model
@@ -819,16 +1144,18 @@ def transcribe_comparison(audio_file):
819
  except Exception as e:
820
  finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
821
 
 
 
822
  # GPU memory cleanup
823
  if torch.cuda.is_available():
824
  torch.cuda.empty_cache()
825
  gc.collect()
826
 
827
- return original_text, finetuned_text
828
 
829
  except Exception as e:
830
  error_msg = f"❌ Errore generale: {str(e)}"
831
- return error_msg, error_msg
832
 
833
 
834
  # Gradio interface
@@ -920,11 +1247,20 @@ def create_interface():
920
  show_copy_button=True,
921
  )
922
 
 
 
 
 
 
 
 
 
 
923
  # Click event
924
  transcribe_btn.click(
925
  fn=transcribe_comparison,
926
  inputs=[audio_input],
927
- outputs=[original_output, finetuned_output],
928
  show_progress=True,
929
  )
930
 
 
20
  import soundfile as sf
21
  import noisereduce
22
  from huggingface_hub import snapshot_download
23
+ from transformers import pipeline
24
 
25
  load_dotenv()
26
 
27
  # Audio preprocessing available with required dependencies
28
  PREPROCESSING_AVAILABLE = True
29
+ DEFAULT_TEXT_POSTPROCESS_MODEL = "google/medgemma-4b-it"
30
+ TEXT_POSTPROCESS_PROMPT = (
31
+ "Agisci come assistente editoriale clinico. Prendi la trascrizione fornita, correggi"
32
+ " eventuali errori di riconoscimento automatico e migliora la grammatica mantenendo"
33
+ " il significato. Anonimizza inoltre il testo sostituendo nomi propri di persone con"
34
+ " segnaposto [PAZIENTE] o [MEDICO] a seconda del ruolo implicato. Non inventare"
35
+ " informazioni nuove, non tradurre. Restituisci solo la versione finale pulita"
36
+ " e pseudonimizzata in italiano, senza preamboli né spiegazioni."
37
+ "\nEsempio 1 - Input: 'Buongiorno dottor Rossi, sono Maria Bianchi e ho prenotato l'holter.'"
38
+ "\nEsempio 1 - Output: 'Buongiorno [MEDICO], sono [PAZIENTE] e ho prenotato l'Holter.'"
39
+ "\nEsempio 2 - Input: 'Il paziente Claudio Caletti riferisce che la dottoressa Neri gli ha prescritto Coumadin.'"
40
+ "\nEsempio 2 - Output: '[PAZIENTE] riferisce che [MEDICO] gli ha prescritto Coumadin.'"
41
+ "\nEsempio 3 - Input: 'Dott.ssa Gallo, ho parlato con la collega Francesca e confermiamo l'intervento.'"
42
+ "\nEsempio 3 - Output: '[MEDICO], ho parlato con [MEDICO] e confermiamo l'intervento.'"
43
+ "\nTesto originale:\n"
44
+ )
45
 
46
 
47
  # Shared caches to keep models/pipelines in memory across requests
 
50
  MODEL_PATH_CACHE: Dict[str, str] = {}
51
  MODEL_PATH_CACHE_LOCK = threading.Lock()
52
 
53
+ TEXT_POSTPROCESS_PIPELINE: Optional[Any] = None
54
+ TEXT_POSTPROCESS_MODEL_ID: Optional[str] = None
55
+ TEXT_POSTPROCESS_PIPELINE_LOCK = threading.Lock()
56
+
57
 
58
  def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
59
  """Get environment variable or default."""
 
182
  if model_id and model_id != base_model_id:
183
  models_to_check.append((model_id, "fine-tuned"))
184
 
185
+ text_postprocess_enabled = get_env_or_secret("TEXT_POSTPROCESS_ENABLED", "false").lower() in {
186
+ "1",
187
+ "true",
188
+ "yes",
189
+ }
190
+
191
+ text_model_id = get_env_or_secret("TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL)
192
+ if text_postprocess_enabled and text_model_id:
193
+ models_to_check.append((text_model_id, "text-postprocess"))
194
+
195
  for model_name, label in models_to_check:
196
  try:
197
  logger.info("Verifying %s model cache for %s", label, model_name)
 
512
  return asr, final_device, final_dtype_name
513
 
514
 
515
+ def get_text_postprocess_pipeline(
516
+ model_id: str,
517
+ device_pref: Optional[str],
518
+ hf_token: Optional[str],
519
+ ) -> Any:
520
+ """Load a minimal text-generation pipeline for post-processing."""
521
+
522
+ logger = logging.getLogger(__name__)
523
+ if not model_id:
524
+ raise ValueError("Model id for text post-processing is not configured")
525
+
526
+ normalized_device_pref = (device_pref or "auto").lower()
527
+ if normalized_device_pref == "auto":
528
+ if torch.cuda.is_available():
529
+ device_choice = "cuda"
530
+ elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
531
+ device_choice = "mps"
532
+ else:
533
+ device_choice = "cpu"
534
+ else:
535
+ device_choice = normalized_device_pref
536
+
537
+ device_argument: Any
538
+ dtype: Optional[torch.dtype] = None
539
+ if device_choice.startswith("cuda") and torch.cuda.is_available():
540
+ device_argument = device_choice
541
+ dtype = torch.bfloat16
542
+ elif device_choice == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
543
+ device_argument = "mps"
544
+ dtype = torch.float16
545
+ else:
546
+ device_argument = "cpu"
547
+ dtype = None
548
+
549
+ global TEXT_POSTPROCESS_PIPELINE, TEXT_POSTPROCESS_MODEL_ID
550
+
551
+ with TEXT_POSTPROCESS_PIPELINE_LOCK:
552
+ if (
553
+ TEXT_POSTPROCESS_PIPELINE is not None
554
+ and TEXT_POSTPROCESS_MODEL_ID == model_id
555
+ ):
556
+ return TEXT_POSTPROCESS_PIPELINE
557
+
558
+ model_source = ensure_local_model(model_id, hf_token=hf_token)
559
+
560
+ is_medgemma = "medgemma" in model_id.lower()
561
+
562
+ if is_medgemma:
563
+ pipe_kwargs: Dict[str, Any] = {
564
+ "task": "image-text-to-text",
565
+ "model": model_source,
566
+ "device": device_argument,
567
+ }
568
+ if dtype is not None:
569
+ pipe_kwargs["torch_dtype"] = dtype
570
+ else:
571
+ pipe_kwargs = {
572
+ "task": "text-generation",
573
+ "model": model_source,
574
+ "device": device_argument,
575
+ "tokenizer": model_source,
576
+ }
577
+ if dtype is not None:
578
+ pipe_kwargs["torch_dtype"] = dtype
579
+ if device_argument != "cpu":
580
+ pipe_kwargs["device_map"] = "auto"
581
+
582
+ logger.info(
583
+ "Loading postprocess pipeline for %s with device=%s, dtype=%s",
584
+ model_id,
585
+ device_argument,
586
+ str(dtype) if dtype is not None else "auto",
587
+ )
588
+
589
+ try:
590
+ postprocess_pipe = pipeline(**pipe_kwargs)
591
+ except Exception as primary_error:
592
+ logger.warning(
593
+ "Postprocess pipeline init failed on %s (%s). Falling back to CPU.",
594
+ device_argument,
595
+ primary_error,
596
+ )
597
+ pipe_kwargs["device"] = "cpu"
598
+ pipe_kwargs.pop("torch_dtype", None)
599
+ pipe_kwargs.pop("device_map", None)
600
+ postprocess_pipe = pipeline(**pipe_kwargs)
601
+
602
+ TEXT_POSTPROCESS_PIPELINE = postprocess_pipe
603
+ TEXT_POSTPROCESS_MODEL_ID = model_id
604
+ return postprocess_pipe
605
+
606
+
607
+ def postprocess_transcription_text(
608
+ text: str,
609
+ context_label: str,
610
+ ) -> str:
611
+ """Run MedGemma post-processing to clean transcription text."""
612
+
613
+ if not text or not text.strip():
614
+ return text
615
+
616
+ logger = logging.getLogger(__name__)
617
+
618
+ text_postprocess_enabled = get_env_or_secret("TEXT_POSTPROCESS_ENABLED", "false").lower() in {
619
+ "1",
620
+ "true",
621
+ "yes",
622
+ }
623
+ if not text_postprocess_enabled:
624
+ logger.debug(
625
+ "Text post-processing skipped for %s: feature disabled",
626
+ context_label,
627
+ )
628
+ return text
629
+
630
+ model_id = get_env_or_secret("TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL)
631
+ if not model_id:
632
+ logger.info("Text post-processing disabled: no model configured")
633
+ return text
634
+
635
+ hf_token = get_env_or_secret("TEXT_POSTPROCESS_HF_TOKEN") or get_env_or_secret(
636
+ "HF_TOKEN"
637
+ )
638
+ device_pref = get_env_or_secret("TEXT_POSTPROCESS_DEVICE", "auto")
639
+ max_new_tokens = int(get_env_or_secret("TEXT_POSTPROCESS_MAX_NEW", "200"))
640
+
641
+ prompt_body = text.strip()
642
+ prompt = f"{TEXT_POSTPROCESS_PROMPT}{prompt_body}\nRisultato:"
643
+ is_medgemma = "medgemma" in model_id.lower()
644
+
645
+ try:
646
+ postprocess_pipe = get_text_postprocess_pipeline(
647
+ model_id=model_id,
648
+ device_pref=device_pref,
649
+ hf_token=hf_token,
650
+ )
651
+
652
+ if is_medgemma:
653
+ system_prompt, separator, _ = TEXT_POSTPROCESS_PROMPT.partition("\nTesto originale:\n")
654
+ if not separator:
655
+ system_prompt = TEXT_POSTPROCESS_PROMPT
656
+ user_prefix = ""
657
+ else:
658
+ user_prefix = "Testo originale:\n"
659
+ system_prompt = system_prompt.strip()
660
+ messages = [
661
+ {
662
+ "role": "system",
663
+ "content": [{"type": "text", "text": system_prompt.strip()}],
664
+ },
665
+ {
666
+ "role": "user",
667
+ "content": [
668
+ {
669
+ "type": "text",
670
+ "text": f"{user_prefix}{prompt_body}\nRisultato:",
671
+ }
672
+ ],
673
+ },
674
+ ]
675
+
676
+ outputs = postprocess_pipe(
677
+ text=messages,
678
+ max_new_tokens=max_new_tokens,
679
+ )
680
+
681
+ generated_text = ""
682
+ if isinstance(outputs, list) and outputs:
683
+ first = outputs[0]
684
+ if isinstance(first, dict):
685
+ generated = first.get("generated_text")
686
+ if isinstance(generated, list):
687
+ # Prefer the latest assistant-like turn
688
+ for msg in reversed(generated):
689
+ if not isinstance(msg, dict):
690
+ continue
691
+ role = msg.get("role")
692
+ if role not in {"assistant", "model", None}:
693
+ continue
694
+ content = msg.get("content")
695
+ if isinstance(content, list):
696
+ for block in content:
697
+ if (
698
+ isinstance(block, dict)
699
+ and block.get("type") == "text"
700
+ ):
701
+ text_block = (block.get("text") or "").strip()
702
+ if text_block:
703
+ generated_text = text_block
704
+ break
705
+ if generated_text:
706
+ break
707
+ elif isinstance(content, str) and content.strip():
708
+ generated_text = content.strip()
709
+ break
710
+ if not generated_text:
711
+ # Fallback: use the last text block regardless of role
712
+ for msg in reversed(generated):
713
+ if not isinstance(msg, dict):
714
+ continue
715
+ content = msg.get("content")
716
+ if isinstance(content, list):
717
+ for block in content:
718
+ if (
719
+ isinstance(block, dict)
720
+ and block.get("type") == "text"
721
+ and block.get("text")
722
+ ):
723
+ generated_text = block["text"].strip()
724
+ break
725
+ if generated_text:
726
+ break
727
+ elif isinstance(content, str) and content.strip():
728
+ generated_text = content.strip()
729
+ break
730
+ elif isinstance(generated, str):
731
+ generated_text = generated.strip()
732
+ elif isinstance(outputs, dict):
733
+ generated = outputs.get("generated_text")
734
+ if isinstance(generated, list):
735
+ for msg in reversed(generated):
736
+ if isinstance(msg, dict):
737
+ text_block = (
738
+ msg.get("text")
739
+ or msg.get("content")
740
+ or ""
741
+ )
742
+ if isinstance(text_block, str) and text_block.strip():
743
+ generated_text = text_block.strip()
744
+ break
745
+ elif isinstance(generated, str):
746
+ generated_text = generated.strip()
747
+
748
+ cleaned = generated_text
749
+ else:
750
+ outputs = postprocess_pipe(
751
+ prompt,
752
+ max_new_tokens=max_new_tokens,
753
+ do_sample=False,
754
+ return_full_text=False,
755
+ )
756
+
757
+ generated_text = ""
758
+ if isinstance(outputs, list) and outputs:
759
+ first = outputs[0]
760
+ if isinstance(first, dict):
761
+ candidate = first.get("generated_text") or first.get("text")
762
+ if isinstance(candidate, str):
763
+ generated_text = candidate
764
+ elif isinstance(candidate, list):
765
+ generated_text = " ".join(
766
+ part for part in candidate if isinstance(part, str)
767
+ )
768
+ elif isinstance(first, str):
769
+ generated_text = first
770
+ elif isinstance(outputs, dict):
771
+ candidate = outputs.get("generated_text") or outputs.get("text")
772
+ if isinstance(candidate, str):
773
+ generated_text = candidate
774
+ elif isinstance(outputs, str):
775
+ generated_text = outputs
776
+
777
+ generated_text = (generated_text or "").strip()
778
+
779
+ if generated_text.startswith(prompt):
780
+ cleaned = generated_text[len(prompt) :].strip()
781
+ else:
782
+ cleaned = generated_text
783
+
784
+ if cleaned:
785
+ if cleaned.startswith(prompt_body):
786
+ cleaned = cleaned[len(prompt_body) :].strip()
787
+ if cleaned.startswith("Risultato:"):
788
+ cleaned = cleaned[len("Risultato:") :].strip()
789
+ if cleaned.lower().startswith("risultato:"):
790
+ cleaned = cleaned[len("risultato:") :].strip()
791
+ logger.debug("Post-processing successful for %s", context_label)
792
+ return cleaned or text
793
+
794
+ logger.warning("Post-processing returned empty output for %s", context_label)
795
+ return text
796
+
797
+ except Exception as exc:
798
+ logger.warning(
799
+ "Text post-processing failed for %s with model %s: %s",
800
+ context_label,
801
+ model_id,
802
+ exc,
803
+ )
804
+ return text
805
+
806
+
807
  @contextmanager
808
  def memory_monitor():
809
  """Context manager to monitor memory usage during inference."""
 
1032
  def transcribe_comparison(audio_file):
1033
  """Main function for Gradio interface."""
1034
  if audio_file is None:
1035
+ warning = "❌ Nessun file audio fornito"
1036
+ return warning, warning, warning
1037
 
1038
  # Model configuration
1039
  model_id = get_env_or_secret("HF_MODEL_ID")
 
1044
 
1045
  if not model_id or not base_model_id:
1046
  error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
1047
+ return error_msg, error_msg, error_msg
1048
 
1049
  # Preprocessing sempre attivo: normalizzazione formato, volume, riduzione rumore, rimozione silenzi
1050
  # Viene applicato automaticamente prima della trascrizione con entrambi i modelli
 
1066
  finetuned_result = None
1067
  original_text = ""
1068
  finetuned_text = ""
1069
+ postprocessed_text = ""
1070
 
1071
  try:
1072
  # Transcribe with original model
 
1144
  except Exception as e:
1145
  finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
1146
 
1147
+ postprocessed_text = finetuned_text or ""
1148
+
1149
  # GPU memory cleanup
1150
  if torch.cuda.is_available():
1151
  torch.cuda.empty_cache()
1152
  gc.collect()
1153
 
1154
+ return original_text, finetuned_text, postprocessed_text
1155
 
1156
  except Exception as e:
1157
  error_msg = f"❌ Errore generale: {str(e)}"
1158
+ return error_msg, error_msg, error_msg
1159
 
1160
 
1161
  # Gradio interface
 
1247
  show_copy_button=True,
1248
  )
1249
 
1250
+ # Post-processing disabilitato temporaneamente: manteniamo il widget ma nascosto
1251
+ medgemma_output = gr.Textbox(
1252
+ label="Testo finale",
1253
+ lines=12,
1254
+ interactive=False,
1255
+ show_copy_button=True,
1256
+ visible=False,
1257
+ )
1258
+
1259
  # Click event
1260
  transcribe_btn.click(
1261
  fn=transcribe_comparison,
1262
  inputs=[audio_input],
1263
+ outputs=[original_output, finetuned_output, medgemma_output],
1264
  show_progress=True,
1265
  )
1266