Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.utils.data import Dataset | |
| from torch.nn.utils.rnn import pad_sequence | |
| import pandas as pd | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| class ProteinMoleculeDataset(Dataset): | |
| def __init__(self, dataset_path, min_pxc50=5.0): | |
| """ | |
| Args: | |
| dataset_path (str): Parquet 文件路径 | |
| min_pxc50 (float): 当样本是 Active 但 pXC50 缺失时的默认填充值 | |
| """ | |
| super().__init__() | |
| # 1. 加载数据 | |
| print(f"Loading dataset from {dataset_path}...") | |
| try: | |
| if dataset_path.endswith('.parquet'): | |
| self.df = pd.read_parquet(dataset_path) | |
| elif dataset_path.endswith('.csv'): | |
| self.df = pd.read_csv(dataset_path) | |
| elif dataset_path.endswith('.json'): | |
| self.df = pd.read_json(dataset_path) | |
| else: | |
| # 尝试通用读取 | |
| try: | |
| self.df = pd.read_csv(dataset_path) | |
| except: | |
| raise ValueError(f"Unsupported file format for: {dataset_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to read file {dataset_path}: {e}") | |
| # 2. 基础清洗:确保核心输入不为空 | |
| initial_len = len(self.df) | |
| self.df = self.df.dropna(subset=['compound__smiles', 'target__foldseek_seq']) | |
| # 3. 质量控制清洗 | |
| # 过滤掉细胞毒性干扰 (viability_flag) 和 频繁击中者 (frequency_flag) | |
| if 'viability_flag' in self.df.columns: | |
| self.df = self.df[self.df['viability_flag'] != True] | |
| if 'frequency_flag' in self.df.columns: | |
| self.df = self.df[self.df['frequency_flag'] != True] | |
| print(f"Dataset loaded. Filtered {initial_len - len(self.df)} rows. Remaining: {len(self.df)}") | |
| # 4. 预处理数据以加速 __getitem__ | |
| # 重置索引,确保通过 idx 访问是连续的 | |
| self.df = self.df.reset_index(drop=True) | |
| self.smiles = self.df['compound__smiles'].values | |
| self.proteins = self.df['target__foldseek_seq'].values | |
| # 处理 Potency (pXC50) | |
| # 将非active的置0 | |
| self.potency = self.df['outcome_potency_pxc50'].fillna(0.0).values | |
| # 将active但缺失potency的置min_pxc50 | |
| self.is_active = self.df['outcome_is_active'].values.astype(float) # 0.0 or 1.0 | |
| mask_active_no_potency = (self.is_active == 1.0) & (self.potency == 0.0) | |
| self.potency[mask_active_no_potency] = min_pxc50 | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| """ | |
| 返回一个样本字典。 | |
| 注意:这里返回的是原始字符串,建议在 DataLoader 的 collate_fn 中进行 Tokenizer 处理。 | |
| """ | |
| smiles = str(self.smiles[idx]) | |
| protein_seq = str(self.proteins[idx]) | |
| label = float(self.is_active[idx]) | |
| score = float(self.potency[idx]) | |
| return { | |
| 'molecule_str': smiles, # 用于 Molecule Encoder | |
| 'protein_str': protein_seq, # 用于 Protein Encoder | |
| 'label': label, # 0 或 1 (用于 BCE Loss / Contrastive Mask) | |
| 'score': score # pXC50 值 (用于 Regression Loss 或 Margin Ranking) | |
| } | |
| class DualTowerCollator: | |
| def __init__(self, protein_tokenizer, molecule_tokenizer, max_prot_len=1024, max_mol_len=512): | |
| self.protein_tokenizer = protein_tokenizer | |
| self.molecule_tokenizer = molecule_tokenizer | |
| self.max_prot_len = max_prot_len | |
| self.max_mol_len = max_mol_len | |
| def __call__(self, batch: List[Dict]) -> Dict: | |
| # 1. 提取文本列表 | |
| molecule_strs = [item['molecule_str'] for item in batch] | |
| protein_strs = [item['protein_str'] for item in batch] | |
| labels = torch.tensor([item['label'] for item in batch], dtype=torch.long) | |
| scores = torch.tensor([item['score'] for item in batch], dtype=torch.float) | |
| # 2. Tokenize Protein | |
| # 注意:这里假设 tokenizer 是 HuggingFace 格式 | |
| prot_inputs = self.protein_tokenizer( | |
| protein_strs, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_prot_len, | |
| return_tensors='pt' | |
| ) | |
| # 3. Tokenize Molecule | |
| mol_inputs = self.molecule_tokenizer( | |
| molecule_strs, | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_mol_len, | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'protein_inputs': prot_inputs, | |
| 'molecule_inputs': mol_inputs, | |
| 'labels': labels, | |
| 'scores': scores | |
| } |