File size: 33,942 Bytes
acbef3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 |
import pandas as pd
import numpy as np
import itertools
import torch
from torch.utils.data import Dataset
import re
import json
from typing import Literal
import os
# import io
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
# from PIL import Image
import torchvision.io as tvio
# import torchvision.transforms as tvt
import torchvision.transforms.v2.functional as tvtF
# --- 辅助函数 ---
# 定义20种常见氨基酸字母(按字母顺序)
AMINO_ACIDS = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
AA_to_index = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
valid_aa = set(AMINO_ACIDS)
def is_valid_sequence(seq):
"""
判断序列是否只包含标准氨基酸字符(允许大写或小写,
对于小写表示 D 型氨基酸也视为合法)
"""
for ch in seq:
if not ch.isalpha():
return False
if ch.upper() not in valid_aa:
return False
return True
def parse_mic(mic_str):
"""
解析 MIC 数据,支持以下几种格式:
1. 数字,例如 "5" -> 5.0
2. ">{数字}" 或 "≥{数字}"(例如 ">4" 或 "≥ 4")→ 数值乘以 1.5
3. 平均值±标准差,例如 "3.2 ± 0.4" → 取平均值 3.2
4. 范围形式,例如 "2.0 - 4.0" → (2.0 + 4.0)/2
注:符号与数字之间可能存在空格,大于等于符号为 "≥" 而非 ">="
"""
if not isinstance(mic_str, str):
return float(mic_str)
mic_str = mic_str.strip()
mic_str = re.sub(r'\s+', '', mic_str)
# 匹配纯数字
if re.fullmatch(r'\d+(\.\d+)?', mic_str):
return float(mic_str)
# 匹配 >{数字} 或 ≥{数字}
m = re.fullmatch(r'[>≥](\d+(\.\d+)?)', mic_str)
if m:
num = float(m.group(1))
return num * 1.5
# 匹配 <{数字} 或 ≤{数字}
m = re.fullmatch(r'[<≤](\d+(\.\d+)?)', mic_str)
if m:
num = float(m.group(1))
return num * 0.75
# 匹配 {数字}±{数字}
m = re.fullmatch(r'(\d+(\.\d+)?)[±](\d+(\.\d+)?)', mic_str)
if m:
return float(m.group(1))
# 匹配 {数字}-{数字}
m = re.fullmatch(r'(\d+(\.\d+)?)-(\d+(\.\d+)?)', mic_str)
if m:
num1 = float(m.group(1))
num2 = float(m.group(3))
return (num1 + num2) / 2.0
print(f"Warning: 无法解析 MIC 值 {mic_str}")
return np.nan
def encode_sequence(seq, pad_length):
"""
将多肽序列转换为固定大小 (pad_length, 21) 的张量:
- 每个残基对应一行;
- 第1列: 表示是否为 D 型氨基酸(若字符为小写,则置 1,否则为 0);
- 后20列: 20种常见氨基酸的独热编码(先转为大写匹配)。
若序列长度小于 pad_length,则在末尾填充全 0 行。
"""
n = len(seq)
arr = np.zeros((pad_length, 21), dtype=np.float32)
# 对实际序列部分进行编码
for i, char in enumerate(seq):
if i >= pad_length:
break # 超出部分不处理(数据集构造时已过滤掉长序列)
if char.islower():
d_indicator = 1.0
aa = char.upper()
else:
d_indicator = 0.0
aa = char
arr[i, 0] = d_indicator
if aa in AA_to_index:
idx = AA_to_index[aa]
arr[i, idx + 1] = 1.0
else:
print(f"Warning: 氨基酸 {aa} 不在标准列表中")
return torch.tensor(arr)
def geometric_mean(values):
"""
计算数值序列的几何平均值
"""
log_vals = np.log(np.array(values))
return float(np.exp(log_vals.mean()))
def process_label(ratio, task):
"""
对比值 ratio 进行 log2 变换,并根据 task 参数返回最终标签:
- task="reg": 返回 log₂比值,并转换为 np.float32;
- task="cls": 根据 log₂比值进行分类:
如果 x <= -0.5 返回 1,
否则返回 0.
若 ratio 非正,返回 np.nan。
"""
if ratio <= 0:
return np.nan
ratio_log = np.log2(ratio)
if task == "reg":
return np.float32(ratio_log)
elif task == "cls":
if ratio_log < 0.:
return 1
else:
return 0
else:
raise ValueError("未知的 task 类型,请使用 'reg' 或 'cls'")
# --- 数据预处理与构建数据集 ---
def load_data(xlsx_file, condition=None):
"""
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均,
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。
返回:
groups: dict,其中 key 为原型序列,
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"),
value 为该变种所有 MIC 值的几何平均
"""
df = pd.read_excel(xlsx_file)
# df = df[df['TARGET ACTIVITY - ACTIVITY MEASURE VALUE'] != 'MBC']
groups = {}
for _, row in df.iterrows():
orig = row["SEQUENCE - Original"]
variant = row["SEQUENCE - D-type amino acid substitution"]
mic_raw = row["TARGET ACTIVITY - CONCENTRATION"]
# 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查)
if not (isinstance(orig, str) and is_valid_sequence(orig)):
continue
if not (isinstance(variant, str) and is_valid_sequence(variant)):
continue
mic_val = parse_mic(mic_raw)
if orig not in groups:
groups[orig] = {}
if variant not in groups[orig]:
groups[orig][variant] = []
groups[orig][variant].append(mic_val)
# 对每个变种计算几何平均(过滤掉 NaN 值)
groups_avg = {}
for orig, var_dict in groups.items():
groups_avg[orig] = {}
for variant, mic_list in var_dict.items():
mic_list = [x for x in mic_list if not np.isnan(x)]
if len(mic_list) == 0:
continue
groups_avg[orig][variant] = geometric_mean(mic_list)
return groups_avg
def load_data_stability(xlsx_file, condition):
"""
从 xlsx 文件中读取数据,将每个具体变种(同一原型-变种)对应的 MIC 值取几何平均,
并按照原型分组。对于原型和变种序列,若存在非标准氨基酸或非字母字符,则过滤掉该行数据。
返回:
groups: dict,其中 key 为原型序列,
value 为 dict,其 key 为变种序列("SEQUENCE - D-type amino acid substitution"),
value 为该变种所有 MIC 值的几何平均
"""
map_dict = {
'125fbs': '12.5% FBS',
'25fbs': '25% FBS',
'mhb': 'MHB',
'nacl': '150mM NaCl'
}
df = pd.read_excel(xlsx_file)
df = df[df['Condition'] == map_dict[condition]]
groups = {}
for _, row in df.iterrows():
variant = row["SEQUENCE"]
orig = variant.upper()
mic_raw = row["Activity"]
# 过滤包含非标准氨基酸或非字母字符的序列(原型和变种均检查)
if not (isinstance(orig, str) and is_valid_sequence(orig)):
continue
if not (isinstance(variant, str) and is_valid_sequence(variant)):
continue
mic_val = parse_mic(mic_raw)
if orig not in groups:
groups[orig] = {}
if variant not in groups[orig]:
groups[orig][variant] = []
groups[orig][variant].append(mic_val)
# 对每个变种计算几何平均(过滤掉 NaN 值)
groups_avg = {}
for orig, var_dict in groups.items():
groups_avg[orig] = {}
for variant, mic_list in var_dict.items():
mic_list = [x for x in mic_list if not np.isnan(x)]
if len(mic_list) == 0:
continue
groups_avg[orig][variant] = geometric_mean(mic_list)
return groups_avg
class PeptidePairDataset(Dataset):
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="cls",
include_reverse=False, include_self=False, one_way=False, gf=False) :
"""
构建数据集:
- 从 xlsx 文件中读取数据,并按照原型分组,
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本;
- 对于同一原型下不同变种构成配对;
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A));
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0);
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务,
将 log₂比值转为整数标签,规则为:
log₂比值 ≤ -0.5 → 1,
log₂比值 ≥ 0.5 → 2,
-0.5 < log₂比值 < 0.5 → 0.
每个数据项返回:
- 变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 标签:根据 task 不同分别为 32 位浮点数或整数
"""
if mode == "train":
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx')
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]:
one_way = True
if mode in ["test", "r2_case", 'r2_case_']:
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx')
else:
loader = load_data_stability
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx')
else:
raise ValueError("未知的 mode,请使用 'train' 或 'test'")
self.data = []
self.pad_length = pad_length
self.task = task
groups_avg = loader(xlsx_file, mode)
if gf:
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
# 针对每个原型,过滤掉长度超过 pad_length 的变种
for orig, variant_dict in groups_avg.items():
# a = len(self.data)
filtered_variants = {variant: mic for variant, mic in variant_dict.items()
if len(variant) <= pad_length}
variants = list(filtered_variants.keys())
n_variants = len(variants)
if n_variants == 0:
continue
if gf:
glob_feat = gf_dict[orig.upper()]
# 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0)
if include_self and (not one_way):
for variant in variants:
encoded_seq = encode_sequence(variant, pad_length)
label = process_label(1.0, task) # log2(1)=0
if gf:
self.data.append(((encoded_seq, encoded_seq, glob_feat), label))
else:
self.data.append(((encoded_seq, encoded_seq), label))
# 添加不同变种之间的样本
for i in [0] if one_way else range(n_variants):
for j in range(i + 1, n_variants):
var1 = variants[i]
var2 = variants[j]
mic1 = filtered_variants[var1]
mic2 = filtered_variants[var2]
# 正向组合: (var1, var2) 标签为 log₂(mic2/mic1)
ratio = mic2 / mic1 if mic1 != 0 else np.nan
label = process_label(ratio, task)
if np.isnan(label):
continue
encoded_var1 = encode_sequence(var1, pad_length)
encoded_var2 = encode_sequence(var2, pad_length)
if gf:
self.data.append(((encoded_var1, encoded_var2, glob_feat), label))
else:
self.data.append(((encoded_var1, encoded_var2), label))
# 若启用正反组合,则添加 (var2, var1)
if include_reverse and (not one_way):
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan
rev_label = process_label(rev_ratio, task)
if gf:
self.data.append(((encoded_var2, encoded_var1, glob_feat), rev_label))
else:
self.data.append(((encoded_var2, encoded_var1), rev_label))
# b = len(self.data)
# print(f"{orig},{b - a}")
def reg_sample_weight(self):
y = []
for _, label in self.data:
y.append(label)
y = np.array(y)
mu = np.mean(y)
sigma = np.std(y)
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2))
# 如果未提供 C,则使用 p 的中位数作为基准常数
C = np.median(p)
epsilon = 1e-6
# 使用对数转化计算采样权重: p 值越低权重越高
weights = np.log(C / (p + epsilon))
# 可选:对权重进行归一化处理,使得权重均值为1
weights_normalized = weights / np.mean(weights)
positive_weights = np.exp(weights_normalized)
return torch.tensor(positive_weights, dtype=torch.float32)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class PeptidePairPicDataset(Dataset):
def __init__(self, mode=Literal['train', 'test', '125fbs', 'nacl', '25fbs', 'mhb'], pad_length=30, task="reg",
include_reverse=False, include_self=False, one_way=False, gf=False,
side_enc=None, pcs=False, resize=None) :
"""
构建数据集:
- 从 xlsx 文件中读取数据,并按照原型分组,
同时过滤包含非标准氨基酸或非字母字符的行,以及变种序列长度超过 pad_length 的样本;
- 对于同一原型下不同变种构成配对;
- 参数 include_reverse: 是否启用正反组合(同时添加 (A, B) 和 (B, A));
- 参数 include_self: 是否启用自组合(添加 (A, A),标签为 log₂(1)=0);
- 参数 task: "reg" 表示回归任务(输出 32 位浮点数标签),"cls" 表示分类任务,
将 log₂比值转为整数标签,规则为:
log₂比值 ≤ -0.5 → 1,
log₂比值 ≥ 0.5 → 2,
-0.5 < log₂比值 < 0.5 → 0.
每个数据项返回:
- 变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 另一个变种多肽序列编码后的张量,形状为 (pad_length, 21)
- 标签:根据 task 不同分别为 32 位浮点数或整数
"""
if mode == "train":
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'train.xlsx')
elif mode in ["test", "r2_case", 'r2_case_', "125fbs", "nacl", "25fbs", "mhb"]:
one_way = True
if mode in ["test", "r2_case", 'r2_case_']:
loader = load_data
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', f'{mode}.xlsx')
else:
loader = load_data_stability
xlsx_file = os.path.join(os.path.dirname(__file__), 'dataset', 'stability.xlsx')
else:
raise ValueError("未知的 mode,请使用 'train' 或 'test'")
self.data = []
self.pics = {}
self.pad_length = pad_length
self.task = task
self.gf = gf
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
groups_avg = loader(xlsx_file, mode)
if gf:
gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
# 针对每个原型,过滤掉长度超过 pad_length 的变种
for orig, variant_dict in groups_avg.items():
# a = len(self.data)
filtered_variants = {variant: mic for variant, mic in variant_dict.items()
if len(variant) <= pad_length}
variants = list(filtered_variants.keys())
for variant in variants:
if self.pcs == 'mix' and variant == orig:
self.pics[variant] = self.read_img(variant, False)
else:
self.pics[variant] = self.read_img(variant, self.pcs)
n_variants = len(variants)
if n_variants == 0:
continue
if gf:
glob_feat = gf_dict[orig.upper()]
# 若启用自组合,则添加 (A, A) 样本,标签为 process_label(1, task) → log2(1)=0(再分类也为 0)
if include_self and (not one_way):
for variant in variants:
label = process_label(1.0, task) # log2(1)=0
if gf:
self.data.append((variant, variant, glob_feat, label))
else:
self.data.append((variant, variant, label))
# 添加不同变种之间的样本
for i in [0] if one_way else range(n_variants):
for j in range(i + 1, n_variants):
var1 = variants[i]
var2 = variants[j]
mic1 = filtered_variants[var1]
mic2 = filtered_variants[var2]
# 正向组合: (var1, var2) 标签为 log₂(mic2/mic1)
ratio = mic2 / mic1 if mic1 != 0 else np.nan
label = process_label(ratio, task)
if np.isnan(label):
continue
if gf:
self.data.append((var1, var2, glob_feat, label))
else:
self.data.append((var1, var2, label))
# 若启用正反组合,则添加 (var2, var1)
if include_reverse and (not one_way):
rev_ratio = mic1 / mic2 if mic2 != 0 else np.nan
rev_label = process_label(rev_ratio, task)
if gf:
self.data.append((var2, var1, glob_feat, rev_label))
else:
self.data.append((var2, var1, rev_label))
# b = len(self.data)
# print(f"{orig},{b - a}")
def reg_sample_weight(self):
y = []
for d in self.data:
label = d[-1]
y.append(label)
y = np.array(y)
mu = np.mean(y)
sigma = np.std(y)
p = 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(-((y - mu) ** 2) / (2 * sigma ** 2))
# 如果未提供 C,则使用 p 的中位数作为基准常数
C = np.median(p)
epsilon = 1e-6
# 使用对数转化计算采样权重: p 值越低权重越高
weights = np.log(C / (p + epsilon))
# 可选:对权重进行归一化处理,使得权重均值为1
weights_normalized = weights / np.mean(weights)
positive_weights = np.exp(weights_normalized)
return torch.tensor(positive_weights, dtype=torch.float32)
def read_img(self, peptide, pcs):
image = draw_peptide(peptide, self.resize, pcs)
return image
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if self.gf:
seq1, seq2, glob_feat, label = self.data[idx]
else:
seq1, seq2, label = self.data[idx]
img1 = self.pics[seq1]
img2 = self.pics[seq2]
if self.side_enc:
img1 = (img1, encode_sequence(seq1, self.pad_length))
img2 = (img2, encode_sequence(seq2, self.pad_length))
if self.gf:
return (img1, img2, glob_feat), label
else:
return (img1, img2), label
class SimplePairClsDataset(Dataset):
def __init__(self, pad_length=30, llm=False, ftr2=False, gf=False,
q_encoder=None, side_enc=None, pcs=False, resize=None):
if llm:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set_llm_aug.json')
elif ftr2:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'finetune_for_r2_llm.json')
else:
file_path = os.path.join(os.path.dirname(__file__), 'dataset', 'train_set.json')
with open(file_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
self.data = []
self.pics = {}
self.pad_length = pad_length
self.gf = gf
self.q_encoder = q_encoder
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
if gf:
self.gf_dict = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))
all_seqs = []
for orig, variants in dataset.items():
if len(orig) > pad_length:
continue
all_seqs.append(orig)
for label in ["1", "0"]:
for variant in variants[label]:
self.data.append((orig, variant, int(label)))
all_seqs.append(variant)
if q_encoder in ['cnn', 'rn18']:
for i in all_seqs:
if self.pcs == 'mix' and i.isupper():
self.pics[i] = self.read_img(i, False)
else:
self.pics[i] = self.read_img(i, self.pcs)
def read_img(self, peptide, pcs):
image = draw_peptide(peptide, self.resize, pcs)
return image
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
seq1, seq2, label = self.data[idx]
if self.q_encoder in ['cnn', 'rn18']:
img1 = self.pics[seq1]
img2 = self.pics[seq2]
if self.side_enc:
img1 = (img1, encode_sequence(seq1, self.pad_length))
img2 = (img2, encode_sequence(seq2, self.pad_length))
else:
img1 = encode_sequence(seq1, self.pad_length)
img2 = encode_sequence(seq2, self.pad_length)
if self.gf:
return (img1, img2, self.gf_dict[seq1]), label
else:
return (img1, img2), label
class PeptidePairCaseDataset(Dataset):
def __init__(self, case:str ='r2', pad_length=30, gf=False):
if case == 'r2':
self.template = 'KWKIKWPVKWFKML'
elif case == 'Indolicidin':
self.template = 'ILPWKWPWWPWRR'
elif case == 'Temporin-A':
self.template = 'FLPLIGRVLSGIL'
elif case == 'Melittin':
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ'
elif case == 'Anoplin':
self.template = 'GLLKRIKTLL'
else:
self.template = case.upper().strip()
self.data = []
self.pad_length = pad_length
self.gf = gf
if gf:
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template]
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template]
# 笛卡尔积,即所有组合
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:]
self.template_seq = encode_sequence(self.template, self.pad_length)
def __len__(self):
return len(self.variants)
def __getitem__(self, idx):
variant = self.variants[idx]
seq2, label = variant, variant
enc_seq1 = self.template_seq
enc_seq2 = encode_sequence(seq2, self.pad_length)
if self.gf:
return (enc_seq1, enc_seq2, self.glob_feat), label
else:
return (enc_seq1, enc_seq2), label
class PeptidePairPicCaseDataset(Dataset):
def __init__(self, case:str ='r2', pad_length=30, side_enc=None, pcs=False, resize=None, gf=False):
if case == 'r2':
self.template = 'KWKIKWPVKWFKML'
elif case == 'Indolicidin':
self.template = 'ILPWKWPWWPWRR'
elif case == 'Temporin-A':
self.template = 'FLPLIGRVLSGIL'
elif case == 'Melittin':
self.template = 'GIGAVLKVLTTGLPALISWIKRKRQQ'
elif case == 'Anoplin':
self.template = 'GLLKRIKTLL'
else:
self.template = case.upper().strip()
self.data = []
self.pad_length = pad_length
self.side_enc = True if side_enc else False
self.pcs = pcs
self.resize = resize
self.gf = gf
if gf:
self.glob_feat = torch.load(os.path.join(os.path.dirname(__file__), 'dataset', 'protbert.pth'))[self.template]
pools = [(ch.upper(), ch.lower()) if ch != 'G' else (ch.upper(),) for ch in self.template]
# 笛卡尔积,即所有组合
self.variants = [''.join(chars) for chars in itertools.product(*pools)][1:]
self.template_pic = self.read_img(self.template)
if self.side_enc:
self.template_seq = encode_sequence(self.template, self.pad_length)
def read_img(self, peptide):
image = draw_peptide(peptide, self.resize, self.pcs)
return image
def __len__(self):
return len(self.variants)
def __getitem__(self, idx):
variant = self.variants[idx]
seq2, label = variant, variant
img1 = self.template_pic
img2 = self.read_img(variant)
if self.side_enc:
img1 = (img1, self.template_seq)
img2 = (img2, encode_sequence(seq2, self.pad_length))
if self.gf:
return (img1, img2, self.glob_feat), label
else:
return (img1, img2), label
aa_side = {
"A": "C", "R": "CCCNC(N)=N", "N": "CC(=O)N", "D": "CC(=O)O", "C": "CS",
"E": "CCC(=O)O", "Q": "CCC(=O)N", "G": "", "H": "Cc1cnc[nH]1", "I": "C(C)CC",
"L": "CC(C)C", "K": "CCCCN", "M": "CCSC", "F": "Cc1ccccc1", "P": "C1CCN1",
"S": "CO", "T": "C(C)O", "W": "Cc1c[nH]c2ccccc12", "Y": "Cc1ccc(O)cc1", "V": "C(C)C"
}
aa_tpl = {}
for aa, R in aa_side.items():
for stereo, chir in (("L", "@"), ("D", "@@")):
if aa == "G": # Gly 没手性
backbone = "N[C:{idx}]C" # N-CA(带编号)-C
else:
backbone = f"N[C{chir}H:{'{idx}'}]({R})C" # N-[C@H:idx](R)-C
aa_tpl[f"{aa}_{stereo}"] = backbone + "(=O)" # 中间残基
aa_tpl[f"{aa}_{stereo}_term"] = backbone + "(=O)O" # C 端
def build_peptide_smiles(seq: str) -> str:
"""
给定单字母序列,返回 backbone 带 [atom_map] 的 SMILES。
大写 = L 型, 小写 = D 型。编号 = 残基序号(1,2,3...) -> α-碳。
"""
if not seq:
return ""
out = []
n = len(seq)
for i, aa in enumerate(seq, start=1):
key = f"{aa.upper()}_{'L' if aa.isupper() else 'D'}"
if i == n:
key += "_term"
out.append(aa_tpl[key].format(idx=i))
return "".join(out)
protease_patterns = {
'trypsin': re.compile(r'(?<=[KR])(?!P)'),
'chymotrypsin': re.compile(r'(?<=[FYWL])(?!P)'),
'elastase': re.compile(r'(?<=[AVSGT])(?!P)'),
'enterokinase': re.compile(r'D{4}K(?=[^P])'),
'caspase': re.compile(r'(?<=D)(?=[GSA])'),
}
def draw_peptide(sequence, size=[768], pcs=False):
"""
根据输入序列生成多肽结构图,并基于常见蛋白酶识别模式高亮酶切位点肽键(红色)。
支持的酶及其正则模式(P1--P1'):
• trypsin: (?<=[KR])(?!P)
• chymotrypsin: (?<=[FYWL])(?!P)
• elastase: (?<=[AVSGT])(?!P)
• enterokinase: D{4}K(?=[^P])
• caspase: (?<=D)(?=[GSA])
"""
# # 1. 生成带 atom map 的 SMILES(现在序号标注在α-碳上)
smiles = build_peptide_smiles(sequence)
mol = Chem.MolFromSmiles(smiles)
# if mol is None:
# raise ValueError("SMILES 解析失败,请检查输入序列和侧链字典。")
AllChem.Compute2DCoords(mol)
highlight_bonds = []
bond_colors = {}
# ----------------------------------------------------
# 2. 先标 D 型残基:高亮与α-碳相连的键为蓝色
d_positions = {i for i, aa in enumerate(sequence, start=1) if aa.islower()}
for atom in mol.GetAtoms():
if atom.GetAtomMapNum() in d_positions:
# 这个atom就是α-碳,高亮与它相连的所有键
for b in atom.GetBonds():
idx = b.GetIdx()
if idx not in highlight_bonds:
highlight_bonds.append(idx)
bond_colors[idx] = (0.0, 0.0, 1.0)
# ----------------------------------------------------
# 3. 再标酶切键:红色(覆盖之前的蓝色)
if pcs:
cleavage_sites = set()
for pat in protease_patterns.values():
for m in pat.finditer(sequence):
cut = m.end() # 切在 cut 之后
if 1 <= cut < len(sequence):
cleavage_sites.add(cut)
for pos in cleavage_sites:
# 先找 P1 残基的 α-C
ca = next((a for a in mol.GetAtoms()
if a.GetAtomMapNum() == pos), None)
if ca is None:
continue
# 找同残基的羧基碳 (sp², 含 O 双键)
carbonyl_c = None
for nb in ca.GetNeighbors():
if nb.GetSymbol() != "C":
continue
# 判断是否有 "=O"
if any(bond.GetBondType() == Chem.BondType.DOUBLE and
o.GetSymbol() == "O"
for bond in nb.GetBonds()
for o in (bond.GetBeginAtom(), bond.GetEndAtom())):
carbonyl_c = nb
break
if carbonyl_c is None:
continue
# 羧基碳连到的 N 就是下一残基的氮
peptide_bond = None
for b in carbonyl_c.GetBonds():
o_atom = b.GetOtherAtom(carbonyl_c)
if o_atom.GetSymbol() == "N":
peptide_bond = b
break
if peptide_bond is None:
continue
bidx = peptide_bond.GetIdx()
if bidx not in highlight_bonds:
highlight_bonds.append(bidx)
bond_colors[bidx] = (1.0, 0.0, 0.0) # 红
# 4. 设置画布大小
if len(size) == 1:
w = h = size[0]
else:
w, h = size
# 5. MolDraw2DCairo 接收 highlightBondColors
drawer = rdMolDraw2D.MolDraw2DCairo(w, h)
# 你也可以通过 drawer.drawOptions() 调整一些样式:bond line width、atom font 等
drawer.DrawMolecule(
mol,
highlightAtoms=[],
highlightBonds=highlight_bonds,
highlightAtomColors={},
highlightBondColors=bond_colors
)
drawer.FinishDrawing()
# 6. 把输出的 PNG bytes 转成 Tensor
png_bytes = bytearray(drawer.GetDrawingText())
byte_tensor = torch.frombuffer(png_bytes, dtype=torch.uint8)
img = tvio.decode_png(byte_tensor, mode=tvio.ImageReadMode.RGB) # [3, H, W], uint8
img = tvtF.to_dtype(img, torch.float32)
img = tvtF.normalize(img, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
return img
if __name__ == '__main__':
# 假设 xlsx 文件路径为 "data.xlsx"
# 设置 pad_length 为 50,同时启用正反组合和自组合
pad_length = 30
dataset = PeptidePairDataset('r2_case', pad_length, "cls", include_reverse=False, include_self=False, one_way=True)
# 打印第一个数据项
if len(dataset) > 0:
(encoded_seq1, encoded_seq2), ratio = dataset[0]
print("第一个样本:")
print("变种1的编码张量形状:", encoded_seq1.shape)
print("变种2的编码张量形状:", encoded_seq2.shape)
print("标签比值(变种2/变种1):", ratio)
print(f"数据集大小:{len(dataset)}")
label_pos = 0
for (_, _), i in dataset:
label_pos += i
print(label_pos)
else:
print("未读入组合数据!")
# # 测试 PeptidesDataset
# pad_length = 30
# dataset = PeptidesDataset(xlsx_file="./dataset/train.xlsx", pad_length=pad_length)
# print(f"PeptidesDataset 样本总数: {len(dataset)}")
# if len(dataset) > 0:
# encoded_seq, label = dataset[0]
# print("第一个样本:")
# print("多肽编码张量形状:", encoded_seq.shape)
# print("标签浓度值(几何平均后):", label)
# else:
# print("未读取到有效数据!") |