Upload 22 files
Browse files- .gitattributes +4 -0
- data/emb_esm2_3b/P18281.pt +3 -0
- data/evaluate_data/evaluate_cases.py +213 -0
- data/evaluate_data/evaluate_pretrain.py +282 -0
- data/evaluate_data/evaluate_with_ancestors.py +339 -0
- data/evaluate_data/evaluate_with_ancestors_exp.py +339 -0
- data/evaluate_data/pretrain_output_to_deepgozero.py +477 -0
- data/evaluate_data/process_case.py +50 -0
- data/evaluate_data/utils.py +280 -0
- data/go1.4-basic.obo +3 -0
- data/go_descriptions1.4.txt +0 -0
- data/swissprot_exp/test_exp_prompt_bp_new.csv +0 -0
- data/swissprot_exp/test_exp_prompt_cc_new.csv +0 -0
- data/swissprot_exp/test_exp_prompt_mf_new.csv +0 -0
- data/swissprot_exp/train_exp_prompt_bp_new.csv +3 -0
- data/swissprot_exp/train_exp_prompt_cc_new.csv +3 -0
- data/swissprot_exp/train_exp_prompt_mf_new.csv +3 -0
- data/swissprot_exp/val_exp_prompt_bp_new.csv +0 -0
- data/swissprot_exp/val_exp_prompt_cc_new.csv +0 -0
- data/swissprot_exp/val_exp_prompt_mf_new.csv +0 -0
- data/terms/bp_terms.pkl +3 -0
- data/terms/cc_terms.pkl +3 -0
- data/terms/mf_terms.pkl +3 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
assets/FAPM.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
assets/LAVIS_technical_report.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/go1.4-basic.obo filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/swissprot_exp/train_exp_prompt_bp_new.csv filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/swissprot_exp/train_exp_prompt_cc_new.csv filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
data/swissprot_exp/train_exp_prompt_mf_new.csv filter=lfs diff=lfs merge=lfs -text
|
data/emb_esm2_3b/P18281.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:91714943ae1d08f860e86cfcd098f3973dc14ca63d88556223223fc9687ac7ec
|
| 3 |
+
size 901864
|
data/evaluate_data/evaluate_cases.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import Levenshtein
|
| 5 |
+
import numpy as np
|
| 6 |
+
import difflib
|
| 7 |
+
# from torchmetrics.text import BLEUScore
|
| 8 |
+
import time
|
| 9 |
+
from multiprocessing import Pool, Queue, Process
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from data.evaluate_data.utils import Ontology
|
| 12 |
+
# bleu = BLEUScore(n_gram=1)
|
| 13 |
+
|
| 14 |
+
def fuzzy_match(texts):
|
| 15 |
+
text_dict = {}
|
| 16 |
+
for context in texts:
|
| 17 |
+
if context not in choices:
|
| 18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
| 19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
| 20 |
+
return text_dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_sim(text, label):
|
| 24 |
+
all_s = []
|
| 25 |
+
for x in label:
|
| 26 |
+
s = 0
|
| 27 |
+
for y in text:
|
| 28 |
+
temp = Levenshtein.ratio(x, y)
|
| 29 |
+
if temp > s:
|
| 30 |
+
s = temp
|
| 31 |
+
all_s.append(s)
|
| 32 |
+
all_s = [round(i, 3) for i in all_s]
|
| 33 |
+
|
| 34 |
+
# bs = [bleu(x, [label]) for x in text]
|
| 35 |
+
return all_s
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def txt_map(x, txt_dict):
|
| 39 |
+
if type(x) == str:
|
| 40 |
+
x = eval(x)
|
| 41 |
+
x_ = []
|
| 42 |
+
for i in x:
|
| 43 |
+
if i == '':
|
| 44 |
+
continue
|
| 45 |
+
if i in txt_dict:
|
| 46 |
+
x_.append(txt_dict[i])
|
| 47 |
+
else:
|
| 48 |
+
x_.append(i)
|
| 49 |
+
return x_
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def go_map(t):
|
| 53 |
+
if t in GO_dict:
|
| 54 |
+
return GO_dict[t]
|
| 55 |
+
else:
|
| 56 |
+
print(t)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_term(df):
|
| 60 |
+
from collections import Counter
|
| 61 |
+
cnt = Counter()
|
| 62 |
+
for i, row in enumerate(df.itertuples()):
|
| 63 |
+
for term in row.prop_annotations:
|
| 64 |
+
cnt[term] += 1
|
| 65 |
+
terms = list(cnt.keys())
|
| 66 |
+
# remove top
|
| 67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
| 68 |
+
if top_term in terms:
|
| 69 |
+
terms.remove(top_term)
|
| 70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
| 71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
| 76 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
| 77 |
+
go_des.columns = ['GO', 'function']
|
| 78 |
+
go_des = go_des[go_des['function'].notnull()]
|
| 79 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
| 80 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
| 81 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
| 82 |
+
|
| 83 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
|
| 84 |
+
data.columns = ['protein', 'pred', 'label']
|
| 85 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
| 86 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 87 |
+
|
| 88 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 89 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 90 |
+
|
| 91 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test.csv', sep='|')
|
| 92 |
+
test = test.drop_duplicates()
|
| 93 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
| 94 |
+
test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 95 |
+
test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 96 |
+
|
| 97 |
+
test_dict = dict()
|
| 98 |
+
for x, y in zip(test['function'], test['GO_label']):
|
| 99 |
+
temp = dict(zip(x, y))
|
| 100 |
+
test_dict.update(temp)
|
| 101 |
+
GO_dict.update(test_dict)
|
| 102 |
+
|
| 103 |
+
choices = list(test_dict.keys())
|
| 104 |
+
|
| 105 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 106 |
+
'''
|
| 107 |
+
print("找到与预测文本最相似的GO标签......")
|
| 108 |
+
t0 = time.time()
|
| 109 |
+
txt_dict = {}
|
| 110 |
+
|
| 111 |
+
all_txt = []
|
| 112 |
+
for txt in data['pred_list']:
|
| 113 |
+
if type(txt) == str:
|
| 114 |
+
all_txt.extend(eval(txt))
|
| 115 |
+
else:
|
| 116 |
+
all_txt.extend(txt)
|
| 117 |
+
all_txt = list(set(all_txt))
|
| 118 |
+
|
| 119 |
+
n = len(all_txt)
|
| 120 |
+
thread = 10
|
| 121 |
+
size = int(n/thread)
|
| 122 |
+
inds = list(range(0, n, size))
|
| 123 |
+
inds.append(n)
|
| 124 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
| 125 |
+
|
| 126 |
+
with Pool(processes=thread) as pool:
|
| 127 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 128 |
+
pool.close()
|
| 129 |
+
pool.join()
|
| 130 |
+
for d in result:
|
| 131 |
+
txt_dict.update(d)
|
| 132 |
+
|
| 133 |
+
# for txt in all_txt[:10]:
|
| 134 |
+
# fuzzy_match(txt)
|
| 135 |
+
|
| 136 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 137 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
| 138 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 139 |
+
|
| 140 |
+
print("calculating f1 score ......")
|
| 141 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
| 142 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 143 |
+
'''
|
| 144 |
+
|
| 145 |
+
# 准备case测试数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
| 146 |
+
prepare_ancestors = True
|
| 147 |
+
if prepare_ancestors:
|
| 148 |
+
print("准备加入祖先后的数据......")
|
| 149 |
+
def prop(df):
|
| 150 |
+
prop_annotations = []
|
| 151 |
+
for i, row in df.iterrows():
|
| 152 |
+
# Propagate annotations
|
| 153 |
+
annot_set = set()
|
| 154 |
+
annots = row['GO_label']
|
| 155 |
+
for go_id in annots:
|
| 156 |
+
annot_set |= go.get_anchestors(go_id)
|
| 157 |
+
annots = list(annot_set)
|
| 158 |
+
prop_annotations.append(annots)
|
| 159 |
+
df['prop_annotations'] = prop_annotations
|
| 160 |
+
return df
|
| 161 |
+
|
| 162 |
+
def pred_text_to_go(df):
|
| 163 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 164 |
+
|
| 165 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 166 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 167 |
+
t0 = time.time()
|
| 168 |
+
txt_dict = {}
|
| 169 |
+
|
| 170 |
+
all_txt = []
|
| 171 |
+
for txt in df['pred_list']:
|
| 172 |
+
if type(txt) == str:
|
| 173 |
+
all_txt.extend(eval(txt))
|
| 174 |
+
else:
|
| 175 |
+
all_txt.extend(txt)
|
| 176 |
+
|
| 177 |
+
all_txt = list(set(all_txt))
|
| 178 |
+
if '' in all_txt:
|
| 179 |
+
all_txt.remove('')
|
| 180 |
+
|
| 181 |
+
n = len(all_txt)
|
| 182 |
+
thread = 10
|
| 183 |
+
size = int(n / thread)
|
| 184 |
+
inds = list(range(0, n, size))
|
| 185 |
+
inds.append(n)
|
| 186 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
| 187 |
+
|
| 188 |
+
with Pool(processes=thread) as pool:
|
| 189 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 190 |
+
pool.close()
|
| 191 |
+
pool.join()
|
| 192 |
+
for d in result:
|
| 193 |
+
txt_dict.update(d)
|
| 194 |
+
|
| 195 |
+
# for txt in all_txt[:10]:
|
| 196 |
+
# fuzzy_match(txt)
|
| 197 |
+
|
| 198 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 199 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
| 200 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 201 |
+
|
| 202 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 203 |
+
return df
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_case.txt', sep='|', header=None)
|
| 207 |
+
test_pred.columns = ['protein', 'pred', 'GO_label']
|
| 208 |
+
test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 209 |
+
test_pred = prop(test_pred)
|
| 210 |
+
test_pred = pred_text_to_go(test_pred)
|
| 211 |
+
|
| 212 |
+
for cat in ['mf', 'bp', 'cc']:
|
| 213 |
+
test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_case.pkl'.format(cat))
|
data/evaluate_data/evaluate_pretrain.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import Levenshtein
|
| 5 |
+
import numpy as np
|
| 6 |
+
import difflib
|
| 7 |
+
# from torchmetrics.text import BLEUScore
|
| 8 |
+
import time
|
| 9 |
+
from multiprocessing import Pool, Queue, Process
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from data.evaluate_data.utils import Ontology
|
| 12 |
+
# bleu = BLEUScore(n_gram=1)
|
| 13 |
+
|
| 14 |
+
def fuzzy_match(texts):
|
| 15 |
+
text_dict = {}
|
| 16 |
+
for context in texts:
|
| 17 |
+
if context not in choices:
|
| 18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
| 19 |
+
sim_list = difflib.get_close_matches(context, choices, n=1, cutoff=0.93)
|
| 20 |
+
if len(sim_list) > 0:
|
| 21 |
+
text_dict[context] = sim_list[0]
|
| 22 |
+
else:
|
| 23 |
+
text_dict[context] = ''
|
| 24 |
+
return text_dict
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_sim(text, label):
|
| 28 |
+
all_s = []
|
| 29 |
+
for x in label:
|
| 30 |
+
s = 0
|
| 31 |
+
for y in text:
|
| 32 |
+
temp = Levenshtein.ratio(x, y)
|
| 33 |
+
if temp > s:
|
| 34 |
+
s = temp
|
| 35 |
+
all_s.append(s)
|
| 36 |
+
all_s = [round(i, 3) for i in all_s]
|
| 37 |
+
|
| 38 |
+
# bs = [bleu(x, [label]) for x in text]
|
| 39 |
+
return all_s
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def txt_map(x, txt_dict):
|
| 43 |
+
if type(x) == str:
|
| 44 |
+
x = eval(x)
|
| 45 |
+
x_ = []
|
| 46 |
+
for i in x:
|
| 47 |
+
if i == '':
|
| 48 |
+
continue
|
| 49 |
+
if i in txt_dict:
|
| 50 |
+
x_.append(txt_dict[i])
|
| 51 |
+
else:
|
| 52 |
+
x_.append(i)
|
| 53 |
+
return x_
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def go_map(t):
|
| 57 |
+
if t in GO_dict:
|
| 58 |
+
return GO_dict[t]
|
| 59 |
+
else:
|
| 60 |
+
pass
|
| 61 |
+
#print(t)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_term(df):
|
| 65 |
+
from collections import Counter
|
| 66 |
+
cnt = Counter()
|
| 67 |
+
for i, row in enumerate(df.itertuples()):
|
| 68 |
+
for term in row.prop_annotations:
|
| 69 |
+
cnt[term] += 1
|
| 70 |
+
terms = list(cnt.keys())
|
| 71 |
+
# remove top
|
| 72 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
| 73 |
+
if top_term in terms:
|
| 74 |
+
terms.remove(top_term)
|
| 75 |
+
terms_df = pd.DataFrame({'gos': terms})
|
| 76 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/terms.pkl')
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
| 81 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
| 82 |
+
go_des.columns = ['GO', 'function']
|
| 83 |
+
go_des = go_des[go_des['function'].notnull()]
|
| 84 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
| 85 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
| 86 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
| 87 |
+
|
| 88 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/output_go_train.txt', sep='|', header=None, on_bad_lines='skip')
|
| 89 |
+
data.columns = ['name', 'pred', 'label']
|
| 90 |
+
#data['label'] = data['label'].apply(lambda x: x.lower())
|
| 91 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 92 |
+
|
| 93 |
+
#data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 94 |
+
data['pred_list'] = data['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
| 95 |
+
|
| 96 |
+
#train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
|
| 97 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/train_exp.csv', sep='|')
|
| 98 |
+
test = test.drop_duplicates()
|
| 99 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
| 100 |
+
test['function'] = test['function'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 101 |
+
test['GO_label'] = test['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 102 |
+
|
| 103 |
+
data = pd.merge(data, test[['name', 'function']], on='name', how='left')
|
| 104 |
+
data['label_list'] = data['function']
|
| 105 |
+
|
| 106 |
+
test_dict = dict()
|
| 107 |
+
for x, y in zip(test['function'], test['GO_label']):
|
| 108 |
+
temp = dict(zip(x, y))
|
| 109 |
+
test_dict.update(temp)
|
| 110 |
+
GO_dict.update(test_dict)
|
| 111 |
+
|
| 112 |
+
choices = list(test_dict.keys())
|
| 113 |
+
|
| 114 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 115 |
+
print("找到与预测文本最相似的GO标签......")
|
| 116 |
+
t0 = time.time()
|
| 117 |
+
txt_dict = {}
|
| 118 |
+
|
| 119 |
+
all_txt = []
|
| 120 |
+
for txt in data['pred_list']:
|
| 121 |
+
if type(txt) == str:
|
| 122 |
+
all_txt.extend(eval(txt))
|
| 123 |
+
else:
|
| 124 |
+
all_txt.extend(txt)
|
| 125 |
+
all_txt = list(set(all_txt))
|
| 126 |
+
|
| 127 |
+
n = len(all_txt)
|
| 128 |
+
thread = 40
|
| 129 |
+
size = int(n/thread)
|
| 130 |
+
inds = list(range(0, n, size))
|
| 131 |
+
inds.append(n)
|
| 132 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
| 133 |
+
|
| 134 |
+
with Pool(processes=thread) as pool:
|
| 135 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 136 |
+
pool.close()
|
| 137 |
+
pool.join()
|
| 138 |
+
for d in result:
|
| 139 |
+
txt_dict.update(d)
|
| 140 |
+
|
| 141 |
+
# for txt in all_txt[:10]:
|
| 142 |
+
# fuzzy_match(txt)
|
| 143 |
+
|
| 144 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 145 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
| 146 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 147 |
+
|
| 148 |
+
print("calculating f1 score ......")
|
| 149 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
| 150 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
labels = []
|
| 154 |
+
pred_labels = []
|
| 155 |
+
for l in data['label_list_go']:
|
| 156 |
+
if type(l) == str:
|
| 157 |
+
l = eval(l)
|
| 158 |
+
labels.extend(l)
|
| 159 |
+
|
| 160 |
+
label_count = {}
|
| 161 |
+
for x in labels:
|
| 162 |
+
if x not in label_count:
|
| 163 |
+
label_count[x] = 1
|
| 164 |
+
else:
|
| 165 |
+
label_count[x] += 1
|
| 166 |
+
|
| 167 |
+
labels = list(set(labels))
|
| 168 |
+
total = len(labels)
|
| 169 |
+
recalls = []
|
| 170 |
+
precisions = []
|
| 171 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
| 172 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
| 173 |
+
if type(label) == str:
|
| 174 |
+
label = eval(label)
|
| 175 |
+
if type(preds) == str:
|
| 176 |
+
txts = eval(preds)
|
| 177 |
+
ll = len(label)
|
| 178 |
+
for t in label:
|
| 179 |
+
# supgo = go.get_anchestors(t)
|
| 180 |
+
# if supgo.intersection(set(preds)):
|
| 181 |
+
if t in preds:
|
| 182 |
+
tp_dict[t] += 1
|
| 183 |
+
else:
|
| 184 |
+
fn_dict[t] += 1
|
| 185 |
+
for p in preds:
|
| 186 |
+
# supgo = go.get_anchestors(p)
|
| 187 |
+
# if not supgo.intersection(set(label)):
|
| 188 |
+
if p not in label:
|
| 189 |
+
if p in fp_dict:
|
| 190 |
+
fp_dict[p] += 1
|
| 191 |
+
else:
|
| 192 |
+
fp_dict[p] = 1
|
| 193 |
+
pred_labels.extend(preds)
|
| 194 |
+
p_total = len(set(pred_labels))
|
| 195 |
+
recall, pr = 0., 0.
|
| 196 |
+
for x in labels:
|
| 197 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 198 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 199 |
+
r = recall / total
|
| 200 |
+
p = pr / p_total
|
| 201 |
+
f1 = 2 * p * r / (p + r)
|
| 202 |
+
|
| 203 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
| 204 |
+
print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
| 208 |
+
prepare_ancestors = False
|
| 209 |
+
if prepare_ancestors:
|
| 210 |
+
print("准备加入祖先后的数据......")
|
| 211 |
+
def prop(df):
|
| 212 |
+
prop_annotations = []
|
| 213 |
+
for i, row in df.iterrows():
|
| 214 |
+
# Propagate annotations
|
| 215 |
+
annot_set = set()
|
| 216 |
+
annots = row['GO_label']
|
| 217 |
+
for go_id in annots:
|
| 218 |
+
annot_set |= go.get_anchestors(go_id)
|
| 219 |
+
annots = list(annot_set)
|
| 220 |
+
prop_annotations.append(annots)
|
| 221 |
+
df['prop_annotations'] = prop_annotations
|
| 222 |
+
return df
|
| 223 |
+
|
| 224 |
+
def remove_nan(x):
|
| 225 |
+
if '' in x:
|
| 226 |
+
x.remove('')
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
def pred_text_to_go(df):
|
| 230 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 231 |
+
|
| 232 |
+
df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
| 233 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 234 |
+
t0 = time.time()
|
| 235 |
+
txt_dict = {}
|
| 236 |
+
|
| 237 |
+
all_txt = []
|
| 238 |
+
for txt in df['pred_list']:
|
| 239 |
+
if type(txt) == str:
|
| 240 |
+
all_txt.extend(eval(txt))
|
| 241 |
+
else:
|
| 242 |
+
all_txt.extend(txt)
|
| 243 |
+
|
| 244 |
+
all_txt = list(set(all_txt))
|
| 245 |
+
if '' in all_txt:
|
| 246 |
+
all_txt.remove('')
|
| 247 |
+
|
| 248 |
+
n = len(all_txt)
|
| 249 |
+
thread = 40
|
| 250 |
+
size = int(n / thread)
|
| 251 |
+
inds = list(range(0, n, size))
|
| 252 |
+
inds.append(n)
|
| 253 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
| 254 |
+
|
| 255 |
+
with Pool(processes=thread) as pool:
|
| 256 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 257 |
+
pool.close()
|
| 258 |
+
pool.join()
|
| 259 |
+
for d in result:
|
| 260 |
+
txt_dict.update(d)
|
| 261 |
+
|
| 262 |
+
# for txt in all_txt[:10]:
|
| 263 |
+
# fuzzy_match(txt)
|
| 264 |
+
|
| 265 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 266 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
| 267 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: remove_nan(x))
|
| 268 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 269 |
+
|
| 270 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 271 |
+
return df
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/pretrain/output_pretrain.txt', sep='|', header=None)
|
| 275 |
+
test_pred.columns = ['protein', 'pred', 'GO_label']
|
| 276 |
+
test_pred['GO_label'] = test_pred['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 277 |
+
test_pred = test_pred(test)
|
| 278 |
+
get_term(test)
|
| 279 |
+
test_pred = pred_text_to_go(test_pred)
|
| 280 |
+
|
| 281 |
+
test_pred.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_pretrain.pkl')
|
| 282 |
+
|
data/evaluate_data/evaluate_with_ancestors.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import Levenshtein
|
| 5 |
+
import numpy as np
|
| 6 |
+
import difflib
|
| 7 |
+
# from torchmetrics.text import BLEUScore
|
| 8 |
+
import time
|
| 9 |
+
from multiprocessing import Pool, Queue, Process
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from data.evaluate_data.utils import Ontology
|
| 12 |
+
# bleu = BLEUScore(n_gram=1)
|
| 13 |
+
|
| 14 |
+
def fuzzy_match(texts):
|
| 15 |
+
text_dict = {}
|
| 16 |
+
for context in texts:
|
| 17 |
+
if context not in choices:
|
| 18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
| 19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
| 20 |
+
return text_dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_sim(text, label):
|
| 24 |
+
all_s = []
|
| 25 |
+
for x in label:
|
| 26 |
+
s = 0
|
| 27 |
+
for y in text:
|
| 28 |
+
temp = Levenshtein.ratio(x, y)
|
| 29 |
+
if temp > s:
|
| 30 |
+
s = temp
|
| 31 |
+
all_s.append(s)
|
| 32 |
+
all_s = [round(i, 3) for i in all_s]
|
| 33 |
+
|
| 34 |
+
# bs = [bleu(x, [label]) for x in text]
|
| 35 |
+
return all_s
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def txt_map(x, txt_dict):
|
| 39 |
+
if type(x) == str:
|
| 40 |
+
x = eval(x)
|
| 41 |
+
x_ = []
|
| 42 |
+
for i in x:
|
| 43 |
+
if i == '':
|
| 44 |
+
continue
|
| 45 |
+
if i in txt_dict:
|
| 46 |
+
x_.append(txt_dict[i])
|
| 47 |
+
else:
|
| 48 |
+
x_.append(i)
|
| 49 |
+
return x_
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def go_map(t):
|
| 53 |
+
if t in GO_dict:
|
| 54 |
+
return GO_dict[t]
|
| 55 |
+
else:
|
| 56 |
+
print(t)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_term(df):
|
| 60 |
+
from collections import Counter
|
| 61 |
+
cnt = Counter()
|
| 62 |
+
for i, row in enumerate(df.itertuples()):
|
| 63 |
+
for term in row.prop_annotations:
|
| 64 |
+
cnt[term] += 1
|
| 65 |
+
terms = list(cnt.keys())
|
| 66 |
+
# remove top
|
| 67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
| 68 |
+
if top_term in terms:
|
| 69 |
+
terms.remove(top_term)
|
| 70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
| 71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
cat = 'mf'
|
| 76 |
+
|
| 77 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
| 78 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
| 79 |
+
go_des.columns = ['GO', 'function']
|
| 80 |
+
go_des = go_des[go_des['function'].notnull()]
|
| 81 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
| 82 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
| 83 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
|
| 87 |
+
|
| 88 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
| 89 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 90 |
+
|
| 91 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 92 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 93 |
+
|
| 94 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
|
| 95 |
+
train = train.drop_duplicates()
|
| 96 |
+
train['function'] = train['function'].apply(lambda x: x.lower().strip())
|
| 97 |
+
train_dict = dict(zip(train['function'], train['GO_label']))
|
| 98 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
|
| 99 |
+
test = test.drop_duplicates()
|
| 100 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
| 101 |
+
test_dict = dict(zip(test['function'], test['GO_label']))
|
| 102 |
+
GO_dict.update(train_dict)
|
| 103 |
+
GO_dict.update(test_dict)
|
| 104 |
+
|
| 105 |
+
choices = []
|
| 106 |
+
for x in data['label_list'].tolist() + train['function'].tolist():
|
| 107 |
+
choices.extend(x)
|
| 108 |
+
choices = list(set(choices))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 112 |
+
print("找到与预测文本最相似的GO标签......")
|
| 113 |
+
t0 = time.time()
|
| 114 |
+
txt_dict = {}
|
| 115 |
+
|
| 116 |
+
all_txt = []
|
| 117 |
+
for txt in data['pred_list']:
|
| 118 |
+
if type(txt) == str:
|
| 119 |
+
all_txt.extend(eval(txt))
|
| 120 |
+
else:
|
| 121 |
+
all_txt.extend(txt)
|
| 122 |
+
all_txt = list(set(all_txt))
|
| 123 |
+
|
| 124 |
+
n = len(all_txt)
|
| 125 |
+
thread = 40
|
| 126 |
+
size = int(n/thread)
|
| 127 |
+
inds = list(range(0, n, size))
|
| 128 |
+
inds.append(n)
|
| 129 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
| 130 |
+
|
| 131 |
+
with Pool(processes=thread) as pool:
|
| 132 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 133 |
+
pool.close()
|
| 134 |
+
pool.join()
|
| 135 |
+
for d in result:
|
| 136 |
+
txt_dict.update(d)
|
| 137 |
+
|
| 138 |
+
# for txt in all_txt[:10]:
|
| 139 |
+
# fuzzy_match(txt)
|
| 140 |
+
|
| 141 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 142 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
| 143 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# sims = []
|
| 147 |
+
# for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
|
| 148 |
+
# a = get_sim(text, label)
|
| 149 |
+
# sims.append(a)
|
| 150 |
+
#
|
| 151 |
+
# data['sim'] = sims
|
| 152 |
+
# data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
| 153 |
+
# print("simlarity: {}".format(data['avg_sim'].mean()))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
print("calculating f1 score ......")
|
| 157 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
| 158 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
labels = []
|
| 162 |
+
pred_labels = []
|
| 163 |
+
for l in data['label_list_go']:
|
| 164 |
+
if type(l) == str:
|
| 165 |
+
l = eval(l)
|
| 166 |
+
labels.extend(l)
|
| 167 |
+
|
| 168 |
+
label_count = {}
|
| 169 |
+
for x in labels:
|
| 170 |
+
if x not in label_count:
|
| 171 |
+
label_count[x] = 1
|
| 172 |
+
else:
|
| 173 |
+
label_count[x] += 1
|
| 174 |
+
|
| 175 |
+
labels = list(set(labels))
|
| 176 |
+
total = len(labels)
|
| 177 |
+
recalls = []
|
| 178 |
+
precisions = []
|
| 179 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
| 180 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
| 181 |
+
if type(label) == str:
|
| 182 |
+
label = eval(label)
|
| 183 |
+
if type(preds) == str:
|
| 184 |
+
txts = eval(preds)
|
| 185 |
+
ll = len(label)
|
| 186 |
+
for t in label:
|
| 187 |
+
supgo = go.get_anchestors(t)
|
| 188 |
+
if supgo.intersection(set(preds)):
|
| 189 |
+
tp_dict[t] += 1
|
| 190 |
+
else:
|
| 191 |
+
fn_dict[t] += 1
|
| 192 |
+
for p in preds:
|
| 193 |
+
supgo = go.get_anchestors(p)
|
| 194 |
+
if not supgo.intersection(set(label)):
|
| 195 |
+
if p in fp_dict:
|
| 196 |
+
fp_dict[p] += 1
|
| 197 |
+
else:
|
| 198 |
+
fp_dict[p] = 1
|
| 199 |
+
pred_labels.extend(preds)
|
| 200 |
+
p_total = len(set(pred_labels))
|
| 201 |
+
recall, pr = 0., 0.
|
| 202 |
+
for x in labels:
|
| 203 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 204 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 205 |
+
r = recall / total
|
| 206 |
+
p = pr / p_total
|
| 207 |
+
f1 = 2 * p * r / (p + r)
|
| 208 |
+
|
| 209 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
| 210 |
+
print("f1 score: {}".format(f1))
|
| 211 |
+
|
| 212 |
+
'''
|
| 213 |
+
cat_f1 = {}
|
| 214 |
+
for x in labels:
|
| 215 |
+
if tp_dict[x] + fn_dict[x] > 0:
|
| 216 |
+
re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 217 |
+
pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 218 |
+
cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
|
| 219 |
+
|
| 220 |
+
plt.xlabel('f score')
|
| 221 |
+
plt.ylabel('count')
|
| 222 |
+
print(np.mean(list(cat_f1.values())))
|
| 223 |
+
plt.hist(list(cat_f1.values()), color='red', bins=30)
|
| 224 |
+
plt.show()
|
| 225 |
+
|
| 226 |
+
xs, ys = [], []
|
| 227 |
+
for x in labels:
|
| 228 |
+
xs.append(label_count[x])
|
| 229 |
+
ys.append(cat_f1[x])
|
| 230 |
+
df_count = pd.DataFrame({'xs': xs, 'ys': ys})
|
| 231 |
+
df_count['xs'].loc[df_count['xs'] > 10] = 11
|
| 232 |
+
df_count['xs'] = df_count['xs'].astype(str)
|
| 233 |
+
df_count1 = df_count.groupby('xs').mean().reset_index()
|
| 234 |
+
df_count2 = df_count.groupby('xs').count().reset_index()
|
| 235 |
+
|
| 236 |
+
plt.xlabel('label count')
|
| 237 |
+
plt.ylabel('f score mean')
|
| 238 |
+
df_count1['xs'] = df_count1['xs'].astype(int)
|
| 239 |
+
plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
|
| 240 |
+
plt.show()
|
| 241 |
+
|
| 242 |
+
plt.xlabel('label count')
|
| 243 |
+
plt.ylabel('protein num')
|
| 244 |
+
df_count2['xs'] = df_count2['xs'].astype(int)
|
| 245 |
+
plt.bar(df_count2['xs'], df_count2['ys'], color='red')
|
| 246 |
+
plt.show()
|
| 247 |
+
'''
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
| 251 |
+
print("准备加入祖先后的数据......")
|
| 252 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/train_{}.csv'.format(cat), sep='|')
|
| 253 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_split/test_{}.csv'.format(cat), sep='|')
|
| 254 |
+
train = train.groupby('name').agg({'GO_label': list}).reset_index()
|
| 255 |
+
test = test.groupby('name').agg({'GO_label': list}).reset_index()
|
| 256 |
+
|
| 257 |
+
def prop(df):
|
| 258 |
+
prop_annotations = []
|
| 259 |
+
for i, row in df.iterrows():
|
| 260 |
+
# Propagate annotations
|
| 261 |
+
annot_set = set()
|
| 262 |
+
annots = row['GO_label']
|
| 263 |
+
for go_id in annots:
|
| 264 |
+
annot_set |= go.get_anchestors(go_id)
|
| 265 |
+
annots = list(annot_set)
|
| 266 |
+
prop_annotations.append(annots)
|
| 267 |
+
df['prop_annotations'] = prop_annotations
|
| 268 |
+
return df
|
| 269 |
+
|
| 270 |
+
train = prop(train)
|
| 271 |
+
test = prop(test)
|
| 272 |
+
|
| 273 |
+
train_test = pd.concat([train, test])
|
| 274 |
+
get_term(train_test)
|
| 275 |
+
del train_test
|
| 276 |
+
|
| 277 |
+
def pred_text_to_go(df):
|
| 278 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 279 |
+
|
| 280 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 281 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 282 |
+
t0 = time.time()
|
| 283 |
+
txt_dict = {}
|
| 284 |
+
|
| 285 |
+
all_txt = []
|
| 286 |
+
for txt in df['pred_list']:
|
| 287 |
+
if type(txt) == str:
|
| 288 |
+
all_txt.extend(eval(txt))
|
| 289 |
+
else:
|
| 290 |
+
all_txt.extend(txt)
|
| 291 |
+
|
| 292 |
+
all_txt = list(set(all_txt))
|
| 293 |
+
if '' in all_txt:
|
| 294 |
+
all_txt.remove('')
|
| 295 |
+
|
| 296 |
+
n = len(all_txt)
|
| 297 |
+
thread = 40
|
| 298 |
+
size = int(n / thread)
|
| 299 |
+
inds = list(range(0, n, size))
|
| 300 |
+
inds.append(n)
|
| 301 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
| 302 |
+
|
| 303 |
+
with Pool(processes=thread) as pool:
|
| 304 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 305 |
+
pool.close()
|
| 306 |
+
pool.join()
|
| 307 |
+
for d in result:
|
| 308 |
+
txt_dict.update(d)
|
| 309 |
+
|
| 310 |
+
# for txt in all_txt[:10]:
|
| 311 |
+
# fuzzy_match(txt)
|
| 312 |
+
|
| 313 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 314 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
| 315 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 316 |
+
|
| 317 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 318 |
+
return df
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_train{}.csv'.format(cat), sep='|')
|
| 322 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/predict_concat_test{}.csv'.format(cat), sep='|')
|
| 323 |
+
|
| 324 |
+
train_pred = pred_text_to_go(train_pred)
|
| 325 |
+
test_pred = pred_text_to_go(test_pred)
|
| 326 |
+
|
| 327 |
+
train_data = pd.merge(train[['name', 'prop_annotations']],
|
| 328 |
+
train_pred[['name', 'pred_list_go']],
|
| 329 |
+
on='name', how='inner')
|
| 330 |
+
train_data = train_data.drop_duplicates('name')
|
| 331 |
+
train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
|
| 332 |
+
|
| 333 |
+
test_data = pd.merge(test[['name', 'prop_annotations']],
|
| 334 |
+
test_pred[['name', 'pred_list_go']],
|
| 335 |
+
on='name', how='inner')
|
| 336 |
+
test_data = test_data.drop_duplicates('name')
|
| 337 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
|
| 338 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
|
| 339 |
+
|
data/evaluate_data/evaluate_with_ancestors_exp.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import re
|
| 3 |
+
import random
|
| 4 |
+
import Levenshtein
|
| 5 |
+
import numpy as np
|
| 6 |
+
import difflib
|
| 7 |
+
# from torchmetrics.text import BLEUScore
|
| 8 |
+
import time
|
| 9 |
+
from multiprocessing import Pool, Queue, Process
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
from data.evaluate_data.utils import Ontology
|
| 12 |
+
# bleu = BLEUScore(n_gram=1)
|
| 13 |
+
|
| 14 |
+
def fuzzy_match(texts):
|
| 15 |
+
text_dict = {}
|
| 16 |
+
for context in texts:
|
| 17 |
+
if context not in choices:
|
| 18 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
| 19 |
+
text_dict[context] = difflib.get_close_matches(context, choices, n=1, cutoff=0.)[0]
|
| 20 |
+
return text_dict
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_sim(text, label):
|
| 24 |
+
all_s = []
|
| 25 |
+
for x in label:
|
| 26 |
+
s = 0
|
| 27 |
+
for y in text:
|
| 28 |
+
temp = Levenshtein.ratio(x, y)
|
| 29 |
+
if temp > s:
|
| 30 |
+
s = temp
|
| 31 |
+
all_s.append(s)
|
| 32 |
+
all_s = [round(i, 3) for i in all_s]
|
| 33 |
+
|
| 34 |
+
# bs = [bleu(x, [label]) for x in text]
|
| 35 |
+
return all_s
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def txt_map(x, txt_dict):
|
| 39 |
+
if type(x) == str:
|
| 40 |
+
x = eval(x)
|
| 41 |
+
x_ = []
|
| 42 |
+
for i in x:
|
| 43 |
+
if i == '':
|
| 44 |
+
continue
|
| 45 |
+
if i in txt_dict:
|
| 46 |
+
x_.append(txt_dict[i])
|
| 47 |
+
else:
|
| 48 |
+
x_.append(i)
|
| 49 |
+
return x_
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def go_map(t):
|
| 53 |
+
if t in GO_dict:
|
| 54 |
+
return GO_dict[t]
|
| 55 |
+
else:
|
| 56 |
+
print(t)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_term(df):
|
| 60 |
+
from collections import Counter
|
| 61 |
+
cnt = Counter()
|
| 62 |
+
for i, row in enumerate(df.itertuples()):
|
| 63 |
+
for term in row.prop_annotations:
|
| 64 |
+
cnt[term] += 1
|
| 65 |
+
terms = list(cnt.keys())
|
| 66 |
+
# remove top
|
| 67 |
+
for top_term in ['GO:0005575', 'GO:0003674', 'GO:0008150']:
|
| 68 |
+
if top_term in terms:
|
| 69 |
+
terms.remove(top_term)
|
| 70 |
+
terms_df = pd.DataFrame({'gos': terms})
|
| 71 |
+
terms_df.to_pickle(f'/cluster/home/wenkai/deepgozero/data/blip2/{cat}/terms.pkl')
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
cat = 'mf'
|
| 76 |
+
|
| 77 |
+
go = Ontology(f'/cluster/home/wenkai/deepgozero/data/data/go.obo', with_rels=True)
|
| 78 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions_new.txt', sep='|', header=None)
|
| 79 |
+
go_des.columns = ['GO', 'function']
|
| 80 |
+
go_des = go_des[go_des['function'].notnull()]
|
| 81 |
+
go_des['function'] = go_des['function'].apply(lambda x: x.lower().strip())
|
| 82 |
+
go_des['GO'] = go_des['GO'].apply(lambda x: re.sub('_', ':', x))
|
| 83 |
+
GO_dict = dict(zip(go_des['function'], go_des['GO']))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
|
| 87 |
+
|
| 88 |
+
data['label'] = data['label'].apply(lambda x: x.lower())
|
| 89 |
+
data['pred'] = data['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 90 |
+
|
| 91 |
+
data['label_list'] = data['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 92 |
+
data['pred_list'] = data['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 93 |
+
|
| 94 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
|
| 95 |
+
train = train.drop_duplicates()
|
| 96 |
+
train['function'] = train['function'].apply(lambda x: x.lower().strip())
|
| 97 |
+
train_dict = dict(zip(train['function'], train['GO_label']))
|
| 98 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
|
| 99 |
+
test = test.drop_duplicates()
|
| 100 |
+
test['function'] = test['function'].apply(lambda x: x.lower().strip())
|
| 101 |
+
test_dict = dict(zip(test['function'], test['GO_label']))
|
| 102 |
+
GO_dict.update(train_dict)
|
| 103 |
+
GO_dict.update(test_dict)
|
| 104 |
+
|
| 105 |
+
choices = []
|
| 106 |
+
for x in data['label_list'].tolist() + train['function'].tolist():
|
| 107 |
+
choices.extend(x)
|
| 108 |
+
choices = list(set(choices))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 112 |
+
print("找到与预测文本最相似的GO标签......")
|
| 113 |
+
t0 = time.time()
|
| 114 |
+
txt_dict = {}
|
| 115 |
+
|
| 116 |
+
all_txt = []
|
| 117 |
+
for txt in data['pred_list']:
|
| 118 |
+
if type(txt) == str:
|
| 119 |
+
all_txt.extend(eval(txt))
|
| 120 |
+
else:
|
| 121 |
+
all_txt.extend(txt)
|
| 122 |
+
all_txt = list(set(all_txt))
|
| 123 |
+
|
| 124 |
+
n = len(all_txt)
|
| 125 |
+
thread = 40
|
| 126 |
+
size = int(n/thread)
|
| 127 |
+
inds = list(range(0, n, size))
|
| 128 |
+
inds.append(n)
|
| 129 |
+
all_txt_sep = [all_txt[i: min(i+size, n)] for i in inds[:-1]]
|
| 130 |
+
|
| 131 |
+
with Pool(processes=thread) as pool:
|
| 132 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 133 |
+
pool.close()
|
| 134 |
+
pool.join()
|
| 135 |
+
for d in result:
|
| 136 |
+
txt_dict.update(d)
|
| 137 |
+
|
| 138 |
+
# for txt in all_txt[:10]:
|
| 139 |
+
# fuzzy_match(txt)
|
| 140 |
+
|
| 141 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 142 |
+
data['pred_list'] = data['pred_list'].apply(lambda x: list(set(x)))
|
| 143 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# sims = []
|
| 147 |
+
# for text, label in zip(data['pred_list'].tolist(), data['label_list'].tolist()):
|
| 148 |
+
# a = get_sim(text, label)
|
| 149 |
+
# sims.append(a)
|
| 150 |
+
#
|
| 151 |
+
# data['sim'] = sims
|
| 152 |
+
# data['avg_sim'] = data['sim'].apply(lambda x: round(np.mean(x), 3))
|
| 153 |
+
# print("simlarity: {}".format(data['avg_sim'].mean()))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
print("calculating f1 score ......")
|
| 157 |
+
data['label_list_go'] = data['label_list'].apply(lambda x: [go_map(i) for i in x])
|
| 158 |
+
data['pred_list_go'] = data['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
labels = []
|
| 162 |
+
pred_labels = []
|
| 163 |
+
for l in data['label_list_go']:
|
| 164 |
+
if type(l) == str:
|
| 165 |
+
l = eval(l)
|
| 166 |
+
labels.extend(l)
|
| 167 |
+
|
| 168 |
+
label_count = {}
|
| 169 |
+
for x in labels:
|
| 170 |
+
if x not in label_count:
|
| 171 |
+
label_count[x] = 1
|
| 172 |
+
else:
|
| 173 |
+
label_count[x] += 1
|
| 174 |
+
|
| 175 |
+
labels = list(set(labels))
|
| 176 |
+
total = len(labels)
|
| 177 |
+
recalls = []
|
| 178 |
+
precisions = []
|
| 179 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels))), dict(zip(labels, [0]*len(labels)))
|
| 180 |
+
for preds, label in zip(data['pred_list_go'], data['label_list_go']):
|
| 181 |
+
if type(label) == str:
|
| 182 |
+
label = eval(label)
|
| 183 |
+
if type(preds) == str:
|
| 184 |
+
txts = eval(preds)
|
| 185 |
+
ll = len(label)
|
| 186 |
+
for t in label:
|
| 187 |
+
supgo = go.get_anchestors(t)
|
| 188 |
+
if supgo.intersection(set(preds)):
|
| 189 |
+
tp_dict[t] += 1
|
| 190 |
+
else:
|
| 191 |
+
fn_dict[t] += 1
|
| 192 |
+
for p in preds:
|
| 193 |
+
supgo = go.get_anchestors(p)
|
| 194 |
+
if not supgo.intersection(set(label)):
|
| 195 |
+
if p in fp_dict:
|
| 196 |
+
fp_dict[p] += 1
|
| 197 |
+
else:
|
| 198 |
+
fp_dict[p] = 1
|
| 199 |
+
pred_labels.extend(preds)
|
| 200 |
+
p_total = len(set(pred_labels))
|
| 201 |
+
recall, pr = 0., 0.
|
| 202 |
+
for x in labels:
|
| 203 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 204 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 205 |
+
r = recall / total
|
| 206 |
+
p = pr / p_total
|
| 207 |
+
f1 = 2 * p * r / (p + r)
|
| 208 |
+
|
| 209 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
| 210 |
+
print("f1 score: {}".format(f1))
|
| 211 |
+
|
| 212 |
+
'''
|
| 213 |
+
cat_f1 = {}
|
| 214 |
+
for x in labels:
|
| 215 |
+
if tp_dict[x] + fn_dict[x] > 0:
|
| 216 |
+
re = tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 217 |
+
pr = tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 218 |
+
cat_f1[x] = 2 * pr * re / (pr + re + 1e-10)
|
| 219 |
+
|
| 220 |
+
plt.xlabel('f score')
|
| 221 |
+
plt.ylabel('count')
|
| 222 |
+
print(np.mean(list(cat_f1.values())))
|
| 223 |
+
plt.hist(list(cat_f1.values()), color='red', bins=30)
|
| 224 |
+
plt.show()
|
| 225 |
+
|
| 226 |
+
xs, ys = [], []
|
| 227 |
+
for x in labels:
|
| 228 |
+
xs.append(label_count[x])
|
| 229 |
+
ys.append(cat_f1[x])
|
| 230 |
+
df_count = pd.DataFrame({'xs': xs, 'ys': ys})
|
| 231 |
+
df_count['xs'].loc[df_count['xs'] > 10] = 11
|
| 232 |
+
df_count['xs'] = df_count['xs'].astype(str)
|
| 233 |
+
df_count1 = df_count.groupby('xs').mean().reset_index()
|
| 234 |
+
df_count2 = df_count.groupby('xs').count().reset_index()
|
| 235 |
+
|
| 236 |
+
plt.xlabel('label count')
|
| 237 |
+
plt.ylabel('f score mean')
|
| 238 |
+
df_count1['xs'] = df_count1['xs'].astype(int)
|
| 239 |
+
plt.scatter(df_count1['xs'], df_count1['ys'], color='red')
|
| 240 |
+
plt.show()
|
| 241 |
+
|
| 242 |
+
plt.xlabel('label count')
|
| 243 |
+
plt.ylabel('protein num')
|
| 244 |
+
df_count2['xs'] = df_count2['xs'].astype(int)
|
| 245 |
+
plt.bar(df_count2['xs'], df_count2['ys'], color='red')
|
| 246 |
+
plt.show()
|
| 247 |
+
'''
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# 准备数据:blip2预测的Go标签作为feature,label加入祖先后作为预测的Y
|
| 251 |
+
print("准备加入祖先后的数据......")
|
| 252 |
+
train = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/train_{}.csv'.format(cat), sep='|')
|
| 253 |
+
test = pd.read_csv('/cluster/home/wenkai/LAVIS/data/sim_exp/test_{}.csv'.format(cat), sep='|')
|
| 254 |
+
train = train.groupby('name').agg({'GO_label': list}).reset_index()
|
| 255 |
+
test = test.groupby('name').agg({'GO_label': list}).reset_index()
|
| 256 |
+
|
| 257 |
+
def prop(df):
|
| 258 |
+
prop_annotations = []
|
| 259 |
+
for i, row in df.iterrows():
|
| 260 |
+
# Propagate annotations
|
| 261 |
+
annot_set = set()
|
| 262 |
+
annots = row['GO_label']
|
| 263 |
+
for go_id in annots:
|
| 264 |
+
annot_set |= go.get_anchestors(go_id)
|
| 265 |
+
annots = list(annot_set)
|
| 266 |
+
prop_annotations.append(annots)
|
| 267 |
+
df['prop_annotations'] = prop_annotations
|
| 268 |
+
return df
|
| 269 |
+
|
| 270 |
+
train = prop(train)
|
| 271 |
+
test = prop(test)
|
| 272 |
+
|
| 273 |
+
train_test = pd.concat([train, test])
|
| 274 |
+
get_term(train_test)
|
| 275 |
+
del train_test
|
| 276 |
+
|
| 277 |
+
def pred_text_to_go(df):
|
| 278 |
+
df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 279 |
+
|
| 280 |
+
df['pred_list'] = df['pred'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 281 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 282 |
+
t0 = time.time()
|
| 283 |
+
txt_dict = {}
|
| 284 |
+
|
| 285 |
+
all_txt = []
|
| 286 |
+
for txt in df['pred_list']:
|
| 287 |
+
if type(txt) == str:
|
| 288 |
+
all_txt.extend(eval(txt))
|
| 289 |
+
else:
|
| 290 |
+
all_txt.extend(txt)
|
| 291 |
+
|
| 292 |
+
all_txt = list(set(all_txt))
|
| 293 |
+
if '' in all_txt:
|
| 294 |
+
all_txt.remove('')
|
| 295 |
+
|
| 296 |
+
n = len(all_txt)
|
| 297 |
+
thread = 40
|
| 298 |
+
size = int(n / thread)
|
| 299 |
+
inds = list(range(0, n, size))
|
| 300 |
+
inds.append(n)
|
| 301 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
| 302 |
+
|
| 303 |
+
with Pool(processes=thread) as pool:
|
| 304 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 305 |
+
pool.close()
|
| 306 |
+
pool.join()
|
| 307 |
+
for d in result:
|
| 308 |
+
txt_dict.update(d)
|
| 309 |
+
|
| 310 |
+
# for txt in all_txt[:10]:
|
| 311 |
+
# fuzzy_match(txt)
|
| 312 |
+
|
| 313 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 314 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: list(set(x)))
|
| 315 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 316 |
+
|
| 317 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [go_map(i) for i in x])
|
| 318 |
+
return df
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
train_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_train{}.csv'.format(cat), sep='|')
|
| 322 |
+
test_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output_exp/predict_concat_test{}.csv'.format(cat), sep='|')
|
| 323 |
+
|
| 324 |
+
train_pred = pred_text_to_go(train_pred)
|
| 325 |
+
test_pred = pred_text_to_go(test_pred)
|
| 326 |
+
|
| 327 |
+
train_data = pd.merge(train[['name', 'prop_annotations']],
|
| 328 |
+
train_pred[['name', 'pred_list_go']],
|
| 329 |
+
on='name', how='inner')
|
| 330 |
+
train_data = train_data.drop_duplicates('name')
|
| 331 |
+
train_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/train_data.pkl'.format(cat))
|
| 332 |
+
|
| 333 |
+
test_data = pd.merge(test[['name', 'prop_annotations']],
|
| 334 |
+
test_pred[['name', 'pred_list_go']],
|
| 335 |
+
on='name', how='inner')
|
| 336 |
+
test_data = test_data.drop_duplicates('name')
|
| 337 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/test_data.pkl'.format(cat))
|
| 338 |
+
test_data.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/{}/valid_data.pkl'.format(cat))
|
| 339 |
+
|
data/evaluate_data/pretrain_output_to_deepgozero.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import time
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
import difflib
|
| 6 |
+
from utils import Ontology
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def filter(x_list):
|
| 11 |
+
new_go = []
|
| 12 |
+
# x_list = [i.strip() for i in x.split(';')]
|
| 13 |
+
for i in x_list:
|
| 14 |
+
if i in filter_go:
|
| 15 |
+
new_go.append(i)
|
| 16 |
+
return '; '.join(new_go)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fuzzy_match(texts):
|
| 20 |
+
text_dict = {}
|
| 21 |
+
for context in texts:
|
| 22 |
+
if context in choices:
|
| 23 |
+
text_dict[context] = context
|
| 24 |
+
elif context not in choices:
|
| 25 |
+
# txt_dict[txt] = process.extractOne(txt, choices)[0]
|
| 26 |
+
sim_list = difflib.get_close_matches(context.lower(), choices, n=1, cutoff=0.9)
|
| 27 |
+
if len(sim_list) > 0:
|
| 28 |
+
text_dict[context] = sim_list[0]
|
| 29 |
+
else:
|
| 30 |
+
# text_dict[context] = ''
|
| 31 |
+
pass
|
| 32 |
+
return text_dict
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def txt_map(x, txt_dict):
|
| 36 |
+
if type(x) == str:
|
| 37 |
+
x = eval(x)
|
| 38 |
+
x_ = []
|
| 39 |
+
for i in x:
|
| 40 |
+
if i == '':
|
| 41 |
+
continue
|
| 42 |
+
if i in txt_dict:
|
| 43 |
+
x_.append(txt_dict[i])
|
| 44 |
+
else:
|
| 45 |
+
# x_.append(i)
|
| 46 |
+
pass
|
| 47 |
+
return x_
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def go_map_prob(x, GO_dict):
|
| 51 |
+
res = []
|
| 52 |
+
for t in x:
|
| 53 |
+
if t[0] in GO_dict:
|
| 54 |
+
res.append((GO_dict[t[0]], t[1]))
|
| 55 |
+
else:
|
| 56 |
+
pass
|
| 57 |
+
# print("{} not in GO_dict".format(t[0]))
|
| 58 |
+
return res
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def txt_map_prob(x, txt_dict):
|
| 62 |
+
if type(x) == str:
|
| 63 |
+
x = eval(x)
|
| 64 |
+
x_ = []
|
| 65 |
+
temp = set()
|
| 66 |
+
for i in x:
|
| 67 |
+
if i[0] == '':
|
| 68 |
+
continue
|
| 69 |
+
elif i[0] in txt_dict and txt_dict[i[0]] not in temp:
|
| 70 |
+
x_.append((txt_dict[i[0]].lower(), i[1]))
|
| 71 |
+
temp.add(txt_dict[i[0]])
|
| 72 |
+
# elif i[0] not in txt_dict:
|
| 73 |
+
# x_.append((i[0].lower(), i[1]))
|
| 74 |
+
# temp.add(i[0])
|
| 75 |
+
else:
|
| 76 |
+
continue
|
| 77 |
+
return x_
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def go_map(x, GO_dict):
|
| 81 |
+
res = []
|
| 82 |
+
for t in x:
|
| 83 |
+
if t in GO_dict:
|
| 84 |
+
res.append(GO_dict[t])
|
| 85 |
+
else:
|
| 86 |
+
# pass
|
| 87 |
+
print("{} not in GO_dict".format(t))
|
| 88 |
+
return res
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def prop(df):
|
| 92 |
+
prop_annotations = []
|
| 93 |
+
for i, row in df.iterrows():
|
| 94 |
+
# Propagate annotations
|
| 95 |
+
annot_set = set()
|
| 96 |
+
annots = row['GO_label']
|
| 97 |
+
for go_id in annots:
|
| 98 |
+
annot_set |= godb.get_anchestors(go_id)
|
| 99 |
+
annots = list(annot_set)
|
| 100 |
+
prop_annotations.append(annots)
|
| 101 |
+
df['prop_annotations'] = prop_annotations
|
| 102 |
+
return df
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def pred_text_to_go(df, with_prob=False):
|
| 106 |
+
# df['pred'] = df['pred'].apply(lambda x: re.sub('</s>', '', x))
|
| 107 |
+
if with_prob:
|
| 108 |
+
df['pred_list_prob'] = df['pred'].apply(lambda x: [eval(i.strip()) for i in x.split(';')])
|
| 109 |
+
df['pred_list'] = df['pred_list_prob'].apply(lambda x: [i[0] for i in x])
|
| 110 |
+
else:
|
| 111 |
+
df['pred_list'] = df['pred'].apply(lambda x: list(set([i.strip() for i in x.split(';')])))
|
| 112 |
+
### 预测的文本如果不在GO标签词中,则算作最相似的GO标签
|
| 113 |
+
t0 = time.time()
|
| 114 |
+
txt_dict = {}
|
| 115 |
+
all_txt = []
|
| 116 |
+
for txt in df['pred_list']:
|
| 117 |
+
if type(txt) == str:
|
| 118 |
+
all_txt.extend(eval(txt))
|
| 119 |
+
else:
|
| 120 |
+
all_txt.extend(txt)
|
| 121 |
+
all_txt = list(set(all_txt))
|
| 122 |
+
if '' in all_txt:
|
| 123 |
+
all_txt.remove('')
|
| 124 |
+
n = len(all_txt)
|
| 125 |
+
thread = 10
|
| 126 |
+
size = int(n / thread)
|
| 127 |
+
inds = list(range(0, n, size))
|
| 128 |
+
inds.append(n)
|
| 129 |
+
all_txt_sep = [all_txt[i: min(i + size, n)] for i in inds[:-1]]
|
| 130 |
+
with Pool(processes=thread) as pool:
|
| 131 |
+
result = pool.map(fuzzy_match, all_txt_sep)
|
| 132 |
+
pool.close()
|
| 133 |
+
pool.join()
|
| 134 |
+
for d in result:
|
| 135 |
+
txt_dict.update(d)
|
| 136 |
+
# print(txt_dict)
|
| 137 |
+
# for txt in all_txt[:10]:
|
| 138 |
+
# fuzzy_match(txt)
|
| 139 |
+
if with_prob:
|
| 140 |
+
df['pred_list_prob'] = df['pred_list_prob'].apply(lambda x: txt_map_prob(x, txt_dict))
|
| 141 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 142 |
+
df['pred_list_go_prob'] = df['pred_list_prob'].apply(lambda x: go_map_prob(x, GO_dict))
|
| 143 |
+
n0 = df.shape[0]
|
| 144 |
+
df['len'] = df['pred_list_go_prob'].apply(lambda x: len(x))
|
| 145 |
+
df = df[df['len'] > 0]
|
| 146 |
+
df = df.drop('len', axis=1)
|
| 147 |
+
df = df.dropna()
|
| 148 |
+
print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
|
| 149 |
+
else:
|
| 150 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: txt_map(x, txt_dict))
|
| 151 |
+
df['pred_list'] = df['pred_list'].apply(lambda x: [i.lower() for i in list(set(x))])
|
| 152 |
+
print("fuzzy matching time: {}".format(time.time() - t0))
|
| 153 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: go_map(x, GO_dict))
|
| 154 |
+
|
| 155 |
+
n0 = df.shape[0]
|
| 156 |
+
df['len'] = df['pred_list_go'].apply(lambda x: len(x))
|
| 157 |
+
df = df[df['len'] > 0]
|
| 158 |
+
df = df.drop('len', axis=1)
|
| 159 |
+
df = df.dropna()
|
| 160 |
+
print('{}条数据,不为空的预测有{}条'.format(n0, df.shape[0]))
|
| 161 |
+
return df
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cal_f1(df):
|
| 165 |
+
df['label_list_go'] = df['label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 166 |
+
df['pred_list_go'] = df['pred_list'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 167 |
+
|
| 168 |
+
labels = []
|
| 169 |
+
pred_labels = []
|
| 170 |
+
for l in df['label_list_go']:
|
| 171 |
+
labels.extend(l)
|
| 172 |
+
|
| 173 |
+
label_count = {}
|
| 174 |
+
for x in labels:
|
| 175 |
+
if x not in label_count:
|
| 176 |
+
label_count[x] = 1
|
| 177 |
+
else:
|
| 178 |
+
label_count[x] += 1
|
| 179 |
+
|
| 180 |
+
labels = list(set(labels))
|
| 181 |
+
total = len(labels)
|
| 182 |
+
tp_dict, fp_dict, fn_dict = dict(zip(labels, [0] * len(labels))), dict(zip(labels, [0] * len(labels))), dict(
|
| 183 |
+
zip(labels, [0] * len(labels)))
|
| 184 |
+
for preds, label in zip(df['pred_list_go'], df['label_list_go']):
|
| 185 |
+
for t in label:
|
| 186 |
+
# supgo = godb.get_anchestors(t)
|
| 187 |
+
# if supgo.intersection(set(preds)):
|
| 188 |
+
if t in preds:
|
| 189 |
+
tp_dict[t] += 1
|
| 190 |
+
else:
|
| 191 |
+
fn_dict[t] += 1
|
| 192 |
+
for p in preds:
|
| 193 |
+
# supgo = godb.get_anchestors(p)
|
| 194 |
+
# if not supgo.intersection(set(label)):
|
| 195 |
+
if p not in label:
|
| 196 |
+
if p in fp_dict:
|
| 197 |
+
fp_dict[p] += 1
|
| 198 |
+
else:
|
| 199 |
+
fp_dict[p] = 1
|
| 200 |
+
pred_labels.extend(preds)
|
| 201 |
+
p_total = len(set(pred_labels))
|
| 202 |
+
recall, pr = 0., 0.
|
| 203 |
+
for x in labels:
|
| 204 |
+
recall += tp_dict[x] / (1.0 * (tp_dict[x] + fn_dict[x] + 1e-8))
|
| 205 |
+
pr += tp_dict[x] / (1.0 * (tp_dict[x] + fp_dict[x] + 1e-8))
|
| 206 |
+
r = recall / total
|
| 207 |
+
p = pr / p_total
|
| 208 |
+
f1 = 2 * p * r / (p + r)
|
| 209 |
+
|
| 210 |
+
print("preds not in labels: {}".format(len(list(fp_dict.keys())) - total))
|
| 211 |
+
print("recall:{}; percision:{}; f1 score: {}".format(r, p, f1))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def cat_go(x):
|
| 215 |
+
try:
|
| 216 |
+
cat = godb.get_namespace(x)
|
| 217 |
+
except:
|
| 218 |
+
print("{} not found".format(x))
|
| 219 |
+
return
|
| 220 |
+
if cat == NAMESPACES['mf']:
|
| 221 |
+
return 'mf'
|
| 222 |
+
elif cat == NAMESPACES['bp']:
|
| 223 |
+
return 'bp'
|
| 224 |
+
elif cat == NAMESPACES['cc']:
|
| 225 |
+
return 'cc'
|
| 226 |
+
return
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def remove_root(x):
|
| 230 |
+
if 'molecular_function' in x:
|
| 231 |
+
x.remove('molecular_function')
|
| 232 |
+
if 'biological_process' in x:
|
| 233 |
+
x.remove('biological_process')
|
| 234 |
+
if 'cellular_component' in x:
|
| 235 |
+
x.remove('cellular_component')
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
NAMESPACES = {
|
| 240 |
+
'cc': 'cellular_component',
|
| 241 |
+
'mf': 'molecular_function',
|
| 242 |
+
'bp': 'biological_process'
|
| 243 |
+
}
|
| 244 |
+
#if not os.path.exists('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl'):
|
| 245 |
+
if 1==1:
|
| 246 |
+
data = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/swissprot_domain_and_train_exp_prompt_new.csv', sep='|')
|
| 247 |
+
print('数据规模:{}'.format(data.shape[0]))
|
| 248 |
+
# data['function'] = data['function'].apply(lambda x: re.sub('[FPC]:', '', x))
|
| 249 |
+
# data.to_csv('swissprot_domain_and_train_exp.csv', sep='|', index=False)
|
| 250 |
+
|
| 251 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
| 252 |
+
go_des = pd.read_csv('/cluster/home/wenkai/LAVIS/data/go_descriptions1.4.txt', sep='|', header=None)
|
| 253 |
+
go_des.columns = ['id', 'text']
|
| 254 |
+
go_des = go_des.dropna()
|
| 255 |
+
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
|
| 256 |
+
go_des['ont'] = go_des['id'].apply(lambda x: cat_go(x))
|
| 257 |
+
go_des = go_des.dropna()
|
| 258 |
+
go_obo_set = set(go_des['id'].tolist())
|
| 259 |
+
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
| 260 |
+
|
| 261 |
+
data['GO_label'] = data['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 262 |
+
data = prop(data)
|
| 263 |
+
|
| 264 |
+
# 加入父节点,得到完整的terms,映射表等等
|
| 265 |
+
go_dict = {}
|
| 266 |
+
for x_list in data['prop_annotations']:
|
| 267 |
+
for goid in x_list:
|
| 268 |
+
if goid in go_dict:
|
| 269 |
+
go_dict[goid] += 1
|
| 270 |
+
else:
|
| 271 |
+
go_dict[goid] = 1
|
| 272 |
+
df_stat = pd.DataFrame({'id': list(go_dict.keys()), 'count': list(go_dict.values())})
|
| 273 |
+
data_gos = set(df_stat['id'].tolist())
|
| 274 |
+
go_des = go_des[go_des['id'].isin(data_gos)]
|
| 275 |
+
filter_go = data_gos.intersection(go_obo_set)
|
| 276 |
+
print(f"包括父节点的GO有{len(data_gos)}个,其中在go1.4.obo中出现的GO有{len(filter_go)}个")
|
| 277 |
+
|
| 278 |
+
go_des.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
|
| 279 |
+
id2text_dict = dict(zip(go_des['id'], go_des['text']))
|
| 280 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
| 281 |
+
|
| 282 |
+
choices_mf = list(set(go_des[go_des['ont'] == 'mf']['text']))
|
| 283 |
+
choices_bp = list(set(go_des[go_des['ont'] == 'bp']['text']))
|
| 284 |
+
choices_cc = list(set(go_des[go_des['ont'] == 'cc']['text']))
|
| 285 |
+
|
| 286 |
+
choices_mf = {x.lower(): x for x in choices_mf}
|
| 287 |
+
choices_bp = {x.lower(): x for x in choices_bp}
|
| 288 |
+
choices_cc = {x.lower(): x for x in choices_cc}
|
| 289 |
+
|
| 290 |
+
data['GO_label'] = data['GO_label'].apply(lambda x: filter(x))
|
| 291 |
+
data = data[data['GO_label'] != '']
|
| 292 |
+
data['function'] = data['GO_label'].apply(lambda x: [id2text_dict[i.strip()] for i in x.split(';')])
|
| 293 |
+
data['function'] = data['function'].apply(lambda x: '; '.join(x))
|
| 294 |
+
|
| 295 |
+
terms = pd.DataFrame({'gos': list(filter_go)})
|
| 296 |
+
terms.to_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
|
| 297 |
+
terms.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/terms.pkl')
|
| 298 |
+
|
| 299 |
+
terms_mf = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'mf']['id']))})
|
| 300 |
+
terms_mf.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/terms.pkl')
|
| 301 |
+
terms_mf.to_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
| 302 |
+
terms_bp = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'bp']['id']))})
|
| 303 |
+
terms_bp.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/bp/terms.pkl')
|
| 304 |
+
terms_bp.to_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
|
| 305 |
+
terms_cc = pd.DataFrame({'gos': list(set(go_des[go_des['ont'] == 'cc']['id']))})
|
| 306 |
+
terms_cc.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/cc/terms.pkl')
|
| 307 |
+
terms_cc.to_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
|
| 308 |
+
else:
|
| 309 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
| 310 |
+
terms = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/terms.pkl')
|
| 311 |
+
filter_go = set(terms['gos'].tolist())
|
| 312 |
+
|
| 313 |
+
terms_mf = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/mf/terms.pkl')
|
| 314 |
+
terms_bp = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/bp/terms.pkl')
|
| 315 |
+
terms_cc = pd.read_pickle('/cluster/home/wenkai/deepgo2/data/cc/terms.pkl')
|
| 316 |
+
|
| 317 |
+
choices_mf = {x.lower(): x for x in terms_mf['gos'].tolist()}
|
| 318 |
+
choices_bp = {x.lower(): x for x in terms_bp['gos'].tolist()}
|
| 319 |
+
choices_cc = {x.lower(): x for x in terms_cc['gos'].tolist()}
|
| 320 |
+
|
| 321 |
+
go_des = pd.read_pickle('/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/go_des.pkl')
|
| 322 |
+
id2text_dict = dict(zip(go_des['id'], go_des['text']))
|
| 323 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
| 324 |
+
|
| 325 |
+
# 对于预测文件,进行GO筛选,并用相似度算法匹配到filter_go;对于train test val 文件,进行GO筛选、加入祖先、加入interPro特征
|
| 326 |
+
# 加入interpro特征
|
| 327 |
+
df_interpro = pd.read_csv('/cluster/home/wenkai/LAVIS/data/uniprot_sprot_blip2_func_data.txt', sep='|',
|
| 328 |
+
nrows=546389,
|
| 329 |
+
header=None)
|
| 330 |
+
df_interpro.columns = ['name', 'seq', 'go', 'text', 'evi', 'ipr']
|
| 331 |
+
df_interpro = df_interpro[df_interpro['ipr'].notnull()]
|
| 332 |
+
df_interpro['ipr'] = df_interpro['ipr'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 333 |
+
|
| 334 |
+
iprs = []
|
| 335 |
+
for x in df_interpro['ipr'].tolist():
|
| 336 |
+
if len(x) > 0:
|
| 337 |
+
iprs.extend(x)
|
| 338 |
+
iprs = list(set(iprs))
|
| 339 |
+
print("ipr个数:{}".format(len(iprs)))
|
| 340 |
+
df_ipr = pd.DataFrame({'interpros': iprs})
|
| 341 |
+
df_ipr.to_pickle('/cluster/home/wenkai/LAVIS/data/interpros.pkl')
|
| 342 |
+
df_ipr.to_pickle('/cluster/home/wenkai/deepgozero/data/blip2/pretrain/interpros.pkl')
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
'''
|
| 346 |
+
# test cases
|
| 347 |
+
df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/test_2000.csv', sep='|')
|
| 348 |
+
df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 349 |
+
#df_real[col] = df_real[col].apply(lambda x: filter(x))
|
| 350 |
+
df_real = df_real[df_real[col] != '']
|
| 351 |
+
print(df_real.shape)
|
| 352 |
+
#df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 353 |
+
#df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
| 354 |
+
df_real = prop(df_real)
|
| 355 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 356 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
| 357 |
+
#df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
| 358 |
+
for ont in ['mf', 'bp', 'cc']:
|
| 359 |
+
file_name = 'output_{}_test_2000'.format(ont)
|
| 360 |
+
if ont == 'mf':
|
| 361 |
+
choices = choices_mf
|
| 362 |
+
elif ont == 'bp':
|
| 363 |
+
choices = choices_bp
|
| 364 |
+
elif ont == 'cc':
|
| 365 |
+
choices = choices_cc
|
| 366 |
+
print("对{}预测文本进行标准化...".format(file_name))
|
| 367 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file_name), sep='|', header=None, on_bad_lines='skip')
|
| 368 |
+
df_pred.columns = ['name', 'pred', 'label']
|
| 369 |
+
n0 = df_pred.shape[0]
|
| 370 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
| 371 |
+
print("{}中有{}条数据未能找到相似度高的GO描述".format(file_name, n0-df_pred.shape[0]))
|
| 372 |
+
#df_pred['pred_list'] = df_pred['pred_list'].apply(lambda x: '; '.join(x))
|
| 373 |
+
#cal_f1(df_pred)
|
| 374 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file_name), sep='|', index=False)
|
| 375 |
+
|
| 376 |
+
df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
|
| 377 |
+
df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
|
| 378 |
+
ipr_and_pred = []
|
| 379 |
+
for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
|
| 380 |
+
try:
|
| 381 |
+
ipr_and_pred.append(x + y)
|
| 382 |
+
except:
|
| 383 |
+
ipr_and_pred.append(y)
|
| 384 |
+
df_pred['ipr_and_pred'] = ipr_and_pred
|
| 385 |
+
print(df_real.isnull().sum())
|
| 386 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
| 387 |
+
#df_pred = df_pred.dropna()
|
| 388 |
+
print(df_pred.shape)
|
| 389 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 390 |
+
'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{}/test_2000_data.pkl'.format(ont))
|
| 391 |
+
'''
|
| 392 |
+
|
| 393 |
+
'''
|
| 394 |
+
df_real = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/nextprot_mf.csv', sep='|')
|
| 395 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 396 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 397 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
| 398 |
+
df_real = prop(df_real)
|
| 399 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 400 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
| 401 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
| 402 |
+
|
| 403 |
+
file = 'output_nextprot'
|
| 404 |
+
choices = choices_mf
|
| 405 |
+
df_pred = pd.read_csv('/cluster/home/wenkai/LAVIS/output/{}.txt'.format(file), sep='|', header=None, on_bad_lines='skip')
|
| 406 |
+
df_pred.columns = ['name', 'pred', 'label']
|
| 407 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
| 408 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv('/cluster/home/wenkai/LAVIS/output/{}_standard.csv'.format(file), sep='|', index=False)
|
| 409 |
+
|
| 410 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
| 411 |
+
df_pred['ipr'] = [[] for _ in range(df_pred.shape[0])]
|
| 412 |
+
df_pred['ipr_and_pred'] = df_pred['pred_list_go_prob']
|
| 413 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 414 |
+
'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/mf/nextprot_data.pkl')
|
| 415 |
+
'''
|
| 416 |
+
# '''
|
| 417 |
+
cat_id = {'mf': '445772', 'bp': '496359', 'cc': '505955'}
|
| 418 |
+
col = 'GO_label'
|
| 419 |
+
for ont in ['mf', 'bp', 'cc']:
|
| 420 |
+
#for ont in ['mf']:
|
| 421 |
+
if ont == 'mf':
|
| 422 |
+
choices = choices_mf
|
| 423 |
+
elif ont == 'bp':
|
| 424 |
+
choices = choices_bp
|
| 425 |
+
elif ont == 'cc':
|
| 426 |
+
choices = choices_cc
|
| 427 |
+
for split in ['train', 'val', 'test']:
|
| 428 |
+
#for split in ['test']:
|
| 429 |
+
df_real = pd.read_csv(f'/cluster/home/wenkai/LAVIS/data/pretrain/mf_bp_cc/{split}_exp_{ont}_new.csv',
|
| 430 |
+
sep='|')
|
| 431 |
+
df_real[col] = df_real[col].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 432 |
+
df_real[col] = df_real[col].apply(lambda x: filter(x))
|
| 433 |
+
df_real = df_real[df_real[col] != '']
|
| 434 |
+
print(df_real.shape)
|
| 435 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 436 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 437 |
+
df_real['GO_label'] = df_real['GO_label'].apply(lambda x: [GO_dict[i] for i in x])
|
| 438 |
+
df_real = prop(df_real)
|
| 439 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: [id2text_dict[i] for i in x])
|
| 440 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: remove_root(x))
|
| 441 |
+
df_real['prop_annotations'] = df_real['prop_annotations'].apply(lambda x: list(set([GO_dict[i] for i in x])))
|
| 442 |
+
|
| 443 |
+
# 预测text转为go
|
| 444 |
+
df_pred = pd.read_csv(
|
| 445 |
+
f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_exp_{cat_id[ont]}.txt', sep='|',
|
| 446 |
+
header=None, on_bad_lines='skip')
|
| 447 |
+
df_pred.columns = ['name', 'pred', 'label']
|
| 448 |
+
n0 = df_pred.shape[0]
|
| 449 |
+
df_pred = pred_text_to_go(df_pred, with_prob=True)
|
| 450 |
+
print("{}中有{}条数据未能找到相似度高的GO描述".format(ont, n0 - df_pred.shape[0]))
|
| 451 |
+
df_pred[['name', 'pred_list_prob', 'label']].to_csv(
|
| 452 |
+
f'/cluster/home/wenkai/LAVIS/output/mf_bp_cc/output_{split}_{ont}_{cat_id[ont]}_standard.csv', sep='|',
|
| 453 |
+
index=False)
|
| 454 |
+
|
| 455 |
+
df_pred = pd.merge(df_pred[['name', 'pred_list_go_prob']], df_interpro[['name', 'ipr']], on='name', how='left')
|
| 456 |
+
df_pred['ipr'] = df_pred['ipr'].fillna("").apply(list)
|
| 457 |
+
ipr_and_pred = []
|
| 458 |
+
for x, y in zip(df_pred['ipr'], df_pred['pred_list_go_prob']):
|
| 459 |
+
try:
|
| 460 |
+
ipr_and_pred.append(x + y)
|
| 461 |
+
except:
|
| 462 |
+
ipr_and_pred.append(y)
|
| 463 |
+
df_pred['ipr_and_pred'] = ipr_and_pred
|
| 464 |
+
|
| 465 |
+
df_pred = pd.merge(df_pred, df_real[['name', 'protein', 'prop_annotations']], on='name', how='left')
|
| 466 |
+
df_pred = df_pred.dropna()
|
| 467 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 468 |
+
f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/{split}_data_{cat_id[ont]}.pkl')
|
| 469 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 470 |
+
f'/cluster/home/wenkai/deepgo2/data/{ont}/{split}_data_{cat_id[ont]}.pkl')
|
| 471 |
+
if split == 'val':
|
| 472 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 473 |
+
f'/cluster/home/wenkai/deepgozero/data/blip2/pretrain/{ont}/valid_data_{cat_id[ont]}.pkl')
|
| 474 |
+
df_pred[['name', 'protein', 'ipr', 'pred_list_go_prob', 'ipr_and_pred', 'prop_annotations']].to_pickle(
|
| 475 |
+
f'/cluster/home/wenkai/deepgo2/data/{ont}/valid_data_{cat_id[ont]}.pkl')
|
| 476 |
+
print(f"{ont} {split} deepgozero propagation data completed")
|
| 477 |
+
# '''
|
data/evaluate_data/process_case.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from utils import Ontology
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def prop(df):
|
| 6 |
+
prop_annotations = []
|
| 7 |
+
for i, row in df.iterrows():
|
| 8 |
+
# Propagate annotations
|
| 9 |
+
annot_set = set()
|
| 10 |
+
annots = row['GO_label']
|
| 11 |
+
for go_id in annots:
|
| 12 |
+
annot_set |= godb.get_anchestors(go_id)
|
| 13 |
+
annots = list(annot_set)
|
| 14 |
+
prop_annotations.append(annots)
|
| 15 |
+
df['prop_annotations'] = prop_annotations
|
| 16 |
+
return df
|
| 17 |
+
|
| 18 |
+
godb = Ontology(f'/cluster/home/wenkai/LAVIS/data/go1.4-basic.obo', with_rels=True)
|
| 19 |
+
|
| 20 |
+
case_mf = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_mf.csv', sep='|')
|
| 21 |
+
|
| 22 |
+
# bp case, 包括辣椒受体
|
| 23 |
+
case_bp = pd.read_csv('/cluster/home/wenkai/LAVIS/data/pretrain/cases_bp.csv', sep='|')
|
| 24 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 25 |
+
case_bp = prop(case_bp)
|
| 26 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: '; '.join(x))
|
| 27 |
+
case_bp['prop_annotations'] = case_bp['prop_annotations'].apply(lambda x: '; '.join(x))
|
| 28 |
+
case_bp[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/bp/cases_data.pkl')
|
| 29 |
+
|
| 30 |
+
case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 31 |
+
case_mf = prop(case_mf)
|
| 32 |
+
case_mf['GO_label'] = case_mf['GO_label'].apply(lambda x: '; '.join(x))
|
| 33 |
+
case_mf['prop_annotations'] = case_mf['prop_annotations'].apply(lambda x: '; '.join(x))
|
| 34 |
+
|
| 35 |
+
case_bp['GO_label'] = case_bp['GO_label'].apply(lambda x: [i.strip() for i in x.split(';')])
|
| 36 |
+
case_bp = prop(case_bp)
|
| 37 |
+
case_mf[['name', 'protein', 'function', 'GO_label', 'id', 'prompt', 'prop_annotations']].to_pickle('/cluster/home/wenkai/deepgo2/data/mf/cases_data_445772.pkl')
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
data/evaluate_data/utils.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque, Counter
|
| 2 |
+
import warnings
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
from xml.etree import ElementTree as ET
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
BIOLOGICAL_PROCESS = 'GO:0008150'
|
| 9 |
+
MOLECULAR_FUNCTION = 'GO:0003674'
|
| 10 |
+
CELLULAR_COMPONENT = 'GO:0005575'
|
| 11 |
+
FUNC_DICT = {
|
| 12 |
+
'cc': CELLULAR_COMPONENT,
|
| 13 |
+
'mf': MOLECULAR_FUNCTION,
|
| 14 |
+
'bp': BIOLOGICAL_PROCESS}
|
| 15 |
+
|
| 16 |
+
NAMESPACES = {
|
| 17 |
+
'cc': 'cellular_component',
|
| 18 |
+
'mf': 'molecular_function',
|
| 19 |
+
'bp': 'biological_process'
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
EXP_CODES = set([
|
| 23 |
+
'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC',
|
| 24 |
+
'HTP', 'HDA', 'HMP', 'HGI', 'HEP'])
|
| 25 |
+
|
| 26 |
+
# CAFA4 Targets
|
| 27 |
+
CAFA_TARGETS = set([
|
| 28 |
+
'287', '3702', '4577', '6239', '7227', '7955', '9606', '9823', '10090',
|
| 29 |
+
'10116', '44689', '83333', '99287', '226900', '243273', '284812', '559292'])
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def is_cafa_target(org):
|
| 33 |
+
return org in CAFA_TARGETS
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def is_exp_code(code):
|
| 37 |
+
return code in EXP_CODES
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_goplus_defs(filename='data/definitions.txt'):
|
| 41 |
+
plus_defs = {}
|
| 42 |
+
with open(filename) as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
line = line.strip()
|
| 45 |
+
go_id, definition = line.split(': ')
|
| 46 |
+
go_id = go_id.replace('_', ':')
|
| 47 |
+
definition = definition.replace('_', ':')
|
| 48 |
+
plus_defs[go_id] = set(definition.split(' and '))
|
| 49 |
+
return plus_defs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Ontology(object):
|
| 53 |
+
|
| 54 |
+
def __init__(self, filename='data/go.obo', with_rels=False):
|
| 55 |
+
self.ont = self.load(filename, with_rels)
|
| 56 |
+
self.ic = None
|
| 57 |
+
self.ic_norm = 0.0
|
| 58 |
+
|
| 59 |
+
def has_term(self, term_id):
|
| 60 |
+
return term_id in self.ont
|
| 61 |
+
|
| 62 |
+
def get_term(self, term_id):
|
| 63 |
+
if self.has_term(term_id):
|
| 64 |
+
return self.ont[term_id]
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
def calculate_ic(self, annots):
|
| 68 |
+
cnt = Counter()
|
| 69 |
+
for x in annots:
|
| 70 |
+
cnt.update(x)
|
| 71 |
+
self.ic = {}
|
| 72 |
+
for go_id, n in cnt.items():
|
| 73 |
+
parents = self.get_parents(go_id)
|
| 74 |
+
if len(parents) == 0:
|
| 75 |
+
min_n = n
|
| 76 |
+
else:
|
| 77 |
+
min_n = min([cnt[x] for x in parents])
|
| 78 |
+
|
| 79 |
+
self.ic[go_id] = math.log(min_n / n, 2)
|
| 80 |
+
self.ic_norm = max(self.ic_norm, self.ic[go_id])
|
| 81 |
+
|
| 82 |
+
def get_ic(self, go_id):
|
| 83 |
+
if self.ic is None:
|
| 84 |
+
raise Exception('Not yet calculated')
|
| 85 |
+
if go_id not in self.ic:
|
| 86 |
+
return 0.0
|
| 87 |
+
return self.ic[go_id]
|
| 88 |
+
|
| 89 |
+
def get_norm_ic(self, go_id):
|
| 90 |
+
return self.get_ic(go_id) / self.ic_norm
|
| 91 |
+
|
| 92 |
+
def load(self, filename, with_rels):
|
| 93 |
+
ont = dict()
|
| 94 |
+
obj = None
|
| 95 |
+
with open(filename, 'r') as f:
|
| 96 |
+
for line in f:
|
| 97 |
+
line = line.strip()
|
| 98 |
+
if not line:
|
| 99 |
+
continue
|
| 100 |
+
if line == '[Term]':
|
| 101 |
+
if obj is not None:
|
| 102 |
+
ont[obj['id']] = obj
|
| 103 |
+
obj = dict()
|
| 104 |
+
obj['is_a'] = list()
|
| 105 |
+
obj['part_of'] = list()
|
| 106 |
+
obj['regulates'] = list()
|
| 107 |
+
obj['alt_ids'] = list()
|
| 108 |
+
obj['is_obsolete'] = False
|
| 109 |
+
continue
|
| 110 |
+
elif line == '[Typedef]':
|
| 111 |
+
if obj is not None:
|
| 112 |
+
ont[obj['id']] = obj
|
| 113 |
+
obj = None
|
| 114 |
+
else:
|
| 115 |
+
if obj is None:
|
| 116 |
+
continue
|
| 117 |
+
l = line.split(": ")
|
| 118 |
+
if l[0] == 'id':
|
| 119 |
+
obj['id'] = l[1]
|
| 120 |
+
elif l[0] == 'alt_id':
|
| 121 |
+
obj['alt_ids'].append(l[1])
|
| 122 |
+
elif l[0] == 'namespace':
|
| 123 |
+
obj['namespace'] = l[1]
|
| 124 |
+
elif l[0] == 'is_a':
|
| 125 |
+
obj['is_a'].append(l[1].split(' ! ')[0])
|
| 126 |
+
elif with_rels and l[0] == 'relationship':
|
| 127 |
+
it = l[1].split()
|
| 128 |
+
# add all types of relationships
|
| 129 |
+
obj['is_a'].append(it[1])
|
| 130 |
+
elif l[0] == 'name':
|
| 131 |
+
obj['name'] = l[1]
|
| 132 |
+
elif l[0] == 'is_obsolete' and l[1] == 'true':
|
| 133 |
+
obj['is_obsolete'] = True
|
| 134 |
+
if obj is not None:
|
| 135 |
+
ont[obj['id']] = obj
|
| 136 |
+
for term_id in list(ont.keys()):
|
| 137 |
+
for t_id in ont[term_id]['alt_ids']:
|
| 138 |
+
ont[t_id] = ont[term_id]
|
| 139 |
+
if ont[term_id]['is_obsolete']:
|
| 140 |
+
del ont[term_id]
|
| 141 |
+
for term_id, val in ont.items():
|
| 142 |
+
if 'children' not in val:
|
| 143 |
+
val['children'] = set()
|
| 144 |
+
for p_id in val['is_a']:
|
| 145 |
+
if p_id in ont:
|
| 146 |
+
if 'children' not in ont[p_id]:
|
| 147 |
+
ont[p_id]['children'] = set()
|
| 148 |
+
ont[p_id]['children'].add(term_id)
|
| 149 |
+
|
| 150 |
+
return ont
|
| 151 |
+
|
| 152 |
+
def get_anchestors(self, term_id):
|
| 153 |
+
if term_id not in self.ont:
|
| 154 |
+
return set()
|
| 155 |
+
term_set = set()
|
| 156 |
+
q = deque()
|
| 157 |
+
q.append(term_id)
|
| 158 |
+
while (len(q) > 0):
|
| 159 |
+
t_id = q.popleft()
|
| 160 |
+
if t_id not in term_set:
|
| 161 |
+
term_set.add(t_id)
|
| 162 |
+
for parent_id in self.ont[t_id]['is_a']:
|
| 163 |
+
if parent_id in self.ont:
|
| 164 |
+
q.append(parent_id)
|
| 165 |
+
return term_set
|
| 166 |
+
|
| 167 |
+
def get_prop_terms(self, terms):
|
| 168 |
+
prop_terms = set()
|
| 169 |
+
|
| 170 |
+
for term_id in terms:
|
| 171 |
+
prop_terms |= self.get_anchestors(term_id)
|
| 172 |
+
return prop_terms
|
| 173 |
+
|
| 174 |
+
def get_parents(self, term_id):
|
| 175 |
+
if term_id not in self.ont:
|
| 176 |
+
return set()
|
| 177 |
+
term_set = set()
|
| 178 |
+
for parent_id in self.ont[term_id]['is_a']:
|
| 179 |
+
if parent_id in self.ont:
|
| 180 |
+
term_set.add(parent_id)
|
| 181 |
+
return term_set
|
| 182 |
+
|
| 183 |
+
def get_namespace_terms(self, namespace):
|
| 184 |
+
terms = set()
|
| 185 |
+
for go_id, obj in self.ont.items():
|
| 186 |
+
if obj['namespace'] == namespace:
|
| 187 |
+
terms.add(go_id)
|
| 188 |
+
return terms
|
| 189 |
+
|
| 190 |
+
def get_namespace(self, term_id):
|
| 191 |
+
return self.ont[term_id]['namespace']
|
| 192 |
+
|
| 193 |
+
def get_term_set(self, term_id):
|
| 194 |
+
if term_id not in self.ont:
|
| 195 |
+
return set()
|
| 196 |
+
term_set = set()
|
| 197 |
+
q = deque()
|
| 198 |
+
q.append(term_id)
|
| 199 |
+
while len(q) > 0:
|
| 200 |
+
t_id = q.popleft()
|
| 201 |
+
if t_id not in term_set:
|
| 202 |
+
term_set.add(t_id)
|
| 203 |
+
for ch_id in self.ont[t_id]['children']:
|
| 204 |
+
q.append(ch_id)
|
| 205 |
+
return term_set
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def read_fasta(filename):
|
| 209 |
+
seqs = list()
|
| 210 |
+
info = list()
|
| 211 |
+
seq = ''
|
| 212 |
+
inf = ''
|
| 213 |
+
with open(filename, 'r') as f:
|
| 214 |
+
for line in f:
|
| 215 |
+
line = line.strip()
|
| 216 |
+
if line.startswith('>'):
|
| 217 |
+
if seq != '':
|
| 218 |
+
seqs.append(seq)
|
| 219 |
+
info.append(inf)
|
| 220 |
+
seq = ''
|
| 221 |
+
inf = line[1:].split()[0]
|
| 222 |
+
else:
|
| 223 |
+
seq += line
|
| 224 |
+
seqs.append(seq)
|
| 225 |
+
info.append(inf)
|
| 226 |
+
return info, seqs
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class DataGenerator(object):
|
| 230 |
+
|
| 231 |
+
def __init__(self, batch_size, is_sparse=False):
|
| 232 |
+
self.batch_size = batch_size
|
| 233 |
+
self.is_sparse = is_sparse
|
| 234 |
+
|
| 235 |
+
def fit(self, inputs, targets=None):
|
| 236 |
+
self.start = 0
|
| 237 |
+
self.inputs = inputs
|
| 238 |
+
self.targets = targets
|
| 239 |
+
if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
|
| 240 |
+
self.size = self.inputs[0].shape[0]
|
| 241 |
+
else:
|
| 242 |
+
self.size = self.inputs.shape[0]
|
| 243 |
+
self.has_targets = targets is not None
|
| 244 |
+
|
| 245 |
+
def __next__(self):
|
| 246 |
+
return self.next()
|
| 247 |
+
|
| 248 |
+
def reset(self):
|
| 249 |
+
self.start = 0
|
| 250 |
+
|
| 251 |
+
def next(self):
|
| 252 |
+
if self.start < self.size:
|
| 253 |
+
batch_index = np.arange(
|
| 254 |
+
self.start, min(self.size, self.start + self.batch_size))
|
| 255 |
+
if isinstance(self.inputs, tuple) or isinstance(self.inputs, list):
|
| 256 |
+
res_inputs = []
|
| 257 |
+
for inp in self.inputs:
|
| 258 |
+
if self.is_sparse:
|
| 259 |
+
res_inputs.append(
|
| 260 |
+
inp[batch_index, :].toarray())
|
| 261 |
+
else:
|
| 262 |
+
res_inputs.append(inp[batch_index, :])
|
| 263 |
+
else:
|
| 264 |
+
if self.is_sparse:
|
| 265 |
+
res_inputs = self.inputs[batch_index, :].toarray()
|
| 266 |
+
else:
|
| 267 |
+
res_inputs = self.inputs[batch_index, :]
|
| 268 |
+
self.start += self.batch_size
|
| 269 |
+
if self.has_targets:
|
| 270 |
+
if self.is_sparse:
|
| 271 |
+
labels = self.targets[batch_index, :].toarray()
|
| 272 |
+
else:
|
| 273 |
+
labels = self.targets[batch_index, :]
|
| 274 |
+
return (res_inputs, labels)
|
| 275 |
+
return res_inputs
|
| 276 |
+
else:
|
| 277 |
+
self.reset()
|
| 278 |
+
return self.next()
|
| 279 |
+
|
| 280 |
+
|
data/go1.4-basic.obo
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3da20cc774d666b4338446bc81341eaf536885dc10ccb667480a79f6b964aa3c
|
| 3 |
+
size 31134256
|
data/go_descriptions1.4.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/test_exp_prompt_bp_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/test_exp_prompt_cc_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/test_exp_prompt_mf_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/train_exp_prompt_bp_new.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12359211ab95f1ce1962b69f033b55e9f502a7527f49414792d1c117ec50b0be
|
| 3 |
+
size 28503657
|
data/swissprot_exp/train_exp_prompt_cc_new.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:01c6144b0e338d3ce8ce98adfd4f9d09f56dc58cd347f4fbaafb6782d694ffd1
|
| 3 |
+
size 23292609
|
data/swissprot_exp/train_exp_prompt_mf_new.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ca3eee941dfc0ee37f59adec6abf8a7276441f04c484a9275274c7003ef4145e
|
| 3 |
+
size 18791760
|
data/swissprot_exp/val_exp_prompt_bp_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/val_exp_prompt_cc_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/swissprot_exp/val_exp_prompt_mf_new.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/terms/bp_terms.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4952f3551e4fe205640b81f9a1816c15c14cc889bbe55f57d378fb3c6d57f2f7
|
| 3 |
+
size 274892
|
data/terms/cc_terms.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20992c211336c4f876c920c2995ae85c1422e8742b7094c997aa70ddec7fc8fd
|
| 3 |
+
size 39440
|
data/terms/mf_terms.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:192861bad821ef3523ab2dcdd1db5eac093364e9b9b4869f75587d656864d29b
|
| 3 |
+
size 107802
|