Update README.md
Browse files
README.md
CHANGED
|
@@ -34,6 +34,7 @@ import torch
|
|
| 34 |
import torch.nn as nn
|
| 35 |
import os
|
| 36 |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
|
|
|
| 37 |
|
| 38 |
class CustomConvNeXtConfig(PretrainedConfig):
|
| 39 |
model_type = "custom-convnext"
|
|
@@ -85,7 +86,6 @@ class CustomConvNeXtModel(PreTrainedModel):
|
|
| 85 |
model = cls(config=config, model_name=model_name, num_classes=config.num_labels)
|
| 86 |
|
| 87 |
# Load state_dict from safetensors file
|
| 88 |
-
from safetensors.torch import load_file # Safetensors library
|
| 89 |
state_dict = load_file(model_path)
|
| 90 |
model.load_state_dict(state_dict)
|
| 91 |
|
|
|
|
| 34 |
import torch.nn as nn
|
| 35 |
import os
|
| 36 |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 37 |
+
from safetensors.torch import load_file
|
| 38 |
|
| 39 |
class CustomConvNeXtConfig(PretrainedConfig):
|
| 40 |
model_type = "custom-convnext"
|
|
|
|
| 86 |
model = cls(config=config, model_name=model_name, num_classes=config.num_labels)
|
| 87 |
|
| 88 |
# Load state_dict from safetensors file
|
|
|
|
| 89 |
state_dict = load_file(model_path)
|
| 90 |
model.load_state_dict(state_dict)
|
| 91 |
|