Spaces:
Configuration error
Configuration error
File size: 1,858 Bytes
c29babb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import os
from typing import override
import torch
import torchvision.transforms as T
from PIL import Image
from src.config import Config, CustomPreprocessing
from src.heads.head import HeadOutput
from src.hf.modeling_gend import GenD
from src.model.base import BaseDeepakeDetectionModel
from src.utils import logger
class GenDHF(BaseDeepakeDetectionModel):
def __init__(self, config: Config):
super().__init__(config, verbose=True)
self.model = GenD.from_pretrained(config.checkpoint)
self.model.eval()
@override
def forward(self, inputs: torch.Tensor) -> HeadOutput:
return HeadOutput(logits_labels=self.model(inputs))
@override
def test_step(self, batch, batch_idx):
batch = self.get_batch(batch)
outputs = self.forward(batch.images)
probs = outputs.logits_labels.softmax(dim=1)
# Save outputs for metrics calculation
self.test_step_outputs.labels.update(batch.labels)
self.test_step_outputs.probs.update(probs.detach())
self.test_step_outputs.idx.update(batch.idx)
@override
def load_checkpoint(self, checkpoint_path: str):
"""Load the model checkpoint."""
pass # Handled by from_pretrained
@override
def get_preprocessing(self):
return self.model.feature_extractor.preprocess
if __name__ == "__main__":
config = Config(
checkpoint="yermandy/GenD_CLIP_L_14",
)
model = GenDHF(config)
model.load_checkpoint(config.checkpoint)
image = Image.open("datasets/FF/DF/001_870/000.png")
# image = Image.open("datasets/FF/real/001/000.png")
preprocessed_image = model.get_preprocessing()(image) # Convert to tensor
batch = preprocessed_image.unsqueeze(0) # Add batch dimension
outputs = model.forward(batch)
print(outputs.logits_labels.softmax(dim=-1))
|