leehao163 commited on
Commit
a2c7d6b
·
verified ·
1 Parent(s): bd59cf2

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +91 -0
model.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pandas as pd
5
+ from huggingface_hub import from_pretrained_keras, hf_hub_download
6
+ import os
7
+
8
+ # ------------- KronosTokenizer 分词器类 -------------
9
+ class KronosTokenizer(nn.Module):
10
+ def __init__(self, vocab_size=1024, embed_dim=128):
11
+ super().__init__()
12
+ self.vocab_size = vocab_size
13
+ self.embed = nn.Embedding(vocab_size, embed_dim)
14
+ # 量化参数(Kronos核心:将连续OHLCV转为离散token)
15
+ self.scale = nn.Parameter(torch.ones(5)) # 对应OHLCV5个特征
16
+ self.shift = nn.Parameter(torch.zeros(5))
17
+
18
+ @classmethod
19
+ def from_pretrained(cls, model_id, **kwargs):
20
+ """从Hugging Face Hub加载预训练分词器"""
21
+ model = cls(**kwargs)
22
+ # 下载预训练权重
23
+ weight_path = hf_hub_download(repo_id=model_id, filename="tokenizer_weights.bin")
24
+ model.load_state_dict(torch.load(weight_path, map_location="cpu"))
25
+ return model
26
+
27
+ def forward(self, x):
28
+ """将OHLCV数据量化为token"""
29
+ x = (x - self.shift) / self.scale
30
+ x = torch.clamp(torch.round(x), 0, self.vocab_size - 1).long()
31
+ return self.embed(x)
32
+
33
+ # ------------- Kronos 主模型类 -------------
34
+ class Kronos(nn.Module):
35
+ def __init__(self, d_model=256, nhead=8, num_layers=6):
36
+ super().__init__()
37
+ self.transformer = nn.TransformerDecoder(
38
+ nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True),
39
+ num_layers=num_layers
40
+ )
41
+ self.fc = nn.Linear(d_model, 5) # 输出OHLCV5个特征
42
+
43
+ @classmethod
44
+ def from_pretrained(cls, model_id, torch_dtype=torch.float32, **kwargs):
45
+ """从Hugging Face Hub加载预训练Kronos模型"""
46
+ model = cls(**kwargs)
47
+ # 下载预训练权重
48
+ weight_path = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin")
49
+ state_dict = torch.load(weight_path, map_location="cpu", weights_only=True)
50
+ model.load_state_dict(state_dict)
51
+ model.dtype = torch_dtype
52
+ return model
53
+
54
+ def forward(self, x):
55
+ """模型前向传播:输入token嵌入,输出预测的OHLCV特征"""
56
+ out = self.transformer(x, x) # 自回归解码
57
+ return self.fc(out)
58
+
59
+ # ------------- KronosPredictor 预测器类 -------------
60
+ class KronosPredictor:
61
+ def __init__(self, model, tokenizer, device="cpu", max_context=512):
62
+ self.model = model.to(device).eval()
63
+ self.tokenizer = tokenizer.to(device)
64
+ self.device = device
65
+ self.max_context = max_context
66
+
67
+ def preprocess(self, df):
68
+ """预处理OHLCV数据:标准化+截断长度"""
69
+ ohlcv = df[["open", "high", "low", "close", "volume"]].values.astype(np.float32)
70
+ # 截断到模型最大上下文长度
71
+ if len(ohlcv) > self.max_context:
72
+ ohlcv = ohlcv[-self.max_context:]
73
+ return torch.tensor(ohlcv, device=self.device)
74
+
75
+ def predict(self, csv_data, prediction_length=5, num_samples=10):
76
+ """核心预测方法:输入CSV数据,输出预测结果"""
77
+ # 读取CSV并预处理
78
+ df = pd.read_csv(csv_data)
79
+ x = self.preprocess(df)
80
+ # 分词器量化
81
+ x_embed = self.tokenizer(x)
82
+ # 多次采样提升稳定性
83
+ predictions = []
84
+ with torch.no_grad():
85
+ for _ in range(num_samples):
86
+ pred = self.model(x_embed)
87
+ # 生成未来prediction_length步的预测
88
+ pred_seq = pred[-prediction_length:].cpu().numpy()
89
+ predictions.append(pred_seq)
90
+ # 取均值作为最终预测
91
+ return np.mean(predictions, axis=0)