from __future__ import annotations import argparse from pathlib import Path from typing import Iterable, List import numpy as np import torch from PIL import Image from safetensors.torch import load_file as load_safetensors from torchvision import transforms from data.dct import DCT_base_Rec_Module from models import AIDE as build_aide_model IMAGE_SIZE = 256 TO_TENSOR = transforms.ToTensor() NORMALIZE_AND_RESIZE = transforms.Compose( [ transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) def build_aide_input_from_pil(image: Image.Image, dct_module: DCT_base_Rec_Module) -> torch.Tensor: image = image.convert("RGB") image_tensor = TO_TENSOR(image) x_minmin, x_maxmax, x_minmin1, x_maxmax1 = dct_module(image_tensor) x_0 = NORMALIZE_AND_RESIZE(image_tensor) x_minmin = NORMALIZE_AND_RESIZE(x_minmin) x_maxmax = NORMALIZE_AND_RESIZE(x_maxmax) x_minmin1 = NORMALIZE_AND_RESIZE(x_minmin1) x_maxmax1 = NORMALIZE_AND_RESIZE(x_maxmax1) return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0) def load_model( repo_dir: str | Path, device: str | None = None, weights_name: str = "model.safetensors", ) -> torch.nn.Module: repo_dir = Path(repo_dir) weights_path = repo_dir / weights_name device = device or ("cuda" if torch.cuda.is_available() else "cpu") model = build_aide_model(resnet_path=None, convnext_path=None) state_dict = load_safetensors(str(weights_path)) model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() return model @torch.inference_mode() def predict_pil_images( model: torch.nn.Module, images: Iterable[Image.Image], device: str | None = None, ) -> List[dict]: device = device or next(model.parameters()).device.type dct_module = DCT_base_Rec_Module() batch = torch.stack([build_aide_input_from_pil(img, dct_module) for img in images], dim=0).to(device) logits = model(batch) probs = torch.softmax(logits, dim=-1).cpu().numpy() outputs = [] for prob in probs: real_prob = float(prob[0]) fake_prob = float(prob[1]) label = "fake" if fake_prob >= real_prob else "real" outputs.append( { "label": label, "real_probability": round(real_prob, 6), "fake_probability": round(fake_prob, 6), } ) return outputs def _load_images(paths: Iterable[str]) -> List[Image.Image]: return [Image.open(path).convert("RGB") for path in paths] def main() -> None: parser = argparse.ArgumentParser(description="Run AIDE image detector inference.") parser.add_argument("--repo_dir", type=str, default=".", help="Local path to the model repository.") parser.add_argument("--image", type=str, nargs="+", required=True, help="One or more image paths.") parser.add_argument("--device", type=str, default=None, help="cuda or cpu") args = parser.parse_args() model = load_model(args.repo_dir, device=args.device) images = _load_images(args.image) predictions = predict_pil_images(model, images, device=args.device) for image_path, prediction in zip(args.image, predictions): print( { "image": str(image_path), **prediction, } ) if __name__ == "__main__": main()