Spaces:
Configuration error
Configuration error
| from dataclasses import dataclass | |
| from typing import Callable, Literal | |
| import lightning as pl | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import wandb | |
| from lightning import seed_everything | |
| from lightning.pytorch.loggers import WandbLogger | |
| from PIL import Image | |
| from sklearn import metrics as M | |
| from torchmetrics import CatMetric | |
| from src import metrics, plots | |
| from src.config import Config | |
| from src.dataset.base import BaseDataset | |
| from src.utils import logger | |
| from src.utils.decorators import TryExcept | |
| class OutputsForMetrics(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.probs = CatMetric() | |
| self.labels = CatMetric() | |
| self.idx = CatMetric() | |
| def reset(self): | |
| self.probs.reset() | |
| self.labels.reset() | |
| self.idx.reset() | |
| class Batch: | |
| images: torch.Tensor | |
| labels: None | torch.Tensor | |
| identity: None | torch.Tensor | |
| source_uids: None | torch.Tensor | |
| idx: None | torch.Tensor | |
| def __getitem__(self, key): | |
| # if batch["image"] is called, return batch.images | |
| return getattr(self, key) | |
| def from_dict(batch: dict): | |
| assert "image" in batch, "Batch must contain 'image' key" | |
| return Batch( | |
| images=batch.get("image"), | |
| labels=batch.get("label"), | |
| identity=batch.get("identity"), | |
| source_uids=batch.get("source_uid"), | |
| idx=batch.get("idx"), | |
| ) | |
| def compute_across_videos(files: list, probs: np.ndarray, labels: np.ndarray, reduce: Literal["mean", "median"]): | |
| """ | |
| Calculate mean probs for each video across all frames | |
| """ | |
| # Get all before the last / | |
| # For example: a/b/c/d -> a/b/c | |
| videos = [f[: -f[::-1].find("/")] for f in files] | |
| # Group by video: video -> [indices] | |
| video2idx = {v: [] for v in videos} | |
| for i, v in enumerate(videos): | |
| video2idx[v].append(i) | |
| # Calculate mean probs for each video across all frames | |
| video2probs = {v: [] for v in videos} | |
| video2labels = {v: [] for v in videos} | |
| for v, idxs in video2idx.items(): | |
| if reduce == "mean": | |
| video2probs[v] = np.mean(probs[idxs], axis=0) | |
| elif reduce == "median": | |
| video2probs[v] = np.median(probs[idxs], axis=0) | |
| else: | |
| raise ValueError(f"Unknown reduce method: {reduce}") | |
| video2labels[v] = labels[idxs[0]] # Assume all frames have the same label | |
| video_probs = np.array(list(video2probs.values())) | |
| video_labels = np.array(list(video2labels.values())) | |
| return video_probs, video_labels | |
| class BaseDeepakeDetectionModel(pl.LightningModule): | |
| def __init__(self, config: Config, verbose: bool = False): | |
| super().__init__() | |
| self.config = config | |
| self.save_hyperparameters(config.model_dump()) | |
| self.is_debug_mode = "tmp" in config.run_name | |
| if verbose: | |
| logger.print(config) | |
| seed_everything(self.config.seed, workers=True, verbose=verbose) | |
| self._init_metrics() | |
| def _init_metrics(self): | |
| self.train_step_outputs = OutputsForMetrics() | |
| self.val_step_outputs = OutputsForMetrics() | |
| self.test_step_outputs = OutputsForMetrics() | |
| def get_preprocessing(self) -> Callable[[Image.Image], torch.Tensor]: | |
| raise NotImplementedError("get_preprocessing must be implemented in the child class") | |
| def get_batch(self, batch: dict) -> Batch: | |
| return Batch.from_dict(batch) | |
| def on_train_epoch_end(self): | |
| if self.logger.log_dir is None: | |
| # TODO: figure out why logger.log_dir can be None | |
| return | |
| # Log weights norms | |
| with TryExcept(verbose=False): | |
| self.log("model/linear-W-norm", self.model.linear.weight.norm().item()) | |
| self.log("model/linear-b-norm", self.model.linear.bias.norm().item()) | |
| # Log learned temperature | |
| with TryExcept(verbose=False): | |
| self.log("model/criterion/compactness_loss/temp", self.criterion.compactness_loss.temp.item()) | |
| with TryExcept(verbose=False): | |
| self.log("model/criterion/dispersion_loss/temp", self.criterion.dispersion_loss.temp.item()) | |
| dataset = self.trainer.datamodule.train_dataset | |
| self.log_all_metrics(self.train_step_outputs, "train", dataset) | |
| def log_metrics( | |
| self, | |
| probs: np.ndarray, | |
| labels: np.ndarray, | |
| stage: Literal["train", "test", "val"], | |
| prefix: str, | |
| level: Literal["frame", "video"], | |
| dataset: BaseDataset, | |
| ): | |
| """ | |
| Images are saved to | |
| `log_dir / prefix / level_metrics / metric.png` | |
| """ | |
| log_dir = self.logger.log_dir | |
| Stage = stage.capitalize() | |
| # Compute ROC and PR curves for every class | |
| fprs, tprs, roc_ths, ovr_macro_auroc = metrics.ovr_roc(labels, probs) | |
| precs, recs, pr_ths, ovr_macro_ap = metrics.ovr_prc(labels, probs) | |
| if self.config.num_classes == 2: | |
| # Compute EER (Equal Error Rate) | |
| eer, eer_th = metrics.calculate_eer(labels, probs, True) | |
| self.log(f"{prefix}/eer_{level}", eer) | |
| self.log(f"{prefix}/eer_th_{level}", eer_th) | |
| # Compute TPR at selected FPRs, e.g., 0.1%, 1%, 5% | |
| selected_fprs = [0.001, 0.01, 0.05] | |
| tpr_at_fprs = metrics.calculate_tpr_at_fpr(labels, probs, selected_fprs) | |
| for target_fpr, tpr in zip(selected_fprs, tpr_at_fprs): | |
| self.log(f"{prefix}/TPR@FPR={target_fpr}_{level}", tpr) | |
| plots.plot_fpr_fnr_curve( | |
| fprs, | |
| tprs, | |
| roc_ths, | |
| title=f"{Stage} FPR vs FNR ({level}-level)", | |
| path=f"{log_dir}/{prefix}/{level}_metrics/{stage}_fpr_fnr_curve.png", | |
| eer=eer, | |
| ) | |
| W1_sep_real, W1_sep_fake, W1_conf_real, W1_conf_fake = metrics.compute_wasserstein1_metrics(probs, labels) | |
| if W1_sep_real is not None: | |
| self.log(f"{prefix}/W1-sep-real_{level}", W1_sep_real) | |
| self.log(f"{prefix}/W1-sep-fake_{level}", W1_sep_fake) | |
| # A mean of Wasserstein distances | |
| self.log(f"{prefix}/W1-sep_{level}", (W1_sep_real + W1_sep_fake) / 2) | |
| self.log(f"{prefix}/W1-conf-real_{level}", W1_conf_real) | |
| self.log(f"{prefix}/W1-conf-fake_{level}", W1_conf_fake) | |
| # A mean of Wasserstein distances | |
| self.log(f"{prefix}/W1-conf_{level}", (W1_conf_real + W1_conf_fake) / 2) | |
| # Compute predictions by EER threshold | |
| preds = np.where(probs[:, 1] > eer_th, 1, 0) | |
| else: | |
| # Compute predictions by argmax rule | |
| preds = probs.argmax(1) | |
| # Log metrics | |
| self.log(f"{prefix}/auroc_{level}", ovr_macro_auroc) | |
| self.log(f"{prefix}/acc_{level}", M.accuracy_score(labels, preds)) | |
| self.log(f"{prefix}/balanced_acc_{level}", M.balanced_accuracy_score(labels, preds)) | |
| self.log(f"{prefix}/f1_score_{level}", M.f1_score(labels, preds, average="macro")) | |
| self.log(f"{prefix}/mAP_{level}", ovr_macro_ap) | |
| class_names = dataset.get_class_names() | |
| plots.plot_probs_distribution( | |
| probs, | |
| labels, | |
| class_names, | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_probs_distribution.png", | |
| ) | |
| plots.plot_roc_curve( | |
| fprs, | |
| tprs, | |
| roc_ths, | |
| f"{Stage} ROC ({level}-level)", | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_roc_{level}.png", | |
| 0.01, | |
| class_names, | |
| ) | |
| plots.plot_prc_curve( | |
| precs, | |
| recs, | |
| pr_ths, | |
| f"{Stage} PR Curve ({level}-level)", | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_pr_curve.png", | |
| 0.01, | |
| class_names, | |
| ) | |
| plots.plot_f1_curve( | |
| precs, | |
| recs, | |
| pr_ths, | |
| f"{Stage} F1 Curve ({level}-level)", | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_f1_curve.png", | |
| 0.01, | |
| class_names, | |
| ) | |
| # Confusion matrix | |
| conf = M.confusion_matrix(labels, preds) | |
| plots.plot_confusion_matrix( | |
| conf, | |
| class_names, | |
| f"{Stage} Confusion Matrix ({level}-level)", | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion.png", | |
| ) | |
| plots.plot_confusion_matrix( | |
| conf, | |
| class_names, | |
| f"{Stage} Confusion Matrix ({level}-level)", | |
| f"{log_dir}/{prefix}/{level}_metrics/{stage}_confusion_norm.png", | |
| True, | |
| ) | |
| wandb_logger = self.get_wandb_logger() | |
| if wandb_logger is not None: | |
| wandb_logger.log_metrics( | |
| { | |
| f"confusion/{prefix}/{stage}_{level}": wandb.plot.confusion_matrix( | |
| y_true=labels, | |
| preds=preds, | |
| class_names=["real", "fake"], | |
| title=f"{Stage} Confusion Matrix {level.capitalize()}", | |
| ) | |
| } | |
| ) | |
| def sources_probs_to_binary(self, probs: np.ndarray) -> np.ndarray: | |
| # probs[:, 0] # is real probs | |
| # probs[:, 1:] # is fake probs (for each generator) | |
| return np.stack([probs[:, 0], probs[:, 1:].max(axis=1)], 1) | |
| def log_all_metrics( | |
| self, | |
| outputs_for_metrics: OutputsForMetrics, | |
| stage: Literal["train", "test", "val"], | |
| dataset: BaseDataset, | |
| ): | |
| # Merge all predictions and labels across processes | |
| labels = outputs_for_metrics.labels.compute().cpu().int().numpy() | |
| probs = outputs_for_metrics.probs.compute().cpu().numpy() | |
| idx = outputs_for_metrics.idx.compute().cpu().int().numpy() | |
| files = [dataset.files[i] for i in idx] # Get files in the same order as the rest | |
| outputs_for_metrics.reset() | |
| if self.config.make_binary_before_video_aggregation: | |
| if probs.shape[1] > 2: | |
| probs = self.sources_probs_to_binary(probs) | |
| # Compute probs and labels for videos | |
| video_probs, video_labels = compute_across_videos(files, probs, labels, self.config.reduce_video_predictions) | |
| # Convery to binary if sources are used | |
| if not self.config.make_binary_before_video_aggregation: | |
| if probs.shape[1] > 2: | |
| probs = self.sources_probs_to_binary(probs) | |
| video_probs = self.sources_probs_to_binary(video_probs) | |
| self.log_metrics(probs, labels, stage, stage, "frame", dataset) | |
| self.log_metrics(video_probs, video_labels, stage, stage, "video", dataset) | |
| # if trn_files / val_files / tst_files is dict, separate metrics for each dataset | |
| if dataset.dataset2files is not None: | |
| if not self.config.make_binary_before_video_aggregation: | |
| logger.print_warning( | |
| "`make_binary_before_video_aggregation=False` is not supported when trn_files / val_files / tst_files is dict" | |
| ) | |
| file2index = {f: i for i, f in enumerate(files)} | |
| for dataset_name, dataset_files in dataset.dataset2files.items(): | |
| # Get files only for current dataset | |
| dataset_files = np.intersect1d(files, dataset_files) | |
| file_indices = [file2index[f] for f in dataset_files] | |
| dataset_probs = probs[file_indices] | |
| dataset_labels = labels[file_indices] | |
| dataset_files = [files[i] for i in file_indices] | |
| self.log_metrics( | |
| dataset_probs, | |
| dataset_labels, | |
| stage, | |
| f"{stage}/dataset/{dataset_name}", | |
| "frame", | |
| dataset, | |
| ) | |
| dataset_video_probs, dataset_video_labels = compute_across_videos( | |
| dataset_files, dataset_probs, dataset_labels, self.config.reduce_video_predictions | |
| ) | |
| self.log_metrics( | |
| dataset_video_probs, | |
| dataset_video_labels, | |
| stage, | |
| f"{stage}/dataset/{dataset_name}", | |
| "video", | |
| dataset, | |
| ) | |
| def custom_preprocessing(self, image: Image.Image) -> Image.Image: | |
| if self.config.custom_preprocessing is None: | |
| return image | |
| if self.config.custom_preprocessing.zoom_factor != 1.0: | |
| zoom_factor = self.config.custom_preprocessing.zoom_factor | |
| width, height = image.size | |
| # Calculate crop size (smaller portion of the image to simulate zoom-in) | |
| crop_w = width // zoom_factor | |
| crop_h = height // zoom_factor | |
| # Center crop coordinates | |
| left = (width - crop_w) // 2 | |
| top = (height - crop_h) // 2 | |
| right = left + crop_w | |
| bottom = top + crop_h | |
| # Crop and resize back to original size | |
| cropped_img = image.crop((left, top, right, bottom)) | |
| if self.config.custom_preprocessing.image_size is not None: | |
| image = cropped_img.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR) | |
| else: | |
| # Use bilinear interpolation to preserve artifacts | |
| image = cropped_img.resize((width, height), Image.BILINEAR) | |
| if self.config.custom_preprocessing.image_size is not None: | |
| image = image.resize(self.config.custom_preprocessing.image_size, Image.BILINEAR) | |
| if self.config.custom_preprocessing.flip_left_right: | |
| image = image.transpose(Image.FLIP_LEFT_RIGHT) | |
| return image | |
| def get_wandb_logger(self) -> WandbLogger | None: | |
| """ | |
| Get the WandbLogger instance from the current loggers. | |
| Returns None if no WandbLogger is found. | |
| """ | |
| for l in self.loggers: | |
| if isinstance(l, WandbLogger): | |
| return l | |
| return None | |
| def on_test_start(self): | |
| logger.print(f"[blue]Logs: {self.logger.log_dir}") | |
| self.log("num_test_files", len(self.trainer.datamodule.test_dataset)) | |
| def test_step(self, batch, batch_idx): | |
| raise NotImplementedError("test_step must be implemented in the child class") | |
| def on_test_epoch_end(self): | |
| if self.logger.log_dir is None: | |
| # TODO: figure out why logger.log_dir can be None | |
| return | |
| # Concatenate all predictions and labels | |
| probs = self.test_step_outputs.probs.compute().cpu().numpy() | |
| labels = self.test_step_outputs.labels.compute().cpu().int().numpy() | |
| idx = self.test_step_outputs.idx.compute().cpu().int().numpy() | |
| dataset = self.trainer.datamodule.test_dataset | |
| files = [dataset.files[i] for i in idx] | |
| # preds is a 2D array of shape (num_samples, num_classes) | |
| probs = {f"prob_class_{i}": np.round(probs[:, i], 4) for i in range(probs.shape[1])} | |
| table = pd.DataFrame({"files": files, "labels": labels, **probs}) | |
| # Save to CSV | |
| table.to_csv(f"{self.logger.log_dir}/test_predictions.csv", index=False, float_format="%.4f") | |
| self.log_all_metrics(self.test_step_outputs, "test", dataset) | |
| def load_checkpoint(self, checkpoint: str): | |
| if checkpoint: | |
| state_dict = torch.load(checkpoint, map_location="cpu", weights_only=True)["state_dict"] | |
| incompatible_keys = self.load_state_dict(state_dict, strict=False) | |
| self.print_checkpoint_keys(incompatible_keys) | |
| def print_checkpoint_keys(self, incompatible_keys): | |
| missing_keys = set(incompatible_keys.missing_keys) | |
| unexpected_keys = set(incompatible_keys.unexpected_keys) | |
| logger.print("\n[blue bold]Keys in checkpoint:") | |
| logger.print("[red bold]- Missing") | |
| logger.print("[yellow bold]? Unexpected") | |
| logger.print("[green bold]+ Matched\n") | |
| for key in self.state_dict().keys(): | |
| if key in missing_keys: | |
| logger.print(f"[red]- {key}") | |
| elif key in unexpected_keys: | |
| logger.print(f"[orange]? {key}") | |
| else: | |
| logger.print(f"[green]+ {key}") | |