changcheng967 commited on
Commit
cda5d9a
·
verified ·
1 Parent(s): 990125c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import yfinance as yf
4
+ import pandas as pd
5
+ import numpy as np
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from sklearn.pipeline import make_pipeline
8
+ from sklearn.linear_model import Ridge
9
+ from loguru import logger
10
+ import time
11
+ import threading
12
+ import plotly.graph_objs as go
13
+
14
+ # 配置日志
15
+ logger.add("app.log", rotation="1 MB", level="DEBUG")
16
+
17
+ def quick_feature_engineering(df):
18
+ """快速特征工程"""
19
+ df = df.copy()
20
+ # 基础特征
21
+ df['Returns'] = df['Close'].pct_change()
22
+ df['Volatility'] = df['Returns'].rolling(5).std()
23
+ # 简化时间特征
24
+ df['Day'] = df.index.dayofweek
25
+ df['Month'] = df.index.month
26
+ return df.dropna()
27
+
28
+ def rapid_training(ticker):
29
+ """快速训练流程(必须在30秒内完成)"""
30
+ start_time = time.time()
31
+
32
+ try:
33
+ # 获取数据(限制为1年数据)
34
+ logger.info(f"Fetching data for {ticker}")
35
+ data = yf.download(ticker, period="1y", progress=False)
36
+ if data.empty:
37
+ raise ValueError("No data available")
38
+
39
+ # 特征工程
40
+ logger.debug("Processing features")
41
+ data = quick_feature_engineering(data)
42
+
43
+ # 准备训练数据
44
+ X = data.drop(columns=['Close'])
45
+ y = data['Close']
46
+
47
+ # 最后7天作为测试集
48
+ train_size = -7
49
+ X_train, y_train = X.iloc[:train_size], y.iloc[:train_size]
50
+
51
+ # 使用轻量级模型管道
52
+ model = make_pipeline(
53
+ MinMaxScaler(),
54
+ Ridge(alpha=1.0) # 调整正则化强度
55
+
56
+ logger.info("Start training")
57
+ model.fit(X_train, y_train)
58
+
59
+ # 生成预测(未来7天)
60
+ logger.debug("Generating predictions")
61
+ last_features = X.iloc[-1:].values
62
+ future_dates = pd.date_range(data.index[-1], periods=8)[1:]
63
+ predictions = []
64
+
65
+ # 递归预测
66
+ current_features = last_features.copy()
67
+ for _ in range(7):
68
+ pred = model.predict(current_features)[0]
69
+ predictions.append(pred)
70
+ # 更新特征(简化处理)
71
+ current_features[0][0] = pred # 更新Open
72
+ current_features[0][3] = pred # 更新Close
73
+
74
+ training_time = time.time() - start_time
75
+ logger.success(f"Training completed in {training_time:.2f}s")
76
+
77
+ return {
78
+ 'data': data,
79
+ 'model': model,
80
+ 'predictions': pd.Series(predictions, index=future_dates),
81
+ 'training_time': training_time
82
+ }
83
+
84
+ except Exception as e:
85
+ logger.error(f"Error in training: {str(e)}")
86
+ return None
87
+
88
+ def create_plot(result):
89
+ """创建交互式图表"""
90
+ data = result['data']
91
+ pred = result['predictions']
92
+
93
+ fig = go.Figure()
94
+
95
+ # 历史价格
96
+ fig.add_trace(go.Scatter(
97
+ x=data.index,
98
+ y=data['Close'],
99
+ name='Historical Price',
100
+ line=dict(color='blue')
101
+ )
102
+
103
+ # 预测价格
104
+ fig.add_trace(go.Scatter(
105
+ x=pred.index,
106
+ y=pred.values,
107
+ name='Prediction',
108
+ line=dict(color='red', dash='dot')
109
+ )
110
+
111
+ fig.update_layout(
112
+ title=f"Stock Price Prediction",
113
+ xaxis_title='Date',
114
+ yaxis_title='Price (USD)',
115
+ hovermode="x unified",
116
+ showlegend=True
117
+ )
118
+
119
+ return fig
120
+
121
+ def predict_stock(ticker):
122
+ """预测流程处理"""
123
+ start_time = time.time()
124
+
125
+ # 显示加载状态
126
+ yield "⌛ 正在获取数据并训练模型(最多30秒)...", None
127
+
128
+ # 在后台线程中运行训练
129
+ result = None
130
+ def train_thread():
131
+ nonlocal result
132
+ result = rapid_training(ticker)
133
+
134
+ thread = threading.Thread(target=train_thread)
135
+ thread.start()
136
+
137
+ # 等待完成(最多30秒)
138
+ thread.join(timeout=30)
139
+
140
+ if not result:
141
+ yield "❌ 训练失败或超时,请尝试其他股票代码", None
142
+ return
143
+
144
+ if result['training_time'] > 30:
145
+ yield "⚠️ 训练超时,结果可能不准确", create_plot(result)
146
+ return
147
+
148
+ # 显示结果
149
+ info_msg = f"✅ 训练成功(耗时{result['training_time']:.1f}秒)\n" \
150
+ f"最新预测:{pred.values[-1]:.2f} USD({pred.index[-1].strftime('%Y-%m-%d')})"
151
+
152
+ yield info_msg, create_plot(result)
153
+
154
+ with gr.Blocks() as demo:
155
+ gr.Markdown("# 🚀 实时股票预测系统")
156
+
157
+ with gr.Row():
158
+ ticker_input = gr.Textbox(
159
+ label="输入股票代码",
160
+ placeholder="例如:AAPL(美股), 0005.HK(港股)",
161
+ max_lines=1
162
+ )
163
+ submit_btn = gr.Button("立即预测", variant="primary")
164
+
165
+ status_output = gr.Textbox(label="状态", interactive=False)
166
+ plot_output = gr.Plot(label="价格预测")
167
+
168
+ submit_btn.click(
169
+ predict_stock,
170
+ inputs=ticker_input,
171
+ outputs=[status_output, plot_output]
172
+ )
173
+
174
+ if __name__ == "__main__":
175
+ demo.launch(server_port=7860)