CSP_0.6B_Inst / cif_tokenizer_fast.py
VivianKeith's picture
Upload folder using huggingface_hub
a4cf5a7 verified
"""
可被AutoTokenizer自动加载的CIF Tokenizer实现(简化版)
通过在tokenize前预处理、decode后后处理的方式,避免复杂的方法签名问题
"""
import os
import re
import json
from typing import List, Optional
from transformers import PreTrainedTokenizerFast
# 加载空间群列表
SPACEGROUPS_PATH = os.path.join(os.path.dirname(__file__), "spacegroups.txt")
if os.path.exists(SPACEGROUPS_PATH):
with open(SPACEGROUPS_PATH, "rt") as f:
SPACE_GROUPS = [sg.strip() for sg in f.readlines()]
else:
SPACE_GROUPS = []
class CIFTokenizerFast(PreTrainedTokenizerFast):
"""
可被AutoTokenizer自动加载的CIF专用Tokenizer
继承自PreTrainedTokenizerFast,完全兼容Hugging Face生态系统
使用说明:
- 通过 auto_map 配置,可被 AutoTokenizer 自动加载
- 需要将 cif_tokenizer_fast.py 放在模型目录中
- 加载时需要 trust_remote_code=True
"""
def __init__(self, *args, space_groups: Optional[List[str]] = None, **kwargs):
"""初始化"""
super().__init__(*args, **kwargs)
self.space_groups = space_groups if space_groups is not None else SPACE_GROUPS
self._build_spacegroup_pattern()
def _build_spacegroup_pattern(self):
"""构建空间群正则表达式模式"""
if self.space_groups:
sorted_sgs = sorted(self.space_groups, key=len, reverse=True)
self.spacegroup_pattern = "|".join([re.escape(sg) for sg in sorted_sgs])
else:
self.spacegroup_pattern = None
def preprocess_cif(self, text: str, single_spaces: bool = True) -> str:
"""
预处理CIF文本
Args:
text: 原始CIF文本
single_spaces: 是否规范化空格
Returns:
预处理后的文本
"""
if not isinstance(text, str):
return text
# 规范化空格
if single_spaces:
text = re.sub(r'[ \t]+', ' ', text)
# 处理空间群歧义
if self.spacegroup_pattern:
pattern = fr'(_symmetry_space_group_name_H-M\s*[\'"]?\s*({self.spacegroup_pattern})\s*[\'"]?)'
def add_sg_suffix(match):
full_match = match.group(0)
for sg in self.space_groups:
if sg in full_match and not full_match.endswith('_sg'):
return full_match.replace(sg, sg + '_sg', 1)
return full_match
text = re.sub(pattern, add_sg_suffix, text)
return text
def postprocess_cif(self, text: str) -> str:
"""后处理:移除_sg后缀"""
if isinstance(text, str) and self.space_groups:
for sg in self.space_groups:
text = text.replace(sg + '_sg', sg)
return text
def __call__(self, *args, **kwargs):
"""
重写__call__方法,强制移除token_type_ids
因为Qwen模型不使用token_type_ids
"""
# 调用父类方法
result = super().__call__(*args, **kwargs)
# 移除token_type_ids(如果存在)
if isinstance(result, dict) and 'token_type_ids' in result:
result.pop('token_type_ids')
return result
# 重写decode方法添加后处理
def decode(self, *args, **kwargs):
"""解码后后处理"""
result = super().decode(*args, **kwargs)
return self.postprocess_cif(result)
def batch_decode(self, *args, **kwargs):
"""批量解码后后处理"""
results = super().batch_decode(*args, **kwargs)
return [self.postprocess_cif(r) for r in results]
def save_pretrained(self, save_directory, *args, **kwargs):
"""保存tokenizer及CIF配置"""
# 保存基础tokenizer
result = super().save_pretrained(save_directory, *args, **kwargs)
# 保存CIF配置
cif_config_path = os.path.join(save_directory, "cif_config.json")
with open(cif_config_path, "w", encoding="utf-8") as f:
json.dump({"space_groups": self.space_groups}, f, ensure_ascii=False, indent=2)
# 配置auto_map
tokenizer_config_path = os.path.join(save_directory, "tokenizer_config.json")
if os.path.exists(tokenizer_config_path):
with open(tokenizer_config_path, "r", encoding="utf-8") as f:
config = json.load(f)
config["tokenizer_class"] = "CIFTokenizerFast"
config["auto_map"] = {
"AutoTokenizer": ["cif_tokenizer_fast.CIFTokenizerFast", None]
}
with open(tokenizer_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
# 复制源文件
try:
import shutil
src = os.path.abspath(__file__)
dst = os.path.join(save_directory, "cif_tokenizer_fast.py")
if src != dst:
shutil.copy2(src, dst)
print(f"✓ 已复制 cif_tokenizer_fast.py")
except Exception as e:
print(f"⚠ 复制失败: {e}")
return result
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
"""从预训练模型加载"""
# 加载CIF配置
space_groups = None
if isinstance(pretrained_model_name_or_path, str):
config_path = os.path.join(pretrained_model_name_or_path, "cif_config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
config = json.load(f)
space_groups = config.get("space_groups")
# 从kwargs获取(优先级更高)
space_groups = kwargs.pop("space_groups", space_groups)
# 加载tokenizer
tokenizer = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
tokenizer.space_groups = space_groups if space_groups is not None else SPACE_GROUPS
tokenizer._build_spacegroup_pattern()
return tokenizer
def register_cif_tokenizer():
"""兼容性函数(不需要实际注册)"""
pass