Update modeling_transhla2.py
Browse files- modeling_transhla2.py +19 -4
modeling_transhla2.py
CHANGED
|
@@ -27,6 +27,7 @@ class TransHLA2Config(PretrainedConfig):
|
|
| 27 |
lora_inference_mode=False,
|
| 28 |
target_modules=None,
|
| 29 |
return_prob=True, # 是否在 forward 返回概率(softmax),否则返回 logits
|
|
|
|
| 30 |
**kwargs,
|
| 31 |
):
|
| 32 |
super().__init__(**kwargs)
|
|
@@ -50,6 +51,7 @@ class TransHLA2Config(PretrainedConfig):
|
|
| 50 |
self.target_modules = target_modules or ['query', 'out_proj', 'value', 'key', 'dense', 'regression']
|
| 51 |
|
| 52 |
self.return_prob = return_prob
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
class TransHLA2(PreTrainedModel):
|
|
@@ -165,11 +167,24 @@ class TransHLA2(PreTrainedModel):
|
|
| 165 |
x = px + x
|
| 166 |
return x
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
def forward(self, epitope_in, hla_in, return_dict=None):
|
| 169 |
-
#
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
# hla_in = {"input_ids": ..., "attention_mask": ...}
|
| 173 |
|
| 174 |
epitope_outputs = self.epitope_lora(**epitope_in)
|
| 175 |
hla_outputs = self.hla_lora(**hla_in)
|
|
|
|
| 27 |
lora_inference_mode=False,
|
| 28 |
target_modules=None,
|
| 29 |
return_prob=True, # 是否在 forward 返回概率(softmax),否则返回 logits
|
| 30 |
+
pad_token_id=1, # ESM 默认 pad id
|
| 31 |
**kwargs,
|
| 32 |
):
|
| 33 |
super().__init__(**kwargs)
|
|
|
|
| 51 |
self.target_modules = target_modules or ['query', 'out_proj', 'value', 'key', 'dense', 'regression']
|
| 52 |
|
| 53 |
self.return_prob = return_prob
|
| 54 |
+
self.pad_token_id = pad_token_id
|
| 55 |
|
| 56 |
|
| 57 |
class TransHLA2(PreTrainedModel):
|
|
|
|
| 167 |
x = px + x
|
| 168 |
return x
|
| 169 |
|
| 170 |
+
def _ensure_mapping_input(self, x):
|
| 171 |
+
# 允许两种输入形式:
|
| 172 |
+
# 1) 字典: {"input_ids": ..., "attention_mask": ...}
|
| 173 |
+
# 2) 直接的 input_ids 张量: (B, L)
|
| 174 |
+
if isinstance(x, torch.Tensor):
|
| 175 |
+
# 仅用 input_ids;如需自动构造 attention_mask,可解除注释:
|
| 176 |
+
# pad_id = self.config.pad_token_id
|
| 177 |
+
# return {"input_ids": x, "attention_mask": (x != pad_id).long()}
|
| 178 |
+
return {"input_ids": x}
|
| 179 |
+
elif isinstance(x, dict):
|
| 180 |
+
return x
|
| 181 |
+
else:
|
| 182 |
+
raise TypeError(f"Unsupported input type: {type(x)}; expected Tensor or dict.")
|
| 183 |
+
|
| 184 |
def forward(self, epitope_in, hla_in, return_dict=None):
|
| 185 |
+
# 兼容张量或字典输入
|
| 186 |
+
epitope_in = self._ensure_mapping_input(epitope_in)
|
| 187 |
+
hla_in = self._ensure_mapping_input(hla_in)
|
|
|
|
| 188 |
|
| 189 |
epitope_outputs = self.epitope_lora(**epitope_in)
|
| 190 |
hla_outputs = self.hla_lora(**hla_in)
|