File size: 3,764 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import override

import cv2
import numpy as np
import torch
import yaml
from PIL import Image
from torchvision import transforms as T

from src.config import Config
from src.heads.head import HeadOutput
from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics
from src.model.forada.ds import DS
from src.utils import logger


class ForAda(BaseDeepakeDetectionModel):
    def __init__(self, config: Config):
        super().__init__(config, verbose=True)

        # load yaml file relative to the current file
        config_path = __file__.replace("forensics_adapter.py", "forensics_adapter_model/config.yaml")
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)

        self.model = DS(
            clip_name=config["clip_model_name"],
            adapter_vit_name=config["vit_name"],
            num_quires=config["num_quires"],
            fusion_map=config["fusion_map"],
            mlp_dim=config["mlp_dim"],
            mlp_out_dim=config["mlp_out_dim"],
            head_num=config["head_num"],
        )
        self.eval()

    @override
    def forward(self, inputs: torch.Tensor) -> HeadOutput:
        outputs = self.model({"image": inputs}, inference=True)
        return HeadOutput(logits_labels=outputs["logits"])

    @override
    def on_test_epoch_start(self):
        self.test_step_outputs = OutputsForMetrics()
        # move model to the device
        self.model.to(self.trainer.strategy.root_device)

    @override
    def test_step(self, batch, batch_idx):
        batch = self.get_batch(batch)
        outputs = self.forward(batch.images)
        probs = outputs.logits_labels.softmax(dim=1)

        # Save outputs for metrics calculation
        self.test_step_outputs.labels.update(batch.labels)
        self.test_step_outputs.probs.update(probs.detach())
        self.test_step_outputs.idx.update(batch.idx)

    @override
    def load_checkpoint(self, checkpoint_path: str):
        """Load the model checkpoint."""
        logger.print_info(f"Loading checkpoint from {checkpoint_path}")
        state_dict = torch.load(checkpoint_path, map_location="cpu")
        incompatible_keys = self.model.load_state_dict(state_dict, strict=False)
        self.print_checkpoint_keys(incompatible_keys)

    @override
    def get_preprocessing(self):
        def preprocess(image: Image) -> torch.Tensor:
            return preprocessing(image)

        return preprocess


_preprocess = T.Compose(
    [
        T.ToTensor(),
        T.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711],
        ),
    ]
)


def preprocessing(image: Image) -> torch.Tensor:
    image = np.array(image)
    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_CUBIC)
    image = np.array(image, dtype=np.uint8)
    image = _preprocess(image)
    # image = F.interpolate(
    #     image.unsqueeze(0),
    #     size=(224, 224),
    #     mode="bilinear",
    #     align_corners=False,
    # )[0]
    return image


if __name__ == "__main__":
    #! Run as module:
    #! python -m src.model.forensics_adapter

    from PIL import Image

    from src.config import Config
    from src.model.ForAda import ForAda

    config = Config()
    model = ForAda(config)

    model.load_checkpoint("weights/forensics_adapter/ForensicsAdapter.pth")

    path = "datasets/FF/real/000/000.png"
    image = Image.open(path)  # Load image
    preprocessed_image = model.get_preprocessing()(image)  # Convert to tensor
    batch = preprocessed_image.unsqueeze(0)  # Add batch dimension
    outputs = model(batch)

    print(outputs.logits_labels)  # Print logits labels
    print(outputs.logits_labels.softmax(dim=1))  # Print probabilities