score-ae / src /model /note_hgt.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
9.52 kB
"""
Heterogeneous Graph Transformer (HGT) for note-level message passing.
Uses PyG's HGTConv to enrich note embeddings with structural context from
the score graph. Only note-to-note relationships are used (onset, consecutive,
during, rest and their reverses). The graph structure comes from graphmuse's
create_score_graph with add_beats=False.
Reference: "Heterogeneous Graph Transformer" (Hu et al., WWW 2020)
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn import HGTConv
from torch_geometric.utils import to_dense_batch
# Node types in the score graph (note-only)
NODE_TYPES = ['note']
# Note-to-note edge types (from graphmuse with add_reverse=True)
# These are the exact edge types stored in the HeteroData graph
NOTE_EDGE_TYPES = [
('note', 'onset', 'note'),
('note', 'consecutive', 'note'),
('note', 'consecutive_rev', 'note'),
('note', 'during', 'note'),
('note', 'during_rev', 'note'),
('note', 'rest', 'note'),
('note', 'rest_rev', 'note'),
]
class NoteHGT(nn.Module):
"""
Multi-layer HGT for enriching note embeddings with graph structure.
Uses PyG's HGTConv for heterogeneous message passing across
note-to-note musical relationships (onset, consecutive, during, rest).
The graph structure comes from graphmuse's create_score_graph with:
add_beats=False -> only note nodes
add_reverse=True -> bidirectional edges
Args:
note_dim: Dimension of note embeddings (from feature embedder)
hidden_dim: Hidden dimension for HGT (if None, uses note_dim)
num_layers: Number of HGT layers
num_heads: Number of attention heads per layer
dropout: Dropout rate
"""
def __init__(
self,
note_dim: int,
hidden_dim: Optional[int] = None,
num_layers: int = 2,
num_heads: int = 4,
dropout: float = 0.1,
):
super().__init__()
self.note_dim = note_dim
self.hidden_dim = hidden_dim or note_dim
self.num_layers = num_layers
# Project note features to hidden dim if different
if note_dim != self.hidden_dim:
self.note_proj = nn.Linear(note_dim, self.hidden_dim)
else:
self.note_proj = nn.Identity()
# Project back to note_dim at the end
if note_dim != self.hidden_dim:
self.note_out_proj = nn.Linear(self.hidden_dim, note_dim)
else:
self.note_out_proj = nn.Identity()
# Metadata for HGTConv
self.metadata = (NODE_TYPES, NOTE_EDGE_TYPES)
# Dropout layer (HGTConv doesn't have built-in dropout)
self.dropout = nn.Dropout(dropout)
# Stack of HGT convolution layers
self.convs = nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(
in_channels=self.hidden_dim,
out_channels=self.hidden_dim,
metadata=self.metadata,
heads=num_heads,
)
self.convs.append(conv)
# Layer norms after each layer
self.norms = nn.ModuleList()
for _ in range(num_layers):
self.norms.append(nn.LayerNorm(self.hidden_dim))
@staticmethod
def extract_edge_dict(graph) -> Dict[Tuple[str, str, str], Tensor]:
"""
Extract edge indices from a HeteroData graph.
Since we use graphmuse with add_beats=False, the graph only contains
note nodes and note-to-note edges. This method simply filters to
ensure we only use the expected edge types.
Args:
graph: PyG HeteroData with edge_index for each edge type
Returns:
Dict mapping edge_type tuple -> edge_index (2, E)
"""
edge_dict = {}
for edge_type in NOTE_EDGE_TYPES:
if edge_type in graph.edge_types:
edge_index = graph[edge_type].edge_index
if edge_index.numel() > 0:
edge_dict[edge_type] = edge_index
return edge_dict
def forward(
self,
note_features: Tensor,
edge_dict: Dict[Tuple[str, str, str], Tensor],
) -> Tensor:
"""
Apply HGT message passing to enrich note embeddings.
Args:
note_features: Note embeddings (N_notes, note_dim)
edge_dict: Dict mapping edge_type_tuple -> edge_index (2, E)
Returns:
Updated note embeddings (N_notes, note_dim)
"""
device = note_features.device
# Build node feature dict (note only)
x_dict = {
'note': self.note_proj(note_features), # (N_notes, hidden_dim)
}
# Move edges to device
edge_index_dict = {
et: ei.to(device) for et, ei in edge_dict.items() if ei.numel() > 0
}
# Apply HGT layers
for i, conv in enumerate(self.convs):
x_dict_new = conv(x_dict, edge_index_dict)
# Apply dropout, layer norm and residual connection
if 'note' in x_dict_new and x_dict_new['note'] is not None:
x_dict['note'] = self.norms[i](
x_dict['note'] + self.dropout(x_dict_new['note'])
)
# Return note features, projected back to original dim
return self.note_out_proj(x_dict['note'])
def forward_batch(
self,
note_features: Tensor,
edge_dicts: List[Dict[Tuple[str, str, str], Tensor]],
num_notes_list: List[int],
mask: Optional[Tensor] = None,
) -> Tensor:
"""
Apply HGT to a batch of graphs using sparse batching for efficiency.
Pipeline:
1. Dense (B, N_max, D) -> Sparse (total_notes, D) with batch vector
2. Concatenate edges with proper node offsets
3. Run HGT once on the batched sparse graph
4. Sparse (total_notes, D) -> Dense (B, N_max, D) via to_dense_batch
Args:
note_features: Batched note embeddings (B, N_max, note_dim)
edge_dicts: List of edge_dict per sample (from extract_edge_dict)
num_notes_list: Number of valid notes per sample
mask: Optional (B, N_max) validity mask (unused, num_notes_list used instead)
Returns:
Updated note embeddings (B, N_max, note_dim)
"""
B, N_max, D = note_features.shape
device = note_features.device
# 1. Dense -> Sparse: flatten valid notes with batch vector
note_list = []
note_batch = []
for b in range(B):
n = num_notes_list[b]
note_list.append(note_features[b, :n]) # (n, D)
note_batch.append(torch.full((n,), b, dtype=torch.long, device=device))
notes_sparse = torch.cat(note_list, dim=0) # (total_notes, D)
notes_sparse = self.note_proj(notes_sparse) # (total_notes, hidden_dim)
note_batch = torch.cat(note_batch, dim=0) # (total_notes,)
# 2. Compute node offsets and concatenate edges
offsets = [0]
for n in num_notes_list[:-1]:
offsets.append(offsets[-1] + n)
# Gather edges per type with offset correction
edge_lists = {et: [] for et in NOTE_EDGE_TYPES}
for b, edge_dict in enumerate(edge_dicts):
offset = offsets[b]
for edge_type, ei in edge_dict.items():
if ei.numel() > 0:
edge_lists[edge_type].append(ei.to(device) + offset)
# Concatenate edges per type (only non-empty)
final_edge_dict = {
et: torch.cat(eis, dim=1)
for et, eis in edge_lists.items() if eis
}
# 3. Run HGT layers
x_dict = {'note': notes_sparse}
for i, conv in enumerate(self.convs):
x_dict_new = conv(x_dict, final_edge_dict)
if 'note' in x_dict_new and x_dict_new['note'] is not None:
x_dict['note'] = self.norms[i](
x_dict['note'] + self.dropout(x_dict_new['note'])
)
# 4. Sparse -> Dense: project back and pad to (B, N_max, D)
note_out = self.note_out_proj(x_dict['note']) # (total_notes, note_dim)
out_dense, _ = to_dense_batch(note_out, note_batch, max_num_nodes=N_max, batch_size=B)
return out_dense
def forward_graphs(
self,
note_features: Tensor,
graphs: List,
num_notes_list: List[int],
) -> Tensor:
"""
Convenience method that takes HeteroData graphs directly.
This extracts edge_dicts from the graphs and calls forward_batch.
Use this when working with the collated batch from ScoreGraphMultiFeatureDataset.
Args:
note_features: Batched note embeddings (B, N_max, note_dim)
graphs: List of HeteroData graphs from the dataset
num_notes_list: Number of valid notes per sample
Returns:
Updated note embeddings (B, N_max, note_dim)
"""
edge_dicts = [self.extract_edge_dict(g) for g in graphs]
return self.forward_batch(note_features, edge_dicts, num_notes_list)