shit-detector / inference.py
cstria0106's picture
Upload folder using huggingface_hub
c2d0005 verified
from __future__ import annotations
import json
import sys
from pathlib import Path
import numpy as np
import onnxruntime as ort
from PIL import Image
MODEL_DIR = Path(__file__).resolve().parent
def _resize_center_crop(image: Image.Image, size: int) -> Image.Image:
resize_short_edge = int(size * 1.14)
width, height = image.size
scale = resize_short_edge / min(width, height)
resized = image.resize((round(width * scale), round(height * scale)))
left = (resized.width - size) // 2
top = (resized.height - size) // 2
return resized.crop((left, top, left + size, top + size))
def _preprocess(image: Image.Image, metadata: dict) -> np.ndarray:
input_size = int(metadata["input_size"])
image = _resize_center_crop(image.convert("RGB"), input_size)
array = np.asarray(image, dtype=np.float32) / 255.0
array = np.transpose(array, (2, 0, 1))
mean = np.asarray(metadata["mean"], dtype=np.float32)[:, None, None]
std = np.asarray(metadata["std"], dtype=np.float32)[:, None, None]
return ((array - mean) / std)[None, ...]
def predict(image_path: str | Path) -> dict:
metadata = json.loads((MODEL_DIR / "metadata.json").read_text(encoding="utf-8"))
session = ort.InferenceSession(
str(MODEL_DIR / "shit_detector.onnx"),
providers=["CPUExecutionProvider"],
)
batch = _preprocess(Image.open(image_path), metadata)
logits = session.run(None, {session.get_inputs()[0].name: batch})[0]
logits = logits * float(metadata.get("logit_scale", 1.0))
shifted = logits - logits.max(axis=-1, keepdims=True)
probs = np.exp(shifted) / np.exp(shifted).sum(axis=-1, keepdims=True)
shit_probability = float(probs[0, 0])
confidence = float(probs.max(axis=-1)[0])
threshold = float(metadata.get("shit_threshold", 0.5))
label = "shit" if shit_probability >= threshold else "not_shit"
return {
"status": label,
"label": label,
"is_shit": label == "shit",
"confidence": confidence,
"shit_probability": shit_probability,
"threshold": threshold,
}
if __name__ == "__main__":
if len(sys.argv) != 2:
raise SystemExit("Usage: python inference.py path/to/image")
print(json.dumps(predict(sys.argv[1]), indent=2))