Fast-Api / app.py
thibautmodrin's picture
new push
118bc22
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import pandas as pd
import numpy as np
import joblib
import os
from pathlib import Path
from io import StringIO
from typing import Dict, Any, Optional
import json
import logging
# Configuration des logs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import des fonctions de train_strategy_models.py
from train_strategy_models import (
preprocess_data,
train_models,
save_models,
predict_best_strategy
)
# Créer l'application FastAPI
app = FastAPI(title="Strategy Selector API")
# Ajouter le support CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Obtenir le chemin absolu du dossier models
BASE_DIR = Path(__file__).resolve().parent
MODELS_DIR = BASE_DIR / "models"
MODELS_DIR.mkdir(exist_ok=True)
logger.info(f"Dossier de base : {BASE_DIR}")
logger.info(f"Dossier des modèles : {MODELS_DIR}")
# Mise à jour de la classe MarketData pour correspondre aux features
class MarketData(BaseModel):
# Features continus
RSI: float
ADX: float
Volatility_20: float
MACD: float
# Signaux des stratégies
Ichimoku_ADX_Volatility_Signal: int
BB_Stoch_ATR_Signal: int
Chikou_MACD_Pente_Signal: int
ADX_Stoch_Volatility_MA_Signal: int
class Config:
schema_extra = {
"example": {
"RSI": 50.0,
"ADX": 25.0,
"Volatility_20": 0.001,
"MACD": 0.0,
"Ichimoku_ADX_Volatility_Signal": 0,
"BB_Stoch_ATR_Signal": 0,
"Chikou_MACD_Pente_Signal": 0,
"ADX_Stoch_Volatility_MA_Signal": 0
}
}
class TrainingResponse(BaseModel):
status: str
message: str
details: Dict[str, Any]
def are_models_available() -> bool:
"""Vérifie si tous les modèles nécessaires sont disponibles"""
required_files = ["model_profit.joblib", "model_drawdown.joblib", "model_params.joblib"]
return all((MODELS_DIR / file).exists() for file in required_files)
def load_models():
"""Charge les modèles existants"""
try:
if not are_models_available():
logger.warning("Modèles non disponibles")
return None, None, None
model_profit = joblib.load(MODELS_DIR / "model_profit.joblib")
model_drawdown = joblib.load(MODELS_DIR / "model_drawdown.joblib")
model_params = joblib.load(MODELS_DIR / "model_params.joblib")
logger.info("Modèles chargés avec succès")
return model_profit, model_drawdown, model_params
except Exception as e:
logger.error(f"Erreur lors du chargement des modèles : {str(e)}")
return None, None, None
@app.get("/")
async def read_root():
return {
"message": "Strategy Selector API",
"version": "1.0",
"status": "running",
"models_available": are_models_available(),
"endpoints": {
"predict": "/predict (POST)",
"train": "/train (POST)",
"health": "/health (GET)"
}
}
@app.get("/health")
async def health_check():
"""Endpoint pour vérifier l'état de l'API et des modèles"""
models_available = are_models_available()
return {
"status": "healthy",
"models_available": models_available,
"models_path": str(MODELS_DIR),
"available_endpoints": [
"/train (POST) - Entraîner les modèles",
"/predict (POST) - Faire des prédictions",
"/health (GET) - Vérifier l'état"
]
}
@app.post("/train")
async def train_from_csv(file: UploadFile = File(...)) -> TrainingResponse:
"""Endpoint pour entraîner les modèles à partir des données CSV"""
try:
logger.info(f"Réception du fichier : {file.filename}")
# Lire et valider le contenu du fichier
content = await file.read()
content_str = content.decode('utf-8')
# Convertir en DataFrame
df = pd.read_csv(StringIO(content_str))
# Vérifier les colonnes requises
required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
raise ValueError(f"Colonnes manquantes : {missing_columns}")
# Configurer l'index temporel
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
logger.info(f"Données reçues : {len(df)} lignes")
# Prétraiter les données
df_processed = preprocess_data(df)
logger.info("Prétraitement terminé")
# Entraîner les modèles
model_profit, model_drawdown, features, strategies, split_info = train_models(df_processed)
logger.info("Entraînement terminé")
# Sauvegarder les modèles
save_models(model_profit, model_drawdown, features, strategies, split_info)
logger.info("Modèles sauvegardés")
return TrainingResponse(
status="success",
message="Modèles entraînés avec succès",
details={
"data_shape": df.shape,
"training_period": {
"start": str(df.index[0]),
"end": str(df.index[-1])
},
"split_info": split_info,
"features": features,
"strategies": strategies
}
)
except Exception as e:
logger.error(f"Erreur lors de l'entraînement : {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Erreur lors de l'entraînement : {str(e)}"
)
@app.post("/predict")
async def predict(data: MarketData):
"""Endpoint pour faire des prédictions avec les modèles entraînés"""
try:
logger.info("Réception d'une requête de prédiction")
# Vérifier si les modèles sont disponibles
if not are_models_available():
logger.warning("Modèles non disponibles")
return {
"status": "error",
"message": "Modèles non disponibles - utilisation des valeurs par défaut",
"best_profit_strategy": "Ichimoku_ADX_Volatility_Signal",
"best_profit_signal": 0,
"best_drawdown_strategy": "BB_Stoch_ATR_Signal",
"best_drawdown_signal": 0
}
# Créer un DataFrame avec une seule ligne
df = pd.DataFrame([{
'RSI': data.RSI,
'ADX': data.ADX,
'Volatility_20': data.Volatility_20,
'MACD': data.MACD,
'Ichimoku_ADX_Volatility_Signal': data.Ichimoku_ADX_Volatility_Signal,
'BB_Stoch_ATR_Signal': data.BB_Stoch_ATR_Signal,
'Chikou_MACD_Pente_Signal': data.Chikou_MACD_Pente_Signal,
'ADX_Stoch_Volatility_MA_Signal': data.ADX_Stoch_Volatility_MA_Signal
}])
# Utiliser la fonction de prédiction
result = predict_best_strategy(df)
if result is None:
raise Exception("Erreur lors de la prédiction")
logger.info(f"Prédiction réussie : {result}")
return {
"status": "success",
**result
}
except Exception as e:
logger.error(f"Erreur lors de la prédiction : {str(e)}")
return {
"status": "error",
"message": str(e),
"best_profit_strategy": "Ichimoku_ADX_Volatility_Signal",
"best_profit_signal": 0,
"best_drawdown_strategy": "BB_Stoch_ATR_Signal",
"best_drawdown_signal": 0
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)