"""Inference script for Multi-XY Regressor (TorchScript). Usage: from inference import predict, load_model model, stats, config = load_model() predictions = predict(model, your_input_features) """ import json import numpy as np import torch from huggingface_hub import hf_hub_download def load_model(repo_id="sagkho/multi-xy-regressor"): config_path = hf_hub_download(repo_id=repo_id, filename="dataset_config.json") model_path = hf_hub_download(repo_id=repo_id, filename="model.pt") stats_path = hf_hub_download(repo_id=repo_id, filename="target_stats.npz") with open(config_path, "r") as f: config = json.load(f) model = torch.jit.load(model_path, map_location="cpu") model.eval() stats = np.load(stats_path) return model, stats, config def predict(model, features, stats=None): """ Args: model: loaded TorchScript model features: np.ndarray or list of input features [batch_size, n_features] stats: dict-like from np.load with 'mean' and 'std' Returns: np.ndarray of shape [batch_size, n_outputs] - denormalized predictions """ if isinstance(features, list): features = np.array(features, dtype=np.float32) if features.ndim == 1: features = features.reshape(1, -1) x = torch.tensor(features, dtype=torch.float32) with torch.no_grad(): pred_norm = model(x).numpy() if stats is not None: pred = pred_norm * stats["std"] + stats["mean"] else: pred = pred_norm return pred