GenD-Sentinel / src /model /base.py
yermandy's picture
init
c29babb
from dataclasses import dataclass
from typing import Callable, Literal
import lightning as pl
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import wandb
from lightning import seed_everything
from lightning.pytorch.loggers import WandbLogger
from PIL import Image
from sklearn import metrics as M
from torchmetrics import CatMetric
from src import metrics, plots
from src.config import Config
from src.dataset.base import BaseDataset
from src.utils import logger
from src.utils.decorators import TryExcept
class OutputsForMetrics(nn.Module):
def __init__(self):
super().__init__()
self.probs = CatMetric()
self.labels = CatMetric()
self.idx = CatMetric()
def reset(self):
self.probs.reset()
self.labels.reset()
self.idx.reset()
@dataclass
class Batch:
images: torch.Tensor
labels: None | torch.Tensor
identity: None | torch.Tensor
source_uids: None | torch.Tensor
idx: None | torch.Tensor
def __getitem__(self, key):
# if batch["image"] is called, return batch.images
return getattr(self, key)
@staticmethod
def from_dict(batch: dict):
assert "image" in batch, "Batch must contain 'image' key"
return Batch(
images=batch.get("image"),
labels=batch.get("label"),
identity=batch.get("identity"),
source_uids=batch.get("source_uid"),
idx=batch.get("idx"),
)
def compute_across_videos(files: list, probs: np.ndarray, labels: np.ndarray, reduce: Literal["mean", "median"]):
"""
Calculate mean probs for each video across all frames
"""
# Get all before the last /
# For example: a/b/c/d -> a/b/c
videos = [f[: -f[::-1].find("/")] for f in files]
# Group by video: video -> [indices]
video2idx = {v: [] for v in videos}
for i, v in enumerate(videos):
video2idx[v].append(i)
# Calculate mean probs for each video across all frames
video2probs = {v: [] for v in videos}
video2labels = {v: [] for v in videos}
for v, idxs in video2idx.items():
if reduce == "mean":
video2probs[v] = np.mean(probs[idxs], axis=0)
elif reduce == "median":
video2probs[v] = np.median(probs[idxs], axis=0)
else:
raise ValueError(f"Unknown reduce method: {reduce}")
video2labels[v] = labels[idxs[0]] # Assume all frames have the same label
video_probs = np.array(list(video2probs.values()))
video_labels = np.array(list(video2labels.values()))
return video_probs, video_labels
class BaseDeepakeDetectionModel(pl.LightningModule):
def __init__(self, config: Config, verbose: bool = False):
super().__init__()
self.config = config
self.save_hyperparameters(config.model_dump())
self.is_debug_mode = "tmp" in config.run_name
if verbose:
logger.print(config)
seed_everything(self.config.seed, workers=True, verbose=verbose)
self._init_metrics()
def _init_metrics(self):
self.train_step_outputs = OutputsForMetrics()
self.val_step_outputs = OutputsForMetrics()
self.test_step_outputs = OutputsForMetrics()
def get_preprocessing(self) -> Callable[[Image.Image], torch.Tensor]:
raise NotImplementedError("get_preprocessing must be implemented in the child class")
def get_batch(self, batch: dict) -> Batch:
return Batch.from_dict(batch)
def on_train_epoch_end(self):
if self.logger.log_dir is None:
# TODO: figure out why logger.log_dir can be None
return
# Log weights norms
with TryExcept(verbose=False):
self.log("model/linear-W-norm", self.model.linear.weight.norm().item())
self.log("model/linear-b-norm", self.model.linear.bias.norm().item())
# Log learned temperature
with TryExcept(verbose=False):
self.log("model/criterion/compactness_loss/temp", self.criterion.compactness_loss.temp.item())
with TryExcept(verbose=False):
self.log("model/criterion/dispersion_loss/temp", self.criterion.dispersion_loss.temp.item())
dataset = self.trainer.datamodule.train_dataset
self.log_all_metrics(self.train_step_outputs, "train", dataset)
def log_metrics(
self,
probs: np.ndarray,
labels: np.ndarray,
stage: Literal["train", "test", "val"],
prefix: str,
level: Literal["frame", "video"],
dataset: BaseDataset,
):
"""
Images are saved to
`log_dir / prefix / level_metrics / metric.png`
"""
log_dir = self.logger.log_dir
Stage = stage.capitalize()
# Compute ROC and PR curves for every class
fprs, tprs, roc_ths, ovr_macro_auroc = metrics.ovr_roc(labels, probs)
precs, recs, pr_ths, ovr_macro_ap = metrics.ovr_prc(labels, probs)
if self.config.num_classes == 2:
# Compute EER (Equal Error Rate)
eer, eer_th = metrics.calculate_eer(labels, probs, True)
self.log(f"{prefix}/eer_{level}", eer)
self.log(f"{prefix}/eer_th_{level}", eer_th)
# Compute TPR at selected FPRs, e.g., 0.1%, 1%, 5%
selected_fprs = [0.001, 0.01, 0.05]
tpr_at_fprs = metrics.calculate_tpr_at_fpr(labels, probs, selected_fprs)
for target_fpr, tpr in zip(selected_fprs, tpr_at_fprs):
self.log(f"{prefix}/TPR@FPR={target_fpr}_{level}", tpr)
plots.plot_fpr_fnr_curve(
fprs,
tprs,
roc_ths,
title=f"{Stage} FPR vs FNR ({level}-level)",
path=f"{log_dir}/{prefix}/{level}_metrics/{stage}_fpr_fnr_curve.png",
eer=eer,
)
W1_sep_real, W1_sep_fake, W1_conf_real, W1_conf_fake = metrics.compute_wasserstein1_metrics(probs, labels)
if W1_sep_real is not None:
self.log(f"{prefix}/W1-sep-real_{level}", W1_sep_real)
self.log(f"{prefix}/W1-sep-fake_{level}", W1_sep_fake)
# A mean of Wasserstein distances
self.log(f"{prefix}/W1-sep_{level}", (W1_sep_real + W1_sep_fake) / 2)
self.log(f"{prefix}/W1-conf-real_{level}", W1_conf_real)
self.log(f"{prefix}/W1-conf-fake_{level}", W1_conf_fake)
# A mean of Wasserstein distances
self.log(f"{prefix}/W1-conf_{level}", (W1_conf_real + W1_conf_fake) / 2)
# Compute predictions by EER threshold
preds = np.where(probs[:, 1] > eer_th, 1, 0)
else:
# Compute predictions by argmax rule
preds = probs.argmax(1)
# Log metrics
self.log(f"{prefix}/auroc_{level}", ovr_macro_auroc)
self.log(f"{prefix}/acc_{level}", M.accuracy_score(labels, preds))
self.log(f"{prefix}/balanced_acc_{level}", M.balanced_accuracy_score(labels, preds))
self.log(f"{prefix}/f1_score_{level}", M.f1_score(labels, preds, average="macro"))
self.log(f"{prefix}/mAP_{level}", ovr_macro_ap)
class_names = dataset.get_class_names()
plots.plot_probs_distribution(
probs,
labels,
class_names,
f"{log_dir}/{prefix}/{level}_metrics/{stage}_probs_distribution.png",
)
plots.plot_roc_curve(
fprs,
tprs,
roc_ths,
f"{Stage} ROC ({level}-level)",
f"{log_dir}/{prefix}/{level}_metrics/{stage}_roc_{level}.png",
0.01,
class_names,
)
plots.plot_prc_curve(
precs,
recs,
pr_ths,
f"{Stage} PR Curve ({level}-level)",
f"{log_dir}/{prefix}/{level}_metrics/{stage}_pr_curve.png",
0.01,
class_names,
)
plots.plot_f1_curve(
precs,
recs,
pr_ths,
f"{Stage} F1 Curve ({level}-level)",
f"{log_dir}/{prefix}/{level}_metrics/{stage}_f1_curve.png",
0.01,
class_names,
)
# Confusion matrix
conf = M.confusion_matrix(labels, preds)
plots.plot_confusion_matrix(
conf,
class_names,
f"{Stage} Confusion Matrix ({level}-level)",
f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion.png",
)
plots.plot_confusion_matrix(
conf,
class_names,
f"{Stage} Confusion Matrix ({level}-level)",
f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion_norm.png",
True,
)
wandb_logger = self.get_wandb_logger()
if wandb_logger is not None:
wandb_logger.log_metrics(
{
f"confusion/{prefix}/{stage}_{level}": wandb.plot.confusion_matrix(
y_true=labels,
preds=preds,
class_names=["real", "fake"],
title=f"{Stage} Confusion Matrix {level.capitalize()}",
)
}
)
def sources_probs_to_binary(self, probs: np.ndarray) -> np.ndarray:
# probs[:, 0] # is real probs
# probs[:, 1:] # is fake probs (for each generator)
return np.stack([probs[:, 0], probs[:, 1:].max(axis=1)], 1)
def log_all_metrics(
self,
outputs_for_metrics: OutputsForMetrics,
stage: Literal["train", "test", "val"],
dataset: BaseDataset,
):
# Merge all predictions and labels across processes
labels = outputs_for_metrics.labels.compute().cpu().int().numpy()
probs = outputs_for_metrics.probs.compute().cpu().numpy()
idx = outputs_for_metrics.idx.compute().cpu().int().numpy()
files = [dataset.files[i] for i in idx] # Get files in the same order as the rest
outputs_for_metrics.reset()
if self.config.make_binary_before_video_aggregation:
if probs.shape[1] > 2:
probs = self.sources_probs_to_binary(probs)
# Compute probs and labels for videos
video_probs, video_labels = compute_across_videos(files, probs, labels, self.config.reduce_video_predictions)
# Convery to binary if sources are used
if not self.config.make_binary_before_video_aggregation:
if probs.shape[1] > 2:
probs = self.sources_probs_to_binary(probs)
video_probs = self.sources_probs_to_binary(video_probs)
self.log_metrics(probs, labels, stage, stage, "frame", dataset)
self.log_metrics(video_probs, video_labels, stage, stage, "video", dataset)
# if trn_files / val_files / tst_files is dict, separate metrics for each dataset
if dataset.dataset2files is not None:
if not self.config.make_binary_before_video_aggregation:
logger.print_warning(
"`make_binary_before_video_aggregation=False` is not supported when trn_files / val_files / tst_files is dict"
)
file2index = {f: i for i, f in enumerate(files)}
for dataset_name, dataset_files in dataset.dataset2files.items():
# Get files only for current dataset
dataset_files = np.intersect1d(files, dataset_files)
file_indices = [file2index[f] for f in dataset_files]
dataset_probs = probs[file_indices]
dataset_labels = labels[file_indices]
dataset_files = [files[i] for i in file_indices]
self.log_metrics(
dataset_probs,
dataset_labels,
stage,
f"{stage}/dataset/{dataset_name}",
"frame",
dataset,
)
dataset_video_probs, dataset_video_labels = compute_across_videos(
dataset_files, dataset_probs, dataset_labels, self.config.reduce_video_predictions
)
self.log_metrics(
dataset_video_probs,
dataset_video_labels,
stage,
f"{stage}/dataset/{dataset_name}",
"video",
dataset,
)
def custom_preprocessing(self, image: Image.Image) -> Image.Image:
if self.config.custom_preprocessing is None:
return image
if self.config.custom_preprocessing.zoom_factor != 1.0:
zoom_factor = self.config.custom_preprocessing.zoom_factor
width, height = image.size
# Calculate crop size (smaller portion of the image to simulate zoom-in)
crop_w = width // zoom_factor
crop_h = height // zoom_factor
# Center crop coordinates
left = (width - crop_w) // 2
top = (height - crop_h) // 2
right = left + crop_w
bottom = top + crop_h
# Crop and resize back to original size
cropped_img = image.crop((left, top, right, bottom))
if self.config.custom_preprocessing.image_size is not None:
image = cropped_img.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR)
else:
# Use bilinear interpolation to preserve artifacts
image = cropped_img.resize((width, height), Image.BILINEAR)
if self.config.custom_preprocessing.image_size is not None:
image = image.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR)
if self.config.custom_preprocessing.flip_left_right:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
return image
def get_wandb_logger(self) -> WandbLogger | None:
"""
Get the WandbLogger instance from the current loggers.
Returns None if no WandbLogger is found.
"""
for l in self.loggers:
if isinstance(l, WandbLogger):
return l
return None
def on_test_start(self):
logger.print(f"[blue]Logs: {self.logger.log_dir}")
self.log("num_test_files", len(self.trainer.datamodule.test_dataset))
def test_step(self, batch, batch_idx):
raise NotImplementedError("test_step must be implemented in the child class")
def on_test_epoch_end(self):
if self.logger.log_dir is None:
# TODO: figure out why logger.log_dir can be None
return
# Concatenate all predictions and labels
probs = self.test_step_outputs.probs.compute().cpu().numpy()
labels = self.test_step_outputs.labels.compute().cpu().int().numpy()
idx = self.test_step_outputs.idx.compute().cpu().int().numpy()
dataset = self.trainer.datamodule.test_dataset
files = [dataset.files[i] for i in idx]
# preds is a 2D array of shape (num_samples, num_classes)
probs = {f"prob_class_{i}": np.round(probs[:, i], 4) for i in range(probs.shape[1])}
table = pd.DataFrame({"files": files, "labels": labels, **probs})
# Save to CSV
table.to_csv(f"{self.logger.log_dir}/test_predictions.csv", index=False, float_format="%.4f")
self.log_all_metrics(self.test_step_outputs, "test", dataset)
def load_checkpoint(self, checkpoint: str):
if checkpoint:
state_dict = torch.load(checkpoint, map_location="cpu", weights_only=True)["state_dict"]
incompatible_keys = self.load_state_dict(state_dict, strict=False)
self.print_checkpoint_keys(incompatible_keys)
def print_checkpoint_keys(self, incompatible_keys):
missing_keys = set(incompatible_keys.missing_keys)
unexpected_keys = set(incompatible_keys.unexpected_keys)
logger.print("\n[blue bold]Keys in checkpoint:")
logger.print("[red bold]- Missing")
logger.print("[yellow bold]? Unexpected")
logger.print("[green bold]+ Matched\n")
for key in self.state_dict().keys():
if key in missing_keys:
logger.print(f"[red]- {key}")
elif key in unexpected_keys:
logger.print(f"[orange]? {key}")
else:
logger.print(f"[green]+ {key}")