{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from Bio import SeqIO\n", "from DeepMFPP.data_helper import Data2EqlTensor" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('ARRRRCSDRFRNCPADEALCGRRRR', 25)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", ">000000000000000010000\n", ">010000000010000000000\n", ">010001000010000000000\n", ">011000000001000000000\n", ">100000000000000000000\n", "\"\"\"\n", "\n", "file_path = './test_samples.fa'\n", "data = []\n", "for record in SeqIO.parse(file_path, 'fasta'):\n", " data.append((record.id, str(record.seq)))\n", "\n", "data[0][1],len(data[0][1])" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('peptide_1', 'ARRRRCSDRFRNCPADEALCGRRRR'),\n", " ('peptide_2', 'FFHHIFRGIVHVGKTIHKLVTGT'),\n", " ('peptide_3', 'GLRKRLRKFRNKIKEKLKKIGQKIQGFVPKLAPRTDY'),\n", " ('peptide_4', 'FLGALWNVAKSVF'),\n", " ('peptide_5', 'KIKSCYYLPCFVTS')]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "length > 50:0\n" ] }, { "data": { "text/plain": [ "torch.Size([5, 50])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seqs,ids = Data2EqlTensor(data,50)\n", "seqs.shape" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 1, 15, 15, 15, 15, 2, 16, 3, 15, 5, 15, 12, 2, 13, 1, 3, 4, 1,\n", " 10, 2, 6, 15, 15, 15, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 5, 5, 7, 7, 8, 5, 15, 6, 8, 18, 7, 18, 6, 9, 17, 8, 7, 9,\n", " 10, 18, 17, 6, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 6, 10, 15, 9, 15, 10, 15, 9, 5, 15, 12, 9, 8, 9, 4, 9, 10, 9,\n", " 9, 8, 6, 14, 9, 8, 14, 6, 5, 18, 13, 9, 10, 1, 13, 15, 17, 3,\n", " 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 5, 10, 6, 1, 10, 19, 12, 18, 1, 9, 16, 18, 5, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", " [ 9, 8, 9, 16, 2, 20, 20, 10, 13, 2, 5, 18, 17, 16, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seqs" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from DeepMFPP.model import DeepMFPP\n", "import torch\n", "import torch.nn as nn\n", "import numpy as np\n", "from DeepMFPP.config import ArgsConfig\n", "\n", "args = ArgsConfig()\n", "args.embedding_size = 480\n", "args.aa_dict = 'esm'\n", "args.loss_fn_name = 'MLFDL'\n", "args.weight_decay = 0\n", "args.batch_size = 192\n", "args.dropout = 0.62\n", "args.scale_factor = 100\n", "args.fldl_pos_weight = 0.4\n", "\n", "sigmoid = nn.Sigmoid()\n", "def predict(seqs:torch.Tensor,data:list,model_path:str, top_k:int=0,threshold:float=0.5, device=args.device):\n", " torch.manual_seed(args.random_seed)\n", " with torch.no_grad():\n", " model = DeepMFPP(vocab_size=21,embedding_size=args.embedding_size, encoder_layer_num=1, fan_layer_num=1, num_heads=8,output_size=args.num_classes,\n", " esm_path=args.ems_path,layer_idx=args.esm_layer_idx,dropout=args.dropout,Contrastive_Learning=args.ctl).to(args.device)\n", " model.eval()\n", " state_dict = torch.load(model_path, map_location=device)\n", " model.load_state_dict(state_dict,strict=False)\n", " model.to(args.device)\n", " # print(device)\n", " seqs.to(args.device)\n", " _, logits = model(seqs)\n", " prob = sigmoid(logits)\n", " # logits = np.round(logits.cpu().numpy(),3)\n", " # prob = np.round(prob.cpu().numpy(),3)\n", " # logits = logits.cpu().numpy()\n", " prob = prob.cpu().numpy()\n", " # print(logits)\n", " # print(prob)\n", " categories = ['AAP', 'ABP', 'ACP', 'ACVP','ADP', 'AEP', 'AFP', 'AHIVP', 'AHP', 'AIP', 'AMRSAP', \n", " 'APP', 'ATP', 'AVP', 'BBP', 'BIP', 'CPP', 'DPPIP', 'QSP', 'SBP', 'THP']\n", " final_out = []\n", " for i, j, k in zip(data, logits, prob):\n", " temp = [i[0], i[1]] # , f\"logits:{j}\", f\"probability:{k}\"\n", " \n", " # 过滤概率值大于阈值的预测结果\n", " result_dict = {}\n", " for label, p in zip(categories, k):\n", " # print(p)\n", " if p > threshold:\n", " result_dict[label] = round(float(p), 4)\n", " \n", " # 返回概率值大于阈值的字典对\n", " # 示例: {'AVP': 0.567, 'ATP': 0.678, ...}\n", " if result_dict:\n", " sorted_result = {k: v for k, v in sorted(result_dict.items(), key=lambda item: item[1], reverse=True)}\n", " else:\n", " sorted_result = {}\n", " # print(sorted_result)\n", "\n", " # 选择概率值最高的 top_k 个预测结果\n", " if top_k: \n", " sorted_items_list = list(sorted_result.items())\n", " top_k_result = dict(sorted_items_list[:top_k])\n", " top_k_result_str = \", \".join(f\"{key}: {value}\" for key, value in top_k_result.items())\n", " temp.extend([top_k_result_str])\n", " \n", " else:\n", " sorted_result_str = \", \".join(f\"{key}: {value}\" for key, value in sorted_result.items())\n", " temp.extend([sorted_result_str])\n", " \n", " final_out.append(temp)\n", " \n", " return final_out" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "def MFPs_classifier(file:str,threshold:float=0.5,top_k=0):\n", " data = []\n", " for record in SeqIO.parse(file, 'fasta'):\n", " data.append((record.id, str(record.seq)))\n", " seqs,_ = Data2EqlTensor(data,51,AminoAcid_vocab=args.aa_dict)\n", " model_weight_path = './weight/DeepMFPP-Best.pth'\n", " MFPs_pred = predict(seqs=seqs, data=data, model_path=model_weight_path, threshold=threshold,top_k=top_k,device=device)\n", " \n", " return MFPs_pred" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "length > 51:0\n" ] }, { "data": { "text/plain": [ "[['peptide_1',\n", " 'ARRRRCSDRFRNCPADEALCGRRRR',\n", " 'ADP: 0.5739, BBP: 0.5358, AAP: 0.5204, AIP: 0.5153, DPPIP: 0.5056, ABP: 0.5'],\n", " ['peptide_2',\n", " 'FFHHIFRGIVHVGKTIHKLVTGT',\n", " 'ADP: 0.5633, AIP: 0.5371, BBP: 0.5369, AAP: 0.5235, DPPIP: 0.5084, SBP: 0.5065, AHP: 0.5059, APP: 0.5027, ACP: 0.502'],\n", " ['peptide_3',\n", " 'GLRKRLRKFRNKIKEKLKKIGQKIQGFVPKLAPRTDY',\n", " 'AAP: 0.5418, ADP: 0.5346, ABP: 0.5167, DPPIP: 0.5162, AIP: 0.5081, QSP: 0.5047, APP: 0.5034'],\n", " ['peptide_4',\n", " 'FLGALWNVAKSVF',\n", " 'ADP: 0.5684, BBP: 0.5619, AHP: 0.5381, AAP: 0.5319, AIP: 0.5189, ACP: 0.5124, QSP: 0.5104, SBP: 0.5059, DPPIP: 0.5012'],\n", " ['peptide_5',\n", " 'KIKSCYYLPCFVTS',\n", " 'ADP: 0.5862, BBP: 0.5636, ACP: 0.5271, AHP: 0.5266, AIP: 0.5244, AAP: 0.5149, DPPIP: 0.5111, QSP: 0.5093, APP: 0.5074']]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "out = MFPs_classifier(file_path,threshold=0.5,top_k=0)\n", "out" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0.4508)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = -0.1974\n", "sigmoid(torch.tensor(x))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('ADP', 0.5739), ('BBP', 0.5358), ('AAP', 0.5204), ('AIP', 0.5153), ('DPPIP', 0.5056), ('ABP', 0.5001)] 6\n", "{'ADP': 0.5739, 'BBP': 0.5358, 'AAP': 0.5204}\n" ] } ], "source": [ "original_dict = {'ADP': 0.5739, 'BBP': 0.5358, 'AAP': 0.5204, 'AIP': 0.5153, 'DPPIP': 0.5056, 'ABP': 0.5001}\n", "n = 3 # 要保留的键值对数量\n", "\n", "sliced_items = list(original_dict.items())\n", "print(sliced_items,len(sliced_items))\n", "sliced_dict = dict(sliced_items[:n])\n", "print(sliced_dict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "env3.8", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }