File size: 3,056 Bytes
e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Set
import warnings
import torch
@dataclass
class EvaluationState:
_attacks_to_run: Set[str]
path: Optional[Path] = None
_run_attacks: Set[str] = field(default_factory=set)
_robust_flags: Optional[torch.Tensor] = None
_last_saved: datetime = datetime(1, 1, 1)
_SAVE_TIMEOUT: int = 60
_clean_accuracy: float = float("nan")
def to_disk(self, force: bool = False) -> None:
seconds_since_last_save = (datetime.now() -
self._last_saved).total_seconds()
if self.path is None or (seconds_since_last_save < self._SAVE_TIMEOUT
and not force):
return
self._last_saved = datetime.now()
d = asdict(self)
if self.robust_flags is not None:
d["_robust_flags"] = d["_robust_flags"].cpu().tolist()
d["_run_attacks"] = list(self._run_attacks)
with self.path.open("w", ) as f:
json.dump(d, f, default=str)
@classmethod
def from_disk(cls, path: Path) -> "EvaluationState":
with path.open("r") as f:
d = json.load(f)
d["_robust_flags"] = torch.tensor(d["_robust_flags"], dtype=torch.bool)
d["path"] = Path(d["path"])
if path != d["path"]:
warnings.warn(
UserWarning(
"The given path is different from the one found in the state file."
))
d["_last_saved"] = datetime.fromisoformat(d["_last_saved"])
return cls(**d)
@property
def robust_flags(self) -> Optional[torch.Tensor]:
return self._robust_flags
@robust_flags.setter
def robust_flags(self, robust_flags: torch.Tensor) -> None:
self._robust_flags = robust_flags
self.to_disk(force=True)
@property
def run_attacks(self) -> Set[str]:
return self._run_attacks
def add_run_attack(self, attack: str) -> None:
self._run_attacks.add(attack)
self.to_disk()
@property
def attacks_to_run(self) -> Set[str]:
return self._attacks_to_run
@attacks_to_run.setter
def attacks_to_run(self, _: Set[str]) -> None:
raise ValueError("attacks_to_run cannot be set outside of the constructor")
@property
def clean_accuracy(self) -> float:
return self._clean_accuracy
@clean_accuracy.setter
def clean_accuracy(self, accuracy) -> None:
self._clean_accuracy = accuracy
self.to_disk(force=True)
@property
def robust_accuracy(self) -> float:
if self.robust_flags is None:
raise ValueError("robust_flags is not set yet. Start the attack first.")
if self.attacks_to_run - self.run_attacks:
warnings.warn("You are checking `robust_accuracy` before all the attacks"
" have been run.")
return self.robust_flags.float().mean().item() |