LucaGroup commited on
Commit
6c59832
·
verified ·
1 Parent(s): 28e914a

Update weights and modeling code to latest version

Browse files
Files changed (2) hide show
  1. modeling_lucaone.py +6 -2
  2. tokenization_lucaone.py +21 -3
modeling_lucaone.py CHANGED
@@ -1141,8 +1141,10 @@ class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel):
1141
 
1142
  class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
1143
  def __init__(self, config):
 
 
1144
  super().__init__(config)
1145
- self.num_labels = config.classifier_num_labels
1146
  self.task_level = config.task_level
1147
  self.task_type = config.task_type
1148
  assert self.task_level == "seq_level"
@@ -1247,8 +1249,10 @@ class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
1247
 
1248
  class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
1249
  def __init__(self, config):
 
 
1250
  super().__init__(config)
1251
- self.num_labels = config.classifier_num_labels
1252
  self.task_level = config.task_level
1253
  self.task_type = config.task_type
1254
  assert self.task_level == "token_level"
 
1141
 
1142
  class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
1143
  def __init__(self, config):
1144
+ if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
1145
+ config.num_labels = config.classifier_num_labels
1146
  super().__init__(config)
1147
+ self.num_labels = config.num_labels
1148
  self.task_level = config.task_level
1149
  self.task_type = config.task_type
1150
  assert self.task_level == "seq_level"
 
1249
 
1250
  class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
1251
  def __init__(self, config):
1252
+ if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
1253
+ config.num_labels = config.classifier_num_labels
1254
  super().__init__(config)
1255
+ self.num_labels = config.num_labels
1256
  self.task_level = config.task_level
1257
  self.task_type = config.task_type
1258
  assert self.task_level == "token_level"
tokenization_lucaone.py CHANGED
@@ -133,7 +133,7 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
133
  self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
134
  self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
135
  self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
136
-
137
  super().__init__(
138
  unk_token=unk_token,
139
  pad_token=pad_token,
@@ -295,7 +295,23 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
295
 
296
  def batch_encode_plus(self, *args, **kwargs):
297
  # 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
298
- return super().batch_encode_plus(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  def encode_plus(
301
  self,
@@ -311,7 +327,9 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
311
  truncation: bool = False,
312
  **kwargs
313
  ) -> Dict[str, Any]:
314
-
 
 
315
  # 调用修复后的 encode,它现在会正确处理截断
316
  token_ids = self.encode(
317
  text,
 
133
  self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
134
  self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
135
  self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
136
+
137
  super().__init__(
138
  unk_token=unk_token,
139
  pad_token=pad_token,
 
295
 
296
  def batch_encode_plus(self, *args, **kwargs):
297
  # 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
298
+ # return super().batch_encode_plus(*args, **kwargs)
299
+ # 修改
300
+ # 循环处理每一条数据
301
+ batch_outputs = []
302
+ batch_text = kwargs["text"]
303
+ seq_type = kwargs["seq_type"]
304
+ for text in batch_text:
305
+ batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs))
306
+
307
+ # 将结果合并为 Dict[str, List[List[int]]]
308
+ # 这样 Dataset.map(batched=True) 才能正确解析
309
+ combined = {key: [] for key in batch_outputs[0].keys()}
310
+ for output in batch_outputs:
311
+ for key, value in output.items():
312
+ combined[key].append(value)
313
+
314
+ return combined
315
 
316
  def encode_plus(
317
  self,
 
327
  truncation: bool = False,
328
  **kwargs
329
  ) -> Dict[str, Any]:
330
+ # 修改
331
+ # 忽略掉不认识的参数,比如 text_pair
332
+ kwargs.pop("text_pair", None)
333
  # 调用修复后的 encode,它现在会正确处理截断
334
  token_ids = self.encode(
335
  text,