thibautmodrin commited on
Commit
cc9632a
·
1 Parent(s): ac0f7d0
Files changed (1) hide show
  1. app.py +53 -22
app.py CHANGED
@@ -52,14 +52,23 @@ class TrainingResponse(BaseModel):
52
  message: str
53
  details: Dict[str, Any]
54
 
 
 
 
 
 
55
  def load_models():
56
  """Charge les modèles existants"""
57
  try:
 
 
 
58
  model_profit = joblib.load(MODELS_DIR / "model_profit.joblib")
59
  model_drawdown = joblib.load(MODELS_DIR / "model_drawdown.joblib")
60
  model_params = joblib.load(MODELS_DIR / "model_params.joblib")
61
  return model_profit, model_drawdown, model_params
62
- except:
 
63
  return None, None, None
64
 
65
  @app.get("/")
@@ -79,11 +88,10 @@ async def health_check():
79
  """
80
  Endpoint pour vérifier l'état de l'API et des modèles
81
  """
82
- model_profit, model_drawdown, model_params = load_models()
83
-
84
  return {
85
  "status": "healthy",
86
- "models_loaded": all([model_profit, model_drawdown, model_params]),
87
  "models_path": str(MODELS_DIR),
88
  "available_endpoints": [
89
  "/train (POST) - Entraîner les modèles",
@@ -98,12 +106,22 @@ async def train_from_csv(file: UploadFile = File(...)) -> TrainingResponse:
98
  Endpoint pour entraîner les modèles à partir des données CSV
99
  """
100
  try:
101
- # Lire le contenu du fichier
102
  content = await file.read()
103
  content_str = content.decode('utf-8')
104
 
105
  # Convertir en DataFrame
106
- df = pd.read_csv(StringIO(content_str), parse_dates=['Date'], index_col='Date')
 
 
 
 
 
 
 
 
 
 
107
  print(f"Données reçues : {len(df)} lignes")
108
 
109
  # Prétraiter les données
@@ -145,18 +163,22 @@ async def predict(data: MarketData):
145
  Endpoint pour faire des prédictions avec les modèles entraînés
146
  """
147
  try:
 
 
 
 
 
 
 
 
 
 
 
148
  # Charger les modèles
149
  model_profit, model_drawdown, model_params = load_models()
150
 
151
- if not all([model_profit, model_drawdown, model_params]):
152
- raise HTTPException(
153
- status_code=400,
154
- detail="Modèles non disponibles. Veuillez d'abord entraîner les modèles."
155
- )
156
-
157
  # Préparer les features dans le bon ordre
158
- features = model_params['features']
159
- X = np.array([[
160
  data.Ichimoku_ADX_Volatility_Signal,
161
  data.BB_Stoch_ATR_Signal,
162
  data.Chikou_MACD_Pente_Signal,
@@ -165,14 +187,16 @@ async def predict(data: MarketData):
165
  data.ADX,
166
  data.Volatility_20,
167
  data.MACD
168
- ]])
 
 
169
 
170
  # Faire les prédictions
171
- profit_strategy_idx = model_profit.predict(X)[0]
172
- drawdown_strategy_idx = model_drawdown.predict(X)[0]
173
 
174
  # Obtenir les noms des stratégies
175
- strategies = model_params['strategies']
176
 
177
  # Récupérer les signaux correspondants
178
  signals = [
@@ -183,6 +207,7 @@ async def predict(data: MarketData):
183
  ]
184
 
185
  return {
 
186
  "best_profit_strategy": strategies[profit_strategy_idx],
187
  "best_profit_signal": signals[profit_strategy_idx],
188
  "best_drawdown_strategy": strategies[drawdown_strategy_idx],
@@ -190,10 +215,16 @@ async def predict(data: MarketData):
190
  }
191
 
192
  except Exception as e:
193
- raise HTTPException(
194
- status_code=500,
195
- detail=f"Erreur lors de la prédiction : {str(e)}"
196
- )
 
 
 
 
 
 
197
 
198
  if __name__ == "__main__":
199
  import uvicorn
 
52
  message: str
53
  details: Dict[str, Any]
54
 
55
+ def are_models_available():
56
+ """Vérifie si tous les modèles nécessaires sont disponibles"""
57
+ required_files = ["model_profit.joblib", "model_drawdown.joblib", "model_params.joblib"]
58
+ return all((MODELS_DIR / file).exists() for file in required_files)
59
+
60
  def load_models():
61
  """Charge les modèles existants"""
62
  try:
63
+ if not are_models_available():
64
+ return None, None, None
65
+
66
  model_profit = joblib.load(MODELS_DIR / "model_profit.joblib")
67
  model_drawdown = joblib.load(MODELS_DIR / "model_drawdown.joblib")
68
  model_params = joblib.load(MODELS_DIR / "model_params.joblib")
69
  return model_profit, model_drawdown, model_params
70
+ except Exception as e:
71
+ print(f"Erreur lors du chargement des modèles : {str(e)}")
72
  return None, None, None
73
 
74
  @app.get("/")
 
88
  """
89
  Endpoint pour vérifier l'état de l'API et des modèles
90
  """
91
+ models_available = are_models_available()
 
92
  return {
93
  "status": "healthy",
94
+ "models_available": models_available,
95
  "models_path": str(MODELS_DIR),
96
  "available_endpoints": [
97
  "/train (POST) - Entraîner les modèles",
 
106
  Endpoint pour entraîner les modèles à partir des données CSV
107
  """
108
  try:
109
+ # Lire et valider le contenu du fichier
110
  content = await file.read()
111
  content_str = content.decode('utf-8')
112
 
113
  # Convertir en DataFrame
114
+ df = pd.read_csv(StringIO(content_str))
115
+
116
+ # Vérifier les colonnes requises
117
+ required_columns = ['Date', 'Open', 'High', 'Low', 'Close']
118
+ if not all(col in df.columns for col in required_columns):
119
+ raise ValueError(f"Colonnes manquantes. Requis: {required_columns}")
120
+
121
+ # Configurer l'index temporel
122
+ df['Date'] = pd.to_datetime(df['Date'])
123
+ df.set_index('Date', inplace=True)
124
+
125
  print(f"Données reçues : {len(df)} lignes")
126
 
127
  # Prétraiter les données
 
163
  Endpoint pour faire des prédictions avec les modèles entraînés
164
  """
165
  try:
166
+ # Vérifier si les modèles sont disponibles
167
+ if not are_models_available():
168
+ return {
169
+ "status": "error",
170
+ "message": "Modèles non disponibles",
171
+ "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal",
172
+ "best_profit_signal": 0,
173
+ "best_drawdown_strategy": "BB_Stoch_ATR_Signal",
174
+ "best_drawdown_signal": 0
175
+ }
176
+
177
  # Charger les modèles
178
  model_profit, model_drawdown, model_params = load_models()
179
 
 
 
 
 
 
 
180
  # Préparer les features dans le bon ordre
181
+ features = [
 
182
  data.Ichimoku_ADX_Volatility_Signal,
183
  data.BB_Stoch_ATR_Signal,
184
  data.Chikou_MACD_Pente_Signal,
 
187
  data.ADX,
188
  data.Volatility_20,
189
  data.MACD
190
+ ]
191
+
192
+ X = np.array(features).reshape(1, -1)
193
 
194
  # Faire les prédictions
195
+ profit_strategy_idx = int(model_profit.predict(X)[0])
196
+ drawdown_strategy_idx = int(model_drawdown.predict(X)[0])
197
 
198
  # Obtenir les noms des stratégies
199
+ strategies = model_params["strategies"]
200
 
201
  # Récupérer les signaux correspondants
202
  signals = [
 
207
  ]
208
 
209
  return {
210
+ "status": "success",
211
  "best_profit_strategy": strategies[profit_strategy_idx],
212
  "best_profit_signal": signals[profit_strategy_idx],
213
  "best_drawdown_strategy": strategies[drawdown_strategy_idx],
 
215
  }
216
 
217
  except Exception as e:
218
+ print(f"Erreur lors de la prédiction : {str(e)}")
219
+ # Retourner une réponse par défaut en cas d'erreur
220
+ return {
221
+ "status": "error",
222
+ "message": str(e),
223
+ "best_profit_strategy": "Ichimoku_ADX_Volatility_Signal",
224
+ "best_profit_signal": 0,
225
+ "best_drawdown_strategy": "BB_Stoch_ATR_Signal",
226
+ "best_drawdown_signal": 0
227
+ }
228
 
229
  if __name__ == "__main__":
230
  import uvicorn