File size: 7,624 Bytes
cda5d9a
 
 
 
 
9c4550c
cda5d9a
9c4550c
 
cda5d9a
 
9c4550c
65e4224
 
 
 
cda5d9a
e3855c5
9c4550c
cda5d9a
14d684f
 
9c4550c
14d684f
 
 
 
 
 
 
 
65e4224
 
9c4550c
 
 
14d684f
e3855c5
65e4224
 
 
 
 
 
9c4550c
14d684f
65e4224
 
9c4550c
f9e9621
9c4550c
 
14d684f
9c4550c
cda5d9a
14d684f
 
cda5d9a
 
e3855c5
14d684f
9c4550c
 
 
 
65e4224
 
9c4550c
cda5d9a
14d684f
 
 
 
 
e3855c5
14d684f
 
 
 
 
 
65e4224
 
14d684f
cda5d9a
65e4224
e3855c5
 
cda5d9a
14d684f
cda5d9a
9c4550c
 
65e4224
 
 
9c4550c
e3855c5
 
9c4550c
 
cda5d9a
14d684f
 
 
cda5d9a
65e4224
9c4550c
 
e3855c5
65e4224
cda5d9a
e3855c5
 
cda5d9a
e3855c5
f9e9621
e3855c5
9c4550c
e3855c5
9c4550c
cda5d9a
14d684f
cda5d9a
14d684f
cda5d9a
 
 
14d684f
cda5d9a
 
14d684f
 
cda5d9a
 
 
 
14d684f
 
9c4550c
e3855c5
14d684f
cda5d9a
 
 
65e4224
 
9c4550c
e3855c5
14d684f
cda5d9a
 
14d684f
9c4550c
 
cda5d9a
14d684f
cda5d9a
 
 
14d684f
 
65e4224
 
14d684f
65e4224
 
 
14d684f
 
65e4224
14d684f
65e4224
14d684f
65e4224
 
14d684f
65e4224
14d684f
 
65e4224
 
 
14d684f
 
65e4224
 
14d684f
 
 
 
 
65e4224
 
14d684f
65e4224
14d684f
 
 
65e4224
 
14d684f
9c4550c
65e4224
14d684f
 
cda5d9a
14d684f
 
 
cda5d9a
 
14d684f
 
 
 
 
9c4550c
 
e3855c5
9c4550c
14d684f
 
 
 
 
 
cda5d9a
14d684f
 
 
cda5d9a
 
 
e3855c5
 
 
 
14d684f
 
 
e3855c5
14d684f
 
e3855c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
# app.py
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import RobustScaler
from lightgbm import LGBMRegressor
from loguru import logger
import threading
import time
import traceback
from ta.momentum import RSIIndicator
from ta.trend import EMAIndicator
from ta.volume import OnBalanceVolumeIndicator

# 日志配置
logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)

def safe_feature_engineering(df):
    """安全稳定的特征工程函数"""
    try:
        # 基础数据校验
        required_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
        if not all(col in df.columns for col in required_cols):
            raise ValueError("缺失必要数据列")
        
        # 数据类型强制转换
        df = df[required_cols].copy()
        df = df.astype({col: float for col in required_cols})
        
        # 基础特征
        df['Returns'] = df['Close'].pct_change()
        df['Volatility'] = df['Returns'].rolling(5).std()
        
        # 技术指标计算
        df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi()
        df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
        df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
        df['OBV'] = OnBalanceVolumeIndicator(
            close=df['Close'],
            volume=df['Volume']
        ).on_balance_volume()
        
        # 清理数据
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        
        return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']]
    
    except Exception as e:
        logger.error(f"特征工程异常: {str(e)}")
        raise

def safe_training(ticker):
    """稳定可靠的训练函数"""
    start_time = time.time()
    try:
        # 数据获取
        logger.info(f"正在获取 [{ticker}] 数据...")
        data = yf.download(
            ticker,
            period="1y",
            interval="1d",
            progress=False,
            auto_adjust=True
        )
        
        # 数据有效性验证
        if data.empty:
            raise ValueError("获取数据为空")
        if len(data) < 30:
            raise ValueError("数据不足30天")
        if 'Close' not in data.columns:
            raise ValueError("缺少Close列")
        
        # 显式处理空值
        nan_count = data['Close'].isna().sum()
        if nan_count > 5:
            raise ValueError(f"检测到{n}个空值")
        
        # 特征处理
        processed_data = safe_feature_engineering(data)
        
        # 准备训练数据
        X = processed_data.drop(columns=['Close'])
        y = processed_data['Close']
        
        # 初始化模型
        model = make_pipeline(
            RobustScaler(),
            LGBMRegressor(
                n_estimators=80,
                max_depth=4,
                learning_rate=0.15,
                verbosity=-1,
                force_row_wise=True,
                random_state=42
            )
        )
        
        # 快速训练
        logger.info("开始模型训练...")
        model.fit(X.values[-200:], y.values[-200:])  # 使用最近200个数据点
        
        # 生成预测
        current_features = X.iloc[-1:].copy()
        predictions = []
        future_dates = pd.date_range(data.index[-1], periods=8)[1:]
        
        for _ in range(7):
            current_close = current_features['Close'].values[0]
            pred = model.predict(current_features.values)[0]
            predictions.append(pred)
            
            # 更新特征
            current_features['Returns'] = (pred - current_close) / current_close
            current_features['Close'] = pred
            current_features['Volatility'] = current_features['Volatility'].values[0]
        
        return {
            'historical': data,
            'predictions': pd.Series(predictions, index=future_dates),
            'time_used': time.time() - start_time
        }
    
    except Exception as e:
        logger.error(f"训练异常: {str(e)}")
        return None

def create_safe_plot(result):
    """安全绘图函数"""
    fig = go.Figure()
    
    # 历史价格
    fig.add_trace(go.Scatter(
        x=result['historical'].index,
        y=result['historical']['Close'],
        name='历史价格',
        line=dict(color='#1f77b4', width=2)
    ))
    
    # 预测价格
    fig.add_trace(go.Scatter(
        x=result['predictions'].index,
        y=result['predictions'].values,
        name='AI预测',
        line=dict(color='#ff7f0e', width=2, dash='dot')
    ))
    
    fig.update_layout(
        title="股票价格预测",
        xaxis_title="日期",
        yaxis_title="价格 (USD)",
        hovermode="x unified",
        template="plotly_white"
    )
    return fig

def safe_predict_flow(ticker):
    """安全预测流程"""
    try:
        start_time = time.time()
        yield "⌛ 正在快速分析中(30秒内完成)...", None, None
        
        result = None
        
        # 后台训练
        def train_task():
            nonlocal result
            result = safe_training(ticker)
        
        thread = threading.Thread(target=train_task)
        thread.start()
        
        # 等待结果
        while thread.is_alive():
            if time.time() - start_time > 28:  # 预留2秒缓冲
                yield "⏳ 系统正在处理,请稍候...", None, None
                break
            time.sleep(0.1)
        
        if not result:
            yield "⚠️ 分析失败,请检查股票代码", None, None
            return
        
        # 构建结果
        info = f"""
        ✅ 分析成功(耗时:{result['time_used']:.1f}秒)
        📅 最新预测:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
        💵 预测价格:{result['predictions'].iloc[-1]:.2f} USD
        """
        
        risk = """
        **风险提示**
        1. 本预测仅供参考,不构成投资建议
        2. 实际价格可能受市场波动影响
        3. 预测误差可能随时间增加
        """
        
        yield info, create_safe_plot(result), risk
    
    except Exception as e:
        logger.critical(f"系统异常: {traceback.format_exc()}")
        yield "⚠️ 系统繁忙,请稍后再试", None, None

# 创建界面
with gr.Blocks(title="稳定版股票预测", theme=gr.themes.Soft()) as app:
    gr.Markdown("# 📊 股票价格预测系统")
    
    with gr.Row():
        input_col = gr.Column(scale=2)
        with input_col:
            stock_input = gr.Textbox(
                label="股票代码",
                placeholder="输入代码 (如: AAPL, 00700.HK)",
                max_lines=1
            )
            submit_btn = gr.Button("开始分析", variant="primary")
        
        output_col = gr.Column(scale=3)
        with output_col:
            status = gr.Markdown("## 状态")
            plot = gr.Plot(label="价格走势")
            risk = gr.Markdown()
    
    submit_btn.click(
        safe_predict_flow,
        inputs=stock_input,
        outputs=[status, plot, risk]
    )

if __name__ == "__main__":
    # 启动配置
    import warnings
    warnings.filterwarnings("ignore")
    
    # 稳定启动
    app.queue(concurrency_count=2)
    app.launch(
        server_port=7860,
        show_error=True,
        ssr_mode=False  # 禁用实验性功能
    )