#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Kronos 日本股市AI预测系统 - 完整版 使用真正的 Kronos 模型进行预测 """ import gradio as gr import pandas as pd import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots from datetime import datetime, timedelta import torch import torch.nn as nn from huggingface_hub import hf_hub_download import os # 导入 Kronos 模型 from kronos import Kronos, KronosTokenizer # 热门日本股票 POPULAR_STOCKS = { '7203.T': {'name': 'Toyota Motor Corp', 'name_ja': 'トヨタ自動車', 'sector': 'Automobile'}, '6758.T': {'name': 'Sony Group Corp', 'name_ja': 'ソニーグループ', 'sector': 'Technology'}, '8306.T': {'name': 'MUFG', 'name_ja': '三菱UFJフィナンシャルグループ', 'sector': 'Finance'}, '8035.T': {'name': 'Tokyo Electron', 'name_ja': '東京エレクトロン', 'sector': 'Technology'}, '7201.T': {'name': 'Nissan Motor Co', 'name_ja': '日産自動車', 'sector': 'Automobile'}, '7267.T': {'name': 'Honda Motor Co', 'name_ja': '本田技研工業', 'sector': 'Automobile'}, '^N225': {'name': 'Nikkei 225', 'name_ja': '日経平均株価', 'sector': 'Index'}, } # 全局变量 kronos_model = None kronos_tokenizer = None device = None def load_kronos_model(): """加载 Kronos 模型和 Tokenizer""" global kronos_model, kronos_tokenizer, device try: # 检测设备 if torch.cuda.is_available(): device = torch.device('cuda') print("🖥️ 使用 GPU (CUDA)") else: device = torch.device('cpu') print("🖥️ 使用 CPU") print("📦 加载 Kronos Tokenizer...") kronos_tokenizer = KronosTokenizer.from_pretrained('NeoQuasar/Kronos-Tokenizer-base') kronos_tokenizer.to(device) kronos_tokenizer.eval() print("📦 加载 Kronos 模型...") kronos_model = Kronos.from_pretrained('NeoQuasar/Kronos-base') kronos_model.to(device) kronos_model.eval() print("✅ Kronos 模型加载成功") return f"✅ Kronos 模型加载成功 (设备: {device})" except Exception as e: print(f"⚠️ Kronos 模型加载失败: {str(e)}") print("💡 将使用简化的预测方法") return f"⚠️ 模型加载失败: {str(e)}" def fetch_stock_data(symbol, period='2y'): """从 Yahoo Finance 获取股票数据""" try: import yfinance as yf print(f"📥 从 Yahoo Finance 获取 {symbol} 数据...") ticker = yf.Ticker(symbol) df = ticker.history(period=period) if df.empty: return None df = df.reset_index() df.columns = [col.lower() for col in df.columns] # 处理日期列 if 'date' not in df.columns and 'index' in df.columns: df = df.rename(columns={'index': 'date'}) required_cols = ['date', 'open', 'high', 'low', 'close', 'volume'] if not all(col in df.columns for col in required_cols): return None df = df[required_cols] print(f"✅ 成功获取 {len(df)} 行数据") return df except Exception as e: print(f"❌ Yahoo Finance 获取失败: {str(e)}") return None def generate_sample_data(): """生成示例数据""" dates = pd.date_range(end=datetime.now(), periods=520, freq='D') base_price = 100 prices = [base_price] for i in range(519): change = np.random.randn() * 2 prices.append(prices[-1] + change) return pd.DataFrame({ 'date': dates, 'open': prices, 'high': [p + abs(np.random.randn()) for p in prices], 'low': [p - abs(np.random.randn()) for p in prices], 'close': [p + np.random.randn() * 0.5 for p in prices], 'volume': np.random.randint(1000000, 10000000, 520) }) def create_advanced_chart(historical_df, pred_df=None): """创建高级 K 线图(带成交量)""" # 创建子图:K线图 + 成交量 fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.03, row_heights=[0.7, 0.3], subplot_titles=('价格走势', '成交量') ) # === 历史数据 K 线图 === fig.add_trace( go.Candlestick( x=historical_df['date'], open=historical_df['open'], high=historical_df['high'], low=historical_df['low'], close=historical_df['close'], name='历史数据', increasing_line_color='#26A69A', decreasing_line_color='#EF5350', increasing_fillcolor='#26A69A', decreasing_fillcolor='#EF5350' ), row=1, col=1 ) # === 预测数据 K 线图 === if pred_df is not None and len(pred_df) > 0: fig.add_trace( go.Candlestick( x=pred_df['date'], open=pred_df['open'], high=pred_df['high'], low=pred_df['low'], close=pred_df['close'], name='预测数据', increasing_line_color='#4CAF50', decreasing_line_color='#FF5252', increasing_fillcolor='rgba(76, 175, 80, 0.3)', decreasing_fillcolor='rgba(255, 82, 82, 0.3)', opacity=0.8 ), row=1, col=1 ) # 添加预测区域的边界线 last_historical = historical_df.iloc[-1] first_prediction = pred_df.iloc[0] fig.add_trace( go.Scatter( x=[last_historical['date'], first_prediction['date']], y=[last_historical['close'], first_prediction['open']], mode='lines', line=dict(color='yellow', width=2, dash='dash'), name='预测起点', showlegend=True ), row=1, col=1 ) # === 历史成交量 === colors = ['#26A69A' if historical_df['close'].iloc[i] >= historical_df['open'].iloc[i] else '#EF5350' for i in range(len(historical_df))] fig.add_trace( go.Bar( x=historical_df['date'], y=historical_df['volume'], name='历史成交量', marker_color=colors, opacity=0.7 ), row=2, col=1 ) # === 预测成交量 === if pred_df is not None and len(pred_df) > 0: pred_colors = ['#4CAF50' if pred_df['close'].iloc[i] >= pred_df['open'].iloc[i] else '#FF5252' for i in range(len(pred_df))] fig.add_trace( go.Bar( x=pred_df['date'], y=pred_df['volume'], name='预测成交量', marker_color=pred_colors, opacity=0.5 ), row=2, col=1 ) # === 添加移动平均线 === if len(historical_df) >= 20: ma20 = historical_df['close'].rolling(window=20).mean() fig.add_trace( go.Scatter( x=historical_df['date'], y=ma20, mode='lines', name='MA20', line=dict(color='orange', width=1.5), opacity=0.7 ), row=1, col=1 ) if len(historical_df) >= 60: ma60 = historical_df['close'].rolling(window=60).mean() fig.add_trace( go.Scatter( x=historical_df['date'], y=ma60, mode='lines', name='MA60', line=dict(color='purple', width=1.5), opacity=0.7 ), row=1, col=1 ) # === 布局设置 === fig.update_layout( title={ 'text': '📈 Kronos AI 股价预测分析', 'x': 0.5, 'xanchor': 'center', 'font': {'size': 20, 'color': '#FFFFFF'} }, template='plotly_dark', height=800, showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ), hovermode='x unified', xaxis_rangeslider_visible=False ) # Y轴标签 fig.update_yaxes(title_text="价格 (¥)", row=1, col=1) fig.update_yaxes(title_text="成交量", row=2, col=1) fig.update_xaxes(title_text="日期", row=2, col=1) return fig def predict_with_kronos_model(df, pred_days, temperature=1.0, top_p=0.9): """使用 Kronos 模型进行预测""" global kronos_model, kronos_tokenizer, device if kronos_model is None or kronos_tokenizer is None: raise Exception("Kronos 模型未加载") try: # 准备输入数据 (OHLCV) lookback = min(400, len(df)) input_data = df.iloc[-lookback:][['open', 'high', 'low', 'close', 'volume']].values # 归一化 mean = input_data.mean(axis=0) std = input_data.std(axis=0) + 1e-8 input_normalized = (input_data - mean) / std # 转换为 tensor input_tensor = torch.FloatTensor(input_normalized).unsqueeze(0).to(device) with torch.no_grad(): # 使用 tokenizer 编码 _, _, _, indices = kronos_tokenizer(input_tensor) # 使用模型预测 predictions = [] current_input = indices for _ in range(pred_days): output = kronos_model(current_input) # 采样 probs = torch.softmax(output[:, -1, :] / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) predictions.append(next_token) current_input = torch.cat([current_input, next_token], dim=1) # 解码预测结果 pred_indices = torch.cat(predictions, dim=1) pred_output = kronos_tokenizer.decode(pred_indices) # 反归一化 pred_denorm = pred_output.cpu().numpy() * std + mean return pred_denorm except Exception as e: print(f"Kronos 模型预测失败: {str(e)}") raise def predict_stock(symbol, pred_days, temperature, top_p): """预测股票价格""" try: stock_info = POPULAR_STOCKS.get(symbol, {'name': symbol, 'name_ja': symbol}) # 获取实时数据 df = fetch_stock_data(symbol, period='2y') if df is None: print("⚠️ 使用示例数据") df = generate_sample_data() data_source = "示例数据" else: data_source = "Yahoo Finance 实时数据" # 准备历史数据 lookback = min(400, len(df)) historical_df = df.iloc[-lookback:].copy() last_date = historical_df['date'].iloc[-1] last_close = historical_df['close'].iloc[-1] # 生成预测 try: if kronos_model is not None and kronos_tokenizer is not None: # 使用真正的 Kronos 模型 pred_data = predict_with_kronos_model(historical_df, pred_days, temperature, top_p) model_used = "Kronos-base (完整模型)" else: raise Exception("模型未加载") except Exception as e: print(f"⚠️ 使用简化预测: {str(e)}") # 简化预测(随机游走) pred_data = [] current = last_close for _ in range(pred_days): change = np.random.randn() * 2 * temperature current = current + change volatility = abs(np.random.randn()) * temperature pred_data.append([ current, # open current + volatility, # high current - volatility, # low current + np.random.randn() * 0.5, # close np.random.randint(1000000, 10000000) # volume ]) pred_data = np.array(pred_data) model_used = "简化算法" # 创建预测 DataFrame future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_days, freq='D') pred_df = pd.DataFrame({ 'date': future_dates, 'open': pred_data[:, 0], 'high': pred_data[:, 1], 'low': pred_data[:, 2], 'close': pred_data[:, 3], 'volume': pred_data[:, 4].astype(int) }) # 创建图表 chart = create_advanced_chart(historical_df, pred_df) # 生成预测摘要 pred_close = pred_df['close'].values pred_change = ((pred_close[-1] - last_close) / last_close * 100) summary = f""" ## 📊 预测结果 **股票**: {stock_info['name_ja']} ({symbol}) **数据来源**: {data_source} **模型**: {model_used} **预测天数**: {pred_days} 天 ### 💰 价格分析 - **当前价格**: ¥{last_close:.2f} - **预测价格**: ¥{pred_close[-1]:.2f} - **预测变化**: {pred_change:+.2f}% ### 📈 趋势分析 - **最高预测**: ¥{max(pred_close):.2f} - **最低预测**: ¥{min(pred_close):.2f} - **平均预测**: ¥{np.mean(pred_close):.2f} - **波动率**: {np.std(pred_close):.2f} ### 📉 风险评估 - **上涨空间**: {((max(pred_close) - last_close) / last_close * 100):+.2f}% - **下跌风险**: {((min(pred_close) - last_close) / last_close * 100):+.2f}% --- ⚠️ **免责声明**: 预测结果仅供参考,不构成投资建议。投资有风险,入市需谨慎。 """ return chart, summary except Exception as e: error_msg = f"❌ 预测失败: {str(e)}\n\n请检查网络连接或稍后重试。" print(error_msg) import traceback traceback.print_exc() return None, error_msg # 初始化 print("🚀 初始化 Kronos 系统...") model_status = load_kronos_model() # 创建 Gradio 界面 with gr.Blocks(theme=gr.themes.Soft(), title="Kronos 日本株AI予測", css=""" .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .gr-button-primary { background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); border: none; } """) as demo: gr.Markdown(""" # 📈 Kronos 日本株AI予測システム 使用 **NeoQuasar/Kronos-base** 模型预测日本股票价格走势 🤖 **模型**: Kronos-base (102.3M 参数,专为金融K线预测设计) 📡 **数据**: Yahoo Finance 实时数据 🎓 **论文**: AAAI 2026 """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📊 选择股票") stock_dropdown = gr.Dropdown( choices=[(f"{info['name_ja']} ({code})", code) for code, info in POPULAR_STOCKS.items()], value='7203.T', label="股票代码", interactive=True, info="选择要预测的日本股票" ) pred_days = gr.Slider( minimum=7, maximum=90, value=30, step=1, label="📅 预测天数", info="预测未来的天数" ) gr.Markdown("### ⚙️ 预测参数") temperature = gr.Slider( minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="🌡️ 温度 (Temperature)", info="控制预测的随机性,值越高越随机" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="🎯 Top-p", info="控制预测的多样性" ) predict_btn = gr.Button("🚀 开始预测", variant="primary", size="lg") gr.Markdown("### 📦 模型状态") status_box = gr.Textbox( value=model_status, label="系统状态", interactive=False, lines=2 ) with gr.Column(scale=2): gr.Markdown("### 📈 预测结果") chart_output = gr.Plot(label="价格走势图") summary_output = gr.Markdown(label="预测摘要") # 绑定事件 predict_btn.click( fn=predict_stock, inputs=[stock_dropdown, pred_days, temperature, top_p], outputs=[chart_output, summary_output] ) gr.Markdown(""" --- ### 📝 关于 Kronos 模型 **Kronos** 是首个面向金融K线图的开源基础模型,基于全球超过45家交易所的数据训练而成。 #### 🔬 技术特点 - **专用分词器**: 将连续的多维K线数据(OHLCV)量化为分层离散令牌 - **自回归Transformer**: 基于令牌预训练的大型模型 - **两阶段框架**: Tokenizer + Predictor #### 📚 学术信息 - **开发者**: NeoQuasar (shiyu-coder) - **论文**: [arXiv:2508.02739](https://arxiv.org/abs/2508.02739) - **会议**: AAAI 2026 - **模型**: [Hugging Face](https://huggingface.co/NeoQuasar/Kronos-base) #### 🔧 技术栈 - **模型**: NeoQuasar/Kronos-base (102.3M 参数) - **框架**: Gradio + PyTorch - **数据**: Yahoo Finance API - **可视化**: Plotly (交互式K线图) #### ⚠️ 免责声明 本系统仅供学习和研究使用,预测结果不构成投资建议。投资有风险,入市需谨慎。 """) if __name__ == "__main__": demo.launch()