MNIST_Classifier / model.py
hikmatfarhat's picture
new version
233e6e6
raw
history blame
567 Bytes
###import os,sys
###sys.path.insert(1,os.path.join(sys.path[0],".."))
from network import Net
from config import Config
from transformers import PreTrainedModel
# utils not used but importing it forces the upload to huggingface hub to include it
class MNIST_Classifier(PreTrainedModel):
config_class = Config
def __init__(self, config):
super().__init__(config)
self.classifier=Net(config.input_size,config.hidden_size1,config.hidden_size2,
config.output_size)
def forward(self, input):
return self.classifier(input)