|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import itertools |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import re |
|
|
import json |
|
|
from typing import Literal |
|
|
import os |
|
|
|
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem |
|
|
from rdkit.Chem.Draw import rdMolDraw2D |
|
|
|
|
|
import torchvision.io as tvio |
|
|
|
|
|
import torchvision.transforms.v2.functional as tvtF |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AMINO_ACIDS = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', |
|
|
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'] |
|
|
AA_to_index = {aa: i for i, aa in enumerate(AMINO_ACIDS)} |
|
|
valid_aa = set(AMINO_ACIDS) |
|
|
|
|
|
def is_valid_sequence(seq): |
|
|
""" |
|
|
判断序列是否只包含标准氨基酸字符(允许大写或小写, |
|
|
对于小写表示 D 型氨基酸也视为合法) |
|
|
""" |
|
|
for ch in seq: |
|
|
if not ch.isalpha(): |
|
|
return False |
|
|
if ch.upper() not in valid_aa: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def parse_mic(mic_str): |
|
|
""" |
|
|
解析 MIC 数据,支持以下几种格式: |
|
|
1. 数字,例如 "5" -> 5.0 |
|
|
2. ">{数字}" 或 "≥{数字}"(例如 ">4" 或 "≥ 4")→ 数值乘以 1.5 |
|
|
3. 平均值±标准差,例如 "3.2 ± 0.4" → 取平均值 3.2 |
|
|
4. 范围形式,例如 "2.0 - 4.0" → (2.0 + 4.0)/2 |
|
|
|
|
|
注:符号与数字之间可能存在空格,大于等于符号为 "≥" 而非 ">=" |
|
|
""" |
|
|
if not isinstance(mic_str, str): |
|
|
return float(mic_str) |
|
|
|
|
|
mic_str = mic_str.strip() |
|
|
mic_str = re.sub(r'\s+', '', mic_str) |
|
|
|
|
|
|
|
|
if re.fullmatch(r'\d+(\.\d+)?', mic_str): |
|
|
return float(mic_str) |
|
|
|
|
|
|
|
|
m = re.fullmatch(r'[>≥](\d+(\.\d+)?)', mic_str) |
|
|
if m: |
|
|
num = float(m.group(1)) |
|
|
return num * 1.5 |
|
|
|
|
|
|
|
|
m = re.fullmatch(r'[<≤](\d+(\.\d+)?)', mic_str) |
|
|
if m: |
|
|
num = float(m.group(1)) |
|
|
return num * 0.75 |
|
|
|
|
|
|
|
|
m = re.fullmatch(r'(\d+(\.\d+)?)[±](\d+(\.\d+)?)', mic_str) |
|
|
if m: |
|
|
return float(m.group(1)) |
|
|
|
|
|
|
|
|
m = re.fullmatch(r'(\d+(\.\d+)?)-(\d+(\.\d+)?)', mic_str) |
|
|
if m: |
|
|
num1 = float(m.group(1)) |
|
|
num2 = float(m.group(3)) |
|
|
return (num1 + num2) / 2.0 |
|
|
|
|
|
print(f"Warning: 无法解析 MIC 值 {mic_str}") |
|
|
return np.nan |
|
|
|
|
|
def encode_sequence(seq, pad_length): |
|
|
""" |
|
|
将多肽序列转换为固定大小 (pad_length, 21) 的张量: |
|
|
- 每个残基对应一行; |
|
|
- 第1列: 表示是否为 D 型氨基酸(若字符为小写,则置 1,否则为 0); |
|
|
- 后20列: 20种常见氨基酸的独热编码(先转为大写匹配)。 |
|
|
若序列长度小于 pad_length,则在末尾填充全 0 行。 |
|
|
""" |
|
|
n = len(seq) |
|
|
arr = np.zeros((pad_length, 21), dtype=np.float32) |
|
|
|
|
|
|
|
|
for i, char in enumerate(seq): |
|
|
if i >= pad_length: |
|
|
break |
|
|
if char.islower(): |
|
|
d_indicator = 1.0 |
|
|
aa = char.upper() |
|
|
else: |
|
|
d_indicator = 0.0 |
|
|
aa = char |
|
|
arr[i, 0] = d_indicator |
|
|
if aa in AA_to_index: |
|
|
idx = AA_to_index[aa] |
|
|
arr[i, idx + 1] = 1.0 |
|
|
else: |
|
|
print(f"Warning: 氨基酸 {aa} 不在标准列表中") |
|
|
return torch.tensor(arr) |
|
|
|
|
|
def geometric_mean(values): |
|
|
""" |
|
|
计算数值序列的几何平均值 |
|
|
""" |
|
|
log_vals = np.log(np.array(values)) |
|
|
return float(np.exp(log_vals.mean())) |
|
|
|
|
|
def process_label(ratio, task): |
|
|
""" |
|
|
对比值 ratio 进行 log2 变换,并根据 task 参数返回最终标签: |
|
|
- task="reg": 返回 log₂比值,并转换为 np.float32; |
|
|
- task="cls": 根据 log₂比值进行分类: |
|
|
如果 x <= -0.5 返回 1, |
|
|
否则返回 0. |
|
|
若 ratio 非正,返回 np.nan。 |
|
|
""" |
|
|
if ratio <= 0: |
|
|
return np.nan |
|
|
ratio_log = np.log2(ratio) |
|
|
if task == "reg": |
|
|
return np.float32(ratio_log) |
|
|
elif task == "cls": |
|
|
if ratio_log < 0.: |
|
|
return 1 |
|
|
else: |
|
|
return 0 |
|
|
else: |
|
|
raise ValueError("未知的 task 类型,请使用 'reg' 或 'cls'") |
|
|
|
|
|
|
|
|
|
|
|
def load_data(xlsx_file, condition=None): |
|
|
""" |
|
|
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均, |
|
|
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。 |
|
|
|
|
|
返回: |
|
|
groups: dict,其中 key 为原型序列, |
|
|
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"), |
|
|
value 为该变种所有 MIC 值的几何平均 |
|
|
""" |
|
|
df = pd.read_excel(xlsx_file) |
|
|
|
|
|
|
|
|
groups = {} |
|
|
for _, row in df.iterrows(): |
|
|
orig = row["SEQUENCE - Original"] |
|
|
variant = row["SEQUENCE - D-type amino acid substitution"] |
|
|
mic_raw = row["TARGET ACTIVITY - CONCENTRATION"] |
|
|
|
|
|
|
|
|
if not (isinstance(orig, str) and is_valid_sequence(orig)): |
|
|
continue |
|
|
if not (isinstance(variant, str) and is_valid_sequence(variant)): |
|
|
continue |
|
|
|
|
|
mic_val = parse_mic(mic_raw) |
|
|
|
|
|
if orig not in groups: |
|
|
groups[orig] = {} |
|
|
if variant not in groups[orig]: |
|
|
groups[orig][variant] = [] |
|
|
groups[orig][variant].append(mic_val) |
|
|
|
|
|
|
|
|
groups_avg = {} |
|
|
for orig, var_dict in groups.items(): |
|
|
groups_avg[orig] = {} |
|
|
for variant, mic_list in var_dict.items(): |
|
|
mic_list = [x for x in mic_list if not np.isnan(x)] |
|
|
if len(mic_list) == 0: |
|
|
continue |
|
|
groups_avg[orig][variant] = geometric_mean(mic_list) |
|
|
return groups_avg |
|
|
|
|
|
|
|
|
def load_data_stability(xlsx_file, condition): |
|
|
""" |
|
|
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均, |
|
|
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。 |
|
|
|
|
|
返回: |
|
|
groups: dict,其中 key 为原型序列, |
|
|
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"), |
|
|
value 为该变种所有 MIC 值的几何平均 |
|
|
""" |
|
|
map_dict = { |
|
|
'125fbs': '12.5% FBS', |
|
|
'25fbs': '25% FBS', |
|
|
'mhb': 'MHB', |
|
|
'nacl': '150mM NaCl' |
|
|
} |
|
|
df = pd.read_excel(xlsx_file) |
|
|
df = df[df['Condition'] == map_dict[condition]] |
|
|
|
|
|
groups = {} |
|
|
for _, row in df.iterrows(): |
|
|
variant = row["SEQUENCE"] |
|
|
orig = variant.upper() |
|
|
mic_raw = row["Activity"] |
|
|
|
|
|
|
|
|
if not (isinstance(orig, str) and is_valid_sequence(orig)): |
|
|
continue |
|
|
if not (isinstance(variant, str) and is_valid_sequence(variant)): |
|
|
continue |
|
|
|
|
|
mic_val = parse_mic(mic_raw) |
|
|
|
|
|
if orig not in groups: |
|
|
groups[orig] = {} |
|
|
if variant not in groups[orig]: |
|
|
groups[orig][variant] = [] |
|
|
groups[orig][variant].append(mic_val) |
|
|
|
|
|
|
|
|
groups_avg = {} |
|
|
for orig, var_dict in groups.items(): |
|
|
groups_avg[orig] = {} |
|
|
for variant, mic_list in var_dict.items(): |
|
|
mic_list = [x for x in mic_list if not np.isnan(x)] |
|
|
if len(mic_list) == 0: |
|
|
continue |
|
|
groups_avg[orig][variant] = geometric_mean(mic_list) |
|
|
return groups_avg |
|
|
|
|
|
class PeptidePairDataset(Dataset): |
|
|
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="cls", |
|
|
include_reverse=False, include_self=False, one_way=False, gf=False) : |
|
|
""" |
|
|
构建数据集: |
|
|
- 从 xlsx 文件中读取数据,并按照原型分组, |
|
|
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本; |
|
|
- 对于同一原型下不同变种构成配对; |
|
|
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A)); |
|
|
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0); |
|
|
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务, |
|
|
将 log₂比值转为整数标签,规则为: |
|
|
log₂比值 ≤ -0.5 → 1, |
|
|
log₂比值 ≥ 0.5 → 2, |
|
|
-0.5 < log₂比值 < 0.5 → 0. |
|
|
|
|
|
每个数据项返回: |
|
|
- 变种多肽序列编码后的张量,形状为 (pad_length, 21) |
|
|
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21) |
|
|
- 标签:根据 task 不同分别为 32 位浮点数或整数 |
|
|
""" |
|
|
if mode == "train": |
|
|
loader = load_data |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx') |
|
|
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]: |
|
|
one_way = True |
|
|
if mode in ["test", "r2_case", 'r2_case_']: |
|
|
loader = load_data |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx') |
|
|
else: |
|
|
loader = load_data_stability |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx') |
|
|
else: |
|
|
raise ValueError("未知的 mode,请使用 'train' 或 'test'") |
|
|
|
|
|
self.data = [] |
|
|
self.pad_length = pad_length |
|
|
self.task = task |
|
|
groups_avg = loader(xlsx_file, mode) |
|
|
if gf: |
|
|
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) |
|
|
|
|
|
|
|
|
for orig, variant_dict in groups_avg.items(): |
|
|
|
|
|
filtered_variants = {variant: mic for variant, mic in variant_dict.items() |
|
|
if len(variant) <= pad_length} |
|
|
variants = list(filtered_variants.keys()) |
|
|
n_variants = len(variants) |
|
|
if n_variants == 0: |
|
|
continue |
|
|
|
|
|
if gf: |
|
|
glob_feat = gf_dict[orig.upper()] |
|
|
|
|
|
|
|
|
if include_self and (not one_way): |
|
|
for variant in variants: |
|
|
encoded_seq = encode_sequence(variant, pad_length) |
|
|
label = process_label(1.0, task) |
|
|
if gf: |
|
|
self.data.append(((encoded_seq, encoded_seq, glob_feat), label)) |
|
|
else: |
|
|
self.data.append(((encoded_seq, encoded_seq), label)) |
|
|
|
|
|
|
|
|
for i in [0] if one_way else range(n_variants): |
|
|
for j in range(i + 1, n_variants): |
|
|
var1 = variants[i] |
|
|
var2 = variants[j] |
|
|
mic1 = filtered_variants[var1] |
|
|
mic2 = filtered_variants[var2] |
|
|
|
|
|
|
|
|
ratio = mic2 / mic1 if mic1 != 0 else np.nan |
|
|
label = process_label(ratio, task) |
|
|
if np.isnan(label): |
|
|
continue |
|
|
encoded_var1 = encode_sequence(var1, pad_length) |
|
|
encoded_var2 = encode_sequence(var2, pad_length) |
|
|
if gf: |
|
|
self.data.append(((encoded_var1, encoded_var2, glob_feat), label)) |
|
|
else: |
|
|
self.data.append(((encoded_var1, encoded_var2), label)) |
|
|
|
|
|
|
|
|
if include_reverse and (not one_way): |
|
|
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan |
|
|
rev_label = process_label(rev_ratio, task) |
|
|
if gf: |
|
|
self.data.append(((encoded_var2, encoded_var1, glob_feat), rev_label)) |
|
|
else: |
|
|
self.data.append(((encoded_var2, encoded_var1), rev_label)) |
|
|
|
|
|
|
|
|
|
|
|
def reg_sample_weight(self): |
|
|
y = [] |
|
|
for _, label in self.data: |
|
|
y.append(label) |
|
|
y = np.array(y) |
|
|
mu = np.mean(y) |
|
|
sigma = np.std(y) |
|
|
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2)) |
|
|
|
|
|
|
|
|
C = np.median(p) |
|
|
epsilon = 1e-6 |
|
|
|
|
|
|
|
|
weights = np.log(C / (p + epsilon)) |
|
|
|
|
|
|
|
|
weights_normalized = weights / np.mean(weights) |
|
|
positive_weights = np.exp(weights_normalized) |
|
|
|
|
|
return torch.tensor(positive_weights, dtype=torch.float32) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return self.data[idx] |
|
|
|
|
|
|
|
|
class PeptidePairPicDataset(Dataset): |
|
|
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="reg", |
|
|
include_reverse=False, include_self=False, one_way=False, gf=False, |
|
|
side_enc=None, pcs=False, resize=None) : |
|
|
""" |
|
|
构建数据集: |
|
|
- 从 xlsx 文件中读取数据,并按照原型分组, |
|
|
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本; |
|
|
- 对于同一原型下不同变种构成配对; |
|
|
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A)); |
|
|
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0); |
|
|
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务, |
|
|
将 log₂比值转为整数标签,规则为: |
|
|
log₂比值 ≤ -0.5 → 1, |
|
|
log₂比值 ≥ 0.5 → 2, |
|
|
-0.5 < log₂比值 < 0.5 → 0. |
|
|
|
|
|
每个数据项返回: |
|
|
- 变种多肽序列编码后的张量,形状为 (pad_length, 21) |
|
|
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21) |
|
|
- 标签:根据 task 不同分别为 32 位浮点数或整数 |
|
|
""" |
|
|
if mode == "train": |
|
|
loader = load_data |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx') |
|
|
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]: |
|
|
one_way = True |
|
|
if mode in ["test", "r2_case", 'r2_case_']: |
|
|
loader = load_data |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx') |
|
|
else: |
|
|
loader = load_data_stability |
|
|
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx') |
|
|
else: |
|
|
raise ValueError("未知的 mode,请使用 'train' 或 'test'") |
|
|
|
|
|
self.data = [] |
|
|
self.pics = {} |
|
|
self.pad_length = pad_length |
|
|
self.task = task |
|
|
self.gf = gf |
|
|
self.side_enc = True if side_enc else False |
|
|
self.pcs = pcs |
|
|
self.resize = resize |
|
|
groups_avg = loader(xlsx_file, mode) |
|
|
if gf: |
|
|
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) |
|
|
|
|
|
|
|
|
for orig, variant_dict in groups_avg.items(): |
|
|
|
|
|
filtered_variants = {variant: mic for variant, mic in variant_dict.items() |
|
|
if len(variant) <= pad_length} |
|
|
variants = list(filtered_variants.keys()) |
|
|
for variant in variants: |
|
|
if self.pcs == 'mix' and variant == orig: |
|
|
self.pics[variant] = self.read_img(variant, False) |
|
|
else: |
|
|
self.pics[variant] = self.read_img(variant, self.pcs) |
|
|
n_variants = len(variants) |
|
|
if n_variants == 0: |
|
|
continue |
|
|
|
|
|
if gf: |
|
|
glob_feat = gf_dict[orig.upper()] |
|
|
|
|
|
|
|
|
if include_self and (not one_way): |
|
|
for variant in variants: |
|
|
label = process_label(1.0, task) |
|
|
if gf: |
|
|
self.data.append((variant, variant, glob_feat, label)) |
|
|
else: |
|
|
self.data.append((variant, variant, label)) |
|
|
|
|
|
|
|
|
for i in [0] if one_way else range(n_variants): |
|
|
for j in range(i + 1, n_variants): |
|
|
var1 = variants[i] |
|
|
var2 = variants[j] |
|
|
mic1 = filtered_variants[var1] |
|
|
mic2 = filtered_variants[var2] |
|
|
|
|
|
|
|
|
ratio = mic2 / mic1 if mic1 != 0 else np.nan |
|
|
label = process_label(ratio, task) |
|
|
if np.isnan(label): |
|
|
continue |
|
|
if gf: |
|
|
self.data.append((var1, var2, glob_feat, label)) |
|
|
else: |
|
|
self.data.append((var1, var2, label)) |
|
|
|
|
|
|
|
|
if include_reverse and (not one_way): |
|
|
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan |
|
|
rev_label = process_label(rev_ratio, task) |
|
|
if gf: |
|
|
self.data.append((var2, var1, glob_feat, rev_label)) |
|
|
else: |
|
|
self.data.append((var2, var1, rev_label)) |
|
|
|
|
|
|
|
|
|
|
|
def reg_sample_weight(self): |
|
|
y = [] |
|
|
for d in self.data: |
|
|
label = d[-1] |
|
|
y.append(label) |
|
|
y = np.array(y) |
|
|
mu = np.mean(y) |
|
|
sigma = np.std(y) |
|
|
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2)) |
|
|
|
|
|
|
|
|
C = np.median(p) |
|
|
epsilon = 1e-6 |
|
|
|
|
|
|
|
|
weights = np.log(C / (p + epsilon)) |
|
|
|
|
|
|
|
|
weights_normalized = weights / np.mean(weights) |
|
|
positive_weights = np.exp(weights_normalized) |
|
|
|
|
|
return torch.tensor(positive_weights, dtype=torch.float32) |
|
|
|
|
|
def read_img(self, peptide, pcs): |
|
|
image = draw_peptide(peptide, self.resize, pcs) |
|
|
return image |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
if self.gf: |
|
|
seq1, seq2, glob_feat, label = self.data[idx] |
|
|
else: |
|
|
seq1, seq2, label = self.data[idx] |
|
|
img1 = self.pics[seq1] |
|
|
img2 = self.pics[seq2] |
|
|
|
|
|
if self.side_enc: |
|
|
img1 = (img1, encode_sequence(seq1, self.pad_length)) |
|
|
img2 = (img2, encode_sequence(seq2, self.pad_length)) |
|
|
|
|
|
if self.gf: |
|
|
return (img1, img2, glob_feat), label |
|
|
else: |
|
|
return (img1, img2), label |
|
|
|
|
|
|
|
|
class SimplePairClsDataset(Dataset): |
|
|
def __init__(self, pad_length=30, llm=False, ftr2=False, gf=False, |
|
|
q_encoder=None, side_enc=None, pcs=False, resize=None): |
|
|
if llm: |
|
|
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set_llm_aug.json') |
|
|
elif ftr2: |
|
|
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'finetune_for_r2_llm.json') |
|
|
else: |
|
|
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set.json') |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
dataset = json.load(f) |
|
|
|
|
|
self.data = [] |
|
|
self.pics = {} |
|
|
self.pad_length = pad_length |
|
|
self.gf = gf |
|
|
self.q_encoder = q_encoder |
|
|
self.side_enc = True if side_enc else False |
|
|
self.pcs = pcs |
|
|
self.resize = resize |
|
|
if gf: |
|
|
self.gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth')) |
|
|
|
|
|
all_seqs = [] |
|
|
for orig, variants in dataset.items(): |
|
|
if len(orig) > pad_length: |
|
|
continue |
|
|
all_seqs.append(orig) |
|
|
for label in ["1", "0"]: |
|
|
for variant in variants[label]: |
|
|
self.data.append((orig, variant, int(label))) |
|
|
all_seqs.append(variant) |
|
|
if q_encoder in ['cnn', 'rn18']: |
|
|
for i in all_seqs: |
|
|
if self.pcs == 'mix' and i.isupper(): |
|
|
self.pics[i] = self.read_img(i, False) |
|
|
else: |
|
|
self.pics[i] = self.read_img(i, self.pcs) |
|
|
|
|
|
def read_img(self, peptide, pcs): |
|
|
image = draw_peptide(peptide, self.resize, pcs) |
|
|
return image |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
seq1, seq2, label = self.data[idx] |
|
|
if self.q_encoder in ['cnn', 'rn18']: |
|
|
img1 = self.pics[seq1] |
|
|
img2 = self.pics[seq2] |
|
|
|
|
|
if self.side_enc: |
|
|
img1 = (img1, encode_sequence(seq1, self.pad_length)) |
|
|
img2 = (img2, encode_sequence(seq2, self.pad_length)) |
|
|
|
|
|
else: |
|
|
img1 = encode_sequence(seq1, self.pad_length) |
|
|
img2 = encode_sequence(seq2, self.pad_length) |
|
|
|
|
|
if self.gf: |
|
|
return (img1, img2, self.gf_dict[seq1]), label |
|
|
else: |
|
|
return (img1, img2), label |
|
|
|
|
|
|
|
|
class PeptidePairCaseDataset(Dataset): |
|
|
def __init__(self, case:str ='r2', pad_length=30, gf=False): |
|
|
|
|
|
if case == 'r2': |
|
|
self.template = 'KWKIKWPVKWFKML' |
|
|
elif case == 'Indolicidin': |
|
|
self.template = 'ILPWKWPWWPWRR' |
|
|
elif case == 'Temporin-A': |
|
|
self.template = 'FLPLIGRVLSGIL' |
|
|
elif case == 'Melittin': |
|
|
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ' |
|
|
elif case == 'Anoplin': |
|
|
self.template = 'GLLKRIKTLL' |
|
|
else: |
|
|
self.template = case.upper().strip() |
|
|
self.data = [] |
|
|
self.pad_length = pad_length |
|
|
self.gf = gf |
|
|
|
|
|
if gf: |
|
|
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template] |
|
|
|
|
|
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template] |
|
|
|
|
|
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:] |
|
|
|
|
|
self.template_seq = encode_sequence(self.template, self.pad_length) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.variants) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
variant = self.variants[idx] |
|
|
seq2, label = variant, variant |
|
|
enc_seq1 = self.template_seq |
|
|
enc_seq2 = encode_sequence(seq2, self.pad_length) |
|
|
|
|
|
if self.gf: |
|
|
return (enc_seq1, enc_seq2, self.glob_feat), label |
|
|
else: |
|
|
return (enc_seq1, enc_seq2), label |
|
|
|
|
|
|
|
|
|
|
|
class PeptidePairPicCaseDataset(Dataset): |
|
|
def __init__(self, case:str ='r2', pad_length=30, side_enc=None, pcs=False, resize=None, gf=False): |
|
|
|
|
|
if case == 'r2': |
|
|
self.template = 'KWKIKWPVKWFKML' |
|
|
elif case == 'Indolicidin': |
|
|
self.template = 'ILPWKWPWWPWRR' |
|
|
elif case == 'Temporin-A': |
|
|
self.template = 'FLPLIGRVLSGIL' |
|
|
elif case == 'Melittin': |
|
|
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ' |
|
|
elif case == 'Anoplin': |
|
|
self.template = 'GLLKRIKTLL' |
|
|
else: |
|
|
self.template = case.upper().strip() |
|
|
self.data = [] |
|
|
self.pad_length = pad_length |
|
|
self.side_enc = True if side_enc else False |
|
|
self.pcs = pcs |
|
|
self.resize = resize |
|
|
self.gf = gf |
|
|
|
|
|
if gf: |
|
|
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template] |
|
|
|
|
|
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template] |
|
|
|
|
|
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:] |
|
|
|
|
|
self.template_pic = self.read_img(self.template) |
|
|
if self.side_enc: |
|
|
self.template_seq = encode_sequence(self.template, self.pad_length) |
|
|
|
|
|
def read_img(self, peptide): |
|
|
image = draw_peptide(peptide, self.resize, self.pcs) |
|
|
return image |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.variants) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
variant = self.variants[idx] |
|
|
seq2, label = variant, variant |
|
|
img1 = self.template_pic |
|
|
img2 = self.read_img(variant) |
|
|
|
|
|
if self.side_enc: |
|
|
img1 = (img1, self.template_seq) |
|
|
img2 = (img2, encode_sequence(seq2, self.pad_length)) |
|
|
|
|
|
if self.gf: |
|
|
return (img1, img2, self.glob_feat), label |
|
|
else: |
|
|
return (img1, img2), label |
|
|
|
|
|
|
|
|
aa_side = { |
|
|
"A": "C", "R": "CCCNC(N)=N", "N": "CC(=O)N", "D": "CC(=O)O", "C": "CS", |
|
|
"E": "CCC(=O)O", "Q": "CCC(=O)N", "G": "", "H": "Cc1cnc[nH]1", "I": "C(C)CC", |
|
|
"L": "CC(C)C", "K": "CCCCN", "M": "CCSC", "F": "Cc1ccccc1", "P": "C1CCN1", |
|
|
"S": "CO", "T": "C(C)O", "W": "Cc1c[nH]c2ccccc12", "Y": "Cc1ccc(O)cc1", "V": "C(C)C" |
|
|
} |
|
|
|
|
|
aa_tpl = {} |
|
|
for aa, R in aa_side.items(): |
|
|
for stereo, chir in (("L", "@"), ("D", "@@")): |
|
|
if aa == "G": |
|
|
backbone = "N[C:{idx}]C" |
|
|
else: |
|
|
backbone = f"N[C{chir}H:{'{idx}'}]({R})C" |
|
|
aa_tpl[f"{aa}_{stereo}"] = backbone + "(=O)" |
|
|
aa_tpl[f"{aa}_{stereo}_term"] = backbone + "(=O)O" |
|
|
|
|
|
def build_peptide_smiles(seq: str) -> str: |
|
|
""" |
|
|
给定单字母序列,返回 backbone 带 [atom_map] 的 SMILES。 |
|
|
大写 = L 型, 小写 = D 型。编号 = 残基序号(1,2,3...) -> α-碳。 |
|
|
""" |
|
|
if not seq: |
|
|
return "" |
|
|
|
|
|
out = [] |
|
|
n = len(seq) |
|
|
for i, aa in enumerate(seq, start=1): |
|
|
key = f"{aa.upper()}_{'L' if aa.isupper() else 'D'}" |
|
|
if i == n: |
|
|
key += "_term" |
|
|
out.append(aa_tpl[key].format(idx=i)) |
|
|
return "".join(out) |
|
|
|
|
|
protease_patterns = { |
|
|
'trypsin': re.compile(r'(?<=[KR])(?!P)'), |
|
|
'chymotrypsin': re.compile(r'(?<=[FYWL])(?!P)'), |
|
|
'elastase': re.compile(r'(?<=[AVSGT])(?!P)'), |
|
|
'enterokinase': re.compile(r'D{4}K(?=[^P])'), |
|
|
'caspase': re.compile(r'(?<=D)(?=[GSA])'), |
|
|
} |
|
|
|
|
|
def draw_peptide(sequence, size=[768], pcs=False): |
|
|
""" |
|
|
根据输入序列生成多肽结构图,并基于常见蛋白酶识别模式高亮酶切位点肽键(红色)。 |
|
|
支持的酶及其正则模式(P1--P1'): |
|
|
• trypsin: (?<=[KR])(?!P) |
|
|
• chymotrypsin: (?<=[FYWL])(?!P) |
|
|
• elastase: (?<=[AVSGT])(?!P) |
|
|
• enterokinase: D{4}K(?=[^P]) |
|
|
• caspase: (?<=D)(?=[GSA]) |
|
|
""" |
|
|
|
|
|
|
|
|
smiles = build_peptide_smiles(sequence) |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
|
|
|
|
|
|
AllChem.Compute2DCoords(mol) |
|
|
|
|
|
highlight_bonds = [] |
|
|
bond_colors = {} |
|
|
|
|
|
|
|
|
|
|
|
d_positions = {i for i, aa in enumerate(sequence, start=1) if aa.islower()} |
|
|
|
|
|
for atom in mol.GetAtoms(): |
|
|
if atom.GetAtomMapNum() in d_positions: |
|
|
|
|
|
for b in atom.GetBonds(): |
|
|
idx = b.GetIdx() |
|
|
if idx not in highlight_bonds: |
|
|
highlight_bonds.append(idx) |
|
|
bond_colors[idx] = (0.0, 0.0, 1.0) |
|
|
|
|
|
|
|
|
|
|
|
if pcs: |
|
|
cleavage_sites = set() |
|
|
for pat in protease_patterns.values(): |
|
|
for m in pat.finditer(sequence): |
|
|
cut = m.end() |
|
|
if 1 <= cut < len(sequence): |
|
|
cleavage_sites.add(cut) |
|
|
|
|
|
for pos in cleavage_sites: |
|
|
|
|
|
ca = next((a for a in mol.GetAtoms() |
|
|
if a.GetAtomMapNum() == pos), None) |
|
|
if ca is None: |
|
|
continue |
|
|
|
|
|
|
|
|
carbonyl_c = None |
|
|
for nb in ca.GetNeighbors(): |
|
|
if nb.GetSymbol() != "C": |
|
|
continue |
|
|
|
|
|
if any(bond.GetBondType() == Chem.BondType.DOUBLE and |
|
|
o.GetSymbol() == "O" |
|
|
for bond in nb.GetBonds() |
|
|
for o in (bond.GetBeginAtom(), bond.GetEndAtom())): |
|
|
carbonyl_c = nb |
|
|
break |
|
|
if carbonyl_c is None: |
|
|
continue |
|
|
|
|
|
|
|
|
peptide_bond = None |
|
|
for b in carbonyl_c.GetBonds(): |
|
|
o_atom = b.GetOtherAtom(carbonyl_c) |
|
|
if o_atom.GetSymbol() == "N": |
|
|
peptide_bond = b |
|
|
break |
|
|
if peptide_bond is None: |
|
|
continue |
|
|
|
|
|
bidx = peptide_bond.GetIdx() |
|
|
if bidx not in highlight_bonds: |
|
|
highlight_bonds.append(bidx) |
|
|
bond_colors[bidx] = (1.0, 0.0, 0.0) |
|
|
|
|
|
|
|
|
if len(size) == 1: |
|
|
w = h = size[0] |
|
|
else: |
|
|
w, h = size |
|
|
|
|
|
|
|
|
drawer = rdMolDraw2D.MolDraw2DCairo(w, h) |
|
|
|
|
|
drawer.DrawMolecule( |
|
|
mol, |
|
|
highlightAtoms=[], |
|
|
highlightBonds=highlight_bonds, |
|
|
highlightAtomColors={}, |
|
|
highlightBondColors=bond_colors |
|
|
) |
|
|
drawer.FinishDrawing() |
|
|
|
|
|
|
|
|
png_bytes = bytearray(drawer.GetDrawingText()) |
|
|
byte_tensor = torch.frombuffer(png_bytes, dtype=torch.uint8) |
|
|
img = tvio.decode_png(byte_tensor, mode=tvio.ImageReadMode.RGB) |
|
|
img = tvtF.to_dtype(img, torch.float32) |
|
|
img = tvtF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
return img |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
pad_length = 30 |
|
|
dataset = PeptidePairDataset('r2_case', pad_length, "cls", include_reverse=False, include_self=False, one_way=True) |
|
|
|
|
|
|
|
|
if len(dataset) > 0: |
|
|
(encoded_seq1, encoded_seq2), ratio = dataset[0] |
|
|
print("第一个样本:") |
|
|
print("变种1的编码张量形状:", encoded_seq1.shape) |
|
|
print("变种2的编码张量形状:", encoded_seq2.shape) |
|
|
print("标签比值(变种2/变种1):", ratio) |
|
|
print(f"数据集大小:{len(dataset)}") |
|
|
label_pos = 0 |
|
|
for (_, _), i in dataset: |
|
|
label_pos += i |
|
|
print(label_pos) |
|
|
|
|
|
else: |
|
|
print("未读入组合数据!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|