| import torch |
| import torch.nn as nn |
| |
| from torch_geometric.nn import GATv2Conv |
| from torch_geometric.data import HeteroData |
| from typing import Dict, List, Any |
| from collections import defaultdict |
| from PIL import Image |
|
|
| from models.helper_encoders import ContextualTimeEncoder |
| |
| from models.vocabulary import ID_TO_LINK_TYPE |
| |
| import models.vocabulary |
| from models.wallet_encoder import WalletEncoder |
| from models.token_encoder import TokenEncoder |
| from models.multi_modal_processor import MultiModalEncoder |
|
|
|
|
| class _TransferLinkEncoder(nn.Module): |
| """Encodes: transfer amount only (timestamps removed).""" |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) |
| features = self._safe_signed_log(amounts) |
|
|
| return self.proj(features) |
|
|
| class _BundleTradeLinkEncoder(nn.Module): |
| """Encodes: total_amount across bundle (timestamps removed).""" |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| totals = torch.tensor([[l.get('total_amount', 0.0)] for l in links], device=device, dtype=self.dtype) |
| total_embeds = self._safe_signed_log(totals) |
|
|
| return self.proj(total_embeds) |
|
|
| class _CopiedTradeLinkEncoder(nn.Module): |
| """ Encodes: 10 numerical features """ |
| def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.in_features = in_features |
| self.norm = nn.LayerNorm(in_features) |
| self.mlp = nn.Sequential( |
| nn.Linear(in_features, out_dim * 2), nn.GELU(), |
| nn.Linear(out_dim * 2, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| num_data = [] |
| for l in links: |
| |
| num_data.append([ |
| l.get('time_gap_on_buy_sec', 0), l.get('time_gap_on_sell_sec', 0), |
| l.get('leader_pnl', 0), l.get('follower_pnl', 0), |
| l.get('follower_buy_total', 0), l.get('follower_sell_total', 0) |
| ]) |
| |
| x = torch.tensor(num_data, device=device, dtype=self.dtype) |
| |
| x_norm = self.norm(self._safe_signed_log(x)) |
| return self.mlp(x_norm) |
|
|
| class _CoordinatedActivityLinkEncoder(nn.Module): |
| """ Encodes: 2 numerical features """ |
| def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.in_features = in_features |
| self.norm = nn.LayerNorm(in_features) |
| self.mlp = nn.Sequential( |
| nn.Linear(in_features, out_dim), nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| num_data = [] |
| for l in links: |
| num_data.append([ |
| l.get('time_gap_on_first_sec', 0), l.get('time_gap_on_second_sec', 0) |
| ]) |
| |
| x = torch.tensor(num_data, device=device, dtype=self.dtype) |
| x_norm = self.norm(self._safe_signed_log(x)) |
| return self.mlp(x_norm) |
|
|
| class _MintedLinkEncoder(nn.Module): |
| """Encodes: buy_amount only (timestamps removed).""" |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| nums = torch.tensor([[l.get('buy_amount', 0.0)] for l in links], device=device, dtype=self.dtype) |
|
|
| num_embeds = self._safe_signed_log(nums) |
|
|
| return self.proj(num_embeds) |
|
|
| class _SnipedLinkEncoder(nn.Module): |
| """ Encodes: rank, sniped_amount """ |
| def __init__(self, in_features: int, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.norm = nn.LayerNorm(in_features) |
| self.mlp = nn.Sequential(nn.Linear(in_features, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) |
| self.dtype = dtype |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| num_data = [[l.get('rank', 0), l.get('sniped_amount', 0)] for l in links] |
| |
| x = torch.tensor(num_data, device=device, dtype=self.dtype) |
|
|
| |
| |
| x[:, 0] = 1.0 / torch.clamp(x[:, 0], min=1.0) |
| x[:, 1] = torch.sign(x[:, 1]) * torch.log1p(torch.abs(x[:, 1])) |
|
|
| x_norm = self.norm(x) |
| return self.mlp(x_norm) |
|
|
| class _LockedSupplyLinkEncoder(nn.Module): |
| """ Encodes: amount """ |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| nums = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) |
| num_embeds = self._safe_signed_log(nums) |
| return self.proj(num_embeds) |
|
|
| class _BurnedLinkEncoder(nn.Module): |
| """Encodes: burned amount (timestamps removed).""" |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| amounts = torch.tensor([[l.get('amount', 0.0)] for l in links], device=device, dtype=self.dtype) |
| amount_embeds = self._safe_signed_log(amounts) |
|
|
| return self.proj(amount_embeds) |
|
|
| class _ProvidedLiquidityLinkEncoder(nn.Module): |
| """Encodes: quote amount (timestamps removed).""" |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def _safe_signed_log(self, x: torch.Tensor) -> torch.Tensor: |
| return torch.sign(x) * torch.log1p(torch.abs(x)) |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| quote_amounts = torch.tensor([[l.get('amount_quote', 0.0)] for l in links], device=device, dtype=self.dtype) |
| quote_embeds = self._safe_signed_log(quote_amounts) |
|
|
| return self.proj(quote_embeds) |
|
|
| class _WhaleOfLinkEncoder(nn.Module): |
| """ Encodes: holding_pct_at_creation """ |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(1, out_dim), |
| nn.GELU(), |
| nn.Linear(out_dim, out_dim) |
| ) |
| self.dtype = dtype |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| vals = torch.tensor([[l.get('holding_pct_at_creation', 0.0)] for l in links], device=device, dtype=self.dtype) |
| vals_log = torch.sign(vals) * torch.log1p(torch.abs(vals)) |
| return self.mlp(vals_log) |
|
|
| class _TopTraderOfLinkEncoder(nn.Module): |
| """ Encodes: pnl_at_creation """ |
| def __init__(self, out_dim: int, dtype: torch.dtype): |
| super().__init__() |
| self.mlp = nn.Sequential(nn.Linear(1, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)) |
| self.dtype = dtype |
|
|
| def forward(self, links: List[Dict[str, Any]], device) -> torch.Tensor: |
| num_data = [[l.get('pnl_at_creation', 0)] for l in links] |
| x = torch.tensor(num_data, device=device, dtype=self.dtype) |
| log_scaled_x = torch.sign(x) * torch.log1p(torch.abs(x)) |
| return self.mlp(log_scaled_x) |
|
|
|
|
| class RelationalGATBlock(nn.Module): |
| """ |
| Shared GATv2Conv that remains relation-aware by concatenating a learned |
| relation embedding to every edge attribute before message passing. |
| """ |
|
|
| def __init__( |
| self, |
| node_dim: int, |
| edge_attr_dim: int, |
| n_heads: int, |
| relations: List[str], |
| dtype: torch.dtype, |
| ): |
| super().__init__() |
| self.rel_to_id = {name: idx for idx, name in enumerate(relations)} |
| self.edge_attr_dim = edge_attr_dim |
| self.rel_emb = nn.Embedding(len(relations), edge_attr_dim) |
| self.conv = GATv2Conv( |
| in_channels=node_dim, |
| out_channels=node_dim, |
| heads=n_heads, |
| concat=False, |
| dropout=0.1, |
| add_self_loops=False, |
| edge_dim=edge_attr_dim * 2, |
| ).to(dtype) |
|
|
| def forward( |
| self, |
| x_src: torch.Tensor, |
| x_dst: torch.Tensor, |
| edge_index: torch.Tensor, |
| edge_attr: torch.Tensor, |
| rel_type: str, |
| ) -> torch.Tensor: |
| num_edges = edge_index.size(1) |
| device = edge_index.device |
|
|
| if edge_attr is None: |
| edge_attr = torch.zeros( |
| num_edges, |
| self.edge_attr_dim, |
| device=device, |
| dtype=x_src.dtype, |
| ) |
|
|
| rel_id = self.rel_to_id.get(rel_type) |
| if rel_id is None: |
| raise KeyError(f"Relation '{rel_type}' not registered in RelationalGATBlock.") |
|
|
| rel_feat = self.rel_emb.weight[rel_id].to(edge_attr.dtype) |
| rel_feat = rel_feat.expand(num_edges, -1) |
| augmented_attr = torch.cat([edge_attr, rel_feat], dim=-1) |
|
|
| return self.conv((x_src, x_dst), edge_index, edge_attr=augmented_attr) |
| |
| |
| |
|
|
| class GraphUpdater(nn.Module): |
| """ |
| FIXED: Manually implements Heterogeneous GNN logic using separate GATv2Conv |
| layers for each edge type, bypassing the problematic `to_hetero` wrapper. |
| """ |
|
|
| def __init__(self,time_encoder: ContextualTimeEncoder, edge_attr_dim: int = 64, |
| n_heads: int = 4, num_layers: int = 2, node_dim: int = 2048, dtype: torch.dtype = torch.float16): |
| super().__init__() |
| self.node_dim = node_dim |
| self.edge_attr_dim = edge_attr_dim |
| self.num_layers = num_layers |
| self.dtype = dtype |
|
|
| |
| self.edge_encoders = nn.ModuleDict({ |
| 'TransferLink': _TransferLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'TransferLinkToken': _TransferLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'BundleTradeLink': _BundleTradeLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'CopiedTradeLink': _CopiedTradeLinkEncoder(6, edge_attr_dim, dtype=dtype), |
| 'CoordinatedActivityLink': _CoordinatedActivityLinkEncoder(2, edge_attr_dim, dtype=dtype), |
| 'MintedLink': _MintedLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'SnipedLink': _SnipedLinkEncoder(2, edge_attr_dim, dtype=dtype), |
| 'LockedSupplyLink': _LockedSupplyLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'BurnedLink': _BurnedLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'ProvidedLiquidityLink': _ProvidedLiquidityLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'WhaleOfLink': _WhaleOfLinkEncoder(edge_attr_dim, dtype=dtype), |
| 'TopTraderOfLink': _TopTraderOfLinkEncoder(edge_attr_dim, dtype=dtype), |
| }).to(dtype) |
|
|
| |
| self.edge_groups = self._build_edge_groups() |
| self.conv_layers = nn.ModuleList() |
| for _ in range(num_layers): |
| conv_dict = nn.ModuleDict() |
| for (src_type, dst_type), relations in self.edge_groups.items(): |
| conv_dict[f"{src_type}__{dst_type}"] = RelationalGATBlock( |
| node_dim=node_dim, |
| edge_attr_dim=edge_attr_dim, |
| n_heads=n_heads, |
| relations=relations, |
| dtype=dtype, |
| ) |
| self.conv_layers.append(conv_dict) |
|
|
| self.norm = nn.LayerNorm(node_dim) |
| self.to(dtype) |
|
|
| |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"[GraphUpdater] Params: {total_params:,} (Trainable: {trainable_params:,})") |
|
|
| def _build_edge_groups(self) -> Dict[tuple, List[str]]: |
| """Group relations by (src_type, dst_type) so conv weights can be shared.""" |
| groups: Dict[tuple, List[str]] = defaultdict(list) |
|
|
| wallet_wallet_links = ['TransferLink', 'BundleTradeLink', 'CopiedTradeLink', 'CoordinatedActivityLink'] |
| wallet_token_links = [ |
| 'TransferLinkToken', 'MintedLink', 'SnipedLink', 'LockedSupplyLink', |
| 'BurnedLink', 'ProvidedLiquidityLink', 'WhaleOfLink', 'TopTraderOfLink' |
| ] |
|
|
| for link in wallet_wallet_links: |
| groups[('wallet', 'wallet')].append(link) |
| groups[('wallet', 'wallet')].append(f"rev_{link}") |
|
|
| for link in wallet_token_links: |
| groups[('wallet', 'token')].append(link) |
| groups[('token', 'wallet')].append(f"rev_{link}") |
|
|
| return groups |
|
|
| def forward( |
| self, |
| x_dict: Dict[str, torch.Tensor], |
| edge_data_dict: Dict[str, Dict[str, Any]] |
| ) -> Dict[str, torch.Tensor]: |
| device = x_dict['wallet'].device |
|
|
| |
| edge_index_dict = {} |
| edge_attr_dict = {} |
|
|
| for link_name, data in edge_data_dict.items(): |
| edge_index = data.get('edge_index') |
| links = data.get('links', []) |
|
|
| |
| if edge_index is None or edge_index.numel() == 0 or not links: |
| continue |
|
|
| edge_index = edge_index.to(device) |
|
|
| |
| |
| if link_name not in models.vocabulary.LINK_NAME_TO_TRIPLET: |
| print(f"Warning: Link name '{link_name}' not found in vocabulary.LINK_NAME_TO_TRIPLET. Skipping.") |
| continue |
| src_type, rel_type, dst_type = models.vocabulary.LINK_NAME_TO_TRIPLET[link_name] |
|
|
| |
| if link_name not in self.edge_encoders: |
| print(f"Warning: No edge encoder found for link type '{link_name}'. Skipping edge attributes.") |
| edge_attr = None |
| else: |
| edge_attr = self.edge_encoders[link_name](links, device).to(self.dtype) |
|
|
|
|
| |
| fwd_key = (src_type, rel_type, dst_type) |
| edge_index_dict[fwd_key] = edge_index |
| if edge_attr is not None: |
| edge_attr_dict[fwd_key] = edge_attr |
|
|
| |
| |
| if edge_index.shape[0] == 2: |
| rev_edge_index = edge_index[[1, 0]] |
| rev_rel_type = f'rev_{rel_type}' |
| rev_key = (dst_type, rev_rel_type, src_type) |
| edge_index_dict[rev_key] = rev_edge_index |
| if edge_attr is not None: |
| |
| edge_attr_dict[rev_key] = edge_attr |
| else: |
| print(f"Warning: Edge index for {link_name} has unexpected shape {edge_index.shape}. Cannot create reverse edge.") |
|
|
|
|
| |
| x_out = x_dict |
| for i in range(self.num_layers): |
| |
| msg_aggregates = { |
| node_type: torch.zeros_like(x_node) |
| for node_type, x_node in x_out.items() |
| } |
|
|
| |
| for edge_type_tuple in edge_index_dict.keys(): |
| src_type, rel_type, dst_type = edge_type_tuple |
| edge_index = edge_index_dict[edge_type_tuple] |
| edge_attr = edge_attr_dict.get(edge_type_tuple) |
|
|
| x_src = x_out.get(src_type) |
| x_dst = x_out.get(dst_type) |
| if x_src is None or x_dst is None: |
| print(f"Warning: Missing node embeddings for types {src_type}->{dst_type}. Skipping.") |
| continue |
|
|
| block_key = f"{src_type}__{dst_type}" |
| if block_key not in self.conv_layers[i]: |
| print(f"Warning: Relational block for {block_key} not found in layer {i}. Skipping.") |
| continue |
| block = self.conv_layers[i][block_key] |
|
|
| try: |
| messages = block(x_src, x_dst, edge_index, edge_attr, rel_type) |
| except KeyError: |
| print(f"Warning: Relation '{rel_type}' missing in block {block_key}. Skipping.") |
| continue |
|
|
| |
| |
| msg_aggregates[dst_type] += messages |
|
|
| |
| x_next = {} |
| for node_type, x_original in x_out.items(): |
| |
| if node_type in msg_aggregates and msg_aggregates[node_type].shape[0] > 0: |
| aggregated_msgs = msg_aggregates[node_type] |
| |
| if x_original.shape == aggregated_msgs.shape: |
| x_next[node_type] = self.norm(x_original + aggregated_msgs) |
| else: |
| print(f"Warning: Shape mismatch for node type {node_type} during update. Original: {x_original.shape}, Aggregated: {aggregated_msgs.shape}. Skipping residual connection.") |
| x_next[node_type] = x_original |
| else: |
| x_next[node_type] = x_original |
|
|
| x_out = x_next |
|
|
| return x_out |
|
|