import sys import json import torch import librosa import numpy as np from transformers import AutoFeatureExtractor, AutoModelForAudioClassification import os LOCAL_MODEL_DIR = "./local_finetuned_model" # Prioridade: 1. Pasta Local (Upload direto) | 2. Repo Customizado (Variável de Ambiente) | 3. Modelo Base CUSTOM_MODEL_REPO = os.environ.get("CUSTOM_MODEL_REPO", None) BASE_MODEL = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake" # Singleton para carregar o modelo e processador apenas uma vez _feature_extractor = None _model = None _last_model_path = None def get_wav2vec_resources(model_path): global _feature_extractor, _model, _last_model_path # Invalidação de Cache: Se o path mudou, precisamos recarregar o modelo if _feature_extractor is None or _model is None or _last_model_path != model_path: print(f"Carregando motor Wav2Vec2: {model_path}...", file=sys.stderr) try: _feature_extractor = AutoFeatureExtractor.from_pretrained(model_path) model = AutoModelForAudioClassification.from_pretrained(model_path) except Exception as e: print(f"⚠️ Erro ao carregar modelo '{model_path}': {e}. Usando fallback para {BASE_MODEL}...", file=sys.stderr) _feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_MODEL) model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL) model_path = BASE_MODEL # --- OTIMIZAÇÃO: Quantização Dinâmica para CPU --- if not torch.cuda.is_available(): print("Aplicando Quantização Dinâmica (CPU Optimization)...", file=sys.stderr) model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) _model = model _model.eval() _last_model_path = model_path return _feature_extractor, _model def run_inference(audio_path, fallback_model_name=None): """ Realiza inferência real usando APENAS o modelo base HyperMoon, conforme solicitado. """ model_path = BASE_MODEL model_name = f"HF Model ({model_path})" print(f"Rodando inferência REAL [{model_name}] em: {audio_path}", file=sys.stderr) try: # 1. Carrega extrator de características e modelo (Singleton) feature_extractor, model = get_wav2vec_resources(model_path) # 2. Carrega e pré-processa o áudio print(f"Lendo áudio: {audio_path}", file=sys.stderr) audio, sr = librosa.load(audio_path, sr=16000) # 3. Prepara inputs para o áudio completo inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True) # 4. Inferência Principal with torch.no_grad(): logits = model(**inputs).logits # 5. Processa resultados scores = torch.softmax(logits, dim=-1) id2label = model.config.id2label # Encontra o índice da classe 'fraud' fraud_idx = 0 for idx, lbl in id2label.items(): if any(x in lbl.lower() for x in ['fake', 'spoof', 'fraud']): fraud_idx = int(idx) break fraud_prob = scores[0][fraud_idx].item() is_fraud = fraud_prob > 0.5 # --- OTIMIZAÇÃO: Análise Temporal Vetorizada (XAI) --- print("Iniciando Análise Temporal Otimizada...", file=sys.stderr) temporal_scores = [] segment_duration = 1.0 # 1 segundo samples_per_segment = int(segment_duration * 16000) segments = [] for i in range(0, len(audio), samples_per_segment): segment = audio[i : i + samples_per_segment] if len(segment) < samples_per_segment // 2: continue # Pad segment if too short for the batcher if len(segment) < samples_per_segment: segment = np.pad(segment, (0, samples_per_segment - len(segment))) segments.append(segment) if segments: # Processa segmentos em mini-batches para evitar OOM no Hugging Face BATCH_SIZE = 8 for i in range(0, len(segments), BATCH_SIZE): batch_segments = segments[i:i+BATCH_SIZE] seg_inputs = feature_extractor(batch_segments, sampling_rate=16000, return_tensors="pt", padding=True) with torch.no_grad(): seg_logits = model(**seg_inputs).logits seg_probs = torch.softmax(seg_logits, dim=-1) temporal_scores.extend([round(p[fraud_idx].item(), 3) for p in seg_probs]) # --------------------------------------------------- results = { "model": model_name, "prediction": id2label[torch.argmax(scores).item()].upper(), "confidence": scores[0][torch.argmax(scores).item()].item(), "deepfake_probability": fraud_prob, "temporal_scores": temporal_scores, "verdict": "SPOOF" if is_fraud else "BONAFIDE", "metadata": { "id2label": id2label, "quantized": not torch.cuda.is_available() } } except Exception as e: print(f"Erro na inferência: {e}") results = { "error": str(e), "verdict": "ERROR" } return results if __name__ == "__main__": if len(sys.argv) < 2: print("Uso: python inference_wav2vec.py ") else: # Silenciamos warnings de transformers import warnings warnings.filterwarnings("ignore") print(json.dumps(run_inference(sys.argv[1])))