|
|
|
|
|
|
|
|
""" |
|
|
Kronos 日本股市AI预测系统 - 完整版 |
|
|
使用真正的 Kronos 模型进行预测 |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
from datetime import datetime, timedelta |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from huggingface_hub import hf_hub_download |
|
|
import os |
|
|
|
|
|
|
|
|
from kronos import Kronos, KronosTokenizer |
|
|
|
|
|
|
|
|
POPULAR_STOCKS = { |
|
|
'7203.T': {'name': 'Toyota Motor Corp', 'name_ja': 'トヨタ自動車', 'sector': 'Automobile'}, |
|
|
'6758.T': {'name': 'Sony Group Corp', 'name_ja': 'ソニーグループ', 'sector': 'Technology'}, |
|
|
'8306.T': {'name': 'MUFG', 'name_ja': '三菱UFJフィナンシャルグループ', 'sector': 'Finance'}, |
|
|
'8035.T': {'name': 'Tokyo Electron', 'name_ja': '東京エレクトロン', 'sector': 'Technology'}, |
|
|
'7201.T': {'name': 'Nissan Motor Co', 'name_ja': '日産自動車', 'sector': 'Automobile'}, |
|
|
'7267.T': {'name': 'Honda Motor Co', 'name_ja': '本田技研工業', 'sector': 'Automobile'}, |
|
|
'^N225': {'name': 'Nikkei 225', 'name_ja': '日経平均株価', 'sector': 'Index'}, |
|
|
} |
|
|
|
|
|
|
|
|
kronos_model = None |
|
|
kronos_tokenizer = None |
|
|
device = None |
|
|
|
|
|
def load_kronos_model(): |
|
|
"""加载 Kronos 模型和 Tokenizer""" |
|
|
global kronos_model, kronos_tokenizer, device |
|
|
|
|
|
try: |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
print("🖥️ 使用 GPU (CUDA)") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
print("🖥️ 使用 CPU") |
|
|
|
|
|
print("📦 加载 Kronos Tokenizer...") |
|
|
kronos_tokenizer = KronosTokenizer.from_pretrained('NeoQuasar/Kronos-Tokenizer-base') |
|
|
kronos_tokenizer.to(device) |
|
|
kronos_tokenizer.eval() |
|
|
|
|
|
print("📦 加载 Kronos 模型...") |
|
|
kronos_model = Kronos.from_pretrained('NeoQuasar/Kronos-base') |
|
|
kronos_model.to(device) |
|
|
kronos_model.eval() |
|
|
|
|
|
print("✅ Kronos 模型加载成功") |
|
|
return f"✅ Kronos 模型加载成功 (设备: {device})" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Kronos 模型加载失败: {str(e)}") |
|
|
print("💡 将使用简化的预测方法") |
|
|
return f"⚠️ 模型加载失败: {str(e)}" |
|
|
|
|
|
def fetch_stock_data(symbol, period='2y'): |
|
|
"""从 Yahoo Finance 获取股票数据""" |
|
|
try: |
|
|
import yfinance as yf |
|
|
print(f"📥 从 Yahoo Finance 获取 {symbol} 数据...") |
|
|
|
|
|
ticker = yf.Ticker(symbol) |
|
|
df = ticker.history(period=period) |
|
|
|
|
|
if df.empty: |
|
|
return None |
|
|
|
|
|
df = df.reset_index() |
|
|
df.columns = [col.lower() for col in df.columns] |
|
|
|
|
|
|
|
|
if 'date' not in df.columns and 'index' in df.columns: |
|
|
df = df.rename(columns={'index': 'date'}) |
|
|
|
|
|
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume'] |
|
|
if not all(col in df.columns for col in required_cols): |
|
|
return None |
|
|
|
|
|
df = df[required_cols] |
|
|
print(f"✅ 成功获取 {len(df)} 行数据") |
|
|
return df |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Yahoo Finance 获取失败: {str(e)}") |
|
|
return None |
|
|
|
|
|
def generate_sample_data(): |
|
|
"""生成示例数据""" |
|
|
dates = pd.date_range(end=datetime.now(), periods=520, freq='D') |
|
|
base_price = 100 |
|
|
prices = [base_price] |
|
|
|
|
|
for i in range(519): |
|
|
change = np.random.randn() * 2 |
|
|
prices.append(prices[-1] + change) |
|
|
|
|
|
return pd.DataFrame({ |
|
|
'date': dates, |
|
|
'open': prices, |
|
|
'high': [p + abs(np.random.randn()) for p in prices], |
|
|
'low': [p - abs(np.random.randn()) for p in prices], |
|
|
'close': [p + np.random.randn() * 0.5 for p in prices], |
|
|
'volume': np.random.randint(1000000, 10000000, 520) |
|
|
}) |
|
|
|
|
|
def create_advanced_chart(historical_df, pred_df=None): |
|
|
"""创建高级 K 线图(带成交量)""" |
|
|
|
|
|
|
|
|
fig = make_subplots( |
|
|
rows=2, cols=1, |
|
|
shared_xaxes=True, |
|
|
vertical_spacing=0.03, |
|
|
row_heights=[0.7, 0.3], |
|
|
subplot_titles=('价格走势', '成交量') |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_trace( |
|
|
go.Candlestick( |
|
|
x=historical_df['date'], |
|
|
open=historical_df['open'], |
|
|
high=historical_df['high'], |
|
|
low=historical_df['low'], |
|
|
close=historical_df['close'], |
|
|
name='历史数据', |
|
|
increasing_line_color='#26A69A', |
|
|
decreasing_line_color='#EF5350', |
|
|
increasing_fillcolor='#26A69A', |
|
|
decreasing_fillcolor='#EF5350' |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
if pred_df is not None and len(pred_df) > 0: |
|
|
fig.add_trace( |
|
|
go.Candlestick( |
|
|
x=pred_df['date'], |
|
|
open=pred_df['open'], |
|
|
high=pred_df['high'], |
|
|
low=pred_df['low'], |
|
|
close=pred_df['close'], |
|
|
name='预测数据', |
|
|
increasing_line_color='#4CAF50', |
|
|
decreasing_line_color='#FF5252', |
|
|
increasing_fillcolor='rgba(76, 175, 80, 0.3)', |
|
|
decreasing_fillcolor='rgba(255, 82, 82, 0.3)', |
|
|
opacity=0.8 |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
last_historical = historical_df.iloc[-1] |
|
|
first_prediction = pred_df.iloc[0] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=[last_historical['date'], first_prediction['date']], |
|
|
y=[last_historical['close'], first_prediction['open']], |
|
|
mode='lines', |
|
|
line=dict(color='yellow', width=2, dash='dash'), |
|
|
name='预测起点', |
|
|
showlegend=True |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
colors = ['#26A69A' if historical_df['close'].iloc[i] >= historical_df['open'].iloc[i] |
|
|
else '#EF5350' for i in range(len(historical_df))] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=historical_df['date'], |
|
|
y=historical_df['volume'], |
|
|
name='历史成交量', |
|
|
marker_color=colors, |
|
|
opacity=0.7 |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
if pred_df is not None and len(pred_df) > 0: |
|
|
pred_colors = ['#4CAF50' if pred_df['close'].iloc[i] >= pred_df['open'].iloc[i] |
|
|
else '#FF5252' for i in range(len(pred_df))] |
|
|
|
|
|
fig.add_trace( |
|
|
go.Bar( |
|
|
x=pred_df['date'], |
|
|
y=pred_df['volume'], |
|
|
name='预测成交量', |
|
|
marker_color=pred_colors, |
|
|
opacity=0.5 |
|
|
), |
|
|
row=2, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
if len(historical_df) >= 20: |
|
|
ma20 = historical_df['close'].rolling(window=20).mean() |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=historical_df['date'], |
|
|
y=ma20, |
|
|
mode='lines', |
|
|
name='MA20', |
|
|
line=dict(color='orange', width=1.5), |
|
|
opacity=0.7 |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
if len(historical_df) >= 60: |
|
|
ma60 = historical_df['close'].rolling(window=60).mean() |
|
|
fig.add_trace( |
|
|
go.Scatter( |
|
|
x=historical_df['date'], |
|
|
y=ma60, |
|
|
mode='lines', |
|
|
name='MA60', |
|
|
line=dict(color='purple', width=1.5), |
|
|
opacity=0.7 |
|
|
), |
|
|
row=1, col=1 |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_layout( |
|
|
title={ |
|
|
'text': '📈 Kronos AI 股价预测分析', |
|
|
'x': 0.5, |
|
|
'xanchor': 'center', |
|
|
'font': {'size': 20, 'color': '#FFFFFF'} |
|
|
}, |
|
|
template='plotly_dark', |
|
|
height=800, |
|
|
showlegend=True, |
|
|
legend=dict( |
|
|
orientation="h", |
|
|
yanchor="bottom", |
|
|
y=1.02, |
|
|
xanchor="right", |
|
|
x=1 |
|
|
), |
|
|
hovermode='x unified', |
|
|
xaxis_rangeslider_visible=False |
|
|
) |
|
|
|
|
|
|
|
|
fig.update_yaxes(title_text="价格 (¥)", row=1, col=1) |
|
|
fig.update_yaxes(title_text="成交量", row=2, col=1) |
|
|
fig.update_xaxes(title_text="日期", row=2, col=1) |
|
|
|
|
|
return fig |
|
|
|
|
|
def predict_with_kronos_model(df, pred_days, temperature=1.0, top_p=0.9): |
|
|
"""使用 Kronos 模型进行预测""" |
|
|
global kronos_model, kronos_tokenizer, device |
|
|
|
|
|
if kronos_model is None or kronos_tokenizer is None: |
|
|
raise Exception("Kronos 模型未加载") |
|
|
|
|
|
try: |
|
|
|
|
|
lookback = min(400, len(df)) |
|
|
input_data = df.iloc[-lookback:][['open', 'high', 'low', 'close', 'volume']].values |
|
|
|
|
|
|
|
|
mean = input_data.mean(axis=0) |
|
|
std = input_data.std(axis=0) + 1e-8 |
|
|
input_normalized = (input_data - mean) / std |
|
|
|
|
|
|
|
|
input_tensor = torch.FloatTensor(input_normalized).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
_, _, _, indices = kronos_tokenizer(input_tensor) |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
current_input = indices |
|
|
|
|
|
for _ in range(pred_days): |
|
|
output = kronos_model(current_input) |
|
|
|
|
|
probs = torch.softmax(output[:, -1, :] / temperature, dim=-1) |
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
predictions.append(next_token) |
|
|
current_input = torch.cat([current_input, next_token], dim=1) |
|
|
|
|
|
|
|
|
pred_indices = torch.cat(predictions, dim=1) |
|
|
pred_output = kronos_tokenizer.decode(pred_indices) |
|
|
|
|
|
|
|
|
pred_denorm = pred_output.cpu().numpy() * std + mean |
|
|
|
|
|
return pred_denorm |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Kronos 模型预测失败: {str(e)}") |
|
|
raise |
|
|
|
|
|
def predict_stock(symbol, pred_days, temperature, top_p): |
|
|
"""预测股票价格""" |
|
|
try: |
|
|
stock_info = POPULAR_STOCKS.get(symbol, {'name': symbol, 'name_ja': symbol}) |
|
|
|
|
|
|
|
|
df = fetch_stock_data(symbol, period='2y') |
|
|
|
|
|
if df is None: |
|
|
print("⚠️ 使用示例数据") |
|
|
df = generate_sample_data() |
|
|
data_source = "示例数据" |
|
|
else: |
|
|
data_source = "Yahoo Finance 实时数据" |
|
|
|
|
|
|
|
|
lookback = min(400, len(df)) |
|
|
historical_df = df.iloc[-lookback:].copy() |
|
|
|
|
|
last_date = historical_df['date'].iloc[-1] |
|
|
last_close = historical_df['close'].iloc[-1] |
|
|
|
|
|
|
|
|
try: |
|
|
if kronos_model is not None and kronos_tokenizer is not None: |
|
|
|
|
|
pred_data = predict_with_kronos_model(historical_df, pred_days, temperature, top_p) |
|
|
model_used = "Kronos-base (完整模型)" |
|
|
else: |
|
|
raise Exception("模型未加载") |
|
|
except Exception as e: |
|
|
print(f"⚠️ 使用简化预测: {str(e)}") |
|
|
|
|
|
pred_data = [] |
|
|
current = last_close |
|
|
for _ in range(pred_days): |
|
|
change = np.random.randn() * 2 * temperature |
|
|
current = current + change |
|
|
volatility = abs(np.random.randn()) * temperature |
|
|
pred_data.append([ |
|
|
current, |
|
|
current + volatility, |
|
|
current - volatility, |
|
|
current + np.random.randn() * 0.5, |
|
|
np.random.randint(1000000, 10000000) |
|
|
]) |
|
|
pred_data = np.array(pred_data) |
|
|
model_used = "简化算法" |
|
|
|
|
|
|
|
|
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_days, freq='D') |
|
|
pred_df = pd.DataFrame({ |
|
|
'date': future_dates, |
|
|
'open': pred_data[:, 0], |
|
|
'high': pred_data[:, 1], |
|
|
'low': pred_data[:, 2], |
|
|
'close': pred_data[:, 3], |
|
|
'volume': pred_data[:, 4].astype(int) |
|
|
}) |
|
|
|
|
|
|
|
|
chart = create_advanced_chart(historical_df, pred_df) |
|
|
|
|
|
|
|
|
pred_close = pred_df['close'].values |
|
|
pred_change = ((pred_close[-1] - last_close) / last_close * 100) |
|
|
|
|
|
summary = f""" |
|
|
## 📊 预测结果 |
|
|
|
|
|
**股票**: {stock_info['name_ja']} ({symbol}) |
|
|
**数据来源**: {data_source} |
|
|
**模型**: {model_used} |
|
|
**预测天数**: {pred_days} 天 |
|
|
|
|
|
### 💰 价格分析 |
|
|
- **当前价格**: ¥{last_close:.2f} |
|
|
- **预测价格**: ¥{pred_close[-1]:.2f} |
|
|
- **预测变化**: {pred_change:+.2f}% |
|
|
|
|
|
### 📈 趋势分析 |
|
|
- **最高预测**: ¥{max(pred_close):.2f} |
|
|
- **最低预测**: ¥{min(pred_close):.2f} |
|
|
- **平均预测**: ¥{np.mean(pred_close):.2f} |
|
|
- **波动率**: {np.std(pred_close):.2f} |
|
|
|
|
|
### 📉 风险评估 |
|
|
- **上涨空间**: {((max(pred_close) - last_close) / last_close * 100):+.2f}% |
|
|
- **下跌风险**: {((min(pred_close) - last_close) / last_close * 100):+.2f}% |
|
|
|
|
|
--- |
|
|
|
|
|
⚠️ **免责声明**: 预测结果仅供参考,不构成投资建议。投资有风险,入市需谨慎。 |
|
|
""" |
|
|
|
|
|
return chart, summary |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"❌ 预测失败: {str(e)}\n\n请检查网络连接或稍后重试。" |
|
|
print(error_msg) |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
print("🚀 初始化 Kronos 系统...") |
|
|
model_status = load_kronos_model() |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="Kronos 日本株AI予測", css=""" |
|
|
.gradio-container { |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
} |
|
|
.gr-button-primary { |
|
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
|
border: none; |
|
|
} |
|
|
""") as demo: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 📈 Kronos 日本株AI予測システム |
|
|
|
|
|
使用 **NeoQuasar/Kronos-base** 模型预测日本股票价格走势 |
|
|
|
|
|
🤖 **模型**: Kronos-base (102.3M 参数,专为金融K线预测设计) |
|
|
📡 **数据**: Yahoo Finance 实时数据 |
|
|
🎓 **论文**: AAAI 2026 |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📊 选择股票") |
|
|
|
|
|
stock_dropdown = gr.Dropdown( |
|
|
choices=[(f"{info['name_ja']} ({code})", code) for code, info in POPULAR_STOCKS.items()], |
|
|
value='7203.T', |
|
|
label="股票代码", |
|
|
interactive=True, |
|
|
info="选择要预测的日本股票" |
|
|
) |
|
|
|
|
|
pred_days = gr.Slider( |
|
|
minimum=7, |
|
|
maximum=90, |
|
|
value=30, |
|
|
step=1, |
|
|
label="📅 预测天数", |
|
|
info="预测未来的天数" |
|
|
) |
|
|
|
|
|
gr.Markdown("### ⚙️ 预测参数") |
|
|
|
|
|
temperature = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="🌡️ 温度 (Temperature)", |
|
|
info="控制预测的随机性,值越高越随机" |
|
|
) |
|
|
|
|
|
top_p = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="🎯 Top-p", |
|
|
info="控制预测的多样性" |
|
|
) |
|
|
|
|
|
predict_btn = gr.Button("🚀 开始预测", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("### 📦 模型状态") |
|
|
status_box = gr.Textbox( |
|
|
value=model_status, |
|
|
label="系统状态", |
|
|
interactive=False, |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown("### 📈 预测结果") |
|
|
|
|
|
chart_output = gr.Plot(label="价格走势图") |
|
|
summary_output = gr.Markdown(label="预测摘要") |
|
|
|
|
|
|
|
|
predict_btn.click( |
|
|
fn=predict_stock, |
|
|
inputs=[stock_dropdown, pred_days, temperature, top_p], |
|
|
outputs=[chart_output, summary_output] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
### 📝 关于 Kronos 模型 |
|
|
|
|
|
**Kronos** 是首个面向金融K线图的开源基础模型,基于全球超过45家交易所的数据训练而成。 |
|
|
|
|
|
#### 🔬 技术特点 |
|
|
- **专用分词器**: 将连续的多维K线数据(OHLCV)量化为分层离散令牌 |
|
|
- **自回归Transformer**: 基于令牌预训练的大型模型 |
|
|
- **两阶段框架**: Tokenizer + Predictor |
|
|
|
|
|
#### 📚 学术信息 |
|
|
- **开发者**: NeoQuasar (shiyu-coder) |
|
|
- **论文**: [arXiv:2508.02739](https://arxiv.org/abs/2508.02739) |
|
|
- **会议**: AAAI 2026 |
|
|
- **模型**: [Hugging Face](https://huggingface.co/NeoQuasar/Kronos-base) |
|
|
|
|
|
#### 🔧 技术栈 |
|
|
- **模型**: NeoQuasar/Kronos-base (102.3M 参数) |
|
|
- **框架**: Gradio + PyTorch |
|
|
- **数据**: Yahoo Finance API |
|
|
- **可视化**: Plotly (交互式K线图) |
|
|
|
|
|
#### ⚠️ 免责声明 |
|
|
本系统仅供学习和研究使用,预测结果不构成投资建议。投资有风险,入市需谨慎。 |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|