chengCCC commited on
Commit
fc89b2b
·
verified ·
1 Parent(s): 672998f

Upload CrossDNA 519M pretrained files

Browse files
519M/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gpl-3.0
3
+ ---
4
+
5
+ ## Using CrossDNA 519M
6
+ ```python
7
+ import os
8
+ os.environ.setdefault("DISABLE_TORCH_COMPILE", "1")
9
+
10
+ import torch
11
+ if hasattr(torch, "compile"):
12
+ def _no_compile(fn=None, *args, **kwargs):
13
+ if fn is None:
14
+ def deco(f): return f
15
+ return deco
16
+ return fn
17
+ torch.compile = _no_compile
18
+
19
+ import torch
20
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
21
+
22
+ # Hugging Face Remote Repository Example
23
+ # repo_id = "chengCCC/CrossDNA_pretrain"
24
+ # subdir = "519M"
25
+
26
+ # tok = AutoTokenizer.from_pretrained(repo_id, subfolder=subdir, trust_remote_code=True, local_files_only=True)
27
+ # model = AutoModelForMaskedLM.from_pretrained(repo_id, subfolder=subdir, trust_remote_code=True, local_files_only=True).eval()
28
+
29
+ # Local Model Example
30
+ MODEL_DIR = "/data/zhaol/projects/huggingface_crossdna_1024/crossdna"
31
+ tok = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True, local_files_only=True)
32
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_DIR, trust_remote_code=True, local_files_only=True).eval()
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ model.to(device)
36
+
37
+ # 512bp DNA sequence
38
+ seq = "ACGT" * 128
39
+
40
+
41
+ enc = tok(seq, return_tensors="pt", add_special_tokens=False)
42
+
43
+
44
+ x = enc["input_ids"].to(device) # [1, L]
45
+
46
+ # Key: map tokenizer IDs back to the model-required range [0..4]
47
+ # A, C, G, T, N: 7..11 -> 0..4
48
+ # All other tokens (UNK/PAD/SEP/CLS/...) are treated as N (=4)
49
+ x = torch.where(x >= 7, x - 7, torch.full_like(x, 4))
50
+
51
+ # ====== embedding ======
52
+ was_pretrain = getattr(model.backbone, "pretrain", False)
53
+ was_for_repr = getattr(model.backbone, "for_representation", False)
54
+ model.backbone.pretrain = False
55
+ model.backbone.for_representation = True
56
+
57
+ with torch.inference_mode():
58
+ embeddings, _ = model.backbone(x) # [B, L, H]
59
+
60
+ print("input_ids.shape =", tuple(x.shape)) # input_ids.shape = (1, 512)
61
+ print("embeddings.shape =", tuple(embeddings.shape)) # embeddings.shape = (1, 512, 1024)
62
+
63
+ model.backbone.pretrain = was_pretrain
64
+ model.backbone.for_representation = was_for_repr
65
+
66
+ ```
519M/config.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "CrossDNA-pretrain",
3
+ "model_type": "crossdna",
4
+ "architectures": ["CrossDNAForMaskedLM"],
5
+
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_crossdna.CrossDNAConfig",
8
+ "AutoModelForMaskedLM": "modeling_crossdna.CrossDNAForMaskedLM",
9
+ "AutoTokenizer": "tokenization_crossdna.CrossDNATokenizer"
10
+ },
11
+
12
+ "torch_dtype": "float32",
13
+
14
+ "alphabet_size": 5,
15
+ "d_model": 1024,
16
+ "block_size": 1024,
17
+ "depth": 12,
18
+ "drop_path_rates": [0.0, 0.05],
19
+ "dropout": 0.15,
20
+
21
+ "pretrain": true,
22
+ "for_representation": false,
23
+ "use_s_scan": true,
24
+ "use_bridge": true,
25
+ "use_mem": false,
26
+ "use_rc_kl": false,
27
+ "use_barlow": false,
28
+ "use_tv": false,
29
+
30
+ "sem_max_weight": 0.12,
31
+ "sem_warmup_steps": 10000,
32
+ "aux_ce_weight": 0.0,
33
+ "gate_freeze_steps": 5000,
34
+ "detach_gate": false,
35
+ "gate_sup_weight": 0.02,
36
+ "gate_sup_warmup_steps": 1000,
37
+ "gate_temp": 2.0,
38
+
39
+ "transformer_cfg": {
40
+ "hidden_size": 1024,
41
+ "norm_eps": 1e-5,
42
+ "max_position_embeddings": 1024,
43
+ "hidden_ratio": 4.0,
44
+ "hidden_act": "swish",
45
+ "fuse_swiglu": true,
46
+ "attn": {
47
+ "num_heads": 8,
48
+ "num_kv_heads": 8,
49
+ "qkv_bias": false,
50
+ "window_size": 512,
51
+ "rope_theta": 10000
52
+ }
53
+ },
54
+ "comba_cfg": {
55
+ "hidden_size": 1024,
56
+ "expand_v": 1,
57
+ "head_dim": 128,
58
+ "num_heads": 8,
59
+ "use_gate": true,
60
+ "mode": "chunk",
61
+ "use_short_conv": true,
62
+ "correction_factor": 0.02,
63
+ "conv_size": 4,
64
+ "norm_eps": 1e-5
65
+ },
66
+
67
+ "pad_token_id": 4,
68
+ "bos_token_id": 2,
69
+ "sep_token_id": 1,
70
+ "cls_token_id": 0,
71
+ "mask_token_id": 3,
72
+
73
+ "vocab_size": 5
74
+ }
519M/configuration_crossdna.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class CrossDNAConfig(PretrainedConfig):
4
+ model_type = "crossdna"
5
+
6
+ def __init__(
7
+ self,
8
+
9
+ alphabet_size=5,
10
+ d_model=1024,
11
+ block_size=1024,
12
+ depth=12,
13
+ drop_path_rates=(0.0, 0.05),
14
+ dropout=0.15,
15
+
16
+ pretrain=True,
17
+ for_representation=False,
18
+ use_s_scan=True,
19
+ use_bridge=True,
20
+ use_mem=False,
21
+ use_rc_kl=False,
22
+ use_barlow=False,
23
+ use_tv=False,
24
+
25
+ sem_max_weight=0.12,
26
+ sem_warmup_steps=10000,
27
+ aux_ce_weight=0.0,
28
+ gate_freeze_steps=5000,
29
+ detach_gate=False,
30
+ gate_sup_weight=0.02,
31
+ gate_sup_warmup_steps=1000,
32
+ gate_temp=2.0,
33
+
34
+
35
+ transformer_cfg=None,
36
+ comba_cfg=None,
37
+
38
+
39
+ pad_token_id=4, # [PAD]
40
+ bos_token_id=2, # [BOS]
41
+ sep_token_id=1, # [SEP]
42
+ cls_token_id=0, # [CLS]
43
+ mask_token_id=3, # [MASK]
44
+ **kwargs
45
+ ):
46
+ super().__init__(
47
+ pad_token_id=pad_token_id,
48
+ bos_token_id=bos_token_id,
49
+ **kwargs
50
+ )
51
+ self.alphabet_size = alphabet_size
52
+ self.d_model = d_model
53
+ self.block_size = block_size
54
+ self.depth = depth
55
+ self.drop_path_rates = list(drop_path_rates) if drop_path_rates is not None else None
56
+ self.dropout = dropout
57
+
58
+ self.pretrain = pretrain
59
+ self.for_representation = for_representation
60
+ self.use_s_scan = use_s_scan
61
+ self.use_bridge = use_bridge
62
+ self.use_mem = use_mem
63
+ self.use_rc_kl = use_rc_kl
64
+ self.use_barlow = use_barlow
65
+ self.use_tv = use_tv
66
+
67
+ self.sem_max_weight = sem_max_weight
68
+ self.sem_warmup_steps = sem_warmup_steps
69
+ self.aux_ce_weight = aux_ce_weight
70
+ self.gate_freeze_steps = gate_freeze_steps
71
+ self.detach_gate = detach_gate
72
+ self.gate_sup_weight = gate_sup_weight
73
+ self.gate_sup_warmup_steps = gate_sup_warmup_steps
74
+ self.gate_temp = gate_temp
75
+
76
+ self.transformer_cfg = transformer_cfg or {
77
+ "hidden_size": d_model,
78
+ "norm_eps": 1e-5,
79
+ "max_position_embeddings": 1024, # 会被dataset.max_length替换;此处给默认
80
+ "hidden_ratio": 4.0,
81
+ "hidden_act": "swish",
82
+ "fuse_swiglu": True,
83
+ "attn": {
84
+ "num_heads": 8,
85
+ "num_kv_heads": 8,
86
+ "qkv_bias": False,
87
+ "window_size": 512,
88
+ "rope_theta": 10000
89
+ }
90
+ }
91
+ self.comba_cfg = comba_cfg or {
92
+ "hidden_size": d_model,
93
+ "expand_v": 1,
94
+ "head_dim": 128,
95
+ "num_heads": 8,
96
+ "use_gate": True,
97
+ "mode": "chunk",
98
+ "use_short_conv": True,
99
+ "correction_factor": 0.02,
100
+ "conv_size": 4,
101
+ "norm_eps": 1e-5,
102
+ }
103
+
104
+ # 方便AutoModel推断 vocab_size
105
+ self.vocab_size = self.alphabet_size
519M/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9f7d43de5bfeaf07f71a560c516e683c9cde9427648a1de6b09aab32d11b42f
3
+ size 4110024080
519M/modeling_crossdna.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ import torch
4
+ import copy
5
+ import torch.nn as nn
6
+ from torch import amp
7
+ import torch.nn.functional as F
8
+ from functools import partial
9
+ from contextlib import contextmanager
10
+ from collections import namedtuple
11
+ from typing import Dict, Optional, Tuple, Any
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import MaskedLMOutput
15
+
16
+ # 编译ckpt到bin文件时候打开注释。
17
+ import os as _os, torch as _torch
18
+ if _os.environ.get("DISABLE_TORCH_COMPILE", "1") == "1" and hasattr(_torch, "compile"):
19
+ def _no_compile(fn=None, *args, **kwargs):
20
+ if fn is None:
21
+ def deco(f): return f
22
+ return deco
23
+ return fn
24
+ _torch.compile = _no_compile
25
+ print("torch.compile =>", torch.compile)
26
+
27
+
28
+ from fla.layers import comba
29
+ from fla.layers.attn import Attention
30
+ from fla.modules import GatedMLP as SambaMLP
31
+ from fla.modules import RMSNorm
32
+ from torch.cuda.amp import autocast
33
+
34
+
35
+ try:
36
+ # 在 Transformers 的动态包 transformers_modules.<hash>.* 环境下走相对导入
37
+ from .configuration_crossdna import CrossDNAConfig
38
+ except ImportError:
39
+ # 直接本地运行(不是通过 from_pretrained 动态加载)时也能跑
40
+ from configuration_crossdna import CrossDNAConfig
41
+
42
+
43
+
44
+ # ========================
45
+ # Utils
46
+ # ========================
47
+ def complement(seq: torch.Tensor) -> torch.Tensor:
48
+ # A=0, C=1, G=2, T=3, N=4
49
+ comp = 3 - seq
50
+ comp[seq == 4] = 4
51
+ return comp
52
+
53
+
54
+ def reverse_complement(seq: torch.Tensor) -> torch.Tensor:
55
+ comp = complement(seq)
56
+ return torch.flip(comp, dims=[1])
57
+
58
+
59
+ def make_complement_perm(C=5, device=None, dtype=torch.float32):
60
+ # A=0,C=1,G=2,T=3,N=4 -> T,A,C,G,N
61
+ perm = torch.tensor([3, 0, 2, 1, 4], device=device)
62
+ P = torch.zeros(C, C, device=device, dtype=dtype)
63
+ P[torch.arange(C, device=device), perm] = 1.0
64
+ return P, perm
65
+
66
+
67
+ def ensure_finite(x: torch.Tensor, name: str):
68
+ # 只检查,不“吃掉”数值问题
69
+ if not torch.isfinite(x).all():
70
+ raise FloatingPointError(f"Non-finite values detected in {name}")
71
+ return x
72
+
73
+
74
+ def linear_warmup_weight(step: int, warmup_steps: int, max_w: float):
75
+ if warmup_steps <= 0:
76
+ return max_w
77
+ if step <= 0:
78
+ return 0.0
79
+ if step >= warmup_steps:
80
+ return max_w
81
+ return max_w * (step / warmup_steps)
82
+
83
+
84
+ def preferred_amp_dtype():
85
+ try:
86
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
87
+ return torch.bfloat16
88
+ except Exception:
89
+ pass
90
+ return torch.float16
91
+
92
+
93
+ # ========================
94
+ # RC 一致性(可选损失)
95
+ # ========================
96
+ def rc_consistency_kl(logits_A, logits_B_fwd, P, tau: float = 1.0, eps: float = 1e-6):
97
+ zA = logits_A.float() / tau
98
+ zB = logits_B_fwd.float() / tau
99
+ pA = F.softmax(zA, dim=-1)
100
+ logpA = F.log_softmax(zA, dim=-1)
101
+ pB = F.softmax(zB, dim=-1)
102
+ pB_comp = torch.matmul(pB, P.t()).clamp_min(eps)
103
+ logpB_comp = pB_comp.log()
104
+ kl = (pA * (logpA - logpB_comp)).sum(dim=-1).mean()
105
+ return kl * (tau * tau)
106
+
107
+
108
+ def rc_consistency_bidirectional_stopgrad(logits_A, logits_B_fwd, P, tau: float = 1.5, eps: float = 1e-6):
109
+ zA = logits_A.float() / tau
110
+ zB = logits_B_fwd.float() / tau
111
+ with torch.no_grad():
112
+ pB_t = torch.matmul(F.softmax(zB, dim=-1), P.t()).clamp_min(eps)
113
+ logpB_t = pB_t.log()
114
+ loss_A = F.kl_div(F.log_softmax(zA, dim=-1), logpB_t, reduction="batchmean", log_target=True)
115
+ with torch.no_grad():
116
+ pA_t = torch.matmul(F.softmax(zA, dim=-1), P.t()).clamp_min(eps)
117
+ logpA_t = pA_t.log()
118
+ loss_B = F.kl_div(F.log_softmax(zB, dim=-1), logpA_t, reduction="batchmean", log_target=True)
119
+ return 0.5 * (tau * tau) * (loss_A + loss_B)
120
+
121
+
122
+ # ========================
123
+ # Barlow & TV(可选)
124
+ # ========================
125
+ def barlow_strand_loss_v2(z1, z2, λ_off=0.04, λ_diag=0.04, eps=1e-3):
126
+ """稳定 Barlow:方差项 + 对角/非对角,z1,z2:[B,L,H] 已对齐到正向"""
127
+ B, L, H = z1.shape
128
+ n = B * L
129
+ z1 = z1.reshape(n, H)
130
+ z2 = z2.reshape(n, H)
131
+
132
+ def _std(z):
133
+ var = z.var(dim=0, unbiased=False)
134
+ return torch.sqrt(var + eps)
135
+
136
+ std1, std2 = _std(z1), _std(z2)
137
+ var_term = (F.relu(1 - std1).pow(2).mean() + F.relu(1 - std2).pow(2).mean())
138
+
139
+ z1 = (z1 - z1.mean(0)) / (std1 + eps)
140
+ z2 = (z2 - z2.mean(0)) / (std2 + eps)
141
+ c = (z1.t() @ z2) / max(1, n) # [H,H]
142
+ diag = torch.diagonal(c)
143
+ off = c - torch.diag_embed(diag)
144
+ cov = λ_diag * (1 - diag).pow(2).mean() + λ_off * off.pow(2).mean()
145
+ return var_term + cov
146
+
147
+
148
+ def tv_mixed(h: torch.Tensor):
149
+ """一阶 L1 + 二阶 L2,总体平滑;h:[B,L,H]"""
150
+ d1 = h[:, 1:, :] - h[:, :-1, :]
151
+ d2 = d1[:, 1:, :] - d1[:, :-1, :]
152
+ return d1.abs().mean() + d2.pow(2).mean()
153
+
154
+ class Mlp(nn.Module):
155
+ """两层线性 + GELU,可返回 residual"""
156
+
157
+ def __init__(self, input_dimension, hidden_dimension=None, output_dimension=None,
158
+ activation=F.gelu, return_residual=False):
159
+ super().__init__()
160
+ self.return_residual = return_residual
161
+ hd = hidden_dimension or input_dimension
162
+ od = output_dimension or input_dimension
163
+ self.linear1 = nn.Linear(input_dimension, hd)
164
+ self.activation = activation
165
+ self.linear2 = nn.Linear(hd, od)
166
+
167
+ def forward(self, x: torch.Tensor):
168
+ h = self.activation(self.linear1(x))
169
+ y = self.linear2(h)
170
+ return (y, x) if self.return_residual else y
171
+
172
+
173
+ def create_comba_cls(comba_kwargs=None, device=None, dtype=None):
174
+ """安全工厂:返回一个可调用的 Comba 构造器(partial)"""
175
+ factory_kwargs = {}
176
+ if device is not None:
177
+ factory_kwargs["device"] = device
178
+ if dtype is not None:
179
+ factory_kwargs["dtype"] = dtype
180
+ try:
181
+ base_kwargs = dict(comba_kwargs or {})
182
+ mixer_cls = partial(comba.Comba, **base_kwargs, **factory_kwargs)
183
+ except ImportError:
184
+ class FallbackComba(nn.Module):
185
+ def forward(self, x, *args, **kwargs):
186
+ return x
187
+ mixer_cls = lambda *args, **kwargs: FallbackComba()
188
+ return mixer_cls
189
+
190
+
191
+ # ========================
192
+ # SWA block(去手动 FP16,统一 AMP)
193
+ # ========================
194
+ class SlidingWindowAttention(nn.Module):
195
+ """
196
+ RMSNorm -> Sliding-window Attention -> Residual -> RMSNorm -> Gated MLP -> Residual
197
+ """
198
+
199
+ def __init__(self, config: Any):
200
+ super().__init__()
201
+
202
+
203
+ if isinstance(config, dict):
204
+ c = config
205
+ else:
206
+ try:
207
+ c = vars(config)
208
+ except Exception as e:
209
+ raise TypeError(f"transformer_cfg must be dict-like, got {type(config)}") from e
210
+
211
+ attn_cfg = c["attn"]
212
+
213
+ self.mixer_norm = RMSNorm(hidden_size=c["hidden_size"], eps=c.get("norm_eps", 1e-5))
214
+ self.mixer = Attention(
215
+ hidden_size=c["hidden_size"],
216
+ num_heads=attn_cfg["num_heads"],
217
+ num_kv_heads=attn_cfg["num_kv_heads"],
218
+ qkv_bias=attn_cfg["qkv_bias"],
219
+ window_size=attn_cfg["window_size"],
220
+ rope_theta=attn_cfg["rope_theta"],
221
+ max_position_embeddings=c["max_position_embeddings"]
222
+ )
223
+ self.mlp_norm = RMSNorm(c["hidden_size"], eps=c.get("norm_eps", 1e-5))
224
+ self.mlp = SambaMLP(
225
+ hidden_size=c["hidden_size"],
226
+ hidden_ratio=c["hidden_ratio"],
227
+ hidden_act=c["hidden_act"],
228
+ fuse_swiglu=c["fuse_swiglu"]
229
+ )
230
+ self.pre_scale = 1.0 / math.sqrt(2.0)
231
+
232
+ def forward(self, hidden_states: torch.Tensor,
233
+ cache_params: Optional[Any] = None, **kwargs) -> Tuple[torch.Tensor, Any]:
234
+ residual = hidden_states
235
+ x = self.mixer_norm(hidden_states)
236
+
237
+ amp_dtype = preferred_amp_dtype()
238
+ with amp.autocast("cuda", enabled=True, dtype=amp_dtype):
239
+ x_scaled = x * self.pre_scale
240
+ attn_out, _, cache_params = self.mixer(
241
+ hidden_states=x_scaled,
242
+ past_key_values=cache_params,
243
+ **kwargs
244
+ )
245
+ attn_out = attn_out / self.pre_scale
246
+
247
+ ensure_finite(attn_out, "attention_out")
248
+ h = residual + attn_out.to(x.dtype)
249
+
250
+ residual = h
251
+ x = self.mlp_norm(h)
252
+ with amp.autocast("cuda", enabled=True, dtype=amp_dtype):
253
+ x = self.mlp(x, **kwargs)
254
+ h = residual + x
255
+ ensure_finite(h, "block_output")
256
+ return h, cache_params
257
+
258
+
259
+ # ========================
260
+ # Enhanced Hybrid Core (Comba + SWA + gating)
261
+ # ========================
262
+ class EnhancedHybridCore(nn.Module):
263
+ def __init__(self, hidden_dim, comba_cfg, transformer_cfg, layer_idx=0, device=None, dtype=None):
264
+ super().__init__()
265
+ self.comba_cls = create_comba_cls(comba_kwargs=comba_cfg, device=device, dtype=dtype)
266
+ try:
267
+ self.comba = self.comba_cls(layer_idx=layer_idx)
268
+ except TypeError:
269
+ self.comba = self.comba_cls()
270
+ self.transformer = SlidingWindowAttention(config=transformer_cfg)
271
+ self.gate = nn.Linear(hidden_dim * 2, hidden_dim)
272
+ self.out_norm = nn.LayerNorm(hidden_dim)
273
+
274
+ @staticmethod
275
+ def _first(x):
276
+ return x[0] if isinstance(x, tuple) else x
277
+
278
+ def forward(self, x):
279
+ # x: [B, l, H]
280
+ m_out = self._first(self.comba(x))
281
+ t_out, _ = self.transformer(m_out)
282
+ concat = torch.cat([m_out, t_out], dim=-1)
283
+ g = torch.sigmoid(self.gate(concat))
284
+ fused = g * t_out + (1 - g) * m_out
285
+ y = self.out_norm(fused)
286
+ ensure_finite(y, "EnhancedHybridCore.out")
287
+ return y
288
+
289
+
290
+ # ========================
291
+ # Deep Branch —— 与原版一致
292
+ # ========================
293
+ class DeepEnhancedBranch(nn.Module):
294
+ def __init__(self, hidden_dim: int, comba_cfg: Dict | None, transformer_cfg: Any, depth: int = 4,
295
+ drop_path_rates=None, *, device=None, dtype=None):
296
+ super().__init__()
297
+ self.layers = nn.ModuleList()
298
+ if drop_path_rates is None:
299
+ rates = [0.05 * (i / max(1, depth - 1)) for i in range(depth)]
300
+ elif isinstance(drop_path_rates, (float, int)):
301
+ rates = [float(drop_path_rates)] * depth
302
+ else:
303
+ rates = list(drop_path_rates) + [drop_path_rates[-1]] * (depth - len(drop_path_rates))
304
+ for i in range(depth):
305
+ layer_cfg = transformer_cfg.copy()
306
+ layer_cfg["drop_path_prob"] = rates[i]
307
+ self.layers.append(EnhancedHybridCore(hidden_dim, comba_cfg, layer_cfg, i, device, dtype))
308
+ self.output_norm = nn.LayerNorm(hidden_dim)
309
+
310
+ def forward(self, x: torch.Tensor): # x:[B,l,H]
311
+ for layer in self.layers:
312
+ x = layer(x)
313
+ y = self.output_norm(x)
314
+ ensure_finite(y, "DeepEnhancedBranch.out")
315
+ return y
316
+
317
+
318
+ # ========================
319
+ # TokenBridge:多尺度膨胀卷积聚合 + 按位线性交换 + 门控
320
+ # ========================
321
+ class TokenBridge(nn.Module):
322
+ """
323
+ 轻量高效:用 depthwise dilated conv 把对侧分支的邻域/远邻上下文“揉”到当前位置,再做按位线性交换与门控注入。
324
+ 要求 xA/xB 已对齐到正向坐标。
325
+ """
326
+ def __init__(self, hidden_dim: int, dropout: float = 0.0,
327
+ kernel_size: int = 9, dilations=(1, 2, 4, 8, 16),
328
+ use_global_token: bool = True):
329
+ super().__init__()
330
+ h = hidden_dim
331
+ pad = lambda d: d * (kernel_size // 2)
332
+
333
+ # 对侧分支的多尺度上下文 B→A
334
+ self.dw_B = nn.ModuleList([
335
+ nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d,
336
+ groups=h, bias=False) for d in dilations
337
+ ])
338
+ self.mix_B = nn.Conv1d(h * len(dilations), h, 1)
339
+
340
+ # 对侧分支的多尺度上下文 A→B
341
+ self.dw_A = nn.ModuleList([
342
+ nn.Conv1d(h, h, kernel_size, padding=pad(d), dilation=d,
343
+ groups=h, bias=False) for d in dilations
344
+ ])
345
+ self.mix_A = nn.Conv1d(h * len(dilations), h, 1)
346
+
347
+ # 本地按位交换(线性)
348
+ self.proj_B2A = nn.Linear(h, h)
349
+ self.proj_A2B = nn.Linear(h, h)
350
+
351
+ # 可选:极廉价“全局 token”(均值)广播
352
+ self.use_global_token = use_global_token
353
+ if use_global_token:
354
+ self.glb_B2A = nn.Linear(h, h)
355
+ self.glb_A2B = nn.Linear(h, h)
356
+
357
+ # 门控(逐 token、逐通道)
358
+ self.gate = nn.Linear(h * 4, h * 2) # -> [gA, gB]
359
+ self.dropout = nn.Dropout(dropout)
360
+ self.normA = nn.LayerNorm(h)
361
+ self.normB = nn.LayerNorm(h)
362
+
363
+ @staticmethod
364
+ def _agg(x: torch.Tensor, branches: nn.ModuleList, mix: nn.Module) -> torch.Tensor:
365
+ # x:[B,L,H] -> [B,L,H]
366
+ xch = x.transpose(1, 2) # [B,H,L]
367
+ ys = [conv(xch) for conv in branches] # list of [B,H,L]
368
+ y = torch.cat(ys, dim=1) # [B,H*len,L]
369
+ y = mix(y).transpose(1, 2).contiguous() # [B,L,H]
370
+ return y
371
+
372
+ def forward(self, xA: torch.Tensor, xB: torch.Tensor):
373
+ # 1) 多尺度聚合对侧上下文
374
+ ctxB = self._agg(xB, self.dw_B, self.mix_B) # B 的多尺度上下文,用于注入 A
375
+ ctxA = self._agg(xA, self.dw_A, self.mix_A) # A 的多尺度上下文,用于注入 B
376
+
377
+ # 2) 按位线性交换(叠加对侧上下文)
378
+ locA = self.proj_B2A(xB + ctxB) # B→A
379
+ locB = self.proj_A2B(xA + ctxA) # A→B
380
+
381
+ # 3) 可选:极廉价全局 token(均值)广播
382
+ if self.use_global_token:
383
+ gB = self.glb_B2A(xB.mean(dim=1, keepdim=True)) # [B,1,H]
384
+ gA = self.glb_A2B(xA.mean(dim=1, keepdim=True)) # [B,1,H]
385
+ locA = locA + gB.expand(-1, xA.size(1), -1)
386
+ locB = locB + gA.expand(-1, xB.size(1), -1)
387
+
388
+ # 4) 门控注入
389
+ z = torch.cat([xA, xB, xA - xB, xA * xB], dim=-1) # [B,L,4H]
390
+ gA, gB = self.gate(z).chunk(2, dim=-1)
391
+ gA = torch.sigmoid(gA)
392
+ gB = torch.sigmoid(gB)
393
+
394
+ yA = self.normA(xA + self.dropout(gA * locA))
395
+ yB = self.normB(xB + self.dropout(gB * locB))
396
+
397
+ ensure_finite(yA, "TokenBridgeLite.A")
398
+ ensure_finite(yB, "TokenBridgeLite.B")
399
+ return yA, yB
400
+
401
+
402
+ # ========================
403
+ # Semantic Preservation Loss
404
+ # ========================
405
+ def semantic_preservation_loss(R_plus: torch.Tensor, H_S_plus: torch.Tensor,
406
+ λ_recon: float = 1.0, λ_local: float = 0.5, λ_global: float = 0.2):
407
+ recon = F.mse_loss(H_S_plus, R_plus)
408
+ if R_plus.size(1) >= 2:
409
+ d_ref = R_plus[:, 1:] - R_plus[:, :-1]
410
+ d_S = H_S_plus[:, 1:] - H_S_plus[:, :-1]
411
+ local = F.mse_loss(d_S, d_ref)
412
+ else:
413
+ local = torch.tensor(0., device=R_plus.device)
414
+
415
+ def gram_norm(x):
416
+ G = torch.einsum("b i d, b j d -> b i j", x, x)
417
+ return G / (G.norm(dim=(1, 2), keepdim=True) + 1e-6)
418
+
419
+ glob = F.mse_loss(gram_norm(H_S_plus), gram_norm(R_plus))
420
+ return λ_recon * recon + λ_local * local + λ_global * glob
421
+
422
+
423
+
424
+
425
+
426
+ @contextmanager
427
+ def eval_mode(*modules):
428
+ states = [m.training for m in modules]
429
+ try:
430
+ for m in modules: m.eval()
431
+ yield
432
+ finally:
433
+ for m, s in zip(modules, states): m.train(s)
434
+
435
+
436
+ class SSScanDNAHybridModel(nn.Module):
437
+ """
438
+ S-scan 分块内存优化版:
439
+ - 按块就地完成:分支→(必要时翻转)→桥接→投影→门控融合→(可选 final_conv)→输出
440
+ - 预训练仅拼接小体量 logits / 掩码;微调则按块收集 fused 并在末尾拼成整段
441
+ - 语义保持 teacher 也按块按需计算,避免整段 _run_branch 占显存
442
+ """
443
+
444
+ def __init__(
445
+ self,
446
+ alphabet_size=5,
447
+ d_model=128,
448
+ block_size=2048,
449
+ comba_cfg=None,
450
+ transformer_cfg=None,
451
+ depth=4,
452
+ drop_path_rates=None,
453
+ pretrain=False,
454
+ for_representation=False,
455
+ use_final_conv=False,
456
+
457
+ use_s_scan: bool = True,
458
+ use_mem: bool = False,
459
+ use_rc_kl: bool = False,
460
+ use_barlow: bool = False,
461
+ use_tv: bool = False,
462
+
463
+ sem_max_weight: float = 0.2,
464
+ sem_warmup_steps: int = 3000,
465
+
466
+ rc_max_weight: float = 0.2,
467
+ rc_warmup_steps: int = 2000,
468
+ rc_tau: float = 1.5,
469
+ rc_bidirectional_stopgrad: bool = True,
470
+
471
+ aux_ce_weight: float = 0.1,
472
+
473
+ gate_freeze_steps: int = 1000,
474
+ detach_gate: bool = False,
475
+ gate_sup_weight: float = 0.005,
476
+ gate_sup_warmup_steps: int = 500,
477
+ gate_temp: float = 2.0,
478
+
479
+ dropout=0.1,
480
+
481
+ use_ema_teacher: bool = True,
482
+ ema_decay: float = 0.999,
483
+ auto_update_ema_in_forward: bool = True,
484
+
485
+ use_bridge: bool = True,
486
+ bridge_dropout: float = 0.0,
487
+ ):
488
+ super().__init__()
489
+ self.alphabet_size = alphabet_size
490
+ self.pretrain = pretrain
491
+ self.for_representation = for_representation
492
+ self.block_size = block_size
493
+ self.use_final_conv = use_final_conv
494
+ self.d_model = d_model
495
+
496
+ # 训练步计数
497
+ self.register_buffer("g_step", torch.zeros(1, dtype=torch.long))
498
+
499
+ # 输入映射(两条管线)
500
+ self.linear = nn.Conv1d(alphabet_size, d_model, kernel_size=9, padding=4)
501
+ self.rc_linear = nn.Conv1d(alphabet_size, d_model, kernel_size=9, padding=4)
502
+
503
+ # 双分支
504
+ self.branchA_core = DeepEnhancedBranch(
505
+ hidden_dim=d_model, comba_cfg=comba_cfg, transformer_cfg=transformer_cfg,
506
+ depth=depth, drop_path_rates=drop_path_rates
507
+ )
508
+ self.branchB_core = DeepEnhancedBranch(
509
+ hidden_dim=d_model, comba_cfg=comba_cfg, transformer_cfg=transformer_cfg,
510
+ depth=depth, drop_path_rates=drop_path_rates
511
+ )
512
+
513
+ # TokenBridge
514
+ self.use_bridge = use_bridge
515
+ if self.use_bridge:
516
+ self.bridge = TokenBridge(d_model, dropout=bridge_dropout)
517
+
518
+ # EMA teacher
519
+ self.use_ema_teacher = use_ema_teacher
520
+ self.ema_decay = ema_decay
521
+ self.auto_update_ema_in_forward = auto_update_ema_in_forward
522
+ if self.use_ema_teacher:
523
+ self.branchA_core_ema = copy.deepcopy(self.branchA_core)
524
+ self.branchB_core_ema = copy.deepcopy(self.branchB_core)
525
+ for p in self.branchA_core_ema.parameters(): p.requires_grad_(False)
526
+ for p in self.branchB_core_ema.parameters(): p.requires_grad_(False)
527
+ if self.use_bridge:
528
+ self.bridge_ema = copy.deepcopy(self.bridge)
529
+ for p in self.bridge_ema.parameters(): p.requires_grad_(False)
530
+
531
+ # 末端投影 + 融合
532
+ self.proj_A = Mlp(d_model, d_model * 2, d_model, activation=F.gelu, return_residual=True)
533
+ self.proj_B = Mlp(d_model, d_model * 2, d_model, activation=F.gelu, return_residual=True)
534
+ self.gate_fuse = nn.Linear(2 * d_model, d_model)
535
+ self.out_linear = nn.Linear(d_model, alphabet_size)
536
+ self.dropout = nn.Dropout(dropout)
537
+
538
+ # RC 置换矩阵
539
+ P_comp, _ = make_complement_perm(self.alphabet_size)
540
+ self.register_buffer("P_comp", P_comp)
541
+
542
+ # 开关 & 调度
543
+ self.use_s_scan = use_s_scan
544
+ self.use_rc_kl = use_rc_kl
545
+ self.use_barlow = use_barlow
546
+ self.use_tv = use_tv
547
+ self.sem_max_weight = sem_max_weight
548
+ self.sem_warmup_steps = sem_warmup_steps
549
+ self.rc_max_weight = rc_max_weight
550
+ self.rc_warmup_steps = rc_warmup_steps
551
+ self.rc_tau = rc_tau
552
+ self.rc_bidirectional_stopgrad = rc_bidirectional_stopgrad
553
+ self.aux_ce_weight = aux_ce_weight
554
+ self.gate_freeze_steps = gate_freeze_steps
555
+ self.detach_gate = detach_gate
556
+ self.gate_sup_weight = gate_sup_weight
557
+ self.gate_sup_warmup_steps = gate_sup_warmup_steps
558
+ self.gate_temp = gate_temp
559
+
560
+ if use_final_conv:
561
+ self.final_conv = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1)
562
+
563
+ @torch.no_grad()
564
+ def update_ema(self):
565
+ if not getattr(self, "use_ema_teacher", False): return
566
+ d = float(getattr(self, "ema_decay", 0.999))
567
+ for m_ema, m in [(self.branchA_core_ema, self.branchA_core),
568
+ (self.branchB_core_ema, self.branchB_core)]:
569
+ for p_ema, p in zip(m_ema.parameters(), m.parameters()):
570
+ p_ema.data.lerp_(p.data, 1.0 - d)
571
+ if getattr(self, "use_bridge", False) and hasattr(self, "bridge_ema"):
572
+ for p_ema, p in zip(self.bridge_ema.parameters(), self.bridge.parameters()):
573
+ p_ema.data.lerp_(p.data, 1.0 - d)
574
+
575
+ # 保持原签名;不使用 return_embedding 作为控制变量
576
+ def forward(self, seq, t=None, cls=None, return_embedding=False, state=None):
577
+ step = int(self.g_step.item())
578
+ if self.training:
579
+ self.g_step += 1
580
+
581
+ if self.pretrain:
582
+ mask = seq[1]
583
+ seq = seq[0]
584
+ else:
585
+ mask = None
586
+
587
+ mn, mx = int(seq.min()), int(seq.max())
588
+ assert 0 <= mn and mx < self.alphabet_size, f"seq ids out of range: [{mn}, {mx}] vs alpha={self.alphabet_size}"
589
+
590
+ # # ===== 输入 one-hot -> conv1d -> [B,L,H] =====
591
+ rc_seq = reverse_complement(seq)
592
+ seq_oh = F.one_hot(seq, num_classes=self.alphabet_size).float()
593
+ rc_oh = F.one_hot(rc_seq, num_classes=self.alphabet_size).float()
594
+
595
+
596
+ h = F.gelu(self.linear(seq_oh.permute(0, 2, 1))) # [B,H,L]
597
+ rc_h = F.gelu(self.rc_linear(rc_oh.permute(0, 2, 1))) # [B,H,L]
598
+ feat = self.dropout(h).permute(0, 2, 1) # [B,L,H]
599
+ rc_feat = self.dropout(rc_h).permute(0, 2, 1)
600
+
601
+ fused = None # 用于微调返回
602
+
603
+ # ===== 主干:S-scan / 非 S-scan =====
604
+ if self.use_s_scan:
605
+ B, L, H = feat.shape
606
+ l = self.block_size
607
+ K = (L + l - 1) // l
608
+
609
+ # 是否收集各类输出(减少不必要的拼接/驻留)
610
+ collect_fused = bool(self.for_representation)
611
+ collect_logits = (not self.for_representation) or self.pretrain
612
+ collect_ab_logits = (self.pretrain or self.use_rc_kl)
613
+
614
+ fused_chunks = [] if collect_fused else None
615
+ logits_chunks = [] if collect_logits else None
616
+ logitsA_chunks = [] if collect_ab_logits else None
617
+ logitsB_chunks = [] if collect_ab_logits else None
618
+ maskA_list, maskB_list = [], []
619
+
620
+ total_aux = feat.new_zeros([]) # 仅预训练使用
621
+ mem_A = mem_B = None
622
+
623
+ for t_block in range(K):
624
+ start = t_block * l
625
+ end = min(start + l, L)
626
+
627
+ X_fwd = feat[:, start:end, :]
628
+ X_rc = rc_feat[:, start:end, :]
629
+
630
+ if (t_block % 2) == 0:
631
+ X_A, X_B = X_fwd, X_rc
632
+ rev_A, rev_B = False, True
633
+ maskA_rc_blk = torch.zeros(B, end - start, dtype=torch.bool, device=feat.device)
634
+ maskB_rc_blk = torch.ones (B, end - start, dtype=torch.bool, device=feat.device)
635
+ else:
636
+ X_A, X_B = X_rc, X_fwd
637
+ rev_A, rev_B = True, False
638
+ maskA_rc_blk = torch.ones (B, end - start, dtype=torch.bool, device=feat.device)
639
+ maskB_rc_blk = torch.zeros(B, end - start, dtype=torch.bool, device=feat.device)
640
+
641
+
642
+ # 分支
643
+ H_A = self.branchA_core(X_A)
644
+ H_B = self.branchB_core(X_B)
645
+ if rev_A: H_A = torch.flip(H_A, dims=[1])
646
+ if rev_B: H_B = torch.flip(H_B, dims=[1])
647
+
648
+ if self.use_bridge:
649
+ H_A, H_B = self.bridge(H_A, H_B)
650
+
651
+ # 投影 + 融合
652
+ fA, rA = self.proj_A(H_A); FA = fA + rA
653
+ fB, rB = self.proj_B(H_B); FB = fB + rB
654
+
655
+ gate_in_blk = torch.cat([FA, FB], dim=-1)
656
+ g_logits_blk = self.gate_fuse(gate_in_blk)
657
+ g_raw_blk = torch.sigmoid(g_logits_blk / getattr(self, "gate_temp", 1.0))
658
+ if step < getattr(self, "gate_freeze_steps", 0):
659
+ g_blk = 0.5 * torch.ones_like(g_raw_blk)
660
+ else:
661
+ g_blk = g_raw_blk
662
+ if getattr(self, "detach_gate", False):
663
+ mix_blk = g_blk.detach() * FA + (1 - g_blk.detach()) * FB
664
+ else:
665
+ mix_blk = g_blk * FA + (1 - g_blk) * FB
666
+ fused_blk = F.layer_norm(mix_blk, (mix_blk.size(-1),))
667
+ fused_blk = ensure_finite(fused_blk, "fused_blk")
668
+
669
+ if self.use_final_conv:
670
+ fused_blk = self.final_conv(fused_blk.permute(0, 2, 1)).permute(0, 2, 1)
671
+
672
+ if collect_fused:
673
+ fused_chunks.append(fused_blk)
674
+
675
+ if collect_logits:
676
+ logits_blk = self.out_linear(fused_blk)
677
+ logits_chunks.append(logits_blk)
678
+
679
+ if collect_ab_logits:
680
+ logitsA_chunks.append(self.out_linear(FA))
681
+ logitsB_chunks.append(self.out_linear(FB))
682
+
683
+ maskA_list.append(maskA_rc_blk)
684
+ maskB_list.append(maskB_rc_blk)
685
+
686
+ # ===== 预训练 aux(按块)=====
687
+ if self.pretrain:
688
+ with torch.no_grad():
689
+ A_feat_blk = F.gelu(self.linear(
690
+ F.one_hot(seq[:, start:end], self.alphabet_size).float().permute(0, 2, 1)
691
+ )).permute(0, 2, 1)
692
+ B_feat_blk_rc = F.gelu(self.rc_linear(
693
+ F.one_hot(rc_seq[:, start:end], self.alphabet_size).float().permute(0, 2, 1)
694
+ )).permute(0, 2, 1)
695
+
696
+ teacherA = self.branchA_core_ema if self.use_ema_teacher else self.branchA_core
697
+ teacherB = self.branchB_core_ema if self.use_ema_teacher else self.branchB_core
698
+ tbridge = self.bridge_ema if (self.use_bridge and self.use_ema_teacher and hasattr(self, "bridge_ema")) else (self.bridge if self.use_bridge else None)
699
+
700
+ mods = [teacherA, teacherB] + ([tbridge] if tbridge is not None else [])
701
+ with eval_mode(*mods):
702
+ R_plus_A_blk = teacherA(A_feat_blk)
703
+ R_plus_B_blk = teacherB(A_feat_blk)
704
+ if tbridge is not None:
705
+ R_plus_A_blk, R_plus_B_blk = tbridge(R_plus_A_blk, R_plus_B_blk)
706
+
707
+ R_minus_A_blk_rc = teacherA(B_feat_blk_rc)
708
+ R_minus_B_blk_rc = teacherB(B_feat_blk_rc)
709
+ R_minus_A_blk_fwd = torch.flip(R_minus_A_blk_rc, dims=[1])
710
+ R_minus_B_blk_fwd = torch.flip(R_minus_B_blk_rc, dims=[1])
711
+ if tbridge is not None:
712
+ R_minus_A_blk_fwd, R_minus_B_blk_fwd = tbridge(R_minus_A_blk_fwd, R_minus_B_blk_fwd)
713
+
714
+ R_A_teacher_blk = torch.where(maskA_rc_blk.unsqueeze(-1), R_minus_A_blk_fwd, R_plus_A_blk)
715
+ R_B_teacher_blk = torch.where(maskB_rc_blk.unsqueeze(-1), R_minus_B_blk_fwd, R_plus_B_blk)
716
+
717
+ sem_A = semantic_preservation_loss(R_A_teacher_blk, FA)
718
+ sem_B = semantic_preservation_loss(R_B_teacher_blk, FB)
719
+ w_sem = linear_warmup_weight(step, getattr(self, "sem_warmup_steps", 0),
720
+ getattr(self, "sem_max_weight", 1.0))
721
+ total_aux = total_aux + w_sem * (sem_A + sem_B)
722
+
723
+ if (getattr(self, "gate_sup_weight", 0.0) > 0.0) and (step >= getattr(self, "gate_freeze_steps", 0)):
724
+ g_target_blk = (~maskA_rc_blk).float().unsqueeze(-1)
725
+ g_token_logits_blk = g_logits_blk.mean(dim=-1, keepdim=True) / getattr(self, "gate_temp", 1.0)
726
+ w_gate = linear_warmup_weight(
727
+ step - getattr(self, "gate_freeze_steps", 0),
728
+ getattr(self, "gate_sup_warmup_steps", 0),
729
+ getattr(self, "gate_sup_weight", 0.0),
730
+ )
731
+ total_aux = total_aux + w_gate * F.binary_cross_entropy_with_logits(
732
+ g_token_logits_blk, g_target_blk
733
+ )
734
+
735
+ if self.use_rc_kl and getattr(self, "rc_max_weight", 0.0) > 0:
736
+ if getattr(self, "rc_bidirectional_stopgrad", True):
737
+ rc = rc_consistency_bidirectional_stopgrad(
738
+ logitsA_chunks[-1], logitsB_chunks[-1], self.P_comp, tau=getattr(self, "rc_tau", 1.5)
739
+ )
740
+ else:
741
+ rc = rc_consistency_kl(
742
+ logitsA_chunks[-1], logitsB_chunks[-1], self.P_comp, tau=getattr(self, "rc_tau", 1.5)
743
+ )
744
+ w_rc = linear_warmup_weight(step, getattr(self, "rc_warmup_steps", 0),
745
+ getattr(self, "rc_max_weight", 0.0))
746
+ total_aux = total_aux + w_rc * rc
747
+
748
+ if self.use_barlow:
749
+ total_aux = total_aux + barlow_strand_loss_v2(H_A, H_B)
750
+ if self.use_tv:
751
+ total_aux = total_aux + tv_mixed(fused_blk)
752
+
753
+ # 拼接所需结果(并裁到 L)
754
+ logits = torch.cat(logits_chunks, dim=1)[:, :L, :] if collect_logits else None
755
+ logits_A_only = torch.cat(logitsA_chunks, dim=1)[:, :L, :] if collect_ab_logits else None
756
+ logits_B_only = torch.cat(logitsB_chunks, dim=1)[:, :L, :] if collect_ab_logits else None
757
+ mask_A_rc = torch.cat(maskA_list, dim=1)[:, :L]
758
+ mask_B_rc = torch.cat(maskB_list, dim=1)[:, :L]
759
+ if collect_fused:
760
+ fused = torch.cat(fused_chunks, dim=1)[:, :L, :]
761
+
762
+ else:
763
+ # 非 S-scan:整段走一遍(短序列或推理场景)
764
+ H_A = self.branchA_core(feat)
765
+ H_Br = self.branchB_core(rc_feat)
766
+ R_A = H_A
767
+ R_B = torch.flip(H_Br, dims=[1])
768
+
769
+ if self.use_bridge:
770
+ R_A, R_B = self.bridge(R_A, R_B)
771
+
772
+ fA, rA = self.proj_A(R_A); FA = fA + rA
773
+ fB, rB = self.proj_B(R_B); FB = fB + rB
774
+
775
+ gate_in = torch.cat([FA, FB], dim=-1)
776
+ g_logits = self.gate_fuse(gate_in)
777
+ g_raw = torch.sigmoid(g_logits / getattr(self, "gate_temp", 1.0))
778
+ if step < getattr(self, "gate_freeze_steps", 0):
779
+ g = 0.5 * torch.ones_like(g_raw)
780
+ else:
781
+ g = g_raw
782
+ if getattr(self, "detach_gate", False):
783
+ mix = g.detach() * FA + (1 - g.detach()) * FB
784
+ else:
785
+ mix = g * FA + (1 - g) * FB
786
+ fused = F.layer_norm(mix, (mix.size(-1),))
787
+ fused = ensure_finite(fused, "fused")
788
+ if self.use_final_conv:
789
+ fused = self.final_conv(fused.permute(0, 2, 1)).permute(0, 2, 1)
790
+
791
+ logits = self.out_linear(fused) if (not self.for_representation or self.pretrain) else None
792
+ logits_A_only = self.out_linear(FA) if (self.pretrain or self.use_rc_kl) else None
793
+ logits_B_only = self.out_linear(FB) if (self.pretrain or self.use_rc_kl) else None
794
+ mask_A_rc = torch.zeros(FA.size()[:2], dtype=torch.bool, device=FA.device)
795
+ mask_B_rc = torch.zeros_like(mask_A_rc)
796
+ total_aux = logits.new_zeros(()) if self.pretrain else None
797
+
798
+ # ===== 按原行为返回 =====
799
+ if self.for_representation:
800
+ # 微调时不做 EMA/aux,不更新 teacher
801
+ return fused, None
802
+
803
+ if self.training and self.use_ema_teacher and self.auto_update_ema_in_forward:
804
+ self.update_ema()
805
+
806
+ current_step = int(step)
807
+
808
+ if self.pretrain:
809
+ # 兼容原有 loss:确保 A/B logits 存在
810
+ if logits_A_only is None: logits_A_only = self.out_linear(FA)
811
+ if logits_B_only is None: logits_B_only = self.out_linear(FB)
812
+ HybridOutput = namedtuple("HybridOutput", ["logits"])
813
+ return HybridOutput(
814
+ logits=(logits,
815
+ mask,
816
+ total_aux,
817
+ logits_A_only.detach(),
818
+ logits_B_only.detach(),
819
+ mask_A_rc.detach(),
820
+ mask_B_rc.detach(),
821
+ current_step)
822
+ ), None
823
+
824
+ return logits, None
825
+
826
+ @property
827
+ def d_output(self):
828
+ if getattr(self, "d_model", None) is None:
829
+ raise NotImplementedError("SequenceModule instantiation must set d_output")
830
+ return self.d_model
831
+
832
+
833
+
834
+ class CrossDNAForMaskedLM(PreTrainedModel):
835
+ config_class = CrossDNAConfig
836
+ base_model_prefix = "backbone"
837
+
838
+ def __init__(self, config: CrossDNAConfig):
839
+ super().__init__(config)
840
+ self.config = config
841
+
842
+
843
+ self.backbone = SSScanDNAHybridModel(
844
+ alphabet_size=config.alphabet_size,
845
+ d_model=config.d_model,
846
+ block_size=config.block_size,
847
+ comba_cfg=config.comba_cfg,
848
+ transformer_cfg=config.transformer_cfg,
849
+ depth=config.depth,
850
+ drop_path_rates=config.drop_path_rates,
851
+ pretrain=config.pretrain,
852
+ for_representation=config.for_representation,
853
+ use_s_scan=config.use_s_scan,
854
+ use_mem=config.use_mem,
855
+ use_rc_kl=config.use_rc_kl,
856
+ use_barlow=config.use_barlow,
857
+ use_tv=config.use_tv,
858
+ sem_max_weight=config.sem_max_weight,
859
+ sem_warmup_steps=config.sem_warmup_steps,
860
+ aux_ce_weight=config.aux_ce_weight,
861
+ gate_freeze_steps=config.gate_freeze_steps,
862
+ detach_gate=config.detach_gate,
863
+ gate_sup_weight=config.gate_sup_weight,
864
+ gate_sup_warmup_steps=config.gate_sup_warmup_steps,
865
+ gate_temp=config.gate_temp,
866
+ dropout=config.dropout,
867
+ use_bridge=config.use_bridge,
868
+ bridge_dropout=0.0,
869
+ )
870
+
871
+ self.post_init()
872
+
873
+ @property
874
+ def mask_token_id(self) -> int:
875
+ return getattr(self.config, "mask_token_id", 3)
876
+
877
+ def forward(
878
+ self,
879
+ input_ids: torch.LongTensor,
880
+ attention_mask: Optional[torch.Tensor] = None, # 未使用
881
+ labels: Optional[torch.LongTensor] = None,
882
+ **kwargs
883
+ ) -> MaskedLMOutput:
884
+ """
885
+ input_ids: [B, L],取值范围 0..alphabet_size-1(A=0,C=1,G=2,T=3,N=4)
886
+ labels: [B, L],MLM标签;非mask位置应为 -100(忽略)
887
+ """
888
+
889
+ # 计算MLM mask:优先用 labels!=-100;否则 fallback 到 input_ids==mask_token_id
890
+ if labels is not None:
891
+ mlm_mask = (labels != -100)
892
+ else:
893
+ mlm_mask = (input_ids == self.mask_token_id)
894
+
895
+ if self.config.pretrain:
896
+ # 预训练路径:你的backbone在pretrain=True时期望 (seq, mask)
897
+ outputs, _ = self.backbone((input_ids, mlm_mask))
898
+ # outputs[0] 是 namedtuple HybridOutput,取 logits[0] 为主 logits
899
+ logits = outputs.logits[0] # [B, L, vocab_size]
900
+ else:
901
+ logits, _ = self.backbone(input_ids) # [B, L, vocab_size]
902
+
903
+ loss = None
904
+ if labels is not None:
905
+ # 标准MLM损失:仅在 labels != -100 的位置计算交叉熵
906
+ vocab = self.config.alphabet_size
907
+ logits_2d = logits.view(-1, vocab)
908
+ labels_1d = labels.view(-1)
909
+ loss = F.cross_entropy(logits_2d, labels_1d, ignore_index=-100)
910
+
911
+ return MaskedLMOutput(
912
+ loss=loss,
913
+ logits=logits,
914
+ hidden_states=None,
915
+ attentions=None,
916
+ )
519M/special_tokens_map.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": "[MASK]",
6
+ "pad_token": "[PAD]",
7
+ "sep_token": "[SEP]",
8
+ "unk_token": "[UNK]"
9
+ }
519M/tokenization_crossdna.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CharacterTokenzier for Hugging Face Transformers.
3
+ This is heavily inspired from CanineTokenizer in transformers package.
4
+ """
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Sequence, Union
9
+
10
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
11
+
12
+
13
+ class CrossDNATokenizer(PreTrainedTokenizer):
14
+ def __init__(self, characters: Sequence[str], model_max_length: int, padding_side: str='left', **kwargs):
15
+ """Character tokenizer for Hugging Face transformers.
16
+ Args:
17
+ characters (Sequence[str]): List of desired characters. Any character which
18
+ is not included in this list will be replaced by a special token called
19
+ [UNK] with id=6. Following are list of all of the special tokens with
20
+ their corresponding ids:
21
+ "[CLS]": 0
22
+ "[SEP]": 1
23
+ "[BOS]": 2
24
+ "[MASK]": 3
25
+ "[PAD]": 4
26
+ "[RESERVED]": 5
27
+ "[UNK]": 6
28
+ an id (starting at 7) will be assigned to each character.
29
+ model_max_length (int): Model maximum sequence length.
30
+ """
31
+ self.characters = characters
32
+ self.model_max_length = model_max_length
33
+
34
+ self._vocab_str_to_int = {
35
+ "[CLS]": 0,
36
+ "[SEP]": 1,
37
+ "[BOS]": 2,
38
+ "[MASK]": 3,
39
+ "[PAD]": 4,
40
+ "[RESERVED]": 5,
41
+ "[UNK]": 6,
42
+ **{ch: i + 7 for i, ch in enumerate(characters)},
43
+ }
44
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
45
+
46
+
47
+ bos_token = AddedToken("[BOS]", lstrip=False, rstrip=False)
48
+ eos_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
49
+ sep_token = AddedToken("[SEP]", lstrip=False, rstrip=False)
50
+ cls_token = AddedToken("[CLS]", lstrip=False, rstrip=False)
51
+ pad_token = AddedToken("[PAD]", lstrip=False, rstrip=False)
52
+ unk_token = AddedToken("[UNK]", lstrip=False, rstrip=False)
53
+
54
+ mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
55
+
56
+ if "add_special_tokens" in kwargs:
57
+ kwargs.pop("add_special_tokens")
58
+
59
+ super().__init__(
60
+ bos_token=bos_token,
61
+ eos_token=sep_token,
62
+ sep_token=sep_token,
63
+ cls_token=cls_token,
64
+ pad_token=pad_token,
65
+ mask_token=mask_token,
66
+ unk_token=unk_token,
67
+ add_prefix_space=False,
68
+ model_max_length=model_max_length,
69
+ padding_side=padding_side,
70
+ **kwargs,
71
+ )
72
+
73
+ def __len__(self):
74
+ return len(self._vocab_str_to_int)
75
+
76
+ @property
77
+ def vocab_size(self) -> int:
78
+ return len(self._vocab_str_to_int)
79
+
80
+ def get_vocab(self):
81
+ """返回token到id的字典(HuggingFace标准接口)"""
82
+ return self._vocab_str_to_int
83
+
84
+ def _tokenize(self, text: str) -> List[str]:
85
+ return list(text)
86
+
87
+ def _convert_token_to_id(self, token: str) -> int:
88
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
89
+
90
+ def _convert_id_to_token(self, index: int) -> str:
91
+ return self._vocab_int_to_str[index]
92
+
93
+ def convert_tokens_to_string(self, tokens):
94
+ return "".join(tokens)
95
+
96
+ def build_inputs_with_special_tokens(
97
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
98
+ ) -> List[int]:
99
+ sep = [self.sep_token_id]
100
+ # cls = [self.cls_token_id]
101
+ result = token_ids_0 + sep
102
+ if token_ids_1 is not None:
103
+ result += token_ids_1 + sep
104
+ return result
105
+
106
+ def get_special_tokens_mask(
107
+ self,
108
+ token_ids_0: List[int],
109
+ token_ids_1: Optional[List[int]] = None,
110
+ already_has_special_tokens: bool = False,
111
+ ) -> List[int]:
112
+ if already_has_special_tokens:
113
+ return super().get_special_tokens_mask(
114
+ token_ids_0=token_ids_0,
115
+ token_ids_1=token_ids_1,
116
+ already_has_special_tokens=True,
117
+ )
118
+
119
+ result = ([0] * len(token_ids_0)) + [1]
120
+ if token_ids_1 is not None:
121
+ result += ([0] * len(token_ids_1)) + [1]
122
+ return result
123
+
124
+ def create_token_type_ids_from_sequences(
125
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
126
+ ) -> List[int]:
127
+ sep = [self.sep_token_id]
128
+ cls = [self.cls_token_id]
129
+
130
+ result = len(cls + token_ids_0 + sep) * [0]
131
+ if token_ids_1 is not None:
132
+ result += len(token_ids_1 + sep) * [1]
133
+ return result
134
+
135
+ def get_config(self) -> Dict:
136
+ return {
137
+ "char_ords": [ord(ch) for ch in self.characters],
138
+ "model_max_length": self.model_max_length,
139
+ }
140
+
141
+ @classmethod
142
+ def from_config(cls, config: Dict) -> "CrossDNATokenizer":
143
+ cfg = {}
144
+ cfg["characters"] = [chr(i) for i in config["char_ords"]]
145
+ cfg["model_max_length"] = config["model_max_length"]
146
+ return cls(**cfg)
147
+
148
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
149
+ cfg_file = Path(save_directory) / "tokenizer_config.json"
150
+ cfg = self.get_config()
151
+ with open(cfg_file, "w") as f:
152
+ json.dump(cfg, f, indent=4)
153
+
154
+ @classmethod
155
+ def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
156
+ cfg_file = Path(save_directory) / "tokenizer_config.json"
157
+ with open(cfg_file) as f:
158
+ cfg = json.load(f)
159
+ return cls.from_config(cfg)
519M/tokenizer_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "char_ords": [65, 67, 71, 84, 78],
3
+ "model_max_length": 1024,
4
+
5
+ "add_prefix_space": false,
6
+ "padding_side": "left",
7
+
8
+ "bos_token": "[BOS]",
9
+ "eos_token": "[SEP]",
10
+ "sep_token": "[SEP]",
11
+ "cls_token": "[CLS]",
12
+ "pad_token": "[PAD]",
13
+ "mask_token": "[MASK]",
14
+ "unk_token": "[UNK]",
15
+
16
+ "added_tokens_decoder": {
17
+ "0": { "content": "[CLS]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
18
+ "1": { "content": "[SEP]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
19
+ "2": { "content": "[BOS]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
20
+ "3": { "content": "[MASK]", "lstrip": true, "normalized": false, "rstrip": false, "single_word": false, "special": true },
21
+ "4": { "content": "[PAD]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true },
22
+ "6": { "content": "[UNK]", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, "special": true }
23
+ },
24
+
25
+ "tokenizer_class": "CrossDNATokenizer",
26
+ "auto_map": {
27
+ "AutoTokenizer": [
28
+ "tokenization_crossdna.CrossDNATokenizer",
29
+ null
30
+ ]
31
+ }
32
+ }