File size: 2,339 Bytes
c2d0005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from __future__ import annotations

import json
import sys
from pathlib import Path

import numpy as np
import onnxruntime as ort
from PIL import Image


MODEL_DIR = Path(__file__).resolve().parent


def _resize_center_crop(image: Image.Image, size: int) -> Image.Image:
    resize_short_edge = int(size * 1.14)
    width, height = image.size
    scale = resize_short_edge / min(width, height)
    resized = image.resize((round(width * scale), round(height * scale)))
    left = (resized.width - size) // 2
    top = (resized.height - size) // 2
    return resized.crop((left, top, left + size, top + size))


def _preprocess(image: Image.Image, metadata: dict) -> np.ndarray:
    input_size = int(metadata["input_size"])
    image = _resize_center_crop(image.convert("RGB"), input_size)
    array = np.asarray(image, dtype=np.float32) / 255.0
    array = np.transpose(array, (2, 0, 1))
    mean = np.asarray(metadata["mean"], dtype=np.float32)[:, None, None]
    std = np.asarray(metadata["std"], dtype=np.float32)[:, None, None]
    return ((array - mean) / std)[None, ...]


def predict(image_path: str | Path) -> dict:
    metadata = json.loads((MODEL_DIR / "metadata.json").read_text(encoding="utf-8"))
    session = ort.InferenceSession(
        str(MODEL_DIR / "shit_detector.onnx"),
        providers=["CPUExecutionProvider"],
    )
    batch = _preprocess(Image.open(image_path), metadata)
    logits = session.run(None, {session.get_inputs()[0].name: batch})[0]
    logits = logits * float(metadata.get("logit_scale", 1.0))
    shifted = logits - logits.max(axis=-1, keepdims=True)
    probs = np.exp(shifted) / np.exp(shifted).sum(axis=-1, keepdims=True)
    shit_probability = float(probs[0, 0])
    confidence = float(probs.max(axis=-1)[0])
    threshold = float(metadata.get("shit_threshold", 0.5))
    label = "shit" if shit_probability >= threshold else "not_shit"
    return {
        "status": label,
        "label": label,
        "is_shit": label == "shit",
        "confidence": confidence,
        "shit_probability": shit_probability,
        "threshold": threshold,
    }


if __name__ == "__main__":
    if len(sys.argv) != 2:
        raise SystemExit("Usage: python inference.py path/to/image")
    print(json.dumps(predict(sys.argv[1]), indent=2))