|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import pandas as pd |
|
|
import torch |
|
|
import sys |
|
|
import os |
|
|
from datetime import timedelta |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "Kronos")) |
|
|
from model import Kronos, KronosTokenizer, KronosPredictor |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Loading model on {device}...") |
|
|
|
|
|
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") |
|
|
|
|
|
model = Kronos.from_pretrained("NeoQuasar/Kronos-base") |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
predictor = KronosPredictor(model, tokenizer, device=device, max_context=512) |
|
|
print("Model loaded!") |
|
|
|
|
|
class CandleData(BaseModel): |
|
|
|
|
|
pair: str |
|
|
data: list[dict] |
|
|
|
|
|
@app.post("/predict") |
|
|
def predict(payload: CandleData): |
|
|
try: |
|
|
|
|
|
df = pd.DataFrame(payload.data) |
|
|
if 'date' in df.columns: |
|
|
df['timestamps'] = pd.to_datetime(df['date']) |
|
|
else: |
|
|
|
|
|
df['timestamps'] = pd.to_datetime(df['date']) |
|
|
|
|
|
input_cols = ["open", "high", "low", "close", "volume"] |
|
|
|
|
|
|
|
|
x_timestamp = df["timestamps"] |
|
|
pred_steps = 3 |
|
|
last_time = x_timestamp.iloc[-1] |
|
|
y_timestamp = pd.Series([ |
|
|
last_time + timedelta(minutes=5 * (i + 1)) |
|
|
for i in range(pred_steps) |
|
|
]) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
forecast = predictor.predict( |
|
|
df=df[input_cols], |
|
|
x_timestamp=x_timestamp, |
|
|
y_timestamp=y_timestamp, |
|
|
pred_len=pred_steps, |
|
|
T=0.8, |
|
|
top_p=0.9, |
|
|
sample_count=10 |
|
|
) |
|
|
|
|
|
|
|
|
raw_pred = forecast["close"].mean() |
|
|
return {"prediction": float(raw_pred)} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |