LPatchTST-NIFTY1 / tokenizer.py
gulnawaz123's picture
Upload LPatchTST checkpoint and source
64fca1b verified
# tokenizer.py ─── 100% from shiyu-coder/Kronos model/module.py + model/kronos.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from einops import rearrange, reduce
# ── Official module.py ────────────────────────────────────────────────────────
class DifferentiableEntropyFunction(Function):
@staticmethod
def forward(ctx, zq, basis, K, eps):
zb = (zq + 1) / 2
zi = ((zb * basis).sum(-1)).to(torch.int64)
cnt = torch.scatter_reduce(
torch.zeros(2 ** K, device=zq.device, dtype=zq.dtype),
0, zi.flatten(),
torch.ones_like(zi.flatten()).to(zq.dtype), 'sum')
prob = (cnt + eps) / (cnt + eps).sum()
H = -(prob * torch.log(prob)).sum()
ctx.save_for_backward(zq, zi, prob)
ctx.K = K
return H
@staticmethod
def backward(ctx, grad_output):
zq, zi, prob = ctx.saved_tensors
grad_array = -grad_output * (torch.log(prob) + 1) / zi.numel() / ctx.K
reord_grad = grad_array[zi.flatten()].reshape(zi.shape)
grad_input = reord_grad.unsqueeze(-1) * zq
return grad_input, None, None, None, None
def codebook_entropy(zq, basis, K, eps=1e-4):
return DifferentiableEntropyFunction.apply(zq, basis, K, eps)
class BinarySphericalQuantizer(nn.Module):
def __init__(self, embed_dim, beta, gamma0, gamma, zeta,
input_format='bchw', soft_entropy=True, group_size=9,
persample_entropy_compute='analytical',
cb_entropy_compute='group', l2_norm=True, inv_temperature=1):
super().__init__()
self.embed_dim = embed_dim
self.beta = beta
self.gamma0 = gamma0
self.gamma = gamma
self.zeta = zeta
self.input_format = input_format
assert self.embed_dim % group_size == 0, \
f"embed_dim ({embed_dim}) must be divisible by group_size ({group_size})"
self.num_groups = self.embed_dim // group_size
self.group_size = group_size
self.persample_entropy_compute = persample_entropy_compute
self.cb_entropy_compute = cb_entropy_compute
self.l2_norm = l2_norm
self.inv_temperature = inv_temperature
self.soft_entropy = soft_entropy
self.register_buffer('basis', 2 ** torch.arange(embed_dim - 1, -1, -1))
self.register_buffer('group_basis', 2 ** torch.arange(group_size - 1, -1, -1))
self.num_dimensions = 2 ** embed_dim
self.bits_per_index = embed_dim
group_codes = torch.arange(2 ** self.group_size)
group_codebook = self.indexes_to_codes(group_codes).float()[:, -group_size:]
self.register_buffer('group_codebook', group_codebook, persistent=False)
def quantize(self, z):
assert z.shape[-1] == self.embed_dim
zhat = torch.where(z > 0,
torch.tensor(1, dtype=z.dtype, device=z.device),
torch.tensor(-1, dtype=z.dtype, device=z.device))
return z + (zhat - z).detach()
def forward(self, z, collect_metrics=True):
zq = self.quantize(z)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
zq = zq * q_scale
if not collect_metrics:
return zq, zq.new_zeros(()), {}
indices = self.codes_to_indexes(zq.detach())
group_indices = self.codes_to_group_indexes(zq.detach())
if not self.training:
used_codes = torch.unique(indices, return_counts=False)
else:
used_codes = None
if self.soft_entropy:
persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(z)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
else:
zb_by_sample = ((zq + 1) / 2).reshape(z.shape[0], -1, z.shape[-1]).to(torch.float32)
persample_entropy = self.get_hard_per_sample_entropy(zb_by_sample)
cb_entropy = codebook_entropy(zq, self.basis, self.embed_dim)
entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
avg_prob = None
commit_loss = self.beta * torch.mean(((zq.detach() - z) ** 2).sum(dim=-1))
return (
zq,
commit_loss + self.zeta * entropy_penalty / self.inv_temperature,
{"H": cb_entropy, "used_codes": used_codes,
"indices": indices, "group_indices": group_indices, "avg_prob": avg_prob}
)
def soft_entropy_loss(self, z):
group_code_book = self.group_codebook / (self.embed_dim ** 0.5 if self.l2_norm else 1)
divided_z = rearrange(z, '... (g c) -> ... g c', c=self.group_size)
distance = -2 * torch.einsum('... g c, d c ->... g d', divided_z, group_code_book)
prob = (-distance * self.inv_temperature).softmax(dim=-1)
if self.persample_entropy_compute == 'analytical':
if self.l2_norm:
p = torch.sigmoid(-4 * z / (self.embed_dim ** 0.5) * self.inv_temperature)
else:
p = torch.sigmoid(-4 * z * self.inv_temperature)
prob = torch.stack([p, 1 - p], dim=-1)
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
else:
per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
avg_prob = reduce(prob, '... g d -> g d', 'mean')
cb_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
return per_sample_entropy, cb_entropy.sum(), avg_prob
def get_hard_per_sample_entropy(self, zb_by_sample):
probs_per_dim = zb_by_sample.sum(1) / zb_by_sample.shape[1]
persample_entropy = (
-probs_per_dim * torch.log(probs_per_dim + 1e-8)
-(1 - probs_per_dim) * torch.log(1 - probs_per_dim + 1e-8)
).sum(-1)
return persample_entropy.mean()
def codes_to_indexes(self, zhat):
assert zhat.shape[-1] == self.embed_dim
return ((zhat + 1) / 2 * self.basis).sum(axis=-1).to(torch.int64)
def codes_to_group_indexes(self, zhat):
zhat_in_group = rearrange(zhat, 'b ... (g c) -> b ... g c', c=self.group_size)
return ((zhat_in_group + 1) / 2 * self.group_basis).sum(axis=-1).to(torch.int64)
def indexes_to_codes(self, indices):
indices = indices.unsqueeze(-1)
codes_non_centered = torch.remainder(torch.floor_divide(indices, self.basis), 2)
return codes_non_centered * 2 - 1
def group_indexes_to_codes(self, group_indices):
group_indices = group_indices.unsqueeze(-1)
codes_non_centered = torch.remainder(torch.floor_divide(group_indices, self.group_basis), 2)
codes_non_centered = rearrange(codes_non_centered, 'b ... g c -> b ... (g c)')
return codes_non_centered * 2 - 1
def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
if normalize:
probs = (count + eps) / (count + eps).sum(dim=dim, keepdim=True)
else:
probs = count
return -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
def get_codebook_entry(self, indices):
z_q = self.indexes_to_codes(indices)
q_scale = 1. / (self.embed_dim ** 0.5) if self.l2_norm else 1.
return z_q * q_scale
class BSQuantizer(nn.Module):
def __init__(self, s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size):
super().__init__()
self.codebook_dim = s1_bits + s2_bits
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.bsq = BinarySphericalQuantizer(
self.codebook_dim, beta, gamma0, gamma, zeta, group_size=group_size)
def bits_to_indices(self, bits):
# bits is already scaled by q_scale, recover sign first
bits = (bits >= 0).to(torch.long)
indices = 2 ** torch.arange(0, bits.shape[-1], 1,
dtype=torch.long, device=bits.device)
return (bits * indices).sum(-1)
def forward(self, z, half=False, collect_metrics=True, apply_normalize=True):
if apply_normalize:
z = F.normalize(z, dim=-1)
quantized, bsq_loss, metrics = self.bsq(z, collect_metrics=collect_metrics)
if half:
q_pre = quantized[:, :, :self.s1_bits]
q_post = quantized[:, :, self.s1_bits:]
z_indices = [self.bits_to_indices(q_pre), self.bits_to_indices(q_post)]
else:
z_indices = self.bits_to_indices(quantized)
return bsq_loss, quantized, z_indices
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class FeedForward(nn.Module):
def __init__(self, d_model, ff_dim, ffn_dropout_p=0.0):
super().__init__()
self.w1 = nn.Linear(d_model, ff_dim, bias=False)
self.w3 = nn.Linear(d_model, ff_dim, bias=False)
self.w2 = nn.Linear(ff_dim, d_model, bias=False)
self.ffn_dropout = nn.Dropout(ffn_dropout_p)
def forward(self, x):
return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def _update_cos_sin_cache(self, x, seq_len):
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached, self.sin_cached
def forward(self, q, k):
cos, sin = self._update_cos_sin_cache(q, q.shape[-2])
return (q * cos) + (self._rotate_half(q) * sin), \
(k * cos) + (self._rotate_half(k) * sin)
def _rotate_half(self, x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
class MultiHeadAttentionWithRoPE(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
self.rotary = RotaryPositionalEmbedding(self.head_dim)
self.attn_dropout_p = attn_dropout_p
self.resid_dropout = nn.Dropout(resid_dropout_p)
def forward(self, x, key_padding_mask=None):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
q, k = self.rotary(q, k)
attn_mask = None
if key_padding_mask is not None:
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2).expand(-1, self.n_heads, T, -1)
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask,
dropout_p=self.attn_dropout_p if self.training else 0.0,
is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.resid_dropout(self.out_proj(out))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ff_dim=1024,
ffn_dropout_p=0.0, attn_dropout_p=0.0, resid_dropout_p=0.0):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.self_attn = MultiHeadAttentionWithRoPE(d_model, n_heads, attn_dropout_p, resid_dropout_p)
self.norm2 = RMSNorm(d_model)
self.ffn = FeedForward(d_model, ff_dim, ffn_dropout_p)
def forward(self, x, key_padding_mask=None):
x = x + self.self_attn(self.norm1(x), key_padding_mask=key_padding_mask)
x = x + self.ffn(self.norm2(x))
return x
# ── Official kronos.py ────────────────────────────────────────────────────────
def _infer_config_from_checkpoint(path):
"""
Reads tensor shapes from a checkpoint WITHOUT loading them into any model.
Returns a dict of kwargs sufficient to reconstruct KronosTokenizer exactly.
Shape map (read from checkpoint keys):
embed.weight : (d_model, d_in)
quant_embed.weight : (codebook_dim, d_model)
post_quant_embed_pre.weight : (d_model, s1_bits)
post_quant_embed.weight : (d_model, codebook_dim)
tokenizer.bsq.basis : (codebook_dim,)
tokenizer.bsq.group_basis : (group_size,)
encoder.0.norm1.weight : (d_model,) β†’ n_enc_layers = len(encoder)+1
encoder.0.ffn.w1.weight : (ff_dim, d_model)
"""
if path.endswith(".safetensors"):
from safetensors import safe_open
shapes = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
shapes[key] = f.get_slice(key).get_shape()
else:
state = torch.load(path, map_location="cpu")
shapes = {k: v.shape for k, v in state.items()}
d_model = shapes["embed.weight"][0]
d_in = shapes["embed.weight"][1]
codebook_dim = shapes["quant_embed.weight"][0]
s1_bits = shapes["post_quant_embed_pre.weight"][1]
s2_bits = codebook_dim - s1_bits
group_size = shapes["tokenizer.bsq.group_basis"][0]
ff_dim = shapes["encoder.0.ffn.w1.weight"][0]
# Count encoder / decoder blocks (keys like encoder.0, encoder.1, ...)
n_enc = sum(1 for k in shapes if k.startswith("encoder.") and k.endswith(".norm1.weight"))
n_dec = sum(1 for k in shapes if k.startswith("decoder.") and k.endswith(".norm1.weight"))
# +1 because __init__ builds (n_layers - 1) blocks
n_enc_layers = n_enc + 1
n_dec_layers = n_dec + 1
# n_heads: head_dim = d_model // n_heads; rotary inv_freq has shape (head_dim//2,)
rotary_dim = shapes["encoder.0.self_attn.rotary.inv_freq"][0] # head_dim // 2
head_dim = rotary_dim * 2
n_heads = d_model // head_dim
cfg = dict(
d_in=d_in,
d_model=d_model,
n_heads=n_heads,
ff_dim=ff_dim,
n_enc_layers=n_enc_layers,
n_dec_layers=n_dec_layers,
s1_bits=s1_bits,
s2_bits=s2_bits,
group_size=group_size,
# dropout values don't affect inference; keep at 0
ffn_dropout_p=0.0,
attn_dropout_p=0.0,
resid_dropout_p=0.0,
)
print(" βœ“ Inferred config from checkpoint:")
for k, v in cfg.items():
print(f" {k:20s} = {v}")
return cfg
class KronosTokenizer(nn.Module):
def __init__(self, d_in=4, d_model=128, n_heads=4, ff_dim=512,
n_enc_layers=3, n_dec_layers=3,
ffn_dropout_p=0.1, attn_dropout_p=0.1, resid_dropout_p=0.1,
s1_bits=6, s2_bits=6,
beta=0.25, gamma0=0.1, gamma=0.1, zeta=0.1, group_size=6):
super().__init__()
self.d_in = d_in
self.s1_bits = s1_bits
self.s2_bits = s2_bits
self.codebook_dim = s1_bits + s2_bits
self.embed = nn.Linear(d_in, d_model)
self.head = nn.Linear(d_model, d_in)
self.encoder = nn.ModuleList([
TransformerBlock(d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p)
for _ in range(n_enc_layers - 1)
])
self.decoder = nn.ModuleList([
TransformerBlock(d_model, n_heads, ff_dim, ffn_dropout_p, attn_dropout_p, resid_dropout_p)
for _ in range(n_dec_layers - 1)
])
self.quant_embed = nn.Linear(d_model, self.codebook_dim, bias=True)
self.post_quant_embed_pre = nn.Linear(s1_bits, d_model)
self.post_quant_embed = nn.Linear(self.codebook_dim, d_model)
self.tokenizer = BSQuantizer(s1_bits, s2_bits, beta, gamma0, gamma, zeta, group_size)
# ------------------------------------------------------------------
# PRIMARY entry-point: always use this to load from a checkpoint file.
# It reads shapes from the file first, rebuilds the model with the
# correct architecture, then loads weights β€” zero size-mismatch errors.
# ------------------------------------------------------------------
@classmethod
def from_pretrained(cls, path, device="cpu",
beta=0.25, gamma0=0.1, gamma=0.1, zeta=0.1):
"""
Construct a KronosTokenizer whose architecture matches the checkpoint
at `path`, then load the weights. Never fails with size-mismatch.
Usage:
tok = KronosTokenizer.from_pretrained("model.safetensors")
tok = KronosTokenizer.from_pretrained("model.safetensors", device="cuda")
"""
cfg = _infer_config_from_checkpoint(path)
# BSQ hyper-params don't affect inference (only loss computation)
model = cls(**cfg, beta=beta, gamma0=gamma0, gamma=gamma, zeta=zeta)
model.load_pretrained(path, device=device)
return model
def load_pretrained(self, path, device="cpu"):
"""
Load weights into an already-constructed model.
Use `from_pretrained` instead unless you have already built the model
with the correct architecture.
"""
if path.endswith(".safetensors"):
from safetensors.torch import load_model
missing, unexpected = load_model(self, path, strict=False)
if missing: print(f" ⚠ Missing keys : {missing}")
if unexpected: print(f" ⚠ Unexpected keys: {unexpected}")
print(f" βœ“ Loaded weights from {path} (safetensors)")
else:
state = torch.load(path, map_location=device)
missing, unexpected = self.load_state_dict(state, strict=False)
if missing: print(f" ⚠ Missing keys : {missing}")
if unexpected: print(f" ⚠ Unexpected keys: {unexpected}")
print(f" βœ“ Loaded weights from {path} (torch)")
self.to(device)
self.eval()
def forward(self, x):
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z)
bsq_loss, quantized, z_indices = self.tokenizer(z)
quantized_pre = quantized[:, :, :self.s1_bits]
z_pre = self.post_quant_embed_pre(quantized_pre)
for layer in self.decoder:
z_pre = layer(z_pre)
z_pre = self.head(z_pre)
z_full = self.post_quant_embed(quantized)
for layer in self.decoder:
z_full = layer(z_full)
z_full = self.head(z_full)
return (z_pre, z_full), bsq_loss, quantized, z_indices
def encode(self, x, half=True):
z = self.embed(x)
for layer in self.encoder:
z = layer(z)
z = self.quant_embed(z)
_, _, z_indices = self.tokenizer(z, half=half, collect_metrics=False)
return z_indices
def indices_to_bits(self, x, half=False):
codebook_dim = self.codebook_dim
q_scale = 1. / (codebook_dim ** 0.5)
if half:
x1, x2 = x[0], x[1]
mask = 2 ** torch.arange(codebook_dim // 2, device=x1.device, dtype=torch.long)
b1 = ((x1.unsqueeze(-1) & mask) != 0).float() * 2 - 1
b2 = ((x2.unsqueeze(-1) & mask) != 0).float() * 2 - 1
bits = torch.cat([b1, b2], dim=-1)
else:
mask = 2 ** torch.arange(codebook_dim, device=x.device, dtype=torch.long)
bits = ((x.unsqueeze(-1) & mask) != 0).float() * 2 - 1
return bits * q_scale
def decode(self, x, half=True):
quantized = self.indices_to_bits(x, half=half)
z = self.post_quant_embed(quantized)
for layer in self.decoder:
z = layer(z)
return self.head(z)
def prepare_ohlc_features(df):
"""
Expects df with columns ['Open', 'High', 'Low', 'Close', 'Volume'].
Returns (N, 6) array of:
[log_ret_O, log_ret_H, log_ret_L, log_ret_C, log_ret_V, log_ret_A]
All relative to PREVIOUS bar's Close (for prices) or Volume (for volume/amount).
"""
import numpy as np
cols = {c.lower(): c for c in df.columns}
o_col = cols.get('open', 'Open')
h_col = cols.get('high', 'High')
l_col = cols.get('low', 'Low')
c_col = cols.get('close', 'Close')
v_col = cols.get('volume', 'Volume')
close = df[c_col].values
prev_close = np.roll(close, 1)
# Volume features (optional, but 6-input tokenizer needs them)
if v_col in df.columns:
volume = df[v_col].values.astype(np.float32)
amount = close * volume
else:
volume = np.zeros_like(close)
amount = np.zeros_like(close)
prev_volume = np.roll(volume, 1)
prev_amount = np.roll(amount, 1)
with np.errstate(divide='ignore', invalid='ignore'):
o = np.log(df[o_col].values / prev_close)
h = np.log(df[h_col].values / prev_close)
l = np.log(df[l_col].values / prev_close)
c = np.log(df[c_col].values / prev_close)
v = np.log((volume + 1e-6) / (prev_volume + 1e-6))
a = np.log((amount + 1e-6) / (prev_amount + 1e-6))
out = np.stack([o, h, l, c, v, a], axis=1)[1:]
out = np.nan_to_num(out).astype(np.float32)
# ── Per-feature rolling z-score normalization ───────────────────────────
# Use a 500-bar rolling window so statistics are local, not global.
# This stabilises the input distribution that the frozen tokenizer sees
# across different volatility regimes.
import pandas as pd
window = 500
df_out = pd.DataFrame(out)
min_p = max(1, min(50, len(df_out) // 10))
roll_mean = df_out.rolling(window, min_periods=min_p).mean().bfill().values
roll_std = df_out.rolling(window, min_periods=min_p).std().bfill().values
out = ((out - roll_mean) / (roll_std + 1e-8)).astype(np.float32)
out = np.clip(out, -5.0, 5.0)
return out