Upload model
Browse files- model.safetensors +1 -1
- modeling_lenet.py +6 -6
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 247728
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:997619d28e6a8e5a1c7ceaab745e1593fa761e00bd59afb6086c3aa480975bf6
|
| 3 |
size 247728
|
modeling_lenet.py
CHANGED
|
@@ -27,8 +27,8 @@ class LeNet(torch.nn.Module):
|
|
| 27 |
)
|
| 28 |
)
|
| 29 |
|
| 30 |
-
def forward(self,
|
| 31 |
-
return self.model(
|
| 32 |
|
| 33 |
|
| 34 |
class LeNetModel(PreTrainedModel):
|
|
@@ -38,8 +38,8 @@ class LeNetModel(PreTrainedModel):
|
|
| 38 |
super().__init__(config)
|
| 39 |
self.model = LeNet()
|
| 40 |
|
| 41 |
-
def forward(self,
|
| 42 |
-
return self.model.forward_features(
|
| 43 |
|
| 44 |
|
| 45 |
class LeNetModelForImageClassification(PreTrainedModel):
|
|
@@ -49,8 +49,8 @@ class LeNetModelForImageClassification(PreTrainedModel):
|
|
| 49 |
super().__init__(config)
|
| 50 |
self.model = LeNet()
|
| 51 |
|
| 52 |
-
def forward(self,
|
| 53 |
-
logits = self.model(
|
| 54 |
if labels is not None:
|
| 55 |
loss = torch.nn.functional.cross_entropy(logits, labels)
|
| 56 |
return {"loss": loss, "logits": logits}
|
|
|
|
| 27 |
)
|
| 28 |
)
|
| 29 |
|
| 30 |
+
def forward(self, pixel_values) -> torch.Tensor:
|
| 31 |
+
return self.model(pixel_values)
|
| 32 |
|
| 33 |
|
| 34 |
class LeNetModel(PreTrainedModel):
|
|
|
|
| 38 |
super().__init__(config)
|
| 39 |
self.model = LeNet()
|
| 40 |
|
| 41 |
+
def forward(self, pixel_values):
|
| 42 |
+
return self.model.forward_features(pixel_values)
|
| 43 |
|
| 44 |
|
| 45 |
class LeNetModelForImageClassification(PreTrainedModel):
|
|
|
|
| 49 |
super().__init__(config)
|
| 50 |
self.model = LeNet()
|
| 51 |
|
| 52 |
+
def forward(self, pixel_values, labels=None):
|
| 53 |
+
logits = self.model(pixel_values)
|
| 54 |
if labels is not None:
|
| 55 |
loss = torch.nn.functional.cross_entropy(logits, labels)
|
| 56 |
return {"loss": loss, "logits": logits}
|