confereai-dev / execution /inference_wav2vec.py
TEDDyx86's picture
Fix: Force BASE_MODEL and correct fraud_probability variable in frontend
bcf4378
import sys
import json
import torch
import librosa
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import os
LOCAL_MODEL_DIR = "./local_finetuned_model"
# Prioridade: 1. Pasta Local (Upload direto) | 2. Repo Customizado (Variável de Ambiente) | 3. Modelo Base
CUSTOM_MODEL_REPO = os.environ.get("CUSTOM_MODEL_REPO", None)
BASE_MODEL = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake"
# Singleton para carregar o modelo e processador apenas uma vez
_feature_extractor = None
_model = None
_last_model_path = None
def get_wav2vec_resources(model_path):
global _feature_extractor, _model, _last_model_path
# Invalidação de Cache: Se o path mudou, precisamos recarregar o modelo
if _feature_extractor is None or _model is None or _last_model_path != model_path:
print(f"Carregando motor Wav2Vec2: {model_path}...", file=sys.stderr)
try:
_feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
model = AutoModelForAudioClassification.from_pretrained(model_path)
except Exception as e:
print(f"⚠️ Erro ao carregar modelo '{model_path}': {e}. Usando fallback para {BASE_MODEL}...", file=sys.stderr)
_feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_MODEL)
model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL)
model_path = BASE_MODEL
# --- OTIMIZAÇÃO: Quantização Dinâmica para CPU ---
if not torch.cuda.is_available():
print("Aplicando Quantização Dinâmica (CPU Optimization)...", file=sys.stderr)
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
_model = model
_model.eval()
_last_model_path = model_path
return _feature_extractor, _model
def run_inference(audio_path, fallback_model_name=None):
"""
Realiza inferência real usando APENAS o modelo base HyperMoon, conforme solicitado.
"""
model_path = BASE_MODEL
model_name = f"HF Model ({model_path})"
print(f"Rodando inferência REAL [{model_name}] em: {audio_path}", file=sys.stderr)
try:
# 1. Carrega extrator de características e modelo (Singleton)
feature_extractor, model = get_wav2vec_resources(model_path)
# 2. Carrega e pré-processa o áudio
print(f"Lendo áudio: {audio_path}", file=sys.stderr)
audio, sr = librosa.load(audio_path, sr=16000)
# 3. Prepara inputs para o áudio completo
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
# 4. Inferência Principal
with torch.no_grad():
logits = model(**inputs).logits
# 5. Processa resultados
scores = torch.softmax(logits, dim=-1)
id2label = model.config.id2label
# Encontra o índice da classe 'fraud'
fraud_idx = 0
for idx, lbl in id2label.items():
if any(x in lbl.lower() for x in ['fake', 'spoof', 'fraud']):
fraud_idx = int(idx)
break
fraud_prob = scores[0][fraud_idx].item()
is_fraud = fraud_prob > 0.5
# --- OTIMIZAÇÃO: Análise Temporal Vetorizada (XAI) ---
print("Iniciando Análise Temporal Otimizada...", file=sys.stderr)
temporal_scores = []
segment_duration = 1.0 # 1 segundo
samples_per_segment = int(segment_duration * 16000)
segments = []
for i in range(0, len(audio), samples_per_segment):
segment = audio[i : i + samples_per_segment]
if len(segment) < samples_per_segment // 2: continue
# Pad segment if too short for the batcher
if len(segment) < samples_per_segment:
segment = np.pad(segment, (0, samples_per_segment - len(segment)))
segments.append(segment)
if segments:
# Processa segmentos em mini-batches para evitar OOM no Hugging Face
BATCH_SIZE = 8
for i in range(0, len(segments), BATCH_SIZE):
batch_segments = segments[i:i+BATCH_SIZE]
seg_inputs = feature_extractor(batch_segments, sampling_rate=16000, return_tensors="pt", padding=True)
with torch.no_grad():
seg_logits = model(**seg_inputs).logits
seg_probs = torch.softmax(seg_logits, dim=-1)
temporal_scores.extend([round(p[fraud_idx].item(), 3) for p in seg_probs])
# ---------------------------------------------------
results = {
"model": model_name,
"prediction": id2label[torch.argmax(scores).item()].upper(),
"confidence": scores[0][torch.argmax(scores).item()].item(),
"deepfake_probability": fraud_prob,
"temporal_scores": temporal_scores,
"verdict": "SPOOF" if is_fraud else "BONAFIDE",
"metadata": {
"id2label": id2label,
"quantized": not torch.cuda.is_available()
}
}
except Exception as e:
print(f"Erro na inferência: {e}")
results = {
"error": str(e),
"verdict": "ERROR"
}
return results
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Uso: python inference_wav2vec.py <audio_path>")
else:
# Silenciamos warnings de transformers
import warnings
warnings.filterwarnings("ignore")
print(json.dumps(run_inference(sys.argv[1])))