Update app.py
Browse files
app.py
CHANGED
|
@@ -7,23 +7,135 @@ from lavis.models.protein_models.protein_function_opt import Blip2ProteinMistral
|
|
| 7 |
from lavis.models.base_model import FAPMConfig
|
| 8 |
import spaces
|
| 9 |
import gradio as gr
|
| 10 |
-
from esm_scripts.extract import run_demo
|
| 11 |
from esm import pretrained, FastaBatchedDataset
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
model
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@spaces.GPU
|
| 23 |
def generate_caption(protein, prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
'''
|
| 28 |
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
|
| 29 |
with torch.no_grad():
|
|
@@ -32,17 +144,50 @@ def generate_caption(protein, prompt):
|
|
| 32 |
'''
|
| 33 |
print("esm embedding generated")
|
| 34 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
samples = {'name': ['protein_name'],
|
| 37 |
'image': torch.unsqueeze(esm_emb, dim=0),
|
| 38 |
'text_input': ['none'],
|
| 39 |
'prompt': [prompt]}
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# return "test"
|
| 47 |
|
| 48 |
|
|
@@ -51,16 +196,50 @@ description = """Quick demonstration of the FAPM model for protein function pred
|
|
| 51 |
|
| 52 |
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
| 53 |
|
| 54 |
-
iface = gr.Interface(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
#
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
|
|
|
| 7 |
from lavis.models.base_model import FAPMConfig
|
| 8 |
import spaces
|
| 9 |
import gradio as gr
|
| 10 |
+
# from esm_scripts.extract import run_demo
|
| 11 |
from esm import pretrained, FastaBatchedDataset
|
| 12 |
+
from data.evaluate_data.utils import Ontology
|
| 13 |
+
import difflib
|
| 14 |
+
import re
|
| 15 |
+
from transformers import MistralForCausalLM
|
| 16 |
+
|
| 17 |
+
# Load the trained model
|
| 18 |
+
def get_model(type='Molecule Function'):
|
| 19 |
+
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
| 20 |
+
if type == 'Molecule Function':
|
| 21 |
+
model.load_checkpoint("model/checkpoint_mf2.pth")
|
| 22 |
+
model.to('cuda')
|
| 23 |
+
elif type == 'Biological Process':
|
| 24 |
+
model.load_checkpoint("model/checkpoint_bp1.pth")
|
| 25 |
+
model.to('cuda')
|
| 26 |
+
elif type == 'Cellar Component':
|
| 27 |
+
model.load_checkpoint("model/checkpoint_cc2.pth")
|
| 28 |
+
model.to('cuda')
|
| 29 |
+
return model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
models = {
|
| 33 |
+
'Molecule Function': get_model('Molecule Function'),
|
| 34 |
+
'Biological Process': get_model('Biological Process'),
|
| 35 |
+
'Cellular Component': get_model('Cellar Component'),
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Load the mistral model
|
| 40 |
+
mistral_model = MistralForCausalLM.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B", torch_dtype=torch.float16).to('cuda')
|
| 41 |
+
|
| 42 |
+
# Load ESM2 model
|
| 43 |
+
model_esm, alphabet = pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
|
| 44 |
+
model_esm.to('cuda')
|
| 45 |
+
model_esm.eval()
|
| 46 |
+
|
| 47 |
+
godb = Ontology(f'data/go1.4-basic.obo', with_rels=True)
|
| 48 |
+
go_des = pd.read_csv('data/go_descriptions1.4.txt', sep='|', header=None)
|
| 49 |
+
go_des.columns = ['id', 'text']
|
| 50 |
+
go_des = go_des.dropna()
|
| 51 |
+
go_des['id'] = go_des['id'].apply(lambda x: re.sub('_', ':', x))
|
| 52 |
+
go_obo_set = set(go_des['id'].tolist())
|
| 53 |
+
go_des['text'] = go_des['text'].apply(lambda x: x.lower())
|
| 54 |
+
GO_dict = dict(zip(go_des['text'], go_des['id']))
|
| 55 |
+
Func_dict = dict(zip(go_des['id'], go_des['text']))
|
| 56 |
+
|
| 57 |
+
terms_mf = pd.read_pickle('data/terms/mf_terms.pkl')
|
| 58 |
+
choices_mf = [Func_dict[i] for i in list(set(terms_mf['gos']))]
|
| 59 |
+
choices_mf = {x.lower(): x for x in choices_mf}
|
| 60 |
+
terms_bp = pd.read_pickle('data/terms/bp_terms.pkl')
|
| 61 |
+
choices_bp = [Func_dict[i] for i in list(set(terms_bp['gos']))]
|
| 62 |
+
choices_bp = {x.lower(): x for x in choices_bp}
|
| 63 |
+
terms_cc = pd.read_pickle('data/terms/cc_terms.pkl')
|
| 64 |
+
choices_cc = [Func_dict[i] for i in list(set(terms_cc['gos']))]
|
| 65 |
+
choices_cc = {x.lower(): x for x in choices_cc}
|
| 66 |
+
choices = {
|
| 67 |
+
'Molecule Function': choices_mf,
|
| 68 |
+
'Biological Process': choices_bp,
|
| 69 |
+
'Cellular Component': choices_cc,
|
| 70 |
+
}
|
| 71 |
|
| 72 |
@spaces.GPU
|
| 73 |
def generate_caption(protein, prompt):
|
| 74 |
+
# Process the image and the prompt
|
| 75 |
+
# with open('/home/user/app/example.fasta', 'w') as f:
|
| 76 |
+
# f.write('>{}\n'.format("protein_name"))
|
| 77 |
+
# f.write('{}\n'.format(protein.strip()))
|
| 78 |
+
# os.system("python esm_scripts/extract.py esm2_t36_3B_UR50D /home/user/app/example.fasta /home/user/app --repr_layers 36 --truncation_seq_length 1024 --include per_tok")
|
| 79 |
+
# esm_emb = run_demo(protein_name='protein_name', protein_seq=protein,
|
| 80 |
+
# model=model_esm, alphabet=alphabet,
|
| 81 |
+
# include='per_tok', repr_layers=[36], truncation_seq_length=1024)
|
| 82 |
+
|
| 83 |
+
protein_name = 'protein_name'
|
| 84 |
+
protein_seq = protein
|
| 85 |
+
include = 'per_tok'
|
| 86 |
+
repr_layers = [36]
|
| 87 |
+
truncation_seq_length = 1024
|
| 88 |
+
toks_per_batch = 4096
|
| 89 |
+
print("start")
|
| 90 |
+
dataset = FastaBatchedDataset([protein_name], [protein_seq])
|
| 91 |
+
print("dataset prepared")
|
| 92 |
+
batches = dataset.get_batch_indices(toks_per_batch, extra_toks_per_seq=1)
|
| 93 |
+
print("batches prepared")
|
| 94 |
+
|
| 95 |
+
data_loader = torch.utils.data.DataLoader(
|
| 96 |
+
dataset, collate_fn=alphabet.get_batch_converter(truncation_seq_length), batch_sampler=batches
|
| 97 |
+
)
|
| 98 |
+
print(f"Read sequences")
|
| 99 |
+
return_contacts = "contacts" in include
|
| 100 |
+
|
| 101 |
+
assert all(-(model_esm.num_layers + 1) <= i <= model_esm.num_layers for i in repr_layers)
|
| 102 |
+
repr_layers = [(i + model_esm.num_layers + 1) % (model_esm.num_layers + 1) for i in repr_layers]
|
| 103 |
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
| 106 |
+
print(
|
| 107 |
+
f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
|
| 108 |
+
)
|
| 109 |
+
if torch.cuda.is_available():
|
| 110 |
+
toks = toks.to(device="cuda", non_blocking=True)
|
| 111 |
+
out = model_esm(toks, repr_layers=repr_layers, return_contacts=return_contacts)
|
| 112 |
+
representations = {
|
| 113 |
+
layer: t.to(device="cpu") for layer, t in out["representations"].items()
|
| 114 |
+
}
|
| 115 |
+
if return_contacts:
|
| 116 |
+
contacts = out["contacts"].to(device="cpu")
|
| 117 |
+
for i, label in enumerate(labels):
|
| 118 |
+
result = {"label": label}
|
| 119 |
+
truncate_len = min(truncation_seq_length, len(strs[i]))
|
| 120 |
+
# Call clone on tensors to ensure tensors are not views into a larger representation
|
| 121 |
+
# See https://github.com/pytorch/pytorch/issues/1995
|
| 122 |
+
if "per_tok" in include:
|
| 123 |
+
result["representations"] = {
|
| 124 |
+
layer: t[i, 1: truncate_len + 1].clone()
|
| 125 |
+
for layer, t in representations.items()
|
| 126 |
+
}
|
| 127 |
+
if "mean" in include:
|
| 128 |
+
result["mean_representations"] = {
|
| 129 |
+
layer: t[i, 1: truncate_len + 1].mean(0).clone()
|
| 130 |
+
for layer, t in representations.items()
|
| 131 |
+
}
|
| 132 |
+
if "bos" in include:
|
| 133 |
+
result["bos_representations"] = {
|
| 134 |
+
layer: t[i, 0].clone() for layer, t in representations.items()
|
| 135 |
+
}
|
| 136 |
+
if return_contacts:
|
| 137 |
+
result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()
|
| 138 |
+
esm_emb = result['representations'][36]
|
| 139 |
'''
|
| 140 |
inputs = tokenizer([protein], return_tensors="pt", padding=True, truncation=True).to('cuda')
|
| 141 |
with torch.no_grad():
|
|
|
|
| 144 |
'''
|
| 145 |
print("esm embedding generated")
|
| 146 |
esm_emb = F.pad(esm_emb.t(), (0, 1024 - len(esm_emb))).t().to('cuda')
|
| 147 |
+
if prompt is None:
|
| 148 |
+
prompt = 'none'
|
| 149 |
+
else:
|
| 150 |
+
prompt = prompt.lower()
|
| 151 |
samples = {'name': ['protein_name'],
|
| 152 |
'image': torch.unsqueeze(esm_emb, dim=0),
|
| 153 |
'text_input': ['none'],
|
| 154 |
'prompt': [prompt]}
|
| 155 |
|
| 156 |
+
union_pred_terms = []
|
| 157 |
+
for model_id in models.keys():
|
| 158 |
+
model = models[model_id]
|
| 159 |
+
# Generate the output
|
| 160 |
+
prediction = model.generate(mistral_model, samples, length_penalty=0., num_beams=15, num_captions=10, temperature=1.,
|
| 161 |
+
repetition_penalty=1.0)
|
| 162 |
+
x = prediction[0]
|
| 163 |
+
x = [eval(i) for i in x.split('; ')]
|
| 164 |
+
pred_terms = []
|
| 165 |
+
temp = []
|
| 166 |
+
for i in x:
|
| 167 |
+
txt = i[0]
|
| 168 |
+
prob = i[1]
|
| 169 |
+
sim_list = difflib.get_close_matches(txt.lower(), choices[model_id], n=1, cutoff=0.9)
|
| 170 |
+
if len(sim_list) > 0:
|
| 171 |
+
t_standard = sim_list[0]
|
| 172 |
+
if t_standard not in temp:
|
| 173 |
+
pred_terms.append(t_standard+f'({prob})')
|
| 174 |
+
temp.append(t_standard)
|
| 175 |
+
union_pred_terms.append(pred_terms)
|
| 176 |
+
|
| 177 |
+
if prompt == 'none':
|
| 178 |
+
res_str = "No available predictions for this protein, you can use other two types of model, remove prompt or try another sequence!"
|
| 179 |
+
else:
|
| 180 |
+
res_str = "No available predictions for this protein, you can use other two types of model or try another sequence!"
|
| 181 |
+
if len(union_pred_terms[0]) == 0 and len(union_pred_terms[1]) == 0 and len(union_pred_terms[2]) == 0:
|
| 182 |
+
return res_str
|
| 183 |
+
res_str = ''
|
| 184 |
+
if len(union_pred_terms[0]) != 0:
|
| 185 |
+
res_str += f"Based on the given amino acid sequence, the protein appears to have a primary function of {', '.join(pred_terms)}. "
|
| 186 |
+
if len(union_pred_terms[1]) != 0:
|
| 187 |
+
res_str += f"It is likely involved in the {', '.join(pred_terms)}. "
|
| 188 |
+
if len(union_pred_terms[2]) != 0:
|
| 189 |
+
res_str += f"It's subcellular localization is within the {', '.join(pred_terms)}."
|
| 190 |
+
return res_str
|
| 191 |
# return "test"
|
| 192 |
|
| 193 |
|
|
|
|
| 196 |
|
| 197 |
The model used in this app is available at [Hugging Face Model Hub](https://huggingface.co/wenkai/FAPM) and the source code can be found on [GitHub](https://github.com/xiangwenkai/FAPM/tree/main)."""
|
| 198 |
|
| 199 |
+
# iface = gr.Interface(
|
| 200 |
+
# fn=generate_caption,
|
| 201 |
+
# inputs=[gr.Textbox(type="text", label="Upload sequence"), gr.Textbox(type="text", label="Prompt")],
|
| 202 |
+
# outputs=gr.Textbox(label="Generated description"),
|
| 203 |
+
# description=description
|
| 204 |
+
# )
|
| 205 |
+
# # Launch the interface
|
| 206 |
+
# iface.launch()
|
| 207 |
+
|
| 208 |
+
css = """
|
| 209 |
+
#output {
|
| 210 |
+
height: 500px;
|
| 211 |
+
overflow: auto;
|
| 212 |
+
border: 1px solid #ccc;
|
| 213 |
+
}
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
with gr.Blocks(css=css) as demo:
|
| 217 |
+
gr.Markdown(description)
|
| 218 |
+
with gr.Tab(label="Protein caption"):
|
| 219 |
+
with gr.Row():
|
| 220 |
+
with gr.Column():
|
| 221 |
+
input_protein = gr.Textbox(type="text", label="Upload sequence")
|
| 222 |
+
prompt = gr.Textbox(type="text", label="Taxonomy Prompt (Optional)")
|
| 223 |
+
submit_btn = gr.Button(value="Submit")
|
| 224 |
+
with gr.Column():
|
| 225 |
+
output_text = gr.Textbox(label="Output Text")
|
| 226 |
+
# O14813 train index 127, 266, 738, 1060 test index 4
|
| 227 |
+
gr.Examples(
|
| 228 |
+
examples=[
|
| 229 |
+
["MDYSYLNSYDSCVAAMEASAYGDFGACSQPGGFQYSPLRPAFPAAGPPCPALGSSNCALGALRDHQPAPYSAVPYKFFPEPSGLHEKRKQRRIRTTFTSAQLKELERVFAETHYPDIYTREELALKIDLTEARVQVWFQNRRAKFRKQERAASAKGAAGAAGAKKGEARCSSEDDDSKESTCSPTPDSTASLPPPPAPGLASPRLSPSPLPVALGSGPGPGPGPQPLKGALWAGVAGGGGGGPGAGAAELLKAWQPAESGPGPFSGVLSSFHRKPGPALKTNLF", ''],
|
| 230 |
+
["MKTLALFLVLVCVLGLVQSWEWPWNRKPTKFPIPSPNPRDKWCRLNLGPAWGGRC", ''],
|
| 231 |
+
["MAAAGGARLLRAASAVLGGPAGRWLHHAGSRAGSSGLLRNRGPGGSAEASRSLSVSARARSSSEDKITVHFINRDGETLTTKGKVGDSLLDVVVENNLDIDGFGACEGTLACSTCHLIFEDHIYEKLDAITDEENDMLDLAYGLTDRSRLGCQICLTKSMDNMTVRVPETVADARQSIDVGKTS", 'Homo'],
|
| 232 |
+
['MASAELSREENVYMAKLAEQAERYEEMVEFMEKVAKTVDSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEEGRGNEDRVTLIKDYRGKIETELTKICDGILKLLETHLVPSSTAPESKVFYLKMKGDYYRYLAEFKTGAERKDAAENTMVAYKAAQDIALAELAPTHPIRLGLALNFSVFYYEILNSPDRACSLAKQAFDEAISELDTLSEESYKDSTLIMQLLRDNLTLWTSDISEDPAEEIREAPKRDSSEGQ', 'Zea'],
|
| 233 |
+
['MIKAAVTKESLYRMNTLMEAFQGFLGLDLGEFTFKVKPGVFLLTDVKSYLIGDKYDDAFNALIDFVLRNDRDAVEGTETDVSIRLGLSPSDMVVKRQDKTFTFTHGDLEFEVHWINL', 'Bacteriophage'],
|
| 234 |
+
['MNDLMIQLLDQFEMGLRERAIKVMATINDEKHRFPMELNKKQCSLMLLGTTDTTTFDMRFNSKKDFPRIKGAREKYPRDAVIEWYHQNWMRTEVKQ', 'Bacteriophage'],
|
| 235 |
+
],
|
| 236 |
+
inputs=[input_protein, prompt],
|
| 237 |
+
outputs=[output_text],
|
| 238 |
+
fn=generate_caption,
|
| 239 |
+
cache_examples=True,
|
| 240 |
+
label='Try examples'
|
| 241 |
+
)
|
| 242 |
+
submit_btn.click(generate_caption, [input_protein, prompt], [output_text])
|
| 243 |
+
|
| 244 |
+
demo.launch(debug=True)
|
| 245 |
|