File size: 4,796 Bytes
2180e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aab3f3d
 
 
 
 
 
 
 
 
 
 
 
2180e31
aab3f3d
2180e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple

class ProteinMoleculeDataset(Dataset):
    def __init__(self, dataset_path, min_pxc50=5.0):
        """
        Args:
            dataset_path (str): Parquet 文件路径
            min_pxc50 (float): 当样本是 Active 但 pXC50 缺失时的默认填充值
        """
        super().__init__()
        
        # 1. 加载数据
        print(f"Loading dataset from {dataset_path}...")
        try:
            if dataset_path.endswith('.parquet'):
                self.df = pd.read_parquet(dataset_path)
            elif dataset_path.endswith('.csv'):
                self.df = pd.read_csv(dataset_path)
            elif dataset_path.endswith('.json'):
                self.df = pd.read_json(dataset_path)
            else:
                # 尝试通用读取
                try:
                    self.df = pd.read_csv(dataset_path)
                except:
                    raise ValueError(f"Unsupported file format for: {dataset_path}")
        except Exception as e:
            raise RuntimeError(f"Failed to read file {dataset_path}: {e}")

        # 2. 基础清洗:确保核心输入不为空
        initial_len = len(self.df)
        self.df = self.df.dropna(subset=['compound__smiles', 'target__foldseek_seq'])
        
        # 3. 质量控制清洗
        # 过滤掉细胞毒性干扰 (viability_flag) 和 频繁击中者 (frequency_flag)
        if 'viability_flag' in self.df.columns:
             self.df = self.df[self.df['viability_flag'] != True]
        if 'frequency_flag' in self.df.columns:
             self.df = self.df[self.df['frequency_flag'] != True]
        print(f"Dataset loaded. Filtered {initial_len - len(self.df)} rows. Remaining: {len(self.df)}")
        
        # 4. 预处理数据以加速 __getitem__
        # 重置索引,确保通过 idx 访问是连续的
        self.df = self.df.reset_index(drop=True)
        self.smiles = self.df['compound__smiles'].values
        self.proteins = self.df['target__foldseek_seq'].values
        
        # 处理 Potency (pXC50)
        # 将非active的置0
        self.potency = self.df['outcome_potency_pxc50'].fillna(0.0).values
        # 将active但缺失potency的置min_pxc50
        self.is_active = self.df['outcome_is_active'].values.astype(float) # 0.0 or 1.0
        mask_active_no_potency = (self.is_active == 1.0) & (self.potency == 0.0)
        self.potency[mask_active_no_potency] = min_pxc50

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        """
        返回一个样本字典。
        注意:这里返回的是原始字符串,建议在 DataLoader 的 collate_fn 中进行 Tokenizer 处理。
        """
        smiles = str(self.smiles[idx])
        protein_seq = str(self.proteins[idx])
        label = float(self.is_active[idx])
        score = float(self.potency[idx])
        
        return {
            'molecule_str': smiles,       # 用于 Molecule Encoder
            'protein_str': protein_seq,   # 用于 Protein Encoder
            'label': label,               # 0 或 1 (用于 BCE Loss / Contrastive Mask)
            'score': score                # pXC50 值 (用于 Regression Loss 或 Margin Ranking)
        }

class DualTowerCollator:
    def __init__(self, protein_tokenizer, molecule_tokenizer, max_prot_len=1024, max_mol_len=512):
        self.protein_tokenizer = protein_tokenizer
        self.molecule_tokenizer = molecule_tokenizer
        self.max_prot_len = max_prot_len
        self.max_mol_len = max_mol_len

    def __call__(self, batch: List[Dict]) -> Dict:
        # 1. 提取文本列表
        molecule_strs = [item['molecule_str'] for item in batch]
        protein_strs = [item['protein_str'] for item in batch]
        labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
        scores = torch.tensor([item['score'] for item in batch], dtype=torch.float)

        # 2. Tokenize Protein
        # 注意:这里假设 tokenizer 是 HuggingFace 格式
        prot_inputs = self.protein_tokenizer(
            protein_strs,
            padding=True,
            truncation=True,
            max_length=self.max_prot_len,
            return_tensors='pt'
        )

        # 3. Tokenize Molecule
        
        mol_inputs = self.molecule_tokenizer(
            molecule_strs,
            padding=True,
            truncation=True,
            max_length=self.max_mol_len,
            return_tensors='pt'
        )

        return {
            'protein_inputs': prot_inputs,
            'molecule_inputs': mol_inputs,
            'labels': labels,
            'scores': scores
        }