GenD-Sentinel / src /model /GenDHF.py
yermandy's picture
init
c29babb
import os
from typing import override
import torch
import torchvision.transforms as T
from PIL import Image
from src.config import Config, CustomPreprocessing
from src.heads.head import HeadOutput
from src.hf.modeling_gend import GenD
from src.model.base import BaseDeepakeDetectionModel
from src.utils import logger
class GenDHF(BaseDeepakeDetectionModel):
def __init__(self, config: Config):
super().__init__(config, verbose=True)
self.model = GenD.from_pretrained(config.checkpoint)
self.model.eval()
@override
def forward(self, inputs: torch.Tensor) -> HeadOutput:
return HeadOutput(logits_labels=self.model(inputs))
@override
def test_step(self, batch, batch_idx):
batch = self.get_batch(batch)
outputs = self.forward(batch.images)
probs = outputs.logits_labels.softmax(dim=1)
# Save outputs for metrics calculation
self.test_step_outputs.labels.update(batch.labels)
self.test_step_outputs.probs.update(probs.detach())
self.test_step_outputs.idx.update(batch.idx)
@override
def load_checkpoint(self, checkpoint_path: str):
"""Load the model checkpoint."""
pass # Handled by from_pretrained
@override
def get_preprocessing(self):
return self.model.feature_extractor.preprocess
if __name__ == "__main__":
config = Config(
checkpoint="yermandy/GenD_CLIP_L_14",
)
model = GenDHF(config)
model.load_checkpoint(config.checkpoint)
image = Image.open("datasets/FF/DF/001_870/000.png")
# image = Image.open("datasets/FF/real/001/000.png")
preprocessed_image = model.get_preprocessing()(image) # Convert to tensor
batch = preprocessed_image.unsqueeze(0) # Add batch dimension
outputs = model.forward(batch)
print(outputs.logits_labels.softmax(dim=-1))