#!/usr/bin/env python3 """ 技术指标计算模块 - 纯pandas实现 支持常用的技术分析指标计算,无需外部依赖 """ import pandas as pd import numpy as np def calculate_sma(data, window): """简单移动平均线 (Simple Moving Average)""" return data.rolling(window=window, min_periods=1).mean() def calculate_ema(data, window): """指数移动平均线 (Exponential Moving Average)""" return data.ewm(span=window, adjust=False).mean() def calculate_rsi(close, window=14): """相对强弱指数 (Relative Strength Index)""" delta = close.diff() gain = delta.where(delta > 0, 0) loss = -delta.where(delta < 0, 0) avg_gain = gain.rolling(window=window, min_periods=1).mean() avg_loss = loss.rolling(window=window, min_periods=1).mean() rs = avg_gain / avg_loss rsi = 100 - (100 / (1 + rs)) return rsi def calculate_macd(close, fast=12, slow=26, signal=9): """MACD指标 (Moving Average Convergence Divergence)""" ema_fast = calculate_ema(close, fast) ema_slow = calculate_ema(close, slow) macd = ema_fast - ema_slow macd_signal = calculate_ema(macd, signal) macd_hist = macd - macd_signal return macd, macd_signal, macd_hist def calculate_bollinger_bands(close, window=20, std_dev=2): """布林带 (Bollinger Bands)""" sma = calculate_sma(close, window) std = close.rolling(window=window, min_periods=1).std() upper = sma + (std * std_dev) lower = sma - (std * std_dev) return upper, sma, lower def calculate_stochastic(high, low, close, k_window=14, d_window=3): """随机指标 (Stochastic Oscillator)""" lowest_low = low.rolling(window=k_window, min_periods=1).min() highest_high = high.rolling(window=k_window, min_periods=1).max() # 避免除零错误 range_hl = highest_high - lowest_low range_hl = range_hl.replace(0, np.nan) k_percent = 100 * ((close - lowest_low) / range_hl) d_percent = k_percent.rolling(window=d_window, min_periods=1).mean() return k_percent, d_percent def calculate_atr(high, low, close, window=14): """平均真实波幅 (Average True Range)""" tr1 = high - low tr2 = abs(high - close.shift(1)) tr3 = abs(low - close.shift(1)) # 计算真实波幅 tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) # 计算ATR atr = tr.rolling(window=window, min_periods=1).mean() return atr def calculate_williams_r(high, low, close, window=14): """威廉指标 (Williams %R)""" highest_high = high.rolling(window=window, min_periods=1).max() lowest_low = low.rolling(window=window, min_periods=1).min() # 避免除零错误 range_hl = highest_high - lowest_low range_hl = range_hl.replace(0, np.nan) wr = -100 * ((highest_high - close) / range_hl) return wr def add_technical_indicators(df, indicators_config=None): """ 为DataFrame添加技术指标 Args: df: 包含OHLCV数据的DataFrame indicators_config: 指标配置字典 Returns: 添加了技术指标的DataFrame """ if indicators_config is None: # 默认指标配置 - 简化配置,减少指标数量 indicators_config = { 'sma': [5, 10, 20], 'ema': [12, 26], 'rsi': [14], 'macd': True, 'bollinger': True, 'atr': [14] } df = df.copy() # 确保必要的列存在 required_cols = ['open', 'high', 'low', 'close'] if not all(col in df.columns for col in required_cols): raise ValueError(f"DataFrame must contain columns: {required_cols}") try: # 简单移动平均线 if 'sma' in indicators_config: for period in indicators_config['sma']: df[f'sma_{period}'] = calculate_sma(df['close'], period) # 指数移动平均线 if 'ema' in indicators_config: for period in indicators_config['ema']: df[f'ema_{period}'] = calculate_ema(df['close'], period) # RSI if 'rsi' in indicators_config: for period in indicators_config['rsi']: df[f'rsi_{period}'] = calculate_rsi(df['close'], period) # MACD if indicators_config.get('macd'): macd, macd_signal, macd_hist = calculate_macd(df['close']) df['macd'] = macd df['macd_signal'] = macd_signal df['macd_hist'] = macd_hist # 布林带 if indicators_config.get('bollinger'): bb_upper, bb_middle, bb_lower = calculate_bollinger_bands(df['close']) df['bb_upper'] = bb_upper df['bb_middle'] = bb_middle df['bb_lower'] = bb_lower # ATR if 'atr' in indicators_config: for period in indicators_config['atr']: df[f'atr_{period}'] = calculate_atr(df['high'], df['low'], df['close'], period) # 填充NaN值而不是删除行 df = df.fillna(method='bfill').fillna(method='ffill') # 计算添加的指标数量 basic_cols = ['open', 'high', 'low', 'close', 'volume', 'amount', 'timestamps'] indicator_cols = [col for col in df.columns if col not in basic_cols] print(f"✅ 技术指标计算完成,添加了 {len(indicator_cols)} 个指标") except Exception as e: print(f"❌ 技术指标计算失败: {e}") # 如果指标计算失败,返回原始数据 basic_cols = ['open', 'high', 'low', 'close'] if 'volume' in df.columns: basic_cols.append('volume') if 'amount' in df.columns: basic_cols.append('amount') if 'timestamps' in df.columns: basic_cols.append('timestamps') return df[basic_cols] return df def get_available_indicators(): """获取可用的技术指标列表""" indicators = { 'trend': ['sma', 'ema', 'macd', 'bollinger'], 'momentum': ['rsi'], 'volatility': ['atr'], 'volume': [] } return indicators