OmidSakaki commited on
Commit
2705a8c
·
verified ·
1 Parent(s): d9381d2

Update src/utils/config.py

Browse files
Files changed (1) hide show
  1. src/utils/config.py +287 -21
src/utils/config.py CHANGED
@@ -1,24 +1,290 @@
 
 
 
 
 
 
 
 
 
 
1
  class TradingConfig:
2
- def __init__(self):
3
- # Environment settings
4
- self.initial_balance = 10000
5
- self.max_steps = 1000
6
- self.transaction_cost = 0.001
7
-
8
- # AI Agent settings
9
- self.learning_rate = 0.001
10
- self.gamma = 0.99
11
- self.epsilon_start = 1.0
12
- self.epsilon_min = 0.1
13
- self.epsilon_decay = 0.995
14
- self.memory_size = 500
15
- self.batch_size = 16
16
-
17
- # Visualization settings
18
- self.chart_width = 800
19
- self.chart_height = 600
20
- self.update_interval = 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def to_dict(self):
23
- return {key: value for key, value in self.__dict__.items()
24
- if not key.startswith('_')}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, Any, Optional
4
+ from dataclasses import dataclass, asdict, field
5
+ from pathlib import Path
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ @dataclass
11
  class TradingConfig:
12
+ """Comprehensive trading configuration with validation and persistence"""
13
+
14
+ # Environment settings
15
+ initial_balance: float = 10000.0
16
+ max_steps: int = 1000
17
+ transaction_cost: float = 0.001
18
+ risk_level: str = "Medium"
19
+ asset_type: str = "Crypto"
20
+
21
+ # AI Agent settings
22
+ learning_rate: float = 0.001
23
+ gamma: float = 0.99
24
+ epsilon_start: float = 1.0
25
+ epsilon_min: float = 0.01
26
+ epsilon_decay: float = 0.9995
27
+ memory_size: int = 10000
28
+ batch_size: int = 32
29
+ target_update_freq: int = 100
30
+ gradient_clip: float = 1.0
31
+
32
+ # Sentiment settings
33
+ use_sentiment: bool = True
34
+ sentiment_influence: float = 0.3
35
+ sentiment_update_freq: int = 5
36
+
37
+ # Visualization settings
38
+ chart_width: int = 800
39
+ chart_height: int = 600
40
+ update_interval: int = 100
41
+ enable_visualization: bool = True
42
+
43
+ # Training settings
44
+ max_episodes: int = 1000
45
+ eval_episodes: int = 10
46
+ eval_freq: int = 100
47
+ save_freq: int = 500
48
+ log_level: str = "INFO"
49
+
50
+ # Paths
51
+ model_dir: str = "models"
52
+ log_dir: str = "logs"
53
+ data_dir: str = "data"
54
+
55
+ # Device settings
56
+ use_cuda: bool = True
57
+ device: str = "auto"
58
+
59
+ def __post_init__(self):
60
+ """Validate and initialize configuration"""
61
+ self._validate()
62
+ self._setup_paths()
63
+ self._setup_device()
64
+ self._setup_logging()
65
+
66
+ def _validate(self):
67
+ """Validate configuration parameters"""
68
+ errors = []
69
 
70
+ # Balance validation
71
+ if self.initial_balance <= 0:
72
+ errors.append("initial_balance must be positive")
73
+
74
+ # Steps validation
75
+ if self.max_steps <= 0:
76
+ errors.append("max_steps must be positive")
77
+
78
+ # Costs validation
79
+ if not 0.0 <= self.transaction_cost <= 0.1:
80
+ errors.append("transaction_cost should be between 0 and 0.1")
81
+
82
+ # Learning rate validation
83
+ if not 0.0001 <= self.learning_rate <= 0.1:
84
+ errors.append("learning_rate should be between 0.0001 and 0.1")
85
+
86
+ # Discount factor validation
87
+ if not 0.0 <= self.gamma <= 1.0:
88
+ errors.append("gamma must be between 0 and 1")
89
+
90
+ # Epsilon validation
91
+ if not 0.0 <= self.epsilon_min <= self.epsilon_start <= 1.0:
92
+ errors.append("epsilon values must satisfy 0 <= epsilon_min <= epsilon_start <= 1")
93
+
94
+ # Batch size validation
95
+ if self.batch_size > self.memory_size:
96
+ errors.append("batch_size cannot exceed memory_size")
97
+
98
+ # Risk level validation
99
+ valid_risks = ["Low", "Medium", "High"]
100
+ if self.risk_level not in valid_risks:
101
+ errors.append(f"risk_level must be one of {valid_risks}")
102
+
103
+ # Asset type validation
104
+ valid_assets = ["Crypto", "Stocks", "Forex", "Commodities"]
105
+ if self.asset_type not in valid_assets:
106
+ errors.append(f"asset_type must be one of {valid_assets}")
107
+
108
+ # Sentiment influence validation
109
+ if not 0.0 <= self.sentiment_influence <= 1.0:
110
+ errors.append("sentiment_influence must be between 0 and 1")
111
+
112
+ if errors:
113
+ logger.error(f"Configuration validation errors: {errors}")
114
+ raise ValueError(f"Invalid configuration: {'; '.join(errors)}")
115
+
116
+ logger.info("Configuration validation passed")
117
+
118
+ def _setup_paths(self):
119
+ """Create necessary directories"""
120
+ for path_attr in ['model_dir', 'log_dir', 'data_dir']:
121
+ path = Path(getattr(self, path_attr))
122
+ path.mkdir(parents=True, exist_ok=True)
123
+ setattr(self, f"{path_attr}_path", path)
124
+
125
+ def _setup_device(self):
126
+ """Setup device configuration"""
127
+ import torch
128
+ if self.device == "auto":
129
+ self.device = "cuda" if self.use_cuda and torch.cuda.is_available() else "cpu"
130
+ else:
131
+ if self.device not in ["cpu", "cuda", "mps"]:
132
+ logger.warning(f"Unknown device {self.device}, defaulting to CPU")
133
+ self.device = "cpu"
134
+
135
+ logger.info(f"Using device: {self.device}")
136
+
137
+ def _setup_logging(self):
138
+ """Setup logging configuration"""
139
+ import logging
140
+ log_level = getattr(logging, self.log_level.upper())
141
+ logging.getLogger().setLevel(log_level)
142
+
143
+ def to_dict(self) -> Dict[str, Any]:
144
+ """Convert config to dictionary, excluding sensitive paths"""
145
+ config_dict = asdict(self)
146
+ # Remove absolute paths for serialization
147
+ for key in list(config_dict.keys()):
148
+ if key.endswith('_path') or 'dir' in key:
149
+ config_dict[key] = str(getattr(self, key)) if isinstance(getattr(self, key), Path) else getattr(self, key)
150
+ return config_dict
151
+
152
+ def to_json(self, filepath: Optional[str] = None) -> str:
153
+ """Serialize config to JSON"""
154
+ config_dict = self.to_dict()
155
+ json_str = json.dumps(config_dict, indent=2, default=str)
156
+
157
+ if filepath:
158
+ with open(filepath, 'w') as f:
159
+ f.write(json_str)
160
+ logger.info(f"Config saved to {filepath}")
161
+
162
+ return json_str
163
+
164
+ @classmethod
165
+ def from_json(cls, filepath: str) -> 'TradingConfig':
166
+ """Load config from JSON file"""
167
+ try:
168
+ with open(filepath, 'r') as f:
169
+ config_dict = json.load(f)
170
+
171
+ # Create dataclass instance
172
+ config = cls(**config_dict)
173
+ logger.info(f"Config loaded from {filepath}")
174
+ return config
175
+ except Exception as e:
176
+ logger.error(f"Error loading config from {filepath}: {e}")
177
+ raise
178
+
179
+ @classmethod
180
+ def from_dict(cls, config_dict: Dict[str, Any]) -> 'TradingConfig':
181
+ """Create config from dictionary"""
182
+ return cls(**config_dict)
183
+
184
+ def save(self, filepath: str):
185
+ """Save config to file"""
186
+ self.to_json(filepath)
187
+
188
+ @staticmethod
189
+ def load(filepath: str) -> 'TradingConfig':
190
+ """Static method to load config"""
191
+ return TradingConfig.from_json(filepath)
192
+
193
+ def update(self, **kwargs):
194
+ """Update config parameters and revalidate"""
195
+ for key, value in kwargs.items():
196
+ if hasattr(self, key):
197
+ setattr(self, key, value)
198
+ else:
199
+ logger.warning(f"Unknown config parameter: {key}")
200
+
201
+ self._validate()
202
+ logger.info("Config updated and validated")
203
+
204
+ def get_agent_params(self) -> Dict[str, Any]:
205
+ """Get parameters specific to agent"""
206
+ return {
207
+ 'learning_rate': self.learning_rate,
208
+ 'gamma': self.gamma,
209
+ 'epsilon_start': self.epsilon_start,
210
+ 'epsilon_min': self.epsilon_min,
211
+ 'epsilon_decay': self.epsilon_decay,
212
+ 'memory_size': self.memory_size,
213
+ 'batch_size': self.batch_size,
214
+ 'target_update_freq': self.target_update_freq,
215
+ 'gradient_clip': self.gradient_clip,
216
+ 'device': self.device
217
+ }
218
+
219
+ def get_env_params(self) -> Dict[str, Any]:
220
+ """Get parameters specific to environment"""
221
+ return {
222
+ 'initial_balance': self.initial_balance,
223
+ 'max_steps': self.max_steps,
224
+ 'transaction_cost': self.transaction_cost,
225
+ 'risk_level': self.risk_level,
226
+ 'asset_type': self.asset_type,
227
+ 'use_sentiment': self.use_sentiment,
228
+ 'sentiment_influence': self.sentiment_influence,
229
+ 'sentiment_update_freq': self.sentiment_update_freq
230
+ }
231
+
232
+ def __str__(self) -> str:
233
+ """String representation of config"""
234
+ return json.dumps(self.to_dict(), indent=2)
235
+
236
+
237
+ # Legacy compatibility
238
+ class LegacyTradingConfig:
239
+ """Wrapper for backward compatibility"""
240
+
241
+ def __init__(self, config_file: Optional[str] = None):
242
+ if config_file and os.path.exists(config_file):
243
+ self.config = TradingConfig.from_json(config_file)
244
+ else:
245
+ self.config = TradingConfig()
246
+
247
+ def __getattr__(self, name):
248
+ return getattr(self.config, name)
249
+
250
  def to_dict(self):
251
+ return self.config.to_dict()
252
+
253
+
254
+ # Default config instance
255
+ DEFAULT_CONFIG = TradingConfig()
256
+
257
+ # Example usage and config loading
258
+ def create_config_from_env() -> TradingConfig:
259
+ """Create config from environment variables"""
260
+ import os
261
+ config_dict = {}
262
+
263
+ env_mappings = {
264
+ 'INITIAL_BALANCE': 'initial_balance',
265
+ 'MAX_STEPS': 'max_steps',
266
+ 'LEARNING_RATE': 'learning_rate',
267
+ 'BATCH_SIZE': 'batch_size',
268
+ 'USE_CUDA': 'use_cuda'
269
+ }
270
+
271
+ for env_var, config_key in env_mappings.items():
272
+ env_value = os.getenv(env_var)
273
+ if env_value is not None:
274
+ try:
275
+ # Try to convert to appropriate type
276
+ if config_key in ['initial_balance', 'learning_rate']:
277
+ config_dict[config_key] = float(env_value)
278
+ elif config_key in ['max_steps', 'batch_size']:
279
+ config_dict[config_key] = int(env_value)
280
+ elif config_key == 'use_cuda':
281
+ config_dict[config_key] = env_value.lower() in ('true', '1', 'yes')
282
+ except ValueError:
283
+ logger.warning(f"Invalid environment variable {env_var}: {env_value}")
284
+
285
+ if config_dict:
286
+ base_config = TradingConfig()
287
+ base_config.update(**config_dict)
288
+ return base_config
289
+
290
+ return DEFAULT_CONFIG