l45k commited on
Commit
3f33131
·
verified ·
1 Parent(s): 889a0e6

Upload model

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. modeling_lenet.py +6 -6
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1afa15c915cfba02cf65f9fda685e845a333113fdf807050238d1d8136c8a7ec
3
  size 247728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:997619d28e6a8e5a1c7ceaab745e1593fa761e00bd59afb6086c3aa480975bf6
3
  size 247728
modeling_lenet.py CHANGED
@@ -27,8 +27,8 @@ class LeNet(torch.nn.Module):
27
  )
28
  )
29
 
30
- def forward(self, tensor) -> torch.Tensor:
31
- return self.model(tensor)
32
 
33
 
34
  class LeNetModel(PreTrainedModel):
@@ -38,8 +38,8 @@ class LeNetModel(PreTrainedModel):
38
  super().__init__(config)
39
  self.model = LeNet()
40
 
41
- def forward(self, tensor):
42
- return self.model.forward_features(tensor)
43
 
44
 
45
  class LeNetModelForImageClassification(PreTrainedModel):
@@ -49,8 +49,8 @@ class LeNetModelForImageClassification(PreTrainedModel):
49
  super().__init__(config)
50
  self.model = LeNet()
51
 
52
- def forward(self, tensor, labels=None):
53
- logits = self.model(tensor)
54
  if labels is not None:
55
  loss = torch.nn.functional.cross_entropy(logits, labels)
56
  return {"loss": loss, "logits": logits}
 
27
  )
28
  )
29
 
30
+ def forward(self, pixel_values) -> torch.Tensor:
31
+ return self.model(pixel_values)
32
 
33
 
34
  class LeNetModel(PreTrainedModel):
 
38
  super().__init__(config)
39
  self.model = LeNet()
40
 
41
+ def forward(self, pixel_values):
42
+ return self.model.forward_features(pixel_values)
43
 
44
 
45
  class LeNetModelForImageClassification(PreTrainedModel):
 
49
  super().__init__(config)
50
  self.model = LeNet()
51
 
52
+ def forward(self, pixel_values, labels=None):
53
+ logits = self.model(pixel_values)
54
  if labels is not None:
55
  loss = torch.nn.functional.cross_entropy(logits, labels)
56
  return {"loss": loss, "logits": logits}