deepfake-server / src /train /train.py
DevQueen's picture
Sync from GitHub via hub-sync
1dc2504 verified
Raw
History Blame Contribute Delete
5.7 kB
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.data.dataset import EyeSequenceDataset
from src.models.lrcn_vit import LRCNViT
from src.train.adversarial import (
attention_consistency_loss,
blink_timing_regularizer,
fgsm_attack,
pgd_attack,
)
from src.utils import ensure_dir, load_yaml, set_seed
def build_model(cfg: Dict) -> LRCNViT:
model_cfg = cfg["model"]
data_cfg = cfg["data"]
return LRCNViT(
backbone_name=model_cfg["backbone"],
backbone_pretrained=model_cfg["backbone_pretrained"],
lstm_hidden=model_cfg["lstm_hidden"],
lstm_layers=model_cfg["lstm_layers"],
dropout=model_cfg["dropout"],
num_classes=model_cfg["num_classes"],
use_blink_head=model_cfg.get("use_blink_head", True),
image_size=data_cfg["image_size"],
)
def merge_config(config_path: str) -> Dict:
cfg = load_yaml(config_path)
if "inherits" not in cfg:
return cfg
merged: Dict = {}
for p in cfg["inherits"]:
parent = load_yaml(p)
for k, v in parent.items():
if isinstance(v, dict):
merged.setdefault(k, {}).update(v)
else:
merged[k] = v
for k, v in cfg.items():
if k == "inherits":
continue
if isinstance(v, dict):
merged.setdefault(k, {}).update(v)
else:
merged[k] = v
return merged
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
cfg = merge_config(args.config)
set_seed(cfg["project"]["seed"])
out_dir = ensure_dir(cfg["project"]["output_dir"])
if torch.cuda.is_available():
device = "cuda"
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Using device: {device}")
metadata_csv = cfg["data"].get("metadata_csv", "data/metadata.csv")
train_ds = EyeSequenceDataset(metadata_csv, split="train")
val_ds = EyeSequenceDataset(metadata_csv, split="val")
train_loader = DataLoader(train_ds, batch_size=cfg["data"]["batch_size"], shuffle=True, num_workers=cfg["data"]["num_workers"])
val_loader = DataLoader(val_ds, batch_size=cfg["data"]["batch_size"], shuffle=False, num_workers=cfg["data"]["num_workers"])
model = build_model(cfg).to(device)
optim = AdamW(model.parameters(), lr=cfg["train"]["lr"], weight_decay=cfg["train"]["weight_decay"])
best_val = 0.0
for epoch in range(cfg["train"]["epochs"]):
model.train()
pbar = tqdm(train_loader, desc=f"epoch {epoch + 1}")
for batch in pbar:
frames = batch["frames"].to(device)
blink = batch["blink"].to(device)
labels = batch["label"].to(device)
logits_clean, aux_clean = model(frames, blink)
loss_clean = F.cross_entropy(logits_clean, labels)
loss = cfg["adv"]["clean_weight"] * loss_clean
if cfg["adv"]["enabled"]:
if cfg["adv"]["attack"].lower() == "fgsm":
adv_frames = fgsm_attack(model, frames, blink, labels, eps=cfg["adv"]["fgsm_eps"])
else:
adv_frames = pgd_attack(
model,
frames,
blink,
labels,
eps=cfg["adv"]["eps"],
alpha=cfg["adv"]["alpha"],
steps=cfg["adv"]["steps"],
)
logits_adv, aux_adv = model(adv_frames, blink)
loss_adv = F.cross_entropy(logits_adv, labels)
loss = loss + cfg["adv"]["adv_weight"] * loss_adv
if cfg["aat"]["enabled"]:
attn_loss = attention_consistency_loss(aux_clean["temporal_feat"], aux_adv["temporal_feat"])
blink_loss = blink_timing_regularizer(
blink,
fps=cfg["aat"]["fps"],
min_seconds=cfg["aat"]["blink_min_seconds"],
max_seconds=cfg["aat"]["blink_max_seconds"],
)
loss = (
loss
+ cfg["aat"]["attention_consistency_weight"] * attn_loss
+ cfg["aat"]["blink_timing_weight"] * blink_loss
)
optim.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg["train"]["grad_clip"])
optim.step()
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
val_acc = evaluate_simple(model, val_loader, device)
if val_acc > best_val:
best_val = val_acc
torch.save(model.state_dict(), out_dir / "best.pt")
print(f"Epoch {epoch + 1}: val_acc={val_acc:.4f}, best={best_val:.4f}")
@torch.no_grad()
def evaluate_simple(model, loader, device: str) -> float:
model.eval()
total = 0
correct = 0
for batch in loader:
frames = batch["frames"].to(device)
blink = batch["blink"].to(device)
labels = batch["label"].to(device)
logits, _ = model(frames, blink)
pred = logits.argmax(dim=1)
total += labels.size(0)
correct += (pred == labels).sum().item()
return correct / max(total, 1)
if __name__ == "__main__":
main()