nas / ProtT3 /stage3.py
yuccaaa's picture
Add files using upload-large-folder tool
48cce71 verified
import os
try:
from model.opt_flash_attention import replace_opt_attn_with_flash_attn
except ModuleNotFoundError:
pass
import torch
import argparse
import warnings
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from data_provider.stage3_dm import Stage3DM
from data_provider.prot_qa_dm import ProtQADM
from model.blip2_stage3 import Blip2Stage3
from model.dist_funs import MyDeepSpeedStrategy
from pathlib import Path
import pytorch_lightning.callbacks as plc
os.environ['OPENBLAS_NUM_THREADS'] = '1'
## for pyg bug
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
## for A5000 gpus
torch.set_float32_matmul_precision('medium') # can be medium (bfloat16), high (tensorfloat32), highest (float32)
def main(args):
pl.seed_everything(args.seed)
# model
model = Blip2Stage3.load_from_checkpoint(args.checkpoint_name, strict=False, args=args, map_location='cpu')
#print(model.blip2.llm_model)
print(f"loaded init checkpoint from {args.checkpoint_name}")
# model = Blip2Stage3(args)
# model.load_from_stage1_checkpoint(args.checkpoint_name)
# print(f"loaded stage1 model from {args.checkpoint_name}")
print('total params:', sum(p.numel() for p in model.parameters()))
dm = Stage3DM(args.dataset, args)
dm.init_tokenizer(model.blip2.llm_tokenizer, model.blip2.plm_tokenizer)
test_loader = dm.test_dataloader()
# 获取第一条数据
batch = next(iter(test_loader))
model.eval()
# 如果你用的是 GPU,则转到 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
# 推理
with torch.no_grad():
output = model.test_step(batch, 0) # 0 表示 batch_idx
print(output)
# trainer.fit_loop.epoch_progress.current.completed = args.caption_eval_epoch - 1
# trainer.validate(model, datamodule=dm)
# callbacks = []
# ## fixme save only used parameters
# # callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/", every_n_epochs=10, save_top_k=-1))
# callbacks.append(plc.ModelCheckpoint(dirpath="all_checkpoints/"+args.filename+"/",
# filename='{epoch:02d}',
# save_last=True,
# save_top_k=-1,
# save_on_train_epoch_end=False))
# if len(args.devices.split(',')) > 1:
# if args.strategy == 'ddp':
# find_unused_parameters = (not args.ptm) or (not args.lm)
# strategy = strategies.DDPStrategy(start_method='spawn', find_unused_parameters=find_unused_parameters)
# elif args.strategy == 'deepspeed':
# strategy = MyDeepSpeedStrategy(stage=2)
# else:
# NotImplementedError()
# else:
# strategy = None
# args.devices = eval(args.devices)
# if args.use_wandb_logger:
# Path(f'./all_checkpoints/{args.filename}/wandb').mkdir(parents=True, exist_ok=True)
# logger = WandbLogger(project=args.filename, save_dir=f'./all_checkpoints/{args.filename}/')
# else:
# logger = CSVLogger(save_dir=f'./all_checkpoints/{args.filename}/')
# trainer = Trainer(
# accelerator=args.accelerator,
# devices=args.devices,
# precision=args.precision,
# callbacks=callbacks,
# strategy=strategy,
# logger=logger,
# # limit_train_batches=2,
# # limit_val_batches=2,
# )
trainer.test(model, datamodule=dm)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--filename', type=str, default="stage2_test")
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--strategy', type=str, default='deepspeed')
# trainer arguments
parser.add_argument('--accelerator', type=str, default='gpu')
parser.add_argument('--devices', type=str, default='0,1,2,3')
parser.add_argument('--precision', type=str, default='bf16')
parser.add_argument('--max_epochs', type=int, default=10)
# parser.add_argument('--accumulate_grad_batches', type=int, default=1)
parser.add_argument('--check_val_every_n_epoch', type=int, default=1)
parser.add_argument('--enable_flash', action='store_true', default=False)
parser.add_argument('--use_wandb_logger', action='store_true', default=False)
parser = Blip2Stage3.add_model_specific_args(parser) # add model args
parser = Stage3DM.add_model_specific_args(parser)
args = parser.parse_args()
if args.enable_flash:
replace_opt_attn_with_flash_attn()
print("=========================================")
for k, v in sorted(vars(args).items()):
print(k, '=', v)
print("=========================================")
return args
if __name__ == '__main__':
main(get_args())