remisek's picture
Install dependencies
4c0fb51
"""Utilities around models for the WorldCereal package."""
import json
from functools import lru_cache
import onnxruntime as ort
import requests
@lru_cache(maxsize=2)
def load_model_onnx(model_url) -> ort.InferenceSession:
"""Load an ONNX model from a URL.
Parameters
----------
model_url: str
URL to the ONNX model.
Returns
-------
ort.InferenceSession
ONNX model loaded with ONNX runtime.
"""
# Two minutes timeout to download the model
response = requests.get(model_url, timeout=120)
model = response.content
return ort.InferenceSession(model)
def validate_cb_model(model_url: str) -> ort.InferenceSession:
"""Validate a catboost model by loading it and checking if the required
metadata is present. Checks for the `class_names` and `class_to_labels`
fields are present in the `class_params` field of the custom metadata of
the model. By default, the CatBoost module should include those fields
when exporting a model to ONNX.
Raises an exception if the model is not valid.
Parameters
----------
model_url : str
URL to the ONNX model.
Returns
-------
ort.InferenceSession
ONNX model loaded with ONNX runtime.
"""
model = load_model_onnx(model_url=model_url)
metadata = model.get_modelmeta().custom_metadata_map
if "class_params" not in metadata:
raise ValueError("Could not find class names in the model metadata.")
class_params = json.loads(metadata["class_params"])
if "class_names" not in class_params:
raise ValueError("Could not find class names in the model metadata.")
if "class_to_label" not in class_params:
raise ValueError("Could not find class to labels in the model metadata.")
return model
def load_model_lut(model_url: str) -> dict:
"""Load the class names to labels mapping from a CatBoost model.
Parameters
----------
model_url : str
URL to the ONNX model.
Returns
-------
dict
Look-up table with class names and labels.
"""
model = validate_cb_model(model_url=model_url)
metadata = model.get_modelmeta().custom_metadata_map
class_params = json.loads(metadata["class_params"])
lut = dict(zip(class_params["class_names"], class_params["class_to_label"]))
sorted_lut = {k: v for k, v in sorted(lut.items(), key=lambda item: item[1])}
return sorted_lut