ProtT3_model / model /blip2_stage1.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
import contextlib
import torch
from model.blip2qformer import Blip2Qformer
import pytorch_lightning as pl
from torch import optim
from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
from tqdm import tqdm
from model.help_funcs import AttrDict, pad_and_concat
from typing import Any, Dict
class Blip2Stage1(pl.LightningModule):
def __init__(self, args):
super().__init__()
if isinstance(args, dict):
args = AttrDict(**args)
self.args = args
self.rerank_cand_num = args.rerank_cand_num
self.blip2qformer = Blip2Qformer(args.ptm, args.lm, args.bert_name, args.plm_name, args.temperature, args.plm_tune, args.num_query_token, args.cross_attention_freq, args.projection_dim, args.pool_size, args.load_4bit)
self.save_hyperparameters(args)
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# checkpoint.pop('optimizer_states')
to_be_removed = []
for key, value in checkpoint['state_dict'].items():
try:
if not self.get_parameter(key).requires_grad:
to_be_removed.append(key)
except AttributeError:
to_be_removed.append(key)
for key in to_be_removed:
checkpoint['state_dict'].pop(key)
def maybe_autocast(self, dtype=torch.bfloat16):
# if on cpu, don't use autocast
# if on gpu, use autocast with dtype if provided, otherwise use torch.float16
enable_autocast = self.device != torch.device("cpu")
if enable_autocast:
return torch.cuda.amp.autocast(dtype=dtype)
else:
return contextlib.nullcontext()
def configure_optimizers(self):
self.trainer.fit_loop.setup_data()
warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
optimizer = optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=self.args.weight_decay)
if self.args.scheduler == 'linear_warmup_cosine_lr':
self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
elif self.args.scheduler == 'linear_warmup_step_lr':
self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
elif self.args.scheduler == 'None':
self.scheduler = None
else:
raise NotImplementedError()
return optimizer
@torch.no_grad()
def validation_step(self, batch, batch_idx, dataloader_idx=0):
prot_batch, text_batch = batch
batch_size = prot_batch.input_ids.shape[0]
blip2_loss = self.blip2qformer(batch)
###============== Overall Loss ===================###
self.log(f"loader{dataloader_idx}/val_loss_ptc", float(blip2_loss.loss_itc), batch_size=batch_size, sync_dist=True)
self.log(f"loader{dataloader_idx}/val_loss_ptm", float(blip2_loss.loss_itm), batch_size=batch_size, sync_dist=True)
self.log(f"loader{dataloader_idx}/val_loss_lm", float(blip2_loss.loss_lm), batch_size=batch_size, sync_dist=True)
self.log(f"loader{dataloader_idx}/val_loss", float(blip2_loss.loss), batch_size=batch_size, sync_dist=True)
return blip2_loss.loss
def get_precision(self, precision):
if precision in {'16', '16-mixed'}:
return torch.float16
elif precision in {'bf16', 'bf16-mixed'}:
return torch.bfloat16
elif precision in {'32',}:
return torch.float32
else:
raise NotImplementedError
def retrieval_evaluation_and_log(self, match_dataloader, log_prefix="") -> None:
with self.maybe_autocast(self.get_precision(self.trainer.precision)):
## for onto test set
p2t_acc, t2p_acc, p2t_rec20, t2p_rec20, \
p2t_rerank_acc, t2p_rerank_acc, p2t_rerank_rec20, t2p_rerank_rec20, \
prot_feat_total, text_feat_total, prot_embed_total, prot_mask_total, text_total, text_mask_total = \
eval_retrieval_inbatch_with_rerank(self.blip2qformer, match_dataloader, self.device)
self.log(f"{log_prefix}inbatch_p2t_acc", p2t_acc, sync_dist=False)
self.log(f"{log_prefix}inbatch_t2p_acc", t2p_acc, sync_dist=False)
self.log(f"{log_prefix}inbatch_p2t_rec20", p2t_rec20, sync_dist=False)
self.log(f"{log_prefix}inbatch_t2p_rec20", t2p_rec20, sync_dist=False)
self.log(f"{log_prefix}rerank_inbatch_p2t_acc", p2t_rerank_acc, sync_dist=False)
self.log(f"{log_prefix}rerank_inbatch_t2p_acc", t2p_rerank_acc, sync_dist=False)
self.log(f"{log_prefix}rerank_inbatch_p2t_rec20", p2t_rerank_rec20, sync_dist=False)
self.log(f"{log_prefix}rerank_inbatch_t2p_rec20", t2p_rerank_rec20, sync_dist=False)
p2t_acc, p2t_rec20, t2p_acc, t2p_rec20, sim_p2t = \
eval_retrieval_fullset(prot_feat_total, text_feat_total, self.device)
self.log(f"{log_prefix}fullset_p2t_acc", p2t_acc, sync_dist=False)
self.log(f"{log_prefix}fullset_t2p_acc", t2p_acc, sync_dist=False)
self.log(f"{log_prefix}fullset_p2t_rec20", p2t_rec20, sync_dist=False)
self.log(f"{log_prefix}fullset_t2p_rec20", t2p_rec20, sync_dist=False)
p2t_acc, p2t_rec20, t2p_acc, t2p_rec20 = \
eval_retrieval_fullset_for_rerank(self.blip2qformer, sim_p2t, prot_embed_total, prot_mask_total, text_total, text_mask_total, self.rerank_cand_num, self.device)
self.log(f"{log_prefix}rerank_fullset_p2t_acc", p2t_acc, sync_dist=False)
self.log(f"{log_prefix}rerank_fullset_t2p_acc", t2p_acc, sync_dist=False)
self.log(f"{log_prefix}rerank_fullset_p2t_rec20", p2t_rec20, sync_dist=False)
self.log(f"{log_prefix}rerank_fullset_t2p_rec20", t2p_rec20, sync_dist=False)
def on_validation_epoch_end(self) -> None:
if self.current_epoch == 0 or (self.current_epoch + 1) % self.args.retrieval_eval_epoch != 0:
return
if self.trainer.global_rank == 0:
# 可在此处添加其他非评估逻辑(如自定义日志记录)
pass
# if self.current_epoch == 0 or (self.current_epoch + 1) % self.args.retrieval_eval_epoch != 0:
# return
# if self.trainer.global_rank == 0:
# ## evaluation for mix dataloaders
# if hasattr(self, 'swiss_test_match_loader') and hasattr(self, 'onto_test_match_loader'):
# self.retrieval_evaluation_and_log(self.swiss_test_match_loader, log_prefix="swiss_test_")
# self.retrieval_evaluation_and_log(self.onto_test_match_loader, log_prefix="onto_test_")
# return
# with self.maybe_autocast(self.get_precision(self.trainer.precision)):
# ## for validation set
# p2t_acc, t2p_acc, p2t_rec20, t2p_rec20, \
# p2t_rerank_acc, t2p_rerank_acc, p2t_rerank_rec20, t2p_rerank_rec20,\
# prot_feat_total, text_feat_total, _, _, _, _ = \
# eval_retrieval_inbatch_with_rerank(self.blip2qformer, self.val_match_loader, self.device)
# self.log("val_inbatch_p2t_acc", p2t_acc, sync_dist=False)
# self.log("val_inbatch_t2p_acc", t2p_acc, sync_dist=False)
# self.log("val_inbatch_p2t_rec20", p2t_rec20, sync_dist=False)
# self.log("val_inbatch_t2p_rec20", t2p_rec20, sync_dist=False)
# self.log("rerank_val_inbatch_p2t_acc", p2t_rerank_acc, sync_dist=False)
# self.log("rerank_val_inbatch_t2p_acc", t2p_rerank_acc, sync_dist=False)
# self.log("rerank_val_inbatch_p2t_rec20", p2t_rerank_rec20, sync_dist=False)
# self.log("rerank_val_inbatch_t2p_rec20", t2p_rerank_rec20, sync_dist=False)
# p2t_acc, p2t_rec20, t2p_acc, t2p_rec20, _ = \
# eval_retrieval_fullset(prot_feat_total, text_feat_total, self.device)
# self.log("val_fullset_p2t_acc", p2t_acc, sync_dist=False)
# self.log("val_fullset_t2p_acc", t2p_acc, sync_dist=False)
# self.log("val_fullset_p2t_rec20", p2t_rec20, sync_dist=False)
# self.log("val_fullset_t2p_rec20", t2p_rec20, sync_dist=False)
# ## for test set
# p2t_acc, t2p_acc, p2t_rec20, t2p_rec20, \
# p2t_rerank_acc, t2p_rerank_acc, p2t_rerank_rec20, t2p_rerank_rec20, \
# prot_feat_total, text_feat_total, prot_embed_total, prot_mask_total, text_total, text_mask_total = \
# eval_retrieval_inbatch_with_rerank(self.blip2qformer, self.test_match_loader, self.device)
# self.log("rerank_test_inbatch_p2t_acc", p2t_rerank_acc, sync_dist=False)
# self.log("rerank_test_inbatch_t2p_acc", t2p_rerank_acc, sync_dist=False)
# self.log("rerank_test_inbatch_p2t_rec20", p2t_rerank_rec20, sync_dist=False)
# self.log("rerank_test_inbatch_t2p_rec20", t2p_rerank_rec20, sync_dist=False)
# self.log("test_inbatch_p2t_acc", p2t_acc, sync_dist=False)
# self.log("test_inbatch_t2p_acc", t2p_acc, sync_dist=False)
# self.log("test_inbatch_p2t_rec20", p2t_rec20, sync_dist=False)
# self.log("test_inbatch_t2p_rec20", t2p_rec20, sync_dist=False)
# p2t_acc, p2t_rec20, t2p_acc, t2p_rec20, sim_p2t = \
# eval_retrieval_fullset(prot_feat_total, text_feat_total, self.device)
# self.log("test_fullset_p2t_acc", p2t_acc, sync_dist=False)
# self.log("test_fullset_t2p_acc", t2p_acc, sync_dist=False)
# self.log("test_fullset_p2t_rec20", p2t_rec20, sync_dist=False)
# self.log("test_fullset_t2p_rec20", t2p_rec20, sync_dist=False)
# p2t_acc, p2t_rec20, t2p_acc, t2p_rec20 = \
# eval_retrieval_fullset_for_rerank(self.blip2qformer, sim_p2t, prot_embed_total, prot_mask_total, text_total, text_mask_total, self.rerank_cand_num, self.device)
# self.log("rerank_test_fullset_p2t_acc", p2t_acc, sync_dist=False)
# self.log("rerank_test_fullset_t2p_acc", t2p_acc, sync_dist=False)
# self.log("rerank_test_fullset_p2t_rec20", p2t_rec20, sync_dist=False)
# self.log("rerank_test_fullset_t2p_rec20", t2p_rec20, sync_dist=False)
# del prot_feat_total, text_feat_total
def training_step(self, batch, batch_idx):
self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
prot_batch, text_batch = batch
batch_size = prot_batch.input_ids.shape[0]
blip2_loss = self.blip2qformer(batch)
###============== Overall Loss ===================###
self.log("train_loss_ptc", float(blip2_loss.loss_itc), batch_size=batch_size, sync_dist=True)
self.log("train_loss_ptm", float(blip2_loss.loss_itm), batch_size=batch_size, sync_dist=True)
self.log("train_loss_lm", float(blip2_loss.loss_lm), batch_size=batch_size, sync_dist=True)
self.log("train_loss", float(blip2_loss.loss), batch_size=batch_size, sync_dist=True)
self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True)
return blip2_loss.loss
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Stage1")
# train mode
parser.add_argument('--temperature', type=float, default=0.1, help='the temperature of NT_XentLoss')
parser.add_argument('--save_every_n_epochs', type=int, default=0)
parser.add_argument('--ptm', action='store_true', help='use graph-text matching or not', default=True)
parser.add_argument('--lm', action='store_true', help='use language modeling or not', default=True)
# evaluation
parser.add_argument('--rerank_cand_num', type=int, default=128)
# plm
parser.add_argument('--plm_name', type=str, default='facebook/esm2_t30_150M_UR50D,可以是模型地址')
parser.add_argument('--plm_tune', type=str, default='freeze')
parser.add_argument('--load_4bit', action='store_true', default=False)
parser.add_argument('--pool_size', type=int, default=0)
# Bert
parser.add_argument('--bert_hidden_dim', type=int, default=768, help='')
parser.add_argument('--bert_name', type=str, default='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract')
parser.add_argument('--projection_dim', type=int, default=256)
parser.add_argument('--cross_attention_freq', type=int, default=2)
parser.add_argument('--num_query_token', type=int, default=8)
# optimization
parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') # or linear_warmup_step_lr
parser.add_argument('--init_checkpoint', type=str, default='')
parser.add_argument('--retrieval_eval_epoch', type=int, default=10)
return parent_parser
@torch.no_grad()
def eval_retrieval_fullset(prot_feat, text_feat, device):
N = prot_feat.shape[0]
B = 8
text_feat = text_feat.to(device)
sim_p2t = []
for i in tqdm(range(0, N, B)):
l_prot_feat = prot_feat[i:i+B].to(device)
l_sim_q2t = (l_prot_feat.unsqueeze(1) @ text_feat.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [N, D, 1]; output shape = [B, N, num_qs]
l_sim_p2t, _ = l_sim_q2t.max(-1) # shape = [B, N]
sim_p2t.append(l_sim_p2t)
sim_p2t = torch.cat(sim_p2t, dim=0).cpu() # shape = [N, N]
rank_p2t = []
for i in range(0, N, B):
sorted_ids = torch.argsort(sim_p2t[i:i+B].to(device), descending=True)
rank_p2t.append((sorted_ids == torch.arange(i,i+sorted_ids.shape[0], device=device).reshape(-1, 1)).int().argmax(dim=-1))
rank_p2t = torch.cat(rank_p2t, dim=0)
rank_t2p = []
for i in range(0, N, B):
sorted_ids = torch.argsort(sim_p2t.T[i:i+B].to(device), descending=True)
rank_t2p.append((sorted_ids == torch.arange(i,i+sorted_ids.shape[0], device=device).reshape(-1, 1)).int().argmax(dim=-1))
rank_t2p = torch.cat(rank_t2p, dim=0)
p2t_acc = float((rank_p2t == 0).float().mean())
p2t_rec20 = float((rank_p2t < 20).float().mean())
t2p_acc = float((rank_t2p == 0).float().mean())
t2p_rec20 = float((rank_t2p < 20).float().mean())
p2t_acc = round(p2t_acc * 100, 2)
p2t_rec20 = round(p2t_rec20 * 100, 2)
t2p_acc = round(t2p_acc * 100, 2)
t2p_rec20 = round(t2p_rec20 * 100, 2)
return p2t_acc, p2t_rec20, t2p_acc, t2p_rec20, sim_p2t
@torch.no_grad()
def eval_retrieval_fullset_for_rerank(model, sim_p2t_total, prot_embed_total, prot_mask_total, text_total, text_mask_total, rerank_cand_num, device):
N = sim_p2t_total.shape[0]
B = 4
rcn = rerank_cand_num ## re-rank candidate numbers
# print(f"sim_p2t_total shape: {sim_p2t_total.shape}")
# print(f"rerank_cand_num: {rerank_cand_num}")
rcn = min(rerank_cand_num, sim_p2t_total.shape[1])
hit_p2t = []
for i in tqdm(range(0, N, B), desc='re-ranking p2t'):
sim = sim_p2t_total[i:i+B].to(device)
rB = sim.shape[0] # real batch size
topk_sim, topk_idx = sim.topk(k=rcn, dim=1) # shape = [B, rcn]
topk_idx = topk_idx.cpu()
prot_embed = prot_embed_total[i:i+B].to(device).repeat_interleave(rcn, 0) # shape = [B * rcn, num_qs, D]
prot_mask = prot_mask_total[i:i+B].to(device).repeat_interleave(rcn, 0) # shape = [B * rcn, num_qs, D]
text = text_total[topk_idx].flatten(0,1).to(device) # shape = [B * rcn, text_len]
text_mask = text_mask_total[topk_idx].flatten(0,1).to(device) # shape = [B * rcn, text_len]
ptm_sim = model.compute_ptm(prot_embed, prot_mask, text, text_mask).reshape(rB, rcn) ## fixme, using the linear clf's logits directly, without softmax
sorted_ids = torch.argsort(topk_sim + ptm_sim, descending=True).cpu() # shape = [B, rcn]
# sorted_ids = torch.argsort(gtm_sim, descending=True).cpu() # shape = [B, rcn]
sorted_ids = torch.gather(topk_idx, 1, sorted_ids) # mapping to original ids
hit_p2t.append((sorted_ids == torch.arange(i,i+rB).reshape(-1, 1)).int())
hit_p2t = torch.cat(hit_p2t, dim=0) # shape = [N, rcn]
# p2t_acc = float((hit_p2t[:, 0]).float().mean())
# p2t_rec20 = float((hit_p2t[:, :20]).float().sum() / N)
# print(p2t_acc, p2t_rec20)
hit_t2p = []
sim_t2p_total = sim_p2t_total.T
for i in tqdm(range(0, N, B), desc='re-ranking t2p'):
sim = sim_t2p_total[i:i+B].to(device)
rB = sim.shape[0]
topk_sim, topk_idx = sim.topk(k=rcn, dim=1)
topk_idx = topk_idx.cpu()
text = text_total[i:i+B].to(device).repeat_interleave(rcn, 0)
text_mask = text_mask_total[i:i+B].to(device).repeat_interleave(rcn, 0)
prot_embed = prot_embed_total[topk_idx].to(device).flatten(0,1)
prot_mask = prot_mask_total[topk_idx].to(device).flatten(0,1)
ptm_sim = model.compute_ptm(prot_embed, prot_mask, text, text_mask).reshape(rB, rcn)
sorted_ids = torch.argsort(topk_sim + ptm_sim, descending=True).cpu() # shape = [B, rcn]
sorted_ids = torch.gather(topk_idx, 1, sorted_ids)
hit_t2p.append((sorted_ids == torch.arange(i,i+sorted_ids.shape[0]).reshape(-1, 1)).int())
hit_t2p = torch.cat(hit_t2p, dim=0)
p2t_acc = float((hit_p2t[:, 0]).float().mean())
p2t_rec20 = float((hit_p2t[:, :20]).float().sum() / N)
t2p_acc = float((hit_t2p[:, 0]).float().mean())
t2p_rec20 = float((hit_t2p[:, :20]).float().sum() / N)
p2t_acc = round(p2t_acc * 100, 2)
p2t_rec20 = round(p2t_rec20 * 100, 2)
t2p_acc = round(t2p_acc * 100, 2)
t2p_rec20 = round(t2p_rec20 * 100, 2)
return p2t_acc, p2t_rec20, t2p_acc, t2p_rec20
@torch.no_grad()
def eval_retrieval_inbatch_with_rerank(model, dataloader, device=None):
'''
include rerank
'''
assert isinstance(model, Blip2Qformer)
pad_token_id = model.tokenizer.pad_token_id
model.eval()
p2t_acc = 0
t2p_acc = 0
p2t_rec20 = 0
t2p_rec20 = 0
allcnt = 0
p2t_rerank_acc = 0
t2p_rerank_acc = 0
p2t_rerank_rec20 = 0
t2p_rerank_rec20 = 0
prot_feat_total = []
text_feat_total = []
prot_embed_total = []
prot_mask_total = []
text_total = []
text_mask_total = []
for batch in tqdm(dataloader):
prot_batch, text_batch = batch
prot_batch, text_batch = prot_batch.to(device), text_batch.to(device)
text_total.append(text_batch.input_ids)
text_mask_total.append(text_batch.attention_mask)
prot_feats, prot_embeds = model.prot_forward(prot_batch) # shape = [B, num_qs, D]
text_feats = model.text_forward(text_batch) # shape = [B, D]
sim_q2t = (prot_feats.unsqueeze(1) @ text_feats.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B, D, 1]; output shape = [B, B, num_qs]
sim_p2t, _ = sim_q2t.max(-1) # shape = [B, B]
B = sim_p2t.shape[0]
sorted_ids = sim_p2t.argsort(descending=True).cpu()
p2t_rank = (sorted_ids == torch.arange(B).reshape(-1, 1)).int().argmax(dim=-1)
sorted_ids = sim_p2t.T.argsort(descending=True).cpu()
t2p_rank = (sorted_ids == torch.arange(B).reshape(-1, 1)).int().argmax(dim=-1)
p2t_acc += float((p2t_rank == 0).sum())
t2p_acc += float((t2p_rank == 0).sum())
p2t_rec20 += float((p2t_rank < 20).sum())
t2p_rec20 += float((t2p_rank < 20).sum())
allcnt += B
prot_feat_total.append(prot_feats.cpu())
text_feat_total.append(text_feats.cpu())
prot_embed_total.append(prot_embeds.cpu())
prot_mask_total.append(prot_batch.attention_mask.cpu())
## reranking
prot_embeds = prot_embeds.repeat_interleave(B, 0) # shape = [B * B, prot_len, D]
prot_mask = prot_batch.attention_mask.repeat_interleave(B, 0) # shape = [B * B, prot_len]
text_ids = text_batch.input_ids.repeat(B, 1) # shape = [B * B, text_len]
text_mask = text_batch.attention_mask.repeat(B, 1) # shape = [B * B, text_len]
## batched reranking
batch_size = 64
ptm_sim = []
for i in range(0, prot_embeds.shape[0], batch_size):
ptm_sim_local = model.compute_ptm(prot_embeds[i:i+batch_size], prot_mask[i:i+batch_size], text_ids[i:i+batch_size], text_mask[i:i+batch_size])
ptm_sim.append(ptm_sim_local)
ptm_sim = torch.cat(ptm_sim, dim=0).reshape(B, B)
rerank_sim = sim_p2t + ptm_sim
## p2t rerank
sorted_ids = torch.argsort(rerank_sim, descending=True).cpu() # shape = [B, B]
hit_p2t = (sorted_ids == torch.arange(B).reshape(-1, 1)).float()
p2t_rerank_acc += float(hit_p2t[:, 0].sum())
p2t_rerank_rec20 += float(hit_p2t[:, :20].sum())
## t2p rerank
sorted_ids = torch.argsort(rerank_sim.T, descending=True).cpu() # shape = [B, B]
hit_t2p = (sorted_ids == torch.arange(B).reshape(-1, 1)).float()
t2p_rerank_acc += float(hit_t2p[:, 0].sum())
t2p_rerank_rec20 += float(hit_t2p[:, :20].sum())
prot_feat_total = torch.cat(prot_feat_total, dim=0)
text_feat_total = torch.cat(text_feat_total, dim=0)
prot_embed_total = pad_and_concat(prot_embed_total)
prot_mask_total = pad_and_concat(prot_mask_total)
text_total = pad_and_concat(text_total, fill_value=pad_token_id)
text_mask_total = pad_and_concat(text_mask_total)
# # text_total = torch.cat(text_total, dim=0)
# text_mask_total = torch.cat(text_mask_total, dim=0)
p2t_acc = round(p2t_acc/allcnt * 100, 2)
t2p_acc = round(t2p_acc/allcnt * 100, 2)
p2t_rec20 = round(p2t_rec20 / allcnt * 100, 2)
t2p_rec20 = round(t2p_rec20 / allcnt * 100, 2)
p2t_rerank_acc = round(p2t_rerank_acc / allcnt * 100, 2)
t2p_rerank_acc = round(t2p_rerank_acc / allcnt * 100, 2)
p2t_rerank_rec20 = round(p2t_rerank_rec20 / allcnt * 100, 2)
t2p_rerank_rec20 = round(t2p_rerank_rec20 / allcnt * 100, 2)
return p2t_acc, t2p_acc, p2t_rec20, t2p_rec20, \
p2t_rerank_acc, t2p_rerank_acc, p2t_rerank_rec20, t2p_rerank_rec20, \
prot_feat_total, text_feat_total, prot_embed_total, prot_mask_total, text_total, text_mask_total