imagerecogniser / src /predict.py
Joergenator's picture
Upload 4 files
604f286 verified
"""Single-image inference for the Streamlit demo.
Loads a trained checkpoint (local or from HuggingFace Hub) and returns a
probability that an uploaded image is AI-generated.
Label convention follows the training pipeline in src/dataset.py:
0 = real, 1 = AI-generated.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from src.models import create_model
from src.transforms import get_eval_transforms
@dataclass(frozen=True)
class ModelSpec:
tag: str
display_name: str
model_name: str
modified: bool
test_auc: float
MODEL_REGISTRY: dict[str, ModelSpec] = {
"vit_base_patch16_224_transfer": ModelSpec(
tag="vit_base_patch16_224_transfer",
display_name="ViT-B/16 (transfer learning) — best accuracy",
model_name="vit_base_patch16_224",
modified=False,
test_auc=0.9950,
),
"densenet121_transfer": ModelSpec(
tag="densenet121_transfer",
display_name="DenseNet-121 (transfer learning)",
model_name="densenet121",
modified=False,
test_auc=0.9854,
),
"resnet50_transfer": ModelSpec(
tag="resnet50_transfer",
display_name="ResNet-50 (transfer learning)",
model_name="resnet50",
modified=False,
test_auc=0.9749,
),
"resnet50_scratch_gelu": ModelSpec(
tag="resnet50_scratch_gelu",
display_name="ResNet-50 (from scratch, GELU activation)",
model_name="resnet50",
modified=True,
test_auc=0.9349,
),
}
WEIGHTS_REPO_DEFAULT = "Joergenator/imagerecogniser-weights"
_EVAL_TRANSFORM = get_eval_transforms()
def _resolve_checkpoint(tag: str) -> str:
local = os.path.join("results", "checkpoints", tag, "best.pt")
if os.path.isfile(local):
return local
repo_id = os.environ.get("HF_WEIGHTS_REPO", WEIGHTS_REPO_DEFAULT)
return hf_hub_download(repo_id=repo_id, filename=f"{tag}.pt")
def load_model(tag: str, device: str | torch.device = "cpu") -> torch.nn.Module:
if tag not in MODEL_REGISTRY:
raise KeyError(f"Unknown model tag: {tag}. Known tags: {list(MODEL_REGISTRY)}")
spec = MODEL_REGISTRY[tag]
ckpt_path = _resolve_checkpoint(tag)
model = create_model(
model_name=spec.model_name,
pretrained=False,
dropout=0.3,
modified=spec.modified,
)
state = torch.load(ckpt_path, map_location="cpu", weights_only=False)
if isinstance(state, dict) and "state_dict" in state:
state = state["state_dict"]
elif isinstance(state, dict) and "model" in state:
state = state["model"]
model.load_state_dict(state)
model.to(device).eval()
return model
@torch.inference_mode()
def predict_image(
model: torch.nn.Module,
image: Image.Image,
device: str | torch.device = "cpu",
) -> tuple[float, str]:
img = image.convert("RGB")
tensor = _EVAL_TRANSFORM(img).unsqueeze(0).to(device)
logit = model(tensor).squeeze()
prob = torch.sigmoid(logit).item()
label = "AI-generated" if prob >= 0.5 else "Real"
return prob, label