| import torch |
| import torch.nn as nn |
|
|
|
|
| class QuantOHLCEmbedder(nn.Module): |
| def __init__( |
| self, |
| num_features: int, |
| sequence_length: int = 60, |
| version_vocab_size: int = 4, |
| hidden_dim: int = 320, |
| num_layers: int = 3, |
| num_heads: int = 8, |
| output_dim: int = 1536, |
| dtype: torch.dtype = torch.float16, |
| ): |
| super().__init__() |
| self.num_features = num_features |
| self.sequence_length = sequence_length |
| self.output_dim = output_dim |
| self.dtype = dtype |
|
|
| self.feature_proj = nn.Sequential( |
| nn.LayerNorm(num_features), |
| nn.Linear(num_features, hidden_dim), |
| nn.GELU(), |
| ) |
| self.position_embedding = nn.Parameter(torch.zeros(1, sequence_length, hidden_dim)) |
| self.version_embedding = nn.Embedding(version_vocab_size, hidden_dim, padding_idx=0) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dim_feedforward=hidden_dim * 4, |
| dropout=0.0, |
| batch_first=True, |
| activation="gelu", |
| norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.output_head = nn.Sequential( |
| nn.LayerNorm(hidden_dim), |
| nn.Linear(hidden_dim, hidden_dim * 2), |
| nn.GELU(), |
| nn.LayerNorm(hidden_dim * 2), |
| nn.Linear(hidden_dim * 2, output_dim), |
| nn.LayerNorm(output_dim), |
| ) |
| self.to(dtype) |
|
|
| def forward( |
| self, |
| feature_tokens: torch.Tensor, |
| feature_mask: torch.Tensor, |
| version_ids: torch.Tensor, |
| ) -> torch.Tensor: |
| if feature_tokens.ndim != 3: |
| raise ValueError(f"Expected [B, T, F], got {feature_tokens.shape}") |
| if feature_tokens.shape[1] != self.sequence_length: |
| raise ValueError(f"Expected T={self.sequence_length}, got {feature_tokens.shape[1]}") |
| if feature_tokens.shape[2] != self.num_features: |
| raise ValueError(f"Expected F={self.num_features}, got {feature_tokens.shape[2]}") |
|
|
| x = self.feature_proj(feature_tokens.to(self.dtype)) |
| version_embed = self.version_embedding(version_ids).unsqueeze(1) |
| x = x + self.position_embedding[:, : x.shape[1], :].to(x.dtype) + version_embed |
| key_padding_mask = ~(feature_mask > 0) |
| x = self.encoder(x, src_key_padding_mask=key_padding_mask) |
|
|
| mask = feature_mask.to(x.dtype).unsqueeze(-1) |
| valid_any = (feature_mask.sum(dim=1, keepdim=True) > 0).to(x.dtype) |
| denom = mask.sum(dim=1).clamp_min(1.0) |
| pooled = (x * mask).sum(dim=1) / denom |
| out = self.output_head(pooled) |
| return out * valid_any |
|
|