Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import torch | |
| import torch.nn.functional as F | |
| def fgsm_attack(model, frames, blink, labels, eps: float): | |
| frames_adv = frames.detach().clone().requires_grad_(True) | |
| logits, _ = model(frames_adv, blink) | |
| loss = F.cross_entropy(logits, labels) | |
| loss.backward() | |
| perturbed = frames_adv + eps * frames_adv.grad.sign() | |
| return perturbed.clamp(0.0, 1.0).detach() | |
| def pgd_attack(model, frames, blink, labels, eps: float, alpha: float, steps: int): | |
| ori = frames.detach() | |
| adv = ori.clone() | |
| for _ in range(steps): | |
| adv.requires_grad_(True) | |
| logits, _ = model(adv, blink) | |
| loss = F.cross_entropy(logits, labels) | |
| loss.backward() | |
| adv = adv + alpha * adv.grad.sign() | |
| delta = torch.clamp(adv - ori, min=-eps, max=eps) | |
| adv = torch.clamp(ori + delta, 0.0, 1.0).detach() | |
| return adv | |
| def attention_consistency_loss(clean_feat: torch.Tensor, adv_feat: torch.Tensor) -> torch.Tensor: | |
| return F.mse_loss(clean_feat, adv_feat) | |
| def blink_timing_regularizer( | |
| blink_seq: torch.Tensor, fps: float, min_seconds: float, max_seconds: float | |
| ) -> torch.Tensor: | |
| # Penalize blink durations outside physiologic range. | |
| min_frames = min_seconds * fps | |
| max_frames = max_seconds * fps | |
| blink_binary = (blink_seq < blink_seq.mean(dim=1, keepdim=True)).float() | |
| durations = blink_binary.sum(dim=1) | |
| low_pen = torch.relu(min_frames - durations) | |
| high_pen = torch.relu(durations - max_frames) | |
| return (low_pen + high_pen).mean() | |