DAminoMuta / dataset.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import pandas as pd
import numpy as np
import itertools
import torch
from torch.utils.data import Dataset
import re
import json
from typing import Literal
import os
# import io
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
# from PIL import Image
import torchvision.io as tvio
# import torchvision.transforms as tvt
import torchvision.transforms.v2.functional as tvtF
# --- 辅助函数 ---
# 定义20种常见氨基酸字母(按字母顺序)
AMINO_ACIDS = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
AA_to_index = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
valid_aa = set(AMINO_ACIDS)
def is_valid_sequence(seq):
"""
判断序列是否只包含标准氨基酸字符(允许大写或小写,
对于小写表示 D 型氨基酸也视为合法)
"""
for ch in seq:
if not ch.isalpha():
return False
if ch.upper() not in valid_aa:
return False
return True
def parse_mic(mic_str):
"""
解析 MIC 数据,支持以下几种格式:
1. 数字,例如 "5" -> 5.0
2. ">{数字}" 或 "≥{数字}"(例如 ">4" 或 "≥ 4")→ 数值乘以 1.5
3. 平均值±标准差,例如 "3.2 ± 0.4" → 取平均值 3.2
4. 范围形式,例如 "2.0 - 4.0" → (2.0 + 4.0)/2
注:符号与数字之间可能存在空格,大于等于符号为 "≥" 而非 ">="
"""
if not isinstance(mic_str, str):
return float(mic_str)
mic_str = mic_str.strip()
mic_str = re.sub(r'\s+', '', mic_str)
# 匹配纯数字
if re.fullmatch(r'\d+(\.\d+)?', mic_str):
return float(mic_str)
# 匹配 >{数字} 或 ≥{数字}
m = re.fullmatch(r'[>≥](\d+(\.\d+)?)', mic_str)
if m:
num = float(m.group(1))
return num * 1.5
# 匹配 <{数字} 或 ≤{数字}
m = re.fullmatch(r'[<≤](\d+(\.\d+)?)', mic_str)
if m:
num = float(m.group(1))
return num * 0.75
# 匹配 {数字}±{数字}
m = re.fullmatch(r'(\d+(\.\d+)?)[±](\d+(\.\d+)?)', mic_str)
if m:
return float(m.group(1))
# 匹配 {数字}-{数字}
m = re.fullmatch(r'(\d+(\.\d+)?)-(\d+(\.\d+)?)', mic_str)
if m:
num1 = float(m.group(1))
num2 = float(m.group(3))
return (num1 + num2) / 2.0
print(f"Warning: 无法解析 MIC 值 {mic_str}")
return np.nan
def encode_sequence(seq, pad_length):
"""
将多肽序列转换为固定大小 (pad_length, 21) 的张量:
- 每个残基对应一行;
- 第1列: 表示是否为 D 型氨基酸(若字符为小写,则置 1,否则为 0);
- 后20列: 20种常见氨基酸的独热编码(先转为大写匹配)。
若序列长度小于 pad_length,则在末尾填充全 0 行。
"""
n = len(seq)
arr = np.zeros((pad_length, 21), dtype=np.float32)
# 对实际序列部分进行编码
for i, char in enumerate(seq):
if i >= pad_length:
break # 超出部分不处理(数据集构造时已过滤掉长序列)
if char.islower():
d_indicator = 1.0
aa = char.upper()
else:
d_indicator = 0.0
aa = char
arr[i, 0] = d_indicator
if aa in AA_to_index:
idx = AA_to_index[aa]
arr[i, idx + 1] = 1.0
else:
print(f"Warning: 氨基酸 {aa} 不在标准列表中")
return torch.tensor(arr)
def geometric_mean(values):
"""
计算数值序列的几何平均值
"""
log_vals = np.log(np.array(values))
return float(np.exp(log_vals.mean()))
def process_label(ratio, task):
"""
对比值 ratio 进行 log2 变换,并根据 task 参数返回最终标签:
- task="reg": 返回 log₂比值,并转换为 np.float32;
- task="cls": 根据 log₂比值进行分类:
如果 x <= -0.5 返回 1,
否则返回 0.
若 ratio 非正,返回 np.nan。
"""
if ratio <= 0:
return np.nan
ratio_log = np.log2(ratio)
if task == "reg":
return np.float32(ratio_log)
elif task == "cls":
if ratio_log < 0.:
return 1
else:
return 0
else:
raise ValueError("未知的 task 类型,请使用 'reg' 或 'cls'")
# --- 数据预处理与构建数据集 ---
def load_data(xlsx_file, condition=None):
"""
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均,
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。
返回:
groups: dict,其中 key 为原型序列,
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"),
value 为该变种所有 MIC 值的几何平均
"""
df = pd.read_excel(xlsx_file)
# df = df[df['TARGET ACTIVITY - ACTIVITY MEASURE VALUE'] != 'MBC']
groups = {}
for _, row in df.iterrows():
orig = row["SEQUENCE - Original"]
variant = row["SEQUENCE - D-type amino acid substitution"]
mic_raw = row["TARGET ACTIVITY - CONCENTRATION"]
# 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查)
if not (isinstance(orig, str) and is_valid_sequence(orig)):
continue
if not (isinstance(variant, str) and is_valid_sequence(variant)):
continue
mic_val = parse_mic(mic_raw)
if orig not in groups:
groups[orig] = {}
if variant not in groups[orig]:
groups[orig][variant] = []
groups[orig][variant].append(mic_val)
# 对每个变种计算几何平均(过滤掉 NaN 值)
groups_avg = {}
for orig, var_dict in groups.items():
groups_avg[orig] = {}
for variant, mic_list in var_dict.items():
mic_list = [x for x in mic_list if not np.isnan(x)]
if len(mic_list) == 0:
continue
groups_avg[orig][variant] = geometric_mean(mic_list)
return groups_avg
def load_data_stability(xlsx_file, condition):
"""
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均,
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。
返回:
groups: dict,其中 key 为原型序列,
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"),
value 为该变种所有 MIC 值的几何平均
"""
map_dict = {
'125fbs': '12.5% FBS',
'25fbs': '25% FBS',
'mhb': 'MHB',
'nacl': '150mM NaCl'
}
df = pd.read_excel(xlsx_file)
df = df[df['Condition'] == map_dict[condition]]
groups = {}
for _, row in df.iterrows():
variant = row["SEQUENCE"]
orig = variant.upper()
mic_raw = row["Activity"]
# 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查)
if not (isinstance(orig, str) and is_valid_sequence(orig)):
continue
if not (isinstance(variant, str) and is_valid_sequence(variant)):
continue
mic_val = parse_mic(mic_raw)
if orig not in groups:
groups[orig] = {}
if variant not in groups[orig]:
groups[orig][variant] = []
groups[orig][variant].append(mic_val)
# 对每个变种计算几何平均(过滤掉 NaN 值)
groups_avg = {}
for orig, var_dict in groups.items():
groups_avg[orig] = {}
for variant, mic_list in var_dict.items():
mic_list = [x for x in mic_list if not np.isnan(x)]
if len(mic_list) == 0:
continue
groups_avg[orig][variant] = geometric_mean(mic_list)
return groups_avg
class PeptidePairDataset(Dataset):
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="cls",
include_reverse=False, include_self=False, one_way=False, gf=False) :
"""
构建数据集:
- 从 xlsx 文件中读取数据,并按照原型分组,
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本;
- 对于同一原型下不同变种构成配对;
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A));
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0);
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务,
将 log₂比值转为整数标签,规则为:
log₂比值 ≤ -0.5 → 1,
log₂比值 ≥ 0.5 → 2,
-0.5 < log₂比值 < 0.5 → 0.
每个数据项返回:
- 变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 标签:根据 task 不同分别为 32 位浮点数或整数
"""
if mode == "train":
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx')
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]:
one_way = True
if mode in ["test", "r2_case", 'r2_case_']:
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx')
else:
loader = load_data_stability
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx')
else:
raise ValueError("未知的 mode,请使用 'train' 或 'test'")
self.data = []
self.pad_length = pad_length
self.task = task
groups_avg = loader(xlsx_file, mode)
if gf:
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
# 针对每个原型,过滤掉长度超过 pad_length 的变种
for orig, variant_dict in groups_avg.items():
# a = len(self.data)
filtered_variants = {variant: mic for variant, mic in variant_dict.items()
if len(variant) <= pad_length}
variants = list(filtered_variants.keys())
n_variants = len(variants)
if n_variants == 0:
continue
if gf:
glob_feat = gf_dict[orig.upper()]
# 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0)
if include_self and (not one_way):
for variant in variants:
encoded_seq = encode_sequence(variant, pad_length)
label = process_label(1.0, task) # log2(1)=0
if gf:
self.data.append(((encoded_seq, encoded_seq, glob_feat), label))
else:
self.data.append(((encoded_seq, encoded_seq), label))
# 添加不同变种之间的样本
for i in [0] if one_way else range(n_variants):
for j in range(i + 1, n_variants):
var1 = variants[i]
var2 = variants[j]
mic1 = filtered_variants[var1]
mic2 = filtered_variants[var2]
# 正向组合: (var1, var2) 标签为 log₂(mic2/mic1)
ratio = mic2 / mic1 if mic1 != 0 else np.nan
label = process_label(ratio, task)
if np.isnan(label):
continue
encoded_var1 = encode_sequence(var1, pad_length)
encoded_var2 = encode_sequence(var2, pad_length)
if gf:
self.data.append(((encoded_var1, encoded_var2, glob_feat), label))
else:
self.data.append(((encoded_var1, encoded_var2), label))
# 若启用正反组合,则添加 (var2, var1)
if include_reverse and (not one_way):
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan
rev_label = process_label(rev_ratio, task)
if gf:
self.data.append(((encoded_var2, encoded_var1, glob_feat), rev_label))
else:
self.data.append(((encoded_var2, encoded_var1), rev_label))
# b = len(self.data)
# print(f"{orig},{b - a}")
def reg_sample_weight(self):
y = []
for _, label in self.data:
y.append(label)
y = np.array(y)
mu = np.mean(y)
sigma = np.std(y)
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2))
# 如果未提供 C,则使用 p 的中位数作为基准常数
C = np.median(p)
epsilon = 1e-6
# 使用对数转化计算采样权重: p 值越低权重越高
weights = np.log(C / (p + epsilon))
# 可选:对权重进行归一化处理,使得权重均值为1
weights_normalized = weights / np.mean(weights)
positive_weights = np.exp(weights_normalized)
return torch.tensor(positive_weights, dtype=torch.float32)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class PeptidePairPicDataset(Dataset):
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="reg",
include_reverse=False, include_self=False, one_way=False, gf=False,
side_enc=None, pcs=False, resize=None) :
"""
构建数据集:
- 从 xlsx 文件中读取数据,并按照原型分组,
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本;
- 对于同一原型下不同变种构成配对;
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A));
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0);
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务,
将 log₂比值转为整数标签,规则为:
log₂比值 ≤ -0.5 → 1,
log₂比值 ≥ 0.5 → 2,
-0.5 < log₂比值 < 0.5 → 0.
每个数据项返回:
- 变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 标签:根据 task 不同分别为 32 位浮点数或整数
"""
if mode == "train":
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx')
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]:
one_way = True
if mode in ["test", "r2_case", 'r2_case_']:
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx')
else:
loader = load_data_stability
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx')
else:
raise ValueError("未知的 mode,请使用 'train' 或 'test'")
self.data = []
self.pics = {}
self.pad_length = pad_length
self.task = task
self.gf = gf
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
groups_avg = loader(xlsx_file, mode)
if gf:
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
# 针对每个原型,过滤掉长度超过 pad_length 的变种
for orig, variant_dict in groups_avg.items():
# a = len(self.data)
filtered_variants = {variant: mic for variant, mic in variant_dict.items()
if len(variant) <= pad_length}
variants = list(filtered_variants.keys())
for variant in variants:
if self.pcs == 'mix' and variant == orig:
self.pics[variant] = self.read_img(variant, False)
else:
self.pics[variant] = self.read_img(variant, self.pcs)
n_variants = len(variants)
if n_variants == 0:
continue
if gf:
glob_feat = gf_dict[orig.upper()]
# 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0)
if include_self and (not one_way):
for variant in variants:
label = process_label(1.0, task) # log2(1)=0
if gf:
self.data.append((variant, variant, glob_feat, label))
else:
self.data.append((variant, variant, label))
# 添加不同变种之间的样本
for i in [0] if one_way else range(n_variants):
for j in range(i + 1, n_variants):
var1 = variants[i]
var2 = variants[j]
mic1 = filtered_variants[var1]
mic2 = filtered_variants[var2]
# 正向组合: (var1, var2) 标签为 log₂(mic2/mic1)
ratio = mic2 / mic1 if mic1 != 0 else np.nan
label = process_label(ratio, task)
if np.isnan(label):
continue
if gf:
self.data.append((var1, var2, glob_feat, label))
else:
self.data.append((var1, var2, label))
# 若启用正反组合,则添加 (var2, var1)
if include_reverse and (not one_way):
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan
rev_label = process_label(rev_ratio, task)
if gf:
self.data.append((var2, var1, glob_feat, rev_label))
else:
self.data.append((var2, var1, rev_label))
# b = len(self.data)
# print(f"{orig},{b - a}")
def reg_sample_weight(self):
y = []
for d in self.data:
label = d[-1]
y.append(label)
y = np.array(y)
mu = np.mean(y)
sigma = np.std(y)
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2))
# 如果未提供 C,则使用 p 的中位数作为基准常数
C = np.median(p)
epsilon = 1e-6
# 使用对数转化计算采样权重: p 值越低权重越高
weights = np.log(C / (p + epsilon))
# 可选:对权重进行归一化处理,使得权重均值为1
weights_normalized = weights / np.mean(weights)
positive_weights = np.exp(weights_normalized)
return torch.tensor(positive_weights, dtype=torch.float32)
def read_img(self, peptide, pcs):
image = draw_peptide(peptide, self.resize, pcs)
return image
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if self.gf:
seq1, seq2, glob_feat, label = self.data[idx]
else:
seq1, seq2, label = self.data[idx]
img1 = self.pics[seq1]
img2 = self.pics[seq2]
if self.side_enc:
img1 = (img1, encode_sequence(seq1, self.pad_length))
img2 = (img2, encode_sequence(seq2, self.pad_length))
if self.gf:
return (img1, img2, glob_feat), label
else:
return (img1, img2), label
class SimplePairClsDataset(Dataset):
def __init__(self, pad_length=30, llm=False, ftr2=False, gf=False,
q_encoder=None, side_enc=None, pcs=False, resize=None):
if llm:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set_llm_aug.json')
elif ftr2:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'finetune_for_r2_llm.json')
else:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set.json')
with open(file_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
self.data = []
self.pics = {}
self.pad_length = pad_length
self.gf = gf
self.q_encoder = q_encoder
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
if gf:
self.gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
all_seqs = []
for orig, variants in dataset.items():
if len(orig) > pad_length:
continue
all_seqs.append(orig)
for label in ["1", "0"]:
for variant in variants[label]:
self.data.append((orig, variant, int(label)))
all_seqs.append(variant)
if q_encoder in ['cnn', 'rn18']:
for i in all_seqs:
if self.pcs == 'mix' and i.isupper():
self.pics[i] = self.read_img(i, False)
else:
self.pics[i] = self.read_img(i, self.pcs)
def read_img(self, peptide, pcs):
image = draw_peptide(peptide, self.resize, pcs)
return image
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
seq1, seq2, label = self.data[idx]
if self.q_encoder in ['cnn', 'rn18']:
img1 = self.pics[seq1]
img2 = self.pics[seq2]
if self.side_enc:
img1 = (img1, encode_sequence(seq1, self.pad_length))
img2 = (img2, encode_sequence(seq2, self.pad_length))
else:
img1 = encode_sequence(seq1, self.pad_length)
img2 = encode_sequence(seq2, self.pad_length)
if self.gf:
return (img1, img2, self.gf_dict[seq1]), label
else:
return (img1, img2), label
class PeptidePairCaseDataset(Dataset):
def __init__(self, case:str ='r2', pad_length=30, gf=False):
if case == 'r2':
self.template = 'KWKIKWPVKWFKML'
elif case == 'Indolicidin':
self.template = 'ILPWKWPWWPWRR'
elif case == 'Temporin-A':
self.template = 'FLPLIGRVLSGIL'
elif case == 'Melittin':
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ'
elif case == 'Anoplin':
self.template = 'GLLKRIKTLL'
else:
self.template = case.upper().strip()
self.data = []
self.pad_length = pad_length
self.gf = gf
if gf:
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template]
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template]
# 笛卡尔积,即所有组合
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:]
self.template_seq = encode_sequence(self.template, self.pad_length)
def __len__(self):
return len(self.variants)
def __getitem__(self, idx):
variant = self.variants[idx]
seq2, label = variant, variant
enc_seq1 = self.template_seq
enc_seq2 = encode_sequence(seq2, self.pad_length)
if self.gf:
return (enc_seq1, enc_seq2, self.glob_feat), label
else:
return (enc_seq1, enc_seq2), label
class PeptidePairPicCaseDataset(Dataset):
def __init__(self, case:str ='r2', pad_length=30, side_enc=None, pcs=False, resize=None, gf=False):
if case == 'r2':
self.template = 'KWKIKWPVKWFKML'
elif case == 'Indolicidin':
self.template = 'ILPWKWPWWPWRR'
elif case == 'Temporin-A':
self.template = 'FLPLIGRVLSGIL'
elif case == 'Melittin':
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ'
elif case == 'Anoplin':
self.template = 'GLLKRIKTLL'
else:
self.template = case.upper().strip()
self.data = []
self.pad_length = pad_length
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
self.gf = gf
if gf:
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template]
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template]
# 笛卡尔积,即所有组合
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:]
self.template_pic = self.read_img(self.template)
if self.side_enc:
self.template_seq = encode_sequence(self.template, self.pad_length)
def read_img(self, peptide):
image = draw_peptide(peptide, self.resize, self.pcs)
return image
def __len__(self):
return len(self.variants)
def __getitem__(self, idx):
variant = self.variants[idx]
seq2, label = variant, variant
img1 = self.template_pic
img2 = self.read_img(variant)
if self.side_enc:
img1 = (img1, self.template_seq)
img2 = (img2, encode_sequence(seq2, self.pad_length))
if self.gf:
return (img1, img2, self.glob_feat), label
else:
return (img1, img2), label
aa_side = {
"A": "C", "R": "CCCNC(N)=N", "N": "CC(=O)N", "D": "CC(=O)O", "C": "CS",
"E": "CCC(=O)O", "Q": "CCC(=O)N", "G": "", "H": "Cc1cnc[nH]1", "I": "C(C)CC",
"L": "CC(C)C", "K": "CCCCN", "M": "CCSC", "F": "Cc1ccccc1", "P": "C1CCN1",
"S": "CO", "T": "C(C)O", "W": "Cc1c[nH]c2ccccc12", "Y": "Cc1ccc(O)cc1", "V": "C(C)C"
}
aa_tpl = {}
for aa, R in aa_side.items():
for stereo, chir in (("L", "@"), ("D", "@@")):
if aa == "G": # Gly 没手性
backbone = "N[C:{idx}]C" # N-CA(带编号)-C
else:
backbone = f"N[C{chir}H:{'{idx}'}]({R})C" # N-[C@H:idx](R)-C
aa_tpl[f"{aa}_{stereo}"] = backbone + "(=O)" # 中间残基
aa_tpl[f"{aa}_{stereo}_term"] = backbone + "(=O)O" # C 端
def build_peptide_smiles(seq: str) -> str:
"""
给定单字母序列,返回 backbone 带 [atom_map] 的 SMILES。
大写 = L 型, 小写 = D 型。编号 = 残基序号(1,2,3...) -> α-碳。
"""
if not seq:
return ""
out = []
n = len(seq)
for i, aa in enumerate(seq, start=1):
key = f"{aa.upper()}_{'L' if aa.isupper() else 'D'}"
if i == n:
key += "_term"
out.append(aa_tpl[key].format(idx=i))
return "".join(out)
protease_patterns = {
'trypsin': re.compile(r'(?<=[KR])(?!P)'),
'chymotrypsin': re.compile(r'(?<=[FYWL])(?!P)'),
'elastase': re.compile(r'(?<=[AVSGT])(?!P)'),
'enterokinase': re.compile(r'D{4}K(?=[^P])'),
'caspase': re.compile(r'(?<=D)(?=[GSA])'),
}
def draw_peptide(sequence, size=[768], pcs=False):
"""
根据输入序列生成多肽结构图,并基于常见蛋白酶识别模式高亮酶切位点肽键(红色)。
支持的酶及其正则模式(P1--P1'):
• trypsin: (?<=[KR])(?!P)
• chymotrypsin: (?<=[FYWL])(?!P)
• elastase: (?<=[AVSGT])(?!P)
• enterokinase: D{4}K(?=[^P])
• caspase: (?<=D)(?=[GSA])
"""
# # 1. 生成带 atom map 的 SMILES(现在序号标注在α-碳上)
smiles = build_peptide_smiles(sequence)
mol = Chem.MolFromSmiles(smiles)
# if mol is None:
# raise ValueError("SMILES 解析失败,请检查输入序列和侧链字典。")
AllChem.Compute2DCoords(mol)
highlight_bonds = []
bond_colors = {}
# ----------------------------------------------------
# 2. 先标 D 型残基:高亮与α-碳相连的键为蓝色
d_positions = {i for i, aa in enumerate(sequence, start=1) if aa.islower()}
for atom in mol.GetAtoms():
if atom.GetAtomMapNum() in d_positions:
# 这个atom就是α-碳,高亮与它相连的所有键
for b in atom.GetBonds():
idx = b.GetIdx()
if idx not in highlight_bonds:
highlight_bonds.append(idx)
bond_colors[idx] = (0.0, 0.0, 1.0)
# ----------------------------------------------------
# 3. 再标酶切键:红色(覆盖之前的蓝色)
if pcs:
cleavage_sites = set()
for pat in protease_patterns.values():
for m in pat.finditer(sequence):
cut = m.end() # 切在 cut 之后
if 1 <= cut < len(sequence):
cleavage_sites.add(cut)
for pos in cleavage_sites:
# 先找 P1 残基的 α-C
ca = next((a for a in mol.GetAtoms()
if a.GetAtomMapNum() == pos), None)
if ca is None:
continue
# 找同残基的羧基碳 (sp², 含 O 双键)
carbonyl_c = None
for nb in ca.GetNeighbors():
if nb.GetSymbol() != "C":
continue
# 判断是否有 "=O"
if any(bond.GetBondType() == Chem.BondType.DOUBLE and
o.GetSymbol() == "O"
for bond in nb.GetBonds()
for o in (bond.GetBeginAtom(), bond.GetEndAtom())):
carbonyl_c = nb
break
if carbonyl_c is None:
continue
# 羧基碳连到的 N 就是下一残基的氮
peptide_bond = None
for b in carbonyl_c.GetBonds():
o_atom = b.GetOtherAtom(carbonyl_c)
if o_atom.GetSymbol() == "N":
peptide_bond = b
break
if peptide_bond is None:
continue
bidx = peptide_bond.GetIdx()
if bidx not in highlight_bonds:
highlight_bonds.append(bidx)
bond_colors[bidx] = (1.0, 0.0, 0.0) # 红
# 4. 设置画布大小
if len(size) == 1:
w = h = size[0]
else:
w, h = size
# 5. MolDraw2DCairo 接收 highlightBondColors
drawer = rdMolDraw2D.MolDraw2DCairo(w, h)
# 你也可以通过 drawer.drawOptions() 调整一些样式:bond line width、atom font 等
drawer.DrawMolecule(
mol,
highlightAtoms=[],
highlightBonds=highlight_bonds,
highlightAtomColors={},
highlightBondColors=bond_colors
)
drawer.FinishDrawing()
# 6. 把输出的 PNG bytes 转成 Tensor
png_bytes = bytearray(drawer.GetDrawingText())
byte_tensor = torch.frombuffer(png_bytes, dtype=torch.uint8)
img = tvio.decode_png(byte_tensor, mode=tvio.ImageReadMode.RGB) # [3, H, W], uint8
img = tvtF.to_dtype(img, torch.float32)
img = tvtF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return img
if __name__ == '__main__':
# 假设 xlsx 文件路径为 "data.xlsx"
# 设置 pad_length 为 50,同时启用正反组合和自组合
pad_length = 30
dataset = PeptidePairDataset('r2_case', pad_length, "cls", include_reverse=False, include_self=False, one_way=True)
# 打印第一个数据项
if len(dataset) > 0:
(encoded_seq1, encoded_seq2), ratio = dataset[0]
print("第一个样本:")
print("变种1的编码张量形状:", encoded_seq1.shape)
print("变种2的编码张量形状:", encoded_seq2.shape)
print("标签比值(变种2/变种1):", ratio)
print(f"数据集大小:{len(dataset)}")
label_pos = 0
for (_, _), i in dataset:
label_pos += i
print(label_pos)
else:
print("未读入组合数据!")
# # 测试 PeptidesDataset
# pad_length = 30
# dataset = PeptidesDataset(xlsx_file="./dataset/train.xlsx", pad_length=pad_length)
# print(f"PeptidesDataset 样本总数: {len(dataset)}")
# if len(dataset) > 0:
# encoded_seq, label = dataset[0]
# print("第一个样本:")
# print("多肽编码张量形状:", encoded_seq.shape)
# print("标签浓度值(几何平均后):", label)
# else:
# print("未读取到有效数据!")