| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import List |
|
|
| |
| import models.vocabulary as vocab |
|
|
| class OHLCEmbedder(nn.Module): |
| """ |
| Embeds a sequence of Open and Close prices AND its interval. |
| |
| FIXED: Now takes interval_ids as input and combines an |
| interval embedding with the 1D-CNN chart pattern features. |
| """ |
| def __init__( |
| self, |
| |
| num_intervals: int, |
| input_channels: int = 2, |
| |
| cnn_channels: List[int] = [8, 16, 32], |
| kernel_sizes: List[int] = [3, 3, 3], |
| |
| interval_embed_dim: int = 16, |
| output_dim: int = 512, |
| dtype: torch.dtype = torch.float16 |
| ): |
| super().__init__() |
| assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length" |
|
|
| self.dtype = dtype |
| self.sequence_length = 300 |
| self.cnn_layers = nn.ModuleList() |
| self.output_dim = output_dim |
| |
| in_channels = input_channels |
| current_seq_len = 300 |
| |
| for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)): |
| conv = nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=k_size, |
| padding='same' |
| ) |
| self.cnn_layers.append(conv) |
| pool = nn.MaxPool1d(kernel_size=2, stride=2) |
| self.cnn_layers.append(pool) |
| current_seq_len = current_seq_len // 2 |
| self.cnn_layers.append(nn.ReLU()) |
| in_channels = out_channels |
|
|
| self.global_pool = nn.AdaptiveAvgPool1d(1) |
| |
| final_cnn_channels = cnn_channels[-1] |
|
|
| |
| self.interval_embedding = nn.Embedding(num_intervals, interval_embed_dim, padding_idx=0) |
|
|
| |
| mlp_input_dim = final_cnn_channels + interval_embed_dim |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(mlp_input_dim, mlp_input_dim * 2), |
| nn.GELU(), |
| nn.LayerNorm(mlp_input_dim * 2), |
| nn.Linear(mlp_input_dim * 2, output_dim), |
| nn.LayerNorm(output_dim) |
| ) |
|
|
| self.to(dtype) |
|
|
| |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[OHLCEmbedder] Params: {total_params:,} (Trainable: {trainable_params:,})") |
|
|
| def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| x (torch.Tensor): Batch of normalized OHLC sequences. |
| Shape: [batch_size, 2, sequence_length] |
| interval_ids (torch.Tensor): Batch of interval IDs. |
| Shape: [batch_size] |
| Returns: |
| torch.Tensor: Batch of OHLC embeddings. |
| Shape: [batch_size, output_dim] |
| """ |
| if x.shape[1] != 2 or x.shape[2] != self.sequence_length: |
| raise ValueError(f"Input tensor shape mismatch. Expected [B, 2, {self.sequence_length}], got {x.shape}") |
| |
| x = x.to(self.dtype) |
|
|
| |
| for layer in self.cnn_layers: |
| x = layer(x) |
| |
| |
| x = self.global_pool(x) |
| |
| |
| x = x.squeeze(-1) |
| |
| |
| |
| interval_embed = self.interval_embedding(interval_ids) |
| |
|
|
| |
| combined = torch.cat([x, interval_embed], dim=1) |
| |
|
|
| |
| x = self.mlp(combined) |
| |
|
|
| return x |
|
|