# app.py import gradio as gr import yfinance as yf import pandas as pd import numpy as np import plotly.graph_objs as go from sklearn.pipeline import make_pipeline from sklearn.preprocessing import RobustScaler from lightgbm import LGBMRegressor from loguru import logger import threading import time import traceback from ta.momentum import RSIIndicator from ta.trend import EMAIndicator from ta.volume import OnBalanceVolumeIndicator # 日志配置 logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True) def safe_feature_engineering(df): """安全稳定的特征工程函数""" try: # 基础数据校验 required_cols = ['Open', 'High', 'Low', 'Close', 'Volume'] if not all(col in df.columns for col in required_cols): raise ValueError("缺失必要数据列") # 数据类型强制转换 df = df[required_cols].copy() df = df.astype({col: float for col in required_cols}) # 基础特征 df['Returns'] = df['Close'].pct_change() df['Volatility'] = df['Returns'].rolling(5).std() # 技术指标计算 df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi() df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator() df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator() df['OBV'] = OnBalanceVolumeIndicator( close=df['Close'], volume=df['Volume'] ).on_balance_volume() # 清理数据 df.replace([np.inf, -np.inf], np.nan, inplace=True) df.dropna(inplace=True) return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']] except Exception as e: logger.error(f"特征工程异常: {str(e)}") raise def safe_training(ticker): """稳定可靠的训练函数""" start_time = time.time() try: # 数据获取 logger.info(f"正在获取 [{ticker}] 数据...") data = yf.download( ticker, period="1y", interval="1d", progress=False, auto_adjust=True ) # 数据有效性验证 if data.empty: raise ValueError("获取数据为空") if len(data) < 30: raise ValueError("数据不足30天") if 'Close' not in data.columns: raise ValueError("缺少Close列") # 显式处理空值 nan_count = data['Close'].isna().sum() if nan_count > 5: raise ValueError(f"检测到{n}个空值") # 特征处理 processed_data = safe_feature_engineering(data) # 准备训练数据 X = processed_data.drop(columns=['Close']) y = processed_data['Close'] # 初始化模型 model = make_pipeline( RobustScaler(), LGBMRegressor( n_estimators=80, max_depth=4, learning_rate=0.15, verbosity=-1, force_row_wise=True, random_state=42 ) ) # 快速训练 logger.info("开始模型训练...") model.fit(X.values[-200:], y.values[-200:]) # 使用最近200个数据点 # 生成预测 current_features = X.iloc[-1:].copy() predictions = [] future_dates = pd.date_range(data.index[-1], periods=8)[1:] for _ in range(7): current_close = current_features['Close'].values[0] pred = model.predict(current_features.values)[0] predictions.append(pred) # 更新特征 current_features['Returns'] = (pred - current_close) / current_close current_features['Close'] = pred current_features['Volatility'] = current_features['Volatility'].values[0] return { 'historical': data, 'predictions': pd.Series(predictions, index=future_dates), 'time_used': time.time() - start_time } except Exception as e: logger.error(f"训练异常: {str(e)}") return None def create_safe_plot(result): """安全绘图函数""" fig = go.Figure() # 历史价格 fig.add_trace(go.Scatter( x=result['historical'].index, y=result['historical']['Close'], name='历史价格', line=dict(color='#1f77b4', width=2) )) # 预测价格 fig.add_trace(go.Scatter( x=result['predictions'].index, y=result['predictions'].values, name='AI预测', line=dict(color='#ff7f0e', width=2, dash='dot') )) fig.update_layout( title="股票价格预测", xaxis_title="日期", yaxis_title="价格 (USD)", hovermode="x unified", template="plotly_white" ) return fig def safe_predict_flow(ticker): """安全预测流程""" try: start_time = time.time() yield "⌛ 正在快速分析中(30秒内完成)...", None, None result = None # 后台训练 def train_task(): nonlocal result result = safe_training(ticker) thread = threading.Thread(target=train_task) thread.start() # 等待结果 while thread.is_alive(): if time.time() - start_time > 28: # 预留2秒缓冲 yield "⏳ 系统正在处理,请稍候...", None, None break time.sleep(0.1) if not result: yield "⚠️ 分析失败,请检查股票代码", None, None return # 构建结果 info = f""" ✅ 分析成功(耗时:{result['time_used']:.1f}秒) 📅 最新预测:{result['predictions'].index[-1].strftime('%Y-%m-%d')} 💵 预测价格:{result['predictions'].iloc[-1]:.2f} USD """ risk = """ **风险提示** 1. 本预测仅供参考,不构成投资建议 2. 实际价格可能受市场波动影响 3. 预测误差可能随时间增加 """ yield info, create_safe_plot(result), risk except Exception as e: logger.critical(f"系统异常: {traceback.format_exc()}") yield "⚠️ 系统繁忙,请稍后再试", None, None # 创建界面 with gr.Blocks(title="稳定版股票预测", theme=gr.themes.Soft()) as app: gr.Markdown("# 📊 股票价格预测系统") with gr.Row(): input_col = gr.Column(scale=2) with input_col: stock_input = gr.Textbox( label="股票代码", placeholder="输入代码 (如: AAPL, 00700.HK)", max_lines=1 ) submit_btn = gr.Button("开始分析", variant="primary") output_col = gr.Column(scale=3) with output_col: status = gr.Markdown("## 状态") plot = gr.Plot(label="价格走势") risk = gr.Markdown() submit_btn.click( safe_predict_flow, inputs=stock_input, outputs=[status, plot, risk] ) if __name__ == "__main__": # 启动配置 import warnings warnings.filterwarnings("ignore") # 稳定启动 app.queue(concurrency_count=2) app.launch( server_port=7860, show_error=True, ssr_mode=False # 禁用实验性功能 )