| """Adversarial robustness evaluation for quantum syndrome decoder. | |
| Implements FGSM, PGD, and adversarial training defense. | |
| Defensive research: testing model robustness under attack scenarios. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import json | |
| from pathlib import Path | |
| from decoder_model import QuantumSyndromeDecoder | |
| from generate_data import generate_dataset | |
| def fgsm_attack(model, syndrome, error_target, epsilon=0.1, loss_fn=None): | |
| """Fast Gradient Sign Method (FGSM) attack.""" | |
| if loss_fn is None: | |
| loss_fn = nn.BCELoss() | |
| syndrome_adv = syndrome.clone().detach().requires_grad_(True) | |
| predictions = model(syndrome_adv) | |
| inverted_target = 1.0 - error_target | |
| loss = loss_fn(predictions, inverted_target.repeat(1, 2, 1, 1, 1)) | |
| loss.backward() | |
| perturbation = epsilon * syndrome_adv.grad.sign() | |
| syndrome_attacked = syndrome + perturbation | |
| return torch.clamp(syndrome_attacked, 0.0, 1.0) | |
| def pgd_attack(model, syndrome, error_target, epsilon=0.1, alpha=0.02, | |
| iterations=20, loss_fn=None): | |
| """Projected Gradient Descent (PGD) attack.""" | |
| if loss_fn is None: | |
| loss_fn = nn.BCELoss() | |
| delta = torch.zeros_like(syndrome).uniform_(-epsilon, epsilon) | |
| syndrome_adv = torch.clamp(syndrome + delta, 0.0, 1.0) | |
| inverted_target = 1.0 - error_target | |
| for i in range(iterations): | |
| syndrome_adv.requires_grad_(True) | |
| predictions = model(syndrome_adv) | |
| loss = loss_fn(predictions, inverted_target.repeat(1, 2, 1, 1, 1)) | |
| loss.backward() | |
| perturbation = alpha * syndrome_adv.grad.sign() | |
| syndrome_adv = syndrome_adv.detach() + perturbation | |
| delta = torch.clamp(syndrome_adv - syndrome, -epsilon, epsilon) | |
| syndrome_adv = torch.clamp(syndrome + delta, 0.0, 1.0) | |
| return syndrome_adv.detach() | |
| def random_noise_attack(syndrome, noise_level=0.1): | |
| """Baseline: random Gaussian noise attack.""" | |
| noise = torch.randn_like(syndrome) * noise_level | |
| return torch.clamp(syndrome + noise, 0.0, 1.0) | |
| def evaluate_under_attack(model, syndromes, errors, attack_fn, attack_name, | |
| device='cpu', epsilon=0.1, batch_size=64): | |
| """Evaluate model performance under adversarial attack.""" | |
| model.eval() | |
| total_clean_correct = 0 | |
| total_adv_correct = 0 | |
| total_elements = 0 | |
| total_fn_increase = 0 | |
| total_fp_increase = 0 | |
| for i in range(0, len(syndromes), batch_size): | |
| batch_syn = torch.from_numpy(syndromes[i:i+batch_size]).float().to(device) | |
| batch_err = torch.from_numpy(errors[i:i+batch_size]).float().to(device) | |
| batch_syn = batch_syn.permute(0, 3, 1, 2).unsqueeze(-1) | |
| batch_err = batch_err.permute(0, 3, 1, 2).unsqueeze(-1) | |
| with torch.no_grad(): | |
| clean_pred = model(batch_syn) | |
| target = batch_err.repeat(1, 2, 1, 1, 1) | |
| clean_binary = (clean_pred > 0.5).float() | |
| clean_correct = (clean_binary == target).float().sum().item() | |
| total_clean_correct += clean_correct | |
| if attack_name == 'random_noise': | |
| adv_syn = attack_fn(batch_syn, epsilon) | |
| else: | |
| adv_syn = attack_fn(model, batch_syn, target, epsilon=epsilon) | |
| with torch.no_grad(): | |
| adv_pred = model(adv_syn) | |
| adv_binary = (adv_pred > 0.5).float() | |
| adv_correct = (adv_binary == target).float().sum().item() | |
| total_adv_correct += adv_correct | |
| clean_fn = (target * (1 - clean_binary)).sum().item() | |
| adv_fn = (target * (1 - adv_binary)).sum().item() | |
| total_fn_increase += (adv_fn - clean_fn) | |
| clean_fp = ((1 - target) * clean_binary).sum().item() | |
| adv_fp = ((1 - target) * adv_binary).sum().item() | |
| total_fp_increase += (adv_fp - clean_fp) | |
| total_elements += target.numel() | |
| clean_acc = total_clean_correct / total_elements if total_elements > 0 else 0 | |
| adv_acc = total_adv_correct / total_elements if total_elements > 0 else 0 | |
| return { | |
| 'attack': attack_name, | |
| 'epsilon': epsilon, | |
| 'clean_accuracy': round(clean_acc, 4), | |
| 'adversarial_accuracy': round(adv_acc, 4), | |
| 'accuracy_drop': round(clean_acc - adv_acc, 4), | |
| 'relative_drop_pct': round((clean_acc - adv_acc) / (clean_acc + 1e-8) * 100, 2), | |
| 'fn_increase': round(total_fn_increase, 2), | |
| 'fp_increase': round(total_fp_increase, 2), | |
| } | |
| def adversarial_training(model, train_syndromes, train_errors, epochs=30, | |
| epsilon=0.05, lr=5e-4, device='cpu'): | |
| """Adversarial training: mix clean + adversarial examples.""" | |
| model = model.to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=3e-2) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) | |
| criterion = nn.BCELoss() | |
| batch_size = 64 | |
| for epoch in range(epochs): | |
| model.train() | |
| epoch_loss = 0 | |
| n_batches = 0 | |
| perm = np.random.permutation(len(train_syndromes)) | |
| for i in range(0, len(train_syndromes), batch_size): | |
| idx = perm[i:i+batch_size] | |
| batch_syn = torch.from_numpy(train_syndromes[idx]).float().to(device) | |
| batch_err = torch.from_numpy(train_errors[idx]).float().to(device) | |
| batch_syn = batch_syn.permute(0, 3, 1, 2).unsqueeze(-1) | |
| batch_err = batch_err.permute(0, 3, 1, 2).unsqueeze(-1) | |
| target = batch_err.repeat(1, 2, 1, 1, 1) | |
| adv_syn = pgd_attack(model, batch_syn, target, epsilon=epsilon, | |
| alpha=epsilon/5, iterations=5) | |
| mix_mask = torch.rand(len(batch_syn)) > 0.5 | |
| mixed_syn = torch.where(mix_mask.view(-1, 1, 1, 1, 1).to(device), | |
| adv_syn, batch_syn) | |
| optimizer.zero_grad() | |
| predictions = model(mixed_syn) | |
| loss = criterion(predictions, target) | |
| loss.backward() | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| n_batches += 1 | |
| scheduler.step() | |
| if (epoch + 1) % 10 == 0: | |
| print(f" Adv-Train Epoch {epoch+1}/{epochs}: loss={epoch_loss/n_batches:.4f}") | |
| return model | |
| def run_full_evaluation(): | |
| """Run complete adversarial robustness evaluation.""" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Device: {device}") | |
| print("\n[1/4] Generating dataset...") | |
| syndromes, errors, metadata = generate_dataset( | |
| distances=[5], error_rates=[0.05], samples_per_config=2000 | |
| ) | |
| n = len(syndromes) | |
| n_test = int(0.3 * n) | |
| test_syn = syndromes[n-n_test:] | |
| test_err = errors[n-n_test:] | |
| train_syn = syndromes[:n-n_test] | |
| train_err = errors[:n-n_test] | |
| print("\n[2/4] Training baseline decoder...") | |
| model = QuantumSyndromeDecoder(distance=5, channels=128, num_layers=4).to(device) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=3e-2) | |
| batch_size = 64 | |
| for epoch in range(20): | |
| model.train() | |
| perm = np.random.permutation(len(train_syn)) | |
| for i in range(0, len(train_syn), batch_size): | |
| idx = perm[i:i+batch_size] | |
| batch_syn = torch.from_numpy(train_syn[idx]).float().to(device).permute(0, 3, 1, 2).unsqueeze(-1) | |
| batch_err = torch.from_numpy(train_err[idx]).float().to(device).permute(0, 3, 1, 2).unsqueeze(-1) | |
| target = batch_err.repeat(1, 2, 1, 1, 1) | |
| optimizer.zero_grad() | |
| pred = model(batch_syn) | |
| loss = nn.BCELoss()(pred, target) | |
| loss.backward() | |
| optimizer.step() | |
| if (epoch + 1) % 5 == 0: | |
| print(f" Baseline Epoch {epoch+1}: loss={loss.item():.4f}") | |
| print("\n[3/4] Running adversarial attacks...") | |
| all_results = [] | |
| epsilons = [0.01, 0.05, 0.1, 0.2, 0.3] | |
| print("\n --- FGSM Attack ---") | |
| for eps in epsilons: | |
| r = evaluate_under_attack(model, test_syn, test_err, fgsm_attack, | |
| 'FGSM', device, epsilon=eps) | |
| all_results.append(r) | |
| print(f" eps={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)") | |
| print("\n --- PGD Attack ---") | |
| for eps in [0.01, 0.05, 0.1, 0.2]: | |
| r = evaluate_under_attack(model, test_syn, test_err, pgd_attack, | |
| 'PGD', device, epsilon=eps) | |
| all_results.append(r) | |
| print(f" eps={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)") | |
| print("\n --- Random Noise (baseline) ---") | |
| for eps in epsilons: | |
| r = evaluate_under_attack(model, test_syn, test_err, random_noise_attack, | |
| 'random_noise', device, epsilon=eps) | |
| all_results.append(r) | |
| print(f" noise={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)") | |
| print("\n[4/4] Adversarial training defense...") | |
| model_adv = QuantumSyndromeDecoder(distance=5, channels=128, num_layers=4).to(device) | |
| model_adv.load_state_dict(model.state_dict()) | |
| model_adv = adversarial_training(model_adv, train_syn, train_err, | |
| epochs=20, epsilon=0.05, device=device) | |
| print("\n --- Hardened model under PGD eps=0.1 ---") | |
| r = evaluate_under_attack(model_adv, test_syn, test_err, pgd_attack, | |
| 'PGD_on_hardened', device, epsilon=0.1) | |
| all_results.append(r) | |
| print(f" Hardened: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)") | |
| Path('outputs').mkdir(exist_ok=True) | |
| with open('outputs/adversarial_results.json', 'w') as f: | |
| json.dump(all_results, f, indent=2) | |
| print("\n" + "="*80) | |
| print(" ADVERSARIAL ROBUSTNESS SUMMARY") | |
| print("="*80) | |
| print(f"{'Attack':<22} {'Epsilon':<10} {'Clean Acc':<12} {'Adv Acc':<12} {'Drop %':<10}") | |
| print("-"*80) | |
| for r in all_results: | |
| print(f"{r['attack']:<22} {r['epsilon']:<10} {r['clean_accuracy']:<12} " | |
| f"{r['adversarial_accuracy']:<12} {r['relative_drop_pct']:<10}") | |
| print("="*80) | |
| torch.save(model.state_dict(), 'outputs/baseline_decoder.pt') | |
| torch.save(model_adv.state_dict(), 'outputs/hardened_decoder.pt') | |
| print("\nModels saved to outputs/") | |
| return all_results | |
| if __name__ == "__main__": | |
| results = run_full_evaluation() |
Xet Storage Details
- Size:
- 10.9 kB
- Xet hash:
- 02d8d423b3a661ec56b56cc3b4a843f728137f864a3ac1fa60ff74217fdfed29
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.