oracle / models /graph_updater.py
zirobtc's picture
Upload folder using huggingface_hub
18eb93c
import torch
import torch.nn as nn
# We still use GATv2Conv, just not the to_hetero wrapper
from torch_geometric.nn import GATv2Conv
from torch_geometric.data import HeteroData
from typing import Dict, List, Any
from collections import defaultdict # For easy aggregation
from PIL import Image
from models.helper_encoders import ContextualTimeEncoder # Type hint for constructor compatibility
# Import the actual ID_TO_LINK_TYPE mapping
from models.vocabulary import ID_TO_LINK_TYPE
# Import other modules needed for the test block
import models.vocabulary
from models.wallet_encoder import WalletEncoder
from models.token_encoder import TokenEncoder
from models.multi_modal_processor import MultiModalEncoder
class _TransferLinkEncoder(nn.Module):
"""Encodes: transfer amount only (timestamps removed)."""
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
features = self._safe_signed_log(amounts)
return self.proj(features)
class _BundleTradeLinkEncoder(nn.Module):
"""Encodes: total_amount across bundle (timestamps removed)."""
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
totals = torch.tensor([[l.get('total_amount', 0.0)] for l in links], device=device, dtype=self.dtype)
total_embeds = self._safe_signed_log(totals)
return self.proj(total_embeds)
class _CopiedTradeLinkEncoder(nn.Module):
""" Encodes: 10 numerical features """
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
super().__init__()
self.in_features = in_features
self.norm = nn.LayerNorm(in_features)
self.mlp = nn.Sequential(
nn.Linear(in_features, out_dim * 2), nn.GELU(),
nn.Linear(out_dim * 2, out_dim)
)
self.dtype = dtype # Store dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
num_data = []
for l in links:
# --- FIXED: Only use the 6 essential features ---
num_data.append([
l.get('time_gap_on_buy_sec', 0), l.get('time_gap_on_sell_sec', 0),
l.get('leader_pnl', 0), l.get('follower_pnl', 0),
l.get('follower_buy_total', 0), l.get('follower_sell_total', 0)
])
# Create tensor with correct dtype
x = torch.tensor(num_data, device=device, dtype=self.dtype)
# Input to norm must match norm's dtype
x_norm = self.norm(self._safe_signed_log(x))
return self.mlp(x_norm)
class _CoordinatedActivityLinkEncoder(nn.Module):
""" Encodes: 2 numerical features """
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
super().__init__()
self.in_features = in_features
self.norm = nn.LayerNorm(in_features)
self.mlp = nn.Sequential(
nn.Linear(in_features, out_dim), nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype # Store dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
num_data = []
for l in links:
num_data.append([
l.get('time_gap_on_first_sec', 0), l.get('time_gap_on_second_sec', 0)
])
# Create tensor with correct dtype
x = torch.tensor(num_data, device=device, dtype=self.dtype)
x_norm = self.norm(self._safe_signed_log(x))
return self.mlp(x_norm)
class _MintedLinkEncoder(nn.Module):
"""Encodes: buy_amount only (timestamps removed)."""
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype # Store dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
nums = torch.tensor([[l.get('buy_amount', 0.0)] for l in links], device=device, dtype=self.dtype)
num_embeds = self._safe_signed_log(nums)
return self.proj(num_embeds)
class _SnipedLinkEncoder(nn.Module):
""" Encodes: rank, sniped_amount """
def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype
super().__init__()
self.norm = nn.LayerNorm(in_features)
self.mlp = nn.Sequential(nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
self.dtype = dtype # Store dtype
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
num_data = [[l.get('rank', 0), l.get('sniped_amount', 0)] for l in links]
# Create tensor with correct dtype
x = torch.tensor(num_data, device=device, dtype=self.dtype)
# --- FIXED: Selectively log-scale features ---
# Invert rank so 1 is highest, treat as linear. Log-scale sniped_amount.
x[:, 0] = 1.0 / torch.clamp(x[:, 0], min=1.0) # Invert rank, clamp to avoid division by zero
x[:, 1] = torch.sign(x[:, 1]) * torch.log1p(torch.abs(x[:, 1])) # Log-scale amount
x_norm = self.norm(x)
return self.mlp(x_norm)
class _LockedSupplyLinkEncoder(nn.Module):
""" Encodes: amount """
def __init__(self, out_dim: int, dtype: torch.dtype): # Removed time_encoder
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype # Store dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
nums = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
num_embeds = self._safe_signed_log(nums)
return self.proj(num_embeds)
class _BurnedLinkEncoder(nn.Module):
"""Encodes: burned amount (timestamps removed)."""
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype)
amount_embeds = self._safe_signed_log(amounts)
return self.proj(amount_embeds)
class _ProvidedLiquidityLinkEncoder(nn.Module):
"""Encodes: quote amount (timestamps removed)."""
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype
def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor:
return torch.sign(x) * torch.log1p(torch.abs(x))
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
quote_amounts = torch.tensor([[l.get('amount_quote', 0.0)] for l in links], device=device, dtype=self.dtype)
quote_embeds = self._safe_signed_log(quote_amounts)
return self.proj(quote_embeds)
class _WhaleOfLinkEncoder(nn.Module):
""" Encodes: holding_pct_at_creation """
def __init__(self, out_dim: int, dtype: torch.dtype):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(1, out_dim),
nn.GELU(),
nn.Linear(out_dim, out_dim)
)
self.dtype = dtype
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
vals = torch.tensor([[l.get('holding_pct_at_creation', 0.0)] for l in links], device=device, dtype=self.dtype)
vals_log = torch.sign(vals) * torch.log1p(torch.abs(vals))
return self.mlp(vals_log)
class _TopTraderOfLinkEncoder(nn.Module):
""" Encodes: pnl_at_creation """
def __init__(self, out_dim: int, dtype: torch.dtype): # Removed in_features
super().__init__()
self.mlp = nn.Sequential(nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
self.dtype = dtype
def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor:
num_data = [[l.get('pnl_at_creation', 0)] for l in links]
x = torch.tensor(num_data, device=device, dtype=self.dtype)
log_scaled_x = torch.sign(x) * torch.log1p(torch.abs(x))
return self.mlp(log_scaled_x)
class RelationalGATBlock(nn.Module):
"""
Shared GATv2Conv that remains relation-aware by concatenating a learned
relation embedding to every edge attribute before message passing.
"""
def __init__(
self,
node_dim: int,
edge_attr_dim: int,
n_heads: int,
relations: List[str],
dtype: torch.dtype,
):
super().__init__()
self.rel_to_id = {name: idx for idx, name in enumerate(relations)}
self.edge_attr_dim = edge_attr_dim
self.rel_emb = nn.Embedding(len(relations), edge_attr_dim)
self.conv = GATv2Conv(
in_channels=node_dim,
out_channels=node_dim,
heads=n_heads,
concat=False,
dropout=0.1,
add_self_loops=False,
edge_dim=edge_attr_dim * 2, # concat of edge attr + relation emb
).to(dtype)
def forward(
self,
x_src: torch.Tensor,
x_dst: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
rel_type: str,
) -> torch.Tensor:
num_edges = edge_index.size(1)
device = edge_index.device
if edge_attr is None:
edge_attr = torch.zeros(
num_edges,
self.edge_attr_dim,
device=device,
dtype=x_src.dtype,
)
rel_id = self.rel_to_id.get(rel_type)
if rel_id is None:
raise KeyError(f"Relation '{rel_type}' not registered in RelationalGATBlock.")
rel_feat = self.rel_emb.weight[rel_id].to(edge_attr.dtype)
rel_feat = rel_feat.expand(num_edges, -1)
augmented_attr = torch.cat([edge_attr, rel_feat], dim=-1)
return self.conv((x_src, x_dst), edge_index, edge_attr=augmented_attr)
# =============================================================================
# 2. The Main GraphUpdater (GNN) - MANUAL HETEROGENEOUS IMPLEMENTATION
# =============================================================================
class GraphUpdater(nn.Module):
"""
FIXED: Manually implements Heterogeneous GNN logic using separate GATv2Conv
layers for each edge type, bypassing the problematic `to_hetero` wrapper.
"""
def __init__(self,time_encoder: ContextualTimeEncoder, edge_attr_dim: int = 64,
n_heads: int = 4, num_layers: int = 2, node_dim: int = 2048, dtype: torch.dtype = torch.float16):
super().__init__()
self.node_dim = node_dim
self.edge_attr_dim = edge_attr_dim
self.num_layers = num_layers
self.dtype = dtype
# --- Instantiate all 11 Link Feature Encoders --- (Unchanged)
self.edge_encoders = nn.ModuleDict({
'TransferLink': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
'TransferLinkToken': _TransferLinkEncoder(edge_attr_dim, dtype=dtype),
'BundleTradeLink': _BundleTradeLinkEncoder(edge_attr_dim, dtype=dtype),
'CopiedTradeLink': _CopiedTradeLinkEncoder(6, edge_attr_dim, dtype=dtype), # FIXED: in_features=6
'CoordinatedActivityLink': _CoordinatedActivityLinkEncoder(2, edge_attr_dim, dtype=dtype),
'MintedLink': _MintedLinkEncoder(edge_attr_dim, dtype=dtype),
'SnipedLink': _SnipedLinkEncoder(2, edge_attr_dim, dtype=dtype),
'LockedSupplyLink': _LockedSupplyLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No time_encoder
'BurnedLink': _BurnedLinkEncoder(edge_attr_dim, dtype=dtype),
'ProvidedLiquidityLink': _ProvidedLiquidityLinkEncoder(edge_attr_dim, dtype=dtype),
'WhaleOfLink': _WhaleOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
'TopTraderOfLink': _TopTraderOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features
}).to(dtype)
# --- Define shared relational GNN blocks per meta edge direction ---
self.edge_groups = self._build_edge_groups()
self.conv_layers = nn.ModuleList()
for _ in range(num_layers):
conv_dict = nn.ModuleDict()
for (src_type, dst_type), relations in self.edge_groups.items():
conv_dict[f"{src_type}__{dst_type}"] = RelationalGATBlock(
node_dim=node_dim,
edge_attr_dim=edge_attr_dim,
n_heads=n_heads,
relations=relations,
dtype=dtype,
)
self.conv_layers.append(conv_dict)
self.norm = nn.LayerNorm(node_dim)
self.to(dtype) # Move norm layer and ModuleList container
# Log params
total_params = sum(p.numel() for p in self.parameters())
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"[GraphUpdater] Params: {total_params:,} (Trainable: {trainable_params:,})")
def _build_edge_groups(self) -> Dict[tuple, List[str]]:
"""Group relations by (src_type, dst_type) so conv weights can be shared."""
groups: Dict[tuple, List[str]] = defaultdict(list)
wallet_wallet_links = ['TransferLink', 'BundleTradeLink', 'CopiedTradeLink', 'CoordinatedActivityLink']
wallet_token_links = [
'TransferLinkToken', 'MintedLink', 'SnipedLink', 'LockedSupplyLink',
'BurnedLink', 'ProvidedLiquidityLink', 'WhaleOfLink', 'TopTraderOfLink'
]
for link in wallet_wallet_links:
groups[('wallet', 'wallet')].append(link)
groups[('wallet', 'wallet')].append(f"rev_{link}")
for link in wallet_token_links:
groups[('wallet', 'token')].append(link)
groups[('token', 'wallet')].append(f"rev_{link}")
return groups
def forward(
self,
x_dict: Dict[str, torch.Tensor],
edge_data_dict: Dict[str, Dict[str, Any]]
) -> Dict[str, torch.Tensor]:
device = x_dict['wallet'].device
# --- 1. Encode Edge Attributes ---
edge_index_dict = {}
edge_attr_dict = {}
for link_name, data in edge_data_dict.items():
edge_index = data.get('edge_index')
links = data.get('links', [])
# Check if edge_index is valid before proceeding
if edge_index is None or edge_index.numel() == 0 or not links:
continue # Skip if no links or index of this type
edge_index = edge_index.to(device)
# Use vocabulary to get the triplet (src, rel, dst)
# Make sure ID_TO_LINK_TYPE is correctly populated
if link_name not in models.vocabulary.LINK_NAME_TO_TRIPLET:
print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.")
continue
src_type, rel_type, dst_type = models.vocabulary.LINK_NAME_TO_TRIPLET[link_name]
# Check if encoder exists for this link name
if link_name not in self.edge_encoders:
print(f"Warning: No edge encoder found for link type '{link_name}'. Skipping edge attributes.")
edge_attr = None # Or handle differently if attributes are essential
else:
edge_attr = self.edge_encoders[link_name](links, device).to(self.dtype)
# Forward link
fwd_key = (src_type, rel_type, dst_type)
edge_index_dict[fwd_key] = edge_index
if edge_attr is not None:
edge_attr_dict[fwd_key] = edge_attr
# Reverse link
# Ensure edge_index has the right shape for flipping
if edge_index.shape[0] == 2:
rev_edge_index = edge_index[[1, 0]]
rev_rel_type = f'rev_{rel_type}'
rev_key = (dst_type, rev_rel_type, src_type)
edge_index_dict[rev_key] = rev_edge_index
if edge_attr is not None:
# Re-use same attributes for reverse edge
edge_attr_dict[rev_key] = edge_attr
else:
print(f"Warning: Edge index for {link_name} has unexpected shape {edge_index.shape}. Cannot create reverse edge.")
# --- 2. Run GNN Layers MANUALLY ---
x_out = x_dict
for i in range(self.num_layers):
# Initialize aggregation tensors for each node type that exists in the input
msg_aggregates = {
node_type: torch.zeros_like(x_node)
for node_type, x_node in x_out.items()
}
# --- Message Passing ---
for edge_type_tuple in edge_index_dict.keys(): # Iterate through edges PRESENT in the batch
src_type, rel_type, dst_type = edge_type_tuple
edge_index = edge_index_dict[edge_type_tuple]
edge_attr = edge_attr_dict.get(edge_type_tuple) # Use .get() in case attr is None
x_src = x_out.get(src_type)
x_dst = x_out.get(dst_type)
if x_src is None or x_dst is None:
print(f"Warning: Missing node embeddings for types {src_type}->{dst_type}. Skipping.")
continue
block_key = f"{src_type}__{dst_type}"
if block_key not in self.conv_layers[i]:
print(f"Warning: Relational block for {block_key} not found in layer {i}. Skipping.")
continue
block = self.conv_layers[i][block_key]
try:
messages = block(x_src, x_dst, edge_index, edge_attr, rel_type)
except KeyError:
print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.")
continue
# GATv2Conv output is already per-destination-node (shape [num_dst_nodes, node_dim])
# NOT per-edge. So we directly accumulate, no scatter needed.
msg_aggregates[dst_type] += messages
# --- Aggregation & Update (Residual Connection) ---
x_next = {}
for node_type, x_original in x_out.items():
# Check if messages were computed and stored correctly
if node_type in msg_aggregates and msg_aggregates[node_type].shape[0] > 0:
aggregated_msgs = msg_aggregates[node_type]
# Ensure dimensions match before adding
if x_original.shape == aggregated_msgs.shape:
x_next[node_type] = self.norm(x_original + aggregated_msgs)
else:
print(f"Warning: Shape mismatch for node type {node_type} during update. Original: {x_original.shape}, Aggregated: {aggregated_msgs.shape}. Skipping residual connection.")
x_next[node_type] = x_original # Fallback
else:
x_next[node_type] = x_original
x_out = x_next
return x_out