kronos-small-custom / inference.py
Faizack
Initial Kronos-small custom deployment
360970f
import os
from typing import Any, Dict
import pandas as pd
import torch
from kronos import Kronos, KronosTokenizer, KronosPredictor # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def _load_components(model_dir: str = "."):
"""
Load tokenizer, model, and predictor from a local directory.
This is called once at module import time on HF Inference Endpoints.
"""
tokenizer = KronosTokenizer.from_pretrained(model_dir)
model = Kronos.from_pretrained(model_dir).to(DEVICE)
max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512"))
predictor = KronosPredictor(
model=model,
tokenizer=tokenizer,
device=DEVICE,
max_context=max_context,
)
return tokenizer, model, predictor
TOKENIZER, MODEL, PREDICTOR = _load_components(".")
def predict(request: Dict[str, Any]) -> Dict[str, Any]:
"""
Entry point for Hugging Face Inference Endpoints.
Expected input format:
{
"inputs": {
"df": [
{"open": ..., "high": ..., "low": ..., "close": ...},
...
],
"x_timestamp": [...], # list of ISO8601 strings or timestamps
"y_timestamp": [...], # list of ISO8601 strings or timestamps
"pred_len": 120,
"T": 1.0, # optional
"top_p": 0.9, # optional
"sample_count": 1 # optional
}
}
"""
inputs = request.get("inputs", request)
df = pd.DataFrame(inputs["df"])
x_timestamp = pd.to_datetime(inputs["x_timestamp"])
y_timestamp = pd.to_datetime(inputs["y_timestamp"])
pred_len = int(inputs["pred_len"])
T = float(inputs.get("T", 1.0))
top_p = float(inputs.get("top_p", 0.9))
sample_count = int(inputs.get("sample_count", 1))
result_df = PREDICTOR.predict(
df=df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=pred_len,
T=T,
top_p=top_p,
sample_count=sample_count,
)
# Return a plain dict for JSON serialization
return {
"predictions": result_df.to_dict(orient="records"),
}