AIDE_image_detector / inference.py
meet4150's picture
Upload folder using huggingface_hub
dd33601 verified
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Iterable, List
import numpy as np
import torch
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torchvision import transforms
from data.dct import DCT_base_Rec_Module
from models import AIDE as build_aide_model
IMAGE_SIZE = 256
TO_TENSOR = transforms.ToTensor()
NORMALIZE_AND_RESIZE = transforms.Compose(
[
transforms.Resize([IMAGE_SIZE, IMAGE_SIZE]),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
def build_aide_input_from_pil(image: Image.Image, dct_module: DCT_base_Rec_Module) -> torch.Tensor:
image = image.convert("RGB")
image_tensor = TO_TENSOR(image)
x_minmin, x_maxmax, x_minmin1, x_maxmax1 = dct_module(image_tensor)
x_0 = NORMALIZE_AND_RESIZE(image_tensor)
x_minmin = NORMALIZE_AND_RESIZE(x_minmin)
x_maxmax = NORMALIZE_AND_RESIZE(x_maxmax)
x_minmin1 = NORMALIZE_AND_RESIZE(x_minmin1)
x_maxmax1 = NORMALIZE_AND_RESIZE(x_maxmax1)
return torch.stack([x_minmin, x_maxmax, x_minmin1, x_maxmax1, x_0], dim=0)
def load_model(
repo_dir: str | Path,
device: str | None = None,
weights_name: str = "model.safetensors",
) -> torch.nn.Module:
repo_dir = Path(repo_dir)
weights_path = repo_dir / weights_name
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model = build_aide_model(resnet_path=None, convnext_path=None)
state_dict = load_safetensors(str(weights_path))
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
return model
@torch.inference_mode()
def predict_pil_images(
model: torch.nn.Module,
images: Iterable[Image.Image],
device: str | None = None,
) -> List[dict]:
device = device or next(model.parameters()).device.type
dct_module = DCT_base_Rec_Module()
batch = torch.stack([build_aide_input_from_pil(img, dct_module) for img in images], dim=0).to(device)
logits = model(batch)
probs = torch.softmax(logits, dim=-1).cpu().numpy()
outputs = []
for prob in probs:
real_prob = float(prob[0])
fake_prob = float(prob[1])
label = "fake" if fake_prob >= real_prob else "real"
outputs.append(
{
"label": label,
"real_probability": round(real_prob, 6),
"fake_probability": round(fake_prob, 6),
}
)
return outputs
def _load_images(paths: Iterable[str]) -> List[Image.Image]:
return [Image.open(path).convert("RGB") for path in paths]
def main() -> None:
parser = argparse.ArgumentParser(description="Run AIDE image detector inference.")
parser.add_argument("--repo_dir", type=str, default=".", help="Local path to the model repository.")
parser.add_argument("--image", type=str, nargs="+", required=True, help="One or more image paths.")
parser.add_argument("--device", type=str, default=None, help="cuda or cpu")
args = parser.parse_args()
model = load_model(args.repo_dir, device=args.device)
images = _load_images(args.image)
predictions = predict_pil_images(model, images, device=args.device)
for image_path, prediction in zip(args.image, predictions):
print(
{
"image": str(image_path),
**prediction,
}
)
if __name__ == "__main__":
main()