import torch import torch.nn as nn import torch.nn.functional as F from typing import List # --- Import vocabulary for the test block --- import models.vocabulary as vocab class OHLCEmbedder(nn.Module): """ Embeds a sequence of Open and Close prices AND its interval. FIXED: Now takes interval_ids as input and combines an interval embedding with the 1D-CNN chart pattern features. """ def __init__( self, # --- NEW: Interval vocab size --- num_intervals: int, input_channels: int = 2, # Open, Close # sequence_length: int = 300, # REMOVED: HARDCODED cnn_channels: List[int] = [8, 16, 32], kernel_sizes: List[int] = [3, 3, 3], # --- NEW: Interval embedding dim --- interval_embed_dim: int = 16, output_dim: int = 512, dtype: torch.dtype = torch.float16 ): super().__init__() assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length" self.dtype = dtype self.sequence_length = 300 # HARDCODED self.cnn_layers = nn.ModuleList() self.output_dim = output_dim in_channels = input_channels current_seq_len = 300 for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)): conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=k_size, padding='same' ) self.cnn_layers.append(conv) pool = nn.MaxPool1d(kernel_size=2, stride=2) self.cnn_layers.append(pool) current_seq_len = current_seq_len // 2 self.cnn_layers.append(nn.ReLU()) in_channels = out_channels self.global_pool = nn.AdaptiveAvgPool1d(1) final_cnn_channels = cnn_channels[-1] # --- NEW: Interval Embedding Layer --- self.interval_embedding = nn.Embedding(num_intervals, interval_embed_dim, padding_idx=0) # --- NEW: MLP input dim is (CNN features + Interval features) --- mlp_input_dim = final_cnn_channels + interval_embed_dim self.mlp = nn.Sequential( nn.Linear(mlp_input_dim, mlp_input_dim * 2), nn.GELU(), nn.LayerNorm(mlp_input_dim * 2), nn.Linear(mlp_input_dim * 2, output_dim), nn.LayerNorm(output_dim) ) self.to(dtype) # Log params total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"[OHLCEmbedder] Params: {total_params:,} (Trainable: {trainable_params:,})") def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor: """ Args: x (torch.Tensor): Batch of normalized OHLC sequences. Shape: [batch_size, 2, sequence_length] interval_ids (torch.Tensor): Batch of interval IDs. Shape: [batch_size] Returns: torch.Tensor: Batch of OHLC embeddings. Shape: [batch_size, output_dim] """ if x.shape[1] != 2 or x.shape[2] != self.sequence_length: raise ValueError(f"Input tensor shape mismatch. Expected [B, 2, {self.sequence_length}], got {x.shape}") x = x.to(self.dtype) # 1. Pass through CNN layers for layer in self.cnn_layers: x = layer(x) # 2. Apply global average pooling x = self.global_pool(x) # 3. Flatten for MLP x = x.squeeze(-1) # Shape: [batch_size, final_cnn_channels] # 4. --- NEW: Get interval embedding --- interval_embed = self.interval_embedding(interval_ids) # Shape: [batch_size, interval_embed_dim] # 5. --- NEW: Combine features --- combined = torch.cat([x, interval_embed], dim=1) # Shape: [batch_size, final_cnn_channels + interval_embed_dim] # 6. Pass through final MLP x = self.mlp(combined) # Shape: [batch_size, output_dim] return x