kronos-api / main.py
jeanno31's picture
Update main.py
899d941 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
import torch
# L'import marche parfaitement car le dossier 'model' a été cloné par Docker !
from model import Kronos, KronosTokenizer, KronosPredictor
app = FastAPI(title="Kronos Trading API")
print("📥 Chargement du modèle Kronos-small...")
try:
tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
# max_context est de 512 pour Kronos-small
predictor = KronosPredictor(model, tokenizer, max_context=512)
print("✅ Modèle chargé et prêt !")
except Exception as e:
print(f"❌ ERREUR CRITIQUE : {str(e)}")
predictor = None
class PredictionRequest(BaseModel):
history_data: list
lookback: int = 400
pred_len: int = 24
@app.get("/")
def read_root():
return {"status": "Kronos API is running 🚀", "model_loaded": predictor is not None}
# 💡 L'ASTUCE ANTI-ERREUR 404 : L'API répondra aux deux adresses !
@app.post("/predict")
@app.post("/predict/predict")
def predict(request: PredictionRequest):
if predictor is None:
raise HTTPException(status_code=500, detail="Le modèle n'a pas pu démarrer.")
if not request.history_data:
raise HTTPException(status_code=400, detail="Aucune donnée historique (OHLC) fournie.")
df = pd.DataFrame(request.history_data)
required_cols = ['open', 'high', 'low', 'close']
if not all(col in df.columns for col in required_cols):
raise HTTPException(status_code=400, detail=f"Colonnes manquantes. Requis : {required_cols}")
if 'timestamps' not in df.columns:
df['timestamps'] = pd.date_range(start='2024-01-01', periods=len(df), freq='H')
else:
df['timestamps'] = pd.to_datetime(df['timestamps'])
actual_lookback = min(len(df), request.lookback, 512)
x_df = df.tail(actual_lookback).reset_index(drop=True)
x_timestamp = x_df['timestamps']
last_time = x_timestamp.iloc[-1]
# 👇 CORRECTION ICI : Conversion explicite en pd.Series 👇
y_timestamp = pd.Series(pd.date_range(start=last_time, periods=request.pred_len + 1, freq='h')[1:])
try:
pred_df = predictor.predict(
df=x_df,
x_timestamp=x_timestamp,
y_timestamp=y_timestamp,
pred_len=request.pred_len,
T=1.0,
top_p=0.9,
sample_count=1
)
current_price = float(x_df['close'].iloc[-1])
predicted_price = float(pred_df['close'].iloc[-1])
expected_change = ((predicted_price - current_price) / current_price) * 100
return {
"status": "success",
"current_price": current_price,
"predicted_price_in_24h": predicted_price,
"expected_change_percent": expected_change
}
except Exception as e:
print(f"❌ Erreur lors de la prédiction : {str(e)}")
raise HTTPException(status_code=500, detail=f"Erreur interne : {str(e)}")