import torch import torch.nn as nn from transformers import PreTrainedModel from transformers import AutoConfig, AutoModel, PretrainedConfig class SimpleLinearConfig(PretrainedConfig): model_type = "simple_linear_model" _no_split_modules = ["linear"] def __init__(self, input_dim=768, output_dim=512, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim class SimpleLinearModel(PreTrainedModel): config_class = SimpleLinearConfig _no_split_modules = [] def __init__(self, config: SimpleLinearConfig): super().__init__(config) self.linear = nn.Linear(config.input_dim, config.output_dim) self.post_init() # This calls init_weights internally def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) def init_weights(self): # Standard weight init for name, param in self.named_parameters(): if param.requires_grad: if "weight" in name: nn.init.xavier_uniform_(param) elif "bias" in name: nn.init.zeros_(param) # Register our config class with AutoConfig AutoConfig.register("simple_linear_model", SimpleLinearConfig) # Register our model class with AutoModel AutoModel.register(SimpleLinearConfig, SimpleLinearModel)