File size: 457 Bytes
d8953e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bdad8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from .configuration_modnet import MODNetConfig

from .modnet import MODNet


class HF_MODNet(PreTrainedModel):
    config_class = MODNetConfig

    def __init__(self, config):
        super().__init__(config)
        self.modnet = MODNet(backbone_pretrained=False)

    def forward(self, x, inference=True):
        return self.modnet(x, inference)