Commit ·
e81feeb
1
Parent(s): 55a56ca
Delete modeling_customefficientnetv2.py
Browse files
modeling_customefficientnetv2.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
from transformers import PretrainedModel
|
| 2 |
-
from configuration_customefficientnetv2 import CustomEfficientNetV2Config
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
class CustomEfficientNetV2(PretrainedModel):
|
| 6 |
-
config_class = CustomEfficientNetV2Config
|
| 7 |
-
|
| 8 |
-
def __init__(self, config):
|
| 9 |
-
super().__init__(config)
|
| 10 |
-
|
| 11 |
-
self.url = config.url
|
| 12 |
-
file_name = self.url.split('/')[-1]
|
| 13 |
-
self.model = torch.load(file_name)
|
| 14 |
-
|
| 15 |
-
self.input_size = config.input_size
|
| 16 |
-
shape = [2] + self.input_size
|
| 17 |
-
example_inputs = torch.randn(shape)
|
| 18 |
-
example_inputs = (example_inputs - example_inputs.min()) / (example_inputs.max() - example_inputs.min())
|
| 19 |
-
|
| 20 |
-
self.num_classes = config.num_classes
|
| 21 |
-
if self.num_classes != 1000:
|
| 22 |
-
self.model.classifier = torch.nn.Linear(in_features=1984, out_features=self.num_classes, bias=True)
|
| 23 |
-
|
| 24 |
-
traced_model = torch.jit.trace(self.model, example_inputs)
|
| 25 |
-
traced_model.save(file_name)
|
| 26 |
-
|
| 27 |
-
self.model = torch.jit.load(file_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|