Commit
·
16b4531
1
Parent(s):
0ac03c0
Fix __init__ to call original CrossEncoder initialization
Browse filesThe previous implementation was completely overriding CrossEncoder's
__init__ method, preventing proper base class initialization. Now we
call the original __init__ first before loading our custom model.
- modeling_zeranker.py +7 -0
modeling_zeranker.py
CHANGED
|
@@ -108,7 +108,14 @@ def load_model(
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
|
|
|
|
|
|
|
|
| 112 |
# Load the model immediately on instantiation
|
| 113 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 114 |
self.inner_model.eval()
|
|
|
|
| 108 |
return tokenizer, model
|
| 109 |
|
| 110 |
|
| 111 |
+
# Store the original __init__ method
|
| 112 |
+
_original_init = _CE.__init__
|
| 113 |
+
|
| 114 |
+
|
| 115 |
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
| 116 |
+
# Call the original CrossEncoder __init__ first
|
| 117 |
+
_original_init(self, *args, **kwargs)
|
| 118 |
+
|
| 119 |
# Load the model immediately on instantiation
|
| 120 |
self.inner_tokenizer, self.inner_model = load_model(global_device)
|
| 121 |
self.inner_model.eval()
|