| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| from PIL import Image | |
| MODEL_DIR = Path(__file__).resolve().parent | |
| def _resize_center_crop(image: Image.Image, size: int) -> Image.Image: | |
| resize_short_edge = int(size * 1.14) | |
| width, height = image.size | |
| scale = resize_short_edge / min(width, height) | |
| resized = image.resize((round(width * scale), round(height * scale))) | |
| left = (resized.width - size) // 2 | |
| top = (resized.height - size) // 2 | |
| return resized.crop((left, top, left + size, top + size)) | |
| def _preprocess(image: Image.Image, metadata: dict) -> np.ndarray: | |
| input_size = int(metadata["input_size"]) | |
| image = _resize_center_crop(image.convert("RGB"), input_size) | |
| array = np.asarray(image, dtype=np.float32) / 255.0 | |
| array = np.transpose(array, (2, 0, 1)) | |
| mean = np.asarray(metadata["mean"], dtype=np.float32)[:, None, None] | |
| std = np.asarray(metadata["std"], dtype=np.float32)[:, None, None] | |
| return ((array - mean) / std)[None, ...] | |
| def predict(image_path: str | Path) -> dict: | |
| metadata = json.loads((MODEL_DIR / "metadata.json").read_text(encoding="utf-8")) | |
| session = ort.InferenceSession( | |
| str(MODEL_DIR / "shit_detector.onnx"), | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| batch = _preprocess(Image.open(image_path), metadata) | |
| logits = session.run(None, {session.get_inputs()[0].name: batch})[0] | |
| logits = logits * float(metadata.get("logit_scale", 1.0)) | |
| shifted = logits - logits.max(axis=-1, keepdims=True) | |
| probs = np.exp(shifted) / np.exp(shifted).sum(axis=-1, keepdims=True) | |
| shit_probability = float(probs[0, 0]) | |
| confidence = float(probs.max(axis=-1)[0]) | |
| threshold = float(metadata.get("shit_threshold", 0.5)) | |
| label = "shit" if shit_probability >= threshold else "not_shit" | |
| return { | |
| "status": label, | |
| "label": label, | |
| "is_shit": label == "shit", | |
| "confidence": confidence, | |
| "shit_probability": shit_probability, | |
| "threshold": threshold, | |
| } | |
| if __name__ == "__main__": | |
| if len(sys.argv) != 2: | |
| raise SystemExit("Usage: python inference.py path/to/image") | |
| print(json.dumps(predict(sys.argv[1]), indent=2)) | |