Spaces:
Runtime error
Runtime error
| # app.py | |
| import gradio as gr | |
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objs as go | |
| from sklearn.pipeline import make_pipeline | |
| from sklearn.preprocessing import RobustScaler | |
| from lightgbm import LGBMRegressor | |
| from loguru import logger | |
| import threading | |
| import time | |
| import traceback | |
| from ta.momentum import RSIIndicator | |
| from ta.trend import EMAIndicator | |
| from ta.volume import OnBalanceVolumeIndicator | |
| # 日志配置 | |
| logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True) | |
| def safe_feature_engineering(df): | |
| """安全稳定的特征工程函数""" | |
| try: | |
| # 基础数据校验 | |
| required_cols = ['Open', 'High', 'Low', 'Close', 'Volume'] | |
| if not all(col in df.columns for col in required_cols): | |
| raise ValueError("缺失必要数据列") | |
| # 数据类型强制转换 | |
| df = df[required_cols].copy() | |
| df = df.astype({col: float for col in required_cols}) | |
| # 基础特征 | |
| df['Returns'] = df['Close'].pct_change() | |
| df['Volatility'] = df['Returns'].rolling(5).std() | |
| # 技术指标计算 | |
| df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi() | |
| df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator() | |
| df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator() | |
| df['OBV'] = OnBalanceVolumeIndicator( | |
| close=df['Close'], | |
| volume=df['Volume'] | |
| ).on_balance_volume() | |
| # 清理数据 | |
| df.replace([np.inf, -np.inf], np.nan, inplace=True) | |
| df.dropna(inplace=True) | |
| return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']] | |
| except Exception as e: | |
| logger.error(f"特征工程异常: {str(e)}") | |
| raise | |
| def safe_training(ticker): | |
| """稳定可靠的训练函数""" | |
| start_time = time.time() | |
| try: | |
| # 数据获取 | |
| logger.info(f"正在获取 [{ticker}] 数据...") | |
| data = yf.download( | |
| ticker, | |
| period="1y", | |
| interval="1d", | |
| progress=False, | |
| auto_adjust=True | |
| ) | |
| # 数据有效性验证 | |
| if data.empty: | |
| raise ValueError("获取数据为空") | |
| if len(data) < 30: | |
| raise ValueError("数据不足30天") | |
| if 'Close' not in data.columns: | |
| raise ValueError("缺少Close列") | |
| # 显式处理空值 | |
| nan_count = data['Close'].isna().sum() | |
| if nan_count > 5: | |
| raise ValueError(f"检测到{n}个空值") | |
| # 特征处理 | |
| processed_data = safe_feature_engineering(data) | |
| # 准备训练数据 | |
| X = processed_data.drop(columns=['Close']) | |
| y = processed_data['Close'] | |
| # 初始化模型 | |
| model = make_pipeline( | |
| RobustScaler(), | |
| LGBMRegressor( | |
| n_estimators=80, | |
| max_depth=4, | |
| learning_rate=0.15, | |
| verbosity=-1, | |
| force_row_wise=True, | |
| random_state=42 | |
| ) | |
| ) | |
| # 快速训练 | |
| logger.info("开始模型训练...") | |
| model.fit(X.values[-200:], y.values[-200:]) # 使用最近200个数据点 | |
| # 生成预测 | |
| current_features = X.iloc[-1:].copy() | |
| predictions = [] | |
| future_dates = pd.date_range(data.index[-1], periods=8)[1:] | |
| for _ in range(7): | |
| current_close = current_features['Close'].values[0] | |
| pred = model.predict(current_features.values)[0] | |
| predictions.append(pred) | |
| # 更新特征 | |
| current_features['Returns'] = (pred - current_close) / current_close | |
| current_features['Close'] = pred | |
| current_features['Volatility'] = current_features['Volatility'].values[0] | |
| return { | |
| 'historical': data, | |
| 'predictions': pd.Series(predictions, index=future_dates), | |
| 'time_used': time.time() - start_time | |
| } | |
| except Exception as e: | |
| logger.error(f"训练异常: {str(e)}") | |
| return None | |
| def create_safe_plot(result): | |
| """安全绘图函数""" | |
| fig = go.Figure() | |
| # 历史价格 | |
| fig.add_trace(go.Scatter( | |
| x=result['historical'].index, | |
| y=result['historical']['Close'], | |
| name='历史价格', | |
| line=dict(color='#1f77b4', width=2) | |
| )) | |
| # 预测价格 | |
| fig.add_trace(go.Scatter( | |
| x=result['predictions'].index, | |
| y=result['predictions'].values, | |
| name='AI预测', | |
| line=dict(color='#ff7f0e', width=2, dash='dot') | |
| )) | |
| fig.update_layout( | |
| title="股票价格预测", | |
| xaxis_title="日期", | |
| yaxis_title="价格 (USD)", | |
| hovermode="x unified", | |
| template="plotly_white" | |
| ) | |
| return fig | |
| def safe_predict_flow(ticker): | |
| """安全预测流程""" | |
| try: | |
| start_time = time.time() | |
| yield "⌛ 正在快速分析中(30秒内完成)...", None, None | |
| result = None | |
| # 后台训练 | |
| def train_task(): | |
| nonlocal result | |
| result = safe_training(ticker) | |
| thread = threading.Thread(target=train_task) | |
| thread.start() | |
| # 等待结果 | |
| while thread.is_alive(): | |
| if time.time() - start_time > 28: # 预留2秒缓冲 | |
| yield "⏳ 系统正在处理,请稍候...", None, None | |
| break | |
| time.sleep(0.1) | |
| if not result: | |
| yield "⚠️ 分析失败,请检查股票代码", None, None | |
| return | |
| # 构建结果 | |
| info = f""" | |
| ✅ 分析成功(耗时:{result['time_used']:.1f}秒) | |
| 📅 最新预测:{result['predictions'].index[-1].strftime('%Y-%m-%d')} | |
| 💵 预测价格:{result['predictions'].iloc[-1]:.2f} USD | |
| """ | |
| risk = """ | |
| **风险提示** | |
| 1. 本预测仅供参考,不构成投资建议 | |
| 2. 实际价格可能受市场波动影响 | |
| 3. 预测误差可能随时间增加 | |
| """ | |
| yield info, create_safe_plot(result), risk | |
| except Exception as e: | |
| logger.critical(f"系统异常: {traceback.format_exc()}") | |
| yield "⚠️ 系统繁忙,请稍后再试", None, None | |
| # 创建界面 | |
| with gr.Blocks(title="稳定版股票预测", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# 📊 股票价格预测系统") | |
| with gr.Row(): | |
| input_col = gr.Column(scale=2) | |
| with input_col: | |
| stock_input = gr.Textbox( | |
| label="股票代码", | |
| placeholder="输入代码 (如: AAPL, 00700.HK)", | |
| max_lines=1 | |
| ) | |
| submit_btn = gr.Button("开始分析", variant="primary") | |
| output_col = gr.Column(scale=3) | |
| with output_col: | |
| status = gr.Markdown("## 状态") | |
| plot = gr.Plot(label="价格走势") | |
| risk = gr.Markdown() | |
| submit_btn.click( | |
| safe_predict_flow, | |
| inputs=stock_input, | |
| outputs=[status, plot, risk] | |
| ) | |
| if __name__ == "__main__": | |
| # 启动配置 | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # 稳定启动 | |
| app.queue(concurrency_count=2) | |
| app.launch( | |
| server_port=7860, | |
| show_error=True, | |
| ssr_mode=False # 禁用实验性功能 | |
| ) |