Commit
·
9bd97fd
1
Parent(s):
e1d57cb
Add postprocessing commented
Browse files
app.py
CHANGED
|
@@ -20,11 +20,28 @@ from pydub.silence import split_on_silence
|
|
| 20 |
import soundfile as sf
|
| 21 |
import noisereduce
|
| 22 |
from huggingface_hub import snapshot_download
|
|
|
|
| 23 |
|
| 24 |
load_dotenv()
|
| 25 |
|
| 26 |
# Audio preprocessing available with required dependencies
|
| 27 |
PREPROCESSING_AVAILABLE = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
# Shared caches to keep models/pipelines in memory across requests
|
|
@@ -33,6 +50,10 @@ PIPELINE_CACHE_LOCK = threading.Lock()
|
|
| 33 |
MODEL_PATH_CACHE: Dict[str, str] = {}
|
| 34 |
MODEL_PATH_CACHE_LOCK = threading.Lock()
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
|
| 38 |
"""Get environment variable or default."""
|
|
@@ -161,6 +182,16 @@ def warm_model_cache() -> None:
|
|
| 161 |
if model_id and model_id != base_model_id:
|
| 162 |
models_to_check.append((model_id, "fine-tuned"))
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
for model_name, label in models_to_check:
|
| 165 |
try:
|
| 166 |
logger.info("Verifying %s model cache for %s", label, model_name)
|
|
@@ -481,6 +512,298 @@ def load_asr_pipeline(
|
|
| 481 |
return asr, final_device, final_dtype_name
|
| 482 |
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
@contextmanager
|
| 485 |
def memory_monitor():
|
| 486 |
"""Context manager to monitor memory usage during inference."""
|
|
@@ -709,7 +1032,8 @@ def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") ->
|
|
| 709 |
def transcribe_comparison(audio_file):
|
| 710 |
"""Main function for Gradio interface."""
|
| 711 |
if audio_file is None:
|
| 712 |
-
|
|
|
|
| 713 |
|
| 714 |
# Model configuration
|
| 715 |
model_id = get_env_or_secret("HF_MODEL_ID")
|
|
@@ -720,7 +1044,7 @@ def transcribe_comparison(audio_file):
|
|
| 720 |
|
| 721 |
if not model_id or not base_model_id:
|
| 722 |
error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
|
| 723 |
-
return error_msg, error_msg
|
| 724 |
|
| 725 |
# Preprocessing sempre attivo: normalizzazione formato, volume, riduzione rumore, rimozione silenzi
|
| 726 |
# Viene applicato automaticamente prima della trascrizione con entrambi i modelli
|
|
@@ -742,6 +1066,7 @@ def transcribe_comparison(audio_file):
|
|
| 742 |
finetuned_result = None
|
| 743 |
original_text = ""
|
| 744 |
finetuned_text = ""
|
|
|
|
| 745 |
|
| 746 |
try:
|
| 747 |
# Transcribe with original model
|
|
@@ -819,16 +1144,18 @@ def transcribe_comparison(audio_file):
|
|
| 819 |
except Exception as e:
|
| 820 |
finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
|
| 821 |
|
|
|
|
|
|
|
| 822 |
# GPU memory cleanup
|
| 823 |
if torch.cuda.is_available():
|
| 824 |
torch.cuda.empty_cache()
|
| 825 |
gc.collect()
|
| 826 |
|
| 827 |
-
return original_text, finetuned_text
|
| 828 |
|
| 829 |
except Exception as e:
|
| 830 |
error_msg = f"❌ Errore generale: {str(e)}"
|
| 831 |
-
return error_msg, error_msg
|
| 832 |
|
| 833 |
|
| 834 |
# Gradio interface
|
|
@@ -920,11 +1247,20 @@ def create_interface():
|
|
| 920 |
show_copy_button=True,
|
| 921 |
)
|
| 922 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 923 |
# Click event
|
| 924 |
transcribe_btn.click(
|
| 925 |
fn=transcribe_comparison,
|
| 926 |
inputs=[audio_input],
|
| 927 |
-
outputs=[original_output, finetuned_output],
|
| 928 |
show_progress=True,
|
| 929 |
)
|
| 930 |
|
|
|
|
| 20 |
import soundfile as sf
|
| 21 |
import noisereduce
|
| 22 |
from huggingface_hub import snapshot_download
|
| 23 |
+
from transformers import pipeline
|
| 24 |
|
| 25 |
load_dotenv()
|
| 26 |
|
| 27 |
# Audio preprocessing available with required dependencies
|
| 28 |
PREPROCESSING_AVAILABLE = True
|
| 29 |
+
DEFAULT_TEXT_POSTPROCESS_MODEL = "google/medgemma-4b-it"
|
| 30 |
+
TEXT_POSTPROCESS_PROMPT = (
|
| 31 |
+
"Agisci come assistente editoriale clinico. Prendi la trascrizione fornita, correggi"
|
| 32 |
+
" eventuali errori di riconoscimento automatico e migliora la grammatica mantenendo"
|
| 33 |
+
" il significato. Anonimizza inoltre il testo sostituendo nomi propri di persone con"
|
| 34 |
+
" segnaposto [PAZIENTE] o [MEDICO] a seconda del ruolo implicato. Non inventare"
|
| 35 |
+
" informazioni nuove, non tradurre. Restituisci solo la versione finale pulita"
|
| 36 |
+
" e pseudonimizzata in italiano, senza preamboli né spiegazioni."
|
| 37 |
+
"\nEsempio 1 - Input: 'Buongiorno dottor Rossi, sono Maria Bianchi e ho prenotato l'holter.'"
|
| 38 |
+
"\nEsempio 1 - Output: 'Buongiorno [MEDICO], sono [PAZIENTE] e ho prenotato l'Holter.'"
|
| 39 |
+
"\nEsempio 2 - Input: 'Il paziente Claudio Caletti riferisce che la dottoressa Neri gli ha prescritto Coumadin.'"
|
| 40 |
+
"\nEsempio 2 - Output: '[PAZIENTE] riferisce che [MEDICO] gli ha prescritto Coumadin.'"
|
| 41 |
+
"\nEsempio 3 - Input: 'Dott.ssa Gallo, ho parlato con la collega Francesca e confermiamo l'intervento.'"
|
| 42 |
+
"\nEsempio 3 - Output: '[MEDICO], ho parlato con [MEDICO] e confermiamo l'intervento.'"
|
| 43 |
+
"\nTesto originale:\n"
|
| 44 |
+
)
|
| 45 |
|
| 46 |
|
| 47 |
# Shared caches to keep models/pipelines in memory across requests
|
|
|
|
| 50 |
MODEL_PATH_CACHE: Dict[str, str] = {}
|
| 51 |
MODEL_PATH_CACHE_LOCK = threading.Lock()
|
| 52 |
|
| 53 |
+
TEXT_POSTPROCESS_PIPELINE: Optional[Any] = None
|
| 54 |
+
TEXT_POSTPROCESS_MODEL_ID: Optional[str] = None
|
| 55 |
+
TEXT_POSTPROCESS_PIPELINE_LOCK = threading.Lock()
|
| 56 |
+
|
| 57 |
|
| 58 |
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]:
|
| 59 |
"""Get environment variable or default."""
|
|
|
|
| 182 |
if model_id and model_id != base_model_id:
|
| 183 |
models_to_check.append((model_id, "fine-tuned"))
|
| 184 |
|
| 185 |
+
text_postprocess_enabled = get_env_or_secret("TEXT_POSTPROCESS_ENABLED", "false").lower() in {
|
| 186 |
+
"1",
|
| 187 |
+
"true",
|
| 188 |
+
"yes",
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
text_model_id = get_env_or_secret("TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL)
|
| 192 |
+
if text_postprocess_enabled and text_model_id:
|
| 193 |
+
models_to_check.append((text_model_id, "text-postprocess"))
|
| 194 |
+
|
| 195 |
for model_name, label in models_to_check:
|
| 196 |
try:
|
| 197 |
logger.info("Verifying %s model cache for %s", label, model_name)
|
|
|
|
| 512 |
return asr, final_device, final_dtype_name
|
| 513 |
|
| 514 |
|
| 515 |
+
def get_text_postprocess_pipeline(
|
| 516 |
+
model_id: str,
|
| 517 |
+
device_pref: Optional[str],
|
| 518 |
+
hf_token: Optional[str],
|
| 519 |
+
) -> Any:
|
| 520 |
+
"""Load a minimal text-generation pipeline for post-processing."""
|
| 521 |
+
|
| 522 |
+
logger = logging.getLogger(__name__)
|
| 523 |
+
if not model_id:
|
| 524 |
+
raise ValueError("Model id for text post-processing is not configured")
|
| 525 |
+
|
| 526 |
+
normalized_device_pref = (device_pref or "auto").lower()
|
| 527 |
+
if normalized_device_pref == "auto":
|
| 528 |
+
if torch.cuda.is_available():
|
| 529 |
+
device_choice = "cuda"
|
| 530 |
+
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 531 |
+
device_choice = "mps"
|
| 532 |
+
else:
|
| 533 |
+
device_choice = "cpu"
|
| 534 |
+
else:
|
| 535 |
+
device_choice = normalized_device_pref
|
| 536 |
+
|
| 537 |
+
device_argument: Any
|
| 538 |
+
dtype: Optional[torch.dtype] = None
|
| 539 |
+
if device_choice.startswith("cuda") and torch.cuda.is_available():
|
| 540 |
+
device_argument = device_choice
|
| 541 |
+
dtype = torch.bfloat16
|
| 542 |
+
elif device_choice == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
| 543 |
+
device_argument = "mps"
|
| 544 |
+
dtype = torch.float16
|
| 545 |
+
else:
|
| 546 |
+
device_argument = "cpu"
|
| 547 |
+
dtype = None
|
| 548 |
+
|
| 549 |
+
global TEXT_POSTPROCESS_PIPELINE, TEXT_POSTPROCESS_MODEL_ID
|
| 550 |
+
|
| 551 |
+
with TEXT_POSTPROCESS_PIPELINE_LOCK:
|
| 552 |
+
if (
|
| 553 |
+
TEXT_POSTPROCESS_PIPELINE is not None
|
| 554 |
+
and TEXT_POSTPROCESS_MODEL_ID == model_id
|
| 555 |
+
):
|
| 556 |
+
return TEXT_POSTPROCESS_PIPELINE
|
| 557 |
+
|
| 558 |
+
model_source = ensure_local_model(model_id, hf_token=hf_token)
|
| 559 |
+
|
| 560 |
+
is_medgemma = "medgemma" in model_id.lower()
|
| 561 |
+
|
| 562 |
+
if is_medgemma:
|
| 563 |
+
pipe_kwargs: Dict[str, Any] = {
|
| 564 |
+
"task": "image-text-to-text",
|
| 565 |
+
"model": model_source,
|
| 566 |
+
"device": device_argument,
|
| 567 |
+
}
|
| 568 |
+
if dtype is not None:
|
| 569 |
+
pipe_kwargs["torch_dtype"] = dtype
|
| 570 |
+
else:
|
| 571 |
+
pipe_kwargs = {
|
| 572 |
+
"task": "text-generation",
|
| 573 |
+
"model": model_source,
|
| 574 |
+
"device": device_argument,
|
| 575 |
+
"tokenizer": model_source,
|
| 576 |
+
}
|
| 577 |
+
if dtype is not None:
|
| 578 |
+
pipe_kwargs["torch_dtype"] = dtype
|
| 579 |
+
if device_argument != "cpu":
|
| 580 |
+
pipe_kwargs["device_map"] = "auto"
|
| 581 |
+
|
| 582 |
+
logger.info(
|
| 583 |
+
"Loading postprocess pipeline for %s with device=%s, dtype=%s",
|
| 584 |
+
model_id,
|
| 585 |
+
device_argument,
|
| 586 |
+
str(dtype) if dtype is not None else "auto",
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
try:
|
| 590 |
+
postprocess_pipe = pipeline(**pipe_kwargs)
|
| 591 |
+
except Exception as primary_error:
|
| 592 |
+
logger.warning(
|
| 593 |
+
"Postprocess pipeline init failed on %s (%s). Falling back to CPU.",
|
| 594 |
+
device_argument,
|
| 595 |
+
primary_error,
|
| 596 |
+
)
|
| 597 |
+
pipe_kwargs["device"] = "cpu"
|
| 598 |
+
pipe_kwargs.pop("torch_dtype", None)
|
| 599 |
+
pipe_kwargs.pop("device_map", None)
|
| 600 |
+
postprocess_pipe = pipeline(**pipe_kwargs)
|
| 601 |
+
|
| 602 |
+
TEXT_POSTPROCESS_PIPELINE = postprocess_pipe
|
| 603 |
+
TEXT_POSTPROCESS_MODEL_ID = model_id
|
| 604 |
+
return postprocess_pipe
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def postprocess_transcription_text(
|
| 608 |
+
text: str,
|
| 609 |
+
context_label: str,
|
| 610 |
+
) -> str:
|
| 611 |
+
"""Run MedGemma post-processing to clean transcription text."""
|
| 612 |
+
|
| 613 |
+
if not text or not text.strip():
|
| 614 |
+
return text
|
| 615 |
+
|
| 616 |
+
logger = logging.getLogger(__name__)
|
| 617 |
+
|
| 618 |
+
text_postprocess_enabled = get_env_or_secret("TEXT_POSTPROCESS_ENABLED", "false").lower() in {
|
| 619 |
+
"1",
|
| 620 |
+
"true",
|
| 621 |
+
"yes",
|
| 622 |
+
}
|
| 623 |
+
if not text_postprocess_enabled:
|
| 624 |
+
logger.debug(
|
| 625 |
+
"Text post-processing skipped for %s: feature disabled",
|
| 626 |
+
context_label,
|
| 627 |
+
)
|
| 628 |
+
return text
|
| 629 |
+
|
| 630 |
+
model_id = get_env_or_secret("TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL)
|
| 631 |
+
if not model_id:
|
| 632 |
+
logger.info("Text post-processing disabled: no model configured")
|
| 633 |
+
return text
|
| 634 |
+
|
| 635 |
+
hf_token = get_env_or_secret("TEXT_POSTPROCESS_HF_TOKEN") or get_env_or_secret(
|
| 636 |
+
"HF_TOKEN"
|
| 637 |
+
)
|
| 638 |
+
device_pref = get_env_or_secret("TEXT_POSTPROCESS_DEVICE", "auto")
|
| 639 |
+
max_new_tokens = int(get_env_or_secret("TEXT_POSTPROCESS_MAX_NEW", "200"))
|
| 640 |
+
|
| 641 |
+
prompt_body = text.strip()
|
| 642 |
+
prompt = f"{TEXT_POSTPROCESS_PROMPT}{prompt_body}\nRisultato:"
|
| 643 |
+
is_medgemma = "medgemma" in model_id.lower()
|
| 644 |
+
|
| 645 |
+
try:
|
| 646 |
+
postprocess_pipe = get_text_postprocess_pipeline(
|
| 647 |
+
model_id=model_id,
|
| 648 |
+
device_pref=device_pref,
|
| 649 |
+
hf_token=hf_token,
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
if is_medgemma:
|
| 653 |
+
system_prompt, separator, _ = TEXT_POSTPROCESS_PROMPT.partition("\nTesto originale:\n")
|
| 654 |
+
if not separator:
|
| 655 |
+
system_prompt = TEXT_POSTPROCESS_PROMPT
|
| 656 |
+
user_prefix = ""
|
| 657 |
+
else:
|
| 658 |
+
user_prefix = "Testo originale:\n"
|
| 659 |
+
system_prompt = system_prompt.strip()
|
| 660 |
+
messages = [
|
| 661 |
+
{
|
| 662 |
+
"role": "system",
|
| 663 |
+
"content": [{"type": "text", "text": system_prompt.strip()}],
|
| 664 |
+
},
|
| 665 |
+
{
|
| 666 |
+
"role": "user",
|
| 667 |
+
"content": [
|
| 668 |
+
{
|
| 669 |
+
"type": "text",
|
| 670 |
+
"text": f"{user_prefix}{prompt_body}\nRisultato:",
|
| 671 |
+
}
|
| 672 |
+
],
|
| 673 |
+
},
|
| 674 |
+
]
|
| 675 |
+
|
| 676 |
+
outputs = postprocess_pipe(
|
| 677 |
+
text=messages,
|
| 678 |
+
max_new_tokens=max_new_tokens,
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
generated_text = ""
|
| 682 |
+
if isinstance(outputs, list) and outputs:
|
| 683 |
+
first = outputs[0]
|
| 684 |
+
if isinstance(first, dict):
|
| 685 |
+
generated = first.get("generated_text")
|
| 686 |
+
if isinstance(generated, list):
|
| 687 |
+
# Prefer the latest assistant-like turn
|
| 688 |
+
for msg in reversed(generated):
|
| 689 |
+
if not isinstance(msg, dict):
|
| 690 |
+
continue
|
| 691 |
+
role = msg.get("role")
|
| 692 |
+
if role not in {"assistant", "model", None}:
|
| 693 |
+
continue
|
| 694 |
+
content = msg.get("content")
|
| 695 |
+
if isinstance(content, list):
|
| 696 |
+
for block in content:
|
| 697 |
+
if (
|
| 698 |
+
isinstance(block, dict)
|
| 699 |
+
and block.get("type") == "text"
|
| 700 |
+
):
|
| 701 |
+
text_block = (block.get("text") or "").strip()
|
| 702 |
+
if text_block:
|
| 703 |
+
generated_text = text_block
|
| 704 |
+
break
|
| 705 |
+
if generated_text:
|
| 706 |
+
break
|
| 707 |
+
elif isinstance(content, str) and content.strip():
|
| 708 |
+
generated_text = content.strip()
|
| 709 |
+
break
|
| 710 |
+
if not generated_text:
|
| 711 |
+
# Fallback: use the last text block regardless of role
|
| 712 |
+
for msg in reversed(generated):
|
| 713 |
+
if not isinstance(msg, dict):
|
| 714 |
+
continue
|
| 715 |
+
content = msg.get("content")
|
| 716 |
+
if isinstance(content, list):
|
| 717 |
+
for block in content:
|
| 718 |
+
if (
|
| 719 |
+
isinstance(block, dict)
|
| 720 |
+
and block.get("type") == "text"
|
| 721 |
+
and block.get("text")
|
| 722 |
+
):
|
| 723 |
+
generated_text = block["text"].strip()
|
| 724 |
+
break
|
| 725 |
+
if generated_text:
|
| 726 |
+
break
|
| 727 |
+
elif isinstance(content, str) and content.strip():
|
| 728 |
+
generated_text = content.strip()
|
| 729 |
+
break
|
| 730 |
+
elif isinstance(generated, str):
|
| 731 |
+
generated_text = generated.strip()
|
| 732 |
+
elif isinstance(outputs, dict):
|
| 733 |
+
generated = outputs.get("generated_text")
|
| 734 |
+
if isinstance(generated, list):
|
| 735 |
+
for msg in reversed(generated):
|
| 736 |
+
if isinstance(msg, dict):
|
| 737 |
+
text_block = (
|
| 738 |
+
msg.get("text")
|
| 739 |
+
or msg.get("content")
|
| 740 |
+
or ""
|
| 741 |
+
)
|
| 742 |
+
if isinstance(text_block, str) and text_block.strip():
|
| 743 |
+
generated_text = text_block.strip()
|
| 744 |
+
break
|
| 745 |
+
elif isinstance(generated, str):
|
| 746 |
+
generated_text = generated.strip()
|
| 747 |
+
|
| 748 |
+
cleaned = generated_text
|
| 749 |
+
else:
|
| 750 |
+
outputs = postprocess_pipe(
|
| 751 |
+
prompt,
|
| 752 |
+
max_new_tokens=max_new_tokens,
|
| 753 |
+
do_sample=False,
|
| 754 |
+
return_full_text=False,
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
generated_text = ""
|
| 758 |
+
if isinstance(outputs, list) and outputs:
|
| 759 |
+
first = outputs[0]
|
| 760 |
+
if isinstance(first, dict):
|
| 761 |
+
candidate = first.get("generated_text") or first.get("text")
|
| 762 |
+
if isinstance(candidate, str):
|
| 763 |
+
generated_text = candidate
|
| 764 |
+
elif isinstance(candidate, list):
|
| 765 |
+
generated_text = " ".join(
|
| 766 |
+
part for part in candidate if isinstance(part, str)
|
| 767 |
+
)
|
| 768 |
+
elif isinstance(first, str):
|
| 769 |
+
generated_text = first
|
| 770 |
+
elif isinstance(outputs, dict):
|
| 771 |
+
candidate = outputs.get("generated_text") or outputs.get("text")
|
| 772 |
+
if isinstance(candidate, str):
|
| 773 |
+
generated_text = candidate
|
| 774 |
+
elif isinstance(outputs, str):
|
| 775 |
+
generated_text = outputs
|
| 776 |
+
|
| 777 |
+
generated_text = (generated_text or "").strip()
|
| 778 |
+
|
| 779 |
+
if generated_text.startswith(prompt):
|
| 780 |
+
cleaned = generated_text[len(prompt) :].strip()
|
| 781 |
+
else:
|
| 782 |
+
cleaned = generated_text
|
| 783 |
+
|
| 784 |
+
if cleaned:
|
| 785 |
+
if cleaned.startswith(prompt_body):
|
| 786 |
+
cleaned = cleaned[len(prompt_body) :].strip()
|
| 787 |
+
if cleaned.startswith("Risultato:"):
|
| 788 |
+
cleaned = cleaned[len("Risultato:") :].strip()
|
| 789 |
+
if cleaned.lower().startswith("risultato:"):
|
| 790 |
+
cleaned = cleaned[len("risultato:") :].strip()
|
| 791 |
+
logger.debug("Post-processing successful for %s", context_label)
|
| 792 |
+
return cleaned or text
|
| 793 |
+
|
| 794 |
+
logger.warning("Post-processing returned empty output for %s", context_label)
|
| 795 |
+
return text
|
| 796 |
+
|
| 797 |
+
except Exception as exc:
|
| 798 |
+
logger.warning(
|
| 799 |
+
"Text post-processing failed for %s with model %s: %s",
|
| 800 |
+
context_label,
|
| 801 |
+
model_id,
|
| 802 |
+
exc,
|
| 803 |
+
)
|
| 804 |
+
return text
|
| 805 |
+
|
| 806 |
+
|
| 807 |
@contextmanager
|
| 808 |
def memory_monitor():
|
| 809 |
"""Context manager to monitor memory usage during inference."""
|
|
|
|
| 1032 |
def transcribe_comparison(audio_file):
|
| 1033 |
"""Main function for Gradio interface."""
|
| 1034 |
if audio_file is None:
|
| 1035 |
+
warning = "❌ Nessun file audio fornito"
|
| 1036 |
+
return warning, warning, warning
|
| 1037 |
|
| 1038 |
# Model configuration
|
| 1039 |
model_id = get_env_or_secret("HF_MODEL_ID")
|
|
|
|
| 1044 |
|
| 1045 |
if not model_id or not base_model_id:
|
| 1046 |
error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente"
|
| 1047 |
+
return error_msg, error_msg, error_msg
|
| 1048 |
|
| 1049 |
# Preprocessing sempre attivo: normalizzazione formato, volume, riduzione rumore, rimozione silenzi
|
| 1050 |
# Viene applicato automaticamente prima della trascrizione con entrambi i modelli
|
|
|
|
| 1066 |
finetuned_result = None
|
| 1067 |
original_text = ""
|
| 1068 |
finetuned_text = ""
|
| 1069 |
+
postprocessed_text = ""
|
| 1070 |
|
| 1071 |
try:
|
| 1072 |
# Transcribe with original model
|
|
|
|
| 1144 |
except Exception as e:
|
| 1145 |
finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}"
|
| 1146 |
|
| 1147 |
+
postprocessed_text = finetuned_text or ""
|
| 1148 |
+
|
| 1149 |
# GPU memory cleanup
|
| 1150 |
if torch.cuda.is_available():
|
| 1151 |
torch.cuda.empty_cache()
|
| 1152 |
gc.collect()
|
| 1153 |
|
| 1154 |
+
return original_text, finetuned_text, postprocessed_text
|
| 1155 |
|
| 1156 |
except Exception as e:
|
| 1157 |
error_msg = f"❌ Errore generale: {str(e)}"
|
| 1158 |
+
return error_msg, error_msg, error_msg
|
| 1159 |
|
| 1160 |
|
| 1161 |
# Gradio interface
|
|
|
|
| 1247 |
show_copy_button=True,
|
| 1248 |
)
|
| 1249 |
|
| 1250 |
+
# Post-processing disabilitato temporaneamente: manteniamo il widget ma nascosto
|
| 1251 |
+
medgemma_output = gr.Textbox(
|
| 1252 |
+
label="Testo finale",
|
| 1253 |
+
lines=12,
|
| 1254 |
+
interactive=False,
|
| 1255 |
+
show_copy_button=True,
|
| 1256 |
+
visible=False,
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
# Click event
|
| 1260 |
transcribe_btn.click(
|
| 1261 |
fn=transcribe_comparison,
|
| 1262 |
inputs=[audio_input],
|
| 1263 |
+
outputs=[original_output, finetuned_output, medgemma_output],
|
| 1264 |
show_progress=True,
|
| 1265 |
)
|
| 1266 |
|