Samabe1109's picture
download
raw
10.9 kB
"""Adversarial robustness evaluation for quantum syndrome decoder.
Implements FGSM, PGD, and adversarial training defense.
Defensive research: testing model robustness under attack scenarios.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from pathlib import Path
from decoder_model import QuantumSyndromeDecoder
from generate_data import generate_dataset
def fgsm_attack(model, syndrome, error_target, epsilon=0.1, loss_fn=None):
"""Fast Gradient Sign Method (FGSM) attack."""
if loss_fn is None:
loss_fn = nn.BCELoss()
syndrome_adv = syndrome.clone().detach().requires_grad_(True)
predictions = model(syndrome_adv)
inverted_target = 1.0 - error_target
loss = loss_fn(predictions, inverted_target.repeat(1, 2, 1, 1, 1))
loss.backward()
perturbation = epsilon * syndrome_adv.grad.sign()
syndrome_attacked = syndrome + perturbation
return torch.clamp(syndrome_attacked, 0.0, 1.0)
def pgd_attack(model, syndrome, error_target, epsilon=0.1, alpha=0.02,
iterations=20, loss_fn=None):
"""Projected Gradient Descent (PGD) attack."""
if loss_fn is None:
loss_fn = nn.BCELoss()
delta = torch.zeros_like(syndrome).uniform_(-epsilon, epsilon)
syndrome_adv = torch.clamp(syndrome + delta, 0.0, 1.0)
inverted_target = 1.0 - error_target
for i in range(iterations):
syndrome_adv.requires_grad_(True)
predictions = model(syndrome_adv)
loss = loss_fn(predictions, inverted_target.repeat(1, 2, 1, 1, 1))
loss.backward()
perturbation = alpha * syndrome_adv.grad.sign()
syndrome_adv = syndrome_adv.detach() + perturbation
delta = torch.clamp(syndrome_adv - syndrome, -epsilon, epsilon)
syndrome_adv = torch.clamp(syndrome + delta, 0.0, 1.0)
return syndrome_adv.detach()
def random_noise_attack(syndrome, noise_level=0.1):
"""Baseline: random Gaussian noise attack."""
noise = torch.randn_like(syndrome) * noise_level
return torch.clamp(syndrome + noise, 0.0, 1.0)
def evaluate_under_attack(model, syndromes, errors, attack_fn, attack_name,
device='cpu', epsilon=0.1, batch_size=64):
"""Evaluate model performance under adversarial attack."""
model.eval()
total_clean_correct = 0
total_adv_correct = 0
total_elements = 0
total_fn_increase = 0
total_fp_increase = 0
for i in range(0, len(syndromes), batch_size):
batch_syn = torch.from_numpy(syndromes[i:i+batch_size]).float().to(device)
batch_err = torch.from_numpy(errors[i:i+batch_size]).float().to(device)
batch_syn = batch_syn.permute(0, 3, 1, 2).unsqueeze(-1)
batch_err = batch_err.permute(0, 3, 1, 2).unsqueeze(-1)
with torch.no_grad():
clean_pred = model(batch_syn)
target = batch_err.repeat(1, 2, 1, 1, 1)
clean_binary = (clean_pred > 0.5).float()
clean_correct = (clean_binary == target).float().sum().item()
total_clean_correct += clean_correct
if attack_name == 'random_noise':
adv_syn = attack_fn(batch_syn, epsilon)
else:
adv_syn = attack_fn(model, batch_syn, target, epsilon=epsilon)
with torch.no_grad():
adv_pred = model(adv_syn)
adv_binary = (adv_pred > 0.5).float()
adv_correct = (adv_binary == target).float().sum().item()
total_adv_correct += adv_correct
clean_fn = (target * (1 - clean_binary)).sum().item()
adv_fn = (target * (1 - adv_binary)).sum().item()
total_fn_increase += (adv_fn - clean_fn)
clean_fp = ((1 - target) * clean_binary).sum().item()
adv_fp = ((1 - target) * adv_binary).sum().item()
total_fp_increase += (adv_fp - clean_fp)
total_elements += target.numel()
clean_acc = total_clean_correct / total_elements if total_elements > 0 else 0
adv_acc = total_adv_correct / total_elements if total_elements > 0 else 0
return {
'attack': attack_name,
'epsilon': epsilon,
'clean_accuracy': round(clean_acc, 4),
'adversarial_accuracy': round(adv_acc, 4),
'accuracy_drop': round(clean_acc - adv_acc, 4),
'relative_drop_pct': round((clean_acc - adv_acc) / (clean_acc + 1e-8) * 100, 2),
'fn_increase': round(total_fn_increase, 2),
'fp_increase': round(total_fp_increase, 2),
}
def adversarial_training(model, train_syndromes, train_errors, epochs=30,
epsilon=0.05, lr=5e-4, device='cpu'):
"""Adversarial training: mix clean + adversarial examples."""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=3e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
criterion = nn.BCELoss()
batch_size = 64
for epoch in range(epochs):
model.train()
epoch_loss = 0
n_batches = 0
perm = np.random.permutation(len(train_syndromes))
for i in range(0, len(train_syndromes), batch_size):
idx = perm[i:i+batch_size]
batch_syn = torch.from_numpy(train_syndromes[idx]).float().to(device)
batch_err = torch.from_numpy(train_errors[idx]).float().to(device)
batch_syn = batch_syn.permute(0, 3, 1, 2).unsqueeze(-1)
batch_err = batch_err.permute(0, 3, 1, 2).unsqueeze(-1)
target = batch_err.repeat(1, 2, 1, 1, 1)
adv_syn = pgd_attack(model, batch_syn, target, epsilon=epsilon,
alpha=epsilon/5, iterations=5)
mix_mask = torch.rand(len(batch_syn)) > 0.5
mixed_syn = torch.where(mix_mask.view(-1, 1, 1, 1, 1).to(device),
adv_syn, batch_syn)
optimizer.zero_grad()
predictions = model(mixed_syn)
loss = criterion(predictions, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
scheduler.step()
if (epoch + 1) % 10 == 0:
print(f" Adv-Train Epoch {epoch+1}/{epochs}: loss={epoch_loss/n_batches:.4f}")
return model
def run_full_evaluation():
"""Run complete adversarial robustness evaluation."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
print("\n[1/4] Generating dataset...")
syndromes, errors, metadata = generate_dataset(
distances=[5], error_rates=[0.05], samples_per_config=2000
)
n = len(syndromes)
n_test = int(0.3 * n)
test_syn = syndromes[n-n_test:]
test_err = errors[n-n_test:]
train_syn = syndromes[:n-n_test]
train_err = errors[:n-n_test]
print("\n[2/4] Training baseline decoder...")
model = QuantumSyndromeDecoder(distance=5, channels=128, num_layers=4).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=3e-2)
batch_size = 64
for epoch in range(20):
model.train()
perm = np.random.permutation(len(train_syn))
for i in range(0, len(train_syn), batch_size):
idx = perm[i:i+batch_size]
batch_syn = torch.from_numpy(train_syn[idx]).float().to(device).permute(0, 3, 1, 2).unsqueeze(-1)
batch_err = torch.from_numpy(train_err[idx]).float().to(device).permute(0, 3, 1, 2).unsqueeze(-1)
target = batch_err.repeat(1, 2, 1, 1, 1)
optimizer.zero_grad()
pred = model(batch_syn)
loss = nn.BCELoss()(pred, target)
loss.backward()
optimizer.step()
if (epoch + 1) % 5 == 0:
print(f" Baseline Epoch {epoch+1}: loss={loss.item():.4f}")
print("\n[3/4] Running adversarial attacks...")
all_results = []
epsilons = [0.01, 0.05, 0.1, 0.2, 0.3]
print("\n --- FGSM Attack ---")
for eps in epsilons:
r = evaluate_under_attack(model, test_syn, test_err, fgsm_attack,
'FGSM', device, epsilon=eps)
all_results.append(r)
print(f" eps={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)")
print("\n --- PGD Attack ---")
for eps in [0.01, 0.05, 0.1, 0.2]:
r = evaluate_under_attack(model, test_syn, test_err, pgd_attack,
'PGD', device, epsilon=eps)
all_results.append(r)
print(f" eps={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)")
print("\n --- Random Noise (baseline) ---")
for eps in epsilons:
r = evaluate_under_attack(model, test_syn, test_err, random_noise_attack,
'random_noise', device, epsilon=eps)
all_results.append(r)
print(f" noise={eps}: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)")
print("\n[4/4] Adversarial training defense...")
model_adv = QuantumSyndromeDecoder(distance=5, channels=128, num_layers=4).to(device)
model_adv.load_state_dict(model.state_dict())
model_adv = adversarial_training(model_adv, train_syn, train_err,
epochs=20, epsilon=0.05, device=device)
print("\n --- Hardened model under PGD eps=0.1 ---")
r = evaluate_under_attack(model_adv, test_syn, test_err, pgd_attack,
'PGD_on_hardened', device, epsilon=0.1)
all_results.append(r)
print(f" Hardened: clean={r['clean_accuracy']:.4f} -> adv={r['adversarial_accuracy']:.4f} (drop={r['relative_drop_pct']:.1f}%)")
Path('outputs').mkdir(exist_ok=True)
with open('outputs/adversarial_results.json', 'w') as f:
json.dump(all_results, f, indent=2)
print("\n" + "="*80)
print(" ADVERSARIAL ROBUSTNESS SUMMARY")
print("="*80)
print(f"{'Attack':<22} {'Epsilon':<10} {'Clean Acc':<12} {'Adv Acc':<12} {'Drop %':<10}")
print("-"*80)
for r in all_results:
print(f"{r['attack']:<22} {r['epsilon']:<10} {r['clean_accuracy']:<12} "
f"{r['adversarial_accuracy']:<12} {r['relative_drop_pct']:<10}")
print("="*80)
torch.save(model.state_dict(), 'outputs/baseline_decoder.pt')
torch.save(model_adv.state_dict(), 'outputs/hardened_decoder.pt')
print("\nModels saved to outputs/")
return all_results
if __name__ == "__main__":
results = run_full_evaluation()

Xet Storage Details

Size:
10.9 kB
·
Xet hash:
02d8d423b3a661ec56b56cc3b4a843f728137f864a3ac1fa60ff74217fdfed29

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.