SupremoUGH commited on
Commit
bd8b79c
·
unverified ·
1 Parent(s): 2beb768

more changes

Browse files
.gitignore CHANGED
@@ -7,4 +7,5 @@ model_weights.pth
7
  # Python metadata
8
  venv
9
  __pycache__
10
- symmetric_test.egg-info
 
 
7
  # Python metadata
8
  venv
9
  __pycache__
10
+ symmetric_test.egg-info
11
+ .ruff_cache
config.json CHANGED
@@ -1,9 +1,14 @@
1
  {
2
  "architectures": ["DigitClassifier"],
3
- "model_type": "pytorch",
4
  "num_labels": 10,
5
  "id2label": {
6
- "0": "0", "1": "1", "2": "2", "3": "3", "4": "4", "5": "5", "6": "6", "7": "7", "8": "8", "9": "9"
 
7
  },
8
- "preprocessor": "symmetric_test.preprocessor.get_preprocessor"
 
 
 
 
9
  }
 
1
  {
2
  "architectures": ["DigitClassifier"],
3
+ "model_type": "custom",
4
  "num_labels": 10,
5
  "id2label": {
6
+ "0": "0", "1": "1", "2": "2", "3": "3", "4": "4",
7
+ "5": "5", "6": "6", "7": "7", "8": "8", "9": "9"
8
  },
9
+ "preprocessor": "symmetric_test.preprocessor.get_transform",
10
+ "auto_map": {
11
+ "AutoModel": "symmetric_test.model.DigitClassifier",
12
+ "AutoConfig": "symmetric_test.model.DigitClassifierConfig"
13
+ }
14
  }
model_weights.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dc1c5d0edd67dcbd3746a7d79567c4825cf261944f7be55ed55d1386d3b7339
3
- size 903624
 
 
 
 
symmetric_test/model.py CHANGED
@@ -1,9 +1,20 @@
 
1
  import torch.nn as nn
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class DigitClassifier(nn.Module):
5
- def __init__(self):
6
  super().__init__()
 
7
  self.conv_block = nn.Sequential(
8
  nn.Conv2d(1, 32, 3),
9
  nn.ReLU(),
@@ -20,3 +31,12 @@ class DigitClassifier(nn.Module):
20
  x = self.conv_block(x)
21
  x = x.view(x.size(0), -1)
22
  return self.classifier(x)
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import torch.nn as nn
3
+ from transformers import PretrainedConfig
4
+
5
+
6
+ class DigitClassifierConfig(PretrainedConfig):
7
+ model_type = "custom"
8
+
9
+ def __init__(self, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.num_labels = 10
12
 
13
 
14
  class DigitClassifier(nn.Module):
15
+ def __init__(self, config):
16
  super().__init__()
17
+ self.config = config
18
  self.conv_block = nn.Sequential(
19
  nn.Conv2d(1, 32, 3),
20
  nn.ReLU(),
 
31
  x = self.conv_block(x)
32
  x = x.view(x.size(0), -1)
33
  return self.classifier(x)
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
37
+ config = DigitClassifierConfig()
38
+ model = cls(config)
39
+ model.load_state_dict(
40
+ torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin")
41
+ )
42
+ return model
symmetric_test/preprocessor.py CHANGED
@@ -1,12 +1,9 @@
1
  from torchvision import transforms
2
 
3
-
4
- def get_preprocessor():
5
- return transforms.Compose(
6
- [
7
- transforms.Resize((28, 28)),
8
- transforms.Grayscale(num_output_channels=1),
9
- transforms.ToTensor(),
10
- transforms.Normalize((0.1307,), (0.3081,)),
11
- ]
12
- )
 
1
  from torchvision import transforms
2
 
3
+ def get_transform():
4
+ return transforms.Compose([
5
+ transforms.Resize((28, 28)),
6
+ transforms.Grayscale(num_output_channels=1),
7
+ transforms.ToTensor(),
8
+ transforms.Normalize((0.1307,), (0.3081,))
9
+ ])
 
 
 
symmetric_test/train.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from torch import optim
4
  from torchvision import datasets, transforms
5
- from huggingface_hub import HfApi, Repository
6
  from .model import DigitClassifier
7
 
8
  # Config (better to put in separate config.yaml)
 
2
  import torch.nn as nn
3
  from torch import optim
4
  from torchvision import datasets, transforms
5
+ from huggingface_hub import HfApi
6
  from .model import DigitClassifier
7
 
8
  # Config (better to put in separate config.yaml)