| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from configuration_rrf import RRFConfig | |
| class RRFModel(PreTrainedModel): | |
| config_class = RRFConfig | |
| base_model_prefix = "rrf" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Registering as a buffer ensures it is saved but not treated as a trainable parameter | |
| self.register_buffer("eigenvectors", torch.zeros((config.num_nodes, config.num_nodes))) | |
| # Standard transformers initialization | |
| self.post_init() | |
| def forward(self, x): | |
| # x: [batch, num_nodes] | |
| # Spectral reconstruction using the manifold eigenvectors | |
| coeffs = x @ self.eigenvectors | |
| reconstruction = coeffs @ self.eigenvectors.T | |
| return reconstruction | |