| """Test full Nes2Net model loading and inference.""" |
| import sys, os, torch, argparse |
|
|
| |
| PROJECT = r'C:\E\Project\Project B.tech\Multimodal Deepfake Detection' |
| NES2NET = os.path.join(PROJECT, 'audio_detection', 'Nes2Net_ASVspoof_ITW') |
| sys.path.insert(0, os.path.join(NES2NET, 'model_scripts')) |
| sys.path.insert(0, NES2NET) |
|
|
| os.environ['XLSR_CHECKPOINT_PATH'] = os.path.join(PROJECT, 'audio_detection', 'checkpoints', 'xlsr2_300m.pt') |
|
|
| from wav2vec2_Nes2Net_X import wav2vec2_Nes2Net_no_Res_w_allT |
|
|
| |
| args = argparse.Namespace( |
| n_output_logits=2, |
| dilation=2, |
| pool_func='mean', |
| Nes_ratio=[8, 8], |
| SE_ratio=[1], |
| ) |
|
|
| print("Building full Nes2Net model...") |
| model = wav2vec2_Nes2Net_no_Res_w_allT(args=args, device='cpu') |
|
|
| |
| CKPT_PATH = os.path.join(PROJECT, 'audio_detection', 'checkpoints', |
| 'ASVspoof_2021_wav2vec2_Nes2Net_X_e100_bz12_lr2.5e_07_algo4_avg_ckpt_ep56_60_62_76_95.pth') |
|
|
| print(f"\nLoading Nes2Net checkpoint...") |
| ckpt = torch.load(CKPT_PATH, map_location='cpu', weights_only=True) |
|
|
| |
| print(f"Checkpoint has {len(ckpt)} keys") |
| missing, unexpected = model.load_state_dict(ckpt, strict=False) |
| print(f"Missing keys: {len(missing)}") |
| if missing: |
| for k in missing[:10]: |
| print(f" {k}") |
| print(f"Unexpected keys: {len(unexpected)}") |
| if unexpected: |
| for k in unexpected[:10]: |
| print(f" {k}") |
|
|
| |
| print("\nTesting forward pass...") |
| x = torch.randn(1, 32000) |
| model.eval() |
| with torch.no_grad(): |
| out = model(x) |
| print(f"Output shape: {out.shape}") |
| print(f"Output values: {out}") |
| print(f"Prediction: {'Real' if out[0][0] > out[0][1] else 'Fake'}") |
| print("\nFull pipeline test PASSED!") |
|
|