File size: 2,489 Bytes
0f96bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Test full Nes2Net model on GPU with a real audio sample."""
import sys, os, torch, argparse, numpy as np

PROJECT = r'C:\E\Project\Project B.tech\Multimodal Deepfake Detection'
NES2NET = os.path.join(PROJECT, 'audio_detection', 'Nes2Net_ASVspoof_ITW')
sys.path.insert(0, os.path.join(NES2NET, 'model_scripts'))
sys.path.insert(0, NES2NET)

os.environ['XLSR_CHECKPOINT_PATH'] = os.path.join(PROJECT, 'audio_detection', 'checkpoints', 'xlsr2_300m.pt')

from wav2vec2_Nes2Net_X import wav2vec2_Nes2Net_no_Res_w_allT

# Correct args matching the checkpoint defaults
args = argparse.Namespace(
    n_output_logits=2, dilation=2, pool_func='mean',
    Nes_ratio=[8, 8], SE_ratio=[1],
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

print("\nBuilding model...")
model = wav2vec2_Nes2Net_no_Res_w_allT(args=args, device=device)

CKPT_PATH = os.path.join(PROJECT, 'audio_detection', 'checkpoints',
    'ASVspoof_2021_wav2vec2_Nes2Net_X_e100_bz12_lr2.5e_07_algo4_avg_ckpt_ep56_60_62_76_95.pth')

print("Loading Nes2Net checkpoint...")
ckpt = torch.load(CKPT_PATH, map_location=device, weights_only=True)
model.load_state_dict(ckpt, strict=False)
model = model.to(device)
model.eval()

print(f"\nModel on device: {next(model.parameters()).device}")

# Test with synthetic audio (2 seconds at 16kHz)
print("\n=== Test 1: Random noise (should lean toward fake) ===")
x = torch.randn(1, 32000).to(device)
with torch.no_grad():
    out = model(x)
score = out[0][0].item() - out[0][1].item()
print(f"  Raw output: {out[0].cpu().tolist()}")
print(f"  Score (real-fake): {score:.4f}")
print(f"  Prediction: {'Real' if score > 0 else 'Fake'}")

# Test with a sine wave (440Hz, 2 seconds)
print("\n=== Test 2: 440Hz sine wave ===")
t = torch.linspace(0, 2, 32000).to(device)
sine = (torch.sin(2 * 3.14159 * 440 * t) * 0.5).unsqueeze(0)
with torch.no_grad():
    out = model(sine)
score = out[0][0].item() - out[0][1].item()
print(f"  Raw output: {out[0].cpu().tolist()}")
print(f"  Score (real-fake): {score:.4f}")
print(f"  Prediction: {'Real' if score > 0 else 'Fake'}")

if device == 'cuda':
    print(f"\nGPU memory used: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

print("\nGPU inference test PASSED!")