rrf-v2-36node-manifold / modeling_rrf.py
antonypamo's picture
Upload modeling_rrf.py with huggingface_hub
0251fac verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from configuration_rrf import RRFConfig
class RRFModel(PreTrainedModel):
config_class = RRFConfig
base_model_prefix = "rrf"
def __init__(self, config):
super().__init__(config)
# Registering as a buffer ensures it is saved but not treated as a trainable parameter
self.register_buffer("eigenvectors", torch.zeros((config.num_nodes, config.num_nodes)))
# Standard transformers initialization
self.post_init()
def forward(self, x):
# x: [batch, num_nodes]
# Spectral reconstruction using the manifold eigenvectors
coeffs = x @ self.eigenvectors
reconstruction = coeffs @ self.eigenvectors.T
return reconstruction