Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import pandas as pd | |
| import numpy as np | |
| import joblib | |
| import os | |
| from pathlib import Path | |
| from io import StringIO | |
| from typing import Dict, Any, Optional | |
| import json | |
| import logging | |
| # Configuration des logs | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Import des fonctions de train_strategy_models.py | |
| from train_strategy_models import ( | |
| preprocess_data, | |
| train_models, | |
| save_models, | |
| predict_best_strategy | |
| ) | |
| # Créer l'application FastAPI | |
| app = FastAPI(title="Strategy Selector API") | |
| # Ajouter le support CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Obtenir le chemin absolu du dossier models | |
| BASE_DIR = Path(__file__).resolve().parent | |
| MODELS_DIR = BASE_DIR / "models" | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| logger.info(f"Dossier de base : {BASE_DIR}") | |
| logger.info(f"Dossier des modèles : {MODELS_DIR}") | |
| # Mise à jour de la classe MarketData pour correspondre aux features | |
| class MarketData(BaseModel): | |
| # Features continus | |
| RSI: float | |
| ADX: float | |
| Volatility_20: float | |
| MACD: float | |
| # Signaux des stratégies | |
| Ichimoku_ADX_Volatility_Signal: int | |
| BB_Stoch_ATR_Signal: int | |
| Chikou_MACD_Pente_Signal: int | |
| ADX_Stoch_Volatility_MA_Signal: int | |
| class Config: | |
| schema_extra = { | |
| "example": { | |
| "RSI": 50.0, | |
| "ADX": 25.0, | |
| "Volatility_20": 0.001, | |
| "MACD": 0.0, | |
| "Ichimoku_ADX_Volatility_Signal": 0, | |
| "BB_Stoch_ATR_Signal": 0, | |
| "Chikou_MACD_Pente_Signal": 0, | |
| "ADX_Stoch_Volatility_MA_Signal": 0 | |
| } | |
| } | |
| class TrainingResponse(BaseModel): | |
| status: str | |
| message: str | |
| details: Dict[str, Any] | |
| def are_models_available() -> bool: | |
| """Vérifie si tous les modèles nécessaires sont disponibles""" | |
| required_files = ["model_profit.joblib", "model_drawdown.joblib", "model_params.joblib"] | |
| return all((MODELS_DIR / file).exists() for file in required_files) | |
| def load_models(): | |
| """Charge les modèles existants""" | |
| try: | |
| if not are_models_available(): | |
| logger.warning("Modèles non disponibles") | |
| return None, None, None | |
| model_profit = joblib.load(MODELS_DIR / "model_profit.joblib") | |
| model_drawdown = joblib.load(MODELS_DIR / "model_drawdown.joblib") | |
| model_params = joblib.load(MODELS_DIR / "model_params.joblib") | |
| logger.info("Modèles chargés avec succès") | |
| return model_profit, model_drawdown, model_params | |
| except Exception as e: | |
| logger.error(f"Erreur lors du chargement des modèles : {str(e)}") | |
| return None, None, None | |
| async def read_root(): | |
| return { | |
| "message": "Strategy Selector API", | |
| "version": "1.0", | |
| "status": "running", | |
| "models_available": are_models_available(), | |
| "endpoints": { | |
| "predict": "/predict (POST)", | |
| "train": "/train (POST)", | |
| "health": "/health (GET)" | |
| } | |
| } | |
| async def health_check(): | |
| """Endpoint pour vérifier l'état de l'API et des modèles""" | |
| models_available = are_models_available() | |
| return { | |
| "status": "healthy", | |
| "models_available": models_available, | |
| "models_path": str(MODELS_DIR), | |
| "available_endpoints": [ | |
| "/train (POST) - Entraîner les modèles", | |
| "/predict (POST) - Faire des prédictions", | |
| "/health (GET) - Vérifier l'état" | |
| ] | |
| } | |
| async def train_from_csv(file: UploadFile = File(...)) -> TrainingResponse: | |
| """Endpoint pour entraîner les modèles à partir des données CSV""" | |
| try: | |
| logger.info(f"Réception du fichier : {file.filename}") | |
| # Lire et valider le contenu du fichier | |
| content = await file.read() | |
| content_str = content.decode('utf-8') | |
| # Convertir en DataFrame | |
| df = pd.read_csv(StringIO(content_str)) | |
| # Vérifier les colonnes requises | |
| required_columns = ['Date', 'Open', 'High', 'Low', 'Close'] | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| raise ValueError(f"Colonnes manquantes : {missing_columns}") | |
| # Configurer l'index temporel | |
| df['Date'] = pd.to_datetime(df['Date']) | |
| df.set_index('Date', inplace=True) | |
| logger.info(f"Données reçues : {len(df)} lignes") | |
| # Prétraiter les données | |
| df_processed = preprocess_data(df) | |
| logger.info("Prétraitement terminé") | |
| # Entraîner les modèles | |
| model_profit, model_drawdown, features, strategies, split_info = train_models(df_processed) | |
| logger.info("Entraînement terminé") | |
| # Sauvegarder les modèles | |
| save_models(model_profit, model_drawdown, features, strategies, split_info) | |
| logger.info("Modèles sauvegardés") | |
| return TrainingResponse( | |
| status="success", | |
| message="Modèles entraînés avec succès", | |
| details={ | |
| "data_shape": df.shape, | |
| "training_period": { | |
| "start": str(df.index[0]), | |
| "end": str(df.index[-1]) | |
| }, | |
| "split_info": split_info, | |
| "features": features, | |
| "strategies": strategies | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Erreur lors de l'entraînement : {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Erreur lors de l'entraînement : {str(e)}" | |
| ) | |
| async def predict(data: MarketData): | |
| """Endpoint pour faire des prédictions avec les modèles entraînés""" | |
| try: | |
| logger.info("Réception d'une requête de prédiction") | |
| # Vérifier si les modèles sont disponibles | |
| if not are_models_available(): | |
| logger.warning("Modèles non disponibles") | |
| return { | |
| "status": "error", | |
| "message": "Modèles non disponibles - utilisation des valeurs par défaut", | |
| "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal", | |
| "best_profit_signal": 0, | |
| "best_drawdown_strategy": "BB_Stoch_ATR_Signal", | |
| "best_drawdown_signal": 0 | |
| } | |
| # Créer un DataFrame avec une seule ligne | |
| df = pd.DataFrame([{ | |
| 'RSI': data.RSI, | |
| 'ADX': data.ADX, | |
| 'Volatility_20': data.Volatility_20, | |
| 'MACD': data.MACD, | |
| 'Ichimoku_ADX_Volatility_Signal': data.Ichimoku_ADX_Volatility_Signal, | |
| 'BB_Stoch_ATR_Signal': data.BB_Stoch_ATR_Signal, | |
| 'Chikou_MACD_Pente_Signal': data.Chikou_MACD_Pente_Signal, | |
| 'ADX_Stoch_Volatility_MA_Signal': data.ADX_Stoch_Volatility_MA_Signal | |
| }]) | |
| # Utiliser la fonction de prédiction | |
| result = predict_best_strategy(df) | |
| if result is None: | |
| raise Exception("Erreur lors de la prédiction") | |
| logger.info(f"Prédiction réussie : {result}") | |
| return { | |
| "status": "success", | |
| **result | |
| } | |
| except Exception as e: | |
| logger.error(f"Erreur lors de la prédiction : {str(e)}") | |
| return { | |
| "status": "error", | |
| "message": str(e), | |
| "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal", | |
| "best_profit_signal": 0, | |
| "best_drawdown_strategy": "BB_Stoch_ATR_Signal", | |
| "best_drawdown_signal": 0 | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |