""" Heterogeneous Graph Transformer (HGT) for note-level message passing. Uses PyG's HGTConv to enrich note embeddings with structural context from the score graph. Only note-to-note relationships are used (onset, consecutive, during, rest and their reverses). The graph structure comes from graphmuse's create_score_graph with add_beats=False. Reference: "Heterogeneous Graph Transformer" (Hu et al., WWW 2020) """ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from torch import Tensor from torch_geometric.nn import HGTConv from torch_geometric.utils import to_dense_batch # Node types in the score graph (note-only) NODE_TYPES = ['note'] # Note-to-note edge types (from graphmuse with add_reverse=True) # These are the exact edge types stored in the HeteroData graph NOTE_EDGE_TYPES = [ ('note', 'onset', 'note'), ('note', 'consecutive', 'note'), ('note', 'consecutive_rev', 'note'), ('note', 'during', 'note'), ('note', 'during_rev', 'note'), ('note', 'rest', 'note'), ('note', 'rest_rev', 'note'), ] class NoteHGT(nn.Module): """ Multi-layer HGT for enriching note embeddings with graph structure. Uses PyG's HGTConv for heterogeneous message passing across note-to-note musical relationships (onset, consecutive, during, rest). The graph structure comes from graphmuse's create_score_graph with: add_beats=False -> only note nodes add_reverse=True -> bidirectional edges Args: note_dim: Dimension of note embeddings (from feature embedder) hidden_dim: Hidden dimension for HGT (if None, uses note_dim) num_layers: Number of HGT layers num_heads: Number of attention heads per layer dropout: Dropout rate """ def __init__( self, note_dim: int, hidden_dim: Optional[int] = None, num_layers: int = 2, num_heads: int = 4, dropout: float = 0.1, ): super().__init__() self.note_dim = note_dim self.hidden_dim = hidden_dim or note_dim self.num_layers = num_layers # Project note features to hidden dim if different if note_dim != self.hidden_dim: self.note_proj = nn.Linear(note_dim, self.hidden_dim) else: self.note_proj = nn.Identity() # Project back to note_dim at the end if note_dim != self.hidden_dim: self.note_out_proj = nn.Linear(self.hidden_dim, note_dim) else: self.note_out_proj = nn.Identity() # Metadata for HGTConv self.metadata = (NODE_TYPES, NOTE_EDGE_TYPES) # Dropout layer (HGTConv doesn't have built-in dropout) self.dropout = nn.Dropout(dropout) # Stack of HGT convolution layers self.convs = nn.ModuleList() for _ in range(num_layers): conv = HGTConv( in_channels=self.hidden_dim, out_channels=self.hidden_dim, metadata=self.metadata, heads=num_heads, ) self.convs.append(conv) # Layer norms after each layer self.norms = nn.ModuleList() for _ in range(num_layers): self.norms.append(nn.LayerNorm(self.hidden_dim)) @staticmethod def extract_edge_dict(graph) -> Dict[Tuple[str, str, str], Tensor]: """ Extract edge indices from a HeteroData graph. Since we use graphmuse with add_beats=False, the graph only contains note nodes and note-to-note edges. This method simply filters to ensure we only use the expected edge types. Args: graph: PyG HeteroData with edge_index for each edge type Returns: Dict mapping edge_type tuple -> edge_index (2, E) """ edge_dict = {} for edge_type in NOTE_EDGE_TYPES: if edge_type in graph.edge_types: edge_index = graph[edge_type].edge_index if edge_index.numel() > 0: edge_dict[edge_type] = edge_index return edge_dict def forward( self, note_features: Tensor, edge_dict: Dict[Tuple[str, str, str], Tensor], ) -> Tensor: """ Apply HGT message passing to enrich note embeddings. Args: note_features: Note embeddings (N_notes, note_dim) edge_dict: Dict mapping edge_type_tuple -> edge_index (2, E) Returns: Updated note embeddings (N_notes, note_dim) """ device = note_features.device # Build node feature dict (note only) x_dict = { 'note': self.note_proj(note_features), # (N_notes, hidden_dim) } # Move edges to device edge_index_dict = { et: ei.to(device) for et, ei in edge_dict.items() if ei.numel() > 0 } # Apply HGT layers for i, conv in enumerate(self.convs): x_dict_new = conv(x_dict, edge_index_dict) # Apply dropout, layer norm and residual connection if 'note' in x_dict_new and x_dict_new['note'] is not None: x_dict['note'] = self.norms[i]( x_dict['note'] + self.dropout(x_dict_new['note']) ) # Return note features, projected back to original dim return self.note_out_proj(x_dict['note']) def forward_batch( self, note_features: Tensor, edge_dicts: List[Dict[Tuple[str, str, str], Tensor]], num_notes_list: List[int], mask: Optional[Tensor] = None, ) -> Tensor: """ Apply HGT to a batch of graphs using sparse batching for efficiency. Pipeline: 1. Dense (B, N_max, D) -> Sparse (total_notes, D) with batch vector 2. Concatenate edges with proper node offsets 3. Run HGT once on the batched sparse graph 4. Sparse (total_notes, D) -> Dense (B, N_max, D) via to_dense_batch Args: note_features: Batched note embeddings (B, N_max, note_dim) edge_dicts: List of edge_dict per sample (from extract_edge_dict) num_notes_list: Number of valid notes per sample mask: Optional (B, N_max) validity mask (unused, num_notes_list used instead) Returns: Updated note embeddings (B, N_max, note_dim) """ B, N_max, D = note_features.shape device = note_features.device # 1. Dense -> Sparse: flatten valid notes with batch vector note_list = [] note_batch = [] for b in range(B): n = num_notes_list[b] note_list.append(note_features[b, :n]) # (n, D) note_batch.append(torch.full((n,), b, dtype=torch.long, device=device)) notes_sparse = torch.cat(note_list, dim=0) # (total_notes, D) notes_sparse = self.note_proj(notes_sparse) # (total_notes, hidden_dim) note_batch = torch.cat(note_batch, dim=0) # (total_notes,) # 2. Compute node offsets and concatenate edges offsets = [0] for n in num_notes_list[:-1]: offsets.append(offsets[-1] + n) # Gather edges per type with offset correction edge_lists = {et: [] for et in NOTE_EDGE_TYPES} for b, edge_dict in enumerate(edge_dicts): offset = offsets[b] for edge_type, ei in edge_dict.items(): if ei.numel() > 0: edge_lists[edge_type].append(ei.to(device) + offset) # Concatenate edges per type (only non-empty) final_edge_dict = { et: torch.cat(eis, dim=1) for et, eis in edge_lists.items() if eis } # 3. Run HGT layers x_dict = {'note': notes_sparse} for i, conv in enumerate(self.convs): x_dict_new = conv(x_dict, final_edge_dict) if 'note' in x_dict_new and x_dict_new['note'] is not None: x_dict['note'] = self.norms[i]( x_dict['note'] + self.dropout(x_dict_new['note']) ) # 4. Sparse -> Dense: project back and pad to (B, N_max, D) note_out = self.note_out_proj(x_dict['note']) # (total_notes, note_dim) out_dense, _ = to_dense_batch(note_out, note_batch, max_num_nodes=N_max, batch_size=B) return out_dense def forward_graphs( self, note_features: Tensor, graphs: List, num_notes_list: List[int], ) -> Tensor: """ Convenience method that takes HeteroData graphs directly. This extracts edge_dicts from the graphs and calls forward_batch. Use this when working with the collated batch from ScoreGraphMultiFeatureDataset. Args: note_features: Batched note embeddings (B, N_max, note_dim) graphs: List of HeteroData graphs from the dataset num_notes_list: Number of valid notes per sample Returns: Updated note embeddings (B, N_max, note_dim) """ edge_dicts = [self.extract_edge_dict(g) for g in graphs] return self.forward_batch(note_features, edge_dicts, num_notes_list)