| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from transformers import AutoConfig, AutoModel, PretrainedConfig | |
| class SimpleLinearConfig(PretrainedConfig): | |
| model_type = "simple_linear_model" | |
| _no_split_modules = ["linear"] | |
| def __init__(self, input_dim=768, output_dim=512, **kwargs): | |
| super().__init__(**kwargs) | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| class SimpleLinearModel(PreTrainedModel): | |
| config_class = SimpleLinearConfig | |
| _no_split_modules = [] | |
| def __init__(self, config: SimpleLinearConfig): | |
| super().__init__(config) | |
| self.linear = nn.Linear(config.input_dim, config.output_dim) | |
| self.post_init() # This calls init_weights internally | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.linear(x) | |
| def init_weights(self): | |
| # Standard weight init | |
| for name, param in self.named_parameters(): | |
| if param.requires_grad: | |
| if "weight" in name: | |
| nn.init.xavier_uniform_(param) | |
| elif "bias" in name: | |
| nn.init.zeros_(param) | |
| # Register our config class with AutoConfig | |
| AutoConfig.register("simple_linear_model", SimpleLinearConfig) | |
| # Register our model class with AutoModel | |
| AutoModel.register(SimpleLinearConfig, SimpleLinearModel) |