multi-xy-regressor / inference.py
sagkho's picture
Upload inference.py with huggingface_hub
8ae87c4 verified
"""Inference script for Multi-XY Regressor (TorchScript).
Usage:
from inference import predict, load_model
model, stats, config = load_model()
predictions = predict(model, your_input_features)
"""
import json
import numpy as np
import torch
from huggingface_hub import hf_hub_download
def load_model(repo_id="sagkho/multi-xy-regressor"):
config_path = hf_hub_download(repo_id=repo_id, filename="dataset_config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="model.pt")
stats_path = hf_hub_download(repo_id=repo_id, filename="target_stats.npz")
with open(config_path, "r") as f:
config = json.load(f)
model = torch.jit.load(model_path, map_location="cpu")
model.eval()
stats = np.load(stats_path)
return model, stats, config
def predict(model, features, stats=None):
"""
Args:
model: loaded TorchScript model
features: np.ndarray or list of input features [batch_size, n_features]
stats: dict-like from np.load with 'mean' and 'std'
Returns:
np.ndarray of shape [batch_size, n_outputs] - denormalized predictions
"""
if isinstance(features, list):
features = np.array(features, dtype=np.float32)
if features.ndim == 1:
features = features.reshape(1, -1)
x = torch.tensor(features, dtype=torch.float32)
with torch.no_grad():
pred_norm = model(x).numpy()
if stats is not None:
pred = pred_norm * stats["std"] + stats["mean"]
else:
pred = pred_norm
return pred