from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import pandas as pd import numpy as np import joblib import os from pathlib import Path from io import StringIO from typing import Dict, Any, Optional import json import logging # Configuration des logs logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Import des fonctions de train_strategy_models.py from train_strategy_models import ( preprocess_data, train_models, save_models, predict_best_strategy ) # Créer l'application FastAPI app = FastAPI(title="Strategy Selector API") # Ajouter le support CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Obtenir le chemin absolu du dossier models BASE_DIR = Path(__file__).resolve().parent MODELS_DIR = BASE_DIR / "models" MODELS_DIR.mkdir(exist_ok=True) logger.info(f"Dossier de base : {BASE_DIR}") logger.info(f"Dossier des modèles : {MODELS_DIR}") # Mise à jour de la classe MarketData pour correspondre aux features class MarketData(BaseModel): # Features continus RSI: float ADX: float Volatility_20: float MACD: float # Signaux des stratégies Ichimoku_ADX_Volatility_Signal: int BB_Stoch_ATR_Signal: int Chikou_MACD_Pente_Signal: int ADX_Stoch_Volatility_MA_Signal: int class Config: schema_extra = { "example": { "RSI": 50.0, "ADX": 25.0, "Volatility_20": 0.001, "MACD": 0.0, "Ichimoku_ADX_Volatility_Signal": 0, "BB_Stoch_ATR_Signal": 0, "Chikou_MACD_Pente_Signal": 0, "ADX_Stoch_Volatility_MA_Signal": 0 } } class TrainingResponse(BaseModel): status: str message: str details: Dict[str, Any] def are_models_available() -> bool: """Vérifie si tous les modèles nécessaires sont disponibles""" required_files = ["model_profit.joblib", "model_drawdown.joblib", "model_params.joblib"] return all((MODELS_DIR / file).exists() for file in required_files) def load_models(): """Charge les modèles existants""" try: if not are_models_available(): logger.warning("Modèles non disponibles") return None, None, None model_profit = joblib.load(MODELS_DIR / "model_profit.joblib") model_drawdown = joblib.load(MODELS_DIR / "model_drawdown.joblib") model_params = joblib.load(MODELS_DIR / "model_params.joblib") logger.info("Modèles chargés avec succès") return model_profit, model_drawdown, model_params except Exception as e: logger.error(f"Erreur lors du chargement des modèles : {str(e)}") return None, None, None @app.get("/") async def read_root(): return { "message": "Strategy Selector API", "version": "1.0", "status": "running", "models_available": are_models_available(), "endpoints": { "predict": "/predict (POST)", "train": "/train (POST)", "health": "/health (GET)" } } @app.get("/health") async def health_check(): """Endpoint pour vérifier l'état de l'API et des modèles""" models_available = are_models_available() return { "status": "healthy", "models_available": models_available, "models_path": str(MODELS_DIR), "available_endpoints": [ "/train (POST) - Entraîner les modèles", "/predict (POST) - Faire des prédictions", "/health (GET) - Vérifier l'état" ] } @app.post("/train") async def train_from_csv(file: UploadFile = File(...)) -> TrainingResponse: """Endpoint pour entraîner les modèles à partir des données CSV""" try: logger.info(f"Réception du fichier : {file.filename}") # Lire et valider le contenu du fichier content = await file.read() content_str = content.decode('utf-8') # Convertir en DataFrame df = pd.read_csv(StringIO(content_str)) # Vérifier les colonnes requises required_columns = ['Date', 'Open', 'High', 'Low', 'Close'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: raise ValueError(f"Colonnes manquantes : {missing_columns}") # Configurer l'index temporel df['Date'] = pd.to_datetime(df['Date']) df.set_index('Date', inplace=True) logger.info(f"Données reçues : {len(df)} lignes") # Prétraiter les données df_processed = preprocess_data(df) logger.info("Prétraitement terminé") # Entraîner les modèles model_profit, model_drawdown, features, strategies, split_info = train_models(df_processed) logger.info("Entraînement terminé") # Sauvegarder les modèles save_models(model_profit, model_drawdown, features, strategies, split_info) logger.info("Modèles sauvegardés") return TrainingResponse( status="success", message="Modèles entraînés avec succès", details={ "data_shape": df.shape, "training_period": { "start": str(df.index[0]), "end": str(df.index[-1]) }, "split_info": split_info, "features": features, "strategies": strategies } ) except Exception as e: logger.error(f"Erreur lors de l'entraînement : {str(e)}") raise HTTPException( status_code=500, detail=f"Erreur lors de l'entraînement : {str(e)}" ) @app.post("/predict") async def predict(data: MarketData): """Endpoint pour faire des prédictions avec les modèles entraînés""" try: logger.info("Réception d'une requête de prédiction") # Vérifier si les modèles sont disponibles if not are_models_available(): logger.warning("Modèles non disponibles") return { "status": "error", "message": "Modèles non disponibles - utilisation des valeurs par défaut", "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal", "best_profit_signal": 0, "best_drawdown_strategy": "BB_Stoch_ATR_Signal", "best_drawdown_signal": 0 } # Créer un DataFrame avec une seule ligne df = pd.DataFrame([{ 'RSI': data.RSI, 'ADX': data.ADX, 'Volatility_20': data.Volatility_20, 'MACD': data.MACD, 'Ichimoku_ADX_Volatility_Signal': data.Ichimoku_ADX_Volatility_Signal, 'BB_Stoch_ATR_Signal': data.BB_Stoch_ATR_Signal, 'Chikou_MACD_Pente_Signal': data.Chikou_MACD_Pente_Signal, 'ADX_Stoch_Volatility_MA_Signal': data.ADX_Stoch_Volatility_MA_Signal }]) # Utiliser la fonction de prédiction result = predict_best_strategy(df) if result is None: raise Exception("Erreur lors de la prédiction") logger.info(f"Prédiction réussie : {result}") return { "status": "success", **result } except Exception as e: logger.error(f"Erreur lors de la prédiction : {str(e)}") return { "status": "error", "message": str(e), "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal", "best_profit_signal": 0, "best_drawdown_strategy": "BB_Stoch_ATR_Signal", "best_drawdown_signal": 0 } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)