ProtT3_model / model /blip2_stage3.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
import os
import torch
from model.blip2_opt import Blip2OPT
import pytorch_lightning as pl
from torch import optim
from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
import json
import torch.distributed as dist
# from peft import LoraConfig, TaskType
from typing import Any, Dict
from model.help_funcs import caption_evaluate, AttrDict
from datetime import datetime
try:
from model.opt_flash_attention import replace_opt_attn_with_flash_attn, replace_opt_attn_with_original_attn
except ModuleNotFoundError:
pass
def get_module_state_dict(state_dict, module_name):
module_state_dict = {}
for key, value in state_dict.items():
if key.startswith(module_name):
key = key[len(module_name) + 1:]
if key == '':
return value
module_state_dict[key] = value
return module_state_dict
class Blip2Stage3(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# checkpoint.pop('optimizer_states')
to_be_removed = []
for key, value in checkpoint['state_dict'].items():
try:
if not self.get_parameter(key).requires_grad:
to_be_removed.append(key)
except AttributeError:
to_be_removed.append(key)
for key in to_be_removed:
checkpoint['state_dict'].pop(key)
def __init__(self, args):
super().__init__()
if isinstance(args, dict):
args = AttrDict(**args)
self.args = args
self.caption_eval_epoch = args.caption_eval_epoch
self.do_sample = args.do_sample
self.num_beams = args.num_beams
self.max_inference_len = args.max_inference_len
self.min_inference_len = args.min_inference_len
self.llm_tune = args.llm_tune
self.enable_flash = args.enable_flash
# if args.llm_name.find('galactica') >= 0:
self.blip2 = Blip2OPT(args.bert_name,
args.num_query_token,
args.cross_attention_freq,
args.plm_model,
args.plm_tune,
args.llm_name,
args.llm_tune,
args.peft_dir,
args)
# else:
# raise NotImplementedError()
self.save_hyperparameters(args)
def load_from_stage1_checkpoint(self, path):
ckpt = torch.load(path, map_location='cpu')
state_dict = ckpt['state_dict']
state_dict = {k.split('blip2qformer.')[1]:v for k, v in state_dict.items()}
self.blip2.load_state_dict(state_dict, strict=False)
return self
def configure_optimizers(self):
self.trainer.fit_loop.setup_data()
warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
optimizer = optim.AdamW(self.parameters(), lr=self.args.init_lr, weight_decay=self.args.weight_decay)
if self.args.scheduler == 'linear_warmup_cosine_lr':
self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
elif self.args.scheduler == 'linear_warmup_step_lr':
self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
elif self.args.scheduler == 'None':
self.scheduler = None
else:
raise NotImplementedError()
return optimizer
def on_validation_epoch_start(self) -> None:
if self.enable_flash:
replace_opt_attn_with_original_attn()
self.saved_dict_list = []
self.prediction_list0 = []
self.target_list0 = []
self.prediction_list1 = []
self.target_list1 = []
@torch.no_grad()
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if (dataloader_idx % 2) == 0:
text_batch = batch[-1]
batch_size = text_batch.input_ids.shape[0]
loss = self.blip2(batch)
###============== Overall Loss ===================###
self.log(f"dataloader{dataloader_idx}/val loss", float(loss), batch_size=batch_size, sync_dist=True)
elif (dataloader_idx % 2) == 1:
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
return
prot_batch, prompt_batch, target_dict = batch
###============== Captioning Results ===================###
samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch}
predictions = self.blip2.generate(
samples,
do_sample=self.do_sample,
num_beams=self.num_beams,
max_length=self.max_inference_len,
min_length=self.min_inference_len
)
target_dict['predictions'] = predictions
self.saved_dict_list.append(target_dict)
def gather_dict_results(self, dict_list):
list_of_dict_list = [None for _ in range(self.trainer.world_size)]
dist.all_gather_object(list_of_dict_list, dict_list)
dict_list = [i for ii in list_of_dict_list for i in ii] ## dict list, each dict has values that are lists of predictions, etc.
keys = dict_list[0].keys()
gathered_dict = {} # each value is a list of predictions, etc.
for key in keys:
gathered_dict[key] = [i for d in dict_list for i in d[key]]
dict_list = []
for i in range(len(gathered_dict['predictions'])):
d = {k:gathered_dict[k][i] for k in keys}
dict_list.append(d)
return dict_list
def save_results(self, dict_list, log_prefix=""):
## save the results
if log_prefix:
name = f'{log_prefix}_predictions.txt'
else:
name = 'predictions.txt'
with open(name, 'w', encoding='utf8') as f:
for d in dict_list:
f.write(json.dumps(d, ensure_ascii=True) + '\n')
def on_validation_epoch_end(self):
if self.enable_flash:
replace_opt_attn_with_flash_attn()
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
return
result_list = self.gather_dict_results(self.saved_dict_list)
## empty cache
self.saved_dict_list = []
if self.global_rank == 0:
self.save_results(result_list, 'deeplocmulti_07141239')
all_predictions = [i['predictions'] for i in result_list]
all_targets = [i['targets'] for i in result_list]
log_prefix = 'dataset0' ## fixme: this is just a placeholder
if 'q_types' in result_list[0]:
## evaluate protein qa
pass
else:
## evaluate captioning
bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len)
acc = evaluate_exact_match(all_predictions, all_targets)
self.log(f"{log_prefix}/acc", acc, sync_dist=False)
self.log(f"{log_prefix}/bleu2", bleu2, sync_dist=False)
self.log(f"{log_prefix}/bleu4", bleu4, sync_dist=False)
self.log(f"{log_prefix}/rouge_1", rouge_1, sync_dist=False)
self.log(f"{log_prefix}/rouge_2", rouge_2, sync_dist=False)
self.log(f"{log_prefix}/rouge_l", rouge_l, sync_dist=False)
self.log(f"{log_prefix}/meteor_score", meteor_score, sync_dist=False)
def on_test_epoch_start(self) -> None:
if self.enable_flash:
replace_opt_attn_with_original_attn()
self.saved_dict_list = []
@torch.no_grad()
def test_step(self, batch, batch_idx):
# if (dataloader_idx % 2) == 0:
# text_batch = batch[-1]
# batch_size = text_batch.input_ids.shape[0]
# loss = self.blip2(batch)
# self.log(f"dataloader{dataloader_idx}/test loss", float(loss), batch_size=batch_size, sync_dist=True)
# elif (dataloader_idx % 2) == 1:
prot_batch, prompt_batch, target_dict = batch
samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch}
# if isinstance(prompt_batch, torch.Tensor):
# prompt_batch = prompt_batch.tolist()
# decoded_prompts = self.blip2.llm_tokenizer.batch_decode(prompt_batch, skip_special_tokens=True)
# for i, prompt in enumerate(decoded_prompts):
# print(f"[Sample {batch_idx} | Prompt {i}]: {prompt}")
predictions = self.blip2.generate(
samples,
do_sample=self.do_sample,
num_beams=self.num_beams,
max_length=self.max_inference_len,
min_length=self.min_inference_len
)
target_dict['predictions'] = predictions
self.saved_dict_list.append(target_dict)
def on_test_epoch_end(self):
if self.enable_flash:
replace_opt_attn_with_flash_attn()
result_list = self.gather_dict_results(self.saved_dict_list)
self.saved_dict_list = []
if self.global_rank == 0:
timestamp = datetime.now().strftime("%m%d%H%M")
prediction_file = f"results/{timestamp}/predictions_test.jsonl"
metrics_file = f"results/{timestamp}/metrics_test.json"
os.makedirs(os.path.dirname(prediction_file), exist_ok=True)
os.makedirs(os.path.dirname(metrics_file), exist_ok=True)
all_predictions = []
all_targets = []
# 保存 prediction-target 对到 jsonl 文件
with open(prediction_file, 'w', encoding='utf-8') as f:
for d in result_list:
pred = d['predictions']
target = d['targets']
all_predictions.append(pred)
all_targets.append(target)
f.write(json.dumps({'prediction': pred, 'target': target}, ensure_ascii=False) + '\n')
# 评估指标
if 'q_types' in result_list[0]:
pass # QA评估略过
else:
bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len)
acc = evaluate_exact_match(all_predictions, all_targets)
metrics = {
"acc": acc,
"bleu2": bleu2,
"bleu4": bleu4,
"rouge_1": rouge_1,
"rouge_2": rouge_2,
"rouge_l": rouge_l,
"meteor_score": meteor_score
}
# 打印日志
for k, v in metrics.items():
self.log(f"test/{k}", v, sync_dist=False)
# 保存 metrics 到 json 文件
with open(metrics_file, 'w', encoding='utf-8') as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
def training_step(self, batch, batch_idx):
if self.scheduler:
self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
batch_size = batch[-1].input_ids.size(0)
###============== Overall Loss ===================###
loss = self.blip2(batch)
self.log("loss", float(loss), batch_size=batch_size, sync_dist=True)
self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True)
return loss
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("ProtBlip2")
# train mode
#parser.add_argument('--save_every_n_epochs', type=int, default=0)
# Bert
parser.add_argument('--bert_name', type=str, default='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract')
parser.add_argument('--cross_attention_freq', type=int, default=2)
parser.add_argument('--num_query_token', type=int, default=8)
# OPT
parser.add_argument('--llm_name', type=str, default="facebook/galactica-1.3b")
parser.add_argument('--num_beams', type=int, default=5)
parser.add_argument('--do_sample', action='store_true', default=False)
parser.add_argument('--max_inference_len', type=int, default=128)
parser.add_argument('--min_inference_len', type=int, default=1)
parser.add_argument('--llm_tune', type=str, default='freeze')
parser.add_argument('--peft_config', type=str, default='')
parser.add_argument('--peft_dir', type=str, default='')
## plm model
parser.add_argument('--plm_model', type=str, default='facebook/esm2_t30_150M_UR50D')
parser.add_argument('--plm_tune', type=str, default='freeze')
## lora config
parser.add_argument('--lora_r', type=int, default=8)
parser.add_argument('--lora_alpha', type=int, default=16)
parser.add_argument('--lora_dropout', type=int, default=0.1)
parser.add_argument('--enbale_gradient_checkpointing', action='store_true', default=False)
# optimization
parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') # or linear_warmup_step_lr
parser.add_argument('--checkpoint_name', type=str, default='')
parser.add_argument('--caption_eval_epoch', type=int, default=10)
return parent_parser
def evaluate_exact_match(predictions, targets):
acc = 0
for prediction, target in zip(predictions, targets):
if str(prediction).strip() == str(target).strip():
acc += 1
acc = round(acc / len(predictions) * 100, 2)
return acc