ChatTS-14B-handler / handler.py
kaluaim's picture
Create handler.py
baf1a4b verified
"""
HF Inference Endpoints handler for ChatTS-14B.
Expected request JSON:
{
"inputs": {
"prompt": "Describe the trend of this series.",
"timeseries": [[0.1, 0.2, 0.3, ...]], # list of float lists, one per <ts><ts/>
"max_new_tokens": 300
}
}
The prompt MUST contain one `<ts><ts/>` placeholder per series in `timeseries`.
Response:
{"generated_text": "..."}
"""
from __future__ import annotations
from typing import Any
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = "") -> None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if self.device == "cuda" else torch.float32
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=dtype,
device_map=0 if self.device == "cuda" else None,
)
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.processor = AutoProcessor.from_pretrained(
path, trust_remote_code=True, tokenizer=self.tokenizer
)
self.model.eval()
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
inputs = data.get("inputs", {})
if isinstance(inputs, str):
return {
"error": "ChatTS requires structured inputs. "
"Use {'inputs': {'prompt': str, 'timeseries': [[...]], 'max_new_tokens': int}}"
}
prompt: str = inputs["prompt"]
ts_lists = inputs["timeseries"]
max_new_tokens: int = int(inputs.get("max_new_tokens", 300))
ts_arrays = [np.asarray(t, dtype=np.float64) for t in ts_lists]
formatted = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
proc_inputs = self.processor(
text=[formatted],
timeseries=ts_arrays,
padding=True,
return_tensors="pt",
)
proc_inputs = {k: v.to(self.device) for k, v in proc_inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**proc_inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)
generated = self.tokenizer.batch_decode(
outputs[:, proc_inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)
return {"generated_text": generated[0]}