changcheng967 commited on
Commit
e3855c5
·
verified ·
1 Parent(s): 65e4224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -74
app.py CHANGED
@@ -16,20 +16,25 @@ from ta.momentum import RSIIndicator
16
  from ta.trend import EMAIndicator
17
  from ta.volume import OnBalanceVolumeIndicator
18
 
19
- # 配置日志
20
  logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
21
 
22
  def enhanced_feature_engineering(df):
23
- """优化的特征工程函数"""
24
  try:
25
- df = df[['Open', 'High', 'Low', 'Close', 'Volume']].copy().astype(float)
 
 
 
 
 
26
 
27
  # 基础特征
28
  df['Returns'] = df['Close'].pct_change()
29
  df['Volatility'] = df['Returns'].rolling(5).std()
30
 
31
- # 技术指标
32
- df['RSI'] = RSIIndicator(df['Close'], window=14).rsi()
33
  df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
34
  df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
35
  df['OBV'] = OnBalanceVolumeIndicator(
@@ -37,24 +42,24 @@ def enhanced_feature_engineering(df):
37
  volume=df['Volume']
38
  ).on_balance_volume()
39
 
40
- # 清理数据
41
  df.replace([np.inf, -np.inf], np.nan, inplace=True)
42
  df.dropna(inplace=True)
43
 
44
- # 选择特征
45
- selected_features = ['Close', 'Returns', 'Volatility', 'RSI', 'EMA_12', 'EMA_26', 'OBV']
46
- return df[selected_features]
47
 
48
  except Exception as e:
49
- logger.error(f"特征工程失败: {str(e)}")
50
  raise
51
 
52
  def robust_training(ticker):
53
- """快速训练函数(30秒超时保证)"""
54
  start_time = time.time()
55
  try:
56
- # 获取数据
57
- logger.info(f"正在获取 {ticker} 数据")
58
  data = yf.download(
59
  ticker,
60
  period="1y",
@@ -66,20 +71,21 @@ def robust_training(ticker):
66
  )
67
 
68
  # 数据验证
69
- if data.empty or len(data) < 30 or 'Close' not in data.columns:
70
- raise ValueError("数据不足")
71
- if data['Close'].isnull().sum() > 5:
72
- raise ValueError("数据存在过多缺失值")
 
 
73
 
74
  # 特征处理
75
- logger.debug("处理特征中...")
76
- data = enhanced_feature_engineering(data)
77
 
78
  # 准备训练数据
79
- X = data.drop(columns=['Close'])
80
- y = data['Close']
81
 
82
- # 初始化模型
83
  model = make_pipeline(
84
  RobustScaler(),
85
  LGBMRegressor(
@@ -87,91 +93,102 @@ def robust_training(ticker):
87
  max_depth=4,
88
  learning_rate=0.15,
89
  verbosity=-1,
90
- force_row_wise=True
 
91
  )
92
  )
93
 
94
- # 快速训练
95
- logger.info("开始快速训练")
96
  tscv = TimeSeriesSplit(n_splits=2)
97
  for train_index, _ in tscv.split(X):
98
- if (time.time() - start_time) > 25:
99
  break
100
- model.fit(X.iloc[train_index], y.iloc[train_index])
 
 
101
 
102
  # 生成预测
103
- logger.debug("生成预测结果")
104
- future_dates = pd.date_range(data.index[-1], periods=8)[1:]
105
  current_features = X.iloc[-1:].copy()
106
  predictions = []
 
107
 
108
  for _ in range(7):
109
- pred = model.predict(current_features)[0]
 
110
  predictions.append(pred)
111
- current_features['Returns'] = (pred - current_features['Close']) / current_features['Close']
 
 
112
  current_features['Close'] = pred
 
113
 
114
  return {
115
- 'data': data,
116
  'predictions': pd.Series(predictions, index=future_dates),
117
  'training_time': time.time() - start_time
118
  }
119
 
120
  except Exception as e:
121
- logger.error(f"训练失败: {str(e)}")
122
  return None
123
 
124
- def create_prediction_plot(result):
125
- """创建交互式图表"""
126
  fig = go.Figure()
127
 
128
  # 历史价格
129
  fig.add_trace(go.Scatter(
130
- x=result['data'].index,
131
- y=result['data']['Close'],
132
  name='历史价格',
133
- line=dict(color='#1f77b4')
134
- ))
135
 
136
  # 预测价格
137
  fig.add_trace(go.Scatter(
138
  x=result['predictions'].index,
139
  y=result['predictions'].values,
140
  name='AI预测',
141
- line=dict(color='#ff7f0e', dash='dot')
142
- ))
143
 
144
  fig.update_layout(
145
- title="股票价格预测",
146
  xaxis_title="日期",
147
  yaxis_title="价格 (USD)",
148
  hovermode="x unified",
149
  template="plotly_white",
150
- margin=dict(t=40, b=20)
 
 
 
 
 
 
151
  )
152
  return fig
153
 
154
- def predict_stock(ticker):
155
- """处理预测流"""
156
  try:
157
  start_time = time.time()
158
- yield "⌛ 正在分析市场数据(30秒内完成)...", None, None
159
 
160
  result = None
161
  error_msg = ""
162
 
163
  # 后台训练线程
164
- def training_task():
165
  nonlocal result
166
  result = robust_training(ticker)
167
 
168
- thread = threading.Thread(target=training_task)
169
  thread.start()
170
 
171
- # 等待线程完成
172
  while thread.is_alive():
173
  if time.time() - start_time > 30:
174
- error_msg = "⏰ 响应超时,请稍后重试"
175
  break
176
  time.sleep(0.1)
177
 
@@ -180,56 +197,65 @@ def predict_stock(ticker):
180
  return
181
 
182
  if not result or result['predictions'].empty:
183
- yield "⚠️ 无法生成预测,请检查股票代码", None, None
184
  return
185
 
186
  # 构建输出信息
187
- time_used = f"{result['training_time']:.1f}秒"
188
- latest_pred = f"{result['predictions'].iloc[-1]:.2f} USD"
189
- info_box = f"""
190
- 分析完成(耗时:{time_used})
 
191
  📅 最新预测日期:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
192
- 💵 预测收盘价:{latest_pred}
193
  """
194
 
195
  # 风险提示
196
- risk_warning = """
197
  **风险提示**
198
- 1. 本预测基于历史数据,不构成投资建议
199
- 2. 实际价格受市场因素影响可能大幅波动
200
- 3. 预测误差可能随预测时间增加而扩大
201
- 4. 过去表现不代表未来
202
  """
203
 
204
- yield info_box, create_prediction_plot(result), risk_warning
205
 
206
  except Exception as e:
207
- logger.critical(f"系统错误: {traceback.format_exc()}")
208
- yield "系统发生意外错误,请联系技术支持", None, None
209
 
210
  # 创建Gradio界面
211
- with gr.Blocks(theme=gr.themes.Soft(), title="智能股票预测") as demo:
212
  gr.Markdown("# 📈 智能股票预测系统")
213
 
214
  with gr.Row():
215
  with gr.Column(scale=2):
216
  ticker_input = gr.Textbox(
217
- label="股票代码",
218
- placeholder="输入股票代码 (如:AAPL, 00700.HK)",
219
  max_lines=1
220
  )
221
- submit_btn = gr.Button("开始预测", variant="primary")
222
 
223
  with gr.Column(scale=3):
224
- status_output = gr.Markdown("## 分析状态")
225
- plot_output = gr.Plot(label="价格走势")
226
- risk_output = gr.Markdown()
227
 
228
  submit_btn.click(
229
- predict_stock,
230
  inputs=ticker_input,
231
- outputs=[status_output, plot_output, risk_output]
232
  )
233
 
234
  if __name__ == "__main__":
235
- demo.launch(server_port=7860)
 
 
 
 
 
 
 
 
 
16
  from ta.trend import EMAIndicator
17
  from ta.volume import OnBalanceVolumeIndicator
18
 
19
+ # 日志配置
20
  logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
21
 
22
  def enhanced_feature_engineering(df):
23
+ """安全可靠的特征工程函数"""
24
  try:
25
+ # 基础数据准备
26
+ df = df[['Open', 'High', 'Low', 'Close', 'Volume']].copy()
27
+ df = df.astype({
28
+ 'Open': float, 'High': float, 'Low': float,
29
+ 'Close': float, 'Volume': float
30
+ })
31
 
32
  # 基础特征
33
  df['Returns'] = df['Close'].pct_change()
34
  df['Volatility'] = df['Returns'].rolling(5).std()
35
 
36
+ # 技术指标(逐个添加避免冲突)
37
+ df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi()
38
  df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
39
  df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
40
  df['OBV'] = OnBalanceVolumeIndicator(
 
42
  volume=df['Volume']
43
  ).on_balance_volume()
44
 
45
+ # 清理异常值
46
  df.replace([np.inf, -np.inf], np.nan, inplace=True)
47
  df.dropna(inplace=True)
48
 
49
+ # 特征选择
50
+ final_features = ['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']
51
+ return df[final_features]
52
 
53
  except Exception as e:
54
+ logger.error(f"特征工程失败: {traceback.format_exc()}")
55
  raise
56
 
57
  def robust_training(ticker):
58
+ """安全可靠的训练函数"""
59
  start_time = time.time()
60
  try:
61
+ # 数据获取
62
+ logger.info(f"开始处理股票代码: {ticker}")
63
  data = yf.download(
64
  ticker,
65
  period="1y",
 
71
  )
72
 
73
  # 数据验证
74
+ if data.empty or len(data) < 30:
75
+ raise ValueError("效数据不足(最少需要30个交易日数据)")
76
+ if 'Close' not in data.columns:
77
+ raise ValueError("无效的股票数据格式")
78
+ if data['Close'].isnull().sum() > 5: # 明确数值比较
79
+ raise ValueError("存在过多缺失值")
80
 
81
  # 特征处理
82
+ processed_data = enhanced_feature_engineering(data)
 
83
 
84
  # 准备训练数据
85
+ X = processed_data.drop(columns=['Close'])
86
+ y = processed_data['Close']
87
 
88
+ # 模型配置
89
  model = make_pipeline(
90
  RobustScaler(),
91
  LGBMRegressor(
 
93
  max_depth=4,
94
  learning_rate=0.15,
95
  verbosity=-1,
96
+ force_row_wise=True,
97
+ random_state=42
98
  )
99
  )
100
 
101
+ # 训练流程
 
102
  tscv = TimeSeriesSplit(n_splits=2)
103
  for train_index, _ in tscv.split(X):
104
+ if (time.time() - start_time) > 25: # 保留5秒预测时间
105
  break
106
+ X_train = X.iloc[train_index].values # 转换为numpy数组
107
+ y_train = y.iloc[train_index].values
108
+ model.fit(X_train, y_train)
109
 
110
  # 生成预测
 
 
111
  current_features = X.iloc[-1:].copy()
112
  predictions = []
113
+ future_dates = pd.date_range(data.index[-1], periods=8)[1:]
114
 
115
  for _ in range(7):
116
+ current_close = current_features['Close'].values[0]
117
+ pred = model.predict(current_features.values)[0]
118
  predictions.append(pred)
119
+
120
+ # 更新特征(标量运算)
121
+ current_features['Returns'] = (pred - current_close) / current_close
122
  current_features['Close'] = pred
123
+ current_features['Volatility'] = current_features['Volatility'].values[0]
124
 
125
  return {
126
+ 'historical_data': data,
127
  'predictions': pd.Series(predictions, index=future_dates),
128
  'training_time': time.time() - start_time
129
  }
130
 
131
  except Exception as e:
132
+ logger.error(f"训练失败: {str(e)}\n{traceback.format_exc()}")
133
  return None
134
 
135
+ def create_visualization(result):
136
+ """创建可视化图表"""
137
  fig = go.Figure()
138
 
139
  # 历史价格
140
  fig.add_trace(go.Scatter(
141
+ x=result['historical_data'].index,
142
+ y=result['historical_data']['Close'],
143
  name='历史价格',
144
+ line=dict(color='#1f77b4', width=2)
145
+ )
146
 
147
  # 预测价格
148
  fig.add_trace(go.Scatter(
149
  x=result['predictions'].index,
150
  y=result['predictions'].values,
151
  name='AI预测',
152
+ line=dict(color='#ff7f0e', width=2, dash='dot')
153
+ )
154
 
155
  fig.update_layout(
156
+ title=f"股票价格预测结果",
157
  xaxis_title="日期",
158
  yaxis_title="价格 (USD)",
159
  hovermode="x unified",
160
  template="plotly_white",
161
+ legend=dict(
162
+ orientation="h",
163
+ yanchor="bottom",
164
+ y=1.02,
165
+ xanchor="right",
166
+ x=1
167
+ )
168
  )
169
  return fig
170
 
171
+ def prediction_workflow(ticker):
172
+ """完整的预测工作流"""
173
  try:
174
  start_time = time.time()
175
+ yield "⌛ 正在快速分析市场数据(预计30秒内完成)...", None, None
176
 
177
  result = None
178
  error_msg = ""
179
 
180
  # 后台训练线程
181
+ def training_job():
182
  nonlocal result
183
  result = robust_training(ticker)
184
 
185
+ thread = threading.Thread(target=training_job)
186
  thread.start()
187
 
188
+ # 等待处理
189
  while thread.is_alive():
190
  if time.time() - start_time > 30:
191
+ error_msg = "⏰ 系统响应超时,请稍后重试"
192
  break
193
  time.sleep(0.1)
194
 
 
197
  return
198
 
199
  if not result or result['predictions'].empty:
200
+ yield "⚠️ 无法生成预测,请检查股票代码有效性", None, None
201
  return
202
 
203
  # 构建输出信息
204
+ training_time = f"{result['training_time']:.1f}秒"
205
+ latest_pred = f"{result['predictions'].iloc[-1]:.2f}"
206
+ info_content = f"""
207
+ ## 分析结果
208
+ ✅ 成功完成分析(耗时:{training_time})
209
  📅 最新预测日期:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
210
+ 💵 预测收盘价:{latest_pred} USD
211
  """
212
 
213
  # 风险提示
214
+ risk_content = """
215
  **风险提示**
216
+ 1. 本预测基于历史数据建模,不构成投资建议
217
+ 2. 实际价格可能受市场突发事件影响
218
+ 3. 预测准确率随预测时间跨度增加而降低
219
+ 4. 合其他信息进行综合判断
220
  """
221
 
222
+ yield info_content, create_visualization(result), risk_content
223
 
224
  except Exception as e:
225
+ logger.critical(f"系统错误: {traceback.format_exc()}")
226
+ yield "��️ 发生意外错误,请联系技术支持", None, None
227
 
228
  # 创建Gradio界面
229
+ with gr.Blocks(theme=gr.themes.Soft(), title="智能股票预测系统") as demo:
230
  gr.Markdown("# 📈 智能股票预测系统")
231
 
232
  with gr.Row():
233
  with gr.Column(scale=2):
234
  ticker_input = gr.Textbox(
235
+ label="输入股票代码",
236
+ placeholder="如:AAPL (苹果), 00700.HK (腾讯)",
237
  max_lines=1
238
  )
239
+ submit_btn = gr.Button("开始分析", variant="primary")
240
 
241
  with gr.Column(scale=3):
242
+ status_display = gr.Markdown("## 当前状态")
243
+ plot_display = gr.Plot(label="价格走势")
244
+ risk_display = gr.Markdown()
245
 
246
  submit_btn.click(
247
+ prediction_workflow,
248
  inputs=ticker_input,
249
+ outputs=[status_display, plot_display, risk_display]
250
  )
251
 
252
  if __name__ == "__main__":
253
+ # 启动配置
254
+ import warnings
255
+ warnings.filterwarnings("ignore")
256
+
257
+ demo.launch(
258
+ server_port=7860,
259
+ show_error=True,
260
+ enable_queue=True
261
+ )