File size: 1,876 Bytes
0f96bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Test full Nes2Net model loading and inference."""
import sys, os, torch, argparse

# Setup paths
PROJECT = r'C:\E\Project\Project B.tech\Multimodal Deepfake Detection'
NES2NET = os.path.join(PROJECT, 'audio_detection', 'Nes2Net_ASVspoof_ITW')
sys.path.insert(0, os.path.join(NES2NET, 'model_scripts'))
sys.path.insert(0, NES2NET)

os.environ['XLSR_CHECKPOINT_PATH'] = os.path.join(PROJECT, 'audio_detection', 'checkpoints', 'xlsr2_300m.pt')

from wav2vec2_Nes2Net_X import wav2vec2_Nes2Net_no_Res_w_allT

# Setup args matching the Nes2Net-X config (defaults from easy_inference_demo.py)
args = argparse.Namespace(
    n_output_logits=2,
    dilation=2,
    pool_func='mean',
    Nes_ratio=[8, 8],
    SE_ratio=[1],
)

print("Building full Nes2Net model...")
model = wav2vec2_Nes2Net_no_Res_w_allT(args=args, device='cpu')

# Load the Nes2Net checkpoint
CKPT_PATH = os.path.join(PROJECT, 'audio_detection', 'checkpoints',
    'ASVspoof_2021_wav2vec2_Nes2Net_X_e100_bz12_lr2.5e_07_algo4_avg_ckpt_ep56_60_62_76_95.pth')

print(f"\nLoading Nes2Net checkpoint...")
ckpt = torch.load(CKPT_PATH, map_location='cpu', weights_only=True)

# Map checkpoint keys: checkpoint has ssl_model.model.* and Nested_Res2Net_TDNN.*
print(f"Checkpoint has {len(ckpt)} keys")
missing, unexpected = model.load_state_dict(ckpt, strict=False)
print(f"Missing keys: {len(missing)}")
if missing:
    for k in missing[:10]:
        print(f"  {k}")
print(f"Unexpected keys: {len(unexpected)}")
if unexpected:
    for k in unexpected[:10]:
        print(f"  {k}")

# Test forward pass
print("\nTesting forward pass...")
x = torch.randn(1, 32000)  # 2 seconds of audio at 16kHz
model.eval()
with torch.no_grad():
    out = model(x)
print(f"Output shape: {out.shape}")
print(f"Output values: {out}")
print(f"Prediction: {'Real' if out[0][0] > out[0][1] else 'Fake'}")
print("\nFull pipeline test PASSED!")