Recommend_system / dataset.py
tong
revise streamlit temp files
aab3f3d
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple
class ProteinMoleculeDataset(Dataset):
def __init__(self, dataset_path, min_pxc50=5.0):
"""
Args:
dataset_path (str): Parquet 文件路径
min_pxc50 (float): 当样本是 Active 但 pXC50 缺失时的默认填充值
"""
super().__init__()
# 1. 加载数据
print(f"Loading dataset from {dataset_path}...")
try:
if dataset_path.endswith('.parquet'):
self.df = pd.read_parquet(dataset_path)
elif dataset_path.endswith('.csv'):
self.df = pd.read_csv(dataset_path)
elif dataset_path.endswith('.json'):
self.df = pd.read_json(dataset_path)
else:
# 尝试通用读取
try:
self.df = pd.read_csv(dataset_path)
except:
raise ValueError(f"Unsupported file format for: {dataset_path}")
except Exception as e:
raise RuntimeError(f"Failed to read file {dataset_path}: {e}")
# 2. 基础清洗:确保核心输入不为空
initial_len = len(self.df)
self.df = self.df.dropna(subset=['compound__smiles', 'target__foldseek_seq'])
# 3. 质量控制清洗
# 过滤掉细胞毒性干扰 (viability_flag) 和 频繁击中者 (frequency_flag)
if 'viability_flag' in self.df.columns:
self.df = self.df[self.df['viability_flag'] != True]
if 'frequency_flag' in self.df.columns:
self.df = self.df[self.df['frequency_flag'] != True]
print(f"Dataset loaded. Filtered {initial_len - len(self.df)} rows. Remaining: {len(self.df)}")
# 4. 预处理数据以加速 __getitem__
# 重置索引,确保通过 idx 访问是连续的
self.df = self.df.reset_index(drop=True)
self.smiles = self.df['compound__smiles'].values
self.proteins = self.df['target__foldseek_seq'].values
# 处理 Potency (pXC50)
# 将非active的置0
self.potency = self.df['outcome_potency_pxc50'].fillna(0.0).values
# 将active但缺失potency的置min_pxc50
self.is_active = self.df['outcome_is_active'].values.astype(float) # 0.0 or 1.0
mask_active_no_potency = (self.is_active == 1.0) & (self.potency == 0.0)
self.potency[mask_active_no_potency] = min_pxc50
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
"""
返回一个样本字典。
注意:这里返回的是原始字符串,建议在 DataLoader 的 collate_fn 中进行 Tokenizer 处理。
"""
smiles = str(self.smiles[idx])
protein_seq = str(self.proteins[idx])
label = float(self.is_active[idx])
score = float(self.potency[idx])
return {
'molecule_str': smiles, # 用于 Molecule Encoder
'protein_str': protein_seq, # 用于 Protein Encoder
'label': label, # 0 或 1 (用于 BCE Loss / Contrastive Mask)
'score': score # pXC50 值 (用于 Regression Loss 或 Margin Ranking)
}
class DualTowerCollator:
def __init__(self, protein_tokenizer, molecule_tokenizer, max_prot_len=1024, max_mol_len=512):
self.protein_tokenizer = protein_tokenizer
self.molecule_tokenizer = molecule_tokenizer
self.max_prot_len = max_prot_len
self.max_mol_len = max_mol_len
def __call__(self, batch: List[Dict]) -> Dict:
# 1. 提取文本列表
molecule_strs = [item['molecule_str'] for item in batch]
protein_strs = [item['protein_str'] for item in batch]
labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)
scores = torch.tensor([item['score'] for item in batch], dtype=torch.float)
# 2. Tokenize Protein
# 注意:这里假设 tokenizer 是 HuggingFace 格式
prot_inputs = self.protein_tokenizer(
protein_strs,
padding=True,
truncation=True,
max_length=self.max_prot_len,
return_tensors='pt'
)
# 3. Tokenize Molecule
mol_inputs = self.molecule_tokenizer(
molecule_strs,
padding=True,
truncation=True,
max_length=self.max_mol_len,
return_tensors='pt'
)
return {
'protein_inputs': prot_inputs,
'molecule_inputs': mol_inputs,
'labels': labels,
'scores': scores
}