Fix model code
Browse files- modeling_blaser.py +12 -3
modeling_blaser.py
CHANGED
|
@@ -104,12 +104,14 @@ class BlaserCore(nn.Module):
|
|
| 104 |
|
| 105 |
|
| 106 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
|
|
|
| 107 |
class BlaserModel(PreTrainedModel):
|
| 108 |
config_class = BlaserConfig
|
| 109 |
|
| 110 |
def __init__(self, config: BlaserConfig):
|
| 111 |
super().__init__(config)
|
| 112 |
-
self.core
|
|
|
|
| 113 |
embedding_dim=config.embedding_dim,
|
| 114 |
output_dim=config.output_dim,
|
| 115 |
hidden_dims=config.hidden_dims,
|
|
@@ -118,7 +120,14 @@ class BlaserModel(PreTrainedModel):
|
|
| 118 |
input_form=config.input_form,
|
| 119 |
norm_emb=config.norm_emb,
|
| 120 |
output_act=config.output_act,
|
| 121 |
-
)
|
| 122 |
|
| 123 |
def forward(self, src, mt, ref=None):
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
# ---------------- HF MODEL WRAPPER ---------------- #
|
| 107 |
+
|
| 108 |
class BlaserModel(PreTrainedModel):
|
| 109 |
config_class = BlaserConfig
|
| 110 |
|
| 111 |
def __init__(self, config: BlaserConfig):
|
| 112 |
super().__init__(config)
|
| 113 |
+
# Instead of self.core, assign directly
|
| 114 |
+
self.mlp = BlaserCore(
|
| 115 |
embedding_dim=config.embedding_dim,
|
| 116 |
output_dim=config.output_dim,
|
| 117 |
hidden_dims=config.hidden_dims,
|
|
|
|
| 120 |
input_form=config.input_form,
|
| 121 |
norm_emb=config.norm_emb,
|
| 122 |
output_act=config.output_act,
|
| 123 |
+
).mlp # only take the Sequential MLP
|
| 124 |
|
| 125 |
def forward(self, src, mt, ref=None):
|
| 126 |
+
# The old checkpoint expects the input feature processing inside BlaserCore
|
| 127 |
+
proc = BlaserCore._featurize(
|
| 128 |
+
self.mlp, # pass self as `self` for static call
|
| 129 |
+
src=BlaserCore._norm(self.mlp, src),
|
| 130 |
+
mt=BlaserCore._norm(self.mlp, mt),
|
| 131 |
+
ref=BlaserCore._norm(self.mlp, ref)
|
| 132 |
+
)
|
| 133 |
+
return self.mlp(proc)
|