from __future__ import annotations import json import sys from pathlib import Path import numpy as np import onnxruntime as ort from PIL import Image MODEL_DIR = Path(__file__).resolve().parent def _resize_center_crop(image: Image.Image, size: int) -> Image.Image: resize_short_edge = int(size * 1.14) width, height = image.size scale = resize_short_edge / min(width, height) resized = image.resize((round(width * scale), round(height * scale))) left = (resized.width - size) // 2 top = (resized.height - size) // 2 return resized.crop((left, top, left + size, top + size)) def _preprocess(image: Image.Image, metadata: dict) -> np.ndarray: input_size = int(metadata["input_size"]) image = _resize_center_crop(image.convert("RGB"), input_size) array = np.asarray(image, dtype=np.float32) / 255.0 array = np.transpose(array, (2, 0, 1)) mean = np.asarray(metadata["mean"], dtype=np.float32)[:, None, None] std = np.asarray(metadata["std"], dtype=np.float32)[:, None, None] return ((array - mean) / std)[None, ...] def predict(image_path: str | Path) -> dict: metadata = json.loads((MODEL_DIR / "metadata.json").read_text(encoding="utf-8")) session = ort.InferenceSession( str(MODEL_DIR / "shit_detector.onnx"), providers=["CPUExecutionProvider"], ) batch = _preprocess(Image.open(image_path), metadata) logits = session.run(None, {session.get_inputs()[0].name: batch})[0] logits = logits * float(metadata.get("logit_scale", 1.0)) shifted = logits - logits.max(axis=-1, keepdims=True) probs = np.exp(shifted) / np.exp(shifted).sum(axis=-1, keepdims=True) shit_probability = float(probs[0, 0]) confidence = float(probs.max(axis=-1)[0]) threshold = float(metadata.get("shit_threshold", 0.5)) label = "shit" if shit_probability >= threshold else "not_shit" return { "status": label, "label": label, "is_shit": label == "shit", "confidence": confidence, "shit_probability": shit_probability, "threshold": threshold, } if __name__ == "__main__": if len(sys.argv) != 2: raise SystemExit("Usage: python inference.py path/to/image") print(json.dumps(predict(sys.argv[1]), indent=2))