Update embedding_model.py
Browse files- embedding_model.py +4 -5
embedding_model.py
CHANGED
|
@@ -10,9 +10,8 @@ MODEL_PATH_IN_MODULE = "0_Model"
|
|
| 10 |
|
| 11 |
|
| 12 |
class EmbeddingModel(models.Transformer):
|
| 13 |
-
def __init__(self,
|
| 14 |
-
|
| 15 |
-
super().__init__(model_name_or_path, *args, **kwargs)
|
| 16 |
|
| 17 |
def tokenize(
|
| 18 |
self,
|
|
@@ -72,7 +71,7 @@ class EmbeddingModel(models.Transformer):
|
|
| 72 |
json.dump(self.get_config_dict(), f, indent=4)
|
| 73 |
|
| 74 |
@staticmethod
|
| 75 |
-
def load(load_dir: str,
|
| 76 |
with open(os.path.join(load_dir, "config.json")) as fIn:
|
| 77 |
config = json.load(fIn)
|
| 78 |
-
return EmbeddingModel(**config,
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class EmbeddingModel(models.Transformer):
|
| 13 |
+
def __init__(self, *args, **kwargs):
|
| 14 |
+
super().__init__(*args, **kwargs)
|
|
|
|
| 15 |
|
| 16 |
def tokenize(
|
| 17 |
self,
|
|
|
|
| 71 |
json.dump(self.get_config_dict(), f, indent=4)
|
| 72 |
|
| 73 |
@staticmethod
|
| 74 |
+
def load(load_dir: str, **kwargs) -> "EmbeddingModel":
|
| 75 |
with open(os.path.join(load_dir, "config.json")) as fIn:
|
| 76 |
config = json.load(fIn)
|
| 77 |
+
return EmbeddingModel(**config, **kwargs)
|