SkywalkerLu commited on
Commit
3330148
·
verified ·
1 Parent(s): 276662d

Update modeling_transhla2.py

Browse files
Files changed (1) hide show
  1. modeling_transhla2.py +19 -4
modeling_transhla2.py CHANGED
@@ -27,6 +27,7 @@ class TransHLA2Config(PretrainedConfig):
27
  lora_inference_mode=False,
28
  target_modules=None,
29
  return_prob=True, # 是否在 forward 返回概率(softmax),否则返回 logits
 
30
  **kwargs,
31
  ):
32
  super().__init__(**kwargs)
@@ -50,6 +51,7 @@ class TransHLA2Config(PretrainedConfig):
50
  self.target_modules = target_modules or ['query', 'out_proj', 'value', 'key', 'dense', 'regression']
51
 
52
  self.return_prob = return_prob
 
53
 
54
 
55
  class TransHLA2(PreTrainedModel):
@@ -165,11 +167,24 @@ class TransHLA2(PreTrainedModel):
165
  x = px + x
166
  return x
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def forward(self, epitope_in, hla_in, return_dict=None):
169
- # epitope_in, hla_in: 输入应为 ESM 的输入字典或张量(通常是 input_ids/attention_mask)
170
- # 这里假定传入的是 ESM 的标准输入字典,例如:
171
- # epitope_in = {"input_ids": ..., "attention_mask": ...}
172
- # hla_in = {"input_ids": ..., "attention_mask": ...}
173
 
174
  epitope_outputs = self.epitope_lora(**epitope_in)
175
  hla_outputs = self.hla_lora(**hla_in)
 
27
  lora_inference_mode=False,
28
  target_modules=None,
29
  return_prob=True, # 是否在 forward 返回概率(softmax),否则返回 logits
30
+ pad_token_id=1, # ESM 默认 pad id
31
  **kwargs,
32
  ):
33
  super().__init__(**kwargs)
 
51
  self.target_modules = target_modules or ['query', 'out_proj', 'value', 'key', 'dense', 'regression']
52
 
53
  self.return_prob = return_prob
54
+ self.pad_token_id = pad_token_id
55
 
56
 
57
  class TransHLA2(PreTrainedModel):
 
167
  x = px + x
168
  return x
169
 
170
+ def _ensure_mapping_input(self, x):
171
+ # 允许两种输入形式:
172
+ # 1) 字典: {"input_ids": ..., "attention_mask": ...}
173
+ # 2) 直接的 input_ids 张量: (B, L)
174
+ if isinstance(x, torch.Tensor):
175
+ # 仅用 input_ids;如需自动构造 attention_mask,可解除注释:
176
+ # pad_id = self.config.pad_token_id
177
+ # return {"input_ids": x, "attention_mask": (x != pad_id).long()}
178
+ return {"input_ids": x}
179
+ elif isinstance(x, dict):
180
+ return x
181
+ else:
182
+ raise TypeError(f"Unsupported input type: {type(x)}; expected Tensor or dict.")
183
+
184
  def forward(self, epitope_in, hla_in, return_dict=None):
185
+ # 兼容张量或字典输入
186
+ epitope_in = self._ensure_mapping_input(epitope_in)
187
+ hla_in = self._ensure_mapping_input(hla_in)
 
188
 
189
  epitope_outputs = self.epitope_lora(**epitope_in)
190
  hla_outputs = self.hla_lora(**hla_in)