kronos-jp / app.py
superyan's picture
Upload 14 files
ca2f5c8 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Kronos 日本股市AI预测系统 - 完整版
使用真正的 Kronos 模型进行预测
"""
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datetime import datetime, timedelta
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import os
# 导入 Kronos 模型
from kronos import Kronos, KronosTokenizer
# 热门日本股票
POPULAR_STOCKS = {
'7203.T': {'name': 'Toyota Motor Corp', 'name_ja': 'トヨタ自動車', 'sector': 'Automobile'},
'6758.T': {'name': 'Sony Group Corp', 'name_ja': 'ソニーグループ', 'sector': 'Technology'},
'8306.T': {'name': 'MUFG', 'name_ja': '三菱UFJフィナンシャルグループ', 'sector': 'Finance'},
'8035.T': {'name': 'Tokyo Electron', 'name_ja': '東京エレクトロン', 'sector': 'Technology'},
'7201.T': {'name': 'Nissan Motor Co', 'name_ja': '日産自動車', 'sector': 'Automobile'},
'7267.T': {'name': 'Honda Motor Co', 'name_ja': '本田技研工業', 'sector': 'Automobile'},
'^N225': {'name': 'Nikkei 225', 'name_ja': '日経平均株価', 'sector': 'Index'},
}
# 全局变量
kronos_model = None
kronos_tokenizer = None
device = None
def load_kronos_model():
"""加载 Kronos 模型和 Tokenizer"""
global kronos_model, kronos_tokenizer, device
try:
# 检测设备
if torch.cuda.is_available():
device = torch.device('cuda')
print("🖥️ 使用 GPU (CUDA)")
else:
device = torch.device('cpu')
print("🖥️ 使用 CPU")
print("📦 加载 Kronos Tokenizer...")
kronos_tokenizer = KronosTokenizer.from_pretrained('NeoQuasar/Kronos-Tokenizer-base')
kronos_tokenizer.to(device)
kronos_tokenizer.eval()
print("📦 加载 Kronos 模型...")
kronos_model = Kronos.from_pretrained('NeoQuasar/Kronos-base')
kronos_model.to(device)
kronos_model.eval()
print("✅ Kronos 模型加载成功")
return f"✅ Kronos 模型加载成功 (设备: {device})"
except Exception as e:
print(f"⚠️ Kronos 模型加载失败: {str(e)}")
print("💡 将使用简化的预测方法")
return f"⚠️ 模型加载失败: {str(e)}"
def fetch_stock_data(symbol, period='2y'):
"""从 Yahoo Finance 获取股票数据"""
try:
import yfinance as yf
print(f"📥 从 Yahoo Finance 获取 {symbol} 数据...")
ticker = yf.Ticker(symbol)
df = ticker.history(period=period)
if df.empty:
return None
df = df.reset_index()
df.columns = [col.lower() for col in df.columns]
# 处理日期列
if 'date' not in df.columns and 'index' in df.columns:
df = df.rename(columns={'index': 'date'})
required_cols = ['date', 'open', 'high', 'low', 'close', 'volume']
if not all(col in df.columns for col in required_cols):
return None
df = df[required_cols]
print(f"✅ 成功获取 {len(df)} 行数据")
return df
except Exception as e:
print(f"❌ Yahoo Finance 获取失败: {str(e)}")
return None
def generate_sample_data():
"""生成示例数据"""
dates = pd.date_range(end=datetime.now(), periods=520, freq='D')
base_price = 100
prices = [base_price]
for i in range(519):
change = np.random.randn() * 2
prices.append(prices[-1] + change)
return pd.DataFrame({
'date': dates,
'open': prices,
'high': [p + abs(np.random.randn()) for p in prices],
'low': [p - abs(np.random.randn()) for p in prices],
'close': [p + np.random.randn() * 0.5 for p in prices],
'volume': np.random.randint(1000000, 10000000, 520)
})
def create_advanced_chart(historical_df, pred_df=None):
"""创建高级 K 线图(带成交量)"""
# 创建子图:K线图 + 成交量
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
row_heights=[0.7, 0.3],
subplot_titles=('价格走势', '成交量')
)
# === 历史数据 K 线图 ===
fig.add_trace(
go.Candlestick(
x=historical_df['date'],
open=historical_df['open'],
high=historical_df['high'],
low=historical_df['low'],
close=historical_df['close'],
name='历史数据',
increasing_line_color='#26A69A',
decreasing_line_color='#EF5350',
increasing_fillcolor='#26A69A',
decreasing_fillcolor='#EF5350'
),
row=1, col=1
)
# === 预测数据 K 线图 ===
if pred_df is not None and len(pred_df) > 0:
fig.add_trace(
go.Candlestick(
x=pred_df['date'],
open=pred_df['open'],
high=pred_df['high'],
low=pred_df['low'],
close=pred_df['close'],
name='预测数据',
increasing_line_color='#4CAF50',
decreasing_line_color='#FF5252',
increasing_fillcolor='rgba(76, 175, 80, 0.3)',
decreasing_fillcolor='rgba(255, 82, 82, 0.3)',
opacity=0.8
),
row=1, col=1
)
# 添加预测区域的边界线
last_historical = historical_df.iloc[-1]
first_prediction = pred_df.iloc[0]
fig.add_trace(
go.Scatter(
x=[last_historical['date'], first_prediction['date']],
y=[last_historical['close'], first_prediction['open']],
mode='lines',
line=dict(color='yellow', width=2, dash='dash'),
name='预测起点',
showlegend=True
),
row=1, col=1
)
# === 历史成交量 ===
colors = ['#26A69A' if historical_df['close'].iloc[i] >= historical_df['open'].iloc[i]
else '#EF5350' for i in range(len(historical_df))]
fig.add_trace(
go.Bar(
x=historical_df['date'],
y=historical_df['volume'],
name='历史成交量',
marker_color=colors,
opacity=0.7
),
row=2, col=1
)
# === 预测成交量 ===
if pred_df is not None and len(pred_df) > 0:
pred_colors = ['#4CAF50' if pred_df['close'].iloc[i] >= pred_df['open'].iloc[i]
else '#FF5252' for i in range(len(pred_df))]
fig.add_trace(
go.Bar(
x=pred_df['date'],
y=pred_df['volume'],
name='预测成交量',
marker_color=pred_colors,
opacity=0.5
),
row=2, col=1
)
# === 添加移动平均线 ===
if len(historical_df) >= 20:
ma20 = historical_df['close'].rolling(window=20).mean()
fig.add_trace(
go.Scatter(
x=historical_df['date'],
y=ma20,
mode='lines',
name='MA20',
line=dict(color='orange', width=1.5),
opacity=0.7
),
row=1, col=1
)
if len(historical_df) >= 60:
ma60 = historical_df['close'].rolling(window=60).mean()
fig.add_trace(
go.Scatter(
x=historical_df['date'],
y=ma60,
mode='lines',
name='MA60',
line=dict(color='purple', width=1.5),
opacity=0.7
),
row=1, col=1
)
# === 布局设置 ===
fig.update_layout(
title={
'text': '📈 Kronos AI 股价预测分析',
'x': 0.5,
'xanchor': 'center',
'font': {'size': 20, 'color': '#FFFFFF'}
},
template='plotly_dark',
height=800,
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
),
hovermode='x unified',
xaxis_rangeslider_visible=False
)
# Y轴标签
fig.update_yaxes(title_text="价格 (¥)", row=1, col=1)
fig.update_yaxes(title_text="成交量", row=2, col=1)
fig.update_xaxes(title_text="日期", row=2, col=1)
return fig
def predict_with_kronos_model(df, pred_days, temperature=1.0, top_p=0.9):
"""使用 Kronos 模型进行预测"""
global kronos_model, kronos_tokenizer, device
if kronos_model is None or kronos_tokenizer is None:
raise Exception("Kronos 模型未加载")
try:
# 准备输入数据 (OHLCV)
lookback = min(400, len(df))
input_data = df.iloc[-lookback:][['open', 'high', 'low', 'close', 'volume']].values
# 归一化
mean = input_data.mean(axis=0)
std = input_data.std(axis=0) + 1e-8
input_normalized = (input_data - mean) / std
# 转换为 tensor
input_tensor = torch.FloatTensor(input_normalized).unsqueeze(0).to(device)
with torch.no_grad():
# 使用 tokenizer 编码
_, _, _, indices = kronos_tokenizer(input_tensor)
# 使用模型预测
predictions = []
current_input = indices
for _ in range(pred_days):
output = kronos_model(current_input)
# 采样
probs = torch.softmax(output[:, -1, :] / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
predictions.append(next_token)
current_input = torch.cat([current_input, next_token], dim=1)
# 解码预测结果
pred_indices = torch.cat(predictions, dim=1)
pred_output = kronos_tokenizer.decode(pred_indices)
# 反归一化
pred_denorm = pred_output.cpu().numpy() * std + mean
return pred_denorm
except Exception as e:
print(f"Kronos 模型预测失败: {str(e)}")
raise
def predict_stock(symbol, pred_days, temperature, top_p):
"""预测股票价格"""
try:
stock_info = POPULAR_STOCKS.get(symbol, {'name': symbol, 'name_ja': symbol})
# 获取实时数据
df = fetch_stock_data(symbol, period='2y')
if df is None:
print("⚠️ 使用示例数据")
df = generate_sample_data()
data_source = "示例数据"
else:
data_source = "Yahoo Finance 实时数据"
# 准备历史数据
lookback = min(400, len(df))
historical_df = df.iloc[-lookback:].copy()
last_date = historical_df['date'].iloc[-1]
last_close = historical_df['close'].iloc[-1]
# 生成预测
try:
if kronos_model is not None and kronos_tokenizer is not None:
# 使用真正的 Kronos 模型
pred_data = predict_with_kronos_model(historical_df, pred_days, temperature, top_p)
model_used = "Kronos-base (完整模型)"
else:
raise Exception("模型未加载")
except Exception as e:
print(f"⚠️ 使用简化预测: {str(e)}")
# 简化预测(随机游走)
pred_data = []
current = last_close
for _ in range(pred_days):
change = np.random.randn() * 2 * temperature
current = current + change
volatility = abs(np.random.randn()) * temperature
pred_data.append([
current, # open
current + volatility, # high
current - volatility, # low
current + np.random.randn() * 0.5, # close
np.random.randint(1000000, 10000000) # volume
])
pred_data = np.array(pred_data)
model_used = "简化算法"
# 创建预测 DataFrame
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_days, freq='D')
pred_df = pd.DataFrame({
'date': future_dates,
'open': pred_data[:, 0],
'high': pred_data[:, 1],
'low': pred_data[:, 2],
'close': pred_data[:, 3],
'volume': pred_data[:, 4].astype(int)
})
# 创建图表
chart = create_advanced_chart(historical_df, pred_df)
# 生成预测摘要
pred_close = pred_df['close'].values
pred_change = ((pred_close[-1] - last_close) / last_close * 100)
summary = f"""
## 📊 预测结果
**股票**: {stock_info['name_ja']} ({symbol})
**数据来源**: {data_source}
**模型**: {model_used}
**预测天数**: {pred_days}
### 💰 价格分析
- **当前价格**: ¥{last_close:.2f}
- **预测价格**: ¥{pred_close[-1]:.2f}
- **预测变化**: {pred_change:+.2f}%
### 📈 趋势分析
- **最高预测**: ¥{max(pred_close):.2f}
- **最低预测**: ¥{min(pred_close):.2f}
- **平均预测**: ¥{np.mean(pred_close):.2f}
- **波动率**: {np.std(pred_close):.2f}
### 📉 风险评估
- **上涨空间**: {((max(pred_close) - last_close) / last_close * 100):+.2f}%
- **下跌风险**: {((min(pred_close) - last_close) / last_close * 100):+.2f}%
---
⚠️ **免责声明**: 预测结果仅供参考,不构成投资建议。投资有风险,入市需谨慎。
"""
return chart, summary
except Exception as e:
error_msg = f"❌ 预测失败: {str(e)}\n\n请检查网络连接或稍后重试。"
print(error_msg)
import traceback
traceback.print_exc()
return None, error_msg
# 初始化
print("🚀 初始化 Kronos 系统...")
model_status = load_kronos_model()
# 创建 Gradio 界面
with gr.Blocks(theme=gr.themes.Soft(), title="Kronos 日本株AI予測", css="""
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.gr-button-primary {
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
border: none;
}
""") as demo:
gr.Markdown("""
# 📈 Kronos 日本株AI予測システム
使用 **NeoQuasar/Kronos-base** 模型预测日本股票价格走势
🤖 **模型**: Kronos-base (102.3M 参数,专为金融K线预测设计)
📡 **数据**: Yahoo Finance 实时数据
🎓 **论文**: AAAI 2026
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📊 选择股票")
stock_dropdown = gr.Dropdown(
choices=[(f"{info['name_ja']} ({code})", code) for code, info in POPULAR_STOCKS.items()],
value='7203.T',
label="股票代码",
interactive=True,
info="选择要预测的日本股票"
)
pred_days = gr.Slider(
minimum=7,
maximum=90,
value=30,
step=1,
label="📅 预测天数",
info="预测未来的天数"
)
gr.Markdown("### ⚙️ 预测参数")
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
step=0.1,
label="🌡️ 温度 (Temperature)",
info="控制预测的随机性,值越高越随机"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="🎯 Top-p",
info="控制预测的多样性"
)
predict_btn = gr.Button("🚀 开始预测", variant="primary", size="lg")
gr.Markdown("### 📦 模型状态")
status_box = gr.Textbox(
value=model_status,
label="系统状态",
interactive=False,
lines=2
)
with gr.Column(scale=2):
gr.Markdown("### 📈 预测结果")
chart_output = gr.Plot(label="价格走势图")
summary_output = gr.Markdown(label="预测摘要")
# 绑定事件
predict_btn.click(
fn=predict_stock,
inputs=[stock_dropdown, pred_days, temperature, top_p],
outputs=[chart_output, summary_output]
)
gr.Markdown("""
---
### 📝 关于 Kronos 模型
**Kronos** 是首个面向金融K线图的开源基础模型,基于全球超过45家交易所的数据训练而成。
#### 🔬 技术特点
- **专用分词器**: 将连续的多维K线数据(OHLCV)量化为分层离散令牌
- **自回归Transformer**: 基于令牌预训练的大型模型
- **两阶段框架**: Tokenizer + Predictor
#### 📚 学术信息
- **开发者**: NeoQuasar (shiyu-coder)
- **论文**: [arXiv:2508.02739](https://arxiv.org/abs/2508.02739)
- **会议**: AAAI 2026
- **模型**: [Hugging Face](https://huggingface.co/NeoQuasar/Kronos-base)
#### 🔧 技术栈
- **模型**: NeoQuasar/Kronos-base (102.3M 参数)
- **框架**: Gradio + PyTorch
- **数据**: Yahoo Finance API
- **可视化**: Plotly (交互式K线图)
#### ⚠️ 免责声明
本系统仅供学习和研究使用,预测结果不构成投资建议。投资有风险,入市需谨慎。
""")
if __name__ == "__main__":
demo.launch()