import torch import torch.nn as nn # We still use GATv2Conv, just not the to_hetero wrapper from torch_geometric.nn import GATv2Conv from torch_geometric.data import HeteroData from typing import Dict, List, Any from collections import defaultdict # For easy aggregation from PIL import Image from models.helper_encoders import ContextualTimeEncoder # Type hint for constructor compatibility # Import the actual ID_TO_LINK_TYPE mapping from models.vocabulary import ID_TO_LINK_TYPE # Import other modules needed for the test block import models.vocabulary from models.wallet_encoder import WalletEncoder from models.token_encoder import TokenEncoder from models.multi_modal_processor import MultiModalEncoder class _TransferLinkEncoder(nn.Module): """Encodes: transfer amount only (timestamps removed).""" def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) features = self._safe_signed_log(amounts) return self.proj(features) class _BundleTradeLinkEncoder(nn.Module): """Encodes: total_amount across bundle (timestamps removed).""" def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: totals = torch.tensor([[l.get('total_amount', 0.0)] for l in links], device=device, dtype=self.dtype) total_embeds = self._safe_signed_log(totals) return self.proj(total_embeds) class _CopiedTradeLinkEncoder(nn.Module): """ Encodes: 10 numerical features """ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype super().__init__() self.in_features = in_features self.norm = nn.LayerNorm(in_features) self.mlp = nn.Sequential( nn.Linear(in_features, out_dim * 2), nn.GELU(), nn.Linear(out_dim * 2, out_dim) ) self.dtype = dtype # Store dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: num_data = [] for l in links: # --- FIXED: Only use the 6 essential features --- num_data.append([ l.get('time_gap_on_buy_sec', 0), l.get('time_gap_on_sell_sec', 0), l.get('leader_pnl', 0), l.get('follower_pnl', 0), l.get('follower_buy_total', 0), l.get('follower_sell_total', 0) ]) # Create tensor with correct dtype x = torch.tensor(num_data, device=device, dtype=self.dtype) # Input to norm must match norm's dtype x_norm = self.norm(self._safe_signed_log(x)) return self.mlp(x_norm) class _CoordinatedActivityLinkEncoder(nn.Module): """ Encodes: 2 numerical features """ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype super().__init__() self.in_features = in_features self.norm = nn.LayerNorm(in_features) self.mlp = nn.Sequential( nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype # Store dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: num_data = [] for l in links: num_data.append([ l.get('time_gap_on_first_sec', 0), l.get('time_gap_on_second_sec', 0) ]) # Create tensor with correct dtype x = torch.tensor(num_data, device=device, dtype=self.dtype) x_norm = self.norm(self._safe_signed_log(x)) return self.mlp(x_norm) class _MintedLinkEncoder(nn.Module): """Encodes: buy_amount only (timestamps removed).""" def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype # Store dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: nums = torch.tensor([[l.get('buy_amount', 0.0)] for l in links], device=device, dtype=self.dtype) num_embeds = self._safe_signed_log(nums) return self.proj(num_embeds) class _SnipedLinkEncoder(nn.Module): """ Encodes: rank, sniped_amount """ def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): # Added dtype super().__init__() self.norm = nn.LayerNorm(in_features) self.mlp = nn.Sequential(nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) self.dtype = dtype # Store dtype def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: num_data = [[l.get('rank', 0), l.get('sniped_amount', 0)] for l in links] # Create tensor with correct dtype x = torch.tensor(num_data, device=device, dtype=self.dtype) # --- FIXED: Selectively log-scale features --- # Invert rank so 1 is highest, treat as linear. Log-scale sniped_amount. x[:, 0] = 1.0 / torch.clamp(x[:, 0], min=1.0) # Invert rank, clamp to avoid division by zero x[:, 1] = torch.sign(x[:, 1]) * torch.log1p(torch.abs(x[:, 1])) # Log-scale amount x_norm = self.norm(x) return self.mlp(x_norm) class _LockedSupplyLinkEncoder(nn.Module): """ Encodes: amount """ def __init__(self, out_dim: int, dtype: torch.dtype): # Removed time_encoder super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype # Store dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: nums = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) num_embeds = self._safe_signed_log(nums) return self.proj(num_embeds) class _BurnedLinkEncoder(nn.Module): """Encodes: burned amount (timestamps removed).""" def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) amount_embeds = self._safe_signed_log(amounts) return self.proj(amount_embeds) class _ProvidedLiquidityLinkEncoder(nn.Module): """Encodes: quote amount (timestamps removed).""" def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.proj = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: return torch.sign(x) * torch.log1p(torch.abs(x)) def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: quote_amounts = torch.tensor([[l.get('amount_quote', 0.0)] for l in links], device=device, dtype=self.dtype) quote_embeds = self._safe_signed_log(quote_amounts) return self.proj(quote_embeds) class _WhaleOfLinkEncoder(nn.Module): """ Encodes: holding_pct_at_creation """ def __init__(self, out_dim: int, dtype: torch.dtype): super().__init__() self.mlp = nn.Sequential( nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim) ) self.dtype = dtype def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: vals = torch.tensor([[l.get('holding_pct_at_creation', 0.0)] for l in links], device=device, dtype=self.dtype) vals_log = torch.sign(vals) * torch.log1p(torch.abs(vals)) return self.mlp(vals_log) class _TopTraderOfLinkEncoder(nn.Module): """ Encodes: pnl_at_creation """ def __init__(self, out_dim: int, dtype: torch.dtype): # Removed in_features super().__init__() self.mlp = nn.Sequential(nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) self.dtype = dtype def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: num_data = [[l.get('pnl_at_creation', 0)] for l in links] x = torch.tensor(num_data, device=device, dtype=self.dtype) log_scaled_x = torch.sign(x) * torch.log1p(torch.abs(x)) return self.mlp(log_scaled_x) class RelationalGATBlock(nn.Module): """ Shared GATv2Conv that remains relation-aware by concatenating a learned relation embedding to every edge attribute before message passing. """ def __init__( self, node_dim: int, edge_attr_dim: int, n_heads: int, relations: List[str], dtype: torch.dtype, ): super().__init__() self.rel_to_id = {name: idx for idx, name in enumerate(relations)} self.edge_attr_dim = edge_attr_dim self.rel_emb = nn.Embedding(len(relations), edge_attr_dim) self.conv = GATv2Conv( in_channels=node_dim, out_channels=node_dim, heads=n_heads, concat=False, dropout=0.1, add_self_loops=False, edge_dim=edge_attr_dim * 2, # concat of edge attr + relation emb ).to(dtype) def forward( self, x_src: torch.Tensor, x_dst: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, rel_type: str, ) -> torch.Tensor: num_edges = edge_index.size(1) device = edge_index.device if edge_attr is None: edge_attr = torch.zeros( num_edges, self.edge_attr_dim, device=device, dtype=x_src.dtype, ) rel_id = self.rel_to_id.get(rel_type) if rel_id is None: raise KeyError(f"Relation '{rel_type}' not registered in RelationalGATBlock.") rel_feat = self.rel_emb.weight[rel_id].to(edge_attr.dtype) rel_feat = rel_feat.expand(num_edges, -1) augmented_attr = torch.cat([edge_attr, rel_feat], dim=-1) return self.conv((x_src, x_dst), edge_index, edge_attr=augmented_attr) # ============================================================================= # 2. The Main GraphUpdater (GNN) - MANUAL HETEROGENEOUS IMPLEMENTATION # ============================================================================= class GraphUpdater(nn.Module): """ FIXED: Manually implements Heterogeneous GNN logic using separate GATv2Conv layers for each edge type, bypassing the problematic `to_hetero` wrapper. """ def __init__(self,time_encoder: ContextualTimeEncoder, edge_attr_dim: int = 64, n_heads: int = 4, num_layers: int = 2, node_dim: int = 2048, dtype: torch.dtype = torch.float16): super().__init__() self.node_dim = node_dim self.edge_attr_dim = edge_attr_dim self.num_layers = num_layers self.dtype = dtype # --- Instantiate all 11 Link Feature Encoders --- (Unchanged) self.edge_encoders = nn.ModuleDict({ 'TransferLink': _TransferLinkEncoder(edge_attr_dim, dtype=dtype), 'TransferLinkToken': _TransferLinkEncoder(edge_attr_dim, dtype=dtype), 'BundleTradeLink': _BundleTradeLinkEncoder(edge_attr_dim, dtype=dtype), 'CopiedTradeLink': _CopiedTradeLinkEncoder(6, edge_attr_dim, dtype=dtype), # FIXED: in_features=6 'CoordinatedActivityLink': _CoordinatedActivityLinkEncoder(2, edge_attr_dim, dtype=dtype), 'MintedLink': _MintedLinkEncoder(edge_attr_dim, dtype=dtype), 'SnipedLink': _SnipedLinkEncoder(2, edge_attr_dim, dtype=dtype), 'LockedSupplyLink': _LockedSupplyLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No time_encoder 'BurnedLink': _BurnedLinkEncoder(edge_attr_dim, dtype=dtype), 'ProvidedLiquidityLink': _ProvidedLiquidityLinkEncoder(edge_attr_dim, dtype=dtype), 'WhaleOfLink': _WhaleOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features 'TopTraderOfLink': _TopTraderOfLinkEncoder(edge_attr_dim, dtype=dtype), # FIXED: No in_features }).to(dtype) # --- Define shared relational GNN blocks per meta edge direction --- self.edge_groups = self._build_edge_groups() self.conv_layers = nn.ModuleList() for _ in range(num_layers): conv_dict = nn.ModuleDict() for (src_type, dst_type), relations in self.edge_groups.items(): conv_dict[f"{src_type}__{dst_type}"] = RelationalGATBlock( node_dim=node_dim, edge_attr_dim=edge_attr_dim, n_heads=n_heads, relations=relations, dtype=dtype, ) self.conv_layers.append(conv_dict) self.norm = nn.LayerNorm(node_dim) self.to(dtype) # Move norm layer and ModuleList container # Log params total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) print(f"[GraphUpdater] Params: {total_params:,} (Trainable: {trainable_params:,})") def _build_edge_groups(self) -> Dict[tuple, List[str]]: """Group relations by (src_type, dst_type) so conv weights can be shared.""" groups: Dict[tuple, List[str]] = defaultdict(list) wallet_wallet_links = ['TransferLink', 'BundleTradeLink', 'CopiedTradeLink', 'CoordinatedActivityLink'] wallet_token_links = [ 'TransferLinkToken', 'MintedLink', 'SnipedLink', 'LockedSupplyLink', 'BurnedLink', 'ProvidedLiquidityLink', 'WhaleOfLink', 'TopTraderOfLink' ] for link in wallet_wallet_links: groups[('wallet', 'wallet')].append(link) groups[('wallet', 'wallet')].append(f"rev_{link}") for link in wallet_token_links: groups[('wallet', 'token')].append(link) groups[('token', 'wallet')].append(f"rev_{link}") return groups def forward( self, x_dict: Dict[str, torch.Tensor], edge_data_dict: Dict[str, Dict[str, Any]] ) -> Dict[str, torch.Tensor]: device = x_dict['wallet'].device # --- 1. Encode Edge Attributes --- edge_index_dict = {} edge_attr_dict = {} for link_name, data in edge_data_dict.items(): edge_index = data.get('edge_index') links = data.get('links', []) # Check if edge_index is valid before proceeding if edge_index is None or edge_index.numel() == 0 or not links: continue # Skip if no links or index of this type edge_index = edge_index.to(device) # Use vocabulary to get the triplet (src, rel, dst) # Make sure ID_TO_LINK_TYPE is correctly populated if link_name not in models.vocabulary.LINK_NAME_TO_TRIPLET: print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.") continue src_type, rel_type, dst_type = models.vocabulary.LINK_NAME_TO_TRIPLET[link_name] # Check if encoder exists for this link name if link_name not in self.edge_encoders: print(f"Warning: No edge encoder found for link type '{link_name}'. Skipping edge attributes.") edge_attr = None # Or handle differently if attributes are essential else: edge_attr = self.edge_encoders[link_name](links, device).to(self.dtype) # Forward link fwd_key = (src_type, rel_type, dst_type) edge_index_dict[fwd_key] = edge_index if edge_attr is not None: edge_attr_dict[fwd_key] = edge_attr # Reverse link # Ensure edge_index has the right shape for flipping if edge_index.shape[0] == 2: rev_edge_index = edge_index[[1, 0]] rev_rel_type = f'rev_{rel_type}' rev_key = (dst_type, rev_rel_type, src_type) edge_index_dict[rev_key] = rev_edge_index if edge_attr is not None: # Re-use same attributes for reverse edge edge_attr_dict[rev_key] = edge_attr else: print(f"Warning: Edge index for {link_name} has unexpected shape {edge_index.shape}. Cannot create reverse edge.") # --- 2. Run GNN Layers MANUALLY --- x_out = x_dict for i in range(self.num_layers): # Initialize aggregation tensors for each node type that exists in the input msg_aggregates = { node_type: torch.zeros_like(x_node) for node_type, x_node in x_out.items() } # --- Message Passing --- for edge_type_tuple in edge_index_dict.keys(): # Iterate through edges PRESENT in the batch src_type, rel_type, dst_type = edge_type_tuple edge_index = edge_index_dict[edge_type_tuple] edge_attr = edge_attr_dict.get(edge_type_tuple) # Use .get() in case attr is None x_src = x_out.get(src_type) x_dst = x_out.get(dst_type) if x_src is None or x_dst is None: print(f"Warning: Missing node embeddings for types {src_type}->{dst_type}. Skipping.") continue block_key = f"{src_type}__{dst_type}" if block_key not in self.conv_layers[i]: print(f"Warning: Relational block for {block_key} not found in layer {i}. Skipping.") continue block = self.conv_layers[i][block_key] try: messages = block(x_src, x_dst, edge_index, edge_attr, rel_type) except KeyError: print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.") continue # GATv2Conv output is already per-destination-node (shape [num_dst_nodes, node_dim]) # NOT per-edge. So we directly accumulate, no scatter needed. msg_aggregates[dst_type] += messages # --- Aggregation & Update (Residual Connection) --- x_next = {} for node_type, x_original in x_out.items(): # Check if messages were computed and stored correctly if node_type in msg_aggregates and msg_aggregates[node_type].shape[0] > 0: aggregated_msgs = msg_aggregates[node_type] # Ensure dimensions match before adding if x_original.shape == aggregated_msgs.shape: x_next[node_type] = self.norm(x_original + aggregated_msgs) else: print(f"Warning: Shape mismatch for node type {node_type} during update. Original: {x_original.shape}, Aggregated: {aggregated_msgs.shape}. Skipping residual connection.") x_next[node_type] = x_original # Fallback else: x_next[node_type] = x_original x_out = x_next return x_out