Starchik1 commited on
Commit
e71fc62
·
verified ·
1 Parent(s): e6886e3

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +525 -0
train.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ train.py
5
+ Продвинутая система обучения моделей с итеративным улучшением.
6
+ Обучается до достижения минимальной точности (по умолчанию 0.80).
7
+ Сохраняет лучшие модели и метаданные в папку models/
8
+ """
9
+ import pandas as pd
10
+ import numpy as np
11
+ import requests
12
+ import joblib
13
+ import os
14
+ import time
15
+ import logging
16
+ import threading
17
+ from datetime import datetime
18
+ from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
19
+ from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
20
+ from sklearn.linear_model import LogisticRegression
21
+ from sklearn.svm import SVC
22
+ from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
23
+ from sklearn.metrics import accuracy_score
24
+ import warnings
25
+ warnings.filterwarnings('ignore')
26
+
27
+ # TA-Lib импортируем здесь; если не установлен — бросим понятную ошибку
28
+ try:
29
+ import talib
30
+ except Exception as e:
31
+ raise ImportError("TA-Lib не найден. Установите TA-Lib (системная библиотека + pip install TA-Lib).") from e
32
+
33
+ # Логирование
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(levelname)s - %(message)s',
37
+ handlers=[
38
+ logging.FileHandler('training_log.txt', encoding='utf-8'),
39
+ logging.StreamHandler()
40
+ ]
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+ class AdvancedCryptoModelTrainer:
45
+ def __init__(self, symbol='BTCUSDT', interval='1h', target_accuracy=0.80, max_iterations=50):
46
+ self.symbol = symbol
47
+ self.interval = interval
48
+ self.target_accuracy = target_accuracy
49
+ self.models = {}
50
+ self.best_models = {}
51
+ self.feature_names = []
52
+ self.training_history = []
53
+ self.current_iteration = 0
54
+ self.max_iterations = max_iterations
55
+
56
+ # Прогрессивные параметры
57
+ self.data_limits = [1000, 2000, 3000, 5000]
58
+ self.feature_complexity_levels = [1, 2, 3, 4, 5]
59
+ self.scaler_types = ['standard', 'robust', 'minmax']
60
+
61
+ logger.info(f"Инициализация тренера для {symbol}, целевая точность: {target_accuracy*100:.2f}%")
62
+
63
+ def fetch_binance_data(self, limit=2000):
64
+ """Получение данных с Binance API, возможно в чанках (max 1000 за запрос)."""
65
+ url = "https://api.binance.com/api/v3/klines"
66
+ params = {
67
+ 'symbol': self.symbol,
68
+ 'interval': self.interval,
69
+ 'limit': min(limit, 1000)
70
+ }
71
+
72
+ all_data = []
73
+ end_time = None
74
+
75
+ while len(all_data) < limit:
76
+ if end_time:
77
+ params['endTime'] = end_time
78
+
79
+ try:
80
+ response = requests.get(url, params=params, timeout=10)
81
+ response.raise_for_status()
82
+ data = response.json()
83
+ if not data:
84
+ break
85
+
86
+ all_data.extend(data)
87
+ # Берём первую свечу в ответе (самая ранняя в странице) и указываем endTime на 1мс меньше,
88
+ # чтобы загрузить более ранние свечи в следующем запросе
89
+ end_time = data[0][0] - 1
90
+
91
+ if len(data) < 1000:
92
+ break
93
+
94
+ time.sleep(0.2)
95
+ except Exception as e:
96
+ logger.error(f"Ошибка при получении данных: {e}")
97
+ break
98
+
99
+ all_data = all_data[:limit]
100
+ if not all_data:
101
+ logger.error("Не удалось получить данные с Binance.")
102
+ return None
103
+
104
+ df = pd.DataFrame(all_data, columns=[
105
+ 'timestamp', 'open', 'high', 'low', 'close', 'volume',
106
+ 'close_time', 'quote_asset_volume', 'number_of_trades',
107
+ 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume', 'ignore'
108
+ ])
109
+ numeric_columns = ['open', 'high', 'low', 'close', 'volume']
110
+ for col in numeric_columns:
111
+ df[col] = pd.to_numeric(df[col], errors='coerce')
112
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
113
+ df = df.sort_values('timestamp').reset_index(drop=True)
114
+ logger.info(f"Получено {len(df)} записей для {self.symbol}")
115
+ return df
116
+
117
+ def calculate_advanced_technical_indicators(self, df, complexity_level=1):
118
+ """Расчет индикаторов (TA-Lib)."""
119
+ df = df.copy()
120
+ # Базовые
121
+ df['sma_5'] = talib.SMA(df['close'], timeperiod=5)
122
+ df['sma_10'] = talib.SMA(df['close'], timeperiod=10)
123
+ df['sma_20'] = talib.SMA(df['close'], timeperiod=20)
124
+ df['sma_50'] = talib.SMA(df['close'], timeperiod=50)
125
+ df['ema_12'] = talib.EMA(df['close'], timeperiod=12)
126
+ df['ema_26'] = talib.EMA(df['close'], timeperiod=26)
127
+ df['rsi'] = talib.RSI(df['close'], timeperiod=14)
128
+
129
+ macd, macd_signal, macd_hist = talib.MACD(df['close'])
130
+ df['macd'] = macd
131
+ df['macd_signal'] = macd_signal
132
+ df['macd_hist'] = macd_hist
133
+
134
+ bb_upper, bb_middle, bb_lower = talib.BBANDS(df['close'])
135
+ df['bb_upper'] = bb_upper
136
+ df['bb_middle'] = bb_middle
137
+ df['bb_lower'] = bb_lower
138
+ # Предохраняемся от деления на ноль
139
+ df['bb_width'] = (bb_upper - bb_lower) / (bb_middle.replace(0, np.nan))
140
+ df['bb_position'] = (df['close'] - bb_lower) / ((bb_upper - bb_lower).replace(0, np.nan))
141
+
142
+ if complexity_level >= 2:
143
+ df['stoch_k'], df['stoch_d'] = talib.STOCH(df['high'], df['low'], df['close'])
144
+ df['williams_r'] = talib.WILLR(df['high'], df['low'], df['close'])
145
+ df['cci'] = talib.CCI(df['high'], df['low'], df['close'])
146
+ df['atr'] = talib.ATR(df['high'], df['low'], df['close'])
147
+ df['adx'] = talib.ADX(df['high'], df['low'], df['close'])
148
+ df['ad'] = talib.AD(df['high'], df['low'], df['close'], df['volume'])
149
+ df['obv'] = talib.OBV(df['close'], df['volume'])
150
+
151
+ if complexity_level >= 3:
152
+ df['mfi'] = talib.MFI(df['high'], df['low'], df['close'], df['volume'])
153
+ df['roc'] = talib.ROC(df['close'])
154
+ df['tema'] = talib.TEMA(df['close'])
155
+ df['dema'] = talib.DEMA(df['close'])
156
+ # Набор паттернов (чуть-чуть)
157
+ df['doji'] = talib.CDLDOJI(df['open'], df['high'], df['low'], df['close'])
158
+ df['engulfing'] = talib.CDLENGULFING(df['open'], df['high'], df['low'], df['close'])
159
+
160
+ if complexity_level >= 4:
161
+ df['ht_trendline'] = talib.HT_TRENDLINE(df['close'])
162
+ df['cmf'] = talib.ADOSC(df['high'], df['low'], df['close'], df['volume'])
163
+
164
+ if complexity_level >= 5:
165
+ for period in [7, 14, 21, 30]:
166
+ df[f'sma_{period}'] = talib.SMA(df['close'], timeperiod=period)
167
+ df[f'ema_{period}'] = talib.EMA(df['close'], timeperiod=period)
168
+ df[f'rsi_{period}'] = talib.RSI(df['close'], timeperiod=period)
169
+
170
+ return df
171
+
172
+ def create_progressive_features(self, df, complexity_level=1):
173
+ df = df.copy()
174
+ basic_lags = [1, 2, 3, 5]
175
+ if complexity_level >= 2:
176
+ basic_lags += [10, 20]
177
+ if complexity_level >= 3:
178
+ basic_lags += [30, 50]
179
+ for lag in basic_lags:
180
+ df[f'close_lag_{lag}'] = df['close'].shift(lag)
181
+ df[f'volume_lag_{lag}'] = df['volume'].shift(lag)
182
+ if 'rsi' in df.columns:
183
+ df[f'rsi_lag_{lag}'] = df['rsi'].shift(lag)
184
+ if 'macd' in df.columns:
185
+ df[f'macd_lag_{lag}'] = df['macd'].shift(lag)
186
+
187
+ windows = [5, 10]
188
+ if complexity_level >= 2:
189
+ windows += [20, 30]
190
+ if complexity_level >= 3:
191
+ windows += [50, 100]
192
+ for window in windows:
193
+ if 'rsi' in df.columns:
194
+ df[f'rsi_sma_{window}'] = df['rsi'].rolling(window).mean()
195
+ df[f'rsi_std_{window}'] = df['rsi'].rolling(window).std()
196
+ if 'macd' in df.columns:
197
+ df[f'macd_sma_{window}'] = df['macd'].rolling(window).mean()
198
+ df[f'macd_std_{window}'] = df['macd'].rolling(window).std()
199
+ df[f'volume_ema_{window}'] = df['volume'].ewm(span=window).mean()
200
+ df[f'price_std_{window}'] = df['close'].rolling(window).std()
201
+
202
+ momentum_periods = [5, 10]
203
+ if complexity_level >= 2:
204
+ momentum_periods += [20, 30]
205
+ if complexity_level >= 3:
206
+ momentum_periods += [50, 100]
207
+ for period in momentum_periods:
208
+ df[f'momentum_{period}'] = df['close'] / df['close'].shift(period) - 1
209
+ try:
210
+ df[f'roc_{period}'] = talib.ROC(df['close'], timeperiod=period)
211
+ except:
212
+ df[f'roc_{period}'] = np.nan
213
+ df[f'volatility_{period}'] = df['close'].pct_change().rolling(period).std()
214
+
215
+ if complexity_level >= 3:
216
+ if 'rsi' in df.columns and 'macd' in df.columns:
217
+ df['rsi_macd_corr'] = df['rsi'].rolling(20).corr(df['macd'])
218
+ if 'sma_20' in df.columns and 'sma_50' in df.columns:
219
+ df['sma_ratio_20_50'] = df['sma_20'] / df['sma_50'].replace(0, np.nan)
220
+ for col in ['close', 'volume', 'rsi']:
221
+ if col in df.columns:
222
+ mean = df[col].rolling(50).mean()
223
+ std = df[col].rolling(50).std()
224
+ df[f'{col}_zscore'] = (df[col] - mean) / (std.replace(0, np.nan))
225
+
226
+ if complexity_level >= 4:
227
+ df['fractal_high'] = ((df['high'] > df['high'].shift(1)) & (df['high'] > df['high'].shift(-1))).astype(int)
228
+ df['fractal_low'] = ((df['low'] < df['low'].shift(1)) & (df['low'] < df['low'].shift(-1))).astype(int)
229
+ df['support'] = df['low'].rolling(20).min()
230
+ df['resistance'] = df['high'].rolling(20).max()
231
+ df['support_distance'] = (df['close'] - df['support']) / df['close']
232
+ df['resistance_distance'] = (df['resistance'] - df['close']) / df['close']
233
+
234
+ if complexity_level >= 5:
235
+ df['wave_trend'] = df['close'].rolling(50).apply(lambda x: 1 if x.iloc[-1] > x.iloc[0] else 0, raw=False)
236
+ if 'rsi' in df.columns:
237
+ price_trend = df['close'].rolling(10).apply(lambda x: x.iloc[-1] - x.iloc[0], raw=False)
238
+ rsi_trend = df['rsi'].rolling(10).apply(lambda x: x.iloc[-1] - x.iloc[0], raw=False)
239
+ df['price_rsi_divergence'] = ((price_trend > 0) & (rsi_trend < 0)) | ((price_trend < 0) & (rsi_trend > 0))
240
+
241
+ return df
242
+
243
+ def create_target_variable(self, df, prediction_horizon=1):
244
+ df = df.copy()
245
+ df['future_price'] = df['close'].shift(-prediction_horizon)
246
+ df['target'] = (df['future_price'] > df['close']).astype(int)
247
+ return df
248
+
249
+ def prepare_features(self, df):
250
+ exclude_columns = [
251
+ 'timestamp', 'open', 'high', 'low', 'close', 'volume',
252
+ 'close_time', 'quote_asset_volume', 'number_of_trades',
253
+ 'taker_buy_base_asset_volume', 'taker_buy_quote_asset_volume',
254
+ 'ignore', 'future_price', 'target'
255
+ ]
256
+ feature_columns = [col for col in df.columns if col not in exclude_columns]
257
+ df_clean = df.dropna()
258
+ if len(df_clean) == 0:
259
+ logger.error("Все строки содержат NaN после очистки!")
260
+ return None, None
261
+ X = df_clean[feature_columns]
262
+ y = df_clean['target']
263
+ self.feature_names = feature_columns
264
+ logger.info(f"Подготовлено {len(X)} образцов с {len(feature_columns)} признаками")
265
+ return X, y
266
+
267
+ def get_progressive_model_params(self, model_name, iteration):
268
+ base_params = {
269
+ 'Random Forest': {
270
+ 'n_estimators': min(100 + iteration * 100, 1000),
271
+ 'max_depth': min(10 + iteration * 2, 25),
272
+ 'min_samples_split': max(5 - iteration, 2),
273
+ 'min_samples_leaf': max(2 - iteration // 2, 1),
274
+ 'max_features': 'sqrt',
275
+ 'bootstrap': True,
276
+ 'random_state': 42,
277
+ 'n_jobs': -1,
278
+ 'class_weight': 'balanced'
279
+ },
280
+ 'Gradient Boosting': {
281
+ 'n_estimators': min(100 + iteration * 50, 500),
282
+ 'learning_rate': max(0.1 - iteration * 0.01, 0.01),
283
+ 'max_depth': min(6 + iteration, 12),
284
+ 'min_samples_split': max(5 - iteration, 2),
285
+ 'min_samples_leaf': max(2 - iteration // 2, 1),
286
+ 'subsample': 0.8,
287
+ 'max_features': 'sqrt',
288
+ 'random_state': 42
289
+ },
290
+ 'Extra Trees': {
291
+ 'n_estimators': min(100 + iteration * 100, 1000),
292
+ 'max_depth': min(10 + iteration * 2, 25),
293
+ 'min_samples_split': max(5 - iteration, 2),
294
+ 'min_samples_leaf': max(2 - iteration // 2, 1),
295
+ 'max_features': 'sqrt',
296
+ 'bootstrap': False,
297
+ 'random_state': 42,
298
+ 'n_jobs': -1,
299
+ 'class_weight': 'balanced'
300
+ },
301
+ 'Logistic Regression': {
302
+ 'random_state': 42,
303
+ 'max_iter': min(1000 + iteration * 500, 5000),
304
+ 'C': 10 ** (-2 + iteration * 0.5),
305
+ 'penalty': 'l2',
306
+ 'solver': 'liblinear',
307
+ 'class_weight': 'balanced'
308
+ },
309
+ 'SVM': {
310
+ 'kernel': 'rbf',
311
+ 'C': 10 ** (max(0, iteration * 0.5)),
312
+ 'gamma': 'scale',
313
+ 'probability': True,
314
+ 'class_weight': 'balanced'
315
+ }
316
+ }
317
+ return base_params.get(model_name, {})
318
+
319
+ def train_iteration(self, data_limit, complexity_level, scaler_type='standard'):
320
+ logger.info(f"Итерация {self.current_iteration + 1}: данных={data_limit}, сложность={complexity_level}, скейлер={scaler_type}")
321
+
322
+ df = self.fetch_binance_data(limit=data_limit)
323
+ if df is None or len(df) < 100:
324
+ logger.error("Недостаточно данных для обучения")
325
+ return False
326
+
327
+ df = self.calculate_advanced_technical_indicators(df, complexity_level)
328
+ df = self.create_progressive_features(df, complexity_level)
329
+ df = self.create_target_variable(df)
330
+
331
+ X, y = self.prepare_features(df)
332
+ if X is None:
333
+ return False
334
+
335
+ # Проверка на наличие хотя бы двух классов
336
+ if y.nunique() < 2:
337
+ logger.error("Целевая переменная содержит только один класс. Нельзя обучить модель.")
338
+ return False
339
+
340
+ try:
341
+ X_train, X_test, y_train, y_test = train_test_split(
342
+ X, y, test_size=0.2, random_state=42, stratify=y
343
+ )
344
+ except Exception as e:
345
+ logger.warning(f"Ошибка stratify split: {e}. Попробуем без stratify.")
346
+ X_train, X_test, y_train, y_test = train_test_split(
347
+ X, y, test_size=0.2, random_state=42
348
+ )
349
+
350
+ if scaler_type == 'standard':
351
+ scaler = StandardScaler()
352
+ elif scaler_type == 'robust':
353
+ scaler = RobustScaler()
354
+ else:
355
+ scaler = MinMaxScaler()
356
+
357
+ X_train_scaled = scaler.fit_transform(X_train)
358
+ X_test_scaled = scaler.transform(X_test)
359
+
360
+ models = {
361
+ 'Random Forest': RandomForestClassifier(**self.get_progressive_model_params('Random Forest', self.current_iteration)),
362
+ 'Gradient Boosting': GradientBoostingClassifier(**self.get_progressive_model_params('Gradient Boosting', self.current_iteration)),
363
+ 'Extra Trees': ExtraTreesClassifier(**self.get_progressive_model_params('Extra Trees', self.current_iteration)),
364
+ 'Logistic Regression': LogisticRegression(**self.get_progressive_model_params('Logistic Regression', self.current_iteration)),
365
+ 'SVM': SVC(**self.get_progressive_model_params('SVM', self.current_iteration))
366
+ }
367
+
368
+ iteration_results = {}
369
+ cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
370
+
371
+ for name, model in models.items():
372
+ logger.info(f"Обучение {name}...")
373
+ try:
374
+ # Для линейных и SVM используем масштабированные признаки
375
+ if name in ['Logistic Regression', 'SVM']:
376
+ cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=cv, scoring='accuracy', n_jobs=-1)
377
+ model.fit(X_train_scaled, y_train)
378
+ y_pred = model.predict(X_test_scaled)
379
+ else:
380
+ cv_scores = cross_val_score(model, X_train, y_train, cv=cv, scoring='accuracy', n_jobs=-1)
381
+ model.fit(X_train, y_train)
382
+ y_pred = model.predict(X_test)
383
+
384
+ accuracy = accuracy_score(y_test, y_pred)
385
+ cv_mean = float(np.mean(cv_scores))
386
+ cv_std = float(np.std(cv_scores))
387
+
388
+ iteration_results[name] = {
389
+ 'model': model,
390
+ 'scaler': scaler if name in ['Logistic Regression', 'SVM'] else None,
391
+ 'accuracy': accuracy,
392
+ 'cv_mean': cv_mean,
393
+ 'cv_std': cv_std
394
+ }
395
+
396
+ logger.info(f"{name}: Точность={accuracy:.4f}, CV={cv_mean:.4f}±{cv_std:.4f}")
397
+
398
+ if name not in self.best_models or accuracy > self.best_models[name]['accuracy']:
399
+ self.best_models[name] = iteration_results[name].copy()
400
+ logger.info(f"Новая лучшая модель {name}: {accuracy:.4f}")
401
+
402
+ except Exception as e:
403
+ logger.error(f"Ошибка при обучении {name}: {e}")
404
+
405
+ self.training_history.append({
406
+ 'iteration': self.current_iteration + 1,
407
+ 'data_limit': data_limit,
408
+ 'complexity_level': complexity_level,
409
+ 'scaler_type': scaler_type,
410
+ 'results': {k: {'accuracy': v['accuracy'], 'cv_mean': v['cv_mean'], 'cv_std': v['cv_std']} for k, v in iteration_results.items()},
411
+ 'best_accuracy': max([r['accuracy'] for r in iteration_results.values()]) if iteration_results else 0
412
+ })
413
+
414
+ best_accuracy = max([r['accuracy'] for r in iteration_results.values()]) if iteration_results else 0
415
+ logger.info(f"Лучшая точность на итерации: {best_accuracy:.4f}")
416
+ return best_accuracy >= self.target_accuracy
417
+
418
+ def train_until_target_accuracy(self):
419
+ logger.info(f"Начинаем обучение до достижения {self.target_accuracy*100:.2f}% (макс итераций {self.max_iterations})")
420
+ target_reached = False
421
+ iteration = 0
422
+
423
+ while not target_reached and iteration < self.max_iterations:
424
+ self.current_iteration = iteration
425
+ data_limit = self.data_limits[min(iteration // 2, len(self.data_limits) - 1)]
426
+ complexity_level = self.feature_complexity_levels[min(iteration // 2, len(self.feature_complexity_levels) - 1)]
427
+ scaler_type = self.scaler_types[iteration % len(self.scaler_types)]
428
+
429
+ logger.info("\n" + "=" * 60)
430
+ logger.info(f"ИТЕРАЦИЯ {iteration + 1}")
431
+ logger.info("=" * 60)
432
+
433
+ try:
434
+ target_reached = self.train_iteration(data_limit, complexity_level, scaler_type)
435
+ except Exception as e:
436
+ logger.error(f"Критическая ошибка на итерации {iteration+1}: {e}")
437
+ target_reached = False
438
+
439
+ if target_reached:
440
+ logger.info(f"🎉 ЦЕЛЕВАЯ ТОЧНОСТЬ ДОСТИГНУТА НА ИТЕРАЦИИ {iteration + 1}!")
441
+ break
442
+
443
+ iteration += 1
444
+ time.sleep(1)
445
+
446
+ if not target_reached:
447
+ logger.warning("Не удалось достичь целевой точности в отведённом числе итераций.")
448
+ return target_reached
449
+
450
+ def save_best_models(self):
451
+ if not self.best_models:
452
+ logger.error("Нет моделей для сохранения!")
453
+ return False
454
+
455
+ models_dir = 'models'
456
+ os.makedirs(models_dir, exist_ok=True)
457
+
458
+ for name, model_data in self.best_models.items():
459
+ model_filename = f"{name.lower().replace(' ', '_')}_model.joblib"
460
+ model_path = os.path.join(models_dir, model_filename)
461
+ joblib.dump(model_data['model'], model_path)
462
+ if model_data['scaler'] is not None:
463
+ scaler_filename = f"{name.lower().replace(' ', '_')}_scaler.joblib"
464
+ scaler_path = os.path.join(models_dir, scaler_filename)
465
+ joblib.dump(model_data['scaler'], scaler_path)
466
+ logger.info(f"Сохранена модель {name} с точностью {model_data['accuracy']:.4f}")
467
+
468
+ # Сохраняем feature names и метаданные
469
+ features_path = os.path.join(models_dir, 'feature_names.joblib')
470
+ joblib.dump(self.feature_names, features_path)
471
+
472
+ metadata = {
473
+ 'symbol': self.symbol,
474
+ 'interval': self.interval,
475
+ 'target_accuracy': self.target_accuracy,
476
+ 'training_date': datetime.now().isoformat(),
477
+ 'total_iterations': self.current_iteration + 1,
478
+ 'best_accuracies': {name: data['accuracy'] for name, data in self.best_models.items()},
479
+ 'feature_count': len(self.feature_names),
480
+ 'training_history': self.training_history
481
+ }
482
+ metadata_path = os.path.join(models_dir, 'metadata.joblib')
483
+ joblib.dump(metadata, metadata_path)
484
+
485
+ logger.info("Модели и метаданные успешно сохранены.")
486
+ return True
487
+
488
+ def main():
489
+ print("Продвинутая система обучения моделей — train.py")
490
+ symbol = input("Введите торговую пару (по умолчанию BTCUSDT): ").strip().upper() or 'BTCUSDT'
491
+ interval = input("Интервал (1m,5m,1h,4h,1d), по умолчанию 1h: ").strip() or '1h'
492
+ target_accuracy_str = input("Целевая точность (по умолчанию 0.80 или 80%): ").strip() or '0.80'
493
+ try:
494
+ target_accuracy = float(target_accuracy_str)
495
+ if target_accuracy > 1:
496
+ target_accuracy = target_accuracy / 100.0
497
+ except:
498
+ target_accuracy = 0.80
499
+ max_iter_str = input("Максимум итераций (по умолчанию 50): ").strip() or '50'
500
+ try:
501
+ max_iters = int(max_iter_str)
502
+ except:
503
+ max_iters = 50
504
+
505
+ trainer = AdvancedCryptoModelTrainer(symbol=symbol, interval=interval, target_accuracy=target_accuracy, max_iterations=max_iters)
506
+ start_time = time.time()
507
+ try:
508
+ success = trainer.train_until_target_accuracy()
509
+ trainer.save_best_models()
510
+ end_time = time.time()
511
+ mins = (end_time - start_time) / 60.0
512
+ print(f"\nОбучение завершено за {mins:.1f} минут")
513
+ if success:
514
+ print("🎉 Целевая точность достигнута!")
515
+ else:
516
+ print("⚠️ Цель не достигнута — сохранены лучшие модели.")
517
+ except KeyboardInterrupt:
518
+ print("\nПрерывание пользователем. Сохранение лучших моделей (если есть)...")
519
+ trainer.save_best_models()
520
+ except Exception as e:
521
+ logger.error(f"Критическая ошибка: {e}")
522
+ trainer.save_best_models()
523
+
524
+ if __name__ == "__main__":
525
+ main()