File size: 5,746 Bytes
e3bdc52
 
 
 
72892b6
e3bdc52
 
 
 
 
ffd044a
ac7fd5c
 
e3bdc52
ea97e04
 
 
 
 
 
 
 
 
 
89dd351
 
 
 
 
 
 
 
 
ea97e04
ffd044a
 
 
 
 
 
 
ea97e04
 
 
 
 
 
adfa4a4
e3bdc52
bcf4378
e3bdc52
bcf4378
 
e3bdc52
 
 
 
ea97e04
 
 
72892b6
e3bdc52
 
 
 
72892b6
e3bdc52
 
72892b6
e3bdc52
 
 
72892b6
e3bdc52
 
 
72892b6
e3bdc52
 
 
72892b6
e3bdc52
 
 
72892b6
e3bdc52
72892b6
 
e3bdc52
 
 
 
72892b6
e3bdc52
 
72892b6
 
 
 
 
 
 
89dd351
 
 
 
 
 
 
 
 
72892b6
e3bdc52
 
 
72892b6
 
e3bdc52
72892b6
e3bdc52
 
 
72892b6
e3bdc52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import sys
import json
import torch
import librosa
import numpy as np
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

import os

LOCAL_MODEL_DIR = "./local_finetuned_model"
# Prioridade: 1. Pasta Local (Upload direto) | 2. Repo Customizado (Variável de Ambiente) | 3. Modelo Base
CUSTOM_MODEL_REPO = os.environ.get("CUSTOM_MODEL_REPO", None)
BASE_MODEL = "HyperMoon/wav2vec2-base-960h-finetuned-deepfake"

# Singleton para carregar o modelo e processador apenas uma vez
_feature_extractor = None
_model = None
_last_model_path = None

def get_wav2vec_resources(model_path):
    global _feature_extractor, _model, _last_model_path
    
    # Invalidação de Cache: Se o path mudou, precisamos recarregar o modelo
    if _feature_extractor is None or _model is None or _last_model_path != model_path:
        print(f"Carregando motor Wav2Vec2: {model_path}...", file=sys.stderr)
        try:
            _feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
            model = AutoModelForAudioClassification.from_pretrained(model_path)
        except Exception as e:
            print(f"⚠️ Erro ao carregar modelo '{model_path}': {e}. Usando fallback para {BASE_MODEL}...", file=sys.stderr)
            _feature_extractor = AutoFeatureExtractor.from_pretrained(BASE_MODEL)
            model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL)
            model_path = BASE_MODEL
        
        # --- OTIMIZAÇÃO: Quantização Dinâmica para CPU ---
        if not torch.cuda.is_available():
            print("Aplicando Quantização Dinâmica (CPU Optimization)...", file=sys.stderr)
            model = torch.quantization.quantize_dynamic(
                model, {torch.nn.Linear}, dtype=torch.qint8
            )
        
        _model = model
        _model.eval()
        _last_model_path = model_path
        
    return _feature_extractor, _model

def run_inference(audio_path, fallback_model_name=None):
    """
    Realiza inferência real usando APENAS o modelo base HyperMoon, conforme solicitado.
    """
    model_path = BASE_MODEL
    model_name = f"HF Model ({model_path})"
        
    print(f"Rodando inferência REAL [{model_name}] em: {audio_path}", file=sys.stderr)
    
    try:
        # 1. Carrega extrator de características e modelo (Singleton)
        feature_extractor, model = get_wav2vec_resources(model_path)

        
        # 2. Carrega e pré-processa o áudio
        print(f"Lendo áudio: {audio_path}", file=sys.stderr)
        audio, sr = librosa.load(audio_path, sr=16000)
        
        # 3. Prepara inputs para o áudio completo
        inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        
        # 4. Inferência Principal
        with torch.no_grad():
            logits = model(**inputs).logits
            
        # 5. Processa resultados
        scores = torch.softmax(logits, dim=-1)
        id2label = model.config.id2label
        
        # Encontra o índice da classe 'fraud'
        fraud_idx = 0
        for idx, lbl in id2label.items():
            if any(x in lbl.lower() for x in ['fake', 'spoof', 'fraud']):
                fraud_idx = int(idx)
                break
        
        fraud_prob = scores[0][fraud_idx].item()
        is_fraud = fraud_prob > 0.5
        
        # --- OTIMIZAÇÃO: Análise Temporal Vetorizada (XAI) ---
        print("Iniciando Análise Temporal Otimizada...", file=sys.stderr)
        temporal_scores = []
        segment_duration = 1.0  # 1 segundo
        samples_per_segment = int(segment_duration * 16000)
        
        segments = []
        for i in range(0, len(audio), samples_per_segment):
            segment = audio[i : i + samples_per_segment]
            if len(segment) < samples_per_segment // 2: continue
            # Pad segment if too short for the batcher
            if len(segment) < samples_per_segment:
                segment = np.pad(segment, (0, samples_per_segment - len(segment)))
            segments.append(segment)
        
        if segments:
            # Processa segmentos em mini-batches para evitar OOM no Hugging Face
            BATCH_SIZE = 8
            for i in range(0, len(segments), BATCH_SIZE):
                batch_segments = segments[i:i+BATCH_SIZE]
                seg_inputs = feature_extractor(batch_segments, sampling_rate=16000, return_tensors="pt", padding=True)
                with torch.no_grad():
                    seg_logits = model(**seg_inputs).logits
                    seg_probs = torch.softmax(seg_logits, dim=-1)
                    temporal_scores.extend([round(p[fraud_idx].item(), 3) for p in seg_probs])
        # ---------------------------------------------------

        results = {
            "model": model_name,
            "prediction": id2label[torch.argmax(scores).item()].upper(),
            "confidence": scores[0][torch.argmax(scores).item()].item(),
            "deepfake_probability": fraud_prob,
            "temporal_scores": temporal_scores,
            "verdict": "SPOOF" if is_fraud else "BONAFIDE",
            "metadata": {
                "id2label": id2label,
                "quantized": not torch.cuda.is_available()
            }
        }
        
    except Exception as e:
        print(f"Erro na inferência: {e}")
        results = {
            "error": str(e),
            "verdict": "ERROR"
        }
    
    return results

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Uso: python inference_wav2vec.py <audio_path>")
    else:
        # Silenciamos warnings de transformers
        import warnings
        warnings.filterwarnings("ignore")
        print(json.dumps(run_inference(sys.argv[1])))