Spaces:
Paused
Paused
| import logging | |
| from typing import Dict, List, Optional, Tuple, Any | |
| import yaml | |
| import time | |
| from datetime import datetime, timedelta | |
| logger = logging.getLogger(__name__) | |
| class RiskManager: | |
| def __init__(self, exchange_client): | |
| self.exchange = exchange_client | |
| self.settings = yaml.safe_load(open("config/settings.yaml")) | |
| self.leverage = self.settings["trading"]["leverage"] | |
| self.tp_percent = self.settings["trading"]["tp_percent"] | |
| self.sl_percent = self.settings["trading"]["sl_percent"] | |
| self.risk_per_trade = self.settings["risk"]["risk_per_trade"] | |
| self.max_daily_loss = 100 | |
| self.daily_pnl = 0.0 | |
| self.daily_trades = 0 | |
| self.max_daily_trades = 50 | |
| self.max_open_positions = 1 | |
| self.min_order_size = 0.001 | |
| self.open_positions = {} | |
| self.emergency_stop = False | |
| self.last_reset_time = datetime.now() | |
| self.trade_history = [] | |
| self.win_count = 0 | |
| self.loss_count = 0 | |
| def calculate_position_size(self, symbol: str, entry_price: float, side: str) -> float: | |
| try: | |
| balance_data = self.exchange.get_balance() | |
| if not balance_data: | |
| logger.error("Failed to get account balance") | |
| return 0.0 | |
| usdt_balance = 0.0 | |
| for coin in balance_data: | |
| if coin.get("coin") == "USDT": | |
| usdt_balance = float(coin.get("walletBalance", 0)) | |
| break | |
| if usdt_balance <= 0: | |
| logger.warning("Insufficient USDT balance") | |
| return 0.0 | |
| risk_amount = usdt_balance * self.risk_per_trade | |
| sl_distance = entry_price * self.sl_percent | |
| position_size = risk_amount / (sl_distance * self.leverage) | |
| position_size = max(position_size, self.min_order_size) | |
| if self.daily_pnl < -self.max_daily_loss: | |
| logger.warning(f"Daily loss limit reached: {self.daily_pnl}") | |
| return 0.0 | |
| if self.daily_trades >= self.max_daily_trades: | |
| logger.warning("Max daily trades reached") | |
| return 0.0 | |
| if symbol in self.open_positions: | |
| logger.warning(f"Position already open for {symbol}") | |
| return 0.0 | |
| position_size = self._validate_position_size(symbol, position_size, entry_price) | |
| logger.info(f"Calculated position size: {position_size:.4f} for {symbol}") | |
| return position_size | |
| except Exception as e: | |
| logger.error(f"Error calculating position size: {e}") | |
| return 0.0 | |
| def _validate_position_size(self, symbol: str, size: float, price: float) -> float: | |
| try: | |
| min_size = 0.001 | |
| max_size = 100.0 | |
| validated_size = max(min_size, min(size, max_size)) | |
| if symbol.startswith('BTC'): | |
| validated_size = round(validated_size, 3) | |
| elif symbol.startswith('ETH'): | |
| validated_size = round(validated_size, 2) | |
| else: | |
| validated_size = round(validated_size, 1) | |
| return validated_size | |
| except Exception as e: | |
| logger.error(f"Error validating position size: {e}") | |
| return 0.0 | |
| def validate_entry_signal(self, symbol: str, signal: str, confidence: float) -> bool: | |
| try: | |
| if self.emergency_stop: | |
| logger.warning("Emergency stop activated") | |
| return False | |
| min_confidence = 0.6 | |
| if confidence < min_confidence: | |
| logger.info(f"Signal confidence too low: {confidence}") | |
| return False | |
| if symbol in self.open_positions: | |
| logger.warning(f"Position already exists for {symbol}") | |
| return False | |
| if not self._check_market_conditions(symbol): | |
| return False | |
| if self._is_high_volatility(symbol): | |
| logger.warning(f"High volatility detected for {symbol}") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error validating entry signal: {e}") | |
| return False | |
| def _check_market_conditions(self, symbol: str) -> bool: | |
| try: | |
| ticker = self.exchange.get_ticker(symbol) | |
| if not ticker: | |
| return False | |
| volume_24h = float(ticker.get("volume24h", 0)) | |
| if volume_24h < 100000: | |
| logger.warning(f"Low volume for {symbol}: {volume_24h}") | |
| return False | |
| bid_price = float(ticker.get("bid1Price", 0)) | |
| ask_price = float(ticker.get("ask1Price", 0)) | |
| if bid_price == 0 or ask_price == 0: | |
| return False | |
| spread = (ask_price - bid_price) / bid_price | |
| max_spread = 0.001 | |
| if spread > max_spread: | |
| logger.warning(f"Spread too wide for {symbol}: {spread}") | |
| return False | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking market conditions: {e}") | |
| return False | |
| def _is_high_volatility(self, symbol: str) -> bool: | |
| try: | |
| prices = self.exchange.get_kline_data(symbol, interval="1", limit=10) | |
| if not prices or len(prices) < 5: | |
| return True | |
| high = max(float(candle[2]) for candle in prices) | |
| low = min(float(candle[3]) for candle in prices) | |
| current_price = float(prices[-1][4]) | |
| volatility = (high - low) / current_price | |
| high_vol_threshold = 0.02 | |
| return volatility > high_vol_threshold | |
| except Exception as e: | |
| logger.error(f"Error checking volatility: {e}") | |
| return True | |
| def update_position(self, symbol: str, position_data: Dict[str, Any]): | |
| try: | |
| self.open_positions[symbol] = { | |
| 'entry_time': datetime.now(), | |
| 'entry_price': float(position_data.get('avgPrice', 0)), | |
| 'size': float(position_data.get('qty', 0)), | |
| 'side': position_data.get('side', ''), | |
| 'leverage': self.leverage | |
| } | |
| logger.info(f"Position updated for {symbol}: {self.open_positions[symbol]}") | |
| except Exception as e: | |
| logger.error(f"Error updating position for {symbol}: {e}") | |
| def close_position(self, symbol: str, reason: str = "manual"): | |
| try: | |
| if symbol not in self.open_positions: | |
| logger.warning(f"No position found for {symbol}") | |
| return False | |
| position = self.open_positions[symbol] | |
| exit_price = self.exchange.get_ticker(symbol) | |
| if exit_price: | |
| exit_price = float(exit_price.get("lastPrice", 0)) | |
| entry_price = position['entry_price'] | |
| if position['side'] == 'Buy': | |
| pnl = (exit_price - entry_price) / entry_price * position['size'] * self.leverage | |
| else: | |
| pnl = (entry_price - exit_price) / entry_price * position['size'] * self.leverage | |
| self.daily_pnl += pnl | |
| if pnl > 0: | |
| self.win_count += 1 | |
| else: | |
| self.loss_count += 1 | |
| self._record_trade(symbol, position, exit_price, pnl, reason) | |
| del self.open_positions[symbol] | |
| self.daily_trades += 1 | |
| logger.info(f"Position closed for {symbol}, reason: {reason}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error closing position for {symbol}: {e}") | |
| return False | |
| def _record_trade(self, symbol: str, position: Dict, exit_price: float, pnl: float, reason: str): | |
| try: | |
| trade = { | |
| 'symbol': symbol, | |
| 'entry_time': position['entry_time'], | |
| 'exit_time': datetime.now(), | |
| 'entry_price': position['entry_price'], | |
| 'exit_price': exit_price, | |
| 'size': position['size'], | |
| 'side': position['side'], | |
| 'pnl': pnl, | |
| 'reason': reason, | |
| 'leverage': position['leverage'] | |
| } | |
| self.trade_history.append(trade) | |
| if len(self.trade_history) > 1000: | |
| self.trade_history = self.trade_history[-1000:] | |
| except Exception as e: | |
| logger.error(f"Error recording trade: {e}") | |
| def check_tp_sl_hit(self, symbol: str) -> Optional[str]: | |
| try: | |
| if symbol not in self.open_positions: | |
| return None | |
| position = self.open_positions[symbol] | |
| current_price = self.exchange.get_ticker(symbol) | |
| if not current_price: | |
| return None | |
| current_price = float(current_price.get("lastPrice", 0)) | |
| entry_price = position['entry_price'] | |
| if position['side'] == 'Buy': | |
| tp_price = entry_price * (1 + self.tp_percent) | |
| sl_price = entry_price * (1 - self.sl_percent) | |
| if current_price >= tp_price: | |
| return "TP" | |
| elif current_price <= sl_price: | |
| return "SL" | |
| else: | |
| tp_price = entry_price * (1 - self.tp_percent) | |
| sl_price = entry_price * (1 + self.sl_percent) | |
| if current_price <= tp_price: | |
| return "TP" | |
| elif current_price >= sl_price: | |
| return "SL" | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error checking TP/SL for {symbol}: {e}") | |
| return None | |
| def emergency_stop_all(self): | |
| try: | |
| self.emergency_stop = True | |
| for symbol in list(self.open_positions.keys()): | |
| self.exchange.close_position(symbol) | |
| self.close_position(symbol, "emergency_stop") | |
| logger.critical("Emergency stop activated - all positions closed") | |
| except Exception as e: | |
| logger.error(f"Error in emergency stop: {e}") | |
| def reset_daily_stats(self): | |
| try: | |
| now = datetime.now() | |
| if now.date() > self.last_reset_time.date(): | |
| self.daily_pnl = 0.0 | |
| self.daily_trades = 0 | |
| self.last_reset_time = now | |
| logger.info("Daily statistics reset") | |
| except Exception as e: | |
| logger.error(f"Error resetting daily stats: {e}") | |
| def get_risk_status(self) -> Dict[str, Any]: | |
| try: | |
| total_trades = self.win_count + self.loss_count | |
| win_rate = self.win_count / total_trades if total_trades > 0 else 0.0 | |
| status = { | |
| 'emergency_stop': self.emergency_stop, | |
| 'open_positions': len(self.open_positions), | |
| 'daily_pnl': self.daily_pnl, | |
| 'daily_trades': self.daily_trades, | |
| 'win_rate': win_rate, | |
| 'total_trades': total_trades, | |
| 'positions': list(self.open_positions.keys()) | |
| } | |
| return status | |
| except Exception as e: | |
| logger.error(f"Error getting risk status: {e}") | |
| return {'error': str(e)} | |