File size: 1,548 Bytes
1dc2504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from __future__ import annotations

import torch
import torch.nn.functional as F


def fgsm_attack(model, frames, blink, labels, eps: float):
    frames_adv = frames.detach().clone().requires_grad_(True)
    logits, _ = model(frames_adv, blink)
    loss = F.cross_entropy(logits, labels)
    loss.backward()
    perturbed = frames_adv + eps * frames_adv.grad.sign()
    return perturbed.clamp(0.0, 1.0).detach()


def pgd_attack(model, frames, blink, labels, eps: float, alpha: float, steps: int):
    ori = frames.detach()
    adv = ori.clone()
    for _ in range(steps):
        adv.requires_grad_(True)
        logits, _ = model(adv, blink)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        adv = adv + alpha * adv.grad.sign()
        delta = torch.clamp(adv - ori, min=-eps, max=eps)
        adv = torch.clamp(ori + delta, 0.0, 1.0).detach()
    return adv


def attention_consistency_loss(clean_feat: torch.Tensor, adv_feat: torch.Tensor) -> torch.Tensor:
    return F.mse_loss(clean_feat, adv_feat)


def blink_timing_regularizer(
    blink_seq: torch.Tensor, fps: float, min_seconds: float, max_seconds: float
) -> torch.Tensor:
    # Penalize blink durations outside physiologic range.
    min_frames = min_seconds * fps
    max_frames = max_seconds * fps
    blink_binary = (blink_seq < blink_seq.mean(dim=1, keepdim=True)).float()
    durations = blink_binary.sum(dim=1)
    low_pen = torch.relu(min_frames - durations)
    high_pen = torch.relu(durations - max_frames)
    return (low_pen + high_pen).mean()