import pandas as pd import numpy as np import itertools import torch from torch.utils.data import Dataset import re import json from typing import Literal import os # import io from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem.Draw import rdMolDraw2D # from PIL import Image import torchvision.io as tvio # import torchvision.transforms as tvt import torchvision.transforms.v2.functional as tvtF # --- 辅助函数 --- # 定义20种常见氨基酸字母(按字母顺序) AMINO_ACIDS = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] AA_to_index = {aa: i for i, aa in enumerate(AMINO_ACIDS)} valid_aa = set(AMINO_ACIDS) def is_valid_sequence(seq): """ 判断序列是否只包含标准氨基酸字符(允许大写或小写, 对于小写表示 D 型氨基酸也视为合法) """ for ch in seq: if not ch.isalpha(): return False if ch.upper() not in valid_aa: return False return True def parse_mic(mic_str): """ 解析 MIC 数据,支持以下几种格式: 1. 数字,例如 "5" -> 5.0 2. ">{数字}" 或 "≥{数字}"(例如 ">4" 或 "≥ 4")→ 数值乘以 1.5 3. 平均值±标准差,例如 "3.2 ± 0.4" → 取平均值 3.2 4. 范围形式,例如 "2.0 - 4.0" → (2.0 + 4.0)/2 注:符号与数字之间可能存在空格,大于等于符号为 "≥" 而非 ">=" """ if not isinstance(mic_str, str): return float(mic_str) mic_str = mic_str.strip() mic_str = re.sub(r'\s+', '', mic_str) # 匹配纯数字 if re.fullmatch(r'\d+(\.\d+)?', mic_str): return float(mic_str) # 匹配 >{数字} 或 ≥{数字} m = re.fullmatch(r'[>≥](\d+(\.\d+)?)', mic_str) if m: num = float(m.group(1)) return num * 1.5 # 匹配 <{数字} 或 ≤{数字} m = re.fullmatch(r'[<≤](\d+(\.\d+)?)', mic_str) if m: num = float(m.group(1)) return num * 0.75 # 匹配 {数字}±{数字} m = re.fullmatch(r'(\d+(\.\d+)?)[±](\d+(\.\d+)?)', mic_str) if m: return float(m.group(1)) # 匹配 {数字}-{数字} m = re.fullmatch(r'(\d+(\.\d+)?)-(\d+(\.\d+)?)', mic_str) if m: num1 = float(m.group(1)) num2 = float(m.group(3)) return (num1 + num2) / 2.0 print(f"Warning: 无法解析 MIC 值 {mic_str}") return np.nan def encode_sequence(seq, pad_length): """ 将多肽序列转换为固定大小 (pad_length, 21) 的张量: - 每个残基对应一行; - 第1列: 表示是否为 D 型氨基酸(若字符为小写,则置 1,否则为 0); - 后20列: 20种常见氨基酸的独热编码(先转为大写匹配)。 若序列长度小于 pad_length,则在末尾填充全 0 行。 """ n = len(seq) arr = np.zeros((pad_length, 21), dtype=np.float32) # 对实际序列部分进行编码 for i, char in enumerate(seq): if i >= pad_length: break # 超出部分不处理(数据集构造时已过滤掉长序列) if char.islower(): d_indicator = 1.0 aa = char.upper() else: d_indicator = 0.0 aa = char arr[i, 0] = d_indicator if aa in AA_to_index: idx = AA_to_index[aa] arr[i, idx + 1] = 1.0 else: print(f"Warning: 氨基酸 {aa} 不在标准列表中") return torch.tensor(arr) def geometric_mean(values): """ 计算数值序列的几何平均值 """ log_vals = np.log(np.array(values)) return float(np.exp(log_vals.mean())) def process_label(ratio, task): """ 对比值 ratio 进行 log2 变换,并根据 task 参数返回最终标签: - task="reg": 返回 log₂比值,并转换为 np.float32; - task="cls": 根据 log₂比值进行分类: 如果 x <= -0.5 返回 1, 否则返回 0. 若 ratio 非正,返回 np.nan。 """ if ratio <= 0: return np.nan ratio_log = np.log2(ratio) if task == "reg": return np.float32(ratio_log) elif task == "cls": if ratio_log < 0.: return 1 else: return 0 else: raise ValueError("未知的 task 类型,请使用 'reg' 或 'cls'") # --- 数据预处理与构建数据集 --- def load_data(xlsx_file, condition=None): """ 从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均, 并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。 返回: groups: dict,其中 key 为原型序列, value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"), value 为该变种所有 MIC 值的几何平均 """ df = pd.read_excel(xlsx_file) # df = df[df['TARGET ACTIVITY - ACTIVITY MEASURE VALUE'] != 'MBC'] groups = {} for _, row in df.iterrows(): orig = row["SEQUENCE - Original"] variant = row["SEQUENCE - D-type amino acid substitution"] mic_raw = row["TARGET ACTIVITY - CONCENTRATION"] # 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查) if not (isinstance(orig, str) and is_valid_sequence(orig)): continue if not (isinstance(variant, str) and is_valid_sequence(variant)): continue mic_val = parse_mic(mic_raw) if orig not in groups: groups[orig] = {} if variant not in groups[orig]: groups[orig][variant] = [] groups[orig][variant].append(mic_val) # 对每个变种计算几何平均(过滤掉 NaN 值) groups_avg = {} for orig, var_dict in groups.items(): groups_avg[orig] = {} for variant, mic_list in var_dict.items(): mic_list = [x for x in mic_list if not np.isnan(x)] if len(mic_list) == 0: continue groups_avg[orig][variant] = geometric_mean(mic_list) return groups_avg def load_data_stability(xlsx_file, condition): """ 从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均, 并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。 返回: groups: dict,其中 key 为原型序列, value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"), value 为该变种所有 MIC 值的几何平均 """ map_dict = { '125fbs': '12.5% FBS', '25fbs': '25% FBS', 'mhb': 'MHB', 'nacl': '150mM NaCl' } df = pd.read_excel(xlsx_file) df = df[df['Condition'] == map_dict[condition]] groups = {} for _, row in df.iterrows(): variant = row["SEQUENCE"] orig = variant.upper() mic_raw = row["Activity"] # 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查) if not (isinstance(orig, str) and is_valid_sequence(orig)): continue if not (isinstance(variant, str) and is_valid_sequence(variant)): continue mic_val = parse_mic(mic_raw) if orig not in groups: groups[orig] = {} if variant not in groups[orig]: groups[orig][variant] = [] groups[orig][variant].append(mic_val) # 对每个变种计算几何平均(过滤掉 NaN 值) groups_avg = {} for orig, var_dict in groups.items(): groups_avg[orig] = {} for variant, mic_list in var_dict.items(): mic_list = [x for x in mic_list if not np.isnan(x)] if len(mic_list) == 0: continue groups_avg[orig][variant] = geometric_mean(mic_list) return groups_avg class PeptidePairDataset(Dataset): def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="cls", include_reverse=False, include_self=False, one_way=False, gf=False) : """ 构建数据集: - 从 xlsx 文件中读取数据,并按照原型分组, 同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本; - 对于同一原型下不同变种构成配对; - 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A)); - 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0); - 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务, 将 log₂比值转为整数标签,规则为: log₂比值 ≤ -0.5 → 1, log₂比值 ≥ 0.5 → 2, -0.5 < log₂比值 < 0.5 → 0. 每个数据项返回: - 变种多肽序列编码后的张量,形状为 (pad_length, 21) - 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21) - 标签:根据 task 不同分别为 32 位浮点数或整数 """ if mode == "train": loader = load_data xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx') elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]: one_way = True if mode in ["test", "r2_case", 'r2_case_']: loader = load_data xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx') else: loader = load_data_stability xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx') else: raise ValueError("未知的 mode,请使用 'train' 或 'test'") self.data = [] self.pad_length = pad_length self.task = task groups_avg = loader(xlsx_file, mode) if gf: gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) # 针对每个原型,过滤掉长度超过 pad_length 的变种 for orig, variant_dict in groups_avg.items(): # a = len(self.data) filtered_variants = {variant: mic for variant, mic in variant_dict.items() if len(variant) <= pad_length} variants = list(filtered_variants.keys()) n_variants = len(variants) if n_variants == 0: continue if gf: glob_feat = gf_dict[orig.upper()] # 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0) if include_self and (not one_way): for variant in variants: encoded_seq = encode_sequence(variant, pad_length) label = process_label(1.0, task) # log2(1)=0 if gf: self.data.append(((encoded_seq, encoded_seq, glob_feat), label)) else: self.data.append(((encoded_seq, encoded_seq), label)) # 添加不同变种之间的样本 for i in [0] if one_way else range(n_variants): for j in range(i + 1, n_variants): var1 = variants[i] var2 = variants[j] mic1 = filtered_variants[var1] mic2 = filtered_variants[var2] # 正向组合: (var1, var2) 标签为 log₂(mic2/mic1) ratio = mic2 / mic1 if mic1 != 0 else np.nan label = process_label(ratio, task) if np.isnan(label): continue encoded_var1 = encode_sequence(var1, pad_length) encoded_var2 = encode_sequence(var2, pad_length) if gf: self.data.append(((encoded_var1, encoded_var2, glob_feat), label)) else: self.data.append(((encoded_var1, encoded_var2), label)) # 若启用正反组合,则添加 (var2, var1) if include_reverse and (not one_way): rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan rev_label = process_label(rev_ratio, task) if gf: self.data.append(((encoded_var2, encoded_var1, glob_feat), rev_label)) else: self.data.append(((encoded_var2, encoded_var1), rev_label)) # b = len(self.data) # print(f"{orig},{b - a}") def reg_sample_weight(self): y = [] for _, label in self.data: y.append(label) y = np.array(y) mu = np.mean(y) sigma = np.std(y) p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2)) # 如果未提供 C,则使用 p 的中位数作为基准常数 C = np.median(p) epsilon = 1e-6 # 使用对数转化计算采样权重: p 值越低权重越高 weights = np.log(C / (p + epsilon)) # 可选:对权重进行归一化处理,使得权重均值为1 weights_normalized = weights / np.mean(weights) positive_weights = np.exp(weights_normalized) return torch.tensor(positive_weights, dtype=torch.float32) def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class PeptidePairPicDataset(Dataset): def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="reg", include_reverse=False, include_self=False, one_way=False, gf=False, side_enc=None, pcs=False, resize=None) : """ 构建数据集: - 从 xlsx 文件中读取数据,并按照原型分组, 同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本; - 对于同一原型下不同变种构成配对; - 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A)); - 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0); - 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务, 将 log₂比值转为整数标签,规则为: log₂比值 ≤ -0.5 → 1, log₂比值 ≥ 0.5 → 2, -0.5 < log₂比值 < 0.5 → 0. 每个数据项返回: - 变种多肽序列编码后的张量,形状为 (pad_length, 21) - 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21) - 标签:根据 task 不同分别为 32 位浮点数或整数 """ if mode == "train": loader = load_data xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx') elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]: one_way = True if mode in ["test", "r2_case", 'r2_case_']: loader = load_data xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx') else: loader = load_data_stability xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx') else: raise ValueError("未知的 mode,请使用 'train' 或 'test'") self.data = [] self.pics = {} self.pad_length = pad_length self.task = task self.gf = gf self.side_enc = True if side_enc else False self.pcs = pcs self.resize = resize groups_avg = loader(xlsx_file, mode) if gf: gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) # 针对每个原型,过滤掉长度超过 pad_length 的变种 for orig, variant_dict in groups_avg.items(): # a = len(self.data) filtered_variants = {variant: mic for variant, mic in variant_dict.items() if len(variant) <= pad_length} variants = list(filtered_variants.keys()) for variant in variants: if self.pcs == 'mix' and variant == orig: self.pics[variant] = self.read_img(variant, False) else: self.pics[variant] = self.read_img(variant, self.pcs) n_variants = len(variants) if n_variants == 0: continue if gf: glob_feat = gf_dict[orig.upper()] # 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0) if include_self and (not one_way): for variant in variants: label = process_label(1.0, task) # log2(1)=0 if gf: self.data.append((variant, variant, glob_feat, label)) else: self.data.append((variant, variant, label)) # 添加不同变种之间的样本 for i in [0] if one_way else range(n_variants): for j in range(i + 1, n_variants): var1 = variants[i] var2 = variants[j] mic1 = filtered_variants[var1] mic2 = filtered_variants[var2] # 正向组合: (var1, var2) 标签为 log₂(mic2/mic1) ratio = mic2 / mic1 if mic1 != 0 else np.nan label = process_label(ratio, task) if np.isnan(label): continue if gf: self.data.append((var1, var2, glob_feat, label)) else: self.data.append((var1, var2, label)) # 若启用正反组合,则添加 (var2, var1) if include_reverse and (not one_way): rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan rev_label = process_label(rev_ratio, task) if gf: self.data.append((var2, var1, glob_feat, rev_label)) else: self.data.append((var2, var1, rev_label)) # b = len(self.data) # print(f"{orig},{b - a}") def reg_sample_weight(self): y = [] for d in self.data: label = d[-1] y.append(label) y = np.array(y) mu = np.mean(y) sigma = np.std(y) p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2)) # 如果未提供 C,则使用 p 的中位数作为基准常数 C = np.median(p) epsilon = 1e-6 # 使用对数转化计算采样权重: p 值越低权重越高 weights = np.log(C / (p + epsilon)) # 可选:对权重进行归一化处理,使得权重均值为1 weights_normalized = weights / np.mean(weights) positive_weights = np.exp(weights_normalized) return torch.tensor(positive_weights, dtype=torch.float32) def read_img(self, peptide, pcs): image = draw_peptide(peptide, self.resize, pcs) return image def __len__(self): return len(self.data) def __getitem__(self, idx): if self.gf: seq1, seq2, glob_feat, label = self.data[idx] else: seq1, seq2, label = self.data[idx] img1 = self.pics[seq1] img2 = self.pics[seq2] if self.side_enc: img1 = (img1, encode_sequence(seq1, self.pad_length)) img2 = (img2, encode_sequence(seq2, self.pad_length)) if self.gf: return (img1, img2, glob_feat), label else: return (img1, img2), label class SimplePairClsDataset(Dataset): def __init__(self, pad_length=30, llm=False, ftr2=False, gf=False, q_encoder=None, side_enc=None, pcs=False, resize=None): if llm: file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set_llm_aug.json') elif ftr2: file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'finetune_for_r2_llm.json') else: file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set.json') with open(file_path, 'r', encoding='utf-8') as f: dataset = json.load(f) self.data = [] self.pics = {} self.pad_length = pad_length self.gf = gf self.q_encoder = q_encoder self.side_enc = True if side_enc else False self.pcs = pcs self.resize = resize if gf: self.gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) all_seqs = [] for orig, variants in dataset.items(): if len(orig) > pad_length: continue all_seqs.append(orig) for label in ["1", "0"]: for variant in variants[label]: self.data.append((orig, variant, int(label))) all_seqs.append(variant) if q_encoder in ['cnn', 'rn18']: for i in all_seqs: if self.pcs == 'mix' and i.isupper(): self.pics[i] = self.read_img(i, False) else: self.pics[i] = self.read_img(i, self.pcs) def read_img(self, peptide, pcs): image = draw_peptide(peptide, self.resize, pcs) return image def __len__(self): return len(self.data) def __getitem__(self, idx): seq1, seq2, label = self.data[idx] if self.q_encoder in ['cnn', 'rn18']: img1 = self.pics[seq1] img2 = self.pics[seq2] if self.side_enc: img1 = (img1, encode_sequence(seq1, self.pad_length)) img2 = (img2, encode_sequence(seq2, self.pad_length)) else: img1 = encode_sequence(seq1, self.pad_length) img2 = encode_sequence(seq2, self.pad_length) if self.gf: return (img1, img2, self.gf_dict[seq1]), label else: return (img1, img2), label class PeptidePairCaseDataset(Dataset): def __init__(self, case:str ='r2', pad_length=30, gf=False): if case == 'r2': self.template = 'KWKIKWPVKWFKML' elif case == 'Indolicidin': self.template = 'ILPWKWPWWPWRR' elif case == 'Temporin-A': self.template = 'FLPLIGRVLSGIL' elif case == 'Melittin': self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ' elif case == 'Anoplin': self.template = 'GLLKRIKTLL' else: self.template = case.upper().strip() self.data = [] self.pad_length = pad_length self.gf = gf if gf: self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template] pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template] # 笛卡尔积,即所有组合 self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:] self.template_seq = encode_sequence(self.template, self.pad_length) def __len__(self): return len(self.variants) def __getitem__(self, idx): variant = self.variants[idx] seq2, label = variant, variant enc_seq1 = self.template_seq enc_seq2 = encode_sequence(seq2, self.pad_length) if self.gf: return (enc_seq1, enc_seq2, self.glob_feat), label else: return (enc_seq1, enc_seq2), label class PeptidePairPicCaseDataset(Dataset): def __init__(self, case:str ='r2', pad_length=30, side_enc=None, pcs=False, resize=None, gf=False): if case == 'r2': self.template = 'KWKIKWPVKWFKML' elif case == 'Indolicidin': self.template = 'ILPWKWPWWPWRR' elif case == 'Temporin-A': self.template = 'FLPLIGRVLSGIL' elif case == 'Melittin': self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ' elif case == 'Anoplin': self.template = 'GLLKRIKTLL' else: self.template = case.upper().strip() self.data = [] self.pad_length = pad_length self.side_enc = True if side_enc else False self.pcs = pcs self.resize = resize self.gf = gf if gf: self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template] pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template] # 笛卡尔积,即所有组合 self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:] self.template_pic = self.read_img(self.template) if self.side_enc: self.template_seq = encode_sequence(self.template, self.pad_length) def read_img(self, peptide): image = draw_peptide(peptide, self.resize, self.pcs) return image def __len__(self): return len(self.variants) def __getitem__(self, idx): variant = self.variants[idx] seq2, label = variant, variant img1 = self.template_pic img2 = self.read_img(variant) if self.side_enc: img1 = (img1, self.template_seq) img2 = (img2, encode_sequence(seq2, self.pad_length)) if self.gf: return (img1, img2, self.glob_feat), label else: return (img1, img2), label aa_side = { "A": "C", "R": "CCCNC(N)=N", "N": "CC(=O)N", "D": "CC(=O)O", "C": "CS", "E": "CCC(=O)O", "Q": "CCC(=O)N", "G": "", "H": "Cc1cnc[nH]1", "I": "C(C)CC", "L": "CC(C)C", "K": "CCCCN", "M": "CCSC", "F": "Cc1ccccc1", "P": "C1CCN1", "S": "CO", "T": "C(C)O", "W": "Cc1c[nH]c2ccccc12", "Y": "Cc1ccc(O)cc1", "V": "C(C)C" } aa_tpl = {} for aa, R in aa_side.items(): for stereo, chir in (("L", "@"), ("D", "@@")): if aa == "G": # Gly 没手性 backbone = "N[C:{idx}]C" # N-CA(带编号)-C else: backbone = f"N[C{chir}H:{'{idx}'}]({R})C" # N-[C@H:idx](R)-C aa_tpl[f"{aa}_{stereo}"] = backbone + "(=O)" # 中间残基 aa_tpl[f"{aa}_{stereo}_term"] = backbone + "(=O)O" # C 端 def build_peptide_smiles(seq: str) -> str: """ 给定单字母序列,返回 backbone 带 [atom_map] 的 SMILES。 大写 = L 型, 小写 = D 型。编号 = 残基序号(1,2,3...) -> α-碳。 """ if not seq: return "" out = [] n = len(seq) for i, aa in enumerate(seq, start=1): key = f"{aa.upper()}_{'L' if aa.isupper() else 'D'}" if i == n: key += "_term" out.append(aa_tpl[key].format(idx=i)) return "".join(out) protease_patterns = { 'trypsin': re.compile(r'(?<=[KR])(?!P)'), 'chymotrypsin': re.compile(r'(?<=[FYWL])(?!P)'), 'elastase': re.compile(r'(?<=[AVSGT])(?!P)'), 'enterokinase': re.compile(r'D{4}K(?=[^P])'), 'caspase': re.compile(r'(?<=D)(?=[GSA])'), } def draw_peptide(sequence, size=[768], pcs=False): """ 根据输入序列生成多肽结构图,并基于常见蛋白酶识别模式高亮酶切位点肽键(红色)。 支持的酶及其正则模式(P1--P1'): • trypsin: (?<=[KR])(?!P) • chymotrypsin: (?<=[FYWL])(?!P) • elastase: (?<=[AVSGT])(?!P) • enterokinase: D{4}K(?=[^P]) • caspase: (?<=D)(?=[GSA]) """ # # 1. 生成带 atom map 的 SMILES(现在序号标注在α-碳上) smiles = build_peptide_smiles(sequence) mol = Chem.MolFromSmiles(smiles) # if mol is None: # raise ValueError("SMILES 解析失败,请检查输入序列和侧链字典。") AllChem.Compute2DCoords(mol) highlight_bonds = [] bond_colors = {} # ---------------------------------------------------- # 2. 先标 D 型残基:高亮与α-碳相连的键为蓝色 d_positions = {i for i, aa in enumerate(sequence, start=1) if aa.islower()} for atom in mol.GetAtoms(): if atom.GetAtomMapNum() in d_positions: # 这个atom就是α-碳,高亮与它相连的所有键 for b in atom.GetBonds(): idx = b.GetIdx() if idx not in highlight_bonds: highlight_bonds.append(idx) bond_colors[idx] = (0.0, 0.0, 1.0) # ---------------------------------------------------- # 3. 再标酶切键:红色(覆盖之前的蓝色) if pcs: cleavage_sites = set() for pat in protease_patterns.values(): for m in pat.finditer(sequence): cut = m.end() # 切在 cut 之后 if 1 <= cut < len(sequence): cleavage_sites.add(cut) for pos in cleavage_sites: # 先找 P1 残基的 α-C ca = next((a for a in mol.GetAtoms() if a.GetAtomMapNum() == pos), None) if ca is None: continue # 找同残基的羧基碳 (sp², 含 O 双键) carbonyl_c = None for nb in ca.GetNeighbors(): if nb.GetSymbol() != "C": continue # 判断是否有 "=O" if any(bond.GetBondType() == Chem.BondType.DOUBLE and o.GetSymbol() == "O" for bond in nb.GetBonds() for o in (bond.GetBeginAtom(), bond.GetEndAtom())): carbonyl_c = nb break if carbonyl_c is None: continue # 羧基碳连到的 N 就是下一残基的氮 peptide_bond = None for b in carbonyl_c.GetBonds(): o_atom = b.GetOtherAtom(carbonyl_c) if o_atom.GetSymbol() == "N": peptide_bond = b break if peptide_bond is None: continue bidx = peptide_bond.GetIdx() if bidx not in highlight_bonds: highlight_bonds.append(bidx) bond_colors[bidx] = (1.0, 0.0, 0.0) # 红 # 4. 设置画布大小 if len(size) == 1: w = h = size[0] else: w, h = size # 5. MolDraw2DCairo 接收 highlightBondColors drawer = rdMolDraw2D.MolDraw2DCairo(w, h) # 你也可以通过 drawer.drawOptions() 调整一些样式:bond line width、atom font 等 drawer.DrawMolecule( mol, highlightAtoms=[], highlightBonds=highlight_bonds, highlightAtomColors={}, highlightBondColors=bond_colors ) drawer.FinishDrawing() # 6. 把输出的 PNG bytes 转成 Tensor png_bytes = bytearray(drawer.GetDrawingText()) byte_tensor = torch.frombuffer(png_bytes, dtype=torch.uint8) img = tvio.decode_png(byte_tensor, mode=tvio.ImageReadMode.RGB) # [3, H, W], uint8 img = tvtF.to_dtype(img, torch.float32) img = tvtF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return img if __name__ == '__main__': # 假设 xlsx 文件路径为 "data.xlsx" # 设置 pad_length 为 50,同时启用正反组合和自组合 pad_length = 30 dataset = PeptidePairDataset('r2_case', pad_length, "cls", include_reverse=False, include_self=False, one_way=True) # 打印第一个数据项 if len(dataset) > 0: (encoded_seq1, encoded_seq2), ratio = dataset[0] print("第一个样本:") print("变种1的编码张量形状:", encoded_seq1.shape) print("变种2的编码张量形状:", encoded_seq2.shape) print("标签比值(变种2/变种1):", ratio) print(f"数据集大小:{len(dataset)}") label_pos = 0 for (_, _), i in dataset: label_pos += i print(label_pos) else: print("未读入组合数据!") # # 测试 PeptidesDataset # pad_length = 30 # dataset = PeptidesDataset(xlsx_file="./dataset/train.xlsx", pad_length=pad_length) # print(f"PeptidesDataset 样本总数: {len(dataset)}") # if len(dataset) > 0: # encoded_seq, label = dataset[0] # print("第一个样本:") # print("多肽编码张量形状:", encoded_seq.shape) # print("标签浓度值(几何平均后):", label) # else: # print("未读取到有效数据!")