changcheng967 commited on
Commit
14d684f
·
verified ·
1 Parent(s): f9e9621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -106
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import plotly.graph_objs as go
7
  from sklearn.pipeline import make_pipeline
8
  from sklearn.preprocessing import RobustScaler
9
- from sklearn.model_selection import TimeSeriesSplit
10
  from lightgbm import LGBMRegressor
11
  from loguru import logger
12
  import threading
@@ -19,21 +18,23 @@ from ta.volume import OnBalanceVolumeIndicator
19
  # 日志配置
20
  logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
21
 
22
- def enhanced_feature_engineering(df):
23
- """安全可靠的特征工程函数"""
24
  try:
25
- # 基础数据准备
26
- df = df[['Open', 'High', 'Low', 'Close', 'Volume']].copy()
27
- df = df.astype({
28
- 'Open': float, 'High': float, 'Low': float,
29
- 'Close': float, 'Volume': float
30
- })
 
 
31
 
32
  # 基础特征
33
  df['Returns'] = df['Close'].pct_change()
34
  df['Volatility'] = df['Returns'].rolling(5).std()
35
 
36
- # 技术指标
37
  df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi()
38
  df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
39
  df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
@@ -42,48 +43,51 @@ def enhanced_feature_engineering(df):
42
  volume=df['Volume']
43
  ).on_balance_volume()
44
 
45
- # 清理异常值
46
  df.replace([np.inf, -np.inf], np.nan, inplace=True)
47
  df.dropna(inplace=True)
48
 
49
  return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']]
50
 
51
  except Exception as e:
52
- logger.error(f"特征工程失败: {traceback.format_exc()}")
53
  raise
54
 
55
- def robust_training(ticker):
56
- """安全可靠的训练函数"""
57
  start_time = time.time()
58
  try:
59
  # 数据获取
60
- logger.info(f"开始处理股票代码: {ticker}")
61
  data = yf.download(
62
  ticker,
63
  period="1y",
64
  interval="1d",
65
- prepost=False,
66
- threads=False,
67
  progress=False,
68
  auto_adjust=True
69
  )
70
 
71
- # 数据验证
72
- if data.empty or len(data) < 30:
73
- raise ValueError("有效数据不足(最少需要30个交易日数据)")
 
 
74
  if 'Close' not in data.columns:
75
- raise ValueError("无效的股票数据格式")
76
- if data['Close'].isnull().sum() > 5:
77
- raise ValueError("存在过多缺失")
 
 
 
78
 
79
  # 特征处理
80
- processed_data = enhanced_feature_engineering(data)
81
 
82
  # 准备训练数据
83
  X = processed_data.drop(columns=['Close'])
84
  y = processed_data['Close']
85
 
86
- # 模型配置
87
  model = make_pipeline(
88
  RobustScaler(),
89
  LGBMRegressor(
@@ -96,14 +100,9 @@ def robust_training(ticker):
96
  )
97
  )
98
 
99
- # 训练流程
100
- tscv = TimeSeriesSplit(n_splits=2)
101
- for train_index, _ in tscv.split(X):
102
- if (time.time() - start_time) > 25:
103
- break
104
- X_train = X.iloc[train_index].values
105
- y_train = y.iloc[train_index].values
106
- model.fit(X_train, y_train)
107
 
108
  # 生成预测
109
  current_features = X.iloc[-1:].copy()
@@ -121,26 +120,26 @@ def robust_training(ticker):
121
  current_features['Volatility'] = current_features['Volatility'].values[0]
122
 
123
  return {
124
- 'historical_data': data,
125
  'predictions': pd.Series(predictions, index=future_dates),
126
- 'training_time': time.time() - start_time
127
  }
128
 
129
  except Exception as e:
130
- logger.error(f"训练���败: {str(e)}\n{traceback.format_exc()}")
131
  return None
132
 
133
- def create_visualization(result):
134
- """创建可视化表(修复括号问题)"""
135
  fig = go.Figure()
136
 
137
  # 历史价格
138
  fig.add_trace(go.Scatter(
139
- x=result['historical_data'].index,
140
- y=result['historical_data']['Close'],
141
  name='历史价格',
142
  line=dict(color='#1f77b4', width=2)
143
- )) # 正确闭合括号
144
 
145
  # 预测价格
146
  fig.add_trace(go.Scatter(
@@ -148,103 +147,88 @@ def create_visualization(result):
148
  y=result['predictions'].values,
149
  name='AI预测',
150
  line=dict(color='#ff7f0e', width=2, dash='dot')
151
- )) # 正确闭合括号
152
 
153
  fig.update_layout(
154
- title="股票价格预测结果",
155
  xaxis_title="日期",
156
  yaxis_title="价格 (USD)",
157
  hovermode="x unified",
158
- template="plotly_white",
159
- legend=dict(
160
- orientation="h",
161
- yanchor="bottom",
162
- y=1.02,
163
- xanchor="right",
164
- x=1
165
- )
166
  )
167
  return fig
168
 
169
- def prediction_workflow(ticker):
170
- """完整的预测工作流"""
171
  try:
172
  start_time = time.time()
173
- yield "⌛ 正在快速分析市场数据预计30秒内完成)...", None, None
174
 
175
  result = None
176
- error_msg = ""
177
 
178
- # 后台训练线程
179
- def training_job():
180
  nonlocal result
181
- result = robust_training(ticker)
182
 
183
- thread = threading.Thread(target=training_job)
184
  thread.start()
185
 
186
- # 等待处理
187
  while thread.is_alive():
188
- if time.time() - start_time > 30:
189
- error_msg = " 系统响应超时,请稍后重试"
190
  break
191
  time.sleep(0.1)
192
 
193
- if error_msg:
194
- yield error_msg, None, None
195
  return
196
 
197
- if not result or result['predictions'].empty:
198
- yield "⚠️ 无法生成预测,请检查股票代码有效性", None, None
199
- return
200
-
201
- # 构建输出信息
202
- training_time = f"{result['training_time']:.1f}秒"
203
- latest_pred = f"{result['predictions'].iloc[-1]:.2f}"
204
- info_content = f"""
205
- ## 分析结果
206
- ✅ 成功完成分析(耗时:{training_time})
207
- 📅 最新预测日期:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
208
- 💵 预测收盘价:{latest_pred} USD
209
  """
210
 
211
- # 风险提示
212
- risk_content = """
213
  **风险提示**
214
- 1. 本预测基于历史数据建模,不构成投资建议
215
- 2. 实际价格可能受市场突发事件影响
216
- 3. 预测准确率预测时间跨度增加而降低
217
- 4. 请结合其他信息进行综合判断
218
  """
219
 
220
- yield info_content, create_visualization(result), risk_content
221
 
222
  except Exception as e:
223
- logger.critical(f"系统级错误: {traceback.format_exc()}")
224
- yield "⚠️ 发生意外错误,请联系技术支持", None, None
225
 
226
- # 创建Gradio界面
227
- with gr.Blocks(theme=gr.themes.Soft(), title="智能股票预测系统") as demo:
228
- gr.Markdown("# 📈 智能股票预测系统")
229
 
230
  with gr.Row():
231
- with gr.Column(scale=2):
232
- ticker_input = gr.Textbox(
233
- label="输入股票代码",
234
- placeholder="例如:AAPL (苹果), 00700.HK (腾讯)",
 
235
  max_lines=1
236
  )
237
  submit_btn = gr.Button("开始分析", variant="primary")
238
 
239
- with gr.Column(scale=3):
240
- status_display = gr.Markdown("## 当前状态")
241
- plot_display = gr.Plot(label="价格走势")
242
- risk_display = gr.Markdown()
243
-
 
244
  submit_btn.click(
245
- prediction_workflow,
246
- inputs=ticker_input,
247
- outputs=[status_display, plot_display, risk_display]
248
  )
249
 
250
  if __name__ == "__main__":
@@ -252,11 +236,10 @@ if __name__ == "__main__":
252
  import warnings
253
  warnings.filterwarnings("ignore")
254
 
255
- # 正确用队列配置
256
- demo.queue() # 先配置队列
257
-
258
- # 正确启动参数
259
- demo.launch(
260
  server_port=7860,
261
- show_error=True
 
262
  )
 
6
  import plotly.graph_objs as go
7
  from sklearn.pipeline import make_pipeline
8
  from sklearn.preprocessing import RobustScaler
 
9
  from lightgbm import LGBMRegressor
10
  from loguru import logger
11
  import threading
 
18
  # 日志配置
19
  logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
20
 
21
+ def safe_feature_engineering(df):
22
+ """安全稳定的特征工程函数"""
23
  try:
24
+ # 基础数据校验
25
+ required_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
26
+ if not all(col in df.columns for col in required_cols):
27
+ raise ValueError("缺失必要数据列")
28
+
29
+ # 数据类型强制转换
30
+ df = df[required_cols].copy()
31
+ df = df.astype({col: float for col in required_cols})
32
 
33
  # 基础特征
34
  df['Returns'] = df['Close'].pct_change()
35
  df['Volatility'] = df['Returns'].rolling(5).std()
36
 
37
+ # 技术指标计算
38
  df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi()
39
  df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
40
  df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
 
43
  volume=df['Volume']
44
  ).on_balance_volume()
45
 
46
+ # 清理数据
47
  df.replace([np.inf, -np.inf], np.nan, inplace=True)
48
  df.dropna(inplace=True)
49
 
50
  return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']]
51
 
52
  except Exception as e:
53
+ logger.error(f"特征工程异常: {str(e)}")
54
  raise
55
 
56
+ def safe_training(ticker):
57
+ """稳定可靠的训练函数"""
58
  start_time = time.time()
59
  try:
60
  # 数据获取
61
+ logger.info(f"正在获取 [{ticker}] 数据...")
62
  data = yf.download(
63
  ticker,
64
  period="1y",
65
  interval="1d",
 
 
66
  progress=False,
67
  auto_adjust=True
68
  )
69
 
70
+ # 数据有效性验证
71
+ if data.empty:
72
+ raise ValueError("获取数据为空")
73
+ if len(data) < 30:
74
+ raise ValueError("数据不足30天")
75
  if 'Close' not in data.columns:
76
+ raise ValueError("缺少Close列")
77
+
78
+ # 显式处理空
79
+ nan_count = data['Close'].isna().sum()
80
+ if nan_count > 5:
81
+ raise ValueError(f"检测到{n}个空值")
82
 
83
  # 特征处理
84
+ processed_data = safe_feature_engineering(data)
85
 
86
  # 准备训练数据
87
  X = processed_data.drop(columns=['Close'])
88
  y = processed_data['Close']
89
 
90
+ # 初始化模型
91
  model = make_pipeline(
92
  RobustScaler(),
93
  LGBMRegressor(
 
100
  )
101
  )
102
 
103
+ # 快速训练
104
+ logger.info("开始模型训练...")
105
+ model.fit(X.values[-200:], y.values[-200:]) # 使用最近200个数据点
 
 
 
 
 
106
 
107
  # 生成预测
108
  current_features = X.iloc[-1:].copy()
 
120
  current_features['Volatility'] = current_features['Volatility'].values[0]
121
 
122
  return {
123
+ 'historical': data,
124
  'predictions': pd.Series(predictions, index=future_dates),
125
+ 'time_used': time.time() - start_time
126
  }
127
 
128
  except Exception as e:
129
+ logger.error(f"训练异常: {str(e)}")
130
  return None
131
 
132
+ def create_safe_plot(result):
133
+ """安全绘函数"""
134
  fig = go.Figure()
135
 
136
  # 历史价格
137
  fig.add_trace(go.Scatter(
138
+ x=result['historical'].index,
139
+ y=result['historical']['Close'],
140
  name='历史价格',
141
  line=dict(color='#1f77b4', width=2)
142
+ ))
143
 
144
  # 预测价格
145
  fig.add_trace(go.Scatter(
 
147
  y=result['predictions'].values,
148
  name='AI预测',
149
  line=dict(color='#ff7f0e', width=2, dash='dot')
150
+ ))
151
 
152
  fig.update_layout(
153
+ title="股票价格预测",
154
  xaxis_title="日期",
155
  yaxis_title="价格 (USD)",
156
  hovermode="x unified",
157
+ template="plotly_white"
 
 
 
 
 
 
 
158
  )
159
  return fig
160
 
161
+ def safe_predict_flow(ticker):
162
+ """安全预测流"""
163
  try:
164
  start_time = time.time()
165
+ yield "⌛ 正在快速分析(30秒内完成)...", None, None
166
 
167
  result = None
 
168
 
169
+ # 后台训练
170
+ def train_task():
171
  nonlocal result
172
+ result = safe_training(ticker)
173
 
174
+ thread = threading.Thread(target=train_task)
175
  thread.start()
176
 
177
+ # 等待结果
178
  while thread.is_alive():
179
+ if time.time() - start_time > 28: # 预留2秒缓冲
180
+ yield " 系统正在处理,请稍候...", None, None
181
  break
182
  time.sleep(0.1)
183
 
184
+ if not result:
185
+ yield "⚠️ 分析失败,请检查股票代码", None, None
186
  return
187
 
188
+ # 构建结果
189
+ info = f"""
190
+ ✅ 分析成功(耗时:{result['time_used']:.1f}秒)
191
+ 📅 最新预测:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
192
+ 💵 预测价格:{result['predictions'].iloc[-1]:.2f} USD
 
 
 
 
 
 
 
193
  """
194
 
195
+ risk = """
 
196
  **风险提示**
197
+ 1. 本预测仅供参考,不构成投资建议
198
+ 2. 实际价格可能受市场波动影响
199
+ 3. 预测误差可能随时间增加
 
200
  """
201
 
202
+ yield info, create_safe_plot(result), risk
203
 
204
  except Exception as e:
205
+ logger.critical(f"系统异常: {traceback.format_exc()}")
206
+ yield "⚠️ 系统繁忙,请稍后再试", None, None
207
 
208
+ # 创建界面
209
+ with gr.Blocks(title="稳定版股票预测", theme=gr.themes.Soft()) as app:
210
+ gr.Markdown("# 📊 股票价格预测系统")
211
 
212
  with gr.Row():
213
+ input_col = gr.Column(scale=2)
214
+ with input_col:
215
+ stock_input = gr.Textbox(
216
+ label="股票代码",
217
+ placeholder="输入代码 (如: AAPL, 00700.HK)",
218
  max_lines=1
219
  )
220
  submit_btn = gr.Button("开始分析", variant="primary")
221
 
222
+ output_col = gr.Column(scale=3)
223
+ with output_col:
224
+ status = gr.Markdown("## 状态")
225
+ plot = gr.Plot(label="价格走势")
226
+ risk = gr.Markdown()
227
+
228
  submit_btn.click(
229
+ safe_predict_flow,
230
+ inputs=stock_input,
231
+ outputs=[status, plot, risk]
232
  )
233
 
234
  if __name__ == "__main__":
 
236
  import warnings
237
  warnings.filterwarnings("ignore")
238
 
239
+ # 稳定
240
+ app.queue(concurrency_count=2)
241
+ app.launch(
 
 
242
  server_port=7860,
243
+ show_error=True,
244
+ ssr_mode=False # 禁用实验性功能
245
  )