File size: 2,189 Bytes
360970f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from typing import Any, Dict

import pandas as pd
import torch

from kronos import Kronos, KronosTokenizer, KronosPredictor  # type: ignore


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def _load_components(model_dir: str = "."):
    """
    Load tokenizer, model, and predictor from a local directory.

    This is called once at module import time on HF Inference Endpoints.
    """
    tokenizer = KronosTokenizer.from_pretrained(model_dir)
    model = Kronos.from_pretrained(model_dir).to(DEVICE)

    max_context = int(os.getenv("KRONOS_MAX_CONTEXT", "512"))

    predictor = KronosPredictor(
        model=model,
        tokenizer=tokenizer,
        device=DEVICE,
        max_context=max_context,
    )

    return tokenizer, model, predictor


TOKENIZER, MODEL, PREDICTOR = _load_components(".")


def predict(request: Dict[str, Any]) -> Dict[str, Any]:
    """
    Entry point for Hugging Face Inference Endpoints.

    Expected input format:

    {
        "inputs": {
            "df": [
                {"open": ..., "high": ..., "low": ..., "close": ...},
                ...
            ],
            "x_timestamp": [...],   # list of ISO8601 strings or timestamps
            "y_timestamp": [...],   # list of ISO8601 strings or timestamps
            "pred_len": 120,
            "T": 1.0,               # optional
            "top_p": 0.9,           # optional
            "sample_count": 1       # optional
        }
    }
    """
    inputs = request.get("inputs", request)

    df = pd.DataFrame(inputs["df"])
    x_timestamp = pd.to_datetime(inputs["x_timestamp"])
    y_timestamp = pd.to_datetime(inputs["y_timestamp"])

    pred_len = int(inputs["pred_len"])
    T = float(inputs.get("T", 1.0))
    top_p = float(inputs.get("top_p", 0.9))
    sample_count = int(inputs.get("sample_count", 1))

    result_df = PREDICTOR.predict(
        df=df,
        x_timestamp=x_timestamp,
        y_timestamp=y_timestamp,
        pred_len=pred_len,
        T=T,
        top_p=top_p,
        sample_count=sample_count,
    )

    # Return a plain dict for JSON serialization
    return {
        "predictions": result_df.to_dict(orient="records"),
    }