Update app.py
Browse files
app.py
CHANGED
|
@@ -13,25 +13,35 @@ from data.evaluate_data.utils import Ontology
|
|
| 13 |
import difflib
|
| 14 |
import re
|
| 15 |
from transformers import MistralForCausalLM
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
|
|
|
| 17 |
# Load the trained model
|
| 18 |
def get_model(type='Molecule Function'):
|
| 19 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
| 20 |
if type == 'Molecule Function':
|
| 21 |
-
model.load_checkpoint("model/checkpoint_mf2.pth")
|
|
|
|
| 22 |
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
|
| 23 |
model.to('cuda')
|
| 24 |
elif type == 'Biological Process':
|
| 25 |
-
model.load_checkpoint("model/checkpoint_bp1.pth")
|
|
|
|
| 26 |
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
|
| 27 |
model.to('cuda')
|
| 28 |
elif type == 'Cellar Component':
|
| 29 |
-
model.load_checkpoint("model/checkpoint_cc2.pth")
|
|
|
|
| 30 |
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
|
| 31 |
model.to('cuda')
|
| 32 |
return model
|
| 33 |
|
| 34 |
-
|
| 35 |
models = {
|
| 36 |
'Molecule Function': get_model('Molecule Function'),
|
| 37 |
'Biological Process': get_model('Biological Process'),
|
|
|
|
| 13 |
import difflib
|
| 14 |
import re
|
| 15 |
from transformers import MistralForCausalLM
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
bp_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_bp1.pth")
|
| 18 |
+
mf_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_mf2.pth")
|
| 19 |
+
cc_param = hf_hub_download(repo_id="wenkai/FAPM", filename="model/checkpoint_cc2.pth")
|
| 20 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/mf2_bert.pth")
|
| 21 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/bp1_bert.pth")
|
| 22 |
+
# hf_hub_download(repo_id="wenkai/FAPM", filename="model/cc2_bert.pth")
|
| 23 |
|
| 24 |
+
# bert_param = BertModel.from_pretrained("bert-base-uncased").state_dict()
|
| 25 |
# Load the trained model
|
| 26 |
def get_model(type='Molecule Function'):
|
| 27 |
model = Blip2ProteinMistral(config=FAPMConfig(), esm_size='3b')
|
| 28 |
if type == 'Molecule Function':
|
| 29 |
+
# model.load_checkpoint("model/checkpoint_mf2.pth")
|
| 30 |
+
model.load_checkpoint(mf_param)
|
| 31 |
model.Qformer.bert = torch.load('model/mf2_bert.pth', map_location=torch.device('cpu'))
|
| 32 |
model.to('cuda')
|
| 33 |
elif type == 'Biological Process':
|
| 34 |
+
# model.load_checkpoint("model/checkpoint_bp1.pth")
|
| 35 |
+
model.load_checkpoint(bp_param)
|
| 36 |
model.Qformer.bert = torch.load('model/bp1_bert.pth', map_location=torch.device('cpu'))
|
| 37 |
model.to('cuda')
|
| 38 |
elif type == 'Cellar Component':
|
| 39 |
+
# model.load_checkpoint("model/checkpoint_cc2.pth")
|
| 40 |
+
model.load_checkpoint(cc_param)
|
| 41 |
model.Qformer.bert = torch.load('model/cc2_bert.pth', map_location=torch.device('cpu'))
|
| 42 |
model.to('cuda')
|
| 43 |
return model
|
| 44 |
|
|
|
|
| 45 |
models = {
|
| 46 |
'Molecule Function': get_model('Molecule Function'),
|
| 47 |
'Biological Process': get_model('Biological Process'),
|