Spaces:
Configuration error
Configuration error
File size: 3,764 Bytes
c29babb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | from typing import override
import cv2
import numpy as np
import torch
import yaml
from PIL import Image
from torchvision import transforms as T
from src.config import Config
from src.heads.head import HeadOutput
from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics
from src.model.forada.ds import DS
from src.utils import logger
class ForAda(BaseDeepakeDetectionModel):
def __init__(self, config: Config):
super().__init__(config, verbose=True)
# load yaml file relative to the current file
config_path = __file__.replace("forensics_adapter.py", "forensics_adapter_model/config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)
self.model = DS(
clip_name=config["clip_model_name"],
adapter_vit_name=config["vit_name"],
num_quires=config["num_quires"],
fusion_map=config["fusion_map"],
mlp_dim=config["mlp_dim"],
mlp_out_dim=config["mlp_out_dim"],
head_num=config["head_num"],
)
self.eval()
@override
def forward(self, inputs: torch.Tensor) -> HeadOutput:
outputs = self.model({"image": inputs}, inference=True)
return HeadOutput(logits_labels=outputs["logits"])
@override
def on_test_epoch_start(self):
self.test_step_outputs = OutputsForMetrics()
# move model to the device
self.model.to(self.trainer.strategy.root_device)
@override
def test_step(self, batch, batch_idx):
batch = self.get_batch(batch)
outputs = self.forward(batch.images)
probs = outputs.logits_labels.softmax(dim=1)
# Save outputs for metrics calculation
self.test_step_outputs.labels.update(batch.labels)
self.test_step_outputs.probs.update(probs.detach())
self.test_step_outputs.idx.update(batch.idx)
@override
def load_checkpoint(self, checkpoint_path: str):
"""Load the model checkpoint."""
logger.print_info(f"Loading checkpoint from {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location="cpu")
incompatible_keys = self.model.load_state_dict(state_dict, strict=False)
self.print_checkpoint_keys(incompatible_keys)
@override
def get_preprocessing(self):
def preprocess(image: Image) -> torch.Tensor:
return preprocessing(image)
return preprocess
_preprocess = T.Compose(
[
T.ToTensor(),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
def preprocessing(image: Image) -> torch.Tensor:
image = np.array(image)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_CUBIC)
image = np.array(image, dtype=np.uint8)
image = _preprocess(image)
# image = F.interpolate(
# image.unsqueeze(0),
# size=(224, 224),
# mode="bilinear",
# align_corners=False,
# )[0]
return image
if __name__ == "__main__":
#! Run as module:
#! python -m src.model.forensics_adapter
from PIL import Image
from src.config import Config
from src.model.ForAda import ForAda
config = Config()
model = ForAda(config)
model.load_checkpoint("weights/forensics_adapter/ForensicsAdapter.pth")
path = "datasets/FF/real/000/000.png"
image = Image.open(path) # Load image
preprocessed_image = model.get_preprocessing()(image) # Convert to tensor
batch = preprocessed_image.unsqueeze(0) # Add batch dimension
outputs = model(batch)
print(outputs.logits_labels) # Print logits labels
print(outputs.logits_labels.softmax(dim=1)) # Print probabilities
|