Spaces:
Running
Running
| """Single-image inference for the Streamlit demo. | |
| Loads a trained checkpoint (local or from HuggingFace Hub) and returns a | |
| probability that an uploaded image is AI-generated. | |
| Label convention follows the training pipeline in src/dataset.py: | |
| 0 = real, 1 = AI-generated. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from dataclasses import dataclass | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| from src.models import create_model | |
| from src.transforms import get_eval_transforms | |
| class ModelSpec: | |
| tag: str | |
| display_name: str | |
| model_name: str | |
| modified: bool | |
| test_auc: float | |
| MODEL_REGISTRY: dict[str, ModelSpec] = { | |
| "vit_base_patch16_224_transfer": ModelSpec( | |
| tag="vit_base_patch16_224_transfer", | |
| display_name="ViT-B/16 (transfer learning) — best accuracy", | |
| model_name="vit_base_patch16_224", | |
| modified=False, | |
| test_auc=0.9950, | |
| ), | |
| "densenet121_transfer": ModelSpec( | |
| tag="densenet121_transfer", | |
| display_name="DenseNet-121 (transfer learning)", | |
| model_name="densenet121", | |
| modified=False, | |
| test_auc=0.9854, | |
| ), | |
| "resnet50_transfer": ModelSpec( | |
| tag="resnet50_transfer", | |
| display_name="ResNet-50 (transfer learning)", | |
| model_name="resnet50", | |
| modified=False, | |
| test_auc=0.9749, | |
| ), | |
| "resnet50_scratch_gelu": ModelSpec( | |
| tag="resnet50_scratch_gelu", | |
| display_name="ResNet-50 (from scratch, GELU activation)", | |
| model_name="resnet50", | |
| modified=True, | |
| test_auc=0.9349, | |
| ), | |
| } | |
| WEIGHTS_REPO_DEFAULT = "Joergenator/imagerecogniser-weights" | |
| _EVAL_TRANSFORM = get_eval_transforms() | |
| def _resolve_checkpoint(tag: str) -> str: | |
| local = os.path.join("results", "checkpoints", tag, "best.pt") | |
| if os.path.isfile(local): | |
| return local | |
| repo_id = os.environ.get("HF_WEIGHTS_REPO", WEIGHTS_REPO_DEFAULT) | |
| return hf_hub_download(repo_id=repo_id, filename=f"{tag}.pt") | |
| def load_model(tag: str, device: str | torch.device = "cpu") -> torch.nn.Module: | |
| if tag not in MODEL_REGISTRY: | |
| raise KeyError(f"Unknown model tag: {tag}. Known tags: {list(MODEL_REGISTRY)}") | |
| spec = MODEL_REGISTRY[tag] | |
| ckpt_path = _resolve_checkpoint(tag) | |
| model = create_model( | |
| model_name=spec.model_name, | |
| pretrained=False, | |
| dropout=0.3, | |
| modified=spec.modified, | |
| ) | |
| state = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| if isinstance(state, dict) and "state_dict" in state: | |
| state = state["state_dict"] | |
| elif isinstance(state, dict) and "model" in state: | |
| state = state["model"] | |
| model.load_state_dict(state) | |
| model.to(device).eval() | |
| return model | |
| def predict_image( | |
| model: torch.nn.Module, | |
| image: Image.Image, | |
| device: str | torch.device = "cpu", | |
| ) -> tuple[float, str]: | |
| img = image.convert("RGB") | |
| tensor = _EVAL_TRANSFORM(img).unsqueeze(0).to(device) | |
| logit = model(tensor).squeeze() | |
| prob = torch.sigmoid(logit).item() | |
| label = "AI-generated" if prob >= 0.5 else "Real" | |
| return prob, label | |