Spaces:
Runtime error
Runtime error
File size: 7,624 Bytes
cda5d9a 9c4550c cda5d9a 9c4550c cda5d9a 9c4550c 65e4224 cda5d9a e3855c5 9c4550c cda5d9a 14d684f 9c4550c 14d684f 65e4224 9c4550c 14d684f e3855c5 65e4224 9c4550c 14d684f 65e4224 9c4550c f9e9621 9c4550c 14d684f 9c4550c cda5d9a 14d684f cda5d9a e3855c5 14d684f 9c4550c 65e4224 9c4550c cda5d9a 14d684f e3855c5 14d684f 65e4224 14d684f cda5d9a 65e4224 e3855c5 cda5d9a 14d684f cda5d9a 9c4550c 65e4224 9c4550c e3855c5 9c4550c cda5d9a 14d684f cda5d9a 65e4224 9c4550c e3855c5 65e4224 cda5d9a e3855c5 cda5d9a e3855c5 f9e9621 e3855c5 9c4550c e3855c5 9c4550c cda5d9a 14d684f cda5d9a 14d684f cda5d9a 14d684f cda5d9a 14d684f cda5d9a 14d684f 9c4550c e3855c5 14d684f cda5d9a 65e4224 9c4550c e3855c5 14d684f cda5d9a 14d684f 9c4550c cda5d9a 14d684f cda5d9a 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 65e4224 14d684f 9c4550c 65e4224 14d684f cda5d9a 14d684f cda5d9a 14d684f 9c4550c e3855c5 9c4550c 14d684f cda5d9a 14d684f cda5d9a e3855c5 14d684f e3855c5 14d684f e3855c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | # app.py
import gradio as gr
import yfinance as yf
import pandas as pd
import numpy as np
import plotly.graph_objs as go
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import RobustScaler
from lightgbm import LGBMRegressor
from loguru import logger
import threading
import time
import traceback
from ta.momentum import RSIIndicator
from ta.trend import EMAIndicator
from ta.volume import OnBalanceVolumeIndicator
# 日志配置
logger.add("app.log", rotation="1 MB", level="DEBUG", backtrace=True, diagnose=True)
def safe_feature_engineering(df):
"""安全稳定的特征工程函数"""
try:
# 基础数据校验
required_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
if not all(col in df.columns for col in required_cols):
raise ValueError("缺失必要数据列")
# 数据类型强制转换
df = df[required_cols].copy()
df = df.astype({col: float for col in required_cols})
# 基础特征
df['Returns'] = df['Close'].pct_change()
df['Volatility'] = df['Returns'].rolling(5).std()
# 技术指标计算
df['RSI_14'] = RSIIndicator(df['Close'], window=14).rsi()
df['EMA_12'] = EMAIndicator(df['Close'], window=12).ema_indicator()
df['EMA_26'] = EMAIndicator(df['Close'], window=26).ema_indicator()
df['OBV'] = OnBalanceVolumeIndicator(
close=df['Close'],
volume=df['Volume']
).on_balance_volume()
# 清理数据
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df.dropna(inplace=True)
return df[['Close', 'Returns', 'Volatility', 'RSI_14', 'EMA_12', 'EMA_26', 'OBV']]
except Exception as e:
logger.error(f"特征工程异常: {str(e)}")
raise
def safe_training(ticker):
"""稳定可靠的训练函数"""
start_time = time.time()
try:
# 数据获取
logger.info(f"正在获取 [{ticker}] 数据...")
data = yf.download(
ticker,
period="1y",
interval="1d",
progress=False,
auto_adjust=True
)
# 数据有效性验证
if data.empty:
raise ValueError("获取数据为空")
if len(data) < 30:
raise ValueError("数据不足30天")
if 'Close' not in data.columns:
raise ValueError("缺少Close列")
# 显式处理空值
nan_count = data['Close'].isna().sum()
if nan_count > 5:
raise ValueError(f"检测到{n}个空值")
# 特征处理
processed_data = safe_feature_engineering(data)
# 准备训练数据
X = processed_data.drop(columns=['Close'])
y = processed_data['Close']
# 初始化模型
model = make_pipeline(
RobustScaler(),
LGBMRegressor(
n_estimators=80,
max_depth=4,
learning_rate=0.15,
verbosity=-1,
force_row_wise=True,
random_state=42
)
)
# 快速训练
logger.info("开始模型训练...")
model.fit(X.values[-200:], y.values[-200:]) # 使用最近200个数据点
# 生成预测
current_features = X.iloc[-1:].copy()
predictions = []
future_dates = pd.date_range(data.index[-1], periods=8)[1:]
for _ in range(7):
current_close = current_features['Close'].values[0]
pred = model.predict(current_features.values)[0]
predictions.append(pred)
# 更新特征
current_features['Returns'] = (pred - current_close) / current_close
current_features['Close'] = pred
current_features['Volatility'] = current_features['Volatility'].values[0]
return {
'historical': data,
'predictions': pd.Series(predictions, index=future_dates),
'time_used': time.time() - start_time
}
except Exception as e:
logger.error(f"训练异常: {str(e)}")
return None
def create_safe_plot(result):
"""安全绘图函数"""
fig = go.Figure()
# 历史价格
fig.add_trace(go.Scatter(
x=result['historical'].index,
y=result['historical']['Close'],
name='历史价格',
line=dict(color='#1f77b4', width=2)
))
# 预测价格
fig.add_trace(go.Scatter(
x=result['predictions'].index,
y=result['predictions'].values,
name='AI预测',
line=dict(color='#ff7f0e', width=2, dash='dot')
))
fig.update_layout(
title="股票价格预测",
xaxis_title="日期",
yaxis_title="价格 (USD)",
hovermode="x unified",
template="plotly_white"
)
return fig
def safe_predict_flow(ticker):
"""安全预测流程"""
try:
start_time = time.time()
yield "⌛ 正在快速分析中(30秒内完成)...", None, None
result = None
# 后台训练
def train_task():
nonlocal result
result = safe_training(ticker)
thread = threading.Thread(target=train_task)
thread.start()
# 等待结果
while thread.is_alive():
if time.time() - start_time > 28: # 预留2秒缓冲
yield "⏳ 系统正在处理,请稍候...", None, None
break
time.sleep(0.1)
if not result:
yield "⚠️ 分析失败,请检查股票代码", None, None
return
# 构建结果
info = f"""
✅ 分析成功(耗时:{result['time_used']:.1f}秒)
📅 最新预测:{result['predictions'].index[-1].strftime('%Y-%m-%d')}
💵 预测价格:{result['predictions'].iloc[-1]:.2f} USD
"""
risk = """
**风险提示**
1. 本预测仅供参考,不构成投资建议
2. 实际价格可能受市场波动影响
3. 预测误差可能随时间增加
"""
yield info, create_safe_plot(result), risk
except Exception as e:
logger.critical(f"系统异常: {traceback.format_exc()}")
yield "⚠️ 系统繁忙,请稍后再试", None, None
# 创建界面
with gr.Blocks(title="稳定版股票预测", theme=gr.themes.Soft()) as app:
gr.Markdown("# 📊 股票价格预测系统")
with gr.Row():
input_col = gr.Column(scale=2)
with input_col:
stock_input = gr.Textbox(
label="股票代码",
placeholder="输入代码 (如: AAPL, 00700.HK)",
max_lines=1
)
submit_btn = gr.Button("开始分析", variant="primary")
output_col = gr.Column(scale=3)
with output_col:
status = gr.Markdown("## 状态")
plot = gr.Plot(label="价格走势")
risk = gr.Markdown()
submit_btn.click(
safe_predict_flow,
inputs=stock_input,
outputs=[status, plot, risk]
)
if __name__ == "__main__":
# 启动配置
import warnings
warnings.filterwarnings("ignore")
# 稳定启动
app.queue(concurrency_count=2)
app.launch(
server_port=7860,
show_error=True,
ssr_mode=False # 禁用实验性功能
) |