File size: 567 Bytes
233e6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
###import os,sys
###sys.path.insert(1,os.path.join(sys.path[0],".."))
from network import Net
from config import Config
from transformers import PreTrainedModel
# utils not used but importing it forces the upload to huggingface hub to include it


class MNIST_Classifier(PreTrainedModel):
    config_class = Config
    def __init__(self, config):
        super().__init__(config)
        self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
           config.output_size)

    def forward(self, input):
        return self.classifier(input)