import torch import torch.nn as nn import typing import transformers from .config import MLPConfig class Backbone(nn.Module): def __init__(self, config): super().__init__() self.model = nn.Sequential( nn.Linear(config.input_dim, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, config.output_dim) ) def forward(self, x): return self.model(x) class MLP(transformers.PreTrainedModel): """HF-compatible model.""" config_class = MLPConfig base_model_prefix = 'mlp' def __init__(self, config: MLPConfig): super().__init__(config) self.config = config self.backbone = Backbone(config)