Update weights and modeling code to latest version
Browse files- modeling_lucaone.py +6 -2
- tokenization_lucaone.py +21 -3
modeling_lucaone.py
CHANGED
|
@@ -1141,8 +1141,10 @@ class LucaGPLMForMaskedLM(LucaGPLMPreTrainedModel):
|
|
| 1141 |
|
| 1142 |
class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
|
| 1143 |
def __init__(self, config):
|
|
|
|
|
|
|
| 1144 |
super().__init__(config)
|
| 1145 |
-
self.num_labels = config.
|
| 1146 |
self.task_level = config.task_level
|
| 1147 |
self.task_type = config.task_type
|
| 1148 |
assert self.task_level == "seq_level"
|
|
@@ -1247,8 +1249,10 @@ class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
|
|
| 1247 |
|
| 1248 |
class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
|
| 1249 |
def __init__(self, config):
|
|
|
|
|
|
|
| 1250 |
super().__init__(config)
|
| 1251 |
-
self.num_labels = config.
|
| 1252 |
self.task_level = config.task_level
|
| 1253 |
self.task_type = config.task_type
|
| 1254 |
assert self.task_level == "token_level"
|
|
|
|
| 1141 |
|
| 1142 |
class LucaGPLMForSequenceClassification(LucaGPLMPreTrainedModel):
|
| 1143 |
def __init__(self, config):
|
| 1144 |
+
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
|
| 1145 |
+
config.num_labels = config.classifier_num_labels
|
| 1146 |
super().__init__(config)
|
| 1147 |
+
self.num_labels = config.num_labels
|
| 1148 |
self.task_level = config.task_level
|
| 1149 |
self.task_type = config.task_type
|
| 1150 |
assert self.task_level == "seq_level"
|
|
|
|
| 1249 |
|
| 1250 |
class LucaGPLMForTokenClassification(LucaGPLMPreTrainedModel):
|
| 1251 |
def __init__(self, config):
|
| 1252 |
+
if hasattr(config, "classifier_num_labels") and config.classifier_num_labels > 0:
|
| 1253 |
+
config.num_labels = config.classifier_num_labels
|
| 1254 |
super().__init__(config)
|
| 1255 |
+
self.num_labels = config.num_labels
|
| 1256 |
self.task_level = config.task_level
|
| 1257 |
self.task_type = config.task_type
|
| 1258 |
assert self.task_level == "token_level"
|
tokenization_lucaone.py
CHANGED
|
@@ -133,7 +133,7 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
|
|
| 133 |
self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
|
| 134 |
self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
|
| 135 |
self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
|
| 136 |
-
|
| 137 |
super().__init__(
|
| 138 |
unk_token=unk_token,
|
| 139 |
pad_token=pad_token,
|
|
@@ -295,7 +295,23 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
|
|
| 295 |
|
| 296 |
def batch_encode_plus(self, *args, **kwargs):
|
| 297 |
# 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
|
| 298 |
-
return super().batch_encode_plus(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
def encode_plus(
|
| 301 |
self,
|
|
@@ -311,7 +327,9 @@ class LucaGPLMTokenizer(PreTrainedTokenizer):
|
|
| 311 |
truncation: bool = False,
|
| 312 |
**kwargs
|
| 313 |
) -> Dict[str, Any]:
|
| 314 |
-
|
|
|
|
|
|
|
| 315 |
# 调用修复后的 encode,它现在会正确处理截断
|
| 316 |
token_ids = self.encode(
|
| 317 |
text,
|
|
|
|
| 133 |
self.cls_idx = self.tok_to_idx.get("[CLS]", 2)
|
| 134 |
self.mask_idx = self.tok_to_idx.get("[MASK]", 4)
|
| 135 |
self.eos_idx = self.tok_to_idx.get("[SEP]", 3)
|
| 136 |
+
|
| 137 |
super().__init__(
|
| 138 |
unk_token=unk_token,
|
| 139 |
pad_token=pad_token,
|
|
|
|
| 295 |
|
| 296 |
def batch_encode_plus(self, *args, **kwargs):
|
| 297 |
# 显式调用父类,或者保留你原有的实现,只要确保内部调用的是修复后的 encode_plus 即可
|
| 298 |
+
# return super().batch_encode_plus(*args, **kwargs)
|
| 299 |
+
# 修改
|
| 300 |
+
# 循环处理每一条数据
|
| 301 |
+
batch_outputs = []
|
| 302 |
+
batch_text = kwargs["text"]
|
| 303 |
+
seq_type = kwargs["seq_type"]
|
| 304 |
+
for text in batch_text:
|
| 305 |
+
batch_outputs.append(self.encode_plus(text, seq_type=seq_type, **kwargs))
|
| 306 |
+
|
| 307 |
+
# 将结果合并为 Dict[str, List[List[int]]]
|
| 308 |
+
# 这样 Dataset.map(batched=True) 才能正确解析
|
| 309 |
+
combined = {key: [] for key in batch_outputs[0].keys()}
|
| 310 |
+
for output in batch_outputs:
|
| 311 |
+
for key, value in output.items():
|
| 312 |
+
combined[key].append(value)
|
| 313 |
+
|
| 314 |
+
return combined
|
| 315 |
|
| 316 |
def encode_plus(
|
| 317 |
self,
|
|
|
|
| 327 |
truncation: bool = False,
|
| 328 |
**kwargs
|
| 329 |
) -> Dict[str, Any]:
|
| 330 |
+
# 修改
|
| 331 |
+
# 忽略掉不认识的参数,比如 text_pair
|
| 332 |
+
kwargs.pop("text_pair", None)
|
| 333 |
# 调用修复后的 encode,它现在会正确处理截断
|
| 334 |
token_ids = self.encode(
|
| 335 |
text,
|