|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import time |
|
|
import logging |
|
|
import gc |
|
|
import io |
|
|
import threading |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, List, Any, Dict |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import psutil |
|
|
from dotenv import load_dotenv |
|
|
import numpy as np |
|
|
from pydub import AudioSegment |
|
|
from pydub.silence import split_on_silence |
|
|
import soundfile as sf |
|
|
import noisereduce |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import pipeline |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
PREPROCESSING_AVAILABLE = True |
|
|
DEFAULT_TEXT_POSTPROCESS_MODEL = "google/medgemma-4b-it" |
|
|
TEXT_POSTPROCESS_PROMPT = ( |
|
|
"Agisci come assistente editoriale clinico. Prendi la trascrizione fornita, correggi" |
|
|
" eventuali errori di riconoscimento automatico e migliora la grammatica mantenendo" |
|
|
" il significato. Anonimizza inoltre il testo sostituendo nomi propri di persone con" |
|
|
" segnaposto [PAZIENTE] o [MEDICO] a seconda del ruolo implicato. Non inventare" |
|
|
" informazioni nuove, non tradurre. Restituisci solo la versione finale pulita" |
|
|
" e pseudonimizzata in italiano, senza preamboli né spiegazioni." |
|
|
"\nEsempio 1 - Input: 'Buongiorno dottor Rossi, sono Maria Bianchi e ho prenotato l'holter.'" |
|
|
"\nEsempio 1 - Output: 'Buongiorno [MEDICO], sono [PAZIENTE] e ho prenotato l'Holter.'" |
|
|
"\nEsempio 2 - Input: 'Il paziente Claudio Caletti riferisce che la dottoressa Neri gli ha prescritto Coumadin.'" |
|
|
"\nEsempio 2 - Output: '[PAZIENTE] riferisce che [MEDICO] gli ha prescritto Coumadin.'" |
|
|
"\nEsempio 3 - Input: 'Dott.ssa Gallo, ho parlato con la collega Francesca e confermiamo l'intervento.'" |
|
|
"\nEsempio 3 - Output: '[MEDICO], ho parlato con [MEDICO] e confermiamo l'intervento.'" |
|
|
"\nTesto originale:\n" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
PIPELINE_CACHE: Dict[Tuple[str, str, str], Tuple[Any, str, str]] = {} |
|
|
PIPELINE_CACHE_LOCK = threading.Lock() |
|
|
MODEL_PATH_CACHE: Dict[str, str] = {} |
|
|
MODEL_PATH_CACHE_LOCK = threading.Lock() |
|
|
|
|
|
TEXT_POSTPROCESS_PIPELINE: Optional[Any] = None |
|
|
TEXT_POSTPROCESS_MODEL_ID: Optional[str] = None |
|
|
TEXT_POSTPROCESS_PIPELINE_LOCK = threading.Lock() |
|
|
|
|
|
|
|
|
def get_env_or_secret(key: str, default: Optional[str] = None) -> Optional[str]: |
|
|
"""Get environment variable or default.""" |
|
|
return os.environ.get(key, default) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InferenceMetrics: |
|
|
"""Track inference performance metrics.""" |
|
|
|
|
|
processing_time: float |
|
|
memory_usage: float |
|
|
device_used: str |
|
|
dtype_used: str |
|
|
model_size_mb: Optional[float] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PreprocessingConfig: |
|
|
"""Configuration for audio preprocessing pipeline.""" |
|
|
|
|
|
normalize_format: bool = True |
|
|
normalize_volume: bool = True |
|
|
reduce_noise: bool = True |
|
|
remove_silence: bool = True |
|
|
|
|
|
|
|
|
def ensure_local_model(model_id: str, hf_token: Optional[str] = None) -> str: |
|
|
"""Ensure a model snapshot is available locally and return its path.""" |
|
|
|
|
|
if os.path.isdir(model_id): |
|
|
return model_id |
|
|
|
|
|
with MODEL_PATH_CACHE_LOCK: |
|
|
cached_path = MODEL_PATH_CACHE.get(model_id) |
|
|
if cached_path and os.path.isdir(cached_path): |
|
|
return cached_path |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
cache_root = get_env_or_secret("HF_MODEL_CACHE_DIR") |
|
|
if not cache_root: |
|
|
cache_root = os.path.join(os.path.dirname(__file__), "hf_models") |
|
|
|
|
|
os.makedirs(cache_root, exist_ok=True) |
|
|
local_dir = os.path.join(cache_root, model_id.replace("/", "__")) |
|
|
|
|
|
try: |
|
|
downloaded_path = snapshot_download( |
|
|
repo_id=model_id, |
|
|
token=hf_token, |
|
|
local_dir=local_dir, |
|
|
local_dir_use_symlinks=False, |
|
|
resume_download=True, |
|
|
) |
|
|
target_path = downloaded_path |
|
|
|
|
|
|
|
|
|
|
|
config_path = os.path.join(target_path, "config.json") |
|
|
if not os.path.isfile(config_path): |
|
|
snapshots_dir = os.path.join(target_path, "snapshots") |
|
|
if os.path.isdir(snapshots_dir): |
|
|
snapshot_candidates = sorted( |
|
|
( |
|
|
os.path.join(snapshots_dir, name) |
|
|
for name in os.listdir(snapshots_dir) |
|
|
if os.path.isdir(os.path.join(snapshots_dir, name)) |
|
|
), |
|
|
key=os.path.getmtime, |
|
|
reverse=True, |
|
|
) |
|
|
for candidate in snapshot_candidates: |
|
|
if os.path.isfile(os.path.join(candidate, "config.json")): |
|
|
target_path = candidate |
|
|
break |
|
|
|
|
|
downloaded_path = target_path |
|
|
except Exception as download_error: |
|
|
|
|
|
if os.path.isdir(local_dir) and os.listdir(local_dir): |
|
|
logger.warning( |
|
|
"Unable to refresh model %s from hub (%s), using existing files", |
|
|
model_id, |
|
|
download_error, |
|
|
) |
|
|
|
|
|
snapshots_dir = os.path.join(local_dir, "snapshots") |
|
|
if os.path.isdir(snapshots_dir): |
|
|
snapshot_candidates = sorted( |
|
|
( |
|
|
os.path.join(snapshots_dir, name) |
|
|
for name in os.listdir(snapshots_dir) |
|
|
if os.path.isdir(os.path.join(snapshots_dir, name)) |
|
|
), |
|
|
key=os.path.getmtime, |
|
|
reverse=True, |
|
|
) |
|
|
for candidate in snapshot_candidates: |
|
|
if os.path.isfile(os.path.join(candidate, "config.json")): |
|
|
downloaded_path = candidate |
|
|
break |
|
|
else: |
|
|
raise |
|
|
|
|
|
with MODEL_PATH_CACHE_LOCK: |
|
|
MODEL_PATH_CACHE[model_id] = downloaded_path |
|
|
|
|
|
return downloaded_path |
|
|
|
|
|
|
|
|
def warm_model_cache() -> None: |
|
|
"""Ensure the configured models are ready on disk.""" |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned") |
|
|
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium") |
|
|
hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret( |
|
|
"HUGGINGFACEHUB_API_TOKEN" |
|
|
) |
|
|
|
|
|
models_to_check: List[Tuple[str, str]] = [] |
|
|
if base_model_id: |
|
|
models_to_check.append((base_model_id, "base")) |
|
|
if model_id and model_id != base_model_id: |
|
|
models_to_check.append((model_id, "fine-tuned")) |
|
|
|
|
|
text_postprocess_enabled = get_env_or_secret( |
|
|
"TEXT_POSTPROCESS_ENABLED", "false" |
|
|
).lower() in { |
|
|
"1", |
|
|
"true", |
|
|
"yes", |
|
|
} |
|
|
|
|
|
text_model_id = get_env_or_secret( |
|
|
"TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL |
|
|
) |
|
|
if text_postprocess_enabled and text_model_id: |
|
|
models_to_check.append((text_model_id, "text-postprocess")) |
|
|
|
|
|
for model_name, label in models_to_check: |
|
|
try: |
|
|
logger.info("Verifying %s model cache for %s", label, model_name) |
|
|
local_path = ensure_local_model(model_name, hf_token=hf_token) |
|
|
logger.info("Model %s ready at %s", model_name, local_path) |
|
|
except Exception: |
|
|
logger.exception("Failed to prepare model %s", model_name) |
|
|
raise |
|
|
|
|
|
|
|
|
def normalize_audio(audio_bytes: bytes) -> bytes: |
|
|
""" |
|
|
Converte un chunk audio in bytes nel formato standard per Whisper. |
|
|
(16kHz, mono, WAV PCM) |
|
|
""" |
|
|
|
|
|
audio_segment = AudioSegment.from_file(io.BytesIO(audio_bytes)) |
|
|
|
|
|
|
|
|
audio_segment = audio_segment.set_frame_rate(16000) |
|
|
|
|
|
audio_segment = audio_segment.set_channels(1) |
|
|
|
|
|
audio_segment = audio_segment.set_sample_width(2) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
audio_segment.export(buffer, format="wav") |
|
|
return buffer.getvalue() |
|
|
|
|
|
|
|
|
def normalize_volume(audio_bytes: bytes) -> bytes: |
|
|
""" |
|
|
Normalizza il volume di un chunk audio WAV. |
|
|
""" |
|
|
|
|
|
audio_segment = AudioSegment.from_wav(io.BytesIO(audio_bytes)) |
|
|
|
|
|
|
|
|
|
|
|
normalized_segment = audio_segment.normalize(headroom=0.1) |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
normalized_segment.export(buffer, format="wav") |
|
|
return buffer.getvalue() |
|
|
|
|
|
|
|
|
def reduce_background_noise(audio_bytes: bytes) -> bytes: |
|
|
""" |
|
|
Riduce il rumore di fondo da un chunk audio WAV. |
|
|
""" |
|
|
|
|
|
buffer_read = io.BytesIO(audio_bytes) |
|
|
rate, data = sf.read(buffer_read) |
|
|
|
|
|
|
|
|
if data.ndim > 1: |
|
|
data = np.mean(data, axis=1) |
|
|
|
|
|
|
|
|
reduced_noise_data = noisereduce.reduce_noise(y=data, sr=rate) |
|
|
|
|
|
|
|
|
buffer_write = io.BytesIO() |
|
|
sf.write(buffer_write, reduced_noise_data, rate, format="wav") |
|
|
return buffer_write.getvalue() |
|
|
|
|
|
|
|
|
def remove_silence(audio_bytes: bytes) -> bytes: |
|
|
""" |
|
|
Rimuove i segmenti di silenzio da un chunk audio in formato WAV. |
|
|
""" |
|
|
|
|
|
audio_segment = AudioSegment.from_wav(io.BytesIO(audio_bytes)) |
|
|
|
|
|
chunks = split_on_silence( |
|
|
audio_segment, |
|
|
min_silence_len=100, |
|
|
silence_thresh=-35, |
|
|
keep_silence=80, |
|
|
) |
|
|
|
|
|
if not chunks: |
|
|
|
|
|
return b"" |
|
|
|
|
|
|
|
|
processed_segment = sum(chunks, AudioSegment.empty()) |
|
|
|
|
|
buffer = io.BytesIO() |
|
|
processed_segment.export(buffer, format="wav") |
|
|
return buffer.getvalue() |
|
|
|
|
|
|
|
|
def preprocess_audio_pipeline(audio_path: str) -> str: |
|
|
""" |
|
|
Applica la pipeline completa di preprocessing audio. |
|
|
Restituisce il path del file audio preprocessato. |
|
|
""" |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.info("Avvio pipeline di preprocessing audio") |
|
|
|
|
|
try: |
|
|
|
|
|
with open(audio_path, "rb") as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
|
|
|
logger.info("1. Normalizzazione formato audio...") |
|
|
audio_bytes = normalize_audio(audio_bytes) |
|
|
|
|
|
logger.info("2. Normalizzazione volume...") |
|
|
audio_bytes = normalize_volume(audio_bytes) |
|
|
|
|
|
logger.info("3. Riduzione rumore di fondo...") |
|
|
audio_bytes = reduce_background_noise(audio_bytes) |
|
|
|
|
|
logger.info("4. Rimozione silenzi...") |
|
|
audio_bytes = remove_silence(audio_bytes) |
|
|
|
|
|
|
|
|
if not audio_bytes: |
|
|
logger.warning( |
|
|
"Audio vuoto dopo rimozione silenzi, utilizzo audio originale" |
|
|
) |
|
|
with open(audio_path, "rb") as f: |
|
|
audio_bytes = f.read() |
|
|
|
|
|
audio_bytes = normalize_audio(audio_bytes) |
|
|
audio_bytes = normalize_volume(audio_bytes) |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: |
|
|
temp_file.write(audio_bytes) |
|
|
preprocessed_path = temp_file.name |
|
|
|
|
|
logger.info(f"Preprocessing completato: {preprocessed_path}") |
|
|
return preprocessed_path |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Errore durante preprocessing: {e}") |
|
|
logger.info("Utilizzo audio originale senza preprocessing") |
|
|
return audio_path |
|
|
|
|
|
|
|
|
def load_asr_pipeline( |
|
|
model_id: str, |
|
|
base_model_id: str, |
|
|
device_pref: str = "auto", |
|
|
hf_token: Optional[str] = None, |
|
|
dtype_pref: str = "auto", |
|
|
chunk_length_s: Optional[int] = None, |
|
|
return_timestamps: bool = False, |
|
|
): |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logger.info(f"Loading ASR pipeline for model: {model_id}") |
|
|
logger.info( |
|
|
f"Device preference: {device_pref}, Token provided: {hf_token is not None}" |
|
|
) |
|
|
|
|
|
import torch |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
device_str = "cpu" |
|
|
if device_pref == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
device_str = "cuda" |
|
|
logger.info(f"Using CUDA: {torch.cuda.get_device_name()}") |
|
|
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
|
|
device_str = "mps" |
|
|
logger.info("Using Apple Silicon MPS for inference") |
|
|
else: |
|
|
device_str = "cpu" |
|
|
logger.info("Using CPU for inference") |
|
|
else: |
|
|
device_str = device_pref |
|
|
|
|
|
|
|
|
dtype = None |
|
|
if dtype_pref == "auto": |
|
|
|
|
|
if "whisper-medium" in model_id: |
|
|
dtype = torch.float32 |
|
|
logger.info( |
|
|
f"Using float32 for {model_id} (medical transcription stability)" |
|
|
) |
|
|
elif device_str == "cuda": |
|
|
dtype = torch.float16 |
|
|
logger.info("Using float16 on CUDA for faster inference") |
|
|
else: |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
dtype = {"float32": torch.float32, "float16": torch.float16}.get( |
|
|
dtype_pref, torch.float32 |
|
|
) |
|
|
|
|
|
logger.info("Pipeline configuration:") |
|
|
logger.info(f" Model: {model_id}") |
|
|
logger.info(f" Base model: {base_model_id}") |
|
|
logger.info(f" Dtype: {dtype}") |
|
|
logger.info(f" Device: {device_str}") |
|
|
logger.info(f" Chunk length: {chunk_length_s}s") |
|
|
logger.info(f" Return timestamps: {return_timestamps}") |
|
|
|
|
|
dtype_name = str(dtype).replace("torch.", "") if dtype is not None else "auto" |
|
|
cache_key = (model_id, device_str, dtype_name) |
|
|
|
|
|
with PIPELINE_CACHE_LOCK: |
|
|
cached_pipeline = PIPELINE_CACHE.get(cache_key) |
|
|
if cached_pipeline: |
|
|
logger.info( |
|
|
"Reusing cached pipeline for %s on %s (%s)", |
|
|
model_id, |
|
|
device_str, |
|
|
dtype_name, |
|
|
) |
|
|
return cached_pipeline |
|
|
|
|
|
model_source = ensure_local_model(model_id, hf_token=hf_token) |
|
|
logger.info(f"Using local model files from: {model_source}") |
|
|
|
|
|
device_argument: Any = 0 if device_str == "cuda" else device_str |
|
|
|
|
|
pipeline_kwargs = { |
|
|
"task": "automatic-speech-recognition", |
|
|
"device": device_argument, |
|
|
} |
|
|
if dtype is not None: |
|
|
pipeline_kwargs["torch_dtype"] = dtype |
|
|
|
|
|
|
|
|
def build_pipeline_with_recovery(model_path: str, kwargs: Dict[str, Any]) -> Any: |
|
|
try: |
|
|
return pipeline(**{**kwargs, "model": model_path}) |
|
|
except Exception as build_error: |
|
|
logger.error( |
|
|
"Failed to load pipeline for %s from %s: %s", |
|
|
model_id, |
|
|
model_path, |
|
|
build_error, |
|
|
) |
|
|
raise |
|
|
|
|
|
try: |
|
|
logger.info( |
|
|
"Setting up ultra-simplified pipeline to avoid forced_decoder_ids conflicts..." |
|
|
) |
|
|
|
|
|
asr = build_pipeline_with_recovery(model_source, pipeline_kwargs) |
|
|
|
|
|
|
|
|
if hasattr(asr.model, "generation_config") and hasattr( |
|
|
asr.model.generation_config, "forced_decoder_ids" |
|
|
): |
|
|
logger.info("Removing forced_decoder_ids from model generation config") |
|
|
asr.model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
if chunk_length_s: |
|
|
logger.info(f"Setting chunk_length_s to {chunk_length_s}") |
|
|
|
|
|
final_device = device_str |
|
|
final_dtype = dtype |
|
|
final_dtype_name = dtype_name |
|
|
|
|
|
logger.info(f"Successfully created ultra-simplified pipeline for: {model_id}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Ultra-simplified pipeline creation failed: {e}") |
|
|
logger.info("Falling back to absolute minimal settings...") |
|
|
|
|
|
fallback_device = "cpu" |
|
|
fallback_dtype = torch.float32 |
|
|
fallback_dtype_name = str(fallback_dtype).replace("torch.", "") |
|
|
fallback_key = (model_id, fallback_device, fallback_dtype_name) |
|
|
|
|
|
with PIPELINE_CACHE_LOCK: |
|
|
cached_pipeline = PIPELINE_CACHE.get(fallback_key) |
|
|
if cached_pipeline: |
|
|
logger.info( |
|
|
"Reusing cached fallback pipeline for %s (%s)", |
|
|
model_id, |
|
|
fallback_dtype_name, |
|
|
) |
|
|
return cached_pipeline |
|
|
|
|
|
try: |
|
|
fallback_kwargs = { |
|
|
"task": "automatic-speech-recognition", |
|
|
"device": fallback_device, |
|
|
"torch_dtype": fallback_dtype, |
|
|
} |
|
|
asr = build_pipeline_with_recovery(model_source, fallback_kwargs) |
|
|
|
|
|
if hasattr(asr.model, "generation_config") and hasattr( |
|
|
asr.model.generation_config, "forced_decoder_ids" |
|
|
): |
|
|
logger.info("Removing forced_decoder_ids from fallback model") |
|
|
asr.model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
final_device = fallback_device |
|
|
final_dtype = fallback_dtype |
|
|
final_dtype_name = fallback_dtype_name |
|
|
logger.info( |
|
|
f"Minimal fallback pipeline created with dtype: {fallback_dtype}" |
|
|
) |
|
|
|
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Minimal fallback failed: {fallback_error}") |
|
|
raise |
|
|
|
|
|
cache_key = (model_id, final_device, final_dtype_name) |
|
|
with PIPELINE_CACHE_LOCK: |
|
|
PIPELINE_CACHE[cache_key] = (asr, final_device, final_dtype_name) |
|
|
|
|
|
return asr, final_device, final_dtype_name |
|
|
|
|
|
|
|
|
def get_text_postprocess_pipeline( |
|
|
model_id: str, |
|
|
device_pref: Optional[str], |
|
|
hf_token: Optional[str], |
|
|
) -> Any: |
|
|
"""Load a minimal text-generation pipeline for post-processing.""" |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
if not model_id: |
|
|
raise ValueError("Model id for text post-processing is not configured") |
|
|
|
|
|
normalized_device_pref = (device_pref or "auto").lower() |
|
|
if normalized_device_pref == "auto": |
|
|
if torch.cuda.is_available(): |
|
|
device_choice = "cuda" |
|
|
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
|
|
device_choice = "mps" |
|
|
else: |
|
|
device_choice = "cpu" |
|
|
else: |
|
|
device_choice = normalized_device_pref |
|
|
|
|
|
device_argument: Any |
|
|
dtype: Optional[torch.dtype] = None |
|
|
if device_choice.startswith("cuda") and torch.cuda.is_available(): |
|
|
device_argument = device_choice |
|
|
dtype = torch.bfloat16 |
|
|
elif ( |
|
|
device_choice == "mps" |
|
|
and getattr(torch.backends, "mps", None) |
|
|
and torch.backends.mps.is_available() |
|
|
): |
|
|
device_argument = "mps" |
|
|
dtype = torch.float16 |
|
|
else: |
|
|
device_argument = "cpu" |
|
|
dtype = None |
|
|
|
|
|
global TEXT_POSTPROCESS_PIPELINE, TEXT_POSTPROCESS_MODEL_ID |
|
|
|
|
|
with TEXT_POSTPROCESS_PIPELINE_LOCK: |
|
|
if ( |
|
|
TEXT_POSTPROCESS_PIPELINE is not None |
|
|
and TEXT_POSTPROCESS_MODEL_ID == model_id |
|
|
): |
|
|
return TEXT_POSTPROCESS_PIPELINE |
|
|
|
|
|
model_source = ensure_local_model(model_id, hf_token=hf_token) |
|
|
|
|
|
is_medgemma = "medgemma" in model_id.lower() |
|
|
|
|
|
if is_medgemma: |
|
|
pipe_kwargs: Dict[str, Any] = { |
|
|
"task": "image-text-to-text", |
|
|
"model": model_source, |
|
|
"device": device_argument, |
|
|
} |
|
|
if dtype is not None: |
|
|
pipe_kwargs["torch_dtype"] = dtype |
|
|
else: |
|
|
pipe_kwargs = { |
|
|
"task": "text-generation", |
|
|
"model": model_source, |
|
|
"device": device_argument, |
|
|
"tokenizer": model_source, |
|
|
} |
|
|
if dtype is not None: |
|
|
pipe_kwargs["torch_dtype"] = dtype |
|
|
if device_argument != "cpu": |
|
|
pipe_kwargs["device_map"] = "auto" |
|
|
|
|
|
logger.info( |
|
|
"Loading postprocess pipeline for %s with device=%s, dtype=%s", |
|
|
model_id, |
|
|
device_argument, |
|
|
str(dtype) if dtype is not None else "auto", |
|
|
) |
|
|
|
|
|
try: |
|
|
postprocess_pipe = pipeline(**pipe_kwargs) |
|
|
except Exception as primary_error: |
|
|
logger.warning( |
|
|
"Postprocess pipeline init failed on %s (%s). Falling back to CPU.", |
|
|
device_argument, |
|
|
primary_error, |
|
|
) |
|
|
pipe_kwargs["device"] = "cpu" |
|
|
pipe_kwargs.pop("torch_dtype", None) |
|
|
pipe_kwargs.pop("device_map", None) |
|
|
postprocess_pipe = pipeline(**pipe_kwargs) |
|
|
|
|
|
TEXT_POSTPROCESS_PIPELINE = postprocess_pipe |
|
|
TEXT_POSTPROCESS_MODEL_ID = model_id |
|
|
return postprocess_pipe |
|
|
|
|
|
|
|
|
def postprocess_transcription_text( |
|
|
text: str, |
|
|
context_label: str, |
|
|
) -> str: |
|
|
"""Run MedGemma post-processing to clean transcription text.""" |
|
|
|
|
|
if not text or not text.strip(): |
|
|
return text |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
text_postprocess_enabled = get_env_or_secret( |
|
|
"TEXT_POSTPROCESS_ENABLED", "false" |
|
|
).lower() in { |
|
|
"1", |
|
|
"true", |
|
|
"yes", |
|
|
} |
|
|
if not text_postprocess_enabled: |
|
|
logger.debug( |
|
|
"Text post-processing skipped for %s: feature disabled", |
|
|
context_label, |
|
|
) |
|
|
return text |
|
|
|
|
|
model_id = get_env_or_secret( |
|
|
"TEXT_POSTPROCESS_MODEL_ID", DEFAULT_TEXT_POSTPROCESS_MODEL |
|
|
) |
|
|
if not model_id: |
|
|
logger.info("Text post-processing disabled: no model configured") |
|
|
return text |
|
|
|
|
|
hf_token = get_env_or_secret("TEXT_POSTPROCESS_HF_TOKEN") or get_env_or_secret( |
|
|
"HF_TOKEN" |
|
|
) |
|
|
device_pref = get_env_or_secret("TEXT_POSTPROCESS_DEVICE", "auto") |
|
|
max_new_tokens = int(get_env_or_secret("TEXT_POSTPROCESS_MAX_NEW", "200")) |
|
|
|
|
|
prompt_body = text.strip() |
|
|
prompt = f"{TEXT_POSTPROCESS_PROMPT}{prompt_body}\nRisultato:" |
|
|
is_medgemma = "medgemma" in model_id.lower() |
|
|
|
|
|
try: |
|
|
postprocess_pipe = get_text_postprocess_pipeline( |
|
|
model_id=model_id, |
|
|
device_pref=device_pref, |
|
|
hf_token=hf_token, |
|
|
) |
|
|
|
|
|
if is_medgemma: |
|
|
system_prompt, separator, _ = TEXT_POSTPROCESS_PROMPT.partition( |
|
|
"\nTesto originale:\n" |
|
|
) |
|
|
if not separator: |
|
|
system_prompt = TEXT_POSTPROCESS_PROMPT |
|
|
user_prefix = "" |
|
|
else: |
|
|
user_prefix = "Testo originale:\n" |
|
|
system_prompt = system_prompt.strip() |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": [{"type": "text", "text": system_prompt.strip()}], |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": f"{user_prefix}{prompt_body}\nRisultato:", |
|
|
} |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
outputs = postprocess_pipe( |
|
|
text=messages, |
|
|
max_new_tokens=max_new_tokens, |
|
|
) |
|
|
|
|
|
generated_text = "" |
|
|
if isinstance(outputs, list) and outputs: |
|
|
first = outputs[0] |
|
|
if isinstance(first, dict): |
|
|
generated = first.get("generated_text") |
|
|
if isinstance(generated, list): |
|
|
|
|
|
for msg in reversed(generated): |
|
|
if not isinstance(msg, dict): |
|
|
continue |
|
|
role = msg.get("role") |
|
|
if role not in {"assistant", "model", None}: |
|
|
continue |
|
|
content = msg.get("content") |
|
|
if isinstance(content, list): |
|
|
for block in content: |
|
|
if ( |
|
|
isinstance(block, dict) |
|
|
and block.get("type") == "text" |
|
|
): |
|
|
text_block = (block.get("text") or "").strip() |
|
|
if text_block: |
|
|
generated_text = text_block |
|
|
break |
|
|
if generated_text: |
|
|
break |
|
|
elif isinstance(content, str) and content.strip(): |
|
|
generated_text = content.strip() |
|
|
break |
|
|
if not generated_text: |
|
|
|
|
|
for msg in reversed(generated): |
|
|
if not isinstance(msg, dict): |
|
|
continue |
|
|
content = msg.get("content") |
|
|
if isinstance(content, list): |
|
|
for block in content: |
|
|
if ( |
|
|
isinstance(block, dict) |
|
|
and block.get("type") == "text" |
|
|
and block.get("text") |
|
|
): |
|
|
generated_text = block["text"].strip() |
|
|
break |
|
|
if generated_text: |
|
|
break |
|
|
elif isinstance(content, str) and content.strip(): |
|
|
generated_text = content.strip() |
|
|
break |
|
|
elif isinstance(generated, str): |
|
|
generated_text = generated.strip() |
|
|
elif isinstance(outputs, dict): |
|
|
generated = outputs.get("generated_text") |
|
|
if isinstance(generated, list): |
|
|
for msg in reversed(generated): |
|
|
if isinstance(msg, dict): |
|
|
text_block = msg.get("text") or msg.get("content") or "" |
|
|
if isinstance(text_block, str) and text_block.strip(): |
|
|
generated_text = text_block.strip() |
|
|
break |
|
|
elif isinstance(generated, str): |
|
|
generated_text = generated.strip() |
|
|
|
|
|
cleaned = generated_text |
|
|
else: |
|
|
outputs = postprocess_pipe( |
|
|
prompt, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=False, |
|
|
return_full_text=False, |
|
|
) |
|
|
|
|
|
generated_text = "" |
|
|
if isinstance(outputs, list) and outputs: |
|
|
first = outputs[0] |
|
|
if isinstance(first, dict): |
|
|
candidate = first.get("generated_text") or first.get("text") |
|
|
if isinstance(candidate, str): |
|
|
generated_text = candidate |
|
|
elif isinstance(candidate, list): |
|
|
generated_text = " ".join( |
|
|
part for part in candidate if isinstance(part, str) |
|
|
) |
|
|
elif isinstance(first, str): |
|
|
generated_text = first |
|
|
elif isinstance(outputs, dict): |
|
|
candidate = outputs.get("generated_text") or outputs.get("text") |
|
|
if isinstance(candidate, str): |
|
|
generated_text = candidate |
|
|
elif isinstance(outputs, str): |
|
|
generated_text = outputs |
|
|
|
|
|
generated_text = (generated_text or "").strip() |
|
|
|
|
|
if generated_text.startswith(prompt): |
|
|
cleaned = generated_text[len(prompt) :].strip() |
|
|
else: |
|
|
cleaned = generated_text |
|
|
|
|
|
if cleaned: |
|
|
if cleaned.startswith(prompt_body): |
|
|
cleaned = cleaned[len(prompt_body) :].strip() |
|
|
if cleaned.startswith("Risultato:"): |
|
|
cleaned = cleaned[len("Risultato:") :].strip() |
|
|
if cleaned.lower().startswith("risultato:"): |
|
|
cleaned = cleaned[len("risultato:") :].strip() |
|
|
logger.debug("Post-processing successful for %s", context_label) |
|
|
return cleaned or text |
|
|
|
|
|
logger.warning("Post-processing returned empty output for %s", context_label) |
|
|
return text |
|
|
|
|
|
except Exception as exc: |
|
|
logger.warning( |
|
|
"Text post-processing failed for %s with model %s: %s", |
|
|
context_label, |
|
|
model_id, |
|
|
exc, |
|
|
) |
|
|
return text |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def memory_monitor(): |
|
|
"""Context manager to monitor memory usage during inference.""" |
|
|
process = psutil.Process() |
|
|
start_memory = process.memory_info().rss / 1024 / 1024 |
|
|
yield |
|
|
end_memory = process.memory_info().rss / 1024 / 1024 |
|
|
return end_memory - start_memory |
|
|
|
|
|
|
|
|
def transcribe_local( |
|
|
audio_path: str, |
|
|
model_id: str, |
|
|
base_model_id: str, |
|
|
language: Optional[str], |
|
|
task: str, |
|
|
device_pref: str, |
|
|
dtype_pref: str, |
|
|
hf_token: Optional[str], |
|
|
chunk_length_s: Optional[int], |
|
|
stride_length_s: Optional[int], |
|
|
return_timestamps: bool, |
|
|
) -> Dict[str, Any]: |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.info(f"Starting transcription: {os.path.basename(audio_path)}") |
|
|
logger.info(f"Model: {model_id}") |
|
|
|
|
|
|
|
|
if audio_path is None: |
|
|
raise ValueError("Audio path is None") |
|
|
if not isinstance(audio_path, (str, bytes, os.PathLike)): |
|
|
raise TypeError( |
|
|
f"Audio path must be str, bytes or os.PathLike, got {type(audio_path)}" |
|
|
) |
|
|
if not os.path.exists(audio_path): |
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}") |
|
|
|
|
|
|
|
|
preprocessed_audio_path = audio_path |
|
|
if PREPROCESSING_AVAILABLE: |
|
|
try: |
|
|
logger.info("Applicazione preprocessing audio...") |
|
|
preprocessed_audio_path = preprocess_audio_pipeline(audio_path) |
|
|
logger.info( |
|
|
f"Preprocessing completato. File processato: {os.path.basename(preprocessed_audio_path)}" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning( |
|
|
f"Errore durante preprocessing, utilizzo audio originale: {e}" |
|
|
) |
|
|
preprocessed_audio_path = audio_path |
|
|
else: |
|
|
logger.info("Preprocessing audio non disponibile, utilizzo audio originale") |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
asr, device_str, dtype_str = load_asr_pipeline( |
|
|
model_id=model_id, |
|
|
base_model_id=base_model_id, |
|
|
device_pref=device_pref, |
|
|
hf_token=hf_token, |
|
|
dtype_pref=dtype_pref, |
|
|
chunk_length_s=chunk_length_s, |
|
|
return_timestamps=return_timestamps, |
|
|
) |
|
|
|
|
|
load_time = time.time() - start_time |
|
|
logger.info(f"Model loaded in {load_time:.2f}s") |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Using simplified configuration to avoid model compatibility issues") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
asr_kwargs = {} |
|
|
|
|
|
|
|
|
if return_timestamps: |
|
|
asr_kwargs["return_timestamps"] = return_timestamps |
|
|
logger.info("Timestamps enabled") |
|
|
|
|
|
|
|
|
if chunk_length_s: |
|
|
try: |
|
|
asr_kwargs["chunk_length_s"] = chunk_length_s |
|
|
logger.info(f"Using chunking strategy: {chunk_length_s}s") |
|
|
except Exception as chunk_error: |
|
|
logger.warning(f"Chunking not supported: {chunk_error}") |
|
|
|
|
|
if stride_length_s is not None: |
|
|
try: |
|
|
asr_kwargs["stride_length_s"] = stride_length_s |
|
|
logger.info(f"Using stride: {stride_length_s}s") |
|
|
except Exception as stride_error: |
|
|
logger.warning(f"Stride not supported: {stride_error}") |
|
|
|
|
|
|
|
|
generate_kwargs: Dict[str, Any] = {} |
|
|
if language: |
|
|
generate_kwargs["language"] = language |
|
|
logger.info(f"Forcing ASR language: {language}") |
|
|
if task: |
|
|
generate_kwargs["task"] = task |
|
|
logger.info(f"Forcing ASR task: {task}") |
|
|
if generate_kwargs: |
|
|
asr_kwargs["generate_kwargs"] = generate_kwargs |
|
|
|
|
|
logger.info(f"Inference parameters configured: {list(asr_kwargs.keys())}") |
|
|
|
|
|
|
|
|
inference_start = time.time() |
|
|
memory_before = psutil.Process().memory_info().rss / 1024 / 1024 |
|
|
|
|
|
try: |
|
|
|
|
|
if asr_kwargs: |
|
|
result = asr(preprocessed_audio_path, **asr_kwargs) |
|
|
else: |
|
|
|
|
|
result = asr(preprocessed_audio_path) |
|
|
|
|
|
inference_time = time.time() - inference_start |
|
|
memory_after = psutil.Process().memory_info().rss / 1024 / 1024 |
|
|
memory_used = memory_after - memory_before |
|
|
|
|
|
logger.info(f"Inference completed successfully in {inference_time:.2f}s") |
|
|
logger.info(f"Memory used: {memory_used:.1f}MB") |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
logger.warning(f"Inference failed with parameters: {error_msg}") |
|
|
|
|
|
|
|
|
if "forced_decoder_ids" in error_msg: |
|
|
logger.info( |
|
|
"Detected forced_decoder_ids error, trying with no parameters..." |
|
|
) |
|
|
elif ( |
|
|
"probability tensor contains either inf, nan or element < 0" |
|
|
in error_msg |
|
|
): |
|
|
logger.info( |
|
|
"Detected numerical instability, trying with no parameters..." |
|
|
) |
|
|
else: |
|
|
logger.info("Unknown error, trying with no parameters...") |
|
|
|
|
|
try: |
|
|
inference_start = time.time() |
|
|
result = asr(preprocessed_audio_path) |
|
|
inference_time = time.time() - inference_start |
|
|
memory_used = 0 |
|
|
|
|
|
logger.info(f"Minimal inference completed in {inference_time:.2f}s") |
|
|
except Exception as final_error: |
|
|
logger.error(f"All inference attempts failed: {final_error}") |
|
|
raise |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Inference failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
if device_str == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
|
|
|
if preprocessed_audio_path != audio_path: |
|
|
try: |
|
|
os.unlink(preprocessed_audio_path) |
|
|
logger.info("File audio preprocessato temporaneo rimosso") |
|
|
except Exception as e: |
|
|
logger.warning(f"Errore rimozione file temporaneo: {e}") |
|
|
|
|
|
|
|
|
meta = { |
|
|
"device": device_str, |
|
|
"dtype": dtype_str, |
|
|
"inference_time": inference_time, |
|
|
"memory_used_mb": memory_used, |
|
|
"model_type": "original" if model_id == base_model_id else "fine-tuned", |
|
|
"preprocessing_applied": preprocessed_audio_path != audio_path, |
|
|
} |
|
|
|
|
|
return {"result": result, "meta": meta} |
|
|
|
|
|
|
|
|
def handle_whisper_problematic_output(text: str, model_name: str = "Whisper") -> dict: |
|
|
"""Gestisce gli output problematici di Whisper come '!', '.', stringhe vuote, ecc.""" |
|
|
if not text: |
|
|
return { |
|
|
"text": "[WHISPER ISSUE: Output vuoto - Audio troppo corto o silenzioso]", |
|
|
"is_problematic": True, |
|
|
"original": text, |
|
|
"issue_type": "empty", |
|
|
} |
|
|
|
|
|
text_stripped = text.strip() |
|
|
|
|
|
|
|
|
problematic_outputs = { |
|
|
"!": "Audio troppo corto/silenzioso", |
|
|
".": "Audio di bassa qualità", |
|
|
"?": "Audio incomprensibile", |
|
|
"...": "Audio troppo lungo senza parlato", |
|
|
"--": "Audio distorto", |
|
|
"—": "Audio con troppo rumore", |
|
|
" per!": "Audio parzialmente comprensibile", |
|
|
"per!": "Audio parzialmente comprensibile", |
|
|
} |
|
|
|
|
|
if text_stripped in problematic_outputs: |
|
|
return { |
|
|
"text": f"[WHISPER ISSUE: '{text_stripped}' - {problematic_outputs[text_stripped]}]", |
|
|
"is_problematic": True, |
|
|
"original": text, |
|
|
"issue_type": text_stripped, |
|
|
"suggestion": problematic_outputs[text_stripped], |
|
|
} |
|
|
|
|
|
|
|
|
if len(text_stripped) <= 2 and not text_stripped.isalpha(): |
|
|
return { |
|
|
"text": f"[WHISPER ISSUE: '{text_stripped}' - Output troppo corto/simbolico]", |
|
|
"is_problematic": True, |
|
|
"original": text, |
|
|
"issue_type": "short_symbolic", |
|
|
} |
|
|
|
|
|
return {"text": text, "is_problematic": False, "original": text} |
|
|
|
|
|
|
|
|
def transcribe_comparison(audio_file): |
|
|
"""Main function for Gradio interface.""" |
|
|
if audio_file is None: |
|
|
warning = "❌ Nessun file audio fornito" |
|
|
return warning, warning, warning |
|
|
|
|
|
|
|
|
model_id = get_env_or_secret("HF_MODEL_ID") |
|
|
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID") |
|
|
hf_token = get_env_or_secret("HF_TOKEN") or get_env_or_secret( |
|
|
"HUGGINGFACEHUB_API_TOKEN" |
|
|
) |
|
|
|
|
|
if not model_id or not base_model_id: |
|
|
error_msg = "❌ Modelli non configurati. Impostare HF_MODEL_ID e BASE_WHISPER_MODEL_ID nelle variabili d'ambiente" |
|
|
return error_msg, error_msg, error_msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
language = "it" |
|
|
task = "transcribe" |
|
|
return_ts = True |
|
|
device_pref = "auto" |
|
|
dtype_pref = "auto" |
|
|
chunk_len = 7 |
|
|
stride_len = 1 |
|
|
|
|
|
try: |
|
|
|
|
|
tmp_path = audio_file |
|
|
|
|
|
original_result = None |
|
|
finetuned_result = None |
|
|
original_text = "" |
|
|
finetuned_text = "" |
|
|
postprocessed_text = "" |
|
|
|
|
|
try: |
|
|
|
|
|
original_result = transcribe_local( |
|
|
audio_path=tmp_path, |
|
|
model_id=base_model_id, |
|
|
base_model_id=base_model_id, |
|
|
language=language, |
|
|
task=task, |
|
|
device_pref=device_pref, |
|
|
dtype_pref=dtype_pref, |
|
|
hf_token=None, |
|
|
chunk_length_s=int(chunk_len) if chunk_len else None, |
|
|
stride_length_s=int(stride_len) if stride_len else None, |
|
|
return_timestamps=return_ts, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(original_result["result"], dict): |
|
|
original_text = original_result["result"].get( |
|
|
"text" |
|
|
) or original_result["result"].get("transcription") |
|
|
elif isinstance(original_result["result"], str): |
|
|
original_text = original_result["result"] |
|
|
|
|
|
if original_text: |
|
|
result = handle_whisper_problematic_output( |
|
|
original_text, "Original Whisper" |
|
|
) |
|
|
if result["is_problematic"]: |
|
|
original_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" |
|
|
else: |
|
|
original_text = result["text"] |
|
|
else: |
|
|
original_text = "❌ Nessun testo restituito dal modello originale" |
|
|
|
|
|
except Exception as e: |
|
|
original_text = f"❌ Errore modello originale: {str(e)}" |
|
|
|
|
|
try: |
|
|
|
|
|
finetuned_result = transcribe_local( |
|
|
audio_path=tmp_path, |
|
|
model_id=model_id, |
|
|
base_model_id=base_model_id, |
|
|
language=language, |
|
|
task=task, |
|
|
device_pref=device_pref, |
|
|
dtype_pref=dtype_pref, |
|
|
hf_token=hf_token or None, |
|
|
chunk_length_s=int(chunk_len) if chunk_len else None, |
|
|
stride_length_s=int(stride_len) if stride_len else None, |
|
|
return_timestamps=return_ts, |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(finetuned_result["result"], dict): |
|
|
finetuned_text = finetuned_result["result"].get( |
|
|
"text" |
|
|
) or finetuned_result["result"].get("transcription") |
|
|
elif isinstance(finetuned_result["result"], str): |
|
|
finetuned_text = finetuned_result["result"] |
|
|
|
|
|
if finetuned_text: |
|
|
result = handle_whisper_problematic_output( |
|
|
finetuned_text, "Fine-tuned Model" |
|
|
) |
|
|
if result["is_problematic"]: |
|
|
finetuned_text = f"⚠️ {result['text']}\n\n💡 Suggerimenti:\n• Registra almeno 5-10 secondi di audio\n• Parla chiaramente e ad alto volume\n• Avvicinati al microfono\n• Evita rumori di fondo" |
|
|
else: |
|
|
finetuned_text = result["text"] |
|
|
else: |
|
|
finetuned_text = "❌ Nessun testo restituito dal modello fine-tuned" |
|
|
|
|
|
except Exception as e: |
|
|
finetuned_text = f"❌ Errore modello fine-tuned: {str(e)}" |
|
|
|
|
|
postprocessed_text = finetuned_text or "" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
return original_text, finetuned_text, postprocessed_text |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ Errore generale: {str(e)}" |
|
|
return error_msg, error_msg, error_msg |
|
|
|
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
"""Create and configure the Gradio interface.""" |
|
|
|
|
|
warm_model_cache() |
|
|
|
|
|
model_id = get_env_or_secret("HF_MODEL_ID", "ReportAId/whisper-medium-it-finetuned") |
|
|
base_model_id = get_env_or_secret("BASE_WHISPER_MODEL_ID", "openai/whisper-medium") |
|
|
|
|
|
|
|
|
logo_html = None |
|
|
try: |
|
|
assets_dir = os.path.join(os.path.dirname(__file__), "assets") |
|
|
light_path = os.path.join(assets_dir, "RaidLight.svg") |
|
|
dark_path = os.path.join(assets_dir, "RaidDark.svg") |
|
|
|
|
|
with open(light_path, "r", encoding="utf-8") as f: |
|
|
light_svg = f.read() |
|
|
with open(dark_path, "r", encoding="utf-8") as f: |
|
|
dark_svg = f.read() |
|
|
|
|
|
logo_html = f""" |
|
|
<style> |
|
|
.logo-container {{ |
|
|
text-align: center; |
|
|
margin: 16px 0 8px; |
|
|
}} |
|
|
.logo-container .sr-only {{ |
|
|
position: absolute; |
|
|
width: 1px; |
|
|
height: 1px; |
|
|
padding: 0; |
|
|
margin: -1px; |
|
|
overflow: hidden; |
|
|
clip: rect(0, 0, 0, 0); |
|
|
white-space: nowrap; |
|
|
border: 0; |
|
|
}} |
|
|
.logo-container svg {{ |
|
|
height: 72px; |
|
|
width: auto; |
|
|
max-width: 100%; |
|
|
}} |
|
|
.logo-container .logo-dark {{ |
|
|
display: none; |
|
|
}} |
|
|
@media (prefers-color-scheme: dark) {{ |
|
|
.logo-container .logo-light {{ |
|
|
display: none !important; |
|
|
}} |
|
|
.logo-container .logo-dark {{ |
|
|
display: inline-block !important; |
|
|
}} |
|
|
}} |
|
|
</style> |
|
|
<div class=\"logo-container\"> |
|
|
<div class=\"logo-light\" aria-hidden=\"true\">{light_svg}</div> |
|
|
<div class=\"logo-dark\" aria-hidden=\"true\">{dark_svg}</div> |
|
|
<span class=\"sr-only\">ReportAId</span> |
|
|
</div> |
|
|
""" |
|
|
except Exception: |
|
|
|
|
|
logo_html = """ |
|
|
<style> |
|
|
.logo-container { text-align: center; margin: 16px 0 8px; } |
|
|
.logo-container .sr-only { |
|
|
position: absolute; |
|
|
width: 1px; |
|
|
height: 1px; |
|
|
padding: 0; |
|
|
margin: -1px; |
|
|
overflow: hidden; |
|
|
clip: rect(0, 0, 0, 0); |
|
|
white-space: nowrap; |
|
|
border: 0; |
|
|
} |
|
|
.logo-container img { height: 72px; width: auto; max-width: 100%; } |
|
|
.logo-container .logo-dark { display: none; } |
|
|
@media (prefers-color-scheme: dark) { |
|
|
.logo-container .logo-light { display: none !important; } |
|
|
.logo-container .logo-dark { display: inline-block !important; } |
|
|
} |
|
|
</style> |
|
|
<div class=\"logo-container\"> |
|
|
<img class=\"logo-light\" src=\"file=assets/RaidLight.svg\" alt=\"ReportAId\"> |
|
|
<img class=\"logo-dark\" src=\"file=assets/RaidDark.svg\" alt=\"ReportAId\"> |
|
|
<span class=\"sr-only\">ReportAId</span> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
with gr.Blocks( |
|
|
title="Medical Transcription", |
|
|
theme=gr.themes.Default(primary_hue="blue"), |
|
|
css=".gradio-container{max-width: 900px !important; margin: 0 auto !important;} .center-col{display:flex;flex-direction:column;align-items:center;} .center-col .wrap{width:100%;}", |
|
|
) as demo: |
|
|
|
|
|
gr.HTML(logo_html) |
|
|
gr.Markdown(""" |
|
|
Questa demo confronta MedWhisper Large ITA con Whisper Large v3 Turbo su parlato clinico in italiano. MedWhisper è una variante domain-adapted (LoRA) del modello base, addestrata su registrazioni sintetiche ricche di gergo medico, acronimi e formule ricorrenti. Carica o registra audio per ottenere trascrizioni affiancate; noterai una resa migliore della terminologia specialistica (es. “Holter delle 24 ore”, “fibrillazione atriale”). Sul nostro held-out clinico, la WER scende dal 7,9% al 4,5% rispetto al checkpoint base. |
|
|
|
|
|
Riferimento al MedWhisper: https://huggingface.co/ReportAId/medwhisper-large-v3-ita |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown(f""" |
|
|
**⚙️ Impostazioni** |
|
|
- Modello originale: `{base_model_id}` |
|
|
- Modello fine-tuned: `{model_id}` |
|
|
- Lingua: Italiano (it) |
|
|
- Preprocessing audio: **ATTIVO** (normalizzazione, riduzione rumore, rimozione silenzi) |
|
|
""") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
gr.Markdown("## Input") |
|
|
|
|
|
|
|
|
audio_input = gr.Audio( |
|
|
label="📥 Registra dal microfono o carica un file", |
|
|
type="filepath", |
|
|
sources=["microphone", "upload"], |
|
|
format="wav", |
|
|
streaming=False, |
|
|
interactive=True, |
|
|
) |
|
|
transcribe_btn = gr.Button("🚀 Trascrivi e Confronta", variant="primary") |
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
gr.Markdown("## Output") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Modello base (Whisper V3)") |
|
|
original_output = gr.Textbox( |
|
|
label="Transcription", |
|
|
lines=12, |
|
|
interactive=False, |
|
|
show_copy_button=True, |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Modello fine-tuned ReportAId") |
|
|
finetuned_output = gr.Textbox( |
|
|
label="Transcription", |
|
|
lines=12, |
|
|
interactive=False, |
|
|
show_copy_button=True, |
|
|
) |
|
|
|
|
|
|
|
|
medgemma_output = gr.Textbox( |
|
|
label="Testo finale", |
|
|
lines=12, |
|
|
interactive=False, |
|
|
show_copy_button=True, |
|
|
visible=False, |
|
|
) |
|
|
|
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=transcribe_comparison, |
|
|
inputs=[audio_input], |
|
|
outputs=[original_output, finetuned_output, medgemma_output], |
|
|
show_progress=True, |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
) |
|
|
|
|
|
demo = create_interface() |
|
|
|
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True, |
|
|
inbrowser=False, |
|
|
quiet=False, |
|
|
) |
|
|
|