""" HF Inference Endpoints handler for ChatTS-14B. Expected request JSON: { "inputs": { "prompt": "Describe the trend of this series.", "timeseries": [[0.1, 0.2, 0.3, ...]], # list of float lists, one per "max_new_tokens": 300 } } The prompt MUST contain one `` placeholder per series in `timeseries`. Response: {"generated_text": "..."} """ from __future__ import annotations from typing import Any import numpy as np import torch from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer class EndpointHandler: def __init__(self, path: str = "") -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if self.device == "cuda" else torch.float32 self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=dtype, device_map=0 if self.device == "cuda" else None, ) self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) self.processor = AutoProcessor.from_pretrained( path, trust_remote_code=True, tokenizer=self.tokenizer ) self.model.eval() def __call__(self, data: dict[str, Any]) -> dict[str, Any]: inputs = data.get("inputs", {}) if isinstance(inputs, str): return { "error": "ChatTS requires structured inputs. " "Use {'inputs': {'prompt': str, 'timeseries': [[...]], 'max_new_tokens': int}}" } prompt: str = inputs["prompt"] ts_lists = inputs["timeseries"] max_new_tokens: int = int(inputs.get("max_new_tokens", 300)) ts_arrays = [np.asarray(t, dtype=np.float64) for t in ts_lists] formatted = ( "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" ) proc_inputs = self.processor( text=[formatted], timeseries=ts_arrays, padding=True, return_tensors="pt", ) proc_inputs = {k: v.to(self.device) for k, v in proc_inputs.items()} with torch.no_grad(): outputs = self.model.generate( **proc_inputs, max_new_tokens=max_new_tokens, do_sample=False, ) generated = self.tokenizer.batch_decode( outputs[:, proc_inputs["input_ids"].shape[1] :], skip_special_tokens=True, ) return {"generated_text": generated[0]}