File size: 14,006 Bytes
2180e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import torch
import torch.nn.functional as F
import json
import os
from collections import OrderedDict
from typing import Tuple, Optional
from tqdm import tqdm

from transformers import PreTrainedTokenizer
from model import load_encoder_components, ProteinMoleculeDualEncoder
from train import train_model
from train_ddp import train_model_ddp

# 默认路径
DEFAULT_PROTEIN_PATH = "./SaProt_650M_AF2"
DEFAULT_MOLECULE_PATH = "./ChemBERTa-zinc-base-v1"
def load_dual_tower_model(
    protein_model_path: Optional[str] = None,
    molecule_model_path: Optional[str] = None,
    pt_path: Optional[str] = None,
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> Tuple[ProteinMoleculeDualEncoder, PreTrainedTokenizer, PreTrainedTokenizer]:
    """
    统一加载模型和Tokenizer的函数。
    
    逻辑流:
    1. 确定骨干网络路径(如果未提供,则使用默认值)。
    2. 初始化模型结构(load_encoder_components + ProteinMoleculeDualEncoder)。
    3. 如果提供了 pt_path,处理 DDP 前缀并加载权重覆盖初始权重。
    """
    
    # --- 1. 确定 Backbones 路径 ---
    # 如果参数为 None,则回退到默认路径;否则使用传入的参数
    # 这样既满足了"只传pt_path时用默认路径",也允许"传pt_path同时指定特定骨干网络"
    p_path = protein_model_path if protein_model_path else DEFAULT_PROTEIN_PATH
    m_path = molecule_model_path if molecule_model_path else DEFAULT_MOLECULE_PATH

    print(f"Step 1: Initializing model structure...")
    print(f"  - Protein Backbone: {p_path}")
    print(f"  - Molecule Backbone: {m_path}")

    # --- 2. 初始化结构 (只写一次,代码复用) ---
    p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components(
        p_path, m_path
    )
    
    model = ProteinMoleculeDualEncoder(
        protein_encoder=p_encoder, 
        molecule_encoder=m_encoder,
        projection_dim=256 # 确保这里的 dim 和你训练时一致
    )

    # --- 3. 加载 Checkpoint (如果存在) ---
    if pt_path is not None:
        if not os.path.exists(pt_path):
            raise FileNotFoundError(f"Checkpoint not found at: {pt_path}")
            
        print(f"Step 2: Loading weights from {pt_path} ...")
        # map_location 防止显存不足或设备不匹配
        state_dict = torch.load(pt_path, map_location=device)

        # --- 3.1 处理 DDP 前缀 (module.) ---
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            # 如果 key 是 module. 开头,去掉前7个字符
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v
        
        # --- 3.2 覆盖权重 ---
        # strict=True 保证权重和结构完全对应,如果有不匹配会报错提示
        missing, unexpected = model.load_state_dict(new_state_dict, strict=True)
        print("  - Weights loaded successfully.")
        if missing: print(f"  - Warning: Missing keys: {missing}")
        if unexpected: print(f"  - Warning: Unexpected keys: {unexpected}")
    else:
        print("Using User-Given Encoders.")

    # 移动到指定设备
    model.to(device)
    model.eval() # 默认设为评估模式,如果需要训练在外部改回 train()
    return model, p_tokenizer, m_tokenizer

def _load_and_extract_data(json_path, extractor_func, desc="Data"):
    """
    通用数据加载辅助函数
    Args:
        json_path: JSON文件路径
        extractor_func: 一个函数,接收 (key, value),返回 (id, sequence_text)
                        如果 sequence_text 无效,应该返回 None
    """
    print(f"Loading {desc} from {json_path}...")
    if not os.path.exists(json_path):
        print(f"Warning: {json_path} does not exist.")
        return [], []

    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    ids = []
    seqs = []
    
    # 遍历 JSON,利用传入的 extractor_func 提取数据
    for k, v in data.items():
        extracted_id, extracted_seq = extractor_func(k, v)
        
        # 简单的有效性检查
        if extracted_seq and isinstance(extracted_seq, str) and len(extracted_seq.strip()) > 0:
            ids.append(extracted_id)
            seqs.append(extracted_seq)
            
    print(f"Found {len(ids)} valid items for {desc}.")
    return ids, seqs

def _compute_tower_vectors(ids, seqs, tokenizer, encoder, projector, batch_size, device, max_len, desc):
    """
    通用推理辅助函数:负责 Batch处理 -> Tokenize -> Model Forward -> Normalize
    """
    if not ids:
        return None, None
        
    embeddings = []
    
    # 使用 no_grad 并在推理模式下运行
    with torch.no_grad():
        for i in tqdm(range(0, len(ids), batch_size), desc=f"Encoding {desc}"):
            batch_seqs = seqs[i : i + batch_size]
            
            # 1. Tokenize
            inputs = tokenizer(
                batch_seqs, 
                padding=True, 
                truncation=True, 
                max_length=max_len, 
                return_tensors='pt'
            ).to(device)
            
            # 2. Forward Chain (拆解双塔模型的单边逻辑)
            # Encoder (Backbone)
            outputs = encoder(**inputs)
            # Pooling (取 [CLS] token, 通常是 idx 0)
            vec = outputs.last_hidden_state[:, 0, :] 
            # Projection (降维/映射层)
            vec = projector(vec)                     
            # Normalize (关键!检索任务必须做 L2 Normalize)
            vec = F.normalize(vec, p=2, dim=1)       
            
            # 移回 CPU 防止显存溢出
            embeddings.append(vec.cpu())

    # 拼接所有 batch 的结果
    if not embeddings:
        return None, None
        
    return ids, torch.cat(embeddings, dim=0)

def update_candidate_vectors(
    model, 
    protein_tokenizer, 
    molecule_tokenizer, 
    protein_json_path, 
    molecule_json_path, 
    output_dir, 
    batch_size=64, 
    device='cuda',
    max_prot_len=1024,
    max_mol_len=512
):
    """
    主函数:编排整个流程
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    model.eval()
    model.to(device)
    
    # --- 核心配置列表 (Configuration) ---
    # 这里定义了如何处理 Protein 和 Molecule 的差异
    tasks = [
        {
            "desc": "Protein",
            "path": protein_json_path,
            # Protein 提取规则: 
            #   ID = Key (UniprotID)
            #   Seq = Value["target__foldseek_seq"]
            "extract_fn": lambda k, v: (k, v.get('target__foldseek_seq')), 
            "tokenizer": protein_tokenizer,
            "encoder": model.protein_encoder,
            "projector": model.prot_proj,
            "max_len": max_prot_len,
            "save_name": "protein_candidates.pt"
        },
        {
            "desc": "Molecule",
            "path": molecule_json_path,
            # Molecule 提取规则: 
            #   ID = Key (SMILES)
            #   Seq = Key (SMILES) - 因为 SMILES 既是 ID 也是序列内容
            "extract_fn": lambda k, v: (k, k), 
            "tokenizer": molecule_tokenizer,
            "encoder": model.molecule_encoder,
            "projector": model.mol_proj,
            "max_len": max_mol_len,
            "save_name": "molecule_candidates.pt"
        }
    ]

    # --- 统一执行流程 ---
    for task in tasks:
        # 1. 准备数据
        ids, seqs = _load_and_extract_data(
            task['path'], 
            task['extract_fn'], 
            task['desc']
        )
        
        # 2. 计算向量
        valid_ids, vectors = _compute_tower_vectors(
            ids, seqs, 
            task['tokenizer'], 
            task['encoder'], 
            task['projector'],
            batch_size, 
            device, 
            task['max_len'], 
            task['desc']
        )
        
        # 3. 保存结果
        if vectors is not None:
            save_path = os.path.join(output_dir, task['save_name'])
            torch.save({
                "ids": valid_ids,  # 字符串列表 (UniProtID 或 SMILES)
                "vectors": vectors # FloatTensor [N, Dim]
            }, save_path)
            print(f"Saved {task['desc']} Vectors: {vectors.shape} to {save_path}\n")
        else:
            print(f"Skipping save for {task['desc']} (No valid data).\n")

def recompute_candidate_vectors(
    protein_json_path: str,
    molecule_json_path: str,
    output_dir: str,
    pt_path: str = None,
    protein_model_path: str = None,
    molecule_model_path: str = None,
    batch_size: int = 64,
    max_prot_len: int = 1024,
    max_mol_len: int = 512,
    device: str = None
):
    """
    全流程函数:加载双塔模型 -> 读取元数据 -> 推理生成向量库 -> 保存 .pt 文件
    
    Args:
        protein_json_path: Unique Target JSON 路径
        molecule_json_path: Unique Compound JSON 路径
        output_dir: 向量结果输出目录
        pt_path: (可选) 训练好的 .pt 权重文件路径。如果为 None,将使用随机初始化的 Projection 层。
        protein_model_path: (可选) 指定 Protein Backbone 路径。
        molecule_model_path: (可选) 指定 Molecule Backbone 路径。
        batch_size: 推理时的 Batch Size
        max_prot_len: 蛋白最大长度
        max_mol_len: 分子最大长度
        device: 'cuda' or 'cpu',默认自动检测
    """
    
    # 0. 自动检测设备
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print(f"={' Start Recomputing Candidate Vectors ':=^60}")
    print(f"Device: {device}")
    
    # 1. 加载模型和 Tokenizer
    # 这里的 load_dual_tower_model 是上一轮定义的函数
    print(f"\n>>> [Stage 1/2] Loading Model & Tokenizers...")
    try:
        model, p_tokenizer, m_tokenizer = load_dual_tower_model(
            protein_model_path=protein_model_path,
            molecule_model_path=molecule_model_path,
            pt_path=pt_path,
            device=device
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # 2. 计算并保存向量
    # 这里的 update_candidate_vectors 是上一轮定义的精简版函数
    print(f"\n>>> [Stage 2/2] Computing & Saving Vectors...")
    try:
        update_candidate_vectors(
            model=model,
            protein_tokenizer=p_tokenizer,
            molecule_tokenizer=m_tokenizer,
            protein_json_path=protein_json_path,
            molecule_json_path=molecule_json_path,
            output_dir=output_dir,
            batch_size=batch_size,
            device=device,
            max_prot_len=max_prot_len,
            max_mol_len=max_mol_len
        )
    except Exception as e:
        print(f"Error computing vectors: {e}")
        return

    print(f"\n={' All Done ':=^60}")
    print(f"Check output files in: {output_dir}")

def continuous_train(
        dataset_path: str,
        model_save_dir: str = 'Dual_Tower_Model/customized_checkpoints',
        protein_model_path:str = None,
        molecule_model_path:str = None,
        best_model_path:str = None,
        device:str = "cuda" if torch.cuda.is_available() else "cpu",
        epochs: int = 5,
        lr: float = 1e-4,
        batch_size: int = 16,
        use_ddp: bool = False
):
    """
    执行持续训练/微调流程。
    如果 best_model_path 存在,则从该 Checkpoint 继续训练;
    否则使用 protein_model_path 和 molecule_model_path (或默认 Backbone) 开始训练。
    """
    model, p_tokenizer, m_tokenizer = load_dual_tower_model(
        protein_model_path=protein_model_path,
        molecule_model_path=molecule_model_path,
        pt_path=best_model_path,
        device=device
    )
    
    if use_ddp:
        train_model_ddp(
            model_and_tokenizers=[model, p_tokenizer,m_tokenizer],
            dataset_path=dataset_path,
            model_save_dir=model_save_dir,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr
        )
    else:
        train_model(
            model_and_tokenizers=[model, p_tokenizer,m_tokenizer],
            dataset_path=dataset_path,
            model_save_dir=model_save_dir,
            epochs=epochs,
            batch_size=batch_size,
            lr=lr
        )
    return True

if __name__ == "__main__":
    # 假设你已经定义好了 load_dual_tower_model 和 update_candidate_vectors
    
    # 这里有两个可能的输入以及对应的三种情况
    # 1. 用户给定模型:(给定两个encoder)
    # 2. 用户给定数据:(给定一个有'compound__smiles', 'target__foldseek_seq', 'outcome_potency_pxc50', 'outcome_is_active'字段的数据集)
    # (3. 是否要使用ddp: train_model_ddp或者train_model)

    # **需要推理**
    # 1. 用户没输入数据, 输入模型   --> recompute_candidate_vectors(protein_model_path, molecule_model_path)
    # **需要训练 + 推理**
    # 2. 用户输入数据,   输入模型   
    #   --> continuous_train(protein_model_path, molecule_model_path) --> recompute_candidate_vectors(best_model_path)
    # 3. 用户输入数据,   没输入模型  
    #   --> continuous_train(best_model_path) --> recompute_candidate_vectors(best_model_path)
    
    protein_json_path = 'drug_target_activity/candidates/unique_targets.json'
    molecule_json_path = 'drug_target_activity/candidates/unique_compounds.json'
    best_model_path = 'Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt'
    recompute_candidate_vectors(
        protein_json_path=protein_json_path,
        molecule_json_path=molecule_json_path,
        output_dir='drug_target_activity/candidates',
        pt_path=best_model_path, # 加载训练好的权重
        batch_size=32
    )


#只给了数据没给模型,就重新训tensor