Upload CrossDNA 8.1M (model files only)
Browse files- 8.1M/config.json +74 -0
- 8.1M/configuration_crossdna.py +105 -0
- 8.1M/model.safetensors +3 -0
- 8.1M/modeling_crossdna.py +916 -0
- 8.1M/special_tokens_map.json +9 -0
- 8.1M/tokenization_crossdna.py +159 -0
- 8.1M/tokenizer_config.json +32 -0
8.1M/config.json
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "CrossDNA-pretrain",
|
| 3 |
+
"model_type": "crossdna",
|
| 4 |
+
"architectures": ["CrossDNAForMaskedLM"],
|
| 5 |
+
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_crossdna.CrossDNAConfig",
|
| 8 |
+
"AutoModelForMaskedLM": "modeling_crossdna.CrossDNAForMaskedLM",
|
| 9 |
+
"AutoTokenizer": "tokenization_crossdna.CrossDNATokenizer"
|
| 10 |
+
},
|
| 11 |
+
|
| 12 |
+
"torch_dtype": "float32",
|
| 13 |
+
|
| 14 |
+
"alphabet_size": 5,
|
| 15 |
+
"d_model": 128,
|
| 16 |
+
"block_size": 1024,
|
| 17 |
+
"depth": 6,
|
| 18 |
+
"drop_path_rates": [0.0, 0.05],
|
| 19 |
+
"dropout": 0.15,
|
| 20 |
+
|
| 21 |
+
"pretrain": true,
|
| 22 |
+
"for_representation": false,
|
| 23 |
+
"use_s_scan": true,
|
| 24 |
+
"use_bridge": true,
|
| 25 |
+
"use_mem": false,
|
| 26 |
+
"use_rc_kl": false,
|
| 27 |
+
"use_barlow": false,
|
| 28 |
+
"use_tv": false,
|
| 29 |
+
|
| 30 |
+
"sem_max_weight": 0.12,
|
| 31 |
+
"sem_warmup_steps": 10000,
|
| 32 |
+
"aux_ce_weight": 0.0,
|
| 33 |
+
"gate_freeze_steps": 5000,
|
| 34 |
+
"detach_gate": false,
|
| 35 |
+
"gate_sup_weight": 0.02,
|
| 36 |
+
"gate_sup_warmup_steps": 500,
|
| 37 |
+
"gate_temp": 2.0,
|
| 38 |
+
|
| 39 |
+
"transformer_cfg": {
|
| 40 |
+
"hidden_size": 128,
|
| 41 |
+
"norm_eps": 1e-5,
|
| 42 |
+
"max_position_embeddings": 1024,
|
| 43 |
+
"hidden_ratio": 4.0,
|
| 44 |
+
"hidden_act": "swish",
|
| 45 |
+
"fuse_swiglu": true,
|
| 46 |
+
"attn": {
|
| 47 |
+
"num_heads": 8,
|
| 48 |
+
"num_kv_heads": 8,
|
| 49 |
+
"qkv_bias": false,
|
| 50 |
+
"window_size": 2048,
|
| 51 |
+
"rope_theta": 10000
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"comba_cfg": {
|
| 55 |
+
"hidden_size": 128,
|
| 56 |
+
"expand_v": 1,
|
| 57 |
+
"head_dim": 64,
|
| 58 |
+
"num_heads": 8,
|
| 59 |
+
"use_gate": true,
|
| 60 |
+
"mode": "chunk",
|
| 61 |
+
"use_short_conv": true,
|
| 62 |
+
"correction_factor": 0.02,
|
| 63 |
+
"conv_size": 4,
|
| 64 |
+
"norm_eps": 1e-5
|
| 65 |
+
},
|
| 66 |
+
|
| 67 |
+
"pad_token_id": 4,
|
| 68 |
+
"bos_token_id": 2,
|
| 69 |
+
"sep_token_id": 1,
|
| 70 |
+
"cls_token_id": 0,
|
| 71 |
+
"mask_token_id": 3,
|
| 72 |
+
|
| 73 |
+
"vocab_size": 5
|
| 74 |
+
}
|
8.1M/configuration_crossdna.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class CrossDNAConfig(PretrainedConfig):
|
| 4 |
+
model_type = "crossdna"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
|
| 9 |
+
alphabet_size=5,
|
| 10 |
+
d_model=128,
|
| 11 |
+
block_size=1024,
|
| 12 |
+
depth=6,
|
| 13 |
+
drop_path_rates=(0.0, 0.05),
|
| 14 |
+
dropout=0.15,
|
| 15 |
+
|
| 16 |
+
pretrain=True,
|
| 17 |
+
for_representation=False,
|
| 18 |
+
use_s_scan=True,
|
| 19 |
+
use_bridge=True,
|
| 20 |
+
use_mem=False,
|
| 21 |
+
use_rc_kl=False,
|
| 22 |
+
use_barlow=False,
|
| 23 |
+
use_tv=False,
|
| 24 |
+
|
| 25 |
+
sem_max_weight=0.12,
|
| 26 |
+
sem_warmup_steps=10000,
|
| 27 |
+
aux_ce_weight=0.0,
|
| 28 |
+
gate_freeze_steps=5000,
|
| 29 |
+
detach_gate=False,
|
| 30 |
+
gate_sup_weight=0.02,
|
| 31 |
+
gate_sup_warmup_steps=500,
|
| 32 |
+
gate_temp=2.0,
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
transformer_cfg=None,
|
| 36 |
+
comba_cfg=None,
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
pad_token_id=4, # [PAD]
|
| 40 |
+
bos_token_id=2, # [BOS]
|
| 41 |
+
sep_token_id=1, # [SEP]
|
| 42 |
+
cls_token_id=0, # [CLS]
|
| 43 |
+
mask_token_id=3, # [MASK]
|
| 44 |
+
**kwargs
|
| 45 |
+
):
|
| 46 |
+
super().__init__(
|
| 47 |
+
pad_token_id=pad_token_id,
|
| 48 |
+
bos_token_id=bos_token_id,
|
| 49 |
+
**kwargs
|
| 50 |
+
)
|
| 51 |
+
self.alphabet_size = alphabet_size
|
| 52 |
+
self.d_model = d_model
|
| 53 |
+
self.block_size = block_size
|
| 54 |
+
self.depth = depth
|
| 55 |
+
self.drop_path_rates = list(drop_path_rates) if drop_path_rates is not None else None
|
| 56 |
+
self.dropout = dropout
|
| 57 |
+
|
| 58 |
+
self.pretrain = pretrain
|
| 59 |
+
self.for_representation = for_representation
|
| 60 |
+
self.use_s_scan = use_s_scan
|
| 61 |
+
self.use_bridge = use_bridge
|
| 62 |
+
self.use_mem = use_mem
|
| 63 |
+
self.use_rc_kl = use_rc_kl
|
| 64 |
+
self.use_barlow = use_barlow
|
| 65 |
+
self.use_tv = use_tv
|
| 66 |
+
|
| 67 |
+
self.sem_max_weight = sem_max_weight
|
| 68 |
+
self.sem_warmup_steps = sem_warmup_steps
|
| 69 |
+
self.aux_ce_weight = aux_ce_weight
|
| 70 |
+
self.gate_freeze_steps = gate_freeze_steps
|
| 71 |
+
self.detach_gate = detach_gate
|
| 72 |
+
self.gate_sup_weight = gate_sup_weight
|
| 73 |
+
self.gate_sup_warmup_steps = gate_sup_warmup_steps
|
| 74 |
+
self.gate_temp = gate_temp
|
| 75 |
+
|
| 76 |
+
self.transformer_cfg = transformer_cfg or {
|
| 77 |
+
"hidden_size": d_model,
|
| 78 |
+
"norm_eps": 1e-5,
|
| 79 |
+
"max_position_embeddings": 1024, # 会被dataset.max_length替换;此处给默认
|
| 80 |
+
"hidden_ratio": 4.0,
|
| 81 |
+
"hidden_act": "swish",
|
| 82 |
+
"fuse_swiglu": True,
|
| 83 |
+
"attn": {
|
| 84 |
+
"num_heads": 8,
|
| 85 |
+
"num_kv_heads": 8,
|
| 86 |
+
"qkv_bias": False,
|
| 87 |
+
"window_size": 2048,
|
| 88 |
+
"rope_theta": 10000
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
self.comba_cfg = comba_cfg or {
|
| 92 |
+
"hidden_size": d_model,
|
| 93 |
+
"expand_v": 1,
|
| 94 |
+
"head_dim": 64,
|
| 95 |
+
"num_heads": 8,
|
| 96 |
+
"use_gate": True,
|
| 97 |
+
"mode": "chunk",
|
| 98 |
+
"use_short_conv": True,
|
| 99 |
+
"correction_factor": 0.02,
|
| 100 |
+
"conv_size": 4,
|
| 101 |
+
"norm_eps": 1e-5,
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
# 方便AutoModel推断 vocab_size
|
| 105 |
+
self.vocab_size = self.alphabet_size
|
8.1M/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:886fc527a7103aa27d4f7f41626a9a283f400983e91d5e09247243d79b1dac63
|
| 3 |
+
size 64417904
|
8.1M/modeling_crossdna.py
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import copy
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch import amp
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from functools import partial
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
from typing import Dict, Optional, Tuple, Any
|
| 12 |
+
|
| 13 |
+
from transformers import PreTrainedModel
|
| 14 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 15 |
+
|
| 16 |
+
# 编译ckpt到bin文件时候打开注释。
|
| 17 |
+
import os as _os, torch as _torch
|
| 18 |
+
if _os.environ.get("DISABLE_TORCH_COMPILE", "1") == "1" and hasattr(_torch, "compile"):
|
| 19 |
+
def _no_compile(fn=None, *args, **kwargs):
|
| 20 |
+
if fn is None:
|
| 21 |
+
def deco(f): return f
|
| 22 |
+
return deco
|
| 23 |
+
return fn
|
| 24 |
+
_torch.compile = _no_compile
|
| 25 |
+
print("torch.compile =>", torch.compile)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
from fla.layers import comba
|
| 29 |
+
from fla.layers.attn import Attention
|
| 30 |
+
from fla.modules import GatedMLP as SambaMLP
|
| 31 |
+
from fla.modules import RMSNorm
|
| 32 |
+
from torch.cuda.amp import autocast
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# 在 Transformers 的动态包 transformers_modules.<hash>.* 环境下走相对导入
|
| 37 |
+
from .configuration_crossdna import CrossDNAConfig
|
| 38 |
+
except ImportError:
|
| 39 |
+
# 直接本地运行(不是通过 from_pretrained 动态加载)时也能跑
|
| 40 |
+
from configuration_crossdna import CrossDNAConfig
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ========================
|
| 45 |
+
# Utils
|
| 46 |
+
# ========================
|
| 47 |
+
def complement(seq: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
# A=0, C=1, G=2, T=3, N=4
|
| 49 |
+
comp = 3 - seq
|
| 50 |
+
comp[seq == 4] = 4
|
| 51 |
+
return comp
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def reverse_complement(seq: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
comp = complement(seq)
|
| 56 |
+
return torch.flip(comp, dims=[1])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def make_complement_perm(C=5, device=None, dtype=torch.float32):
|
| 60 |
+
# A=0,C=1,G=2,T=3,N=4 -> T,A,C,G,N
|
| 61 |
+
perm = torch.tensor([3, 0, 2, 1, 4], device=device)
|
| 62 |
+
P = torch.zeros(C, C, device=device, dtype=dtype)
|
| 63 |
+
P[torch.arange(C, device=device), perm] = 1.0
|
| 64 |
+
return P, perm
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def ensure_finite(x: torch.Tensor, name: str):
|
| 68 |
+
# 只检查,不“吃掉”数值问题
|
| 69 |
+
if not torch.isfinite(x).all():
|
| 70 |
+
raise FloatingPointError(f"Non-finite values detected in {name}")
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def linear_warmup_weight(step: int, warmup_steps: int, max_w: float):
|
| 75 |
+
if warmup_steps <= 0:
|
| 76 |
+
return max_w
|
| 77 |
+
if step <= 0:
|
| 78 |
+
return 0.0
|
| 79 |
+
if step >= warmup_steps:
|
| 80 |
+
return max_w
|
| 81 |
+
return max_w * (step / warmup_steps)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def preferred_amp_dtype():
|
| 85 |
+
try:
|
| 86 |
+
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
| 87 |
+
return torch.bfloat16
|
| 88 |
+
except Exception:
|
| 89 |
+
pass
|
| 90 |
+
return torch.float16
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ========================
|
| 94 |
+
# RC 一致性(可选损失)
|
| 95 |
+
# ========================
|
| 96 |
+
def rc_consistency_kl(logits_A, logits_B_fwd, P, tau: float = 1.0, eps: float = 1e-6):
|
| 97 |
+
zA = logits_A.float() / tau
|
| 98 |
+
zB = logits_B_fwd.float() / tau
|
| 99 |
+
pA = F.softmax(zA, dim=-1)
|
| 100 |
+
logpA = F.log_softmax(zA, dim=-1)
|
| 101 |
+
pB = F.softmax(zB, dim=-1)
|
| 102 |
+
pB_comp = torch.matmul(pB, P.t()).clamp_min(eps)
|
| 103 |
+
logpB_comp = pB_comp.log()
|
| 104 |
+
kl = (pA * (logpA - logpB_comp)).sum(dim=-1).mean()
|
| 105 |
+
return kl * (tau * tau)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def rc_consistency_bidirectional_stopgrad(logits_A, logits_B_fwd, P, tau: float = 1.5, eps: float = 1e-6):
|
| 109 |
+
zA = logits_A.float() / tau
|
| 110 |
+
zB = logits_B_fwd.float() / tau
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
pB_t = torch.matmul(F.softmax(zB, dim=-1), P.t()).clamp_min(eps)
|
| 113 |
+
logpB_t = pB_t.log()
|
| 114 |
+
loss_A = F.kl_div(F.log_softmax(zA, dim=-1), logpB_t, reduction="batchmean", log_target=True)
|
| 115 |
+
with torch.no_grad():
|
| 116 |
+
pA_t = torch.matmul(F.softmax(zA, dim=-1), P.t()).clamp_min(eps)
|
| 117 |
+
logpA_t = pA_t.log()
|
| 118 |
+
loss_B = F.kl_div(F.log_softmax(zB, dim=-1), logpA_t, reduction="batchmean", log_target=True)
|
| 119 |
+
return 0.5 * (tau * tau) * (loss_A + loss_B)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ========================
|
| 123 |
+
# Barlow & TV(可选)
|
| 124 |
+
# ========================
|
| 125 |
+
def barlow_strand_loss_v2(z1, z2, λ_off=0.04, λ_diag=0.04, eps=1e-3):
|
| 126 |
+
"""稳定 Barlow:方差项 + 对角/非对角,z1,z2:[B,L,H] 已对齐到正向"""
|
| 127 |
+
B, L, H = z1.shape
|
| 128 |
+
n = B * L
|
| 129 |
+
z1 = z1.reshape(n, H)
|
| 130 |
+
z2 = z2.reshape(n, H)
|
| 131 |
+
|
| 132 |
+
def _std(z):
|
| 133 |
+
var = z.var(dim=0, unbiased=False)
|
| 134 |
+
return torch.sqrt(var + eps)
|
| 135 |
+
|
| 136 |
+
std1, std2 = _std(z1), _std(z2)
|
| 137 |
+
var_term = (F.relu(1 - std1).pow(2).mean() + F.relu(1 - std2).pow(2).mean())
|
| 138 |
+
|
| 139 |
+
z1 = (z1 - z1.mean(0)) / (std1 + eps)
|
| 140 |
+
z2 = (z2 - z2.mean(0)) / (std2 + eps)
|
| 141 |
+
c = (z1.t() @ z2) / max(1, n) # [H,H]
|
| 142 |
+
diag = torch.diagonal(c)
|
| 143 |
+
off = c - torch.diag_embed(diag)
|
| 144 |
+
cov = λ_diag * (1 - diag).pow(2).mean() + λ_off * off.pow(2).mean()
|
| 145 |
+
return var_term + cov
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def tv_mixed(h: torch.Tensor):
|
| 149 |
+
"""一阶 L1 + 二阶 L2,总体平滑;h:[B,L,H]"""
|
| 150 |
+
d1 = h[:, 1:, :] - h[:, :-1, :]
|
| 151 |
+
d2 = d1[:, 1:, :] - d1[:, :-1, :]
|
| 152 |
+
return d1.abs().mean() + d2.pow(2).mean()
|
| 153 |
+
|
| 154 |
+
class Mlp(nn.Module):
|
| 155 |
+
"""两层线性 + GELU,可返回 residual"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, input_dimension, hidden_dimension=None, output_dimension=None,
|
| 158 |
+
activation=F.gelu, return_residual=False):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.return_residual = return_residual
|
| 161 |
+
hd = hidden_dimension or input_dimension
|
| 162 |
+
od = output_dimension or input_dimension
|
| 163 |
+
self.linear1 = nn.Linear(input_dimension, hd)
|
| 164 |
+
self.activation = activation
|
| 165 |
+
self.linear2 = nn.Linear(hd, od)
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor):
|
| 168 |
+
h = self.activation(self.linear1(x))
|
| 169 |
+
y = self.linear2(h)
|
| 170 |
+
return (y, x) if self.return_residual else y
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_comba_cls(comba_kwargs=None, device=None, dtype=None):
|
| 174 |
+
"""安全工厂:返回一个可调用的 Comba 构造器(partial)"""
|
| 175 |
+
factory_kwargs = {}
|
| 176 |
+
if device is not None:
|
| 177 |
+
factory_kwargs["device"] = device
|
| 178 |
+
if dtype is not None:
|
| 179 |
+
factory_kwargs["dtype"] = dtype
|
| 180 |
+
try:
|
| 181 |
+
base_kwargs = dict(comba_kwargs or {})
|
| 182 |
+
mixer_cls = partial(comba.Comba, **base_kwargs, **factory_kwargs)
|
| 183 |
+
except ImportError:
|
| 184 |
+
class FallbackComba(nn.Module):
|
| 185 |
+
def forward(self, x, *args, **kwargs):
|
| 186 |
+
return x
|
| 187 |
+
mixer_cls = lambda *args, **kwargs: FallbackComba()
|
| 188 |
+
return mixer_cls
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ========================
|
| 192 |
+
# SWA block(去手动 FP16,统一 AMP)
|
| 193 |
+
# ========================
|
| 194 |
+
class SlidingWindowAttention(nn.Module):
|
| 195 |
+
"""
|
| 196 |
+
RMSNorm -> Sliding-window Attention -> Residual -> RMSNorm -> Gated MLP -> Residual
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(self, config: Any):
|
| 200 |
+
super().__init__()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if isinstance(config, dict):
|
| 204 |
+
c = config
|
| 205 |
+
else:
|
| 206 |
+
try:
|
| 207 |
+
c = vars(config)
|
| 208 |
+
except Exception as e:
|
| 209 |
+
raise TypeError(f"transformer_cfg must be dict-like, got {type(config)}") from e
|
| 210 |
+
|
| 211 |
+
attn_cfg = c["attn"]
|
| 212 |
+
|
| 213 |
+
self.mixer_norm = RMSNorm(hidden_size=c["hidden_size"], eps=c.get("norm_eps", 1e-5))
|
| 214 |
+
self.mixer = Attention(
|
| 215 |
+
hidden_size=c["hidden_size"],
|
| 216 |
+
num_heads=attn_cfg["num_heads"],
|
| 217 |
+
num_kv_heads=attn_cfg["num_kv_heads"],
|
| 218 |
+
qkv_bias=attn_cfg["qkv_bias"],
|
| 219 |
+
window_size=attn_cfg["window_size"],
|
| 220 |
+
rope_theta=attn_cfg["rope_theta"],
|
| 221 |
+
max_position_embeddings=c["max_position_embeddings"]
|
| 222 |
+
)
|
| 223 |
+
self.mlp_norm = RMSNorm(c["hidden_size"], eps=c.get("norm_eps", 1e-5))
|
| 224 |
+
self.mlp = SambaMLP(
|
| 225 |
+
hidden_size=c["hidden_size"],
|
| 226 |
+
hidden_ratio=c["hidden_ratio"],
|
| 227 |
+
hidden_act=c["hidden_act"],
|
| 228 |
+
fuse_swiglu=c["fuse_swiglu"]
|
| 229 |
+
)
|
| 230 |
+
self.pre_scale = 1.0 / math.sqrt(2.0)
|
| 231 |
+
|
| 232 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 233 |
+
cache_params: Optional[Any] = None, **kwargs) -> Tuple[torch.Tensor, Any]:
|
| 234 |
+
residual = hidden_states
|
| 235 |
+
x = self.mixer_norm(hidden_states)
|
| 236 |
+
|
| 237 |
+
amp_dtype = preferred_amp_dtype()
|
| 238 |
+
with amp.autocast("cuda", enabled=True, dtype=amp_dtype):
|
| 239 |
+
x_scaled = x * self.pre_scale
|
| 240 |
+
attn_out, _, cache_params = self.mixer(
|
| 241 |
+
hidden_states=x_scaled,
|
| 242 |
+
past_key_values=cache_params,
|
| 243 |
+
**kwargs
|
| 244 |
+
)
|
| 245 |
+
attn_out = attn_out / self.pre_scale
|
| 246 |
+
|
| 247 |
+
ensure_finite(attn_out, "attention_out")
|
| 248 |
+
h = residual + attn_out.to(x.dtype)
|
| 249 |
+
|
| 250 |
+
residual = h
|
| 251 |
+
x = self.mlp_norm(h)
|
| 252 |
+
with amp.autocast("cuda", enabled=True, dtype=amp_dtype):
|
| 253 |
+
x = self.mlp(x, **kwargs)
|
| 254 |
+
h = residual + x
|
| 255 |
+
ensure_finite(h, "block_output")
|
| 256 |
+
return h, cache_params
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ========================
|
| 260 |
+
# Enhanced Hybrid Core (Comba + SWA + gating)
|
| 261 |
+
# ========================
|
| 262 |
+
class EnhancedHybridCore(nn.Module):
|
| 263 |
+
def __init__(self, hidden_dim, comba_cfg, transformer_cfg, layer_idx=0, device=None, dtype=None):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.comba_cls = create_comba_cls(comba_kwargs=comba_cfg, device=device, dtype=dtype)
|
| 266 |
+
try:
|
| 267 |
+
self.comba = self.comba_cls(layer_idx=layer_idx)
|
| 268 |
+
except TypeError:
|
| 269 |
+
self.comba = self.comba_cls()
|
| 270 |
+
self.transformer = SlidingWindowAttention(config=transformer_cfg)
|
| 271 |
+
self.gate = nn.Linear(hidden_dim * 2, hidden_dim)
|
| 272 |
+
self.out_norm = nn.LayerNorm(hidden_dim)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def _first(x):
|
| 276 |
+
return x[0] if isinstance(x, tuple) else x
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
# x: [B, l, H]
|
| 280 |
+
m_out = self._first(self.comba(x))
|
| 281 |
+
t_out, _ = self.transformer(m_out)
|
| 282 |
+
concat = torch.cat([m_out, t_out], dim=-1)
|
| 283 |
+
g = torch.sigmoid(self.gate(concat))
|
| 284 |
+
fused = g * t_out + (1 - g) * m_out
|
| 285 |
+
y = self.out_norm(fused)
|
| 286 |
+
ensure_finite(y, "EnhancedHybridCore.out")
|
| 287 |
+
return y
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ========================
|
| 291 |
+
# Deep Branch —— 与原版一致
|
| 292 |
+
# ========================
|
| 293 |
+
class DeepEnhancedBranch(nn.Module):
|
| 294 |
+
def __init__(self, hidden_dim: int, comba_cfg: Dict | None, transformer_cfg: Any, depth: int = 4,
|
| 295 |
+
drop_path_rates=None, *, device=None, dtype=None):
|
| 296 |
+
super().__init__()
|
| 297 |
+
self.layers = nn.ModuleList()
|
| 298 |
+
if drop_path_rates is None:
|
| 299 |
+
rates = [0.05 * (i / max(1, depth - 1)) for i in range(depth)]
|
| 300 |
+
elif isinstance(drop_path_rates, (float, int)):
|
| 301 |
+
rates = [float(drop_path_rates)] * depth
|
| 302 |
+
else:
|
| 303 |
+
rates = list(drop_path_rates) + [drop_path_rates[-1]] * (depth - len(drop_path_rates))
|
| 304 |
+
for i in range(depth):
|
| 305 |
+
layer_cfg = transformer_cfg.copy()
|
| 306 |
+
layer_cfg["drop_path_prob"] = rates[i]
|
| 307 |
+
self.layers.append(EnhancedHybridCore(hidden_dim, comba_cfg, layer_cfg, i, device, dtype))
|
| 308 |
+
self.output_norm = nn.LayerNorm(hidden_dim)
|
| 309 |
+
|
| 310 |
+
def forward(self, x: torch.Tensor): # x:[B,l,H]
|
| 311 |
+
for layer in self.layers:
|
| 312 |
+
x = layer(x)
|
| 313 |
+
y = self.output_norm(x)
|
| 314 |
+
ensure_finite(y, "DeepEnhancedBranch.out")
|
| 315 |
+
return y
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
# ========================
|
| 319 |
+
# TokenBridge:多尺度膨胀卷积聚合 + 按位线性交换 + 门控
|
| 320 |
+
# ========================
|
| 321 |
+
class TokenBridge(nn.Module):
|
| 322 |
+
"""
|
| 323 |
+
轻量高效:用 depthwise dilated conv 把对侧分支的邻域/远邻上下文“揉”到当前位置,再做按位线性交换与门控注入。
|
| 324 |
+
要求 xA/xB 已对齐到正向坐标。
|
| 325 |
+
"""
|
| 326 |
+
def __init__(self, hidden_dim: int, dropout: float = 0.0,
|
| 327 |
+
kernel_size: int = 9, dilations=(1, 2, 4, 8, 16),
|
| 328 |
+
use_global_token: bool = True):
|
| 329 |
+
super().__init__()
|
| 330 |
+
h = hidden_dim
|
| 331 |
+
pad = lambda d: d * (kernel_size // 2)
|
| 332 |
+
|
| 333 |
+
# 对侧分支的多尺度上下文 B→A
|
| 334 |
+
self.dw_B = nn.ModuleList([
|
| 335 |
+
nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d,
|
| 336 |
+
groups=h, bias=False) for d in dilations
|
| 337 |
+
])
|
| 338 |
+
self.mix_B = nn.Conv1d(h * len(dilations), h, 1)
|
| 339 |
+
|
| 340 |
+
# 对侧分支的多尺度上下文 A→B
|
| 341 |
+
self.dw_A = nn.ModuleList([
|
| 342 |
+
nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d,
|
| 343 |
+
groups=h, bias=False) for d in dilations
|
| 344 |
+
])
|
| 345 |
+
self.mix_A = nn.Conv1d(h * len(dilations), h, 1)
|
| 346 |
+
|
| 347 |
+
# 本地按位交换(线性)
|
| 348 |
+
self.proj_B2A = nn.Linear(h, h)
|
| 349 |
+
self.proj_A2B = nn.Linear(h, h)
|
| 350 |
+
|
| 351 |
+
# 可选:极廉价“全局 token”(均值)广播
|
| 352 |
+
self.use_global_token = use_global_token
|
| 353 |
+
if use_global_token:
|
| 354 |
+
self.glb_B2A = nn.Linear(h, h)
|
| 355 |
+
self.glb_A2B = nn.Linear(h, h)
|
| 356 |
+
|
| 357 |
+
# 门控(逐 token、逐通道)
|
| 358 |
+
self.gate = nn.Linear(h * 4, h * 2) # -> [gA, gB]
|
| 359 |
+
self.dropout = nn.Dropout(dropout)
|
| 360 |
+
self.normA = nn.LayerNorm(h)
|
| 361 |
+
self.normB = nn.LayerNorm(h)
|
| 362 |
+
|
| 363 |
+
@staticmethod
|
| 364 |
+
def _agg(x: torch.Tensor, branches: nn.ModuleList, mix: nn.Module) -> torch.Tensor:
|
| 365 |
+
# x:[B,L,H] -> [B,L,H]
|
| 366 |
+
xch = x.transpose(1, 2) # [B,H,L]
|
| 367 |
+
ys = [conv(xch) for conv in branches] # list of [B,H,L]
|
| 368 |
+
y = torch.cat(ys, dim=1) # [B,H*len,L]
|
| 369 |
+
y = mix(y).transpose(1, 2).contiguous() # [B,L,H]
|
| 370 |
+
return y
|
| 371 |
+
|
| 372 |
+
def forward(self, xA: torch.Tensor, xB: torch.Tensor):
|
| 373 |
+
# 1) 多尺度聚合对侧上下文
|
| 374 |
+
ctxB = self._agg(xB, self.dw_B, self.mix_B) # B 的多尺度上下文,用于注入 A
|
| 375 |
+
ctxA = self._agg(xA, self.dw_A, self.mix_A) # A 的多尺度上下文,用于注入 B
|
| 376 |
+
|
| 377 |
+
# 2) 按位线性交换(叠加对侧上下文)
|
| 378 |
+
locA = self.proj_B2A(xB + ctxB) # B→A
|
| 379 |
+
locB = self.proj_A2B(xA + ctxA) # A→B
|
| 380 |
+
|
| 381 |
+
# 3) 可选:极廉价全局 token(均值)广播
|
| 382 |
+
if self.use_global_token:
|
| 383 |
+
gB = self.glb_B2A(xB.mean(dim=1, keepdim=True)) # [B,1,H]
|
| 384 |
+
gA = self.glb_A2B(xA.mean(dim=1, keepdim=True)) # [B,1,H]
|
| 385 |
+
locA = locA + gB.expand(-1, xA.size(1), -1)
|
| 386 |
+
locB = locB + gA.expand(-1, xB.size(1), -1)
|
| 387 |
+
|
| 388 |
+
# 4) 门控注入
|
| 389 |
+
z = torch.cat([xA, xB, xA - xB, xA * xB], dim=-1) # [B,L,4H]
|
| 390 |
+
gA, gB = self.gate(z).chunk(2, dim=-1)
|
| 391 |
+
gA = torch.sigmoid(gA)
|
| 392 |
+
gB = torch.sigmoid(gB)
|
| 393 |
+
|
| 394 |
+
yA = self.normA(xA + self.dropout(gA * locA))
|
| 395 |
+
yB = self.normB(xB + self.dropout(gB * locB))
|
| 396 |
+
|
| 397 |
+
ensure_finite(yA, "TokenBridgeLite.A")
|
| 398 |
+
ensure_finite(yB, "TokenBridgeLite.B")
|
| 399 |
+
return yA, yB
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# ========================
|
| 403 |
+
# Semantic Preservation Loss
|
| 404 |
+
# ========================
|
| 405 |
+
def semantic_preservation_loss(R_plus: torch.Tensor, H_S_plus: torch.Tensor,
|
| 406 |
+
λ_recon: float = 1.0, λ_local: float = 0.5, λ_global: float = 0.2):
|
| 407 |
+
recon = F.mse_loss(H_S_plus, R_plus)
|
| 408 |
+
if R_plus.size(1) >= 2:
|
| 409 |
+
d_ref = R_plus[:, 1:] - R_plus[:, :-1]
|
| 410 |
+
d_S = H_S_plus[:, 1:] - H_S_plus[:, :-1]
|
| 411 |
+
local = F.mse_loss(d_S, d_ref)
|
| 412 |
+
else:
|
| 413 |
+
local = torch.tensor(0., device=R_plus.device)
|
| 414 |
+
|
| 415 |
+
def gram_norm(x):
|
| 416 |
+
G = torch.einsum("b i d, b j d -> b i j", x, x)
|
| 417 |
+
return G / (G.norm(dim=(1, 2), keepdim=True) + 1e-6)
|
| 418 |
+
|
| 419 |
+
glob = F.mse_loss(gram_norm(H_S_plus), gram_norm(R_plus))
|
| 420 |
+
return λ_recon * recon + λ_local * local + λ_global * glob
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
@contextmanager
|
| 427 |
+
def eval_mode(*modules):
|
| 428 |
+
states = [m.training for m in modules]
|
| 429 |
+
try:
|
| 430 |
+
for m in modules: m.eval()
|
| 431 |
+
yield
|
| 432 |
+
finally:
|
| 433 |
+
for m, s in zip(modules, states): m.train(s)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
class SSScanDNAHybridModel(nn.Module):
|
| 437 |
+
"""
|
| 438 |
+
S-scan 分块内存优化版:
|
| 439 |
+
- 按块就地完成:分支→(必要时翻转)→桥接→投影→门控融合→(可选 final_conv)→输出
|
| 440 |
+
- 预训练仅拼接小体量 logits / 掩码;微调则按块收集 fused 并在末尾拼成整段
|
| 441 |
+
- 语义保持 teacher 也按块按需计算,避免整段 _run_branch 占显存
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
def __init__(
|
| 445 |
+
self,
|
| 446 |
+
alphabet_size=5,
|
| 447 |
+
d_model=128,
|
| 448 |
+
block_size=2048,
|
| 449 |
+
comba_cfg=None,
|
| 450 |
+
transformer_cfg=None,
|
| 451 |
+
depth=4,
|
| 452 |
+
drop_path_rates=None,
|
| 453 |
+
pretrain=False,
|
| 454 |
+
for_representation=False,
|
| 455 |
+
use_final_conv=False,
|
| 456 |
+
|
| 457 |
+
use_s_scan: bool = True,
|
| 458 |
+
use_mem: bool = False,
|
| 459 |
+
use_rc_kl: bool = False,
|
| 460 |
+
use_barlow: bool = False,
|
| 461 |
+
use_tv: bool = False,
|
| 462 |
+
|
| 463 |
+
sem_max_weight: float = 0.2,
|
| 464 |
+
sem_warmup_steps: int = 3000,
|
| 465 |
+
|
| 466 |
+
rc_max_weight: float = 0.2,
|
| 467 |
+
rc_warmup_steps: int = 2000,
|
| 468 |
+
rc_tau: float = 1.5,
|
| 469 |
+
rc_bidirectional_stopgrad: bool = True,
|
| 470 |
+
|
| 471 |
+
aux_ce_weight: float = 0.1,
|
| 472 |
+
|
| 473 |
+
gate_freeze_steps: int = 1000,
|
| 474 |
+
detach_gate: bool = False,
|
| 475 |
+
gate_sup_weight: float = 0.005,
|
| 476 |
+
gate_sup_warmup_steps: int = 500,
|
| 477 |
+
gate_temp: float = 2.0,
|
| 478 |
+
|
| 479 |
+
dropout=0.1,
|
| 480 |
+
|
| 481 |
+
use_ema_teacher: bool = True,
|
| 482 |
+
ema_decay: float = 0.999,
|
| 483 |
+
auto_update_ema_in_forward: bool = True,
|
| 484 |
+
|
| 485 |
+
use_bridge: bool = True,
|
| 486 |
+
bridge_dropout: float = 0.0,
|
| 487 |
+
):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.alphabet_size = alphabet_size
|
| 490 |
+
self.pretrain = pretrain
|
| 491 |
+
self.for_representation = for_representation
|
| 492 |
+
self.block_size = block_size
|
| 493 |
+
self.use_final_conv = use_final_conv
|
| 494 |
+
self.d_model = d_model
|
| 495 |
+
|
| 496 |
+
# 训练步计数
|
| 497 |
+
self.register_buffer("g_step", torch.zeros(1, dtype=torch.long))
|
| 498 |
+
|
| 499 |
+
# 输入映射(两条管线)
|
| 500 |
+
self.linear = nn.Conv1d(alphabet_size, d_model, kernel_size=9, padding=4)
|
| 501 |
+
self.rc_linear = nn.Conv1d(alphabet_size, d_model, kernel_size=9, padding=4)
|
| 502 |
+
|
| 503 |
+
# 双分支
|
| 504 |
+
self.branchA_core = DeepEnhancedBranch(
|
| 505 |
+
hidden_dim=d_model, comba_cfg=comba_cfg, transformer_cfg=transformer_cfg,
|
| 506 |
+
depth=depth, drop_path_rates=drop_path_rates
|
| 507 |
+
)
|
| 508 |
+
self.branchB_core = DeepEnhancedBranch(
|
| 509 |
+
hidden_dim=d_model, comba_cfg=comba_cfg, transformer_cfg=transformer_cfg,
|
| 510 |
+
depth=depth, drop_path_rates=drop_path_rates
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# TokenBridge
|
| 514 |
+
self.use_bridge = use_bridge
|
| 515 |
+
if self.use_bridge:
|
| 516 |
+
self.bridge = TokenBridge(d_model, dropout=bridge_dropout)
|
| 517 |
+
|
| 518 |
+
# EMA teacher
|
| 519 |
+
self.use_ema_teacher = use_ema_teacher
|
| 520 |
+
self.ema_decay = ema_decay
|
| 521 |
+
self.auto_update_ema_in_forward = auto_update_ema_in_forward
|
| 522 |
+
if self.use_ema_teacher:
|
| 523 |
+
self.branchA_core_ema = copy.deepcopy(self.branchA_core)
|
| 524 |
+
self.branchB_core_ema = copy.deepcopy(self.branchB_core)
|
| 525 |
+
for p in self.branchA_core_ema.parameters(): p.requires_grad_(False)
|
| 526 |
+
for p in self.branchB_core_ema.parameters(): p.requires_grad_(False)
|
| 527 |
+
if self.use_bridge:
|
| 528 |
+
self.bridge_ema = copy.deepcopy(self.bridge)
|
| 529 |
+
for p in self.bridge_ema.parameters(): p.requires_grad_(False)
|
| 530 |
+
|
| 531 |
+
# 末端投影 + 融合
|
| 532 |
+
self.proj_A = Mlp(d_model, d_model * 2, d_model, activation=F.gelu, return_residual=True)
|
| 533 |
+
self.proj_B = Mlp(d_model, d_model * 2, d_model, activation=F.gelu, return_residual=True)
|
| 534 |
+
self.gate_fuse = nn.Linear(2 * d_model, d_model)
|
| 535 |
+
self.out_linear = nn.Linear(d_model, alphabet_size)
|
| 536 |
+
self.dropout = nn.Dropout(dropout)
|
| 537 |
+
|
| 538 |
+
# RC 置换矩阵
|
| 539 |
+
P_comp, _ = make_complement_perm(self.alphabet_size)
|
| 540 |
+
self.register_buffer("P_comp", P_comp)
|
| 541 |
+
|
| 542 |
+
# 开关 & 调度
|
| 543 |
+
self.use_s_scan = use_s_scan
|
| 544 |
+
self.use_rc_kl = use_rc_kl
|
| 545 |
+
self.use_barlow = use_barlow
|
| 546 |
+
self.use_tv = use_tv
|
| 547 |
+
self.sem_max_weight = sem_max_weight
|
| 548 |
+
self.sem_warmup_steps = sem_warmup_steps
|
| 549 |
+
self.rc_max_weight = rc_max_weight
|
| 550 |
+
self.rc_warmup_steps = rc_warmup_steps
|
| 551 |
+
self.rc_tau = rc_tau
|
| 552 |
+
self.rc_bidirectional_stopgrad = rc_bidirectional_stopgrad
|
| 553 |
+
self.aux_ce_weight = aux_ce_weight
|
| 554 |
+
self.gate_freeze_steps = gate_freeze_steps
|
| 555 |
+
self.detach_gate = detach_gate
|
| 556 |
+
self.gate_sup_weight = gate_sup_weight
|
| 557 |
+
self.gate_sup_warmup_steps = gate_sup_warmup_steps
|
| 558 |
+
self.gate_temp = gate_temp
|
| 559 |
+
|
| 560 |
+
if use_final_conv:
|
| 561 |
+
self.final_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
|
| 562 |
+
|
| 563 |
+
@torch.no_grad()
|
| 564 |
+
def update_ema(self):
|
| 565 |
+
if not getattr(self, "use_ema_teacher", False): return
|
| 566 |
+
d = float(getattr(self, "ema_decay", 0.999))
|
| 567 |
+
for m_ema, m in [(self.branchA_core_ema, self.branchA_core),
|
| 568 |
+
(self.branchB_core_ema, self.branchB_core)]:
|
| 569 |
+
for p_ema, p in zip(m_ema.parameters(), m.parameters()):
|
| 570 |
+
p_ema.data.lerp_(p.data, 1.0 - d)
|
| 571 |
+
if getattr(self, "use_bridge", False) and hasattr(self, "bridge_ema"):
|
| 572 |
+
for p_ema, p in zip(self.bridge_ema.parameters(), self.bridge.parameters()):
|
| 573 |
+
p_ema.data.lerp_(p.data, 1.0 - d)
|
| 574 |
+
|
| 575 |
+
# 保持原签名;不使用 return_embedding 作为控制变量
|
| 576 |
+
def forward(self, seq, t=None, cls=None, return_embedding=False, state=None):
|
| 577 |
+
step = int(self.g_step.item())
|
| 578 |
+
if self.training:
|
| 579 |
+
self.g_step += 1
|
| 580 |
+
|
| 581 |
+
if self.pretrain:
|
| 582 |
+
mask = seq[1]
|
| 583 |
+
seq = seq[0]
|
| 584 |
+
else:
|
| 585 |
+
mask = None
|
| 586 |
+
|
| 587 |
+
mn, mx = int(seq.min()), int(seq.max())
|
| 588 |
+
assert 0 <= mn and mx < self.alphabet_size, f"seq ids out of range: [{mn}, {mx}] vs alpha={self.alphabet_size}"
|
| 589 |
+
|
| 590 |
+
# # ===== 输入 one-hot -> conv1d -> [B,L,H] =====
|
| 591 |
+
rc_seq = reverse_complement(seq)
|
| 592 |
+
seq_oh = F.one_hot(seq, num_classes=self.alphabet_size).float()
|
| 593 |
+
rc_oh = F.one_hot(rc_seq, num_classes=self.alphabet_size).float()
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
h = F.gelu(self.linear(seq_oh.permute(0, 2, 1))) # [B,H,L]
|
| 597 |
+
rc_h = F.gelu(self.rc_linear(rc_oh.permute(0, 2, 1))) # [B,H,L]
|
| 598 |
+
feat = self.dropout(h).permute(0, 2, 1) # [B,L,H]
|
| 599 |
+
rc_feat = self.dropout(rc_h).permute(0, 2, 1)
|
| 600 |
+
|
| 601 |
+
fused = None # 用于微调返回
|
| 602 |
+
|
| 603 |
+
# ===== 主干:S-scan / 非 S-scan =====
|
| 604 |
+
if self.use_s_scan:
|
| 605 |
+
B, L, H = feat.shape
|
| 606 |
+
l = self.block_size
|
| 607 |
+
K = (L + l - 1) // l
|
| 608 |
+
|
| 609 |
+
# 是否收集各类输出(减少不必要的拼接/驻留)
|
| 610 |
+
collect_fused = bool(self.for_representation)
|
| 611 |
+
collect_logits = (not self.for_representation) or self.pretrain
|
| 612 |
+
collect_ab_logits = (self.pretrain or self.use_rc_kl)
|
| 613 |
+
|
| 614 |
+
fused_chunks = [] if collect_fused else None
|
| 615 |
+
logits_chunks = [] if collect_logits else None
|
| 616 |
+
logitsA_chunks = [] if collect_ab_logits else None
|
| 617 |
+
logitsB_chunks = [] if collect_ab_logits else None
|
| 618 |
+
maskA_list, maskB_list = [], []
|
| 619 |
+
|
| 620 |
+
total_aux = feat.new_zeros([]) # 仅预训练使用
|
| 621 |
+
mem_A = mem_B = None
|
| 622 |
+
|
| 623 |
+
for t_block in range(K):
|
| 624 |
+
start = t_block * l
|
| 625 |
+
end = min(start + l, L)
|
| 626 |
+
|
| 627 |
+
X_fwd = feat[:, start:end, :]
|
| 628 |
+
X_rc = rc_feat[:, start:end, :]
|
| 629 |
+
|
| 630 |
+
if (t_block % 2) == 0:
|
| 631 |
+
X_A, X_B = X_fwd, X_rc
|
| 632 |
+
rev_A, rev_B = False, True
|
| 633 |
+
maskA_rc_blk = torch.zeros(B, end - start, dtype=torch.bool, device=feat.device)
|
| 634 |
+
maskB_rc_blk = torch.ones (B, end - start, dtype=torch.bool, device=feat.device)
|
| 635 |
+
else:
|
| 636 |
+
X_A, X_B = X_rc, X_fwd
|
| 637 |
+
rev_A, rev_B = True, False
|
| 638 |
+
maskA_rc_blk = torch.ones (B, end - start, dtype=torch.bool, device=feat.device)
|
| 639 |
+
maskB_rc_blk = torch.zeros(B, end - start, dtype=torch.bool, device=feat.device)
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
# 分支
|
| 643 |
+
H_A = self.branchA_core(X_A)
|
| 644 |
+
H_B = self.branchB_core(X_B)
|
| 645 |
+
if rev_A: H_A = torch.flip(H_A, dims=[1])
|
| 646 |
+
if rev_B: H_B = torch.flip(H_B, dims=[1])
|
| 647 |
+
|
| 648 |
+
if self.use_bridge:
|
| 649 |
+
H_A, H_B = self.bridge(H_A, H_B)
|
| 650 |
+
|
| 651 |
+
# 投影 + 融合
|
| 652 |
+
fA, rA = self.proj_A(H_A); FA = fA + rA
|
| 653 |
+
fB, rB = self.proj_B(H_B); FB = fB + rB
|
| 654 |
+
|
| 655 |
+
gate_in_blk = torch.cat([FA, FB], dim=-1)
|
| 656 |
+
g_logits_blk = self.gate_fuse(gate_in_blk)
|
| 657 |
+
g_raw_blk = torch.sigmoid(g_logits_blk / getattr(self, "gate_temp", 1.0))
|
| 658 |
+
if step < getattr(self, "gate_freeze_steps", 0):
|
| 659 |
+
g_blk = 0.5 * torch.ones_like(g_raw_blk)
|
| 660 |
+
else:
|
| 661 |
+
g_blk = g_raw_blk
|
| 662 |
+
if getattr(self, "detach_gate", False):
|
| 663 |
+
mix_blk = g_blk.detach() * FA + (1 - g_blk.detach()) * FB
|
| 664 |
+
else:
|
| 665 |
+
mix_blk = g_blk * FA + (1 - g_blk) * FB
|
| 666 |
+
fused_blk = F.layer_norm(mix_blk, (mix_blk.size(-1),))
|
| 667 |
+
fused_blk = ensure_finite(fused_blk, "fused_blk")
|
| 668 |
+
|
| 669 |
+
if self.use_final_conv:
|
| 670 |
+
fused_blk = self.final_conv(fused_blk.permute(0, 2, 1)).permute(0, 2, 1)
|
| 671 |
+
|
| 672 |
+
if collect_fused:
|
| 673 |
+
fused_chunks.append(fused_blk)
|
| 674 |
+
|
| 675 |
+
if collect_logits:
|
| 676 |
+
logits_blk = self.out_linear(fused_blk)
|
| 677 |
+
logits_chunks.append(logits_blk)
|
| 678 |
+
|
| 679 |
+
if collect_ab_logits:
|
| 680 |
+
logitsA_chunks.append(self.out_linear(FA))
|
| 681 |
+
logitsB_chunks.append(self.out_linear(FB))
|
| 682 |
+
|
| 683 |
+
maskA_list.append(maskA_rc_blk)
|
| 684 |
+
maskB_list.append(maskB_rc_blk)
|
| 685 |
+
|
| 686 |
+
# ===== 预训练 aux(按块)=====
|
| 687 |
+
if self.pretrain:
|
| 688 |
+
with torch.no_grad():
|
| 689 |
+
A_feat_blk = F.gelu(self.linear(
|
| 690 |
+
F.one_hot(seq[:, start:end], self.alphabet_size).float().permute(0, 2, 1)
|
| 691 |
+
)).permute(0, 2, 1)
|
| 692 |
+
B_feat_blk_rc = F.gelu(self.rc_linear(
|
| 693 |
+
F.one_hot(rc_seq[:, start:end], self.alphabet_size).float().permute(0, 2, 1)
|
| 694 |
+
)).permute(0, 2, 1)
|
| 695 |
+
|
| 696 |
+
teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
|
| 697 |
+
teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
|
| 698 |
+
tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and hasattr(self, "bridge_ema")) else (self.bridge if self.use_bridge else None)
|
| 699 |
+
|
| 700 |
+
mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
|
| 701 |
+
with eval_mode(*mods):
|
| 702 |
+
R_plus_A_blk = teacherA(A_feat_blk)
|
| 703 |
+
R_plus_B_blk = teacherB(A_feat_blk)
|
| 704 |
+
if tbridge is not None:
|
| 705 |
+
R_plus_A_blk, R_plus_B_blk = tbridge(R_plus_A_blk, R_plus_B_blk)
|
| 706 |
+
|
| 707 |
+
R_minus_A_blk_rc = teacherA(B_feat_blk_rc)
|
| 708 |
+
R_minus_B_blk_rc = teacherB(B_feat_blk_rc)
|
| 709 |
+
R_minus_A_blk_fwd = torch.flip(R_minus_A_blk_rc, dims=[1])
|
| 710 |
+
R_minus_B_blk_fwd = torch.flip(R_minus_B_blk_rc, dims=[1])
|
| 711 |
+
if tbridge is not None:
|
| 712 |
+
R_minus_A_blk_fwd, R_minus_B_blk_fwd = tbridge(R_minus_A_blk_fwd, R_minus_B_blk_fwd)
|
| 713 |
+
|
| 714 |
+
R_A_teacher_blk = torch.where(maskA_rc_blk.unsqueeze(-1), R_minus_A_blk_fwd, R_plus_A_blk)
|
| 715 |
+
R_B_teacher_blk = torch.where(maskB_rc_blk.unsqueeze(-1), R_minus_B_blk_fwd, R_plus_B_blk)
|
| 716 |
+
|
| 717 |
+
sem_A = semantic_preservation_loss(R_A_teacher_blk, FA)
|
| 718 |
+
sem_B = semantic_preservation_loss(R_B_teacher_blk, FB)
|
| 719 |
+
w_sem = linear_warmup_weight(step, getattr(self, "sem_warmup_steps", 0),
|
| 720 |
+
getattr(self, "sem_max_weight", 1.0))
|
| 721 |
+
total_aux = total_aux + w_sem * (sem_A + sem_B)
|
| 722 |
+
|
| 723 |
+
if (getattr(self, "gate_sup_weight", 0.0) > 0.0) and (step >= getattr(self, "gate_freeze_steps", 0)):
|
| 724 |
+
g_target_blk = (~maskA_rc_blk).float().unsqueeze(-1)
|
| 725 |
+
g_token_logits_blk = g_logits_blk.mean(dim=-1, keepdim=True) / getattr(self, "gate_temp", 1.0)
|
| 726 |
+
w_gate = linear_warmup_weight(
|
| 727 |
+
step - getattr(self, "gate_freeze_steps", 0),
|
| 728 |
+
getattr(self, "gate_sup_warmup_steps", 0),
|
| 729 |
+
getattr(self, "gate_sup_weight", 0.0),
|
| 730 |
+
)
|
| 731 |
+
total_aux = total_aux + w_gate * F.binary_cross_entropy_with_logits(
|
| 732 |
+
g_token_logits_blk, g_target_blk
|
| 733 |
+
)
|
| 734 |
+
|
| 735 |
+
if self.use_rc_kl and getattr(self, "rc_max_weight", 0.0) > 0:
|
| 736 |
+
if getattr(self, "rc_bidirectional_stopgrad", True):
|
| 737 |
+
rc = rc_consistency_bidirectional_stopgrad(
|
| 738 |
+
logitsA_chunks[-1], logitsB_chunks[-1], self.P_comp, tau=getattr(self, "rc_tau", 1.5)
|
| 739 |
+
)
|
| 740 |
+
else:
|
| 741 |
+
rc = rc_consistency_kl(
|
| 742 |
+
logitsA_chunks[-1], logitsB_chunks[-1], self.P_comp, tau=getattr(self, "rc_tau", 1.5)
|
| 743 |
+
)
|
| 744 |
+
w_rc = linear_warmup_weight(step, getattr(self, "rc_warmup_steps", 0),
|
| 745 |
+
getattr(self, "rc_max_weight", 0.0))
|
| 746 |
+
total_aux = total_aux + w_rc * rc
|
| 747 |
+
|
| 748 |
+
if self.use_barlow:
|
| 749 |
+
total_aux = total_aux + barlow_strand_loss_v2(H_A, H_B)
|
| 750 |
+
if self.use_tv:
|
| 751 |
+
total_aux = total_aux + tv_mixed(fused_blk)
|
| 752 |
+
|
| 753 |
+
# 拼接所需结果(并裁到 L)
|
| 754 |
+
logits = torch.cat(logits_chunks, dim=1)[:, :L, :] if collect_logits else None
|
| 755 |
+
logits_A_only = torch.cat(logitsA_chunks, dim=1)[:, :L, :] if collect_ab_logits else None
|
| 756 |
+
logits_B_only = torch.cat(logitsB_chunks, dim=1)[:, :L, :] if collect_ab_logits else None
|
| 757 |
+
mask_A_rc = torch.cat(maskA_list, dim=1)[:, :L]
|
| 758 |
+
mask_B_rc = torch.cat(maskB_list, dim=1)[:, :L]
|
| 759 |
+
if collect_fused:
|
| 760 |
+
fused = torch.cat(fused_chunks, dim=1)[:, :L, :]
|
| 761 |
+
|
| 762 |
+
else:
|
| 763 |
+
# 非 S-scan:整段走一遍(短序列或推理场景)
|
| 764 |
+
H_A = self.branchA_core(feat)
|
| 765 |
+
H_Br = self.branchB_core(rc_feat)
|
| 766 |
+
R_A = H_A
|
| 767 |
+
R_B = torch.flip(H_Br, dims=[1])
|
| 768 |
+
|
| 769 |
+
if self.use_bridge:
|
| 770 |
+
R_A, R_B = self.bridge(R_A, R_B)
|
| 771 |
+
|
| 772 |
+
fA, rA = self.proj_A(R_A); FA = fA + rA
|
| 773 |
+
fB, rB = self.proj_B(R_B); FB = fB + rB
|
| 774 |
+
|
| 775 |
+
gate_in = torch.cat([FA, FB], dim=-1)
|
| 776 |
+
g_logits = self.gate_fuse(gate_in)
|
| 777 |
+
g_raw = torch.sigmoid(g_logits / getattr(self, "gate_temp", 1.0))
|
| 778 |
+
if step < getattr(self, "gate_freeze_steps", 0):
|
| 779 |
+
g = 0.5 * torch.ones_like(g_raw)
|
| 780 |
+
else:
|
| 781 |
+
g = g_raw
|
| 782 |
+
if getattr(self, "detach_gate", False):
|
| 783 |
+
mix = g.detach() * FA + (1 - g.detach()) * FB
|
| 784 |
+
else:
|
| 785 |
+
mix = g * FA + (1 - g) * FB
|
| 786 |
+
fused = F.layer_norm(mix, (mix.size(-1),))
|
| 787 |
+
fused = ensure_finite(fused, "fused")
|
| 788 |
+
if self.use_final_conv:
|
| 789 |
+
fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
|
| 790 |
+
|
| 791 |
+
logits = self.out_linear(fused) if (not self.for_representation or self.pretrain) else None
|
| 792 |
+
logits_A_only = self.out_linear(FA) if (self.pretrain or self.use_rc_kl) else None
|
| 793 |
+
logits_B_only = self.out_linear(FB) if (self.pretrain or self.use_rc_kl) else None
|
| 794 |
+
mask_A_rc = torch.zeros(FA.size()[:2], dtype=torch.bool, device=FA.device)
|
| 795 |
+
mask_B_rc = torch.zeros_like(mask_A_rc)
|
| 796 |
+
total_aux = logits.new_zeros(()) if self.pretrain else None
|
| 797 |
+
|
| 798 |
+
# ===== 按原行为返回 =====
|
| 799 |
+
if self.for_representation:
|
| 800 |
+
# 微调时不做 EMA/aux,不更新 teacher
|
| 801 |
+
return fused, None
|
| 802 |
+
|
| 803 |
+
if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
|
| 804 |
+
self.update_ema()
|
| 805 |
+
|
| 806 |
+
current_step = int(step)
|
| 807 |
+
|
| 808 |
+
if self.pretrain:
|
| 809 |
+
# 兼容原有 loss:确保 A/B logits 存在
|
| 810 |
+
if logits_A_only is None: logits_A_only = self.out_linear(FA)
|
| 811 |
+
if logits_B_only is None: logits_B_only = self.out_linear(FB)
|
| 812 |
+
HybridOutput = namedtuple("HybridOutput", ["logits"])
|
| 813 |
+
return HybridOutput(
|
| 814 |
+
logits=(logits,
|
| 815 |
+
mask,
|
| 816 |
+
total_aux,
|
| 817 |
+
logits_A_only.detach(),
|
| 818 |
+
logits_B_only.detach(),
|
| 819 |
+
mask_A_rc.detach(),
|
| 820 |
+
mask_B_rc.detach(),
|
| 821 |
+
current_step)
|
| 822 |
+
), None
|
| 823 |
+
|
| 824 |
+
return logits, None
|
| 825 |
+
|
| 826 |
+
@property
|
| 827 |
+
def d_output(self):
|
| 828 |
+
if getattr(self, "d_model", None) is None:
|
| 829 |
+
raise NotImplementedError("SequenceModule instantiation must set d_output")
|
| 830 |
+
return self.d_model
|
| 831 |
+
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
class CrossDNAForMaskedLM(PreTrainedModel):
|
| 835 |
+
config_class = CrossDNAConfig
|
| 836 |
+
base_model_prefix = "backbone"
|
| 837 |
+
|
| 838 |
+
def __init__(self, config: CrossDNAConfig):
|
| 839 |
+
super().__init__(config)
|
| 840 |
+
self.config = config
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
self.backbone = SSScanDNAHybridModel(
|
| 844 |
+
alphabet_size=config.alphabet_size,
|
| 845 |
+
d_model=config.d_model,
|
| 846 |
+
block_size=config.block_size,
|
| 847 |
+
comba_cfg=config.comba_cfg,
|
| 848 |
+
transformer_cfg=config.transformer_cfg,
|
| 849 |
+
depth=config.depth,
|
| 850 |
+
drop_path_rates=config.drop_path_rates,
|
| 851 |
+
pretrain=config.pretrain,
|
| 852 |
+
for_representation=config.for_representation,
|
| 853 |
+
use_s_scan=config.use_s_scan,
|
| 854 |
+
use_mem=config.use_mem,
|
| 855 |
+
use_rc_kl=config.use_rc_kl,
|
| 856 |
+
use_barlow=config.use_barlow,
|
| 857 |
+
use_tv=config.use_tv,
|
| 858 |
+
sem_max_weight=config.sem_max_weight,
|
| 859 |
+
sem_warmup_steps=config.sem_warmup_steps,
|
| 860 |
+
aux_ce_weight=config.aux_ce_weight,
|
| 861 |
+
gate_freeze_steps=config.gate_freeze_steps,
|
| 862 |
+
detach_gate=config.detach_gate,
|
| 863 |
+
gate_sup_weight=config.gate_sup_weight,
|
| 864 |
+
gate_sup_warmup_steps=config.gate_sup_warmup_steps,
|
| 865 |
+
gate_temp=config.gate_temp,
|
| 866 |
+
dropout=config.dropout,
|
| 867 |
+
use_bridge=config.use_bridge,
|
| 868 |
+
bridge_dropout=0.0,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
self.post_init()
|
| 872 |
+
|
| 873 |
+
@property
|
| 874 |
+
def mask_token_id(self) -> int:
|
| 875 |
+
return getattr(self.config, "mask_token_id", 3)
|
| 876 |
+
|
| 877 |
+
def forward(
|
| 878 |
+
self,
|
| 879 |
+
input_ids: torch.LongTensor,
|
| 880 |
+
attention_mask: Optional[torch.Tensor] = None, # 未使用
|
| 881 |
+
labels: Optional[torch.LongTensor] = None,
|
| 882 |
+
**kwargs
|
| 883 |
+
) -> MaskedLMOutput:
|
| 884 |
+
"""
|
| 885 |
+
input_ids: [B, L],取值范围 0..alphabet_size-1(A=0,C=1,G=2,T=3,N=4)
|
| 886 |
+
labels: [B, L],MLM标签;非mask位置应为 -100(忽略)
|
| 887 |
+
"""
|
| 888 |
+
|
| 889 |
+
# 计算MLM mask:优先用 labels!=-100;否则 fallback 到 input_ids==mask_token_id
|
| 890 |
+
if labels is not None:
|
| 891 |
+
mlm_mask = (labels != -100)
|
| 892 |
+
else:
|
| 893 |
+
mlm_mask = (input_ids == self.mask_token_id)
|
| 894 |
+
|
| 895 |
+
if self.config.pretrain:
|
| 896 |
+
# 预训练路径:你的backbone在pretrain=True时期望 (seq, mask)
|
| 897 |
+
outputs, _ = self.backbone((input_ids, mlm_mask))
|
| 898 |
+
# outputs[0] 是 namedtuple HybridOutput,取 logits[0] 为主 logits
|
| 899 |
+
logits = outputs.logits[0] # [B, L, vocab_size]
|
| 900 |
+
else:
|
| 901 |
+
logits, _ = self.backbone(input_ids) # [B, L, vocab_size]
|
| 902 |
+
|
| 903 |
+
loss = None
|
| 904 |
+
if labels is not None:
|
| 905 |
+
# 标准MLM损失:仅在 labels != -100 的位置计算交叉熵
|
| 906 |
+
vocab = self.config.alphabet_size
|
| 907 |
+
logits_2d = logits.view(-1, vocab)
|
| 908 |
+
labels_1d = labels.view(-1)
|
| 909 |
+
loss = F.cross_entropy(logits_2d, labels_1d, ignore_index=-100)
|
| 910 |
+
|
| 911 |
+
return MaskedLMOutput(
|
| 912 |
+
loss=loss,
|
| 913 |
+
logits=logits,
|
| 914 |
+
hidden_states=None,
|
| 915 |
+
attentions=None,
|
| 916 |
+
)
|
8.1M/special_tokens_map.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "[BOS]",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"eos_token": "[SEP]",
|
| 5 |
+
"mask_token": "[MASK]",
|
| 6 |
+
"pad_token": "[PAD]",
|
| 7 |
+
"sep_token": "[SEP]",
|
| 8 |
+
"unk_token": "[UNK]"
|
| 9 |
+
}
|
8.1M/tokenization_crossdna.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CharacterTokenzier for Hugging Face Transformers.
|
| 3 |
+
This is heavily inspired from CanineTokenizer in transformers package.
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Sequence, Union
|
| 9 |
+
|
| 10 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class CrossDNATokenizer(PreTrainedTokenizer):
|
| 14 |
+
def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
|
| 15 |
+
"""Character tokenizer for Hugging Face transformers.
|
| 16 |
+
Args:
|
| 17 |
+
characters (Sequence[str]): List of desired characters. Any character which
|
| 18 |
+
is not included in this list will be replaced by a special token called
|
| 19 |
+
[UNK] with id=6. Following are list of all of the special tokens with
|
| 20 |
+
their corresponding ids:
|
| 21 |
+
"[CLS]": 0
|
| 22 |
+
"[SEP]": 1
|
| 23 |
+
"[BOS]": 2
|
| 24 |
+
"[MASK]": 3
|
| 25 |
+
"[PAD]": 4
|
| 26 |
+
"[RESERVED]": 5
|
| 27 |
+
"[UNK]": 6
|
| 28 |
+
an id (starting at 7) will be assigned to each character.
|
| 29 |
+
model_max_length (int): Model maximum sequence length.
|
| 30 |
+
"""
|
| 31 |
+
self.characters = characters
|
| 32 |
+
self.model_max_length = model_max_length
|
| 33 |
+
|
| 34 |
+
self._vocab_str_to_int = {
|
| 35 |
+
"[CLS]": 0,
|
| 36 |
+
"[SEP]": 1,
|
| 37 |
+
"[BOS]": 2,
|
| 38 |
+
"[MASK]": 3,
|
| 39 |
+
"[PAD]": 4,
|
| 40 |
+
"[RESERVED]": 5,
|
| 41 |
+
"[UNK]": 6,
|
| 42 |
+
**{ch: i + 7 for i, ch in enumerate(characters)},
|
| 43 |
+
}
|
| 44 |
+
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
|
| 48 |
+
eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
|
| 49 |
+
sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
|
| 50 |
+
cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
|
| 51 |
+
pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
|
| 52 |
+
unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
|
| 53 |
+
|
| 54 |
+
mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
|
| 55 |
+
|
| 56 |
+
if "add_special_tokens" in kwargs:
|
| 57 |
+
kwargs.pop("add_special_tokens")
|
| 58 |
+
|
| 59 |
+
super().__init__(
|
| 60 |
+
bos_token=bos_token,
|
| 61 |
+
eos_token=sep_token,
|
| 62 |
+
sep_token=sep_token,
|
| 63 |
+
cls_token=cls_token,
|
| 64 |
+
pad_token=pad_token,
|
| 65 |
+
mask_token=mask_token,
|
| 66 |
+
unk_token=unk_token,
|
| 67 |
+
add_prefix_space=False,
|
| 68 |
+
model_max_length=model_max_length,
|
| 69 |
+
padding_side=padding_side,
|
| 70 |
+
**kwargs,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def __len__(self):
|
| 74 |
+
return len(self._vocab_str_to_int)
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def vocab_size(self) -> int:
|
| 78 |
+
return len(self._vocab_str_to_int)
|
| 79 |
+
|
| 80 |
+
def get_vocab(self):
|
| 81 |
+
"""返回token到id的字典(HuggingFace标准接口)"""
|
| 82 |
+
return self._vocab_str_to_int
|
| 83 |
+
|
| 84 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 85 |
+
return list(text)
|
| 86 |
+
|
| 87 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 88 |
+
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
|
| 89 |
+
|
| 90 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 91 |
+
return self._vocab_int_to_str[index]
|
| 92 |
+
|
| 93 |
+
def convert_tokens_to_string(self, tokens):
|
| 94 |
+
return "".join(tokens)
|
| 95 |
+
|
| 96 |
+
def build_inputs_with_special_tokens(
|
| 97 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 98 |
+
) -> List[int]:
|
| 99 |
+
sep = [self.sep_token_id]
|
| 100 |
+
# cls = [self.cls_token_id]
|
| 101 |
+
result = token_ids_0 + sep
|
| 102 |
+
if token_ids_1 is not None:
|
| 103 |
+
result += token_ids_1 + sep
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
def get_special_tokens_mask(
|
| 107 |
+
self,
|
| 108 |
+
token_ids_0: List[int],
|
| 109 |
+
token_ids_1: Optional[List[int]] = None,
|
| 110 |
+
already_has_special_tokens: bool = False,
|
| 111 |
+
) -> List[int]:
|
| 112 |
+
if already_has_special_tokens:
|
| 113 |
+
return super().get_special_tokens_mask(
|
| 114 |
+
token_ids_0=token_ids_0,
|
| 115 |
+
token_ids_1=token_ids_1,
|
| 116 |
+
already_has_special_tokens=True,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
result = ([0] * len(token_ids_0)) + [1]
|
| 120 |
+
if token_ids_1 is not None:
|
| 121 |
+
result += ([0] * len(token_ids_1)) + [1]
|
| 122 |
+
return result
|
| 123 |
+
|
| 124 |
+
def create_token_type_ids_from_sequences(
|
| 125 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 126 |
+
) -> List[int]:
|
| 127 |
+
sep = [self.sep_token_id]
|
| 128 |
+
cls = [self.cls_token_id]
|
| 129 |
+
|
| 130 |
+
result = len(cls + token_ids_0 + sep) * [0]
|
| 131 |
+
if token_ids_1 is not None:
|
| 132 |
+
result += len(token_ids_1 + sep) * [1]
|
| 133 |
+
return result
|
| 134 |
+
|
| 135 |
+
def get_config(self) -> Dict:
|
| 136 |
+
return {
|
| 137 |
+
"char_ords": [ord(ch) for ch in self.characters],
|
| 138 |
+
"model_max_length": self.model_max_length,
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def from_config(cls, config: Dict) -> "CrossDNATokenizer":
|
| 143 |
+
cfg = {}
|
| 144 |
+
cfg["characters"] = [chr(i) for i in config["char_ords"]]
|
| 145 |
+
cfg["model_max_length"] = config["model_max_length"]
|
| 146 |
+
return cls(**cfg)
|
| 147 |
+
|
| 148 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
|
| 149 |
+
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
| 150 |
+
cfg = self.get_config()
|
| 151 |
+
with open(cfg_file, "w") as f:
|
| 152 |
+
json.dump(cfg, f, indent=4)
|
| 153 |
+
|
| 154 |
+
@classmethod
|
| 155 |
+
def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
|
| 156 |
+
cfg_file = Path(save_directory) / "tokenizer_config.json"
|
| 157 |
+
with open(cfg_file) as f:
|
| 158 |
+
cfg = json.load(f)
|
| 159 |
+
return cls.from_config(cfg)
|
8.1M/tokenizer_config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"char_ords": [65, 67, 71, 84, 78],
|
| 3 |
+
"model_max_length": 1024,
|
| 4 |
+
|
| 5 |
+
"add_prefix_space": false,
|
| 6 |
+
"padding_side": "left",
|
| 7 |
+
|
| 8 |
+
"bos_token": "[BOS]",
|
| 9 |
+
"eos_token": "[SEP]",
|
| 10 |
+
"sep_token": "[SEP]",
|
| 11 |
+
"cls_token": "[CLS]",
|
| 12 |
+
"pad_token": "[PAD]",
|
| 13 |
+
"mask_token": "[MASK]",
|
| 14 |
+
"unk_token": "[UNK]",
|
| 15 |
+
|
| 16 |
+
"added_tokens_decoder": {
|
| 17 |
+
"0": { "content": "[CLS]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
|
| 18 |
+
"1": { "content": "[SEP]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
|
| 19 |
+
"2": { "content": "[BOS]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
|
| 20 |
+
"3": { "content": "[MASK]", "lstrip": true, "normalized": false, "rstrip": false, "single_word": false, "special": true },
|
| 21 |
+
"4": { "content": "[PAD]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
|
| 22 |
+
"6": { "content": "[UNK]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true }
|
| 23 |
+
},
|
| 24 |
+
|
| 25 |
+
"tokenizer_class": "CrossDNATokenizer",
|
| 26 |
+
"auto_map": {
|
| 27 |
+
"AutoTokenizer": [
|
| 28 |
+
"tokenization_crossdna.CrossDNATokenizer",
|
| 29 |
+
null
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
}
|