stroke-classification / src /model_utils.py
melisklc0's picture
feat: Add initial project structure with Docker support and stroke classification model
b02f059
Raw
History Blame Contribute Delete
2.03 kB
import os
from typing import Any
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from PIL import Image
IMAGE_SIZE = 299
CLASS_NAMES = ("No-Stroke", "Stroke")
REPO_ID = os.environ.get("STROKE_MODEL_REPO", "melisklc0/efficientnet-b0-stroke-distilled")
ONNX_FILENAME = "model.onnx"
def _softmax(x: np.ndarray) -> np.ndarray:
x = x.astype(np.float64)
x = x - np.max(x, axis=-1, keepdims=True)
e = np.exp(x)
return (e / e.sum(axis=-1, keepdims=True)).astype(np.float32)
def preprocess_image(img: Image.Image, image_size: int = IMAGE_SIZE) -> np.ndarray:
"""RGB, resize, ImageNet normalize -> NCHW float32."""
rgb = img.convert("RGB").resize((image_size, image_size), Image.Resampling.BILINEAR)
arr = np.asarray(rgb, dtype=np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
arr = (arr - mean) / std
return np.expand_dims(arr, axis=0)
def load_stroke_model():
"""Download ONNX from the model Hub repo and build an inference session."""
onnx_path = hf_hub_download(
repo_id=REPO_ID,
filename=ONNX_FILENAME,
repo_type="model",
)
providers: list[str] = ["CPUExecutionProvider"]
if ort.get_device() == "GPU":
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
session = ort.InferenceSession(onnx_path, providers=providers)
return session, preprocess_image
def predict(session: ort.InferenceSession, preprocess: Any, img: Image.Image):
x = preprocess(img)
inp = session.get_inputs()[0].name
logits = session.run(None, {inp: x})[0]
probs = _softmax(logits[0])
results = {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
pred_idx = int(np.argmax(probs))
prediction = CLASS_NAMES[pred_idx]
confidence = float(probs[pred_idx])
return prediction, confidence, results