File size: 2,550 Bytes
c29babb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import override

import torch
import torchvision.transforms as T
from PIL import Image

from src.config import Config
from src.heads.head import HeadOutput
from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics
from src.model.effort.model import EffortModel
from src.utils import logger

preprocessing_alternative = T.Compose(
    [
        T.Resize((224, 224), interpolation=T.InterpolationMode.BILINEAR),
        T.ToTensor(),
        T.Normalize(
            [0.48145466, 0.4578275, 0.40821073],
            [0.26862954, 0.26130258, 0.27577711],
        ),
    ]
)


class Effort(BaseDeepakeDetectionModel):
    def __init__(self, config: Config):
        super().__init__(config, verbose=True)
        self.detector = EffortModel()
        self.test_step_outputs = OutputsForMetrics()

        self.detector.eval()

    @override
    def forward(self, inputs: torch.Tensor) -> HeadOutput:
        logits, l2_embeddings = self.detector(inputs)
        return HeadOutput(logits_labels=logits, l2_embeddings=l2_embeddings)

    @override
    def test_step(self, batch, batch_idx):
        batch = self.get_batch(batch)
        outputs = self.forward(batch.images)
        probs = outputs.logits_labels.softmax(dim=1)

        # Save outputs for metrics calculation
        self.test_step_outputs.labels.update(batch.labels)
        self.test_step_outputs.probs.update(probs.detach())
        self.test_step_outputs.idx.update(batch.idx)

    @override
    def load_checkpoint(self, checkpoint_path: str):
        """Load the model checkpoint."""
        logger.print_info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        state_dict = checkpoint.get("state_dict", checkpoint)
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        incompatible_keys = self.detector.load_state_dict(state_dict, strict=False)
        self.print_checkpoint_keys(incompatible_keys)

    @override
    def get_preprocessing(self):
        def preprocess(image: Image) -> torch.Tensor:
            return preprocessing_alternative(image)

        return preprocess


if __name__ == "__main__":
    # Example usage
    model = Effort()
    print(model)

    model.load_checkpoint("weights/effort/effort_clip_L14_trainOn_FaceForensic.pth")

    image = Image.open("datasets/FF/real/000/000.png")
    tensor = preprocessing_alternative(image).unsqueeze(0)  # Add batch dimension
    outputs = model({"image": tensor})

    print(outputs)