Kronosdemo / model.py
leehao163's picture
Update model.py
703fe7b verified
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download # 仅保留hf_hub_download
import os
# ------------- KronosTokenizer 分词器类 -------------
class KronosTokenizer(nn.Module):
def __init__(self, vocab_size=1024, embed_dim=128):
super().__init__()
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, embed_dim)
# 量化参数(对应OHLCV5个特征)
self.scale = nn.Parameter(torch.ones(5))
self.shift = nn.Parameter(torch.zeros(5))
@classmethod
def from_pretrained(cls, model_id, **kwargs):
"""从Hugging Face Hub加载预训练分词器"""
model = cls(**kwargs)
try:
# 下载分词器权重(适配Kronos的权重命名)
weight_path = hf_hub_download(
repo_id=model_id,
filename="tokenizer_weights.bin",
cache_dir="./cache" # 本地缓存,避免重复下载
)
model.load_state_dict(torch.load(weight_path, map_location="cpu", weights_only=True))
except:
# 若权重文件命名不同,尝试加载pytorch_model.bin
weight_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir="./cache")
model.load_state_dict(torch.load(weight_path, map_location="cpu", weights_only=True))
return model
def forward(self, x):
"""将OHLCV数据量化为token"""
x = (x - self.shift) / self.scale
x = torch.clamp(torch.round(x), 0, self.vocab_size - 1).long()
return self.embed(x)
# ------------- Kronos 主模型类 -------------
class Kronos(nn.Module):
def __init__(self, d_model=256, nhead=8, num_layers=6):
super().__init__()
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Linear(d_model, 5) # 输出OHLCV5个特征
@classmethod
def from_pretrained(cls, model_id, torch_dtype=torch.float32, **kwargs):
"""从Hugging Face Hub加载预训练Kronos模型"""
model = cls(**kwargs)
# 下载模型权重
weight_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", cache_dir="./cache")
state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model = model.to(dtype=torch_dtype) # 设置数据类型
return model
def forward(self, x):
"""模型前向传播"""
out = self.transformer(x, x) # 自回归解码
return self.fc(out)
# ------------- KronosPredictor 预测器类 -------------
class KronosPredictor:
def __init__(self, model, tokenizer, device="cpu", max_context=512):
self.model = model.to(device).eval()
self.tokenizer = tokenizer.to(device)
self.device = device
self.max_context = max_context
def preprocess(self, df):
"""预处理OHLCV数据"""
ohlcv = df[["open", "high", "low", "close", "volume"]].values.astype(np.float32)
# 截断到模型最大上下文长度
if len(ohlcv) > self.max_context:
ohlcv = ohlcv[-self.max_context:]
return torch.tensor(ohlcv, device=self.device)
def predict(self, csv_data, prediction_length=5, num_samples=10):
"""核心预测方法"""
# 读取CSV并预处理
df = pd.read_csv(csv_data)
x = self.preprocess(df)
# 分词器量化
x_embed = self.tokenizer(x)
# 多次采样提升稳定性
predictions = []
with torch.no_grad():
for _ in range(num_samples):
pred = self.model(x_embed)
# 生成未来prediction_length步的预测
pred_seq = pred[-prediction_length:].cpu().numpy()
predictions.append(pred_seq)
# 取均值作为最终预测
return np.mean(predictions, axis=0)