Upload the lib
Browse files- PD_pLMProbXDiff/DataSetPack.py +0 -0
- PD_pLMProbXDiff/ModelPack.py +0 -0
- PD_pLMProbXDiff/PostMDPack.py +375 -0
- PD_pLMProbXDiff/TrainerPack.py +0 -0
- PD_pLMProbXDiff/UtilityPack.py +671 -0
PD_pLMProbXDiff/DataSetPack.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PD_pLMProbXDiff/ModelPack.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PD_pLMProbXDiff/PostMDPack.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
|
| 12 |
+
import linecache
|
| 13 |
+
import re
|
| 14 |
+
|
| 15 |
+
from Bio.PDB import PDBParser, PDBIO
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
from Bio.PDB import PDBIO
|
| 19 |
+
from Bio.PDB import PDBParser
|
| 20 |
+
from Bio.PDB import Superimposer
|
| 21 |
+
from Bio.PDB.vectors import calc_angle, calc_dihedral
|
| 22 |
+
import Bio.PDB.vectors
|
| 23 |
+
#
|
| 24 |
+
from Bio.PDB.DSSP import DSSP # add try a self-made one
|
| 25 |
+
# from Bio.PDB.DSSP_SelfMade import DSSP_SelfMade # add try a self-made one
|
| 26 |
+
|
| 27 |
+
resdict = {
|
| 28 |
+
"ALA": "A",
|
| 29 |
+
"CYS": "C",
|
| 30 |
+
"ASP": "D",
|
| 31 |
+
"GLU": "E",
|
| 32 |
+
"PHE": "F",
|
| 33 |
+
"GLY": "G",
|
| 34 |
+
"HIS": "H",
|
| 35 |
+
"ILE": "I",
|
| 36 |
+
"LYS": "K",
|
| 37 |
+
"LEU": "L",
|
| 38 |
+
"MET": "M",
|
| 39 |
+
"ASN": "N",
|
| 40 |
+
"PRO": "P",
|
| 41 |
+
"GLN": "Q",
|
| 42 |
+
"ARG": "R",
|
| 43 |
+
"SER": "S",
|
| 44 |
+
"THR": "T",
|
| 45 |
+
"VAL": "V",
|
| 46 |
+
"TRP": "W",
|
| 47 |
+
"TYR": "Y",
|
| 48 |
+
}
|
| 49 |
+
# using those from force field file
|
| 50 |
+
#
|
| 51 |
+
resdict = {
|
| 52 |
+
"ALA": "A",
|
| 53 |
+
"ARG": "R",
|
| 54 |
+
"ASN": "N",
|
| 55 |
+
"ASP": "D",
|
| 56 |
+
"CYS": "C",
|
| 57 |
+
"GLN": "Q",
|
| 58 |
+
"GLU": "E",
|
| 59 |
+
"GLY": "G",
|
| 60 |
+
"HIS": "H",
|
| 61 |
+
"HSD": "H",
|
| 62 |
+
"HSE": "H",
|
| 63 |
+
"HSP": "H",
|
| 64 |
+
"ILE": "I",
|
| 65 |
+
"LYS": "K",
|
| 66 |
+
"LEU": "L",
|
| 67 |
+
"MET": "M",
|
| 68 |
+
"PHE": "F",
|
| 69 |
+
"PRO": "P",
|
| 70 |
+
"SER": "S",
|
| 71 |
+
"THR": "T",
|
| 72 |
+
"TRP": "W",
|
| 73 |
+
"TYR": "Y",
|
| 74 |
+
"VAL": "V",
|
| 75 |
+
|
| 76 |
+
}
|
| 77 |
+
#
|
| 78 |
+
# SMD setup
|
| 79 |
+
SMD_Vel = 0.0001 # A/timestep
|
| 80 |
+
|
| 81 |
+
# step_data * SMD_Vel = pulling_dist
|
| 82 |
+
|
| 83 |
+
def collect_geo_of_backbone(chain):
|
| 84 |
+
prev = "0"
|
| 85 |
+
rad = 180.0 / math.pi
|
| 86 |
+
# result
|
| 87 |
+
resu = {"AA":[],\
|
| 88 |
+
"Bond_CA_N":[],"Bond_CA_C":[],"Bond_N_C1":[],\
|
| 89 |
+
"Angl_CA1_C1_N":[],"Angl_C1_N_CA":[],"Angl_N_CA_C":[],\
|
| 90 |
+
"Dihe_PHI":[],"Dihe_PSI":[],"Dihe_OME":[]}
|
| 91 |
+
#
|
| 92 |
+
for res in chain:
|
| 93 |
+
if res.get_resname() in resdict.keys():
|
| 94 |
+
|
| 95 |
+
# seq += resdict[res.get_resname()]
|
| 96 |
+
resu["AA"].append(resdict[res.get_resname()])
|
| 97 |
+
# ToDo, check whether this res has N, CA, C
|
| 98 |
+
# if not (res.has_key("N") and res.has_key("NA") and res.has_key("C")):
|
| 99 |
+
# print("Key backbone atom is missing")
|
| 100 |
+
|
| 101 |
+
if prev == "0":
|
| 102 |
+
# 1st AA:
|
| 103 |
+
N_prev = res["N"]
|
| 104 |
+
CA_prev = res["CA"]
|
| 105 |
+
C_prev = res["C"]
|
| 106 |
+
# update the key
|
| 107 |
+
prev = "1"
|
| 108 |
+
else:
|
| 109 |
+
n1 = N_prev.get_vector()
|
| 110 |
+
ca1 = CA_prev.get_vector()
|
| 111 |
+
c1 = C_prev.get_vector()
|
| 112 |
+
|
| 113 |
+
# print(res)
|
| 114 |
+
C_curr = res["C"]
|
| 115 |
+
N_curr = res["N"]
|
| 116 |
+
CA_curr = res["CA"]
|
| 117 |
+
|
| 118 |
+
# get the coordinates
|
| 119 |
+
c = C_curr.get_vector()
|
| 120 |
+
n = N_curr.get_vector()
|
| 121 |
+
ca = CA_curr.get_vector()
|
| 122 |
+
|
| 123 |
+
# get the measurement
|
| 124 |
+
ca1_c1_n_ThisAngle = calc_angle(ca1, c1, n)*rad
|
| 125 |
+
c1_n_ca_ThisAngle = calc_angle(c1, n, ca)*rad
|
| 126 |
+
n_ca_c_ThisAngle = calc_angle(n, ca, c)*rad
|
| 127 |
+
|
| 128 |
+
ca_n_ThisBond = CA_curr - N_curr
|
| 129 |
+
ca_c_ThisBond = CA_curr - C_curr
|
| 130 |
+
n_c1_ThisBond = N_curr - C_prev
|
| 131 |
+
|
| 132 |
+
ThisPsi = calc_dihedral(n1, ca1, c1, n) # degree
|
| 133 |
+
ThisOmega = calc_dihedral(ca1, c1, n, ca) # degree
|
| 134 |
+
ThisPhi = calc_dihedral(c1, n, ca, c) # degree
|
| 135 |
+
|
| 136 |
+
# store the results
|
| 137 |
+
# n1-ca1-c1--n-ca-c--n2-ca2-c2
|
| 138 |
+
resu["Bond_CA_N"].append(ca_n_ThisBond)
|
| 139 |
+
resu["Bond_CA_C"].append(ca_c_ThisBond)
|
| 140 |
+
resu["Bond_N_C1"].append(n_c1_ThisBond) # peptide bond
|
| 141 |
+
#
|
| 142 |
+
resu["Angl_CA1_C1_N"].append(ca1_c1_n_ThisAngle)
|
| 143 |
+
resu["Angl_C1_N_CA"].append(c1_n_ca_ThisAngle)
|
| 144 |
+
resu["Angl_N_CA_C"].append(n_ca_c_ThisAngle)
|
| 145 |
+
#
|
| 146 |
+
resu["Dihe_PHI"].append(ThisPhi)
|
| 147 |
+
resu["Dihe_PSI"].append(ThisPsi)
|
| 148 |
+
resu["Dihe_OME"].append(ThisOmega)
|
| 149 |
+
|
| 150 |
+
# update the AA info
|
| 151 |
+
N_prev = res["N"]
|
| 152 |
+
CA_prev = res["CA"]
|
| 153 |
+
C_prev = res["C"]
|
| 154 |
+
|
| 155 |
+
# summerize the result
|
| 156 |
+
return resu
|
| 157 |
+
#
|
| 158 |
+
def collect_multi_chain_AA_info(pdb_file):
|
| 159 |
+
parser = PDBParser()
|
| 160 |
+
structure = parser.get_structure("sample", pdb_file)
|
| 161 |
+
resu_full = {"Chain":[],"AA":{}}
|
| 162 |
+
for chain in structure.get_chains():
|
| 163 |
+
this_chain_id = chain.get_id()
|
| 164 |
+
# print('Working on Chain ', this_chain_id)
|
| 165 |
+
# working on one chain; Assume there is only one chain
|
| 166 |
+
resu_full["Chain"].append(this_chain_id)
|
| 167 |
+
resu_test = collect_geo_of_backbone(chain)
|
| 168 |
+
resu_full["AA"][this_chain_id]=resu_test["AA"]
|
| 169 |
+
# can add more
|
| 170 |
+
|
| 171 |
+
return resu_full
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# read one record
|
| 176 |
+
|
| 177 |
+
# plot one record ONLY in the non-empty cases
|
| 178 |
+
#
|
| 179 |
+
def get_one_force_record(ii, resu_file_name_list):
|
| 180 |
+
# ii = pick_file_list[i]
|
| 181 |
+
pdb_id = resu_file_name_list['PDB_ID'][ii]
|
| 182 |
+
data_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/collect_results/smd_resu.dat'
|
| 183 |
+
data = np.genfromtxt(data_one_file)
|
| 184 |
+
# print(data.shape)
|
| 185 |
+
# kernel = np.ones(kernel_size) / kernel_size
|
| 186 |
+
|
| 187 |
+
# focus on disp-force curve
|
| 188 |
+
# print('# of data point: ', data.shape[0])
|
| 189 |
+
disp_data = data[:,1]
|
| 190 |
+
force_data = data[:,7]
|
| 191 |
+
|
| 192 |
+
# + add the pulling point info
|
| 193 |
+
# pulling point disp
|
| 194 |
+
step_data = data[:,0]
|
| 195 |
+
setdata_one_file = resu_file_name_list['Path'][ii]+'/1_working_dir/box_dimension_after_eq.dat'
|
| 196 |
+
line_4 = linecache.getline(setdata_one_file, 4)
|
| 197 |
+
SMD_Vel = float(line_4.split()[2])
|
| 198 |
+
pull_data = SMD_Vel*step_data
|
| 199 |
+
|
| 200 |
+
# force_data_convolved_10 = np.convolve(force_data, kernel, mode='same')
|
| 201 |
+
return disp_data, force_data, pdb_id, pull_data
|
| 202 |
+
|
| 203 |
+
# collect AA from the record
|
| 204 |
+
def get_one_AA_record(ii, resu_file_name_list):
|
| 205 |
+
# ii = pick_file_list[i]
|
| 206 |
+
# TestProt_chain_0_after_psf.pdb
|
| 207 |
+
pdb_file = resu_file_name_list['Path'][ii]+'/1_working_dir/TestProt_chain_0_after_psf.pdb'
|
| 208 |
+
|
| 209 |
+
resu_full = collect_multi_chain_AA_info(pdb_file)
|
| 210 |
+
# Here, we assume there is only one chain in the file, which is the case for tensile test
|
| 211 |
+
# AA_seq = resu_full["AA"][resu_full["Chain"][0]]
|
| 212 |
+
AA_seq = ''.join(resu_full["AA"][resu_full["Chain"][0]])
|
| 213 |
+
|
| 214 |
+
return AA_seq
|
| 215 |
+
|
| 216 |
+
# smooth functions
|
| 217 |
+
def conv_one_record(force_data, kernel_size):
|
| 218 |
+
kernel = np.ones(kernel_size) / kernel_size
|
| 219 |
+
force_data_convolved = np.convolve(force_data, kernel, mode='same')
|
| 220 |
+
|
| 221 |
+
return force_data_convolved
|
| 222 |
+
|
| 223 |
+
from math import factorial
|
| 224 |
+
|
| 225 |
+
from scipy.ndimage.filters import uniform_filter1d
|
| 226 |
+
#
|
| 227 |
+
# function to smooth the data
|
| 228 |
+
def savitzky_golay(y, window_size, order, deriv=0, rate=1):
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
# window_size = np.abs(np.int(window_size))
|
| 232 |
+
window_size = np.abs(int(window_size))
|
| 233 |
+
# order = np.abs(np.int(order))
|
| 234 |
+
order = np.abs(int(order))
|
| 235 |
+
except ValueError:
|
| 236 |
+
raise ValueError("window_size and order have to be of type int")
|
| 237 |
+
|
| 238 |
+
if window_size % 2 != 1 or window_size < 1:
|
| 239 |
+
raise TypeError("window_size size must be a positive odd number")
|
| 240 |
+
if window_size < order + 2:
|
| 241 |
+
raise TypeError("window_size is too small for the polynomials order")
|
| 242 |
+
order_range = range(order+1)
|
| 243 |
+
half_window = (window_size -1) // 2
|
| 244 |
+
# precompute coefficients
|
| 245 |
+
b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window+1)])
|
| 246 |
+
m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv)
|
| 247 |
+
# pad the signal at the extremes with
|
| 248 |
+
# values taken from the signal itself
|
| 249 |
+
firstvals = y[0] - np.abs( y[1:half_window+1][::-1] - y[0] )
|
| 250 |
+
lastvals = y[-1] + np.abs(y[-half_window-1:-1][::-1] - y[-1])
|
| 251 |
+
y = np.concatenate((firstvals, y, lastvals))
|
| 252 |
+
|
| 253 |
+
return np.convolve( m[::-1], y, mode='valid')
|
| 254 |
+
|
| 255 |
+
#
|
| 256 |
+
def read_gap_values_from_dat(file):
|
| 257 |
+
# line_2 = linecache.getline('r"'+file+'"', 2)
|
| 258 |
+
# line_3 = linecache.getline('r"'+file+'"', 3)
|
| 259 |
+
line_2 = linecache.getline(file, 2)
|
| 260 |
+
line_3 = linecache.getline(file, 3)
|
| 261 |
+
# get the values
|
| 262 |
+
ini_gap = float(line_2.split()[2])
|
| 263 |
+
fin_gap = float(line_3.split()[2])
|
| 264 |
+
return ini_gap, fin_gap
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def read_one_array_from_df(one_record):
|
| 268 |
+
return np.array(list(map(float, one_record.split(" "))))
|
| 269 |
+
#
|
| 270 |
+
def read_string_find_max(reco):
|
| 271 |
+
x = read_one_array_from_df(reco)
|
| 272 |
+
return np.amax(x)
|
| 273 |
+
|
| 274 |
+
def read_string_find_max(reco):
|
| 275 |
+
x = read_one_array_from_df(reco)
|
| 276 |
+
return np.amax(x)
|
| 277 |
+
#
|
| 278 |
+
def cal_seq_end_gap(x):
|
| 279 |
+
inc_gap_arr = x['posi_data']-x['posi_data'][0]
|
| 280 |
+
ini_gap = x['ini_gap']
|
| 281 |
+
gap_arr = ini_gap+inc_gap_arr
|
| 282 |
+
|
| 283 |
+
return gap_arr
|
| 284 |
+
#
|
| 285 |
+
def cal_pull_end_gap(x):
|
| 286 |
+
inc_gap_arr = x['pull_data'] # -x['pull_data'][0]
|
| 287 |
+
ini_gap = x['ini_gap']
|
| 288 |
+
gap_arr = ini_gap+inc_gap_arr
|
| 289 |
+
|
| 290 |
+
return gap_arr
|
| 291 |
+
|
| 292 |
+
#
|
| 293 |
+
# pick the force at the unfolding of every residues
|
| 294 |
+
|
| 295 |
+
def simplify_NormPull_FORCEnF_rec(n_fold,this_seq_len,this_n_PullGap_arr,this_Force_arr):
|
| 296 |
+
|
| 297 |
+
target_pull_gap_list = [1./(this_seq_len*n_fold)*(jj+0) for jj in range(this_seq_len*n_fold)]
|
| 298 |
+
target_pull_gap_list.append(1.)
|
| 299 |
+
|
| 300 |
+
# retrive the force values
|
| 301 |
+
target_force = []
|
| 302 |
+
for jj in range(len(target_pull_gap_list)):
|
| 303 |
+
# for jj in range(10):
|
| 304 |
+
this_t_n_PullGap = target_pull_gap_list[jj]
|
| 305 |
+
|
| 306 |
+
if this_t_n_PullGap<this_n_PullGap_arr[0]:
|
| 307 |
+
this_t_F = 0.
|
| 308 |
+
else:
|
| 309 |
+
# find the neareast one
|
| 310 |
+
disp_arr = np.abs(this_n_PullGap_arr - this_t_n_PullGap)
|
| 311 |
+
pick_id = np.argmin(disp_arr)
|
| 312 |
+
this_t_F = this_Force_arr[pick_id]
|
| 313 |
+
#
|
| 314 |
+
target_force.append(this_t_F)
|
| 315 |
+
#
|
| 316 |
+
target_pull_gap_arr = np.array(target_pull_gap_list)
|
| 317 |
+
target_force_arr = np.array(target_force)
|
| 318 |
+
|
| 319 |
+
# for delivery
|
| 320 |
+
resu = {}
|
| 321 |
+
resu['sample_NormPullGap'] = target_pull_gap_arr
|
| 322 |
+
resu['smaple_FORCE'] = target_force_arr
|
| 323 |
+
return resu
|
| 324 |
+
|
| 325 |
+
# read input conditions
|
| 326 |
+
def read_input_model_A(file_path):
|
| 327 |
+
with open(file_path, 'r') as f:
|
| 328 |
+
txt = f.read()
|
| 329 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
| 330 |
+
arr = np.loadtxt(nums)
|
| 331 |
+
# print(arr)
|
| 332 |
+
# print(arr[0])
|
| 333 |
+
|
| 334 |
+
return arr
|
| 335 |
+
|
| 336 |
+
def read_input_model_B(file_path):
|
| 337 |
+
with open(file_path, 'r') as f:
|
| 338 |
+
txt = f.read()
|
| 339 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
| 340 |
+
# arr = np.loadtxt(nums)
|
| 341 |
+
arr = np.loadtxt( [nums[0].replace('\n','')] )
|
| 342 |
+
# print(arr)
|
| 343 |
+
# print(arr[0])
|
| 344 |
+
|
| 345 |
+
return arr
|
| 346 |
+
|
| 347 |
+
def read_one_input_arr_from_txt(file_path):
|
| 348 |
+
with open(file_path, 'r') as f:
|
| 349 |
+
txt = f.read()
|
| 350 |
+
nums = re.findall(r'\[([^][]+)\]', txt)
|
| 351 |
+
# arr = np.loadtxt(nums)
|
| 352 |
+
arr = np.loadtxt( [nums[0].replace('\n','')] )
|
| 353 |
+
# print(arr)
|
| 354 |
+
# print(arr[0])
|
| 355 |
+
|
| 356 |
+
return arr
|
| 357 |
+
|
| 358 |
+
# this only for this version, in folder3 it is updated
|
| 359 |
+
# # for folder3
|
| 360 |
+
# def recover_input_for_model_B(file_path, seq_len):
|
| 361 |
+
# raw_arr = read_one_input_arr_from_txt(file_path)
|
| 362 |
+
# arr = raw_arr[1:1+seq_len+1]
|
| 363 |
+
# return arr
|
| 364 |
+
# for folder2
|
| 365 |
+
def recover_input_for_model_B_ver2(file_path, seq_len):
|
| 366 |
+
raw_arr = read_one_input_arr_from_txt(file_path)
|
| 367 |
+
arr = raw_arr[0:0+seq_len+1]
|
| 368 |
+
return arr
|
| 369 |
+
|
| 370 |
+
# for folder3
|
| 371 |
+
def recover_input_for_model_B_ver3(file_path, seq_len):
|
| 372 |
+
raw_arr = read_one_input_arr_from_txt(file_path)
|
| 373 |
+
arr = np.zeros(seq_len+1)
|
| 374 |
+
arr[1:1+seq_len] = raw_arr[0:0+seq_len]
|
| 375 |
+
return arr
|
PD_pLMProbXDiff/TrainerPack.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
PD_pLMProbXDiff/UtilityPack.py
ADDED
|
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ==========================================================
|
| 2 |
+
# Utility functions
|
| 3 |
+
# ==========================================================
|
| 4 |
+
import os
|
| 5 |
+
from scipy.interpolate import CubicSpline, PchipInterpolator, Akima1DInterpolator
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
from Bio.PDB import PDBParser
|
| 11 |
+
from Bio.PDB.DSSP import DSSP
|
| 12 |
+
from Bio.PDB import PDBList
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
import esm
|
| 17 |
+
# =========================================================
|
| 18 |
+
# create a folder path if not exist
|
| 19 |
+
def create_path(this_path):
|
| 20 |
+
if not os.path.exists(this_path):
|
| 21 |
+
print('Creating the given path...')
|
| 22 |
+
os.mkdir (this_path)
|
| 23 |
+
path_stat = 1
|
| 24 |
+
print('Done.')
|
| 25 |
+
else:
|
| 26 |
+
print('The given path already exists!')
|
| 27 |
+
path_stat = 2
|
| 28 |
+
return path_stat
|
| 29 |
+
|
| 30 |
+
# ==========================================================
|
| 31 |
+
|
| 32 |
+
# measure the model size
|
| 33 |
+
def params (model):
|
| 34 |
+
pytorch_total_params = sum(p.numel() for p in model.parameters())
|
| 35 |
+
pytorch_total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 36 |
+
|
| 37 |
+
print ("Total parameters: ", pytorch_total_params," trainable parameters: ", pytorch_total_params_trainable)
|
| 38 |
+
|
| 39 |
+
# ==========================================================
|
| 40 |
+
# initialization function for dict for models
|
| 41 |
+
def prepare_UNet_keys(write_dict):
|
| 42 |
+
# if not setted, using the default
|
| 43 |
+
Full_Keys=['dim', 'text_embed_dim', 'num_resnet_blocks', 'cond_dim', 'num_image_tokens', 'num_time_tokens', 'learned_sinu_pos_emb_dim', 'out_dim', 'dim_mults', 'cond_images_channels', 'channels', 'channels_out', 'attn_dim_head', 'attn_heads', 'ff_mult', 'lowres_cond', 'layer_attns', 'layer_attns_depth', 'layer_attns_add_text_cond', 'attend_at_middle', 'layer_cross_attns', 'use_linear_attn', 'use_linear_cross_attn', 'cond_on_text', 'max_text_len', 'init_dim', 'resnet_groups', 'init_conv_kernel_size', 'init_cross_embed', 'init_cross_embed_kernel_sizes', 'cross_embed_downsample', 'cross_embed_downsample_kernel_sizes', 'attn_pool_text', 'attn_pool_num_latents', 'dropout', 'memory_efficient', 'init_conv_to_final_conv_residual', 'use_global_context_attn', 'scale_skip_connection', 'final_resnet_block', 'final_conv_kernel_size', 'cosine_sim_attn', 'self_cond', 'combine_upsample_fmaps', 'pixel_shuffle_upsample', 'beginning_and_final_conv_present']
|
| 44 |
+
# initialization
|
| 45 |
+
PKeys={}
|
| 46 |
+
for key in Full_Keys:
|
| 47 |
+
PKeys[key]=None
|
| 48 |
+
# modify if keys are provided
|
| 49 |
+
for write_key in write_dict.keys():
|
| 50 |
+
if write_key in PKeys.keys():
|
| 51 |
+
PKeys[write_key]=write_dict[write_key]
|
| 52 |
+
else:
|
| 53 |
+
print("Wrong key found: ", write_key)
|
| 54 |
+
|
| 55 |
+
return PKeys
|
| 56 |
+
|
| 57 |
+
def prepare_ModelB_keys(write_dict):
|
| 58 |
+
Full_Keys=['timesteps', 'dim', 'pred_dim', 'loss_type', 'elucidated', 'padding_idx', 'cond_dim', 'text_embed_dim', 'input_tokens', 'sequence_embed', 'embed_dim_position', 'max_text_len', 'cond_images_channels', 'max_length', 'device']
|
| 59 |
+
# initialization
|
| 60 |
+
PKeys={}
|
| 61 |
+
for key in Full_Keys:
|
| 62 |
+
PKeys[key]=None
|
| 63 |
+
# modify if keys are provided
|
| 64 |
+
for write_key in write_dict.keys():
|
| 65 |
+
if write_key in PKeys.keys():
|
| 66 |
+
PKeys[write_key]=write_dict[write_key]
|
| 67 |
+
else:
|
| 68 |
+
print("Wrong key found: ", write_key)
|
| 69 |
+
|
| 70 |
+
return PKeys
|
| 71 |
+
|
| 72 |
+
def modify_keys(old_dict,write_dict):
|
| 73 |
+
new_dict = old_dict.copy()
|
| 74 |
+
for w_key in write_dict.keys():
|
| 75 |
+
if w_key in old_dict.keys():
|
| 76 |
+
new_dict[w_key]=write_dict[w_key]
|
| 77 |
+
else:
|
| 78 |
+
print("Alien key found: ", w_key)
|
| 79 |
+
return new_dict
|
| 80 |
+
|
| 81 |
+
# ==========================================================
|
| 82 |
+
# mix two NForce record for a given AA length
|
| 83 |
+
# ==========================================================
|
| 84 |
+
def mixing_two_FORCE_for_AA_Len(NGap1,Force1,NGap2,Force2,LenAA,mix_fac):
|
| 85 |
+
N = np.amax([len(NGap1), len(NGap2)])
|
| 86 |
+
N_Base = math.ceil(N*2)
|
| 87 |
+
fun_PI_0 = PchipInterpolator(NGap1,Force1)
|
| 88 |
+
fun_PI_1 = PchipInterpolator(NGap2,Force2)
|
| 89 |
+
xx=np.linspace(0,1,N_Base)
|
| 90 |
+
yy=fun_PI_0(xx)*mix_fac+fun_PI_1(xx)*(1-mix_fac)
|
| 91 |
+
fun_PI = PchipInterpolator(xx,yy)
|
| 92 |
+
# discrete result
|
| 93 |
+
x1=np.linspace(0,1,LenAA+1)
|
| 94 |
+
y1=fun_PI(x1)
|
| 95 |
+
return fun_PI, x1, y1
|
| 96 |
+
|
| 97 |
+
# =========================================================
|
| 98 |
+
#
|
| 99 |
+
# =========================================================
|
| 100 |
+
def get_Model_A_error (fname, cond, plotit=True, ploterror=False):
|
| 101 |
+
|
| 102 |
+
sec_structure,sec_structure_3state, sequence=get_DSSP_result (fname)
|
| 103 |
+
sscount=[]
|
| 104 |
+
length = len (sec_structure)
|
| 105 |
+
sscount.append (sec_structure.count('H')/length)
|
| 106 |
+
sscount.append (sec_structure.count('E')/length)
|
| 107 |
+
sscount.append (sec_structure.count('T')/length)
|
| 108 |
+
sscount.append (sec_structure.count('~')/length)
|
| 109 |
+
sscount.append (sec_structure.count('B')/length)
|
| 110 |
+
sscount.append (sec_structure.count('G')/length)
|
| 111 |
+
sscount.append (sec_structure.count('I')/length)
|
| 112 |
+
sscount.append (sec_structure.count('S')/length)
|
| 113 |
+
sscount=np.asarray (sscount)
|
| 114 |
+
|
| 115 |
+
error=np.abs(sscount-cond)
|
| 116 |
+
print ("Abs error per SS structure type (H, E, T, ~, B, G, I S): ", error)
|
| 117 |
+
|
| 118 |
+
if ploterror:
|
| 119 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
| 120 |
+
plt.plot (error, 'o-', label='Error over SS type')
|
| 121 |
+
plt.legend()
|
| 122 |
+
plt.ylabel ('SS content')
|
| 123 |
+
plt.show()
|
| 124 |
+
|
| 125 |
+
x=np.linspace (0, 7, 8)
|
| 126 |
+
|
| 127 |
+
sslabels=['H','E','T','~','B','G','I','S']
|
| 128 |
+
|
| 129 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
| 130 |
+
|
| 131 |
+
ax.bar(x-0.15, cond, width=0.3, color='b', align='center')
|
| 132 |
+
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
|
| 133 |
+
|
| 134 |
+
ax.set_ylim([0, 1])
|
| 135 |
+
|
| 136 |
+
plt.xticks(range(len(sslabels)), sslabels, size='medium')
|
| 137 |
+
plt.legend (['GT','Prediction'])
|
| 138 |
+
|
| 139 |
+
plt.ylabel ('SS content')
|
| 140 |
+
plt.show()
|
| 141 |
+
|
| 142 |
+
######################## 3 types
|
| 143 |
+
|
| 144 |
+
sscount=[]
|
| 145 |
+
length = len (sec_structure)
|
| 146 |
+
sscount.append (sec_structure_3state.count('H')/length)
|
| 147 |
+
sscount.append (sec_structure_3state.count('E')/length)
|
| 148 |
+
sscount.append (sec_structure_3state.count('~')/length)
|
| 149 |
+
cond_p=[np.sum([cond[0],cond[5], cond[6]]), np.sum ([cond[1], cond[4]]), np.sum([cond[2],cond[3],cond[7]]) ]
|
| 150 |
+
|
| 151 |
+
print ("cond 3type: ",cond_p)
|
| 152 |
+
sscount=np.asarray (sscount)
|
| 153 |
+
|
| 154 |
+
error3=np.abs(sscount-cond_p)
|
| 155 |
+
print ("Abs error per 3-type SS structure type (C, H, E): ", error)
|
| 156 |
+
|
| 157 |
+
if ploterror:
|
| 158 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
| 159 |
+
|
| 160 |
+
plt.plot (error3, 'o-', label='Error over SS type')
|
| 161 |
+
plt.legend()
|
| 162 |
+
plt.ylabel ('SS content')
|
| 163 |
+
plt.show()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
x=np.linspace (0,2, 3)
|
| 167 |
+
|
| 168 |
+
sslabels=['H','E', '~' ]
|
| 169 |
+
|
| 170 |
+
#ax = plt.subplot(111, figsize=(4,4))
|
| 171 |
+
fig, ax = plt.subplots(1, 1, figsize=(6,3))
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
ax.bar(x-0.15, cond_p, width=0.3, color='b', align='center')
|
| 175 |
+
ax.bar(x+0.15, sscount, width=0.3, color='r', align='center')
|
| 176 |
+
|
| 177 |
+
ax.set_ylim([0, 1])
|
| 178 |
+
|
| 179 |
+
plt.xticks(range(len(sslabels)), sslabels, size='medium')
|
| 180 |
+
plt.legend (['GT','Prediction'])
|
| 181 |
+
|
| 182 |
+
plt.ylabel ('SS content')
|
| 183 |
+
plt.show()
|
| 184 |
+
|
| 185 |
+
return error
|
| 186 |
+
|
| 187 |
+
def get_DSSP_result (fname):
|
| 188 |
+
pdb_list = [fname]
|
| 189 |
+
|
| 190 |
+
# parse structure
|
| 191 |
+
p = PDBParser()
|
| 192 |
+
for i in pdb_list:
|
| 193 |
+
structure = p.get_structure(i, fname)
|
| 194 |
+
# use only the first model
|
| 195 |
+
model = structure[0]
|
| 196 |
+
# calculate DSSP
|
| 197 |
+
dssp = DSSP(model, fname, file_type='PDB' )
|
| 198 |
+
# extract sequence and secondary structure from the DSSP tuple
|
| 199 |
+
sequence = ''
|
| 200 |
+
sec_structure = ''
|
| 201 |
+
for z in range(len(dssp)):
|
| 202 |
+
a_key = list(dssp.keys())[z]
|
| 203 |
+
sequence += dssp[a_key][1]
|
| 204 |
+
sec_structure += dssp[a_key][2]
|
| 205 |
+
|
| 206 |
+
#print(i)
|
| 207 |
+
#print(sequence)
|
| 208 |
+
#print(sec_structure)
|
| 209 |
+
#
|
| 210 |
+
# The DSSP codes for secondary structure used here are:
|
| 211 |
+
# ===== ====
|
| 212 |
+
# Code Structure
|
| 213 |
+
# ===== ====
|
| 214 |
+
# H Alpha helix (4-12)
|
| 215 |
+
# B Isolated beta-bridge residue
|
| 216 |
+
# E Strand
|
| 217 |
+
# G 3-10 helix
|
| 218 |
+
# I Pi helix
|
| 219 |
+
# T Turn
|
| 220 |
+
# S Bend
|
| 221 |
+
# ~ None
|
| 222 |
+
# ===== ====
|
| 223 |
+
#
|
| 224 |
+
|
| 225 |
+
sec_structure = sec_structure.replace('-', '~')
|
| 226 |
+
sec_structure_3state=sec_structure
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# if desired, convert DSSP's 8-state assignments into 3-state [C - coil, E - extended (beta-strand), H - helix]
|
| 230 |
+
sec_structure_3state = sec_structure_3state.replace('H', 'H') #0
|
| 231 |
+
sec_structure_3state = sec_structure_3state.replace('E', 'E')
|
| 232 |
+
sec_structure_3state = sec_structure_3state.replace('T', '~')
|
| 233 |
+
sec_structure_3state = sec_structure_3state.replace('~', '~')
|
| 234 |
+
sec_structure_3state = sec_structure_3state.replace('B', 'E')
|
| 235 |
+
sec_structure_3state = sec_structure_3state.replace('G', 'H') #5
|
| 236 |
+
sec_structure_3state = sec_structure_3state.replace('I', 'H') #6
|
| 237 |
+
sec_structure_3state = sec_structure_3state.replace('S', '~')
|
| 238 |
+
return sec_structure,sec_structure_3state, sequence
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def string_diff (seq1, seq2):
|
| 242 |
+
return sum(1 for a, b in zip(seq1, seq2) if a != b) + abs(len(seq1) - len(seq2))
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# ============================================================
|
| 246 |
+
# on esm, rebuild AA sequence from embedding
|
| 247 |
+
# ============================================================
|
| 248 |
+
import esm
|
| 249 |
+
|
| 250 |
+
def decode_one_ems_token_rec(this_token, esm_alphabet):
|
| 251 |
+
# print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
|
| 252 |
+
# print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
|
| 253 |
+
# print( (this_token==100).nonzero(as_tuple=True)[0]==None )
|
| 254 |
+
|
| 255 |
+
id_b=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
|
| 256 |
+
id_e=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if len(id_e)==0:
|
| 260 |
+
# no ending for this one, so id_e points to the end
|
| 261 |
+
id_e=len(this_token)
|
| 262 |
+
else:
|
| 263 |
+
id_e=id_e[0]
|
| 264 |
+
if len(id_b)==0:
|
| 265 |
+
id_b=0
|
| 266 |
+
else:
|
| 267 |
+
id_b=id_b[-1]
|
| 268 |
+
|
| 269 |
+
this_seq = []
|
| 270 |
+
# this_token_used = []
|
| 271 |
+
for ii in range(id_b+1,id_e,1):
|
| 272 |
+
# this_token_used.append(this_token[ii])
|
| 273 |
+
this_seq.append(
|
| 274 |
+
esm_alphabet.get_tok(this_token[ii])
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
this_seq = "".join(this_seq)
|
| 278 |
+
|
| 279 |
+
# print(this_seq)
|
| 280 |
+
# print(len(this_seq))
|
| 281 |
+
# # print(this_token[id_b+1:id_e])
|
| 282 |
+
return this_seq
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def decode_many_ems_token_rec(batch_tokens, esm_alphabet):
|
| 286 |
+
rev_y_seq = []
|
| 287 |
+
for jj in range(len(batch_tokens)):
|
| 288 |
+
# do for one seq: this_seq
|
| 289 |
+
this_seq = decode_one_ems_token_rec(
|
| 290 |
+
batch_tokens[jj], esm_alphabet
|
| 291 |
+
)
|
| 292 |
+
rev_y_seq.append(this_seq)
|
| 293 |
+
return rev_y_seq
|
| 294 |
+
|
| 295 |
+
# ++ for omegafold sequence: treat unknows as X
|
| 296 |
+
uncomm_idx_list = [0, 1, 2, 3, 24, 25, 26, 27, 28, 29, 30, 31, 32]
|
| 297 |
+
|
| 298 |
+
# this one decide the beginning and ending AUTOMATICALLY
|
| 299 |
+
def decode_one_ems_token_rec_for_folding(
|
| 300 |
+
this_token,
|
| 301 |
+
this_logits,
|
| 302 |
+
esm_alphabet,
|
| 303 |
+
esm_model):
|
| 304 |
+
|
| 305 |
+
# print( (this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0] )
|
| 306 |
+
# print( (this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0] )
|
| 307 |
+
# print( (this_token==100).nonzero(as_tuple=True)[0]==None )
|
| 308 |
+
|
| 309 |
+
# 1. use this_token to find the beginning and ending
|
| 310 |
+
# 2. to logits to generate tokens that ONLY contains foldable AAs
|
| 311 |
+
#
|
| 312 |
+
id_b_0=(this_token==esm_alphabet.cls_idx).nonzero(as_tuple=True)[0]
|
| 313 |
+
id_e_0=(this_token==esm_alphabet.eos_idx).nonzero(as_tuple=True)[0]
|
| 314 |
+
|
| 315 |
+
# ------------------------------------------------------------------
|
| 316 |
+
# principle:
|
| 317 |
+
# 1. begin at 0th
|
| 318 |
+
# 2. end as soon as possible: relay on that the first endding is learned
|
| 319 |
+
id_b = 0
|
| 320 |
+
#
|
| 321 |
+
if len(id_e_0)==0:
|
| 322 |
+
id_e=len(this_token)
|
| 323 |
+
else:
|
| 324 |
+
id_e=id_e_0[0]
|
| 325 |
+
# correct if needed
|
| 326 |
+
if id_e<=id_b+1:
|
| 327 |
+
if len(id_e_0)>1:
|
| 328 |
+
id_e=id_e_0[1]
|
| 329 |
+
else:
|
| 330 |
+
id_e=len(this_token)
|
| 331 |
+
# -------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
# # ------------------------------------------------------------------
|
| 334 |
+
# # not perfect
|
| 335 |
+
# # principle:
|
| 336 |
+
# # 1. begin as late as possible
|
| 337 |
+
# # 2. end as soon as possible
|
| 338 |
+
# #
|
| 339 |
+
# if len(id_b_0)==0:
|
| 340 |
+
# id_b=0
|
| 341 |
+
# else:
|
| 342 |
+
# id_b=id_b_0[-1]
|
| 343 |
+
# # so, beginning is set
|
| 344 |
+
# # looking for the nearest ending signal if we can find one
|
| 345 |
+
# # 1. pick those in id_e that id_b<id_e
|
| 346 |
+
# id_e_1=[]
|
| 347 |
+
# for this_e in id_e_0:
|
| 348 |
+
# if this_e>id_b:
|
| 349 |
+
# id_e_1.append(this_e)
|
| 350 |
+
# # 2. check what we find
|
| 351 |
+
# if len(id_e_1)==0:
|
| 352 |
+
# # no endding, id_e points to the end
|
| 353 |
+
# id_e=len(this_token)
|
| 354 |
+
# else:
|
| 355 |
+
# # otherwise, find endding point and pick the first one
|
| 356 |
+
# id_e=id_e_1[0]
|
| 357 |
+
# # 3. if id_b+1==id_e, we still get nothing. So, this is a fake fix
|
| 358 |
+
# if id_e==id_b+1:
|
| 359 |
+
# if len(id_e_1)>1:
|
| 360 |
+
# id_e=id_e_1[1]
|
| 361 |
+
# else:
|
| 362 |
+
# id_e=len(this_token)
|
| 363 |
+
# # --------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
# if id_b>id_e:
|
| 366 |
+
# for debug:
|
| 367 |
+
print("start at: ", id_b)
|
| 368 |
+
print("end at: ", id_e)
|
| 369 |
+
|
| 370 |
+
# along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
|
| 371 |
+
use_logits = this_logits[id_b+1:id_e] # (seq_len_eff, token_len)
|
| 372 |
+
use_logits[:,uncomm_idx_list]=-float('inf')
|
| 373 |
+
use_token = use_logits.max(1).indices
|
| 374 |
+
|
| 375 |
+
# print(use_token)
|
| 376 |
+
|
| 377 |
+
this_seq = []
|
| 378 |
+
# this_token_used = []
|
| 379 |
+
# for ii in range(id_b+1,id_e,1):
|
| 380 |
+
for ii in range(len(use_token)):
|
| 381 |
+
# this_token_used.append(this_token[ii])
|
| 382 |
+
# print(esm_alphabet.get_tok(use_token[ii]))
|
| 383 |
+
# print(ii)
|
| 384 |
+
this_seq.append(
|
| 385 |
+
esm_alphabet.get_tok(use_token[ii])
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
this_seq = "".join(this_seq)
|
| 389 |
+
|
| 390 |
+
# # generate a foldable sequece
|
| 391 |
+
# # map all uncommon ones into X/24
|
| 392 |
+
# for idx, one_token in enumerate( this_token_used):
|
| 393 |
+
# find_it=0
|
| 394 |
+
# for this_uncomm in uncomm_idx_list:
|
| 395 |
+
# find_id=find_id+(this_uncomm==one_token)
|
| 396 |
+
# #
|
| 397 |
+
# if find_id>0:
|
| 398 |
+
# this_token_used[idx]=24 # 24 means X
|
| 399 |
+
# # translate token into sequences
|
| 400 |
+
# this_seq_foldable=[]
|
| 401 |
+
# for one_token in this_token_used:
|
| 402 |
+
# this_seq_foldable.append(
|
| 403 |
+
# esm_alphabet.get_tok(one_token)
|
| 404 |
+
# )
|
| 405 |
+
|
| 406 |
+
# # print(this_seq)
|
| 407 |
+
# # print(len(this_seq))
|
| 408 |
+
# # # print(this_token[id_b+1:id_e])
|
| 409 |
+
# return this_seq, this_seq_foldable
|
| 410 |
+
return this_seq
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def decode_many_ems_token_rec_for_folding(
|
| 414 |
+
batch_tokens,
|
| 415 |
+
batch_logits,
|
| 416 |
+
esm_alphabet,
|
| 417 |
+
esm_model):
|
| 418 |
+
|
| 419 |
+
rev_y_seq = []
|
| 420 |
+
for jj in range(len(batch_tokens)):
|
| 421 |
+
# do for one seq: this_seq
|
| 422 |
+
this_seq = decode_one_ems_token_rec_for_folding(
|
| 423 |
+
batch_tokens[jj],
|
| 424 |
+
batch_logits[jj],
|
| 425 |
+
esm_alphabet,
|
| 426 |
+
esm_model,
|
| 427 |
+
)
|
| 428 |
+
rev_y_seq.append(this_seq)
|
| 429 |
+
return rev_y_seq
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def convert_into_logits(esm_model, result):
|
| 433 |
+
repre=rearrange(
|
| 434 |
+
result,
|
| 435 |
+
'b l c -> b c l'
|
| 436 |
+
)
|
| 437 |
+
with torch.no_grad():
|
| 438 |
+
logits=esm_model.lm_head(repre)
|
| 439 |
+
|
| 440 |
+
return logits
|
| 441 |
+
|
| 442 |
+
# this one return the unmodified tokens and logits
|
| 443 |
+
def convert_into_tokens(model, result, pLM_Model_Name):
|
| 444 |
+
if pLM_Model_Name=='esm2_t33_650M_UR50D' \
|
| 445 |
+
or pLM_Model_Name=='esm2_t36_3B_UR50D' \
|
| 446 |
+
or pLM_Model_Name=='esm2_t30_150M_UR50D' \
|
| 447 |
+
or pLM_Model_Name=='esm2_t12_35M_UR50D' :
|
| 448 |
+
|
| 449 |
+
repre=rearrange(
|
| 450 |
+
result,
|
| 451 |
+
'b c l -> b l c'
|
| 452 |
+
)
|
| 453 |
+
with torch.no_grad():
|
| 454 |
+
logits=model.lm_head(repre) # (b, l, token_dim)
|
| 455 |
+
|
| 456 |
+
tokens=logits.max(2).indices # (b,l)
|
| 457 |
+
|
| 458 |
+
else:
|
| 459 |
+
print("pLM_Model is not defined...")
|
| 460 |
+
return tokens,logits
|
| 461 |
+
# ++
|
| 462 |
+
def convert_into_tokens_using_prob(prob_result, pLM_Model_Name):
|
| 463 |
+
if pLM_Model_Name=='esm2_t33_650M_UR50D' \
|
| 464 |
+
or pLM_Model_Name=='esm2_t36_3B_UR50D' \
|
| 465 |
+
or pLM_Model_Name=='esm2_t30_150M_UR50D' \
|
| 466 |
+
or pLM_Model_Name=='esm2_t12_35M_UR50D' :
|
| 467 |
+
|
| 468 |
+
repre=rearrange(
|
| 469 |
+
prob_result,
|
| 470 |
+
'b c l -> b l c'
|
| 471 |
+
)
|
| 472 |
+
# with torch.no_grad():
|
| 473 |
+
# logits=model.lm_head(repre) # (b, l, token_dim)
|
| 474 |
+
logits = repre
|
| 475 |
+
|
| 476 |
+
tokens=logits.max(2).indices # (b,l)
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
print("pLM_Model is not defined...")
|
| 480 |
+
return tokens,logits
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
#
|
| 484 |
+
def read_mask_from_input(
|
| 485 |
+
# consider different type of inputs
|
| 486 |
+
# raw data: x_data (sequences)
|
| 487 |
+
# tokenized: x_data_tokenized
|
| 488 |
+
tokenized_data=None, # X_train_batch,
|
| 489 |
+
mask_value=None,
|
| 490 |
+
seq_data=None,
|
| 491 |
+
max_seq_length=None,
|
| 492 |
+
):
|
| 493 |
+
# # old:
|
| 494 |
+
# mask = X_train_batch!=mask_value
|
| 495 |
+
# new
|
| 496 |
+
if seq_data!=None:
|
| 497 |
+
# use the real sequence length to create mask
|
| 498 |
+
n_seq = len(seq_data)
|
| 499 |
+
mask = torch.zeros(n_seq, max_seq_length)
|
| 500 |
+
for ii in range(n_seq):
|
| 501 |
+
this_len = len(seq_data[ii])
|
| 502 |
+
mask[ii,1:1+this_len]=1
|
| 503 |
+
mask = mask==1
|
| 504 |
+
#
|
| 505 |
+
elif tokenized_data!=None:
|
| 506 |
+
n_seq = len(tokenized_data)
|
| 507 |
+
mask = tokenized_data!=mask_value
|
| 508 |
+
# fix the beginning part: 0+content+00, not 00+content+00
|
| 509 |
+
for ii in range(n_seq):
|
| 510 |
+
# get all nonzero index
|
| 511 |
+
id_1 = (mask[ii]==True).nonzero(as_tuple=True)[0]
|
| 512 |
+
# correction for ForcPath,
|
| 513 |
+
# pick up 0.0 for zero-force padding at the beginning
|
| 514 |
+
mask[ii,1:id_1[0]]=True
|
| 515 |
+
|
| 516 |
+
return mask
|
| 517 |
+
|
| 518 |
+
# ++ read one length
|
| 519 |
+
def read_one_len_from_padding_vec(
|
| 520 |
+
in_np_array,
|
| 521 |
+
padding_val=0.0,
|
| 522 |
+
):
|
| 523 |
+
mask = in_np_array!=padding_val
|
| 524 |
+
id_list_all_1 = mask.nonzero()[0]
|
| 525 |
+
vec_len = id_list_all_1[-1]+1
|
| 526 |
+
|
| 527 |
+
return vec_len
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
# this one decide the beginning and ending using mask
|
| 531 |
+
def decode_one_ems_token_rec_for_folding_with_mask(
|
| 532 |
+
this_token,
|
| 533 |
+
this_logits,
|
| 534 |
+
esm_alphabet,
|
| 535 |
+
esm_model,
|
| 536 |
+
this_mask,
|
| 537 |
+
):
|
| 538 |
+
# translate all logits into tokens then screen the unmaksed part
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
# along the sequence, we pick only index [id_b+1:id_e]. This exclude the <cls> and <eos>
|
| 542 |
+
use_logits = this_logits # (seq_len_eff, token_len)
|
| 543 |
+
use_logits[:,uncomm_idx_list]=-float('inf')
|
| 544 |
+
use_token = use_logits.max(1).indices
|
| 545 |
+
#
|
| 546 |
+
print(use_token)
|
| 547 |
+
use_token = use_token[this_mask==True]
|
| 548 |
+
# print(use_token)
|
| 549 |
+
|
| 550 |
+
this_seq = []
|
| 551 |
+
# this_token_used = []
|
| 552 |
+
# for ii in range(id_b+1,id_e,1):
|
| 553 |
+
for ii in range(len(use_token)):
|
| 554 |
+
# this_token_used.append(this_token[ii])
|
| 555 |
+
# print(esm_alphabet.get_tok(use_token[ii]))
|
| 556 |
+
# print(ii)
|
| 557 |
+
this_seq.append(
|
| 558 |
+
esm_alphabet.get_tok(use_token[ii])
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
this_seq = "".join(this_seq)
|
| 562 |
+
|
| 563 |
+
return this_seq
|
| 564 |
+
|
| 565 |
+
def decode_many_ems_token_rec_for_folding_with_mask(
|
| 566 |
+
batch_tokens,
|
| 567 |
+
batch_logits,
|
| 568 |
+
esm_alphabet,
|
| 569 |
+
esm_model,
|
| 570 |
+
mask):
|
| 571 |
+
|
| 572 |
+
rev_y_seq = []
|
| 573 |
+
for jj in range(len(batch_tokens)):
|
| 574 |
+
# do for one seq: this_seq
|
| 575 |
+
this_seq = decode_one_ems_token_rec_for_folding_with_mask(
|
| 576 |
+
batch_tokens[jj],
|
| 577 |
+
batch_logits[jj],
|
| 578 |
+
esm_alphabet,
|
| 579 |
+
esm_model,
|
| 580 |
+
mask[jj]
|
| 581 |
+
)
|
| 582 |
+
rev_y_seq.append(this_seq)
|
| 583 |
+
return rev_y_seq
|
| 584 |
+
|
| 585 |
+
# =====================================================
|
| 586 |
+
# create new input condition for ForcPath case
|
| 587 |
+
# =====================================================
|
| 588 |
+
from scipy import interpolate
|
| 589 |
+
|
| 590 |
+
def interpolate_and_resample_ForcPath(y0,seq_len1):
|
| 591 |
+
seq_len0=len(y0)-1
|
| 592 |
+
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
|
| 593 |
+
f=interpolate.interp1d(x0,y0)
|
| 594 |
+
#
|
| 595 |
+
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
|
| 596 |
+
y1=f(x1)
|
| 597 |
+
#
|
| 598 |
+
resu = {}
|
| 599 |
+
resu['y1']=y1
|
| 600 |
+
resu['x1']=x1
|
| 601 |
+
resu['x0']=x0
|
| 602 |
+
return resu
|
| 603 |
+
#
|
| 604 |
+
def mix_two_ForcPath(y0,y1,seq_len2):
|
| 605 |
+
seq_len0=len(y0)-1
|
| 606 |
+
x0=np.arange(0., 1.+1./seq_len0, 1./seq_len0)
|
| 607 |
+
seq_len1=len(y1)-1
|
| 608 |
+
x1=np.arange(0., 1.+1./seq_len1, 1./seq_len1)
|
| 609 |
+
f0=interpolate.interp1d(x0,y0)
|
| 610 |
+
f1=interpolate.interp1d(x1,y1)
|
| 611 |
+
#
|
| 612 |
+
x2=np.arange(0., 1.+1./seq_len2, 1./seq_len2)
|
| 613 |
+
y2=(f0(x2)+f1(x2))/1.
|
| 614 |
+
#
|
| 615 |
+
resu={}
|
| 616 |
+
resu['y2']=y2
|
| 617 |
+
resu['x2']=x2
|
| 618 |
+
resu['x1']=x1
|
| 619 |
+
resu['x0']=x0
|
| 620 |
+
return resu
|
| 621 |
+
#
|
| 622 |
+
# =====================================================
|
| 623 |
+
# load in function for language model
|
| 624 |
+
# =====================================================
|
| 625 |
+
import esm
|
| 626 |
+
|
| 627 |
+
def load_in_pLM(pLM_Model_Name,device):
|
| 628 |
+
#
|
| 629 |
+
# ++ for pLM
|
| 630 |
+
if pLM_Model_Name=='trivial':
|
| 631 |
+
pLM_Model=None
|
| 632 |
+
esm_alphabet=None
|
| 633 |
+
len_toks=0
|
| 634 |
+
esm_layer=0
|
| 635 |
+
|
| 636 |
+
elif pLM_Model_Name=='esm2_t33_650M_UR50D':
|
| 637 |
+
# dim: 1280
|
| 638 |
+
esm_layer=33
|
| 639 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
| 640 |
+
len_toks=len(esm_alphabet.all_toks)
|
| 641 |
+
pLM_Model.eval()
|
| 642 |
+
pLM_Model. to(device)
|
| 643 |
+
|
| 644 |
+
elif pLM_Model_Name=='esm2_t36_3B_UR50D':
|
| 645 |
+
# dim: 2560
|
| 646 |
+
esm_layer=36
|
| 647 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t36_3B_UR50D()
|
| 648 |
+
len_toks=len(esm_alphabet.all_toks)
|
| 649 |
+
pLM_Model.eval()
|
| 650 |
+
pLM_Model. to(device)
|
| 651 |
+
|
| 652 |
+
elif pLM_Model_Name=='esm2_t30_150M_UR50D':
|
| 653 |
+
# dim: 640
|
| 654 |
+
esm_layer=30
|
| 655 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
|
| 656 |
+
len_toks=len(esm_alphabet.all_toks)
|
| 657 |
+
pLM_Model.eval()
|
| 658 |
+
pLM_Model. to(device)
|
| 659 |
+
|
| 660 |
+
elif pLM_Model_Name=='esm2_t12_35M_UR50D':
|
| 661 |
+
# dim: 480
|
| 662 |
+
esm_layer=12
|
| 663 |
+
pLM_Model, esm_alphabet = esm.pretrained.esm2_t12_35M_UR50D()
|
| 664 |
+
len_toks=len(esm_alphabet.all_toks)
|
| 665 |
+
pLM_Model.eval()
|
| 666 |
+
pLM_Model. to(device)
|
| 667 |
+
|
| 668 |
+
else:
|
| 669 |
+
print("pLM model is missing...")
|
| 670 |
+
|
| 671 |
+
return pLM_Model, esm_alphabet, esm_layer, len_toks
|