deepfake-server / src /train /adversarial.py
DevQueen's picture
Sync from GitHub via hub-sync
1dc2504 verified
Raw
History Blame Contribute Delete
1.55 kB
from __future__ import annotations
import torch
import torch.nn.functional as F
def fgsm_attack(model, frames, blink, labels, eps: float):
frames_adv = frames.detach().clone().requires_grad_(True)
logits, _ = model(frames_adv, blink)
loss = F.cross_entropy(logits, labels)
loss.backward()
perturbed = frames_adv + eps * frames_adv.grad.sign()
return perturbed.clamp(0.0, 1.0).detach()
def pgd_attack(model, frames, blink, labels, eps: float, alpha: float, steps: int):
ori = frames.detach()
adv = ori.clone()
for _ in range(steps):
adv.requires_grad_(True)
logits, _ = model(adv, blink)
loss = F.cross_entropy(logits, labels)
loss.backward()
adv = adv + alpha * adv.grad.sign()
delta = torch.clamp(adv - ori, min=-eps, max=eps)
adv = torch.clamp(ori + delta, 0.0, 1.0).detach()
return adv
def attention_consistency_loss(clean_feat: torch.Tensor, adv_feat: torch.Tensor) -> torch.Tensor:
return F.mse_loss(clean_feat, adv_feat)
def blink_timing_regularizer(
blink_seq: torch.Tensor, fps: float, min_seconds: float, max_seconds: float
) -> torch.Tensor:
# Penalize blink durations outside physiologic range.
min_frames = min_seconds * fps
max_frames = max_seconds * fps
blink_binary = (blink_seq < blink_seq.mean(dim=1, keepdim=True)).float()
durations = blink_binary.sum(dim=1)
low_pen = torch.relu(min_frames - durations)
high_pen = torch.relu(durations - max_frames)
return (low_pen + high_pen).mean()