changcheng967 commited on
Commit
9c4550c
·
verified ·
1 Parent(s): 1994951

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -79
app.py CHANGED
@@ -3,74 +3,111 @@ import gradio as gr
3
  import yfinance as yf
4
  import pandas as pd
5
  import numpy as np
6
- from sklearn.preprocessing import MinMaxScaler
7
  from sklearn.pipeline import make_pipeline
8
- from sklearn.linear_model import Ridge
 
 
9
  from loguru import logger
10
- import time
11
  import threading
12
- import plotly.graph_objs as go
 
13
 
14
  # 配置日志
15
- logger.add("app.log", rotation="1 MB", level="DEBUG")
16
 
17
- def quick_feature_engineering(df):
18
- """快速特征工程"""
19
- df = df.copy()
20
- # 基础特征
21
- df['Returns'] = df['Close'].pct_change()
22
- df['Volatility'] = df['Returns'].rolling(5).std()
23
- # 简化时间特征
24
- df['Day'] = df.index.dayofweek
25
- df['Month'] = df.index.month
26
- return df.dropna()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- def rapid_training(ticker):
29
- """快速训练流程(必须在30秒内完成)"""
30
  start_time = time.time()
31
 
32
  try:
33
- # 获取数据(限制为1年
34
  logger.info(f"Fetching data for {ticker}")
35
- data = yf.download(ticker, period="1y", progress=False)
36
- if data.empty:
37
- raise ValueError("No data available")
 
 
 
 
 
 
 
38
 
39
  # 特征工程
40
  logger.debug("Processing features")
41
- data = quick_feature_engineering(data)
42
 
43
  # 准备训练数据
44
  X = data.drop(columns=['Close'])
45
  y = data['Close']
46
 
47
- # 最后7天作为测试集
48
- train_size = -7
49
- X_train, y_train = X.iloc[:train_size], y.iloc[:train_size]
50
 
51
- # 使用轻量级模型管道
52
  model = make_pipeline(
53
- MinMaxScaler(),
54
- Ridge(alpha=1.0) # 调整正则化强度
 
 
 
 
 
 
 
55
 
56
- logger.info("Start training")
57
- model.fit(X_train, y_train)
 
 
 
 
 
 
58
 
59
- # 生成预测(未来7天)
60
  logger.debug("Generating predictions")
61
- last_features = X.iloc[-1:].values
62
  future_dates = pd.date_range(data.index[-1], periods=8)[1:]
63
- predictions = []
64
 
65
- # 递归预测
66
- current_features = last_features.copy()
 
67
  for _ in range(7):
68
  pred = model.predict(current_features)[0]
69
  predictions.append(pred)
70
- # 更新特征(简化处理
71
- current_features[0][0] = pred # 更新Open
72
- current_features[0][3] = pred # 更新Close
73
-
74
  training_time = time.time() - start_time
75
  logger.success(f"Training completed in {training_time:.2f}s")
76
 
@@ -82,11 +119,11 @@ def rapid_training(ticker):
82
  }
83
 
84
  except Exception as e:
85
- logger.error(f"Error in training: {str(e)}")
86
  return None
87
 
88
  def create_plot(result):
89
- """创建交互式图表"""
90
  data = result['data']
91
  pred = result['predictions']
92
 
@@ -96,79 +133,106 @@ def create_plot(result):
96
  fig.add_trace(go.Scatter(
97
  x=data.index,
98
  y=data['Close'],
99
- name='Historical Price',
100
- line=dict(color='blue')
101
  )
102
 
103
  # 预测价格
104
  fig.add_trace(go.Scatter(
105
  x=pred.index,
106
  y=pred.values,
107
- name='Prediction',
108
- line=dict(color='red', dash='dot')
109
  )
110
 
111
  fig.update_layout(
112
- title=f"Stock Price Prediction",
113
- xaxis_title='Date',
114
- yaxis_title='Price (USD)',
115
  hovermode="x unified",
116
- showlegend=True
 
 
117
  )
118
 
119
  return fig
120
 
121
  def predict_stock(ticker):
122
- """预测流程处理"""
123
  start_time = time.time()
124
 
125
- # 显示加载状态
126
- yield "⌛ 正在获取数据并训练模型(最多30秒)...", None
127
 
128
- # 在后台线程中运行训练
129
  result = None
130
- def train_thread():
 
 
131
  nonlocal result
132
- result = rapid_training(ticker)
 
 
 
133
 
134
- thread = threading.Thread(target=train_thread)
135
  thread.start()
136
 
137
- # 等待完成(最多30秒)
138
- thread.join(timeout=30)
 
 
 
 
139
 
140
- if not result:
141
- yield "❌ 训练失败或超时,请尝试其他股票代码", None
142
  return
143
 
144
- if result['training_time'] > 30:
145
- yield "⚠️ 训练超时,结果可能准确", create_plot(result)
146
  return
147
 
148
- # 结果
149
- info_msg = f"✅ 训练成功(耗时{result['training_time']:.1f}秒)\n" \
150
- f"最新预测:{pred.values[-1]:.2f} USD({pred.index[-1].strftime('%Y-%m-%d')})"
151
-
152
- yield info_msg, create_plot(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- with gr.Blocks() as demo:
155
- gr.Markdown("# 🚀 实时股票预测系统")
156
 
157
  with gr.Row():
158
- ticker_input = gr.Textbox(
159
- label="输入股票代码",
160
- placeholder="例如:AAPL(美), 0005.HK(港股)",
161
- max_lines=1
162
- )
163
- submit_btn = gr.Button("立即预测", variant="primary")
164
-
165
- status_output = gr.Textbox(label="状态", interactive=False)
166
- plot_output = gr.Plot(label="价格预测")
 
 
 
167
 
168
  submit_btn.click(
169
  predict_stock,
170
  inputs=ticker_input,
171
- outputs=[status_output, plot_output]
172
  )
173
 
174
  if __name__ == "__main__":
 
3
  import yfinance as yf
4
  import pandas as pd
5
  import numpy as np
6
+ import plotly.graph_objs as go
7
  from sklearn.pipeline import make_pipeline
8
+ from sklearn.preprocessing import RobustScaler
9
+ from sklearn.model_selection import TimeSeriesSplit
10
+ from lightgbm import LGBMRegressor
11
  from loguru import logger
 
12
  import threading
13
+ import time
14
+ from ta import add_all_ta_features # 技术指标库
15
 
16
  # 配置日志
17
+ logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
18
 
19
+ def enhanced_feature_engineering(df):
20
+ """优化后的特征工程(包含技术指标)"""
21
+ try:
22
+ df = df.copy()
23
+ # 基础特征
24
+ df['Returns'] = df['Close'].pct_change()
25
+ df['Volatility'] = df['Returns'].rolling(5).std()
26
+
27
+ # 使用ta库快速添加技术指标
28
+ df = add_all_ta_features(
29
+ df,
30
+ open="Open", high="High", low="Low", close="Close", volume="Volume",
31
+ fillna=True
32
+ )
33
+
34
+ # 选择关键特征
35
+ selected_features = [
36
+ 'Close', 'Returns', 'Volatility',
37
+ 'trend_ema_fast', 'trend_ema_slow',
38
+ 'momentum_rsi', 'volume_obv'
39
+ ]
40
+
41
+ return df[selected_features].dropna()
42
+
43
+ except Exception as e:
44
+ logger.error(f"Feature engineering failed: {str(e)}")
45
+ raise
46
 
47
+ def robust_training(ticker):
48
+ """增强型训练流程(30秒超时保证)"""
49
  start_time = time.time()
50
 
51
  try:
52
+ # 获取数据(优化API参数)
53
  logger.info(f"Fetching data for {ticker}")
54
+ data = yf.download(
55
+ ticker,
56
+ period="1y",
57
+ interval="1d",
58
+ prepost=False,
59
+ threads=False,
60
+ progress=False
61
+ )
62
+ if data.empty or len(data) < 30:
63
+ raise ValueError("Insufficient data for training")
64
 
65
  # 特征工程
66
  logger.debug("Processing features")
67
+ data = enhanced_feature_engineering(data)
68
 
69
  # 准备训练数据
70
  X = data.drop(columns=['Close'])
71
  y = data['Close']
72
 
73
+ # 时间序列交叉验证
74
+ tscv = TimeSeriesSplit(n_splits=3)
 
75
 
76
+ # 轻量级模型管道
77
  model = make_pipeline(
78
+ RobustScaler(),
79
+ LGBMRegressor(
80
+ n_estimators=100,
81
+ max_depth=5,
82
+ learning_rate=0.1,
83
+ verbosity=-1,
84
+ force_row_wise=True
85
+ )
86
+ )
87
 
88
+ # 快速交叉验证
89
+ logger.info("Starting rapid training")
90
+ for train_index, _ in tscv.split(X):
91
+ X_train = X.iloc[train_index]
92
+ y_train = y.iloc[train_index]
93
+ model.fit(X_train, y_train)
94
+ if (time.time() - start_time) > 25: # 保留5秒预测时间
95
+ break
96
 
97
+ # 生成预测
98
  logger.debug("Generating predictions")
 
99
  future_dates = pd.date_range(data.index[-1], periods=8)[1:]
 
100
 
101
+ # 使用最后有效特征生成预测
102
+ current_features = X.iloc[-1:].copy()
103
+ predictions = []
104
  for _ in range(7):
105
  pred = model.predict(current_features)[0]
106
  predictions.append(pred)
107
+ # 更新特征(简化逻辑
108
+ current_features['Returns'] = (pred - current_features['Close']) / current_features['Close']
109
+ current_features['Close'] = pred
110
+
111
  training_time = time.time() - start_time
112
  logger.success(f"Training completed in {training_time:.2f}s")
113
 
 
119
  }
120
 
121
  except Exception as e:
122
+ logger.error(f"Training error: {str(e)}")
123
  return None
124
 
125
  def create_plot(result):
126
+ """增强型可视化"""
127
  data = result['data']
128
  pred = result['predictions']
129
 
 
133
  fig.add_trace(go.Scatter(
134
  x=data.index,
135
  y=data['Close'],
136
+ name='历史价格',
137
+ line=dict(color='#1f77b4')
138
  )
139
 
140
  # 预测价格
141
  fig.add_trace(go.Scatter(
142
  x=pred.index,
143
  y=pred.values,
144
+ name='AI预测',
145
+ line=dict(color='#ff7f0e', dash='dot')
146
  )
147
 
148
  fig.update_layout(
149
+ title=f"股价预测结果",
150
+ xaxis_title="日期",
151
+ yaxis_title="价格 (USD)",
152
  hovermode="x unified",
153
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
154
+ margin=dict(t=40, b=20),
155
+ template="plotly_white"
156
  )
157
 
158
  return fig
159
 
160
  def predict_stock(ticker):
161
+ """增强型预测流程"""
162
  start_time = time.time()
163
 
164
+ yield "⌛ 正在快速分析市场数据(预计30秒内完成)...", None, None
 
165
 
 
166
  result = None
167
+ error_msg = ""
168
+
169
+ def training_task():
170
  nonlocal result
171
+ try:
172
+ result = robust_training(ticker)
173
+ except Exception as e:
174
+ logger.error(f"Critical error: {str(e)}")
175
 
176
+ thread = threading.Thread(target=training_task)
177
  thread.start()
178
 
179
+ # 等待线程完成(最多30秒)
180
+ while thread.is_alive():
181
+ if (time.time() - start_time) > 30:
182
+ error_msg = "⏰ 系统响应超时,请简化查询条件后重试"
183
+ break
184
+ time.sleep(0.1)
185
 
186
+ if error_msg:
187
+ yield error_msg, None, None
188
  return
189
 
190
+ if not result or result['predictions'].empty:
191
+ yield "⚠️ 数据足或股票代码无效,请尝试其他代码", None, None
192
  return
193
 
194
+ # 构建风险提
195
+ risk_warning = """
196
+ **风险提示说明**
197
+ 1. 本预测基于历史数据生成,不构成投资建议
198
+ 2. 实际股价受市场环境、公司公告等多因素影响
199
+ 3. 预测误差可能随市场波动增大
200
+ 4. 过去表现不代表未来结果
201
+ 最新预测仅供参考,请理性判断
202
+ """
203
+
204
+ # 格式化输出信息
205
+ time_used = f"{result['training_time']:.1f}秒"
206
+ latest_pred = f"{result['predictions'].iloc[-1]:.2f} USD"
207
+ info_box = f"""
208
+ ✅ 分析完成(耗时:{time_used})
209
+ 📅 最新预测日期:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
210
+ 💵 预测收盘价:{latest_pred}
211
+ """
212
+
213
+ yield info_box, create_plot(result), risk_warning
214
 
215
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
216
+ gr.Markdown("# 📊 智能股票预测系统")
217
 
218
  with gr.Row():
219
+ with gr.Column(scale=2):
220
+ ticker_input = gr.Textbox(
221
+ label="输入票代码",
222
+ placeholder="例如:AAPL (苹果), 00700.HK (腾讯)",
223
+ max_lines=1
224
+ )
225
+ submit_btn = gr.Button("开始分析", variant="primary")
226
+
227
+ with gr.Column(scale=3):
228
+ status_output = gr.Markdown(label="分析进度")
229
+ plot_output = gr.Plot(label="价格趋势")
230
+ risk_output = gr.Markdown()
231
 
232
  submit_btn.click(
233
  predict_stock,
234
  inputs=ticker_input,
235
+ outputs=[status_output, plot_output, risk_output]
236
  )
237
 
238
  if __name__ == "__main__":