GenD-Sentinel / src /model /ForAda.py
yermandy's picture
init
c29babb
from typing import override
import cv2
import numpy as np
import torch
import yaml
from PIL import Image
from torchvision import transforms as T
from src.config import Config
from src.heads.head import HeadOutput
from src.model.base import BaseDeepakeDetectionModel, OutputsForMetrics
from src.model.forada.ds import DS
from src.utils import logger
class ForAda(BaseDeepakeDetectionModel):
def __init__(self, config: Config):
super().__init__(config, verbose=True)
# load yaml file relative to the current file
config_path = __file__.replace("forensics_adapter.py", "forensics_adapter_model/config.yaml")
with open(config_path, "r") as f:
config = yaml.safe_load(f)
self.model = DS(
clip_name=config["clip_model_name"],
adapter_vit_name=config["vit_name"],
num_quires=config["num_quires"],
fusion_map=config["fusion_map"],
mlp_dim=config["mlp_dim"],
mlp_out_dim=config["mlp_out_dim"],
head_num=config["head_num"],
)
self.eval()
@override
def forward(self, inputs: torch.Tensor) -> HeadOutput:
outputs = self.model({"image": inputs}, inference=True)
return HeadOutput(logits_labels=outputs["logits"])
@override
def on_test_epoch_start(self):
self.test_step_outputs = OutputsForMetrics()
# move model to the device
self.model.to(self.trainer.strategy.root_device)
@override
def test_step(self, batch, batch_idx):
batch = self.get_batch(batch)
outputs = self.forward(batch.images)
probs = outputs.logits_labels.softmax(dim=1)
# Save outputs for metrics calculation
self.test_step_outputs.labels.update(batch.labels)
self.test_step_outputs.probs.update(probs.detach())
self.test_step_outputs.idx.update(batch.idx)
@override
def load_checkpoint(self, checkpoint_path: str):
"""Load the model checkpoint."""
logger.print_info(f"Loading checkpoint from {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location="cpu")
incompatible_keys = self.model.load_state_dict(state_dict, strict=False)
self.print_checkpoint_keys(incompatible_keys)
@override
def get_preprocessing(self):
def preprocess(image: Image) -> torch.Tensor:
return preprocessing(image)
return preprocess
_preprocess = T.Compose(
[
T.ToTensor(),
T.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711],
),
]
)
def preprocessing(image: Image) -> torch.Tensor:
image = np.array(image)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_CUBIC)
image = np.array(image, dtype=np.uint8)
image = _preprocess(image)
# image = F.interpolate(
# image.unsqueeze(0),
# size=(224, 224),
# mode="bilinear",
# align_corners=False,
# )[0]
return image
if __name__ == "__main__":
#! Run as module:
#! python -m src.model.forensics_adapter
from PIL import Image
from src.config import Config
from src.model.ForAda import ForAda
config = Config()
model = ForAda(config)
model.load_checkpoint("weights/forensics_adapter/ForensicsAdapter.pth")
path = "datasets/FF/real/000/000.png"
image = Image.open(path) # Load image
preprocessed_image = model.get_preprocessing()(image) # Convert to tensor
batch = preprocessed_image.unsqueeze(0) # Add batch dimension
outputs = model(batch)
print(outputs.logits_labels) # Print logits labels
print(outputs.logits_labels.softmax(dim=1)) # Print probabilities