Spaces:
Sleeping
Sleeping
| """ | |
| Heterogeneous Graph Transformer (HGT) for note-level message passing. | |
| Uses PyG's HGTConv to enrich note embeddings with structural context from | |
| the score graph. Only note-to-note relationships are used (onset, consecutive, | |
| during, rest and their reverses). The graph structure comes from graphmuse's | |
| create_score_graph with add_beats=False. | |
| Reference: "Heterogeneous Graph Transformer" (Hu et al., WWW 2020) | |
| """ | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from torch_geometric.nn import HGTConv | |
| from torch_geometric.utils import to_dense_batch | |
| # Node types in the score graph (note-only) | |
| NODE_TYPES = ['note'] | |
| # Note-to-note edge types (from graphmuse with add_reverse=True) | |
| # These are the exact edge types stored in the HeteroData graph | |
| NOTE_EDGE_TYPES = [ | |
| ('note', 'onset', 'note'), | |
| ('note', 'consecutive', 'note'), | |
| ('note', 'consecutive_rev', 'note'), | |
| ('note', 'during', 'note'), | |
| ('note', 'during_rev', 'note'), | |
| ('note', 'rest', 'note'), | |
| ('note', 'rest_rev', 'note'), | |
| ] | |
| class NoteHGT(nn.Module): | |
| """ | |
| Multi-layer HGT for enriching note embeddings with graph structure. | |
| Uses PyG's HGTConv for heterogeneous message passing across | |
| note-to-note musical relationships (onset, consecutive, during, rest). | |
| The graph structure comes from graphmuse's create_score_graph with: | |
| add_beats=False -> only note nodes | |
| add_reverse=True -> bidirectional edges | |
| Args: | |
| note_dim: Dimension of note embeddings (from feature embedder) | |
| hidden_dim: Hidden dimension for HGT (if None, uses note_dim) | |
| num_layers: Number of HGT layers | |
| num_heads: Number of attention heads per layer | |
| dropout: Dropout rate | |
| """ | |
| def __init__( | |
| self, | |
| note_dim: int, | |
| hidden_dim: Optional[int] = None, | |
| num_layers: int = 2, | |
| num_heads: int = 4, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.note_dim = note_dim | |
| self.hidden_dim = hidden_dim or note_dim | |
| self.num_layers = num_layers | |
| # Project note features to hidden dim if different | |
| if note_dim != self.hidden_dim: | |
| self.note_proj = nn.Linear(note_dim, self.hidden_dim) | |
| else: | |
| self.note_proj = nn.Identity() | |
| # Project back to note_dim at the end | |
| if note_dim != self.hidden_dim: | |
| self.note_out_proj = nn.Linear(self.hidden_dim, note_dim) | |
| else: | |
| self.note_out_proj = nn.Identity() | |
| # Metadata for HGTConv | |
| self.metadata = (NODE_TYPES, NOTE_EDGE_TYPES) | |
| # Dropout layer (HGTConv doesn't have built-in dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| # Stack of HGT convolution layers | |
| self.convs = nn.ModuleList() | |
| for _ in range(num_layers): | |
| conv = HGTConv( | |
| in_channels=self.hidden_dim, | |
| out_channels=self.hidden_dim, | |
| metadata=self.metadata, | |
| heads=num_heads, | |
| ) | |
| self.convs.append(conv) | |
| # Layer norms after each layer | |
| self.norms = nn.ModuleList() | |
| for _ in range(num_layers): | |
| self.norms.append(nn.LayerNorm(self.hidden_dim)) | |
| def extract_edge_dict(graph) -> Dict[Tuple[str, str, str], Tensor]: | |
| """ | |
| Extract edge indices from a HeteroData graph. | |
| Since we use graphmuse with add_beats=False, the graph only contains | |
| note nodes and note-to-note edges. This method simply filters to | |
| ensure we only use the expected edge types. | |
| Args: | |
| graph: PyG HeteroData with edge_index for each edge type | |
| Returns: | |
| Dict mapping edge_type tuple -> edge_index (2, E) | |
| """ | |
| edge_dict = {} | |
| for edge_type in NOTE_EDGE_TYPES: | |
| if edge_type in graph.edge_types: | |
| edge_index = graph[edge_type].edge_index | |
| if edge_index.numel() > 0: | |
| edge_dict[edge_type] = edge_index | |
| return edge_dict | |
| def forward( | |
| self, | |
| note_features: Tensor, | |
| edge_dict: Dict[Tuple[str, str, str], Tensor], | |
| ) -> Tensor: | |
| """ | |
| Apply HGT message passing to enrich note embeddings. | |
| Args: | |
| note_features: Note embeddings (N_notes, note_dim) | |
| edge_dict: Dict mapping edge_type_tuple -> edge_index (2, E) | |
| Returns: | |
| Updated note embeddings (N_notes, note_dim) | |
| """ | |
| device = note_features.device | |
| # Build node feature dict (note only) | |
| x_dict = { | |
| 'note': self.note_proj(note_features), # (N_notes, hidden_dim) | |
| } | |
| # Move edges to device | |
| edge_index_dict = { | |
| et: ei.to(device) for et, ei in edge_dict.items() if ei.numel() > 0 | |
| } | |
| # Apply HGT layers | |
| for i, conv in enumerate(self.convs): | |
| x_dict_new = conv(x_dict, edge_index_dict) | |
| # Apply dropout, layer norm and residual connection | |
| if 'note' in x_dict_new and x_dict_new['note'] is not None: | |
| x_dict['note'] = self.norms[i]( | |
| x_dict['note'] + self.dropout(x_dict_new['note']) | |
| ) | |
| # Return note features, projected back to original dim | |
| return self.note_out_proj(x_dict['note']) | |
| def forward_batch( | |
| self, | |
| note_features: Tensor, | |
| edge_dicts: List[Dict[Tuple[str, str, str], Tensor]], | |
| num_notes_list: List[int], | |
| mask: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| """ | |
| Apply HGT to a batch of graphs using sparse batching for efficiency. | |
| Pipeline: | |
| 1. Dense (B, N_max, D) -> Sparse (total_notes, D) with batch vector | |
| 2. Concatenate edges with proper node offsets | |
| 3. Run HGT once on the batched sparse graph | |
| 4. Sparse (total_notes, D) -> Dense (B, N_max, D) via to_dense_batch | |
| Args: | |
| note_features: Batched note embeddings (B, N_max, note_dim) | |
| edge_dicts: List of edge_dict per sample (from extract_edge_dict) | |
| num_notes_list: Number of valid notes per sample | |
| mask: Optional (B, N_max) validity mask (unused, num_notes_list used instead) | |
| Returns: | |
| Updated note embeddings (B, N_max, note_dim) | |
| """ | |
| B, N_max, D = note_features.shape | |
| device = note_features.device | |
| # 1. Dense -> Sparse: flatten valid notes with batch vector | |
| note_list = [] | |
| note_batch = [] | |
| for b in range(B): | |
| n = num_notes_list[b] | |
| note_list.append(note_features[b, :n]) # (n, D) | |
| note_batch.append(torch.full((n,), b, dtype=torch.long, device=device)) | |
| notes_sparse = torch.cat(note_list, dim=0) # (total_notes, D) | |
| notes_sparse = self.note_proj(notes_sparse) # (total_notes, hidden_dim) | |
| note_batch = torch.cat(note_batch, dim=0) # (total_notes,) | |
| # 2. Compute node offsets and concatenate edges | |
| offsets = [0] | |
| for n in num_notes_list[:-1]: | |
| offsets.append(offsets[-1] + n) | |
| # Gather edges per type with offset correction | |
| edge_lists = {et: [] for et in NOTE_EDGE_TYPES} | |
| for b, edge_dict in enumerate(edge_dicts): | |
| offset = offsets[b] | |
| for edge_type, ei in edge_dict.items(): | |
| if ei.numel() > 0: | |
| edge_lists[edge_type].append(ei.to(device) + offset) | |
| # Concatenate edges per type (only non-empty) | |
| final_edge_dict = { | |
| et: torch.cat(eis, dim=1) | |
| for et, eis in edge_lists.items() if eis | |
| } | |
| # 3. Run HGT layers | |
| x_dict = {'note': notes_sparse} | |
| for i, conv in enumerate(self.convs): | |
| x_dict_new = conv(x_dict, final_edge_dict) | |
| if 'note' in x_dict_new and x_dict_new['note'] is not None: | |
| x_dict['note'] = self.norms[i]( | |
| x_dict['note'] + self.dropout(x_dict_new['note']) | |
| ) | |
| # 4. Sparse -> Dense: project back and pad to (B, N_max, D) | |
| note_out = self.note_out_proj(x_dict['note']) # (total_notes, note_dim) | |
| out_dense, _ = to_dense_batch(note_out, note_batch, max_num_nodes=N_max, batch_size=B) | |
| return out_dense | |
| def forward_graphs( | |
| self, | |
| note_features: Tensor, | |
| graphs: List, | |
| num_notes_list: List[int], | |
| ) -> Tensor: | |
| """ | |
| Convenience method that takes HeteroData graphs directly. | |
| This extracts edge_dicts from the graphs and calls forward_batch. | |
| Use this when working with the collated batch from ScoreGraphMultiFeatureDataset. | |
| Args: | |
| note_features: Batched note embeddings (B, N_max, note_dim) | |
| graphs: List of HeteroData graphs from the dataset | |
| num_notes_list: Number of valid notes per sample | |
| Returns: | |
| Updated note embeddings (B, N_max, note_dim) | |
| """ | |
| edge_dicts = [self.extract_edge_dict(g) for g in graphs] | |
| return self.forward_batch(note_features, edge_dicts, num_notes_list) | |