| import torch | |
| from torch import nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from .configuration_modnet import MODNetConfig | |
| from .modnet import MODNet | |
| class HF_MODNet(PreTrainedModel): | |
| config_class = MODNetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.modnet = MODNet(backbone_pretrained=False) | |
| def forward(self, x, inference=True): | |
| return self.modnet(x, inference) |