File size: 4,380 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from typing import override

import torch
import torchvision.transforms as T
from PIL import Image

from src.config import Config, CustomPreprocessing
from src.heads.head import HeadOutput
from src.model.base import BaseDeepakeDetectionModel
from src.model.fsfm import models_vit, models_vit_fs_adapter
from src.utils import logger


def download_model_if_needed(checkpoint_path: str, link: str):
    if not os.path.exists(checkpoint_path):
        logger.print_warning_once(f"Checkpoint '{checkpoint_path}' not found, downloading...")
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
        os.system(f"wget {link} -O {checkpoint_path}")


class FSFM(BaseDeepakeDetectionModel):
    def __init__(self, config: Config):
        super().__init__(config, verbose=True)
        self.initialize_model(config.checkpoint)
        self.model.eval()

    def initialize_model(self, checkpoint_path: str):
        if checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L-Adapter.pth":
            link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/finetune_fs-adapter/cross_dataset_DfD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true"
            download_model_if_needed(checkpoint_path, link)
            self.model = models_vit_fs_adapter.vit_large_patch16(num_classes=2, drop_path_rate=0.1, global_pool=True)

        elif checkpoint_path == "weights/FS-VFM/FS-VFM-ViT-L.pth":
            link = "https://hf.co/Wolowolo/fsfm-3c/resolve/main/finetuned_models/FS-VFM_extensions/cross_dataset_DFD_and_DiFF/ViT-L_VF2_600e/FT_on_FF%2B%2B_c23_32frames/checkpoint-min_val_loss.pth?download=true"
            download_model_if_needed(checkpoint_path, link)
            self.model = models_vit.vit_large_patch16(
                num_classes=2,
                drop_path_rate=0.1,
                global_pool=True,
            )

        else:
            raise ValueError(f"Unknown FS-VFM checkpoint path: {checkpoint_path}")

    @override
    def forward(self, inputs: torch.Tensor) -> HeadOutput:
        outputs = self.model(inputs)
        outputs = outputs[..., [1, 0]]  # Swap 0 and 1 rows to have [real, fake]
        return HeadOutput(logits_labels=outputs)

    @override
    def test_step(self, batch, batch_idx):
        batch = self.get_batch(batch)
        outputs = self.forward(batch.images)
        probs = outputs.logits_labels.softmax(dim=1)

        # Save outputs for metrics calculation
        self.test_step_outputs.labels.update(batch.labels)
        self.test_step_outputs.probs.update(probs.detach())
        self.test_step_outputs.idx.update(batch.idx)

    @override
    def load_checkpoint(self, checkpoint_path: str):
        """Load the model checkpoint."""
        logger.print_info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
        incompatible_keys = self.model.load_state_dict(checkpoint["model"], strict=False)
        self.print_checkpoint_keys(incompatible_keys)

    @override
    def get_preprocessing(self):
        if self.config.custom_preprocessing is None:
            logger.print_warning_once("This model might expect a zoom in to the facial image. Make sure to tune it.")

        def preprocess(image: Image) -> torch.Tensor:
            image = self.custom_preprocessing(image)
            return transform(image)

        return preprocess


transform = T.Compose(
    [
        T.Resize(224, interpolation=T.InterpolationMode.BILINEAR),
        T.ToTensor(),
        T.Normalize(
            [0.5482207536697388, 0.42340534925460815, 0.3654651641845703],
            [0.2789176106452942, 0.2438540756702423, 0.23493893444538116],
        ),
    ]
)


if __name__ == "__main__":
    config = Config(
        checkpoint="weights/FS-VFM/FS-VFM-ViT-L.pth",
        custom_preprocessing=CustomPreprocessing(zoom_factor=1.3),
    )
    model = FSFM(config)
    model.load_checkpoint(config.checkpoint)

    image = Image.open("datasets/FF/DF/001_870/000.png")
    # image = Image.open("datasets/FF/real/001/000.png")
    preprocessed_image = model.get_preprocessing()(image)  # Convert to tensor
    batch = preprocessed_image.unsqueeze(0)  # Add batch dimension
    outputs = model.forward(batch)
    print(outputs.logits_labels.softmax(dim=-1))