import os from typing import Any, Dict import pandas as pd import torch from kronos import Kronos, KronosTokenizer, KronosPredictor # type: ignore DEVICE = "cuda" if torch.cuda.is_available() else "cpu" def _load_components(model_dir: str = "."): """ Load tokenizer, model, and predictor from a local directory. This is called once at module import time on HF Inference Endpoints. """ tokenizer = KronosTokenizer.from_pretrained(model_dir) model = Kronos.from_pretrained(model_dir).to(DEVICE) max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512")) predictor = KronosPredictor( model=model, tokenizer=tokenizer, device=DEVICE, max_context=max_context, ) return tokenizer, model, predictor TOKENIZER, MODEL, PREDICTOR = _load_components(".") def predict(request: Dict[str, Any]) -> Dict[str, Any]: """ Entry point for Hugging Face Inference Endpoints. Expected input format: { "inputs": { "df": [ {"open": ..., "high": ..., "low": ..., "close": ...}, ... ], "x_timestamp": [...], # list of ISO8601 strings or timestamps "y_timestamp": [...], # list of ISO8601 strings or timestamps "pred_len": 120, "T": 1.0, # optional "top_p": 0.9, # optional "sample_count": 1 # optional } } """ inputs = request.get("inputs", request) df = pd.DataFrame(inputs["df"]) x_timestamp = pd.to_datetime(inputs["x_timestamp"]) y_timestamp = pd.to_datetime(inputs["y_timestamp"]) pred_len = int(inputs["pred_len"]) T = float(inputs.get("T", 1.0)) top_p = float(inputs.get("top_p", 0.9)) sample_count = int(inputs.get("sample_count", 1)) result_df = PREDICTOR.predict( df=df, x_timestamp=x_timestamp, y_timestamp=y_timestamp, pred_len=pred_len, T=T, top_p=top_p, sample_count=sample_count, ) # Return a plain dict for JSON serialization return { "predictions": result_df.to_dict(orient="records"), }