| """Inference script for Multi-XY Regressor (TorchScript). |
| Usage: |
| from inference import predict, load_model |
| model, stats, config = load_model() |
| predictions = predict(model, your_input_features) |
| """ |
| import json |
| import numpy as np |
| import torch |
| from huggingface_hub import hf_hub_download |
|
|
| def load_model(repo_id="sagkho/multi-xy-regressor"): |
| config_path = hf_hub_download(repo_id=repo_id, filename="dataset_config.json") |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.pt") |
| stats_path = hf_hub_download(repo_id=repo_id, filename="target_stats.npz") |
|
|
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| model = torch.jit.load(model_path, map_location="cpu") |
| model.eval() |
|
|
| stats = np.load(stats_path) |
|
|
| return model, stats, config |
|
|
| def predict(model, features, stats=None): |
| """ |
| Args: |
| model: loaded TorchScript model |
| features: np.ndarray or list of input features [batch_size, n_features] |
| stats: dict-like from np.load with 'mean' and 'std' |
| Returns: |
| np.ndarray of shape [batch_size, n_outputs] - denormalized predictions |
| """ |
| if isinstance(features, list): |
| features = np.array(features, dtype=np.float32) |
| if features.ndim == 1: |
| features = features.reshape(1, -1) |
|
|
| x = torch.tensor(features, dtype=torch.float32) |
| with torch.no_grad(): |
| pred_norm = model(x).numpy() |
|
|
| if stats is not None: |
| pred = pred_norm * stats["std"] + stats["mean"] |
| else: |
| pred = pred_norm |
|
|
| return pred |
|
|