oracle / models /ohlc_embedder.py
zirobtc's picture
Upload folder using huggingface_hub
d195287 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
# --- Import vocabulary for the test block ---
import models.vocabulary as vocab
class OHLCEmbedder(nn.Module):
"""
Embeds a sequence of Open and Close prices AND its interval.
FIXED: Now takes interval_ids as input and combines an
interval embedding with the 1D-CNN chart pattern features.
"""
def __init__(
self,
# --- NEW: Interval vocab size ---
num_intervals: int,
input_channels: int = 2, # Open, Close
# sequence_length: int = 300, # REMOVED: HARDCODED
cnn_channels: List[int] = [8, 16, 32],
kernel_sizes: List[int] = [3, 3, 3],
# --- NEW: Interval embedding dim ---
interval_embed_dim: int = 16,
output_dim: int = 512,
dtype: torch.dtype = torch.float16
):
super().__init__()
assert len(cnn_channels) == len(kernel_sizes), "cnn_channels and kernel_sizes must have the same length"
self.dtype = dtype
self.sequence_length = 300 # HARDCODED
self.cnn_layers = nn.ModuleList()
self.output_dim = output_dim
in_channels = input_channels
current_seq_len = 300
for i, (out_channels, k_size) in enumerate(zip(cnn_channels, kernel_sizes)):
conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=k_size,
padding='same'
)
self.cnn_layers.append(conv)
pool = nn.MaxPool1d(kernel_size=2, stride=2)
self.cnn_layers.append(pool)
current_seq_len = current_seq_len // 2
self.cnn_layers.append(nn.ReLU())
in_channels = out_channels
self.global_pool = nn.AdaptiveAvgPool1d(1)
final_cnn_channels = cnn_channels[-1]
# --- NEW: Interval Embedding Layer ---
self.interval_embedding = nn.Embedding(num_intervals, interval_embed_dim, padding_idx=0)
# --- NEW: MLP input dim is (CNN features + Interval features) ---
mlp_input_dim = final_cnn_channels + interval_embed_dim
self.mlp = nn.Sequential(
nn.Linear(mlp_input_dim, mlp_input_dim * 2),
nn.GELU(),
nn.LayerNorm(mlp_input_dim * 2),
nn.Linear(mlp_input_dim * 2, output_dim),
nn.LayerNorm(output_dim)
)
self.to(dtype)
# Log params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[OHLCEmbedder] Params: {total_params:,} (Trainable: {trainable_params:,})")
def forward(self, x: torch.Tensor, interval_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Batch of normalized OHLC sequences.
Shape: [batch_size, 2, sequence_length]
interval_ids (torch.Tensor): Batch of interval IDs.
Shape: [batch_size]
Returns:
torch.Tensor: Batch of OHLC embeddings.
Shape: [batch_size, output_dim]
"""
if x.shape[1] != 2 or x.shape[2] != self.sequence_length:
raise ValueError(f"Input tensor shape mismatch. Expected [B, 2, {self.sequence_length}], got {x.shape}")
x = x.to(self.dtype)
# 1. Pass through CNN layers
for layer in self.cnn_layers:
x = layer(x)
# 2. Apply global average pooling
x = self.global_pool(x)
# 3. Flatten for MLP
x = x.squeeze(-1)
# Shape: [batch_size, final_cnn_channels]
# 4. --- NEW: Get interval embedding ---
interval_embed = self.interval_embedding(interval_ids)
# Shape: [batch_size, interval_embed_dim]
# 5. --- NEW: Combine features ---
combined = torch.cat([x, interval_embed], dim=1)
# Shape: [batch_size, final_cnn_channels + interval_embed_dim]
# 6. Pass through final MLP
x = self.mlp(combined)
# Shape: [batch_size, output_dim]
return x