| | import os |
| | from typing import Any, Dict |
| |
|
| | import pandas as pd |
| | import torch |
| |
|
| | from kronos import Kronos, KronosTokenizer, KronosPredictor |
| |
|
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| |
|
| | def _load_components(model_dir: str = "."): |
| | """ |
| | Load tokenizer, model, and predictor from a local directory. |
| | |
| | This is called once at module import time on HF Inference Endpoints. |
| | """ |
| | tokenizer = KronosTokenizer.from_pretrained(model_dir) |
| | model = Kronos.from_pretrained(model_dir).to(DEVICE) |
| |
|
| | max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512")) |
| |
|
| | predictor = KronosPredictor( |
| | model=model, |
| | tokenizer=tokenizer, |
| | device=DEVICE, |
| | max_context=max_context, |
| | ) |
| |
|
| | return tokenizer, model, predictor |
| |
|
| |
|
| | TOKENIZER, MODEL, PREDICTOR = _load_components(".") |
| |
|
| |
|
| | def predict(request: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Entry point for Hugging Face Inference Endpoints. |
| | |
| | Expected input format: |
| | |
| | { |
| | "inputs": { |
| | "df": [ |
| | {"open": ..., "high": ..., "low": ..., "close": ...}, |
| | ... |
| | ], |
| | "x_timestamp": [...], # list of ISO8601 strings or timestamps |
| | "y_timestamp": [...], # list of ISO8601 strings or timestamps |
| | "pred_len": 120, |
| | "T": 1.0, # optional |
| | "top_p": 0.9, # optional |
| | "sample_count": 1 # optional |
| | } |
| | } |
| | """ |
| | inputs = request.get("inputs", request) |
| |
|
| | df = pd.DataFrame(inputs["df"]) |
| | x_timestamp = pd.to_datetime(inputs["x_timestamp"]) |
| | y_timestamp = pd.to_datetime(inputs["y_timestamp"]) |
| |
|
| | pred_len = int(inputs["pred_len"]) |
| | T = float(inputs.get("T", 1.0)) |
| | top_p = float(inputs.get("top_p", 0.9)) |
| | sample_count = int(inputs.get("sample_count", 1)) |
| |
|
| | result_df = PREDICTOR.predict( |
| | df=df, |
| | x_timestamp=x_timestamp, |
| | y_timestamp=y_timestamp, |
| | pred_len=pred_len, |
| | T=T, |
| | top_p=top_p, |
| | sample_count=sample_count, |
| | ) |
| |
|
| | |
| | return { |
| | "predictions": result_df.to_dict(orient="records"), |
| | } |
| |
|