import torch from torch import nn from transformers import PreTrainedModel, PretrainedConfig from .configuration_modnet import MODNetConfig from .modnet import MODNet class HF_MODNet(PreTrainedModel): config_class = MODNetConfig def __init__(self, config): super().__init__(config) self.modnet = MODNet(backbone_pretrained=False) def forward(self, x, inference=True): return self.modnet(x, inference)