| | import os |
| | try: |
| | from model.opt_flash_attention import replace_opt_attn_with_flash_attn |
| | except ModuleNotFoundError: |
| | pass |
| | import torch |
| | import argparse |
| | import warnings |
| | import pytorch_lightning as pl |
| | from pytorch_lightning import Trainer |
| | from pytorch_lightning.loggers import CSVLogger, WandbLogger |
| | from data_provider.stage3_dm import Stage3DM |
| | from data_provider.prot_qa_dm import ProtQADM |
| | from model.blip2_stage3 import Blip2Stage3 |
| | from model.dist_funs import MyDeepSpeedStrategy |
| | from pathlib import Path |
| | import pytorch_lightning.callbacks as plc |
| |
|
| |
|
| | os.environ['OPENBLAS_NUM_THREADS'] = '1' |
| | |
| | warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
| | |
| | torch.set_float32_matmul_precision('medium') |
| |
|
| |
|
| |
|
| |
|
| | def main(args): |
| | pl.seed_everything(args.seed) |
| | |
| | model = Blip2Stage3.load_from_checkpoint(args.checkpoint_name, strict=False, args=args, map_location='cpu') |
| | |
| | print(f"loaded init checkpoint from {args.checkpoint_name}") |
| | |
| | |
| | |
| | print('total params:', sum(p.numel() for p in model.parameters())) |
| | |
| | dm = Stage3DM(args.dataset, args) |
| | |
| | dm.init_tokenizer(model.blip2.llm_tokenizer, model.blip2.plm_tokenizer) |
| | |
| | test_loader = dm.test_dataloader() |
| |
|
| | |
| | batch = next(iter(test_loader)) |
| | |
| | model.eval() |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = model.to(device) |
| | batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} |
| |
|
| | |
| | with torch.no_grad(): |
| | output = model.test_step(batch, 0) |
| | print(output) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | trainer.test(model, datamodule=dm) |
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--filename', type=str, default="stage2_test") |
| | parser.add_argument('--seed', type=int, default=42, help='random seed') |
| | |
| | parser.add_argument('--strategy', type=str, default='deepspeed') |
| |
|
| | |
| | parser.add_argument('--accelerator', type=str, default='gpu') |
| | parser.add_argument('--devices', type=str, default='0,1,2,3') |
| | parser.add_argument('--precision', type=str, default='bf16') |
| | parser.add_argument('--max_epochs', type=int, default=10) |
| | |
| | parser.add_argument('--check_val_every_n_epoch', type=int, default=1) |
| | parser.add_argument('--enable_flash', action='store_true', default=False) |
| | parser.add_argument('--use_wandb_logger', action='store_true', default=False) |
| | parser = Blip2Stage3.add_model_specific_args(parser) |
| | parser = Stage3DM.add_model_specific_args(parser) |
| | args = parser.parse_args() |
| |
|
| | if args.enable_flash: |
| | replace_opt_attn_with_flash_attn() |
| | |
| | print("=========================================") |
| | for k, v in sorted(vars(args).items()): |
| | print(k, '=', v) |
| | print("=========================================") |
| | return args |
| |
|
| | if __name__ == '__main__': |
| | main(get_args()) |
| |
|
| |
|