import torch from torch.utils.data import Dataset from torch.nn.utils.rnn import pad_sequence import pandas as pd import numpy as np from typing import List, Dict, Tuple class ProteinMoleculeDataset(Dataset): def __init__(self, dataset_path, min_pxc50=5.0): """ Args: dataset_path (str): Parquet 文件路径 min_pxc50 (float): 当样本是 Active 但 pXC50 缺失时的默认填充值 """ super().__init__() # 1. 加载数据 print(f"Loading dataset from {dataset_path}...") try: if dataset_path.endswith('.parquet'): self.df = pd.read_parquet(dataset_path) elif dataset_path.endswith('.csv'): self.df = pd.read_csv(dataset_path) elif dataset_path.endswith('.json'): self.df = pd.read_json(dataset_path) else: # 尝试通用读取 try: self.df = pd.read_csv(dataset_path) except: raise ValueError(f"Unsupported file format for: {dataset_path}") except Exception as e: raise RuntimeError(f"Failed to read file {dataset_path}: {e}") # 2. 基础清洗:确保核心输入不为空 initial_len = len(self.df) self.df = self.df.dropna(subset=['compound__smiles', 'target__foldseek_seq']) # 3. 质量控制清洗 # 过滤掉细胞毒性干扰 (viability_flag) 和 频繁击中者 (frequency_flag) if 'viability_flag' in self.df.columns: self.df = self.df[self.df['viability_flag'] != True] if 'frequency_flag' in self.df.columns: self.df = self.df[self.df['frequency_flag'] != True] print(f"Dataset loaded. Filtered {initial_len - len(self.df)} rows. Remaining: {len(self.df)}") # 4. 预处理数据以加速 __getitem__ # 重置索引,确保通过 idx 访问是连续的 self.df = self.df.reset_index(drop=True) self.smiles = self.df['compound__smiles'].values self.proteins = self.df['target__foldseek_seq'].values # 处理 Potency (pXC50) # 将非active的置0 self.potency = self.df['outcome_potency_pxc50'].fillna(0.0).values # 将active但缺失potency的置min_pxc50 self.is_active = self.df['outcome_is_active'].values.astype(float) # 0.0 or 1.0 mask_active_no_potency = (self.is_active == 1.0) & (self.potency == 0.0) self.potency[mask_active_no_potency] = min_pxc50 def __len__(self): return len(self.df) def __getitem__(self, idx): """ 返回一个样本字典。 注意:这里返回的是原始字符串,建议在 DataLoader 的 collate_fn 中进行 Tokenizer 处理。 """ smiles = str(self.smiles[idx]) protein_seq = str(self.proteins[idx]) label = float(self.is_active[idx]) score = float(self.potency[idx]) return { 'molecule_str': smiles, # 用于 Molecule Encoder 'protein_str': protein_seq, # 用于 Protein Encoder 'label': label, # 0 或 1 (用于 BCE Loss / Contrastive Mask) 'score': score # pXC50 值 (用于 Regression Loss 或 Margin Ranking) } class DualTowerCollator: def __init__(self, protein_tokenizer, molecule_tokenizer, max_prot_len=1024, max_mol_len=512): self.protein_tokenizer = protein_tokenizer self.molecule_tokenizer = molecule_tokenizer self.max_prot_len = max_prot_len self.max_mol_len = max_mol_len def __call__(self, batch: List[Dict]) -> Dict: # 1. 提取文本列表 molecule_strs = [item['molecule_str'] for item in batch] protein_strs = [item['protein_str'] for item in batch] labels = torch.tensor([item['label'] for item in batch], dtype=torch.long) scores = torch.tensor([item['score'] for item in batch], dtype=torch.float) # 2. Tokenize Protein # 注意:这里假设 tokenizer 是 HuggingFace 格式 prot_inputs = self.protein_tokenizer( protein_strs, padding=True, truncation=True, max_length=self.max_prot_len, return_tensors='pt' ) # 3. Tokenize Molecule mol_inputs = self.molecule_tokenizer( molecule_strs, padding=True, truncation=True, max_length=self.max_mol_len, return_tensors='pt' ) return { 'protein_inputs': prot_inputs, 'molecule_inputs': mol_inputs, 'labels': labels, 'scores': scores }