import torch import torch.nn as nn from transformers import PreTrainedModel from configuration_rrf import RRFConfig class RRFModel(PreTrainedModel): config_class = RRFConfig base_model_prefix = "rrf" def __init__(self, config): super().__init__(config) # Registering as a buffer ensures it is saved but not treated as a trainable parameter self.register_buffer("eigenvectors", torch.zeros((config.num_nodes, config.num_nodes))) # Standard transformers initialization self.post_init() def forward(self, x): # x: [batch, num_nodes] # Spectral reconstruction using the manifold eigenvectors coeffs = x @ self.eigenvectors reconstruction = coeffs @ self.eigenvectors.T return reconstruction