Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +38 -0
- CGFormer/.gitignore +7 -0
- CGFormer/.ipynb_checkpoints/test-checkpoint.py +106 -0
- CGFormer/.ipynb_checkpoints/test_mosaic-checkpoint.py +106 -0
- CGFormer/LICENSE +21 -0
- CGFormer/README.md +56 -0
- CGFormer/bash_logs/ACE_filter050.log +480 -0
- CGFormer/bash_logs/ACE_filter050_rev.log +528 -0
- CGFormer/bash_logs/sanity_node03.log +0 -0
- CGFormer/bert/__pycache__/activations.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/activations.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/configuration_bert.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/configuration_bert.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/configuration_utils.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/configuration_utils.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/file_utils.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/file_utils.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/generation_utils.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/generation_utils.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/modeling_bert.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/modeling_bert.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/modeling_utils.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/modeling_utils.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_bert.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_utils.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_utils_base.cpython-38.pyc +0 -0
- CGFormer/bert/__pycache__/tokenization_utils_base.cpython-39.pyc +0 -0
- CGFormer/bert/activations.py +56 -0
- CGFormer/bert/configuration_bert.py +143 -0
- CGFormer/bert/configuration_utils.py +408 -0
- CGFormer/bert/file_utils.py +808 -0
- CGFormer/bert/generation_utils.py +993 -0
- CGFormer/bert/modeling_bert.py +1569 -0
- CGFormer/bert/modeling_utils.py +1268 -0
- CGFormer/bert/tokenization_bert.py +545 -0
- CGFormer/bert/tokenization_utils.py +723 -0
- CGFormer/bert/tokenization_utils_base.py +0 -0
- CGFormer/ckpts/swin_base_patch4_window12_384_22k.pth +3 -0
- CGFormer/config/config_gref_ace.yaml +63 -0
- CGFormer/config/config_mosaic_refcocog_u.yaml +51 -0
- CGFormer/config/config_rcc_ace.yaml +63 -0
- CGFormer/config/config_rccp_ace.yaml +63 -0
- CGFormer/config/config_refzom_ace.yaml +64 -0
- CGFormer/config/config_refzom_repro.yaml +62 -0
- CGFormer/config/config_refzom_repro_eval.yaml +62 -0
- CGFormer/config/impl/config.yaml +53 -0
- CGFormer/config/open.yaml +55 -0
- CGFormer/config/refcoco_mosaic/config.yaml +59 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,41 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
CGFormer/external/mmsegmentation/demo/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
CGFormer/external/mmsegmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
CGFormer/external/mmsegmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
CGFormer/external/mmsegmentation/projects/medical/2d_image/histopathology/conic2022_seg/conic2022_seg_dataset.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
CGFormer/external/mmsegmentation/resources/3dogs.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
CGFormer/external/mmsegmentation/resources/seg_demo.gif filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/0.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/1.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/2.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_refuge_dataset/img_dir/pseudo_g0001.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
CGFormer/external/mmsegmentation/tests/data/pseudo_vaihingen_dataset/img_dir/area1_0_0_512_512.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
CGFormer/image/framework.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
CGFormer/wandb/offline-run-20250307_173512-9h2on932/run-9h2on932.wandb filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
CGFormer/wandb/offline-run-20250307_174303-li5zqatl/run-li5zqatl.wandb filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
CGFormer/wandb/offline-run-20250307_182402-j7d7o60n/run-j7d7o60n.wandb filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
CGFormer/wandb/offline-run-20250307_183427-lje8pep7/run-lje8pep7.wandb filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
CGFormer/wandb/offline-run-20250307_191605-qwg5jc6l/run-qwg5jc6l.wandb filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
CGFormer/wandb/offline-run-20250307_191652-pdgidm12/run-pdgidm12.wandb filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
CGFormer/wandb/offline-run-20250307_200613-ikc5v4qd/run-ikc5v4qd.wandb filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
CGFormer/wandb/offline-run-20250307_201001-i0m64au8/run-i0m64au8.wandb filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
CGFormer/wandb/offline-run-20250307_210707-ialdzorz/run-ialdzorz.wandb filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
CGFormer/wandb/offline-run-20250307_211011-2bbev839/run-2bbev839.wandb filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
CGFormer/wandb/offline-run-20250308_193217-dnb3uu3l/run-dnb3uu3l.wandb filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
CGFormer/wandb/offline-run-20250308_231240-qhgmf9fk/run-qhgmf9fk.wandb filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
CGFormer/wandb/offline-run-20250309_115725-5fgrfjdy/run-5fgrfjdy.wandb filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
CGFormer/wandb/offline-run-20250309_170924-0x06srss/run-0x06srss.wandb filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
CGFormer/wandb/offline-run-20250309_171616-684omhh0/run-684omhh0.wandb filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
CGFormer/wandb/offline-run-20250309_171623-3b8sr48c/run-3b8sr48c.wandb filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
CGFormer/wandb/offline-run-20250309_173234-04u0nc2s/run-04u0nc2s.wandb filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
CGFormer/wandb/offline-run-20250309_202147-gbujz424/run-gbujz424.wandb filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
CGFormer/wandb/offline-run-20250309_203311-xfi3d65b/run-xfi3d65b.wandb filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
CGFormer/wandb/offline-run-20250309_203334-e1fdhljy/run-e1fdhljy.wandb filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
CGFormer/wandb/offline-run-20250309_205719-wlfh3gyq/run-wlfh3gyq.wandb filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
CGFormer/wandb/offline-run-20250309_205856-a8l51dy6/run-a8l51dy6.wandb filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
CGFormer/wandb/offline-run-20250309_212058-k3fcizav/run-k3fcizav.wandb filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
CGFormer/wandb/offline-run-20250309_212213-a992tfly/run-a992tfly.wandb filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
CGFormer/wandb/offline-run-20250311_174825-ghnm4ky9/run-ghnm4ky9.wandb filter=lfs diff=lfs merge=lfs -text
|
CGFormer/.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/exp
|
| 2 |
+
/wandb/**
|
| 3 |
+
**/__pycache__
|
| 4 |
+
/train_open.py
|
| 5 |
+
/.vscode
|
| 6 |
+
config/config.yaml
|
| 7 |
+
config/open.yaml
|
CGFormer/.ipynb_checkpoints/test-checkpoint.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
import deepspeed
|
| 12 |
+
import utils.config as config
|
| 13 |
+
from engine.engine import inference
|
| 14 |
+
from model import build_segmenter
|
| 15 |
+
from utils.dataset import RefDataset
|
| 16 |
+
from utils.misc import setup_logger
|
| 17 |
+
|
| 18 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
cv2.setNumThreads(0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_parser():
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description='Pytorch Referring Expression Segmentation')
|
| 28 |
+
parser.add_argument('--config',
|
| 29 |
+
default='path to xxx.yaml',
|
| 30 |
+
type=str,
|
| 31 |
+
help='config file')
|
| 32 |
+
parser.add_argument('--opts',
|
| 33 |
+
default=None,
|
| 34 |
+
nargs=argparse.REMAINDER,
|
| 35 |
+
help='override some settings in the config.')
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
assert args.config is not None
|
| 38 |
+
cfg = config.load_cfg_from_cfg_file(args.config)
|
| 39 |
+
if args.opts is not None:
|
| 40 |
+
cfg = config.merge_cfg_from_list(cfg, args.opts)
|
| 41 |
+
return cfg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@logger.catch
|
| 45 |
+
def main():
|
| 46 |
+
args = get_parser()
|
| 47 |
+
args.output_dir = os.path.join(args.output_folder, args.exp_name)
|
| 48 |
+
if args.visualize:
|
| 49 |
+
args.vis_dir = os.path.join(args.output_dir, "vis")
|
| 50 |
+
os.makedirs(args.vis_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# logger
|
| 53 |
+
setup_logger(args.output_dir,
|
| 54 |
+
distributed_rank=0,
|
| 55 |
+
filename="test.log",
|
| 56 |
+
mode="a")
|
| 57 |
+
logger.info(args.test_split)
|
| 58 |
+
|
| 59 |
+
# build dataset & dataloader
|
| 60 |
+
test_data = RefDataset(lmdb_dir=args.test_lmdb,
|
| 61 |
+
mask_dir=args.mask_root,
|
| 62 |
+
dataset=args.dataset,
|
| 63 |
+
split=args.test_split,
|
| 64 |
+
mode='test',
|
| 65 |
+
input_size=args.input_size,
|
| 66 |
+
word_length=args.word_len,
|
| 67 |
+
args=args)
|
| 68 |
+
test_loader = torch.utils.data.DataLoader(test_data,
|
| 69 |
+
batch_size=1,
|
| 70 |
+
shuffle=False,
|
| 71 |
+
num_workers=1,
|
| 72 |
+
pin_memory=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# build model
|
| 76 |
+
model = build_segmenter(args, DDP=False)
|
| 77 |
+
|
| 78 |
+
# deepspeed_config = {
|
| 79 |
+
# "kernel_inject": False,
|
| 80 |
+
# "dtype": "fp16",
|
| 81 |
+
# "enable_cuda_graph": True,
|
| 82 |
+
# "checkpoint": f'{args.output_dir}/best_model'
|
| 83 |
+
# }
|
| 84 |
+
|
| 85 |
+
# logger.info(model)
|
| 86 |
+
|
| 87 |
+
#args.model_dir = os.path.join(args.output_dir, "best_model.pth")
|
| 88 |
+
if os.path.isdir(args.output_dir):
|
| 89 |
+
logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
|
| 90 |
+
#checkpoint = torch.load(args.model_dir)
|
| 91 |
+
#model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
| 92 |
+
model = load_state_dict_from_zero_checkpoint(model, args.output_dir, tag="best_model").cuda()
|
| 93 |
+
#model.load_checkpoint(args.output_dir, tag="best_model")
|
| 94 |
+
|
| 95 |
+
logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
|
| 99 |
+
.format(args.output_dir))
|
| 100 |
+
|
| 101 |
+
# inference
|
| 102 |
+
inference(test_loader, model, args)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
main()
|
CGFormer/.ipynb_checkpoints/test_mosaic-checkpoint.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.parallel
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from loguru import logger
|
| 10 |
+
|
| 11 |
+
import deepspeed
|
| 12 |
+
import utils.config as config
|
| 13 |
+
from engine.engine import inference
|
| 14 |
+
from model import build_segmenter
|
| 15 |
+
from utils.dataset_mosaic import RefDataset
|
| 16 |
+
from utils.misc import setup_logger
|
| 17 |
+
|
| 18 |
+
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
warnings.filterwarnings("ignore")
|
| 22 |
+
cv2.setNumThreads(0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_parser():
|
| 26 |
+
parser = argparse.ArgumentParser(
|
| 27 |
+
description='Pytorch Referring Expression Segmentation')
|
| 28 |
+
parser.add_argument('--config',
|
| 29 |
+
default='path to xxx.yaml',
|
| 30 |
+
type=str,
|
| 31 |
+
help='config file')
|
| 32 |
+
parser.add_argument('--opts',
|
| 33 |
+
default=None,
|
| 34 |
+
nargs=argparse.REMAINDER,
|
| 35 |
+
help='override some settings in the config.')
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
assert args.config is not None
|
| 38 |
+
cfg = config.load_cfg_from_cfg_file(args.config)
|
| 39 |
+
if args.opts is not None:
|
| 40 |
+
cfg = config.merge_cfg_from_list(cfg, args.opts)
|
| 41 |
+
return cfg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@logger.catch
|
| 45 |
+
def main():
|
| 46 |
+
args = get_parser()
|
| 47 |
+
args.output_dir = os.path.join(args.output_folder, args.exp_name)
|
| 48 |
+
if args.visualize:
|
| 49 |
+
args.vis_dir = os.path.join(args.output_dir, "vis")
|
| 50 |
+
os.makedirs(args.vis_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# logger
|
| 53 |
+
setup_logger(args.output_dir,
|
| 54 |
+
distributed_rank=0,
|
| 55 |
+
filename="test.log",
|
| 56 |
+
mode="a")
|
| 57 |
+
logger.info(args.test_split)
|
| 58 |
+
|
| 59 |
+
# build dataset & dataloader
|
| 60 |
+
test_data = RefDataset(lmdb_dir=args.test_lmdb,
|
| 61 |
+
mask_dir=args.mask_root,
|
| 62 |
+
dataset=args.dataset,
|
| 63 |
+
split=args.test_split,
|
| 64 |
+
mode='test',
|
| 65 |
+
input_size=args.input_size,
|
| 66 |
+
word_length=args.word_len,
|
| 67 |
+
args=args)
|
| 68 |
+
test_loader = torch.utils.data.DataLoader(test_data,
|
| 69 |
+
batch_size=1,
|
| 70 |
+
shuffle=False,
|
| 71 |
+
num_workers=1,
|
| 72 |
+
pin_memory=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# build model
|
| 76 |
+
model = build_segmenter(args, DDP=False)
|
| 77 |
+
|
| 78 |
+
# deepspeed_config = {
|
| 79 |
+
# "kernel_inject": False,
|
| 80 |
+
# "dtype": "fp16",
|
| 81 |
+
# "enable_cuda_graph": True,
|
| 82 |
+
# "checkpoint": f'{args.output_dir}/best_model'
|
| 83 |
+
# }
|
| 84 |
+
|
| 85 |
+
# logger.info(model)
|
| 86 |
+
|
| 87 |
+
#args.model_dir = os.path.join(args.output_dir, "best_model.pth")
|
| 88 |
+
if os.path.isdir(args.output_dir):
|
| 89 |
+
logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
|
| 90 |
+
#checkpoint = torch.load(args.model_dir)
|
| 91 |
+
#model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)
|
| 92 |
+
model = load_state_dict_from_zero_checkpoint(model, args.output_dir, tag="best_model").cuda()
|
| 93 |
+
#model.load_checkpoint(args.output_dir, tag="best_model")
|
| 94 |
+
|
| 95 |
+
logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
|
| 99 |
+
.format(args.output_dir))
|
| 100 |
+
|
| 101 |
+
# inference
|
| 102 |
+
inference(test_loader, model, args)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == '__main__':
|
| 106 |
+
main()
|
CGFormer/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2023 Toneyaya
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
CGFormer/README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CGFormer
|
| 2 |
+
The official PyTorch implementation of the CVPR 2023 paper "Contrastive Grouping with Transformer for Referring Image Segmentation".
|
| 3 |
+
|
| 4 |
+
This paper first introduces learnable query tokens to represent objects and then alternately queries linguistic features and groups visual features into the query tokens for object-aware cross-modal reasoning. CGFormer achieves cross-level interaction by jointly updating the query tokens and decoding masks in every two consecutive layers. In addition, we introduce new splits on datasets for evaluating generalization for referring image segmentation models.
|
| 5 |
+
|
| 6 |
+
## Framework
|
| 7 |
+
<p align="center">
|
| 8 |
+
<img src="image/framework.jpg" width="1000">
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
## Preparation
|
| 12 |
+
|
| 13 |
+
1. Environment
|
| 14 |
+
- [PyTorch](www.pytorch.org)
|
| 15 |
+
- Other dependencies in `requirements.txt`
|
| 16 |
+
2. Datasets
|
| 17 |
+
- The detailed instruction is in [prepare_datasets](data/READEME.md)
|
| 18 |
+
3. Pretrained weights
|
| 19 |
+
- [Swin-Base-window12](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)
|
| 20 |
+
|
| 21 |
+
## Train and Test (RIS)
|
| 22 |
+
|
| 23 |
+
This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. Besides, the evaluation only supports single-gpu mode.
|
| 24 |
+
|
| 25 |
+
To do training of CGFormer with 8 GPUs, run:
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
python -u train.py --config config/config.yaml
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
To do evaluation of CGFormer with 1 GPU, run:
|
| 32 |
+
```
|
| 33 |
+
CUDA_VISIBLE_DEVICES=0 python -u test.py \
|
| 34 |
+
--config config/refcoco/config.yaml \
|
| 35 |
+
--opts TEST.test_split val \
|
| 36 |
+
TEST.test_lmdb path/val.lmdb
|
| 37 |
+
```
|
| 38 |
+
## License
|
| 39 |
+
|
| 40 |
+
This project is under the MIT license. See [LICENSE](LICENSE) for details.
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## Citation
|
| 44 |
+
If you find our work useful in your research, please consider citing:
|
| 45 |
+
```
|
| 46 |
+
@InProceedings{Tang_2023_CVPR,
|
| 47 |
+
author = {Tang, Jiajin and Zheng, Ge and Shi, Cheng and Yang, Sibei},
|
| 48 |
+
title = {Contrastive Grouping With Transformer for Referring Image Segmentation},
|
| 49 |
+
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
| 50 |
+
month = {June},
|
| 51 |
+
year = {2023},
|
| 52 |
+
pages = {23570-23580}
|
| 53 |
+
}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Some code changes come from [CRIS](https://github.com/DerrickWang005/CRIS.pytorch/tree/master) and [LAVT](https://github.com/yz93/LAVT-RIS).
|
CGFormer/bash_logs/ACE_filter050.log
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated
|
| 2 |
+
and will be removed in future. Use torchrun.
|
| 3 |
+
Note that --use-env is set by default in torchrun.
|
| 4 |
+
If your script expects `--local-rank` argument to be set, please
|
| 5 |
+
change it to read from `os.environ['LOCAL_RANK']` instead. See
|
| 6 |
+
https://pytorch.org/docs/stable/distributed.html#launch-utility for
|
| 7 |
+
further instructions
|
| 8 |
+
|
| 9 |
+
warnings.warn(
|
| 10 |
+
[2025-03-03 00:25:13,383] torch.distributed.run: [WARNING]
|
| 11 |
+
[2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] *****************************************
|
| 12 |
+
[2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 13 |
+
[2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] *****************************************
|
| 14 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 15 |
+
check_for_updates()
|
| 16 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 17 |
+
check_for_updates()
|
| 18 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 19 |
+
check_for_updates()
|
| 20 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 21 |
+
check_for_updates()
|
| 22 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 23 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 24 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 25 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 26 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 27 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 28 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 29 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 30 |
+
2025-03-03 00:25:21.347 | INFO | __main__:main:66 - LOCAL_RANK from env: 1
|
| 31 |
+
2025-03-03 00:25:21.347 | INFO | __main__:main:66 - LOCAL_RANK from env: 0
|
| 32 |
+
2025-03-03 00:25:21.359 | INFO | __main__:main:66 - LOCAL_RANK from env: 3
|
| 33 |
+
2025-03-03 00:25:21.369 | INFO | __main__:main:66 - LOCAL_RANK from env: 2
|
| 34 |
+
2025-03-03 00:25:21 | INFO | __main__:90 - Starting with GPU: 0, Rank: 0, World Size: 4
|
| 35 |
+
git root error: Cmd('git') failed due to: exit code(128)
|
| 36 |
+
cmdline: git rev-parse --show-toplevel
|
| 37 |
+
stderr: 'fatal: detected dubious ownership in repository at '/data2/projects/chaeyun/CGFormer'
|
| 38 |
+
To add an exception for this directory, call:
|
| 39 |
+
|
| 40 |
+
git config --global --add safe.directory /data2/projects/chaeyun/CGFormer'
|
| 41 |
+
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
|
| 42 |
+
wandb: Tracking run with wandb version 0.19.1
|
| 43 |
+
wandb: W&B syncing is set to `offline` in this directory.
|
| 44 |
+
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
|
| 45 |
+
node03:2017036:2017036 [0] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 46 |
+
node03:2017036:2017036 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 47 |
+
node03:2017036:2017036 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 48 |
+
node03:2017036:2017036 [0] NCCL INFO cudaDriverVersion 12070
|
| 49 |
+
NCCL version 2.18.5+cuda11.8
|
| 50 |
+
node03:2017039:2017039 [3] NCCL INFO cudaDriverVersion 12070
|
| 51 |
+
node03:2017038:2017038 [2] NCCL INFO cudaDriverVersion 12070
|
| 52 |
+
node03:2017037:2017037 [1] NCCL INFO cudaDriverVersion 12070
|
| 53 |
+
node03:2017038:2017038 [2] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 54 |
+
node03:2017039:2017039 [3] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 55 |
+
node03:2017037:2017037 [1] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 56 |
+
node03:2017038:2017038 [2] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 57 |
+
node03:2017038:2017038 [2] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 58 |
+
node03:2017039:2017039 [3] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 59 |
+
node03:2017039:2017039 [3] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 60 |
+
node03:2017037:2017037 [1] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 61 |
+
node03:2017037:2017037 [1] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 62 |
+
node03:2017039:2017152 [3] NCCL INFO NET/IB : No device found.
|
| 63 |
+
node03:2017039:2017152 [3] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 64 |
+
node03:2017039:2017152 [3] NCCL INFO Using network Socket
|
| 65 |
+
node03:2017038:2017151 [2] NCCL INFO NET/IB : No device found.
|
| 66 |
+
node03:2017038:2017151 [2] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 67 |
+
node03:2017038:2017151 [2] NCCL INFO Using network Socket
|
| 68 |
+
node03:2017036:2017153 [0] NCCL INFO NET/IB : No device found.
|
| 69 |
+
node03:2017037:2017154 [1] NCCL INFO NET/IB : No device found.
|
| 70 |
+
node03:2017036:2017153 [0] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 71 |
+
node03:2017036:2017153 [0] NCCL INFO Using network Socket
|
| 72 |
+
node03:2017037:2017154 [1] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 73 |
+
node03:2017037:2017154 [1] NCCL INFO Using network Socket
|
| 74 |
+
node03:2017038:2017151 [2] NCCL INFO comm 0xaa0bbe0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 14000 commId 0x190df8e791fecf1b - Init START
|
| 75 |
+
node03:2017037:2017154 [1] NCCL INFO comm 0xadc9fc0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 13000 commId 0x190df8e791fecf1b - Init START
|
| 76 |
+
node03:2017039:2017152 [3] NCCL INFO comm 0xb009870 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 48000 commId 0x190df8e791fecf1b - Init START
|
| 77 |
+
node03:2017036:2017153 [0] NCCL INFO comm 0xb555ea0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 12000 commId 0x190df8e791fecf1b - Init START
|
| 78 |
+
node03:2017037:2017154 [1] NCCL INFO Setting affinity for GPU 1 to 5500,00000055
|
| 79 |
+
node03:2017038:2017151 [2] NCCL INFO Setting affinity for GPU 2 to 5500,00000055
|
| 80 |
+
node03:2017039:2017152 [3] NCCL INFO Setting affinity for GPU 3 to 5500,00000055
|
| 81 |
+
node03:2017036:2017153 [0] NCCL INFO Setting affinity for GPU 0 to 5500,00000055
|
| 82 |
+
node03:2017039:2017152 [3] NCCL INFO Trees [0] -1/-1/-1->3->2 [1] -1/-1/-1->3->2
|
| 83 |
+
node03:2017036:2017153 [0] NCCL INFO Channel 00/02 : 0 1 2 3
|
| 84 |
+
node03:2017038:2017151 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1
|
| 85 |
+
node03:2017037:2017154 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0
|
| 86 |
+
node03:2017039:2017152 [3] NCCL INFO P2P Chunksize set to 131072
|
| 87 |
+
node03:2017036:2017153 [0] NCCL INFO Channel 01/02 : 0 1 2 3
|
| 88 |
+
node03:2017038:2017151 [2] NCCL INFO P2P Chunksize set to 131072
|
| 89 |
+
node03:2017037:2017154 [1] NCCL INFO P2P Chunksize set to 131072
|
| 90 |
+
node03:2017036:2017153 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
|
| 91 |
+
node03:2017036:2017153 [0] NCCL INFO P2P Chunksize set to 131072
|
| 92 |
+
node03:2017037:2017154 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/IPC
|
| 93 |
+
node03:2017038:2017151 [2] NCCL INFO Channel 00 : 2[2] -> 3[3] via SHM/direct/direct
|
| 94 |
+
node03:2017037:2017154 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/IPC
|
| 95 |
+
node03:2017038:2017151 [2] NCCL INFO Channel 01 : 2[2] -> 3[3] via SHM/direct/direct
|
| 96 |
+
node03:2017036:2017153 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/IPC
|
| 97 |
+
node03:2017039:2017152 [3] NCCL INFO Channel 00 : 3[3] -> 0[0] via SHM/direct/direct
|
| 98 |
+
node03:2017039:2017152 [3] NCCL INFO Channel 01 : 3[3] -> 0[0] via SHM/direct/direct
|
| 99 |
+
node03:2017036:2017153 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/IPC
|
| 100 |
+
node03:2017037:2017154 [1] NCCL INFO Connected all rings
|
| 101 |
+
node03:2017036:2017153 [0] NCCL INFO Connected all rings
|
| 102 |
+
node03:2017038:2017151 [2] NCCL INFO Connected all rings
|
| 103 |
+
node03:2017039:2017152 [3] NCCL INFO Connected all rings
|
| 104 |
+
node03:2017037:2017154 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/IPC
|
| 105 |
+
node03:2017039:2017152 [3] NCCL INFO Channel 00 : 3[3] -> 2[2] via SHM/direct/direct
|
| 106 |
+
node03:2017037:2017154 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/IPC
|
| 107 |
+
node03:2017039:2017152 [3] NCCL INFO Channel 01 : 3[3] -> 2[2] via SHM/direct/direct
|
| 108 |
+
node03:2017036:2017153 [0] NCCL INFO Connected all trees
|
| 109 |
+
node03:2017036:2017153 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
|
| 110 |
+
node03:2017036:2017153 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 111 |
+
node03:2017038:2017151 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/IPC
|
| 112 |
+
node03:2017038:2017151 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/IPC
|
| 113 |
+
node03:2017037:2017154 [1] NCCL INFO Connected all trees
|
| 114 |
+
node03:2017037:2017154 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
|
| 115 |
+
node03:2017037:2017154 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 116 |
+
node03:2017038:2017151 [2] NCCL INFO Connected all trees
|
| 117 |
+
node03:2017038:2017151 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
|
| 118 |
+
node03:2017038:2017151 [2] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 119 |
+
node03:2017039:2017152 [3] NCCL INFO Connected all trees
|
| 120 |
+
node03:2017039:2017152 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
|
| 121 |
+
node03:2017039:2017152 [3] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 122 |
+
node03:2017039:2017152 [3] NCCL INFO comm 0xb009870 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 48000 commId 0x190df8e791fecf1b - Init COMPLETE
|
| 123 |
+
node03:2017037:2017154 [1] NCCL INFO comm 0xadc9fc0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 13000 commId 0x190df8e791fecf1b - Init COMPLETE
|
| 124 |
+
node03:2017038:2017151 [2] NCCL INFO comm 0xaa0bbe0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 14000 commId 0x190df8e791fecf1b - Init COMPLETE
|
| 125 |
+
node03:2017036:2017153 [0] NCCL INFO comm 0xb555ea0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 12000 commId 0x190df8e791fecf1b - Init COMPLETE
|
| 126 |
+
2025-03-03 00:25:23 | INFO | model:31 - Window size 12!
|
| 127 |
+
2025-03-03 00:25:24 | INFO | model:51 - Initializing Multi-modal Swin Transformer weights from ckpts/swin_base_patch4_window12_384_22k.pth
|
| 128 |
+
2025-03-03 00:25:25 | INFO | model.backbone:459 - loading swin success !!!
|
| 129 |
+
2025-03-03 00:25:29 | INFO | __main__:144 - Model moved to GPU: 0
|
| 130 |
+
2025-03-03 00:25:29 | INFO | __main__:145 - amsgrad: True
|
| 131 |
+
batch_size: 24
|
| 132 |
+
batch_size_val: 16
|
| 133 |
+
bert: bert-base-uncased
|
| 134 |
+
dataset: refcocog_u
|
| 135 |
+
dist_backend: nccl
|
| 136 |
+
dropout: 0.0
|
| 137 |
+
epochs: 50
|
| 138 |
+
evaluate: True
|
| 139 |
+
exclude_multiobj: True
|
| 140 |
+
exp_name: ACE_filter050
|
| 141 |
+
filter_threshold: 0.5
|
| 142 |
+
fusion_drop: 0.0
|
| 143 |
+
gpu: 0
|
| 144 |
+
hp_selection: strict
|
| 145 |
+
input_size: 480
|
| 146 |
+
local_rank: 0
|
| 147 |
+
loss_option: ACE_verbonly
|
| 148 |
+
lr: 0.0001
|
| 149 |
+
lr_backbone: 5e-05
|
| 150 |
+
lr_text_encoder: 5e-05
|
| 151 |
+
manual_seed: 2051388757
|
| 152 |
+
margin_value: 12
|
| 153 |
+
mask_root: data/masks/refcocog_u
|
| 154 |
+
metric_learning: True
|
| 155 |
+
metric_loss_weight: 0.1
|
| 156 |
+
metric_mode: hardpos_only_sbertsim_refined
|
| 157 |
+
mha: 8-8-8-8
|
| 158 |
+
mixup_lasttwo: True
|
| 159 |
+
num_token: 2
|
| 160 |
+
output_dir: exp/refcoco_u/ACE_filter050
|
| 161 |
+
output_folder: exp/refcoco_u
|
| 162 |
+
print_freq: 100
|
| 163 |
+
rank: 0
|
| 164 |
+
resume: None
|
| 165 |
+
save_freq: 1
|
| 166 |
+
start_epoch: 0
|
| 167 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 168 |
+
swin_type: base
|
| 169 |
+
sync_bn: True
|
| 170 |
+
temperature: 0.07
|
| 171 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 172 |
+
test_split: test
|
| 173 |
+
token_dim: 512
|
| 174 |
+
train_lmdb: data/lmdb/refcocog_u/train.lmdb
|
| 175 |
+
train_split: train
|
| 176 |
+
val_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 177 |
+
val_split: val
|
| 178 |
+
vis_dim: 512
|
| 179 |
+
visualize: False
|
| 180 |
+
weight: None
|
| 181 |
+
weight_decay: 0.0001
|
| 182 |
+
window12: True
|
| 183 |
+
word_dim: 768
|
| 184 |
+
word_len: 20
|
| 185 |
+
workers: 32
|
| 186 |
+
workers_val: 8
|
| 187 |
+
world_size: 4
|
| 188 |
+
2025-03-03 00:28:05 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 100/1759] Batch=1.34 (1.53) Data=0.00 (0.09) Lr=0.000100 Loss=1.1470 (1.1906) IoU=15.42 (19.69) Prec@50=0.00 (8.61)
|
| 189 |
+
2025-03-03 00:30:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 200/1759] Batch=1.54 (1.50) Data=0.00 (0.06) Lr=0.000100 Loss=1.1851 (1.1191) IoU=24.54 (23.36) Prec@50=12.77 (12.84)
|
| 190 |
+
2025-03-03 00:32:56 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 300/1759] Batch=1.23 (1.48) Data=0.00 (0.05) Lr=0.000100 Loss=0.9150 (1.0781) IoU=33.89 (25.72) Prec@50=29.46 (15.38)
|
| 191 |
+
2025-03-03 00:35:23 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 400/1759] Batch=1.47 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.9083 (1.0518) IoU=31.66 (26.80) Prec@50=18.33 (16.54)
|
| 192 |
+
2025-03-03 00:37:51 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 500/1759] Batch=1.51 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.9000 (1.0296) IoU=32.92 (27.76) Prec@50=22.46 (17.54)
|
| 193 |
+
2025-03-03 00:40:19 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 600/1759] Batch=1.31 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.8722 (1.0109) IoU=40.11 (28.53) Prec@50=26.04 (18.47)
|
| 194 |
+
2025-03-03 00:42:47 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 700/1759] Batch=1.56 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8722 (0.9958) IoU=38.67 (29.04) Prec@50=33.12 (19.23)
|
| 195 |
+
2025-03-03 00:45:15 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 800/1759] Batch=1.21 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.9943 (0.9830) IoU=31.87 (29.65) Prec@50=22.32 (20.11)
|
| 196 |
+
2025-03-03 00:47:43 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 900/1759] Batch=1.35 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8947 (0.9714) IoU=36.31 (30.10) Prec@50=25.10 (20.81)
|
| 197 |
+
2025-03-03 00:50:09 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1000/1759] Batch=1.77 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8877 (0.9635) IoU=33.68 (30.41) Prec@50=25.49 (21.28)
|
| 198 |
+
2025-03-03 00:52:36 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1100/1759] Batch=1.59 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7724 (0.9548) IoU=38.99 (30.85) Prec@50=36.88 (21.78)
|
| 199 |
+
2025-03-03 00:55:03 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1200/1759] Batch=1.45 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7028 (0.9458) IoU=47.27 (31.19) Prec@50=49.79 (22.26)
|
| 200 |
+
2025-03-03 00:57:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1300/1759] Batch=1.42 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7939 (0.9381) IoU=39.03 (31.59) Prec@50=25.15 (22.68)
|
| 201 |
+
2025-03-03 00:59:59 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1400/1759] Batch=1.27 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7186 (0.9315) IoU=40.97 (31.96) Prec@50=35.71 (23.16)
|
| 202 |
+
2025-03-03 01:02:23 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1500/1759] Batch=1.80 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8415 (0.9243) IoU=33.69 (32.31) Prec@50=23.39 (23.62)
|
| 203 |
+
2025-03-03 01:04:50 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1600/1759] Batch=2.03 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8356 (0.9182) IoU=35.39 (32.66) Prec@50=21.51 (24.10)
|
| 204 |
+
2025-03-03 01:07:15 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1700/1759] Batch=1.40 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7456 (0.9132) IoU=41.80 (33.01) Prec@50=30.85 (24.57)
|
| 205 |
+
2025-03-03 01:09:16 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[1/50] mIoU=43.35 oIoU=42.54 Pr@50: 39.60 Pr@60: 28.96 Pr@70: 18.63 Pr@80: 10.71 Pr@90: 3.11
|
| 206 |
+
2025-03-03 01:12:09 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 100/1759] Batch=1.30 (1.47) Data=0.00 (0.04) Lr=0.000100 Loss=0.7384 (0.7771) IoU=36.20 (40.67) Prec@50=25.00 (35.48)
|
| 207 |
+
2025-03-03 01:14:38 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 200/1759] Batch=1.28 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7371 (0.7694) IoU=41.03 (40.41) Prec@50=33.33 (35.67)
|
| 208 |
+
2025-03-03 01:17:04 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 300/1759] Batch=1.42 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7440 (0.7664) IoU=37.40 (40.49) Prec@50=29.86 (36.04)
|
| 209 |
+
2025-03-03 01:19:35 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 400/1759] Batch=1.50 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7437 (0.7650) IoU=35.39 (40.39) Prec@50=33.61 (36.12)
|
| 210 |
+
2025-03-03 01:22:01 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 500/1759] Batch=1.61 (1.48) Data=0.01 (0.03) Lr=0.000100 Loss=0.8122 (0.7668) IoU=40.75 (40.34) Prec@50=39.79 (36.16)
|
| 211 |
+
2025-03-03 01:24:25 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 600/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7956 (0.7655) IoU=35.37 (40.32) Prec@50=29.86 (36.20)
|
| 212 |
+
2025-03-03 01:26:51 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 700/1759] Batch=1.30 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.7419 (0.7664) IoU=43.77 (40.30) Prec@50=37.50 (36.12)
|
| 213 |
+
2025-03-03 01:29:16 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 800/1759] Batch=1.22 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7913 (0.7653) IoU=42.33 (40.31) Prec@50=45.83 (36.31)
|
| 214 |
+
2025-03-03 01:31:42 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 900/1759] Batch=1.67 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8527 (0.7637) IoU=32.06 (40.55) Prec@50=30.27 (36.72)
|
| 215 |
+
2025-03-03 01:34:07 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1000/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8082 (0.7611) IoU=40.24 (40.63) Prec@50=29.76 (36.89)
|
| 216 |
+
2025-03-03 01:36:33 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1100/1759] Batch=1.27 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.8084 (0.7607) IoU=38.52 (40.66) Prec@50=31.25 (36.92)
|
| 217 |
+
2025-03-03 01:38:57 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1200/1759] Batch=1.47 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6577 (0.7588) IoU=41.28 (40.73) Prec@50=39.49 (37.16)
|
| 218 |
+
2025-03-03 01:41:22 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1300/1759] Batch=1.32 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7004 (0.7568) IoU=43.16 (40.84) Prec@50=37.05 (37.41)
|
| 219 |
+
2025-03-03 01:43:49 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1400/1759] Batch=1.23 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7302 (0.7556) IoU=50.00 (40.88) Prec@50=54.76 (37.47)
|
| 220 |
+
2025-03-03 01:46:16 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1500/1759] Batch=1.36 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.8330 (0.7550) IoU=39.82 (40.90) Prec@50=27.08 (37.47)
|
| 221 |
+
2025-03-03 01:48:44 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1600/1759] Batch=1.76 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7036 (0.7541) IoU=44.15 (40.92) Prec@50=44.38 (37.50)
|
| 222 |
+
2025-03-03 01:51:09 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1700/1759] Batch=1.50 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6859 (0.7531) IoU=42.78 (41.05) Prec@50=39.23 (37.69)
|
| 223 |
+
2025-03-03 01:53:12 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[2/50] mIoU=49.64 oIoU=46.67 Pr@50: 47.79 Pr@60: 37.85 Pr@70: 29.00 Pr@80: 18.63 Pr@90: 6.25
|
| 224 |
+
2025-03-03 01:56:07 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 100/1759] Batch=1.46 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.6057 (0.6714) IoU=51.61 (44.71) Prec@50=54.41 (44.64)
|
| 225 |
+
2025-03-03 01:58:35 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 200/1759] Batch=1.53 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.9918 (0.6685) IoU=29.66 (44.32) Prec@50=33.96 (43.73)
|
| 226 |
+
2025-03-03 02:01:02 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 300/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6285 (0.6760) IoU=49.97 (44.40) Prec@50=53.97 (43.81)
|
| 227 |
+
2025-03-03 02:03:29 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 400/1759] Batch=1.83 (1.48) Data=0.01 (0.03) Lr=0.000100 Loss=0.8127 (0.6792) IoU=33.49 (44.24) Prec@50=30.44 (43.48)
|
| 228 |
+
2025-03-03 02:05:57 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 500/1759] Batch=1.61 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6658 (0.6781) IoU=43.32 (44.19) Prec@50=49.29 (43.60)
|
| 229 |
+
2025-03-03 02:08:23 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 600/1759] Batch=1.67 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6875 (0.6804) IoU=45.37 (43.86) Prec@50=38.48 (43.05)
|
| 230 |
+
2025-03-03 02:10:51 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 700/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6716 (0.6817) IoU=49.63 (43.89) Prec@50=45.83 (42.92)
|
| 231 |
+
2025-03-03 02:13:18 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 800/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8302 (0.6830) IoU=40.04 (44.00) Prec@50=37.35 (43.08)
|
| 232 |
+
2025-03-03 02:15:43 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 900/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8283 (0.6843) IoU=36.38 (44.04) Prec@50=28.97 (43.11)
|
| 233 |
+
2025-03-03 02:18:09 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1000/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6356 (0.6843) IoU=48.44 (43.85) Prec@50=50.15 (42.86)
|
| 234 |
+
2025-03-03 02:20:34 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1100/1759] Batch=1.86 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5373 (0.6842) IoU=44.84 (43.78) Prec@50=50.40 (42.74)
|
| 235 |
+
2025-03-03 02:23:04 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1200/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6934 (0.6824) IoU=43.73 (43.80) Prec@50=42.86 (42.84)
|
| 236 |
+
2025-03-03 02:25:31 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1300/1759] Batch=1.38 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8082 (0.6824) IoU=44.25 (43.82) Prec@50=43.25 (42.90)
|
| 237 |
+
2025-03-03 02:27:58 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1400/1759] Batch=1.77 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7234 (0.6808) IoU=37.74 (43.89) Prec@50=31.61 (42.97)
|
| 238 |
+
2025-03-03 02:30:25 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6207 (0.6785) IoU=52.28 (44.04) Prec@50=49.60 (43.19)
|
| 239 |
+
2025-03-03 02:32:52 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1600/1759] Batch=1.38 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5630 (0.6783) IoU=50.30 (44.12) Prec@50=51.19 (43.27)
|
| 240 |
+
2025-03-03 02:35:20 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1700/1759] Batch=1.35 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6715 (0.6784) IoU=39.24 (44.26) Prec@50=39.93 (43.41)
|
| 241 |
+
2025-03-03 02:37:23 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[3/50] mIoU=52.86 oIoU=49.37 Pr@50: 54.19 Pr@60: 45.54 Pr@70: 36.26 Pr@80: 25.82 Pr@90: 10.52
|
| 242 |
+
2025-03-03 02:40:20 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 100/1759] Batch=1.27 (1.50) Data=0.00 (0.04) Lr=0.000100 Loss=0.6958 (0.6286) IoU=44.29 (46.25) Prec@50=38.39 (45.99)
|
| 243 |
+
2025-03-03 02:42:47 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 200/1759] Batch=1.30 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6073 (0.6325) IoU=46.17 (46.22) Prec@50=52.38 (45.96)
|
| 244 |
+
2025-03-03 02:45:16 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 300/1759] Batch=1.32 (1.49) Data=0.00 (0.03) Lr=0.000100 Loss=0.5563 (0.6330) IoU=52.67 (46.34) Prec@50=59.87 (45.94)
|
| 245 |
+
2025-03-03 02:47:42 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 400/1759] Batch=1.30 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6527 (0.6364) IoU=52.90 (46.53) Prec@50=59.38 (46.17)
|
| 246 |
+
2025-03-03 02:50:08 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 500/1759] Batch=1.28 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7021 (0.6343) IoU=49.21 (46.89) Prec@50=42.71 (46.54)
|
| 247 |
+
2025-03-03 02:52:34 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 600/1759] Batch=1.24 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7423 (0.6356) IoU=39.62 (46.82) Prec@50=36.61 (46.63)
|
| 248 |
+
2025-03-03 02:55:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 700/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7147 (0.6425) IoU=42.98 (46.63) Prec@50=37.15 (46.32)
|
| 249 |
+
2025-03-03 02:57:30 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 800/1759] Batch=1.77 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5617 (0.6499) IoU=42.31 (46.31) Prec@50=36.67 (45.86)
|
| 250 |
+
2025-03-03 02:59:58 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 900/1759] Batch=1.70 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6406 (0.6511) IoU=46.95 (46.46) Prec@50=45.00 (46.10)
|
| 251 |
+
2025-03-03 03:02:23 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1000/1759] Batch=1.28 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5956 (0.6502) IoU=47.62 (46.59) Prec@50=43.75 (46.30)
|
| 252 |
+
2025-03-03 03:04:52 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1100/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6672 (0.6503) IoU=51.07 (46.69) Prec@50=51.29 (46.50)
|
| 253 |
+
2025-03-03 03:07:16 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1200/1759] Batch=1.22 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6364 (0.6505) IoU=51.87 (46.69) Prec@50=51.93 (46.48)
|
| 254 |
+
2025-03-03 03:09:43 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1300/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4925 (0.6489) IoU=57.07 (46.74) Prec@50=55.01 (46.50)
|
| 255 |
+
2025-03-03 03:12:10 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1400/1759] Batch=1.79 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.6326 (0.6491) IoU=47.66 (46.81) Prec@50=46.04 (46.58)
|
| 256 |
+
2025-03-03 03:14:35 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6763 (0.6487) IoU=52.54 (46.97) Prec@50=58.58 (46.77)
|
| 257 |
+
2025-03-03 03:17:03 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1600/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6583 (0.6469) IoU=48.20 (47.13) Prec@50=50.83 (46.97)
|
| 258 |
+
2025-03-03 03:19:31 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1700/1759] Batch=1.27 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6820 (0.6478) IoU=43.32 (47.12) Prec@50=48.66 (46.90)
|
| 259 |
+
2025-03-03 03:21:35 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[4/50] mIoU=54.75 oIoU=50.99 Pr@50: 57.73 Pr@60: 48.80 Pr@70: 40.37 Pr@80: 29.04 Pr@90: 12.15
|
| 260 |
+
2025-03-03 03:24:28 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 100/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5842 (0.5822) IoU=52.98 (50.21) Prec@50=50.89 (51.29)
|
| 261 |
+
2025-03-03 03:26:54 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 200/1759] Batch=1.42 (1.46) Data=0.01 (0.03) Lr=0.000100 Loss=0.5671 (0.5931) IoU=52.04 (49.29) Prec@50=67.11 (50.16)
|
| 262 |
+
2025-03-03 03:29:22 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 300/1759] Batch=1.35 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6269 (0.5962) IoU=49.78 (48.97) Prec@50=53.17 (49.76)
|
| 263 |
+
2025-03-03 03:31:49 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 400/1759] Batch=1.48 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5380 (0.5930) IoU=53.36 (49.08) Prec@50=60.08 (50.03)
|
| 264 |
+
2025-03-03 03:34:16 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 500/1759] Batch=1.79 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6159 (0.5929) IoU=46.81 (49.18) Prec@50=45.49 (50.33)
|
| 265 |
+
2025-03-03 03:36:45 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 600/1759] Batch=1.43 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4669 (0.5917) IoU=56.32 (49.45) Prec@50=55.22 (50.71)
|
| 266 |
+
2025-03-03 03:39:10 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 700/1759] Batch=1.57 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6299 (0.5933) IoU=44.67 (49.38) Prec@50=45.40 (50.54)
|
| 267 |
+
2025-03-03 03:41:37 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 800/1759] Batch=1.43 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7958 (0.5949) IoU=42.46 (49.27) Prec@50=42.29 (50.46)
|
| 268 |
+
2025-03-03 03:44:03 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 900/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7949 (0.5958) IoU=40.19 (49.17) Prec@50=37.70 (50.46)
|
| 269 |
+
2025-03-03 03:46:30 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1000/1759] Batch=1.75 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4842 (0.5962) IoU=55.96 (49.27) Prec@50=60.33 (50.63)
|
| 270 |
+
2025-03-03 03:48:56 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1100/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5489 (0.5967) IoU=55.96 (49.18) Prec@50=60.76 (50.47)
|
| 271 |
+
2025-03-03 03:51:24 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1200/1759] Batch=1.77 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5284 (0.5968) IoU=48.37 (49.16) Prec@50=56.47 (50.49)
|
| 272 |
+
2025-03-03 03:53:48 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1300/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5359 (0.5963) IoU=50.98 (49.14) Prec@50=49.26 (50.53)
|
| 273 |
+
2025-03-03 03:56:18 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1400/1759] Batch=1.38 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.8166 (0.5978) IoU=42.17 (48.99) Prec@50=36.31 (50.28)
|
| 274 |
+
2025-03-03 03:58:44 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6011 (0.5979) IoU=48.88 (49.04) Prec@50=53.27 (50.35)
|
| 275 |
+
2025-03-03 04:01:11 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1600/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7017 (0.5988) IoU=47.34 (48.99) Prec@50=45.04 (50.25)
|
| 276 |
+
2025-03-03 04:03:39 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1700/1759] Batch=1.66 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5102 (0.5985) IoU=51.01 (49.05) Prec@50=53.77 (50.32)
|
| 277 |
+
2025-03-03 04:05:41 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[5/50] mIoU=56.95 oIoU=53.30 Pr@50: 60.75 Pr@60: 53.26 Pr@70: 44.57 Pr@80: 32.76 Pr@90: 14.32
|
| 278 |
+
2025-03-03 04:08:38 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 100/1759] Batch=1.27 (1.49) Data=0.00 (0.04) Lr=0.000100 Loss=0.6468 (0.5652) IoU=49.76 (51.26) Prec@50=42.71 (53.73)
|
| 279 |
+
2025-03-03 04:11:06 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 200/1759] Batch=1.41 (1.49) Data=0.00 (0.03) Lr=0.000100 Loss=0.6561 (0.5522) IoU=44.92 (52.29) Prec@50=38.89 (54.84)
|
| 280 |
+
2025-03-03 04:13:33 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 300/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4852 (0.5548) IoU=52.33 (51.63) Prec@50=54.51 (54.21)
|
| 281 |
+
2025-03-03 04:15:59 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 400/1759] Batch=1.37 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5306 (0.5601) IoU=55.68 (51.13) Prec@50=62.65 (53.69)
|
| 282 |
+
2025-03-03 04:18:27 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 500/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4842 (0.5597) IoU=56.85 (51.28) Prec@50=55.85 (53.78)
|
| 283 |
+
2025-03-03 04:20:53 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 600/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5329 (0.5586) IoU=47.32 (51.24) Prec@50=50.00 (53.65)
|
| 284 |
+
2025-03-03 04:23:18 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 700/1759] Batch=1.14 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5324 (0.5594) IoU=56.30 (51.25) Prec@50=59.52 (53.65)
|
| 285 |
+
2025-03-03 04:25:47 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 800/1759] Batch=1.23 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5867 (0.5610) IoU=50.78 (51.24) Prec@50=52.38 (53.66)
|
| 286 |
+
2025-03-03 04:28:14 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 900/1759] Batch=1.49 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5270 (0.5646) IoU=51.31 (51.14) Prec@50=48.21 (53.48)
|
| 287 |
+
2025-03-03 04:30:41 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1000/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5509 (0.5655) IoU=52.33 (51.12) Prec@50=51.79 (53.51)
|
| 288 |
+
2025-03-03 04:33:09 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1100/1759] Batch=1.25 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5011 (0.5665) IoU=59.98 (51.03) Prec@50=73.66 (53.37)
|
| 289 |
+
2025-03-03 04:35:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1200/1759] Batch=1.13 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7173 (0.5666) IoU=49.50 (51.04) Prec@50=52.98 (53.44)
|
| 290 |
+
2025-03-03 04:38:02 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1300/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6343 (0.5665) IoU=46.32 (50.93) Prec@50=45.29 (53.32)
|
| 291 |
+
2025-03-03 04:40:29 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1400/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5961 (0.5674) IoU=49.77 (50.91) Prec@50=50.00 (53.32)
|
| 292 |
+
2025-03-03 04:42:57 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1500/1759] Batch=1.43 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4715 (0.5672) IoU=52.89 (50.96) Prec@50=54.68 (53.44)
|
| 293 |
+
2025-03-03 04:45:21 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1600/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5379 (0.5673) IoU=52.68 (50.94) Prec@50=58.28 (53.43)
|
| 294 |
+
2025-03-03 04:47:46 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1700/1759] Batch=1.70 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6034 (0.5678) IoU=48.31 (50.90) Prec@50=47.55 (53.37)
|
| 295 |
+
2025-03-03 04:49:50 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[6/50] mIoU=57.77 oIoU=53.88 Pr@50: 62.42 Pr@60: 54.46 Pr@70: 45.96 Pr@80: 34.74 Pr@90: 15.64
|
| 296 |
+
2025-03-03 04:52:43 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 100/1759] Batch=1.26 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4772 (0.5159) IoU=58.15 (52.95) Prec@50=68.60 (57.46)
|
| 297 |
+
2025-03-03 04:55:10 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 200/1759] Batch=1.26 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5554 (0.5202) IoU=53.89 (52.93) Prec@50=58.48 (57.39)
|
| 298 |
+
2025-03-03 04:57:36 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 300/1759] Batch=1.34 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4606 (0.5208) IoU=58.93 (53.09) Prec@50=64.43 (57.32)
|
| 299 |
+
2025-03-03 05:00:03 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 400/1759] Batch=1.41 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4532 (0.5196) IoU=59.22 (53.45) Prec@50=74.80 (57.58)
|
| 300 |
+
2025-03-03 05:02:29 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 500/1759] Batch=1.38 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5778 (0.5210) IoU=59.85 (53.71) Prec@50=80.95 (57.77)
|
| 301 |
+
2025-03-03 05:04:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 600/1759] Batch=1.72 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5753 (0.5226) IoU=56.04 (53.67) Prec@50=55.98 (57.49)
|
| 302 |
+
2025-03-03 05:07:23 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 700/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6087 (0.5253) IoU=51.54 (53.41) Prec@50=50.00 (57.11)
|
| 303 |
+
2025-03-03 05:09:49 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 800/1759] Batch=1.36 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6311 (0.5259) IoU=55.29 (53.46) Prec@50=55.51 (57.13)
|
| 304 |
+
2025-03-03 05:12:16 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 900/1759] Batch=1.26 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7219 (0.5281) IoU=49.02 (53.55) Prec@50=53.87 (57.31)
|
| 305 |
+
2025-03-03 05:14:42 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1000/1759] Batch=1.43 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4692 (0.5275) IoU=65.16 (53.58) Prec@50=78.17 (57.38)
|
| 306 |
+
2025-03-03 05:17:07 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1100/1759] Batch=1.26 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5824 (0.5272) IoU=45.79 (53.59) Prec@50=48.66 (57.39)
|
| 307 |
+
2025-03-03 05:19:35 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1200/1759] Batch=1.81 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4784 (0.5274) IoU=53.63 (53.55) Prec@50=54.79 (57.31)
|
| 308 |
+
2025-03-03 05:22:04 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1300/1759] Batch=1.72 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5570 (0.5278) IoU=51.42 (53.54) Prec@50=50.05 (57.32)
|
| 309 |
+
2025-03-03 05:24:31 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1400/1759] Batch=1.28 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4228 (0.5272) IoU=57.91 (53.52) Prec@50=65.62 (57.28)
|
| 310 |
+
2025-03-03 05:27:00 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1500/1759] Batch=1.83 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.4513 (0.5273) IoU=58.91 (53.56) Prec@50=49.84 (57.28)
|
| 311 |
+
2025-03-03 05:29:27 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1600/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4685 (0.5275) IoU=58.79 (53.56) Prec@50=64.34 (57.24)
|
| 312 |
+
2025-03-03 05:31:54 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1700/1759] Batch=1.62 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4928 (0.5272) IoU=47.91 (53.57) Prec@50=45.88 (57.28)
|
| 313 |
+
2025-03-03 05:33:56 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[7/50] mIoU=60.25 oIoU=55.60 Pr@50: 65.68 Pr@60: 59.39 Pr@70: 51.09 Pr@80: 38.20 Pr@90: 17.59
|
| 314 |
+
2025-03-03 05:36:51 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 100/1759] Batch=1.44 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.4619 (0.4767) IoU=50.05 (57.19) Prec@50=56.96 (61.73)
|
| 315 |
+
2025-03-03 05:39:19 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 200/1759] Batch=1.33 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4224 (0.4827) IoU=60.05 (56.71) Prec@50=67.26 (61.55)
|
| 316 |
+
2025-03-03 05:41:44 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 300/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4845 (0.4818) IoU=55.20 (56.42) Prec@50=65.23 (61.33)
|
| 317 |
+
2025-03-03 05:44:13 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 400/1759] Batch=1.95 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5197 (0.4854) IoU=51.30 (56.06) Prec@50=53.19 (60.90)
|
| 318 |
+
2025-03-03 05:46:43 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 500/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5332 (0.4882) IoU=48.13 (55.89) Prec@50=51.34 (60.70)
|
| 319 |
+
2025-03-03 05:49:08 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 600/1759] Batch=1.43 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3839 (0.4883) IoU=52.71 (55.88) Prec@50=55.28 (60.68)
|
| 320 |
+
2025-03-03 05:51:36 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 700/1759] Batch=1.25 (1.48) Data=0.00 (0.02) Lr=0.000100 Loss=0.4711 (0.4896) IoU=53.68 (55.75) Prec@50=63.10 (60.51)
|
| 321 |
+
2025-03-03 05:54:02 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 800/1759] Batch=1.84 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3726 (0.4901) IoU=57.30 (55.78) Prec@50=64.32 (60.39)
|
| 322 |
+
2025-03-03 05:56:27 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 900/1759] Batch=1.25 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4996 (0.4914) IoU=56.04 (55.96) Prec@50=54.91 (60.61)
|
| 323 |
+
2025-03-03 05:58:54 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1000/1759] Batch=1.41 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.5232 (0.4923) IoU=60.26 (55.92) Prec@50=69.84 (60.50)
|
| 324 |
+
2025-03-03 06:01:18 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1100/1759] Batch=1.80 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.5720 (0.4947) IoU=50.41 (55.79) Prec@50=54.66 (60.34)
|
| 325 |
+
2025-03-03 06:03:42 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1200/1759] Batch=1.44 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4019 (0.4951) IoU=61.91 (55.80) Prec@50=69.10 (60.33)
|
| 326 |
+
2025-03-03 06:06:09 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1300/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5695 (0.4957) IoU=54.77 (55.78) Prec@50=55.75 (60.21)
|
| 327 |
+
2025-03-03 06:08:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1400/1759] Batch=1.84 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6492 (0.4982) IoU=42.49 (55.71) Prec@50=42.14 (60.10)
|
| 328 |
+
2025-03-03 06:11:03 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5565 (0.4978) IoU=49.92 (55.73) Prec@50=46.83 (60.09)
|
| 329 |
+
2025-03-03 06:13:31 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1600/1759] Batch=1.30 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4775 (0.4973) IoU=63.03 (55.75) Prec@50=68.60 (60.11)
|
| 330 |
+
2025-03-03 06:16:00 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1700/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4690 (0.4972) IoU=55.66 (55.79) Prec@50=46.83 (60.15)
|
| 331 |
+
2025-03-03 06:18:02 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[8/50] mIoU=60.71 oIoU=56.70 Pr@50: 66.73 Pr@60: 59.36 Pr@70: 50.62 Pr@80: 38.63 Pr@90: 16.96
|
| 332 |
+
2025-03-03 06:20:56 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 100/1759] Batch=1.45 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3857 (0.4604) IoU=59.86 (58.29) Prec@50=64.48 (63.98)
|
| 333 |
+
2025-03-03 06:23:21 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 200/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4570 (0.4735) IoU=63.38 (57.40) Prec@50=73.66 (62.69)
|
| 334 |
+
2025-03-03 06:25:48 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 300/1759] Batch=1.43 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5468 (0.4705) IoU=55.90 (57.64) Prec@50=59.08 (62.91)
|
| 335 |
+
2025-03-03 06:28:13 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 400/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4366 (0.4683) IoU=63.43 (57.61) Prec@50=68.75 (62.74)
|
| 336 |
+
2025-03-03 06:30:38 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 500/1759] Batch=1.31 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6139 (0.4667) IoU=54.24 (57.58) Prec@50=55.75 (62.60)
|
| 337 |
+
2025-03-03 06:33:03 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 600/1759] Batch=1.22 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5159 (0.4657) IoU=58.25 (57.54) Prec@50=59.82 (62.54)
|
| 338 |
+
2025-03-03 06:35:27 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 700/1759] Batch=1.20 (1.45) Data=0.00 (0.02) Lr=0.000100 Loss=0.4972 (0.4664) IoU=57.24 (57.36) Prec@50=67.56 (62.27)
|
| 339 |
+
2025-03-03 06:37:55 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 800/1759] Batch=1.37 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5249 (0.4666) IoU=58.34 (57.30) Prec@50=60.76 (62.13)
|
| 340 |
+
2025-03-03 06:40:21 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 900/1759] Batch=1.27 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5121 (0.4688) IoU=55.35 (57.25) Prec@50=52.23 (62.01)
|
| 341 |
+
2025-03-03 06:42:47 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1000/1759] Batch=1.45 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5049 (0.4697) IoU=53.27 (57.15) Prec@50=58.41 (61.89)
|
| 342 |
+
2025-03-03 06:45:12 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1100/1759] Batch=1.25 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5198 (0.4704) IoU=57.83 (57.14) Prec@50=58.48 (61.89)
|
| 343 |
+
2025-03-03 06:47:41 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1200/1759] Batch=1.39 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4482 (0.4701) IoU=56.85 (57.09) Prec@50=62.10 (61.84)
|
| 344 |
+
2025-03-03 06:50:09 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1300/1759] Batch=1.39 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6194 (0.4711) IoU=45.18 (57.02) Prec@50=49.31 (61.76)
|
| 345 |
+
2025-03-03 06:52:37 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1400/1759] Batch=1.26 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5719 (0.4718) IoU=53.01 (56.94) Prec@50=56.99 (61.68)
|
| 346 |
+
2025-03-03 06:55:05 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1500/1759] Batch=1.31 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3935 (0.4720) IoU=62.80 (56.86) Prec@50=68.85 (61.58)
|
| 347 |
+
2025-03-03 06:57:30 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1600/1759] Batch=1.32 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5472 (0.4724) IoU=60.17 (56.79) Prec@50=65.97 (61.48)
|
| 348 |
+
2025-03-03 06:59:58 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1700/1759] Batch=1.38 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3759 (0.4721) IoU=60.89 (56.73) Prec@50=63.24 (61.42)
|
| 349 |
+
2025-03-03 07:01:58 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[9/50] mIoU=61.79 oIoU=57.95 Pr@50: 68.91 Pr@60: 61.84 Pr@70: 54.15 Pr@80: 40.99 Pr@90: 18.32
|
| 350 |
+
2025-03-03 07:04:52 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 100/1759] Batch=1.42 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4251 (0.4314) IoU=59.89 (58.14) Prec@50=68.54 (63.57)
|
| 351 |
+
2025-03-03 07:07:22 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 200/1759] Batch=1.69 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4248 (0.4374) IoU=62.13 (58.02) Prec@50=67.50 (63.52)
|
| 352 |
+
2025-03-03 07:09:47 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 300/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3592 (0.4369) IoU=56.82 (58.28) Prec@50=68.06 (63.59)
|
| 353 |
+
2025-03-03 07:12:15 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 400/1759] Batch=1.35 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.4061 (0.4374) IoU=64.88 (58.70) Prec@50=70.29 (63.98)
|
| 354 |
+
2025-03-03 07:14:41 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 500/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4022 (0.4395) IoU=60.43 (58.72) Prec@50=59.42 (63.88)
|
| 355 |
+
2025-03-03 07:17:09 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 600/1759] Batch=1.32 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6635 (0.4405) IoU=42.66 (58.58) Prec@50=49.95 (63.78)
|
| 356 |
+
2025-03-03 07:19:37 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 700/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4055 (0.4416) IoU=63.60 (58.54) Prec@50=67.86 (63.75)
|
| 357 |
+
2025-03-03 07:22:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 800/1759] Batch=1.22 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3670 (0.4434) IoU=68.35 (58.42) Prec@50=73.66 (63.65)
|
| 358 |
+
2025-03-03 07:24:32 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 900/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4577 (0.4437) IoU=59.77 (58.38) Prec@50=72.02 (63.55)
|
| 359 |
+
2025-03-03 07:26:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1000/1759] Batch=1.38 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3914 (0.4433) IoU=56.84 (58.38) Prec@50=71.92 (63.59)
|
| 360 |
+
2025-03-03 07:29:24 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1100/1759] Batch=1.40 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4301 (0.4452) IoU=64.13 (58.34) Prec@50=64.24 (63.54)
|
| 361 |
+
2025-03-03 07:31:49 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1200/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5309 (0.4448) IoU=51.24 (58.37) Prec@50=55.62 (63.55)
|
| 362 |
+
2025-03-03 07:34:15 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1300/1759] Batch=1.31 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5097 (0.4461) IoU=53.53 (58.35) Prec@50=59.08 (63.49)
|
| 363 |
+
2025-03-03 07:36:42 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1400/1759] Batch=1.40 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4394 (0.4461) IoU=58.68 (58.39) Prec@50=59.52 (63.55)
|
| 364 |
+
2025-03-03 07:39:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1500/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4440 (0.4464) IoU=59.17 (58.38) Prec@50=67.06 (63.51)
|
| 365 |
+
2025-03-03 07:41:31 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1600/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4039 (0.4464) IoU=56.05 (58.33) Prec@50=60.76 (63.41)
|
| 366 |
+
2025-03-03 07:43:56 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1700/1759] Batch=1.61 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4076 (0.4472) IoU=57.63 (58.25) Prec@50=65.15 (63.35)
|
| 367 |
+
2025-03-03 07:46:00 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[10/50] mIoU=61.95 oIoU=58.15 Pr@50: 69.22 Pr@60: 62.58 Pr@70: 54.46 Pr@80: 42.31 Pr@90: 19.91
|
| 368 |
+
2025-03-03 07:48:54 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 100/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4570 (0.3980) IoU=60.56 (59.63) Prec@50=61.71 (65.29)
|
| 369 |
+
2025-03-03 07:51:22 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 200/1759] Batch=1.44 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4088 (0.4088) IoU=55.49 (58.90) Prec@50=56.47 (64.34)
|
| 370 |
+
2025-03-03 07:53:49 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 300/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5129 (0.4076) IoU=58.42 (59.49) Prec@50=58.73 (65.27)
|
| 371 |
+
2025-03-03 07:56:14 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 400/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5632 (0.4112) IoU=48.45 (59.59) Prec@50=45.14 (65.45)
|
| 372 |
+
2025-03-03 07:58:39 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 500/1759] Batch=1.45 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3631 (0.4145) IoU=60.37 (59.62) Prec@50=61.03 (65.47)
|
| 373 |
+
2025-03-03 08:01:04 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 600/1759] Batch=1.42 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.2917 (0.4155) IoU=71.40 (59.60) Prec@50=78.17 (65.43)
|
| 374 |
+
2025-03-03 08:03:30 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 700/1759] Batch=1.44 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4789 (0.4174) IoU=58.62 (59.62) Prec@50=61.34 (65.38)
|
| 375 |
+
2025-03-03 08:05:57 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 800/1759] Batch=1.33 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3516 (0.4210) IoU=60.97 (59.48) Prec@50=69.59 (65.11)
|
| 376 |
+
2025-03-03 08:08:21 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 900/1759] Batch=1.95 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3417 (0.4205) IoU=59.15 (59.43) Prec@50=67.80 (65.06)
|
| 377 |
+
2025-03-03 08:10:47 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1000/1759] Batch=1.21 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3900 (0.4195) IoU=64.07 (59.51) Prec@50=68.75 (65.12)
|
| 378 |
+
2025-03-03 08:13:12 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1100/1759] Batch=1.51 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3569 (0.4213) IoU=69.53 (59.50) Prec@50=74.52 (65.06)
|
| 379 |
+
2025-03-03 08:15:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1200/1759] Batch=1.22 (1.46) Data=0.01 (0.02) Lr=0.000100 Loss=0.5098 (0.4224) IoU=58.16 (59.38) Prec@50=62.05 (64.97)
|
| 380 |
+
2025-03-03 08:18:03 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1300/1759] Batch=1.35 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6300 (0.4239) IoU=47.06 (59.24) Prec@50=51.79 (64.82)
|
| 381 |
+
2025-03-03 08:20:26 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1400/1759] Batch=1.32 (1.46) Data=0.01 (0.02) Lr=0.000100 Loss=0.3446 (0.4252) IoU=64.94 (59.19) Prec@50=75.60 (64.73)
|
| 382 |
+
2025-03-03 08:22:51 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1500/1759] Batch=1.28 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4085 (0.4247) IoU=60.33 (59.26) Prec@50=66.96 (64.89)
|
| 383 |
+
2025-03-03 08:25:17 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1600/1759] Batch=1.81 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.2942 (0.4246) IoU=60.72 (59.32) Prec@50=59.59 (64.92)
|
| 384 |
+
2025-03-03 08:27:44 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1700/1759] Batch=1.20 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3268 (0.4244) IoU=71.48 (59.39) Prec@50=83.04 (64.99)
|
| 385 |
+
2025-03-03 08:29:47 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[11/50] mIoU=62.68 oIoU=58.89 Pr@50: 69.95 Pr@60: 62.97 Pr@70: 54.85 Pr@80: 43.25 Pr@90: 20.81
|
| 386 |
+
[2025-03-03 08:31:52,535] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGINT death signal, shutting down workers
|
| 387 |
+
[2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017036 closing signal SIGINT
|
| 388 |
+
[2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017037 closing signal SIGINT
|
| 389 |
+
[2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017038 closing signal SIGINT
|
| 390 |
+
[2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017039 closing signal SIGINT
|
| 391 |
+
Exception in thread Thread-24:
|
| 392 |
+
Traceback (most recent call last):
|
| 393 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/threading.py", line 980, in _bootstrap_inner
|
| 394 |
+
self.run()
|
| 395 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/threading.py", line 917, in run
|
| 396 |
+
self._target(*self._args, **self._kwargs)
|
| 397 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
|
| 398 |
+
do_one_step()
|
| 399 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
|
| 400 |
+
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
| 401 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/multiprocessing/queues.py", line 122, in get
|
| 402 |
+
return _ForkingPickler.loads(res)
|
| 403 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 355, in rebuild_storage_fd
|
| 404 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa02edd68b0>
|
| 405 |
+
Traceback (most recent call last):
|
| 406 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 407 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6b5cf638b0>
|
| 408 |
+
Traceback (most recent call last):
|
| 409 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 410 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4c191bd8b0>
|
| 411 |
+
Traceback (most recent call last):
|
| 412 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 413 |
+
[2025-03-03 08:31:52,707] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017036 closing signal SIGTERM
|
| 414 |
+
[2025-03-03 08:31:52,707] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017037 closing signal SIGTERM
|
| 415 |
+
[2025-03-03 08:31:52,708] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017038 closing signal SIGTERM
|
| 416 |
+
[2025-03-03 08:31:52,708] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017039 closing signal SIGTERM
|
| 417 |
+
Traceback (most recent call last):
|
| 418 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 736, in run
|
| 419 |
+
result = self._invoke_run(role)
|
| 420 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 877, in _invoke_run
|
| 421 |
+
time.sleep(monitor_interval)
|
| 422 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 423 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 424 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
|
| 425 |
+
|
| 426 |
+
During handling of the above exception, another exception occurred:
|
| 427 |
+
|
| 428 |
+
Traceback (most recent call last):
|
| 429 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 743, in run
|
| 430 |
+
self._shutdown(e.sigval)
|
| 431 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
|
| 432 |
+
self._pcontext.close(death_sig)
|
| 433 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
|
| 434 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 435 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
|
| 436 |
+
handler.proc.wait(time_to_wait)
|
| 437 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
|
| 438 |
+
return self._wait(timeout=timeout)
|
| 439 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
|
| 440 |
+
time.sleep(delay)
|
| 441 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 442 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 443 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
|
| 444 |
+
|
| 445 |
+
During handling of the above exception, another exception occurred:
|
| 446 |
+
|
| 447 |
+
Traceback (most recent call last):
|
| 448 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 197, in _run_module_as_main
|
| 449 |
+
return _run_code(code, main_globals, None,
|
| 450 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 87, in _run_code
|
| 451 |
+
exec(code, run_globals)
|
| 452 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 196, in <module>
|
| 453 |
+
main()
|
| 454 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 192, in main
|
| 455 |
+
launch(args)
|
| 456 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 177, in launch
|
| 457 |
+
run(args)
|
| 458 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/run.py", line 797, in run
|
| 459 |
+
elastic_launch(
|
| 460 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
|
| 461 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 462 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 255, in launch_agent
|
| 463 |
+
result = agent.run()
|
| 464 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
|
| 465 |
+
result = f(*args, **kwargs)
|
| 466 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 748, in run
|
| 467 |
+
self._shutdown()
|
| 468 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
|
| 469 |
+
self._pcontext.close(death_sig)
|
| 470 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
|
| 471 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 472 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
|
| 473 |
+
handler.proc.wait(time_to_wait)
|
| 474 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
|
| 475 |
+
return self._wait(timeout=timeout)
|
| 476 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
|
| 477 |
+
time.sleep(delay)
|
| 478 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 479 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 480 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
|
CGFormer/bash_logs/ACE_filter050_rev.log
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated
|
| 2 |
+
and will be removed in future. Use torchrun.
|
| 3 |
+
Note that --use-env is set by default in torchrun.
|
| 4 |
+
If your script expects `--local-rank` argument to be set, please
|
| 5 |
+
change it to read from `os.environ['LOCAL_RANK']` instead. See
|
| 6 |
+
https://pytorch.org/docs/stable/distributed.html#launch-utility for
|
| 7 |
+
further instructions
|
| 8 |
+
|
| 9 |
+
warnings.warn(
|
| 10 |
+
[2025-03-03 16:23:35,171] torch.distributed.run: [WARNING]
|
| 11 |
+
[2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] *****************************************
|
| 12 |
+
[2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 13 |
+
[2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] *****************************************
|
| 14 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 15 |
+
check_for_updates()
|
| 16 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 17 |
+
check_for_updates()
|
| 18 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 19 |
+
check_for_updates()
|
| 20 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 21 |
+
check_for_updates()
|
| 22 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 23 |
+
check_for_updates()
|
| 24 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
|
| 25 |
+
check_for_updates()
|
| 26 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 27 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 28 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 29 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 30 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 31 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 32 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 33 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 34 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 35 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 36 |
+
/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
|
| 37 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
|
| 38 |
+
2025-03-03 16:24:00.550 | INFO | __main__:main:66 - LOCAL_RANK from env: 2
|
| 39 |
+
2025-03-03 16:24:00.550 | INFO | __main__:main:66 - LOCAL_RANK from env: 5
|
| 40 |
+
2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 4
|
| 41 |
+
2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 0
|
| 42 |
+
2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 1
|
| 43 |
+
2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 3
|
| 44 |
+
2025-03-03 16:24:00 | INFO | __main__:90 - Starting with GPU: 0, Rank: 0, World Size: 6
|
| 45 |
+
git root error: Cmd('git') failed due to: exit code(128)
|
| 46 |
+
cmdline: git rev-parse --show-toplevel
|
| 47 |
+
stderr: 'fatal: detected dubious ownership in repository at '/data2/projects/chaeyun/CGFormer'
|
| 48 |
+
To add an exception for this directory, call:
|
| 49 |
+
|
| 50 |
+
git config --global --add safe.directory /data2/projects/chaeyun/CGFormer'
|
| 51 |
+
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
|
| 52 |
+
wandb: Tracking run with wandb version 0.19.1
|
| 53 |
+
wandb: W&B syncing is set to `offline` in this directory.
|
| 54 |
+
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
|
| 55 |
+
node03:2316571:2316571 [0] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 56 |
+
node03:2316571:2316571 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 57 |
+
node03:2316571:2316571 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 58 |
+
node03:2316571:2316571 [0] NCCL INFO cudaDriverVersion 12070
|
| 59 |
+
NCCL version 2.18.5+cuda11.8
|
| 60 |
+
node03:2316574:2316574 [3] NCCL INFO cudaDriverVersion 12070
|
| 61 |
+
node03:2316575:2316575 [4] NCCL INFO cudaDriverVersion 12070
|
| 62 |
+
node03:2316572:2316572 [1] NCCL INFO cudaDriverVersion 12070
|
| 63 |
+
node03:2316576:2316576 [5] NCCL INFO cudaDriverVersion 12070
|
| 64 |
+
node03:2316573:2316573 [2] NCCL INFO cudaDriverVersion 12070
|
| 65 |
+
node03:2316575:2316575 [4] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 66 |
+
node03:2316574:2316574 [3] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 67 |
+
node03:2316576:2316576 [5] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 68 |
+
node03:2316572:2316572 [1] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 69 |
+
node03:2316573:2316573 [2] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
|
| 70 |
+
node03:2316575:2316575 [4] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 71 |
+
node03:2316575:2316575 [4] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 72 |
+
node03:2316574:2316574 [3] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 73 |
+
node03:2316574:2316574 [3] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 74 |
+
node03:2316572:2316572 [1] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 75 |
+
node03:2316572:2316572 [1] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 76 |
+
node03:2316576:2316576 [5] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 77 |
+
node03:2316573:2316573 [2] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
|
| 78 |
+
node03:2316576:2316576 [5] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 79 |
+
node03:2316573:2316573 [2] NCCL INFO NET/Plugin : No plugin found, using internal implementation
|
| 80 |
+
node03:2316571:2316708 [0] NCCL INFO NET/IB : No device found.
|
| 81 |
+
node03:2316573:2316712 [2] NCCL INFO NET/IB : No device found.
|
| 82 |
+
node03:2316571:2316708 [0] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 83 |
+
node03:2316571:2316708 [0] NCCL INFO Using network Socket
|
| 84 |
+
node03:2316573:2316712 [2] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 85 |
+
node03:2316573:2316712 [2] NCCL INFO Using network Socket
|
| 86 |
+
node03:2316574:2316710 [3] NCCL INFO NET/IB : No device found.
|
| 87 |
+
node03:2316574:2316710 [3] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 88 |
+
node03:2316574:2316710 [3] NCCL INFO Using network Socket
|
| 89 |
+
node03:2316572:2316711 [1] NCCL INFO NET/IB : No device found.
|
| 90 |
+
node03:2316572:2316711 [1] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 91 |
+
node03:2316572:2316711 [1] NCCL INFO Using network Socket
|
| 92 |
+
node03:2316576:2316713 [5] NCCL INFO NET/IB : No device found.
|
| 93 |
+
node03:2316576:2316713 [5] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 94 |
+
node03:2316576:2316713 [5] NCCL INFO Using network Socket
|
| 95 |
+
node03:2316575:2316709 [4] NCCL INFO NET/IB : No device found.
|
| 96 |
+
node03:2316575:2316709 [4] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
|
| 97 |
+
node03:2316575:2316709 [4] NCCL INFO Using network Socket
|
| 98 |
+
node03:2316571:2316708 [0] NCCL INFO comm 0xa205d10 rank 0 nranks 6 cudaDev 0 nvmlDev 0 busId 12000 commId 0x5c97a0f7f601b696 - Init START
|
| 99 |
+
node03:2316573:2316712 [2] NCCL INFO comm 0xac1c200 rank 2 nranks 6 cudaDev 2 nvmlDev 2 busId 14000 commId 0x5c97a0f7f601b696 - Init START
|
| 100 |
+
node03:2316576:2316713 [5] NCCL INFO comm 0x9b6ede0 rank 5 nranks 6 cudaDev 5 nvmlDev 5 busId c1000 commId 0x5c97a0f7f601b696 - Init START
|
| 101 |
+
node03:2316572:2316711 [1] NCCL INFO comm 0xa155230 rank 1 nranks 6 cudaDev 1 nvmlDev 1 busId 13000 commId 0x5c97a0f7f601b696 - Init START
|
| 102 |
+
node03:2316575:2316709 [4] NCCL INFO comm 0xa46f6f0 rank 4 nranks 6 cudaDev 4 nvmlDev 4 busId c0000 commId 0x5c97a0f7f601b696 - Init START
|
| 103 |
+
node03:2316574:2316710 [3] NCCL INFO comm 0xadbe390 rank 3 nranks 6 cudaDev 3 nvmlDev 3 busId 48000 commId 0x5c97a0f7f601b696 - Init START
|
| 104 |
+
node03:2316573:2316712 [2] NCCL INFO Setting affinity for GPU 2 to 14005500,00140055
|
| 105 |
+
node03:2316572:2316711 [1] NCCL INFO Setting affinity for GPU 1 to 14005500,00140055
|
| 106 |
+
node03:2316574:2316710 [3] NCCL INFO Setting affinity for GPU 3 to 14005500,00140055
|
| 107 |
+
node03:2316571:2316708 [0] NCCL INFO Setting affinity for GPU 0 to 14005500,00140055
|
| 108 |
+
node03:2316571:2316708 [0] NCCL INFO Channel 00/02 : 0 1 2 3 4 5
|
| 109 |
+
node03:2316575:2316709 [4] NCCL INFO Trees [0] 5/-1/-1->4->3 [1] 5/-1/-1->4->3
|
| 110 |
+
node03:2316572:2316711 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0
|
| 111 |
+
node03:2316574:2316710 [3] NCCL INFO Trees [0] 4/-1/-1->3->2 [1] 4/-1/-1->3->2
|
| 112 |
+
node03:2316573:2316712 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1
|
| 113 |
+
node03:2316571:2316708 [0] NCCL INFO Channel 01/02 : 0 1 2 3 4 5
|
| 114 |
+
node03:2316576:2316713 [5] NCCL INFO Trees [0] -1/-1/-1->5->4 [1] -1/-1/-1->5->4
|
| 115 |
+
node03:2316575:2316709 [4] NCCL INFO P2P Chunksize set to 131072
|
| 116 |
+
node03:2316572:2316711 [1] NCCL INFO P2P Chunksize set to 131072
|
| 117 |
+
node03:2316574:2316710 [3] NCCL INFO P2P Chunksize set to 131072
|
| 118 |
+
node03:2316573:2316712 [2] NCCL INFO P2P Chunksize set to 131072
|
| 119 |
+
node03:2316571:2316708 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
|
| 120 |
+
node03:2316576:2316713 [5] NCCL INFO P2P Chunksize set to 131072
|
| 121 |
+
node03:2316571:2316708 [0] NCCL INFO P2P Chunksize set to 131072
|
| 122 |
+
node03:2316572:2316711 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/IPC
|
| 123 |
+
node03:2316572:2316711 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/IPC
|
| 124 |
+
node03:2316576:2316713 [5] NCCL INFO Channel 00 : 5[5] -> 0[0] via SHM/direct/direct
|
| 125 |
+
node03:2316573:2316712 [2] NCCL INFO Channel 00 : 2[2] -> 3[3] via SHM/direct/direct
|
| 126 |
+
node03:2316576:2316713 [5] NCCL INFO Channel 01 : 5[5] -> 0[0] via SHM/direct/direct
|
| 127 |
+
node03:2316573:2316712 [2] NCCL INFO Channel 01 : 2[2] -> 3[3] via SHM/direct/direct
|
| 128 |
+
node03:2316575:2316709 [4] NCCL INFO Channel 00/0 : 4[4] -> 5[5] via P2P/IPC
|
| 129 |
+
node03:2316575:2316709 [4] NCCL INFO Channel 01/0 : 4[4] -> 5[5] via P2P/IPC
|
| 130 |
+
node03:2316571:2316708 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/IPC
|
| 131 |
+
node03:2316574:2316710 [3] NCCL INFO Channel 00 : 3[3] -> 4[4] via SHM/direct/direct
|
| 132 |
+
node03:2316571:2316708 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/IPC
|
| 133 |
+
node03:2316574:2316710 [3] NCCL INFO Channel 01 : 3[3] -> 4[4] via SHM/direct/direct
|
| 134 |
+
node03:2316575:2316709 [4] NCCL INFO Connected all rings
|
| 135 |
+
node03:2316572:2316711 [1] NCCL INFO Connected all rings
|
| 136 |
+
node03:2316571:2316708 [0] NCCL INFO Connected all rings
|
| 137 |
+
node03:2316573:2316712 [2] NCCL INFO Connected all rings
|
| 138 |
+
node03:2316576:2316713 [5] NCCL INFO Connected all rings
|
| 139 |
+
node03:2316576:2316713 [5] NCCL INFO Channel 00/0 : 5[5] -> 4[4] via P2P/IPC
|
| 140 |
+
node03:2316572:2316711 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/IPC
|
| 141 |
+
node03:2316575:2316709 [4] NCCL INFO Channel 00 : 4[4] -> 3[3] via SHM/direct/direct
|
| 142 |
+
node03:2316574:2316710 [3] NCCL INFO Connected all rings
|
| 143 |
+
node03:2316572:2316711 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/IPC
|
| 144 |
+
node03:2316576:2316713 [5] NCCL INFO Channel 01/0 : 5[5] -> 4[4] via P2P/IPC
|
| 145 |
+
node03:2316575:2316709 [4] NCCL INFO Channel 01 : 4[4] -> 3[3] via SHM/direct/direct
|
| 146 |
+
node03:2316571:2316708 [0] NCCL INFO Connected all trees
|
| 147 |
+
node03:2316571:2316708 [0] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 148 |
+
node03:2316571:2316708 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 149 |
+
node03:2316576:2316713 [5] NCCL INFO Connected all trees
|
| 150 |
+
node03:2316576:2316713 [5] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 151 |
+
node03:2316576:2316713 [5] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 152 |
+
node03:2316573:2316712 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/IPC
|
| 153 |
+
node03:2316573:2316712 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/IPC
|
| 154 |
+
node03:2316572:2316711 [1] NCCL INFO Connected all trees
|
| 155 |
+
node03:2316572:2316711 [1] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 156 |
+
node03:2316572:2316711 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 157 |
+
node03:2316574:2316710 [3] NCCL INFO Channel 00 : 3[3] -> 2[2] via SHM/direct/direct
|
| 158 |
+
node03:2316574:2316710 [3] NCCL INFO Channel 01 : 3[3] -> 2[2] via SHM/direct/direct
|
| 159 |
+
node03:2316573:2316712 [2] NCCL INFO Connected all trees
|
| 160 |
+
node03:2316573:2316712 [2] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 161 |
+
node03:2316573:2316712 [2] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 162 |
+
node03:2316575:2316709 [4] NCCL INFO Connected all trees
|
| 163 |
+
node03:2316575:2316709 [4] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 164 |
+
node03:2316575:2316709 [4] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 165 |
+
node03:2316574:2316710 [3] NCCL INFO Connected all trees
|
| 166 |
+
node03:2316574:2316710 [3] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
|
| 167 |
+
node03:2316574:2316710 [3] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
|
| 168 |
+
node03:2316575:2316709 [4] NCCL INFO comm 0xa46f6f0 rank 4 nranks 6 cudaDev 4 nvmlDev 4 busId c0000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 169 |
+
node03:2316573:2316712 [2] NCCL INFO comm 0xac1c200 rank 2 nranks 6 cudaDev 2 nvmlDev 2 busId 14000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 170 |
+
node03:2316571:2316708 [0] NCCL INFO comm 0xa205d10 rank 0 nranks 6 cudaDev 0 nvmlDev 0 busId 12000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 171 |
+
node03:2316572:2316711 [1] NCCL INFO comm 0xa155230 rank 1 nranks 6 cudaDev 1 nvmlDev 1 busId 13000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 172 |
+
node03:2316574:2316710 [3] NCCL INFO comm 0xadbe390 rank 3 nranks 6 cudaDev 3 nvmlDev 3 busId 48000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 173 |
+
node03:2316576:2316713 [5] NCCL INFO comm 0x9b6ede0 rank 5 nranks 6 cudaDev 5 nvmlDev 5 busId c1000 commId 0x5c97a0f7f601b696 - Init COMPLETE
|
| 174 |
+
2025-03-03 16:24:03 | INFO | model:31 - Window size 12!
|
| 175 |
+
2025-03-03 16:24:03 | INFO | model:51 - Initializing Multi-modal Swin Transformer weights from ckpts/swin_base_patch4_window12_384_22k.pth
|
| 176 |
+
2025-03-03 16:24:05 | INFO | model.backbone:459 - loading swin success !!!
|
| 177 |
+
2025-03-03 16:24:08 | INFO | __main__:144 - Model moved to GPU: 0
|
| 178 |
+
2025-03-03 16:24:08 | INFO | __main__:145 - amsgrad: True
|
| 179 |
+
batch_size: 30
|
| 180 |
+
batch_size_val: 16
|
| 181 |
+
bert: bert-base-uncased
|
| 182 |
+
dataset: refcocog_u
|
| 183 |
+
dist_backend: nccl
|
| 184 |
+
dropout: 0.0
|
| 185 |
+
epochs: 50
|
| 186 |
+
evaluate: True
|
| 187 |
+
exclude_multiobj: True
|
| 188 |
+
exp_name: ACE_filter050_rev
|
| 189 |
+
filter_threshold: 0.5
|
| 190 |
+
fusion_drop: 0.0
|
| 191 |
+
gpu: 0
|
| 192 |
+
hp_selection: strict
|
| 193 |
+
input_size: 480
|
| 194 |
+
local_rank: 0
|
| 195 |
+
loss_option: ACE_verbonly
|
| 196 |
+
lr: 0.0001
|
| 197 |
+
lr_backbone: 5e-05
|
| 198 |
+
lr_text_encoder: 5e-05
|
| 199 |
+
manual_seed: 1455390217
|
| 200 |
+
margin_value: 12
|
| 201 |
+
mask_root: data/masks/refcocog_u
|
| 202 |
+
metric_learning: True
|
| 203 |
+
metric_loss_weight: 0.1
|
| 204 |
+
metric_mode: hardpos_only_sbertsim_refined
|
| 205 |
+
mha: 8-8-8-8
|
| 206 |
+
mixup_lasttwo: False
|
| 207 |
+
num_token: 2
|
| 208 |
+
output_dir: exp/refcoco_u/ACE_filter050_rev
|
| 209 |
+
output_folder: exp/refcoco_u
|
| 210 |
+
print_freq: 100
|
| 211 |
+
rank: 0
|
| 212 |
+
resume: None
|
| 213 |
+
save_freq: 1
|
| 214 |
+
start_epoch: 0
|
| 215 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 216 |
+
swin_type: base
|
| 217 |
+
sync_bn: True
|
| 218 |
+
temperature: 0.07
|
| 219 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 220 |
+
test_split: test
|
| 221 |
+
token_dim: 512
|
| 222 |
+
train_lmdb: data/lmdb/refcocog_u/train.lmdb
|
| 223 |
+
train_split: train
|
| 224 |
+
val_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 225 |
+
val_split: val
|
| 226 |
+
vis_dim: 512
|
| 227 |
+
visualize: False
|
| 228 |
+
weight: None
|
| 229 |
+
weight_decay: 0.0001
|
| 230 |
+
window12: True
|
| 231 |
+
word_dim: 768
|
| 232 |
+
word_len: 20
|
| 233 |
+
workers: 32
|
| 234 |
+
workers_val: 8
|
| 235 |
+
world_size: 6
|
| 236 |
+
2025-03-03 16:26:29 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 100/1407] Batch=1.30 (1.38) Data=0.00 (0.07) Lr=0.000100 Loss=1.0907 (1.1761) IoU=26.22 (19.83) Prec@50=23.61 (8.67)
|
| 237 |
+
2025-03-03 16:28:41 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 200/1407] Batch=1.29 (1.35) Data=0.00 (0.05) Lr=0.000100 Loss=0.9969 (1.0899) IoU=29.22 (24.57) Prec@50=25.56 (13.58)
|
| 238 |
+
2025-03-03 16:30:54 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 300/1407] Batch=1.36 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.9088 (1.0496) IoU=35.17 (26.91) Prec@50=21.79 (16.03)
|
| 239 |
+
2025-03-03 16:33:06 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 400/1407] Batch=1.22 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.8703 (1.0167) IoU=35.01 (28.40) Prec@50=21.29 (17.80)
|
| 240 |
+
2025-03-03 16:35:19 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 500/1407] Batch=1.15 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.9754 (0.9922) IoU=31.02 (29.64) Prec@50=17.78 (19.37)
|
| 241 |
+
2025-03-03 16:37:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 600/1407] Batch=1.16 (1.33) Data=0.00 (0.04) Lr=0.000100 Loss=0.8838 (0.9712) IoU=33.06 (30.63) Prec@50=26.75 (20.87)
|
| 242 |
+
2025-03-03 16:39:43 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 700/1407] Batch=1.24 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.7971 (0.9545) IoU=34.82 (31.32) Prec@50=30.06 (21.88)
|
| 243 |
+
2025-03-03 16:41:53 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 800/1407] Batch=1.48 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.9076 (0.9414) IoU=35.91 (31.87) Prec@50=23.91 (22.66)
|
| 244 |
+
2025-03-03 16:44:04 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 900/1407] Batch=1.29 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.9933 (0.9301) IoU=30.51 (32.29) Prec@50=18.81 (23.24)
|
| 245 |
+
2025-03-03 16:46:13 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1000/1407] Batch=1.26 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7584 (0.9199) IoU=40.61 (32.82) Prec@50=30.95 (23.96)
|
| 246 |
+
2025-03-03 16:48:24 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1100/1407] Batch=1.69 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8691 (0.9110) IoU=35.25 (33.29) Prec@50=19.97 (24.69)
|
| 247 |
+
2025-03-03 16:50:37 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1200/1407] Batch=1.39 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8363 (0.9033) IoU=34.11 (33.67) Prec@50=28.97 (25.27)
|
| 248 |
+
2025-03-03 16:52:49 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1300/1407] Batch=1.35 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8169 (0.8959) IoU=34.20 (34.05) Prec@50=26.85 (25.87)
|
| 249 |
+
2025-03-03 16:55:00 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1400/1407] Batch=1.09 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7557 (0.8885) IoU=44.41 (34.49) Prec@50=41.27 (26.50)
|
| 250 |
+
2025-03-03 16:55:45 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[1/50] mIoU=44.49 oIoU=42.87 Pr@50: 40.95 Pr@60: 30.46 Pr@70: 20.32 Pr@80: 11.81 Pr@90: 2.80
|
| 251 |
+
2025-03-03 16:58:21 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 100/1407] Batch=1.28 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.9049 (0.7654) IoU=35.26 (41.03) Prec@50=29.66 (36.59)
|
| 252 |
+
2025-03-03 17:00:37 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 200/1407] Batch=1.68 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.7241 (0.7458) IoU=40.81 (41.92) Prec@50=42.92 (38.62)
|
| 253 |
+
2025-03-03 17:02:47 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 300/1407] Batch=1.42 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8373 (0.7440) IoU=40.21 (42.04) Prec@50=32.79 (38.73)
|
| 254 |
+
2025-03-03 17:04:56 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 400/1407] Batch=1.33 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8551 (0.7431) IoU=35.71 (41.98) Prec@50=34.74 (38.65)
|
| 255 |
+
2025-03-03 17:07:05 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 500/1407] Batch=1.56 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7449 (0.7434) IoU=44.00 (41.89) Prec@50=46.22 (38.68)
|
| 256 |
+
2025-03-03 17:09:14 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 600/1407] Batch=1.31 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7519 (0.7414) IoU=38.65 (41.92) Prec@50=30.89 (38.80)
|
| 257 |
+
2025-03-03 17:11:26 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 700/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6646 (0.7398) IoU=43.75 (42.04) Prec@50=41.87 (39.10)
|
| 258 |
+
2025-03-03 17:13:36 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 800/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7340 (0.7378) IoU=42.55 (42.15) Prec@50=39.05 (39.38)
|
| 259 |
+
2025-03-03 17:15:44 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 900/1407] Batch=1.77 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6731 (0.7369) IoU=43.04 (42.09) Prec@50=48.61 (39.34)
|
| 260 |
+
2025-03-03 17:17:53 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1000/1407] Batch=1.36 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6417 (0.7343) IoU=45.02 (42.19) Prec@50=51.01 (39.55)
|
| 261 |
+
2025-03-03 17:20:04 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1100/1407] Batch=1.30 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.8503 (0.7329) IoU=38.33 (42.24) Prec@50=35.36 (39.68)
|
| 262 |
+
2025-03-03 17:22:15 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1200/1407] Batch=1.62 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.7639 (0.7322) IoU=45.07 (42.29) Prec@50=34.68 (39.71)
|
| 263 |
+
2025-03-03 17:24:27 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1300/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8485 (0.7314) IoU=36.88 (42.42) Prec@50=38.10 (39.84)
|
| 264 |
+
2025-03-03 17:26:37 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1400/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.8243 (0.7302) IoU=44.98 (42.64) Prec@50=41.79 (40.17)
|
| 265 |
+
2025-03-03 17:27:21 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[2/50] mIoU=50.35 oIoU=48.08 Pr@50: 51.28 Pr@60: 40.60 Pr@70: 30.92 Pr@80: 19.46 Pr@90: 6.02
|
| 266 |
+
2025-03-03 17:29:59 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 100/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6662 (0.6521) IoU=43.34 (48.06) Prec@50=45.54 (48.24)
|
| 267 |
+
2025-03-03 17:32:11 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 200/1407] Batch=1.32 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.6270 (0.6581) IoU=50.53 (47.15) Prec@50=58.53 (47.12)
|
| 268 |
+
2025-03-03 17:34:21 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 300/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7448 (0.6597) IoU=42.87 (46.97) Prec@50=36.43 (46.95)
|
| 269 |
+
2025-03-03 17:36:33 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 400/1407] Batch=1.39 (1.31) Data=0.01 (0.03) Lr=0.000100 Loss=0.6364 (0.6586) IoU=41.82 (46.78) Prec@50=42.46 (46.52)
|
| 270 |
+
2025-03-03 17:38:45 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 500/1407] Batch=1.40 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7957 (0.6597) IoU=43.02 (46.63) Prec@50=32.08 (46.34)
|
| 271 |
+
2025-03-03 17:40:56 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 600/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7063 (0.6615) IoU=44.00 (46.61) Prec@50=39.84 (46.29)
|
| 272 |
+
2025-03-03 17:43:06 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 700/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5882 (0.6644) IoU=54.05 (46.60) Prec@50=57.14 (46.15)
|
| 273 |
+
2025-03-03 17:45:16 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 800/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6320 (0.6651) IoU=48.72 (46.60) Prec@50=47.24 (46.12)
|
| 274 |
+
2025-03-03 17:47:27 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 900/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7309 (0.6651) IoU=44.48 (46.71) Prec@50=40.69 (46.31)
|
| 275 |
+
2025-03-03 17:49:35 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1000/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8106 (0.6643) IoU=33.29 (46.83) Prec@50=28.57 (46.37)
|
| 276 |
+
2025-03-03 17:51:46 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1100/1407] Batch=1.32 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6855 (0.6630) IoU=43.40 (46.90) Prec@50=45.09 (46.52)
|
| 277 |
+
2025-03-03 17:53:58 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1200/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6270 (0.6619) IoU=51.54 (47.00) Prec@50=47.42 (46.63)
|
| 278 |
+
2025-03-03 17:56:09 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1300/1407] Batch=1.38 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6188 (0.6614) IoU=48.14 (47.06) Prec@50=53.94 (46.75)
|
| 279 |
+
2025-03-03 17:58:19 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1400/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6312 (0.6612) IoU=53.97 (47.06) Prec@50=59.68 (46.75)
|
| 280 |
+
2025-03-03 17:59:03 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[3/50] mIoU=54.39 oIoU=51.28 Pr@50: 57.42 Pr@60: 48.45 Pr@70: 39.20 Pr@80: 27.04 Pr@90: 10.96
|
| 281 |
+
2025-03-03 18:01:40 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 100/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7058 (0.6106) IoU=38.22 (48.27) Prec@50=32.74 (49.35)
|
| 282 |
+
2025-03-03 18:03:51 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 200/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5542 (0.6165) IoU=43.66 (47.87) Prec@50=45.73 (48.88)
|
| 283 |
+
2025-03-03 18:06:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 300/1407] Batch=1.51 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5978 (0.6161) IoU=47.81 (48.07) Prec@50=44.17 (49.12)
|
| 284 |
+
2025-03-03 18:08:11 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 400/1407] Batch=1.11 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7430 (0.6135) IoU=45.09 (48.32) Prec@50=44.84 (49.50)
|
| 285 |
+
2025-03-03 18:10:22 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 500/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6319 (0.6166) IoU=49.79 (48.05) Prec@50=47.02 (49.14)
|
| 286 |
+
2025-03-03 18:12:33 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 600/1407] Batch=1.59 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6365 (0.6162) IoU=41.68 (48.00) Prec@50=35.05 (48.94)
|
| 287 |
+
2025-03-03 18:14:46 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 700/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5867 (0.6171) IoU=53.39 (47.94) Prec@50=57.92 (48.86)
|
| 288 |
+
2025-03-03 18:16:55 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 800/1407] Batch=1.62 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6259 (0.6161) IoU=45.88 (48.17) Prec@50=48.51 (49.16)
|
| 289 |
+
2025-03-03 18:19:05 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 900/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6503 (0.6170) IoU=50.25 (48.31) Prec@50=51.11 (49.28)
|
| 290 |
+
2025-03-03 18:21:17 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1000/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6409 (0.6158) IoU=48.29 (48.45) Prec@50=50.24 (49.47)
|
| 291 |
+
2025-03-03 18:23:27 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1100/1407] Batch=1.28 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6323 (0.6145) IoU=50.15 (48.53) Prec@50=51.49 (49.55)
|
| 292 |
+
2025-03-03 18:25:37 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5679 (0.6129) IoU=55.70 (48.67) Prec@50=61.69 (49.72)
|
| 293 |
+
2025-03-03 18:27:50 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1300/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6499 (0.6116) IoU=50.86 (48.78) Prec@50=50.87 (49.90)
|
| 294 |
+
2025-03-03 18:30:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1400/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7374 (0.6114) IoU=39.56 (48.79) Prec@50=34.68 (49.92)
|
| 295 |
+
2025-03-03 18:30:45 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[4/50] mIoU=56.35 oIoU=53.31 Pr@50: 60.68 Pr@60: 51.48 Pr@70: 42.50 Pr@80: 30.30 Pr@90: 11.50
|
| 296 |
+
2025-03-03 18:33:23 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 100/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4113 (0.5488) IoU=61.30 (51.36) Prec@50=69.78 (53.66)
|
| 297 |
+
2025-03-03 18:35:32 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 200/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4562 (0.5582) IoU=59.50 (51.34) Prec@50=66.80 (53.66)
|
| 298 |
+
2025-03-03 18:37:45 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 300/1407] Batch=1.43 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5283 (0.5580) IoU=52.54 (51.87) Prec@50=57.52 (54.45)
|
| 299 |
+
2025-03-03 18:39:56 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 400/1407] Batch=1.33 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5587 (0.5559) IoU=49.20 (52.23) Prec@50=55.06 (54.90)
|
| 300 |
+
2025-03-03 18:42:08 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 500/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6211 (0.5567) IoU=50.09 (52.48) Prec@50=54.76 (55.43)
|
| 301 |
+
2025-03-03 18:44:19 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 600/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5935 (0.5572) IoU=51.16 (52.40) Prec@50=53.10 (55.33)
|
| 302 |
+
2025-03-03 18:46:29 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 700/1407] Batch=1.20 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6169 (0.5571) IoU=55.58 (52.48) Prec@50=61.11 (55.51)
|
| 303 |
+
2025-03-03 18:48:40 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 800/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3962 (0.5566) IoU=63.84 (52.62) Prec@50=73.62 (55.67)
|
| 304 |
+
2025-03-03 18:50:52 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 900/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5741 (0.5561) IoU=51.11 (52.62) Prec@50=49.78 (55.68)
|
| 305 |
+
2025-03-03 18:53:04 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1000/1407] Batch=1.30 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5752 (0.5549) IoU=46.45 (52.75) Prec@50=45.34 (55.80)
|
| 306 |
+
2025-03-03 18:55:17 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1100/1407] Batch=1.49 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5531 (0.5562) IoU=54.64 (52.65) Prec@50=56.15 (55.63)
|
| 307 |
+
2025-03-03 18:57:30 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1200/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.6240 (0.5565) IoU=53.21 (52.64) Prec@50=60.42 (55.60)
|
| 308 |
+
2025-03-03 18:59:42 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1300/1407] Batch=1.35 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.5657 (0.5568) IoU=57.00 (52.65) Prec@50=65.97 (55.57)
|
| 309 |
+
2025-03-03 19:01:53 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1400/1407] Batch=1.10 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4952 (0.5573) IoU=56.29 (52.65) Prec@50=57.14 (55.56)
|
| 310 |
+
2025-03-03 19:02:37 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[5/50] mIoU=59.29 oIoU=56.00 Pr@50: 65.58 Pr@60: 57.50 Pr@70: 48.68 Pr@80: 36.48 Pr@90: 15.23
|
| 311 |
+
2025-03-03 19:05:14 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 100/1407] Batch=1.34 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4567 (0.4978) IoU=56.26 (55.83) Prec@50=58.94 (60.25)
|
| 312 |
+
2025-03-03 19:07:25 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 200/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5726 (0.4995) IoU=47.47 (55.42) Prec@50=48.71 (60.07)
|
| 313 |
+
2025-03-03 19:09:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 300/1407] Batch=1.32 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4761 (0.5059) IoU=58.15 (55.06) Prec@50=62.62 (59.25)
|
| 314 |
+
2025-03-03 19:11:44 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 400/1407] Batch=1.23 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5243 (0.5106) IoU=55.15 (54.66) Prec@50=58.63 (58.77)
|
| 315 |
+
2025-03-03 19:13:54 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 500/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6256 (0.5109) IoU=46.39 (54.44) Prec@50=48.02 (58.53)
|
| 316 |
+
2025-03-03 19:16:05 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 600/1407] Batch=1.45 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5675 (0.5104) IoU=47.12 (54.51) Prec@50=49.41 (58.52)
|
| 317 |
+
2025-03-03 19:18:15 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 700/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6829 (0.5110) IoU=48.30 (54.50) Prec@50=54.37 (58.45)
|
| 318 |
+
2025-03-03 19:20:25 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 800/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5415 (0.5114) IoU=49.64 (54.42) Prec@50=50.30 (58.29)
|
| 319 |
+
2025-03-03 19:22:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 900/1407] Batch=1.28 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4806 (0.5120) IoU=58.17 (54.46) Prec@50=57.84 (58.30)
|
| 320 |
+
2025-03-03 19:24:45 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1000/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5476 (0.5130) IoU=46.38 (54.36) Prec@50=48.41 (58.26)
|
| 321 |
+
2025-03-03 19:26:54 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1100/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4327 (0.5143) IoU=61.53 (54.18) Prec@50=70.24 (58.03)
|
| 322 |
+
2025-03-03 19:29:04 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1200/1407] Batch=1.27 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4774 (0.5143) IoU=55.10 (54.15) Prec@50=55.14 (58.05)
|
| 323 |
+
2025-03-03 19:31:12 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1300/1407] Batch=1.40 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5230 (0.5145) IoU=51.25 (54.14) Prec@50=52.98 (58.04)
|
| 324 |
+
2025-03-03 19:33:21 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1400/1407] Batch=1.39 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6039 (0.5156) IoU=48.91 (54.14) Prec@50=46.36 (57.95)
|
| 325 |
+
2025-03-03 19:34:05 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[6/50] mIoU=60.83 oIoU=57.66 Pr@50: 67.29 Pr@60: 60.37 Pr@70: 51.44 Pr@80: 39.70 Pr@90: 18.07
|
| 326 |
+
2025-03-03 19:36:44 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 100/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3694 (0.4535) IoU=64.46 (57.97) Prec@50=69.13 (63.46)
|
| 327 |
+
2025-03-03 19:38:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4216 (0.4611) IoU=56.11 (57.46) Prec@50=58.08 (62.46)
|
| 328 |
+
2025-03-03 19:41:04 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 300/1407] Batch=1.18 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3782 (0.4635) IoU=61.55 (57.46) Prec@50=69.13 (62.34)
|
| 329 |
+
2025-03-03 19:43:16 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 400/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4767 (0.4658) IoU=57.43 (57.24) Prec@50=67.86 (61.95)
|
| 330 |
+
2025-03-03 19:45:27 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 500/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4054 (0.4673) IoU=62.52 (57.11) Prec@50=69.68 (61.74)
|
| 331 |
+
2025-03-03 19:47:38 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 600/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3967 (0.4709) IoU=63.16 (56.86) Prec@50=71.21 (61.39)
|
| 332 |
+
2025-03-03 19:49:48 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 700/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4413 (0.4737) IoU=62.03 (56.70) Prec@50=80.79 (61.12)
|
| 333 |
+
2025-03-03 19:51:58 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 800/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4080 (0.4749) IoU=62.66 (56.63) Prec@50=68.57 (61.05)
|
| 334 |
+
2025-03-03 19:54:08 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 900/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4112 (0.4776) IoU=60.53 (56.50) Prec@50=67.70 (60.89)
|
| 335 |
+
2025-03-03 19:56:20 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1000/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4987 (0.4784) IoU=61.67 (56.37) Prec@50=65.00 (60.70)
|
| 336 |
+
2025-03-03 19:58:31 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1100/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3991 (0.4795) IoU=56.89 (56.34) Prec@50=59.13 (60.63)
|
| 337 |
+
2025-03-03 20:00:42 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1200/1407] Batch=1.59 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3944 (0.4802) IoU=61.46 (56.35) Prec@50=64.31 (60.58)
|
| 338 |
+
2025-03-03 20:02:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1300/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6520 (0.4818) IoU=52.79 (56.28) Prec@50=57.54 (60.48)
|
| 339 |
+
2025-03-03 20:05:06 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1400/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5381 (0.4822) IoU=54.06 (56.23) Prec@50=57.18 (60.46)
|
| 340 |
+
2025-03-03 20:05:51 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[7/50] mIoU=62.05 oIoU=58.56 Pr@50: 69.35 Pr@60: 61.85 Pr@70: 53.96 Pr@80: 40.87 Pr@90: 19.00
|
| 341 |
+
2025-03-03 20:08:27 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 100/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3871 (0.4439) IoU=57.42 (57.00) Prec@50=60.81 (62.43)
|
| 342 |
+
2025-03-03 20:10:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 200/1407] Batch=1.12 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4502 (0.4448) IoU=59.06 (57.36) Prec@50=60.71 (62.74)
|
| 343 |
+
2025-03-03 20:12:49 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 300/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4527 (0.4435) IoU=61.08 (57.48) Prec@50=67.06 (63.02)
|
| 344 |
+
2025-03-03 20:15:02 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 400/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4179 (0.4455) IoU=59.71 (57.75) Prec@50=65.75 (63.11)
|
| 345 |
+
2025-03-03 20:17:13 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 500/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4548 (0.4472) IoU=59.80 (57.79) Prec@50=66.67 (63.14)
|
| 346 |
+
2025-03-03 20:19:26 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 600/1407] Batch=1.25 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4598 (0.4486) IoU=49.73 (57.75) Prec@50=53.57 (62.92)
|
| 347 |
+
2025-03-03 20:21:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 700/1407] Batch=1.24 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4703 (0.4484) IoU=56.49 (57.69) Prec@50=60.24 (62.81)
|
| 348 |
+
2025-03-03 20:23:48 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 800/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5171 (0.4514) IoU=57.90 (57.64) Prec@50=64.52 (62.68)
|
| 349 |
+
2025-03-03 20:25:58 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 900/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5224 (0.4527) IoU=59.30 (57.70) Prec@50=61.31 (62.74)
|
| 350 |
+
2025-03-03 20:28:09 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5577 (0.4544) IoU=52.09 (57.60) Prec@50=56.67 (62.56)
|
| 351 |
+
2025-03-03 20:30:22 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1100/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4545 (0.4550) IoU=56.95 (57.56) Prec@50=62.40 (62.53)
|
| 352 |
+
2025-03-03 20:32:33 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1200/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3771 (0.4548) IoU=61.85 (57.58) Prec@50=72.30 (62.52)
|
| 353 |
+
2025-03-03 20:34:45 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1300/1407] Batch=1.32 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4313 (0.4551) IoU=60.03 (57.54) Prec@50=65.79 (62.46)
|
| 354 |
+
2025-03-03 20:36:57 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1400/1407] Batch=1.21 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.5147 (0.4564) IoU=53.92 (57.49) Prec@50=52.18 (62.34)
|
| 355 |
+
2025-03-03 20:37:41 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[8/50] mIoU=62.41 oIoU=58.68 Pr@50: 69.23 Pr@60: 61.46 Pr@70: 53.65 Pr@80: 41.69 Pr@90: 18.92
|
| 356 |
+
2025-03-03 20:40:22 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 100/1407] Batch=1.15 (1.32) Data=0.01 (0.03) Lr=0.000100 Loss=0.3586 (0.4313) IoU=70.73 (59.54) Prec@50=83.97 (64.88)
|
| 357 |
+
2025-03-03 20:42:34 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 200/1407] Batch=1.17 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4021 (0.4262) IoU=58.45 (59.39) Prec@50=69.84 (64.84)
|
| 358 |
+
2025-03-03 20:44:44 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 300/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5048 (0.4308) IoU=62.54 (59.46) Prec@50=73.25 (65.16)
|
| 359 |
+
2025-03-03 20:46:54 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 400/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5139 (0.4306) IoU=57.54 (59.52) Prec@50=63.33 (65.13)
|
| 360 |
+
2025-03-03 20:49:05 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 500/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3949 (0.4328) IoU=64.36 (59.33) Prec@50=70.08 (64.82)
|
| 361 |
+
2025-03-03 20:51:16 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 600/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4611 (0.4326) IoU=57.47 (59.35) Prec@50=60.89 (64.74)
|
| 362 |
+
2025-03-03 20:53:26 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 700/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4755 (0.4334) IoU=56.52 (59.25) Prec@50=56.35 (64.61)
|
| 363 |
+
2025-03-03 20:55:37 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 800/1407] Batch=1.29 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4796 (0.4337) IoU=60.29 (59.28) Prec@50=68.59 (64.69)
|
| 364 |
+
2025-03-03 20:57:47 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 900/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4704 (0.4349) IoU=52.80 (59.28) Prec@50=63.89 (64.72)
|
| 365 |
+
2025-03-03 20:59:59 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3763 (0.4347) IoU=65.34 (59.13) Prec@50=72.64 (64.62)
|
| 366 |
+
2025-03-03 21:02:11 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1100/1407] Batch=1.42 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4699 (0.4350) IoU=59.93 (59.02) Prec@50=65.93 (64.56)
|
| 367 |
+
2025-03-03 21:04:23 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1200/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3397 (0.4357) IoU=65.16 (58.84) Prec@50=72.56 (64.37)
|
| 368 |
+
2025-03-03 21:06:34 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1300/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6262 (0.4363) IoU=48.86 (58.78) Prec@50=47.62 (64.23)
|
| 369 |
+
2025-03-03 21:08:46 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1400/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4082 (0.4374) IoU=63.09 (58.70) Prec@50=74.50 (64.19)
|
| 370 |
+
2025-03-03 21:09:29 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[9/50] mIoU=62.25 oIoU=58.06 Pr@50: 69.23 Pr@60: 63.05 Pr@70: 54.78 Pr@80: 42.58 Pr@90: 20.16
|
| 371 |
+
2025-03-03 21:11:56 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 100/1407] Batch=1.28 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3925 (0.4111) IoU=55.59 (59.60) Prec@50=61.05 (66.27)
|
| 372 |
+
2025-03-03 21:14:06 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 200/1407] Batch=1.19 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4349 (0.4082) IoU=58.24 (59.65) Prec@50=63.89 (66.06)
|
| 373 |
+
2025-03-03 21:16:18 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 300/1407] Batch=1.72 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4199 (0.4066) IoU=59.00 (59.75) Prec@50=65.48 (65.97)
|
| 374 |
+
2025-03-03 21:18:31 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 400/1407] Batch=1.18 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.3139 (0.4110) IoU=64.88 (59.49) Prec@50=76.98 (65.56)
|
| 375 |
+
2025-03-03 21:20:42 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 500/1407] Batch=1.24 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4145 (0.4128) IoU=58.26 (59.39) Prec@50=63.10 (65.38)
|
| 376 |
+
2025-03-03 21:22:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 600/1407] Batch=1.47 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4332 (0.4123) IoU=55.08 (59.26) Prec@50=59.82 (65.29)
|
| 377 |
+
2025-03-03 21:25:06 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 700/1407] Batch=1.37 (1.32) Data=0.01 (0.03) Lr=0.000100 Loss=0.5174 (0.4145) IoU=49.25 (59.16) Prec@50=51.03 (65.12)
|
| 378 |
+
2025-03-03 21:27:17 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 800/1407] Batch=1.13 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4140 (0.4154) IoU=55.88 (59.11) Prec@50=62.70 (65.09)
|
| 379 |
+
2025-03-03 21:29:28 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 900/1407] Batch=1.36 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4532 (0.4157) IoU=56.31 (59.08) Prec@50=59.76 (65.11)
|
| 380 |
+
2025-03-03 21:31:37 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4510 (0.4167) IoU=54.24 (59.11) Prec@50=59.07 (65.15)
|
| 381 |
+
2025-03-03 21:33:46 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1100/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4805 (0.4178) IoU=54.15 (59.11) Prec@50=53.57 (65.15)
|
| 382 |
+
2025-03-03 21:35:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1200/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4031 (0.4190) IoU=61.14 (59.09) Prec@50=66.90 (65.10)
|
| 383 |
+
2025-03-03 21:38:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1300/1407] Batch=1.38 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4519 (0.4200) IoU=53.42 (59.04) Prec@50=58.10 (65.00)
|
| 384 |
+
2025-03-03 21:40:16 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1400/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4071 (0.4211) IoU=64.93 (59.02) Prec@50=74.13 (64.97)
|
| 385 |
+
2025-03-03 21:40:58 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[10/50] mIoU=62.92 oIoU=59.93 Pr@50: 70.59 Pr@60: 64.65 Pr@70: 55.67 Pr@80: 43.43 Pr@90: 20.32
|
| 386 |
+
2025-03-03 21:43:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 100/1407] Batch=1.28 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4913 (0.3952) IoU=56.25 (61.12) Prec@50=59.94 (67.24)
|
| 387 |
+
2025-03-03 21:45:49 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 200/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4227 (0.3953) IoU=61.48 (60.92) Prec@50=66.01 (66.88)
|
| 388 |
+
2025-03-03 21:47:57 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 300/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.2995 (0.3939) IoU=66.41 (61.32) Prec@50=76.88 (67.40)
|
| 389 |
+
2025-03-03 21:50:08 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 400/1407] Batch=1.29 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3530 (0.3954) IoU=63.13 (61.36) Prec@50=66.90 (67.54)
|
| 390 |
+
2025-03-03 21:52:18 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 500/1407] Batch=1.28 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3185 (0.3975) IoU=61.71 (61.22) Prec@50=66.07 (67.42)
|
| 391 |
+
2025-03-03 21:54:28 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 600/1407] Batch=1.22 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3977 (0.3992) IoU=57.77 (60.99) Prec@50=64.72 (67.13)
|
| 392 |
+
2025-03-03 21:56:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 700/1407] Batch=1.23 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4287 (0.3990) IoU=58.03 (60.86) Prec@50=61.11 (67.02)
|
| 393 |
+
2025-03-03 21:58:50 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 800/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4175 (0.3999) IoU=62.48 (60.75) Prec@50=69.74 (66.85)
|
| 394 |
+
2025-03-03 22:01:01 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 900/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4177 (0.4006) IoU=56.65 (60.83) Prec@50=60.22 (66.99)
|
| 395 |
+
2025-03-03 22:03:14 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1000/1407] Batch=1.34 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3733 (0.4012) IoU=64.66 (60.94) Prec@50=75.50 (67.11)
|
| 396 |
+
2025-03-03 22:05:25 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1100/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3426 (0.4011) IoU=72.35 (61.04) Prec@50=86.75 (67.21)
|
| 397 |
+
2025-03-03 22:07:36 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3453 (0.4008) IoU=65.79 (61.17) Prec@50=76.83 (67.42)
|
| 398 |
+
2025-03-03 22:09:48 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1300/1407] Batch=1.64 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4473 (0.4011) IoU=59.89 (61.20) Prec@50=73.39 (67.39)
|
| 399 |
+
2025-03-03 22:12:00 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1400/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4229 (0.4018) IoU=57.42 (61.18) Prec@50=66.96 (67.32)
|
| 400 |
+
2025-03-03 22:12:46 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[11/50] mIoU=63.56 oIoU=60.39 Pr@50: 71.37 Pr@60: 65.11 Pr@70: 57.58 Pr@80: 44.68 Pr@90: 20.90
|
| 401 |
+
2025-03-03 22:15:27 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 100/1407] Batch=1.32 (1.35) Data=0.00 (0.04) Lr=0.000100 Loss=0.3085 (0.3645) IoU=65.95 (62.33) Prec@50=78.01 (69.10)
|
| 402 |
+
2025-03-03 22:17:39 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 200/1407] Batch=1.66 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.2717 (0.3668) IoU=68.88 (62.85) Prec@50=78.70 (69.82)
|
| 403 |
+
2025-03-03 22:19:51 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 300/1407] Batch=1.14 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3265 (0.3708) IoU=62.95 (62.78) Prec@50=68.89 (69.53)
|
| 404 |
+
2025-03-03 22:22:04 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 400/1407] Batch=1.13 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3058 (0.3704) IoU=68.49 (62.68) Prec@50=73.65 (69.41)
|
| 405 |
+
2025-03-03 22:24:17 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 500/1407] Batch=1.25 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3924 (0.3704) IoU=69.36 (62.64) Prec@50=85.52 (69.33)
|
| 406 |
+
2025-03-03 22:26:26 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 600/1407] Batch=1.16 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.2908 (0.3704) IoU=66.72 (62.64) Prec@50=72.06 (69.23)
|
| 407 |
+
2025-03-03 22:28:38 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 700/1407] Batch=1.30 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4189 (0.3716) IoU=54.04 (62.63) Prec@50=55.06 (69.14)
|
| 408 |
+
2025-03-03 22:30:49 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 800/1407] Batch=1.26 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.2783 (0.3722) IoU=71.81 (62.64) Prec@50=80.36 (69.16)
|
| 409 |
+
2025-03-03 22:32:58 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 900/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.3591 (0.3739) IoU=69.03 (62.56) Prec@50=79.31 (69.07)
|
| 410 |
+
[2025-03-03 22:34:38,633] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGINT death signal, shutting down workers
|
| 411 |
+
[2025-03-03 22:34:38,634] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316571 closing signal SIGINT
|
| 412 |
+
[2025-03-03 22:34:38,636] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316572 closing signal SIGINT
|
| 413 |
+
[2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316573 closing signal SIGINT
|
| 414 |
+
[2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316574 closing signal SIGINT
|
| 415 |
+
[2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316575 closing signal SIGINT
|
| 416 |
+
[2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316576 closing signal SIGINT
|
| 417 |
+
Exception ignored in: [2025-03-03 22:34:38,812] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316571 closing signal SIGTERM
|
| 418 |
+
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fcbc8afc8b0>
|
| 419 |
+
Traceback (most recent call last):
|
| 420 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 421 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe97c7bf8b0>
|
| 422 |
+
Traceback (most recent call last):
|
| 423 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 424 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7e575df8b0>
|
| 425 |
+
Traceback (most recent call last):
|
| 426 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 427 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0f473318b0>
|
| 428 |
+
Traceback (most recent call last):
|
| 429 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 430 |
+
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f29013bb8b0>
|
| 431 |
+
Traceback (most recent call last):
|
| 432 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
|
| 433 |
+
[2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316572 closing signal SIGTERM
|
| 434 |
+
[2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316573 closing signal SIGTERM
|
| 435 |
+
[2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316574 closing signal SIGTERM
|
| 436 |
+
[2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316575 closing signal SIGTERM
|
| 437 |
+
[2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316576 closing signal SIGTERM
|
| 438 |
+
Traceback (most recent call last):
|
| 439 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 736, in run
|
| 440 |
+
result = self._invoke_run(role)
|
| 441 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 877, in _invoke_run
|
| 442 |
+
time.sleep(monitor_interval)
|
| 443 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 444 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 445 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
|
| 446 |
+
|
| 447 |
+
During handling of the above exception, another exception occurred:
|
| 448 |
+
|
| 449 |
+
Traceback (most recent call last):
|
| 450 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 743, in run
|
| 451 |
+
self._shutdown(e.sigval)
|
| 452 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
|
| 453 |
+
self._pcontext.close(death_sig)
|
| 454 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
|
| 455 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 456 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
|
| 457 |
+
handler.proc.wait(time_to_wait)
|
| 458 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
|
| 459 |
+
return self._wait(timeout=timeout)
|
| 460 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
|
| 461 |
+
time.sleep(delay)
|
| 462 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 463 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 464 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
|
| 465 |
+
|
| 466 |
+
During handling of the above exception, another exception occurred:
|
| 467 |
+
|
| 468 |
+
Traceback (most recent call last):
|
| 469 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 255, in launch_agent
|
| 470 |
+
result = agent.run()
|
| 471 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
|
| 472 |
+
result = f(*args, **kwargs)
|
| 473 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 748, in run
|
| 474 |
+
self._shutdown()
|
| 475 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
|
| 476 |
+
self._pcontext.close(death_sig)
|
| 477 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
|
| 478 |
+
self._close(death_sig=death_sig, timeout=timeout)
|
| 479 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
|
| 480 |
+
handler.proc.wait(time_to_wait)
|
| 481 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
|
| 482 |
+
return self._wait(timeout=timeout)
|
| 483 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
|
| 484 |
+
time.sleep(delay)
|
| 485 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 486 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 487 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
|
| 488 |
+
|
| 489 |
+
During handling of the above exception, another exception occurred:
|
| 490 |
+
|
| 491 |
+
Traceback (most recent call last):
|
| 492 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 197, in _run_module_as_main
|
| 493 |
+
return _run_code(code, main_globals, None,
|
| 494 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 87, in _run_code
|
| 495 |
+
exec(code, run_globals)
|
| 496 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 196, in <module>
|
| 497 |
+
main()
|
| 498 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 192, in main
|
| 499 |
+
launch(args)
|
| 500 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 177, in launch
|
| 501 |
+
run(args)
|
| 502 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/run.py", line 797, in run
|
| 503 |
+
elastic_launch(
|
| 504 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
|
| 505 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 506 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 277, in launch_agent
|
| 507 |
+
events.record(agent.get_event_failed())
|
| 508 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 756, in get_event_failed
|
| 509 |
+
raw_error=traceback.format_exc(),
|
| 510 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 167, in format_exc
|
| 511 |
+
return "".join(format_exception(*sys.exc_info(), limit=limit, chain=chain))
|
| 512 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 120, in format_exception
|
| 513 |
+
return list(TracebackException(
|
| 514 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 517, in __init__
|
| 515 |
+
self.stack = StackSummary.extract(
|
| 516 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 366, in extract
|
| 517 |
+
f.line
|
| 518 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 288, in line
|
| 519 |
+
self._line = linecache.getline(self.filename, self.lineno).strip()
|
| 520 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 30, in getline
|
| 521 |
+
lines = getlines(filename, module_globals)
|
| 522 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 46, in getlines
|
| 523 |
+
return updatecache(filename, module_globals)
|
| 524 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 93, in updatecache
|
| 525 |
+
stat = os.stat(fullname)
|
| 526 |
+
File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
|
| 527 |
+
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
|
| 528 |
+
torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
|
CGFormer/bash_logs/sanity_node03.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CGFormer/bert/__pycache__/activations.cpython-38.pyc
ADDED
|
Binary file (1.95 kB). View file
|
|
|
CGFormer/bert/__pycache__/activations.cpython-39.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
CGFormer/bert/__pycache__/configuration_bert.cpython-38.pyc
ADDED
|
Binary file (7.88 kB). View file
|
|
|
CGFormer/bert/__pycache__/configuration_bert.cpython-39.pyc
ADDED
|
Binary file (7.88 kB). View file
|
|
|
CGFormer/bert/__pycache__/configuration_utils.cpython-38.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
CGFormer/bert/__pycache__/configuration_utils.cpython-39.pyc
ADDED
|
Binary file (16.3 kB). View file
|
|
|
CGFormer/bert/__pycache__/file_utils.cpython-38.pyc
ADDED
|
Binary file (24.5 kB). View file
|
|
|
CGFormer/bert/__pycache__/file_utils.cpython-39.pyc
ADDED
|
Binary file (24.7 kB). View file
|
|
|
CGFormer/bert/__pycache__/generation_utils.cpython-38.pyc
ADDED
|
Binary file (28.2 kB). View file
|
|
|
CGFormer/bert/__pycache__/generation_utils.cpython-39.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
CGFormer/bert/__pycache__/modeling_bert.cpython-38.pyc
ADDED
|
Binary file (55.3 kB). View file
|
|
|
CGFormer/bert/__pycache__/modeling_bert.cpython-39.pyc
ADDED
|
Binary file (55.2 kB). View file
|
|
|
CGFormer/bert/__pycache__/modeling_utils.cpython-38.pyc
ADDED
|
Binary file (48 kB). View file
|
|
|
CGFormer/bert/__pycache__/modeling_utils.cpython-39.pyc
ADDED
|
Binary file (48 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_bert.cpython-38.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_bert.cpython-39.pyc
ADDED
|
Binary file (19.3 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_utils.cpython-38.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_utils.cpython-39.pyc
ADDED
|
Binary file (24.9 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_utils_base.cpython-38.pyc
ADDED
|
Binary file (82.4 kB). View file
|
|
|
CGFormer/bert/__pycache__/tokenization_utils_base.cpython-39.pyc
ADDED
|
Binary file (82.4 kB). View file
|
|
|
CGFormer/bert/activations.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def swish(x):
|
| 12 |
+
return x * torch.sigmoid(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _gelu_python(x):
|
| 16 |
+
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
| 17 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
| 18 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
| 19 |
+
This is now written in C in torch.nn.functional
|
| 20 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 21 |
+
"""
|
| 22 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def gelu_new(x):
|
| 26 |
+
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
| 27 |
+
Also see https://arxiv.org/abs/1606.08415
|
| 28 |
+
"""
|
| 29 |
+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if torch.__version__ < "1.4.0":
|
| 33 |
+
gelu = _gelu_python
|
| 34 |
+
else:
|
| 35 |
+
gelu = F.gelu
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def gelu_fast(x):
|
| 39 |
+
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
ACT2FN = {
|
| 43 |
+
"relu": F.relu,
|
| 44 |
+
"swish": swish,
|
| 45 |
+
"gelu": gelu,
|
| 46 |
+
"tanh": torch.tanh,
|
| 47 |
+
"gelu_new": gelu_new,
|
| 48 |
+
"gelu_fast": gelu_fast,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_activation(activation_string):
|
| 53 |
+
if activation_string in ACT2FN:
|
| 54 |
+
return ACT2FN[activation_string]
|
| 55 |
+
else:
|
| 56 |
+
raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
|
CGFormer/bert/configuration_bert.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" BERT model configuration """
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
from .configuration_utils import PretrainedConfig
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 27 |
+
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
| 28 |
+
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
| 29 |
+
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
| 30 |
+
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
| 31 |
+
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
| 32 |
+
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
| 33 |
+
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
| 34 |
+
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
| 35 |
+
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
| 36 |
+
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
| 37 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
| 38 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
| 39 |
+
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
| 40 |
+
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
| 41 |
+
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
| 42 |
+
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
|
| 43 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
|
| 44 |
+
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
|
| 45 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
|
| 46 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
| 47 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
| 48 |
+
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
|
| 49 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class BertConfig(PretrainedConfig):
|
| 54 |
+
r"""
|
| 55 |
+
This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
|
| 56 |
+
It is used to instantiate an BERT model according to the specified arguments, defining the model
|
| 57 |
+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
| 58 |
+
the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
|
| 59 |
+
|
| 60 |
+
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
| 61 |
+
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
| 62 |
+
for more information.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
vocab_size (:obj:`int`, optional, defaults to 30522):
|
| 67 |
+
Vocabulary size of the BERT model. Defines the different tokens that
|
| 68 |
+
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
|
| 69 |
+
hidden_size (:obj:`int`, optional, defaults to 768):
|
| 70 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 71 |
+
num_hidden_layers (:obj:`int`, optional, defaults to 12):
|
| 72 |
+
Number of hidden layers in the Transformer encoder.
|
| 73 |
+
num_attention_heads (:obj:`int`, optional, defaults to 12):
|
| 74 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 75 |
+
intermediate_size (:obj:`int`, optional, defaults to 3072):
|
| 76 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 77 |
+
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
| 78 |
+
The non-linear activation function (function or string) in the encoder and pooler.
|
| 79 |
+
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
| 80 |
+
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
| 81 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
| 82 |
+
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
| 83 |
+
The dropout ratio for the attention probabilities.
|
| 84 |
+
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
| 85 |
+
The maximum sequence length that this model might ever be used with.
|
| 86 |
+
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
| 87 |
+
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
| 88 |
+
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
|
| 89 |
+
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
| 90 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 91 |
+
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
| 92 |
+
The epsilon used by the layer normalization layers.
|
| 93 |
+
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
|
| 94 |
+
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
| 95 |
+
|
| 96 |
+
Example::
|
| 97 |
+
|
| 98 |
+
>>> from transformers import BertModel, BertConfig
|
| 99 |
+
|
| 100 |
+
>>> # Initializing a BERT bert-base-uncased style configuration
|
| 101 |
+
>>> configuration = BertConfig()
|
| 102 |
+
|
| 103 |
+
>>> # Initializing a model from the bert-base-uncased style configuration
|
| 104 |
+
>>> model = BertModel(configuration)
|
| 105 |
+
|
| 106 |
+
>>> # Accessing the model configuration
|
| 107 |
+
>>> configuration = model.config
|
| 108 |
+
"""
|
| 109 |
+
model_type = "bert"
|
| 110 |
+
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
vocab_size=30522,
|
| 114 |
+
hidden_size=768,
|
| 115 |
+
num_hidden_layers=12,
|
| 116 |
+
num_attention_heads=12,
|
| 117 |
+
intermediate_size=3072,
|
| 118 |
+
hidden_act="gelu",
|
| 119 |
+
hidden_dropout_prob=0.1,
|
| 120 |
+
attention_probs_dropout_prob=0.1,
|
| 121 |
+
max_position_embeddings=512,
|
| 122 |
+
type_vocab_size=2,
|
| 123 |
+
initializer_range=0.02,
|
| 124 |
+
layer_norm_eps=1e-12,
|
| 125 |
+
pad_token_id=0,
|
| 126 |
+
gradient_checkpointing=False,
|
| 127 |
+
**kwargs
|
| 128 |
+
):
|
| 129 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 130 |
+
|
| 131 |
+
self.vocab_size = vocab_size
|
| 132 |
+
self.hidden_size = hidden_size
|
| 133 |
+
self.num_hidden_layers = num_hidden_layers
|
| 134 |
+
self.num_attention_heads = num_attention_heads
|
| 135 |
+
self.hidden_act = hidden_act
|
| 136 |
+
self.intermediate_size = intermediate_size
|
| 137 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 138 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 139 |
+
self.max_position_embeddings = max_position_embeddings
|
| 140 |
+
self.type_vocab_size = type_vocab_size
|
| 141 |
+
self.initializer_range = initializer_range
|
| 142 |
+
self.layer_norm_eps = layer_norm_eps
|
| 143 |
+
self.gradient_checkpointing = gradient_checkpointing
|
CGFormer/bert/configuration_utils.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
""" Configuration base class and utilities."""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import copy
|
| 20 |
+
import json
|
| 21 |
+
import logging
|
| 22 |
+
import os
|
| 23 |
+
from typing import Dict, Tuple
|
| 24 |
+
|
| 25 |
+
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class PretrainedConfig(object):
|
| 32 |
+
r""" Base class for all configuration classes.
|
| 33 |
+
Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
|
| 34 |
+
|
| 35 |
+
Note:
|
| 36 |
+
A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
|
| 37 |
+
It only affects the model's configuration.
|
| 38 |
+
|
| 39 |
+
Class attributes (overridden by derived classes):
|
| 40 |
+
- ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
|
| 44 |
+
Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
|
| 45 |
+
num_labels (:obj:`int`, `optional`, defaults to `2`):
|
| 46 |
+
Number of classes to use when the model is a classification model (sequences/tokens)
|
| 47 |
+
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 48 |
+
Should the model returns all hidden-states.
|
| 49 |
+
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 50 |
+
Should the model returns all attentions.
|
| 51 |
+
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 52 |
+
Is the model used with Torchscript (for PyTorch models).
|
| 53 |
+
"""
|
| 54 |
+
model_type: str = ""
|
| 55 |
+
|
| 56 |
+
def __init__(self, **kwargs):
|
| 57 |
+
# Attributes with defaults
|
| 58 |
+
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
|
| 59 |
+
self.output_attentions = kwargs.pop("output_attentions", False)
|
| 60 |
+
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
|
| 61 |
+
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
|
| 62 |
+
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
|
| 63 |
+
self.pruned_heads = kwargs.pop("pruned_heads", {})
|
| 64 |
+
|
| 65 |
+
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
|
| 66 |
+
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
|
| 67 |
+
self.is_decoder = kwargs.pop("is_decoder", False)
|
| 68 |
+
|
| 69 |
+
# Parameters for sequence generation
|
| 70 |
+
self.max_length = kwargs.pop("max_length", 20)
|
| 71 |
+
self.min_length = kwargs.pop("min_length", 0)
|
| 72 |
+
self.do_sample = kwargs.pop("do_sample", False)
|
| 73 |
+
self.early_stopping = kwargs.pop("early_stopping", False)
|
| 74 |
+
self.num_beams = kwargs.pop("num_beams", 1)
|
| 75 |
+
self.temperature = kwargs.pop("temperature", 1.0)
|
| 76 |
+
self.top_k = kwargs.pop("top_k", 50)
|
| 77 |
+
self.top_p = kwargs.pop("top_p", 1.0)
|
| 78 |
+
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
|
| 79 |
+
self.length_penalty = kwargs.pop("length_penalty", 1.0)
|
| 80 |
+
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
|
| 81 |
+
self.bad_words_ids = kwargs.pop("bad_words_ids", None)
|
| 82 |
+
self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
|
| 83 |
+
|
| 84 |
+
# Fine-tuning task arguments
|
| 85 |
+
self.architectures = kwargs.pop("architectures", None)
|
| 86 |
+
self.finetuning_task = kwargs.pop("finetuning_task", None)
|
| 87 |
+
self.id2label = kwargs.pop("id2label", None)
|
| 88 |
+
self.label2id = kwargs.pop("label2id", None)
|
| 89 |
+
if self.id2label is not None:
|
| 90 |
+
kwargs.pop("num_labels", None)
|
| 91 |
+
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
|
| 92 |
+
# Keys are always strings in JSON so convert ids to int here.
|
| 93 |
+
else:
|
| 94 |
+
self.num_labels = kwargs.pop("num_labels", 2)
|
| 95 |
+
|
| 96 |
+
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
|
| 97 |
+
self.prefix = kwargs.pop("prefix", None)
|
| 98 |
+
self.bos_token_id = kwargs.pop("bos_token_id", None)
|
| 99 |
+
self.pad_token_id = kwargs.pop("pad_token_id", None)
|
| 100 |
+
self.eos_token_id = kwargs.pop("eos_token_id", None)
|
| 101 |
+
self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
|
| 102 |
+
|
| 103 |
+
# task specific arguments
|
| 104 |
+
self.task_specific_params = kwargs.pop("task_specific_params", None)
|
| 105 |
+
|
| 106 |
+
# TPU arguments
|
| 107 |
+
self.xla_device = kwargs.pop("xla_device", None)
|
| 108 |
+
|
| 109 |
+
# Additional attributes without default values
|
| 110 |
+
for key, value in kwargs.items():
|
| 111 |
+
try:
|
| 112 |
+
setattr(self, key, value)
|
| 113 |
+
except AttributeError as err:
|
| 114 |
+
logger.error("Can't set {} with value {} for {}".format(key, value, self))
|
| 115 |
+
raise err
|
| 116 |
+
|
| 117 |
+
@property
|
| 118 |
+
def num_labels(self):
|
| 119 |
+
return len(self.id2label)
|
| 120 |
+
|
| 121 |
+
@num_labels.setter
|
| 122 |
+
def num_labels(self, num_labels):
|
| 123 |
+
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
|
| 124 |
+
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
|
| 125 |
+
|
| 126 |
+
def save_pretrained(self, save_directory):
|
| 127 |
+
"""
|
| 128 |
+
Save a configuration object to the directory `save_directory`, so that it
|
| 129 |
+
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
save_directory (:obj:`string`):
|
| 133 |
+
Directory where the configuration JSON file will be saved.
|
| 134 |
+
"""
|
| 135 |
+
if os.path.isfile(save_directory):
|
| 136 |
+
raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
|
| 137 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 138 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 139 |
+
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
| 140 |
+
|
| 141 |
+
self.to_json_file(output_config_file, use_diff=True)
|
| 142 |
+
logger.info("Configuration saved in {}".format(output_config_file))
|
| 143 |
+
|
| 144 |
+
@classmethod
|
| 145 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
|
| 146 |
+
r"""
|
| 147 |
+
|
| 148 |
+
Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
pretrained_model_name_or_path (:obj:`string`):
|
| 152 |
+
either:
|
| 153 |
+
- a string with the `shortcut name` of a pre-trained model configuration to load from cache or
|
| 154 |
+
download, e.g.: ``bert-base-uncased``.
|
| 155 |
+
- a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
|
| 156 |
+
our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
| 157 |
+
- a path to a `directory` containing a configuration file saved using the
|
| 158 |
+
:func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
|
| 159 |
+
- a path or url to a saved configuration JSON `file`, e.g.:
|
| 160 |
+
``./my_model_directory/configuration.json``.
|
| 161 |
+
cache_dir (:obj:`string`, `optional`):
|
| 162 |
+
Path to a directory in which a downloaded pre-trained model
|
| 163 |
+
configuration should be cached if the standard cache should not be used.
|
| 164 |
+
kwargs (:obj:`Dict[str, any]`, `optional`):
|
| 165 |
+
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
| 166 |
+
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
|
| 167 |
+
controlled by the `return_unused_kwargs` keyword parameter.
|
| 168 |
+
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 169 |
+
Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
|
| 170 |
+
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 171 |
+
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
| 172 |
+
proxies (:obj:`Dict`, `optional`):
|
| 173 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.:
|
| 174 |
+
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
|
| 175 |
+
The proxies are used on each request.
|
| 176 |
+
return_unused_kwargs: (`optional`) bool:
|
| 177 |
+
If False, then this function returns just the final configuration object.
|
| 178 |
+
If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
|
| 179 |
+
dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
|
| 180 |
+
of kwargs which has not been used to update `config` and is otherwise ignored.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
:class:`PretrainedConfig`: An instance of a configuration object
|
| 184 |
+
|
| 185 |
+
Examples::
|
| 186 |
+
|
| 187 |
+
# We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
|
| 188 |
+
# derived class: BertConfig
|
| 189 |
+
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
|
| 190 |
+
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
|
| 191 |
+
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
|
| 192 |
+
config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
|
| 193 |
+
assert config.output_attention == True
|
| 194 |
+
config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
|
| 195 |
+
foo=False, return_unused_kwargs=True)
|
| 196 |
+
assert config.output_attention == True
|
| 197 |
+
assert unused_kwargs == {'foo': False}
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 201 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
|
| 205 |
+
"""
|
| 206 |
+
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
|
| 207 |
+
for instantiating a Config using `from_dict`.
|
| 208 |
+
|
| 209 |
+
Parameters:
|
| 210 |
+
pretrained_model_name_or_path (:obj:`string`):
|
| 211 |
+
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
:obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
|
| 215 |
+
|
| 216 |
+
"""
|
| 217 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 218 |
+
force_download = kwargs.pop("force_download", False)
|
| 219 |
+
resume_download = kwargs.pop("resume_download", False)
|
| 220 |
+
proxies = kwargs.pop("proxies", None)
|
| 221 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 222 |
+
|
| 223 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 224 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
| 225 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
| 226 |
+
config_file = pretrained_model_name_or_path
|
| 227 |
+
else:
|
| 228 |
+
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
|
| 229 |
+
|
| 230 |
+
try:
|
| 231 |
+
# Load from URL or cache if already cached
|
| 232 |
+
resolved_config_file = cached_path(
|
| 233 |
+
config_file,
|
| 234 |
+
cache_dir=cache_dir,
|
| 235 |
+
force_download=force_download,
|
| 236 |
+
proxies=proxies,
|
| 237 |
+
resume_download=resume_download,
|
| 238 |
+
local_files_only=local_files_only,
|
| 239 |
+
)
|
| 240 |
+
# Load config dict
|
| 241 |
+
if resolved_config_file is None:
|
| 242 |
+
raise EnvironmentError
|
| 243 |
+
config_dict = cls._dict_from_json_file(resolved_config_file)
|
| 244 |
+
|
| 245 |
+
except EnvironmentError:
|
| 246 |
+
msg = (
|
| 247 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
| 248 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
| 249 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
| 250 |
+
)
|
| 251 |
+
raise EnvironmentError(msg)
|
| 252 |
+
|
| 253 |
+
except json.JSONDecodeError:
|
| 254 |
+
msg = (
|
| 255 |
+
"Couldn't reach server at '{}' to download configuration file or "
|
| 256 |
+
"configuration file is not a valid JSON file. "
|
| 257 |
+
"Please check network or file content here: {}.".format(config_file, resolved_config_file)
|
| 258 |
+
)
|
| 259 |
+
raise EnvironmentError(msg)
|
| 260 |
+
|
| 261 |
+
if resolved_config_file == config_file:
|
| 262 |
+
logger.info("loading configuration file {}".format(config_file))
|
| 263 |
+
else:
|
| 264 |
+
logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
|
| 265 |
+
|
| 266 |
+
return config_dict, kwargs
|
| 267 |
+
|
| 268 |
+
@classmethod
|
| 269 |
+
def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
|
| 270 |
+
"""
|
| 271 |
+
Constructs a `Config` from a Python dictionary of parameters.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
config_dict (:obj:`Dict[str, any]`):
|
| 275 |
+
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
|
| 276 |
+
from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
|
| 277 |
+
method.
|
| 278 |
+
kwargs (:obj:`Dict[str, any]`):
|
| 279 |
+
Additional parameters from which to initialize the configuration object.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
:class:`PretrainedConfig`: An instance of a configuration object
|
| 283 |
+
"""
|
| 284 |
+
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
| 285 |
+
|
| 286 |
+
config = cls(**config_dict)
|
| 287 |
+
|
| 288 |
+
if hasattr(config, "pruned_heads"):
|
| 289 |
+
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
|
| 290 |
+
|
| 291 |
+
# Update config with kwargs if needed
|
| 292 |
+
to_remove = []
|
| 293 |
+
for key, value in kwargs.items():
|
| 294 |
+
if hasattr(config, key):
|
| 295 |
+
setattr(config, key, value)
|
| 296 |
+
to_remove.append(key)
|
| 297 |
+
for key in to_remove:
|
| 298 |
+
kwargs.pop(key, None)
|
| 299 |
+
|
| 300 |
+
logger.info("Model config %s", str(config))
|
| 301 |
+
if return_unused_kwargs:
|
| 302 |
+
return config, kwargs
|
| 303 |
+
else:
|
| 304 |
+
return config
|
| 305 |
+
|
| 306 |
+
@classmethod
|
| 307 |
+
def from_json_file(cls, json_file: str) -> "PretrainedConfig":
|
| 308 |
+
"""
|
| 309 |
+
Constructs a `Config` from the path to a json file of parameters.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
json_file (:obj:`string`):
|
| 313 |
+
Path to the JSON file containing the parameters.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
:class:`PretrainedConfig`: An instance of a configuration object
|
| 317 |
+
|
| 318 |
+
"""
|
| 319 |
+
config_dict = cls._dict_from_json_file(json_file)
|
| 320 |
+
return cls(**config_dict)
|
| 321 |
+
|
| 322 |
+
@classmethod
|
| 323 |
+
def _dict_from_json_file(cls, json_file: str):
|
| 324 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 325 |
+
text = reader.read()
|
| 326 |
+
return json.loads(text)
|
| 327 |
+
|
| 328 |
+
def __eq__(self, other):
|
| 329 |
+
return self.__dict__ == other.__dict__
|
| 330 |
+
|
| 331 |
+
def __repr__(self):
|
| 332 |
+
return "{} {}".format(self.__class__.__name__, self.to_json_string())
|
| 333 |
+
|
| 334 |
+
def to_diff_dict(self):
|
| 335 |
+
"""
|
| 336 |
+
Removes all attributes from config which correspond to the default
|
| 337 |
+
config attributes for better readability and serializes to a Python
|
| 338 |
+
dictionary.
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 342 |
+
"""
|
| 343 |
+
config_dict = self.to_dict()
|
| 344 |
+
|
| 345 |
+
# get the default config dict
|
| 346 |
+
default_config_dict = PretrainedConfig().to_dict()
|
| 347 |
+
|
| 348 |
+
serializable_config_dict = {}
|
| 349 |
+
|
| 350 |
+
# only serialize values that differ from the default config
|
| 351 |
+
for key, value in config_dict.items():
|
| 352 |
+
if key not in default_config_dict or value != default_config_dict[key]:
|
| 353 |
+
serializable_config_dict[key] = value
|
| 354 |
+
|
| 355 |
+
return serializable_config_dict
|
| 356 |
+
|
| 357 |
+
def to_dict(self):
|
| 358 |
+
"""
|
| 359 |
+
Serializes this instance to a Python dictionary.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 363 |
+
"""
|
| 364 |
+
output = copy.deepcopy(self.__dict__)
|
| 365 |
+
if hasattr(self.__class__, "model_type"):
|
| 366 |
+
output["model_type"] = self.__class__.model_type
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def to_json_string(self, use_diff=True):
|
| 370 |
+
"""
|
| 371 |
+
Serializes this instance to a JSON string.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
use_diff (:obj:`bool`):
|
| 375 |
+
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
:obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
|
| 379 |
+
"""
|
| 380 |
+
if use_diff is True:
|
| 381 |
+
config_dict = self.to_diff_dict()
|
| 382 |
+
else:
|
| 383 |
+
config_dict = self.to_dict()
|
| 384 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 385 |
+
|
| 386 |
+
def to_json_file(self, json_file_path, use_diff=True):
|
| 387 |
+
"""
|
| 388 |
+
Save this instance to a json file.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
json_file_path (:obj:`string`):
|
| 392 |
+
Path to the JSON file in which this configuration instance's parameters will be saved.
|
| 393 |
+
use_diff (:obj:`bool`):
|
| 394 |
+
If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
|
| 395 |
+
"""
|
| 396 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 397 |
+
writer.write(self.to_json_string(use_diff=use_diff))
|
| 398 |
+
|
| 399 |
+
def update(self, config_dict: Dict):
|
| 400 |
+
"""
|
| 401 |
+
Updates attributes of this class
|
| 402 |
+
with attributes from `config_dict`.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
:obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
|
| 406 |
+
"""
|
| 407 |
+
for key, value in config_dict.items():
|
| 408 |
+
setattr(self, key, value)
|
CGFormer/bert/file_utils.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for working with the local dataset cache.
|
| 3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
| 4 |
+
Copyright by the AllenNLP authors.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import fnmatch
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import sys
|
| 13 |
+
import tarfile
|
| 14 |
+
import tempfile
|
| 15 |
+
from contextlib import contextmanager
|
| 16 |
+
from functools import partial, wraps
|
| 17 |
+
from hashlib import sha256
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Dict, Optional, Union
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
from zipfile import ZipFile, is_zipfile
|
| 22 |
+
|
| 23 |
+
import requests
|
| 24 |
+
from filelock import FileLock
|
| 25 |
+
from tqdm.auto import tqdm
|
| 26 |
+
|
| 27 |
+
#from . import __version__
|
| 28 |
+
__version__ = "3.0.2"
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
| 34 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
| 35 |
+
if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
|
| 36 |
+
import torch
|
| 37 |
+
|
| 38 |
+
_torch_available = True # pylint: disable=invalid-name
|
| 39 |
+
logger.info("PyTorch version {} available.".format(torch.__version__))
|
| 40 |
+
else:
|
| 41 |
+
logger.info("Disabling PyTorch because USE_TF is set")
|
| 42 |
+
_torch_available = False
|
| 43 |
+
except ImportError:
|
| 44 |
+
_torch_available = False # pylint: disable=invalid-name
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
| 48 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
| 49 |
+
|
| 50 |
+
if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
|
| 51 |
+
import tensorflow as tf
|
| 52 |
+
|
| 53 |
+
assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
|
| 54 |
+
_tf_available = True # pylint: disable=invalid-name
|
| 55 |
+
logger.info("TensorFlow version {} available.".format(tf.__version__))
|
| 56 |
+
else:
|
| 57 |
+
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
| 58 |
+
_tf_available = False
|
| 59 |
+
except (ImportError, AssertionError):
|
| 60 |
+
_tf_available = False # pylint: disable=invalid-name
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
from torch.hub import _get_torch_home
|
| 65 |
+
|
| 66 |
+
torch_cache_home = _get_torch_home()
|
| 67 |
+
except ImportError:
|
| 68 |
+
torch_cache_home = os.path.expanduser(
|
| 69 |
+
os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
import torch_xla.core.xla_model as xm # noqa: F401
|
| 75 |
+
|
| 76 |
+
if _torch_available:
|
| 77 |
+
_torch_tpu_available = True # pylint: disable=
|
| 78 |
+
else:
|
| 79 |
+
_torch_tpu_available = False
|
| 80 |
+
except ImportError:
|
| 81 |
+
_torch_tpu_available = False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
import psutil # noqa: F401
|
| 86 |
+
|
| 87 |
+
_psutil_available = True
|
| 88 |
+
|
| 89 |
+
except ImportError:
|
| 90 |
+
_psutil_available = False
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
import py3nvml # noqa: F401
|
| 95 |
+
|
| 96 |
+
_py3nvml_available = True
|
| 97 |
+
|
| 98 |
+
except ImportError:
|
| 99 |
+
_py3nvml_available = False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
from apex import amp # noqa: F401
|
| 104 |
+
|
| 105 |
+
_has_apex = True
|
| 106 |
+
except ImportError:
|
| 107 |
+
_has_apex = False
|
| 108 |
+
|
| 109 |
+
default_cache_path = os.path.join(torch_cache_home, "transformers")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
|
| 113 |
+
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
|
| 114 |
+
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
|
| 115 |
+
|
| 116 |
+
WEIGHTS_NAME = "pytorch_model.bin"
|
| 117 |
+
TF2_WEIGHTS_NAME = "tf_model.h5"
|
| 118 |
+
TF_WEIGHTS_NAME = "model.ckpt"
|
| 119 |
+
CONFIG_NAME = "config.json"
|
| 120 |
+
MODEL_CARD_NAME = "modelcard.json"
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
|
| 124 |
+
DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
|
| 125 |
+
DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
|
| 126 |
+
|
| 127 |
+
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
|
| 128 |
+
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def is_torch_available():
|
| 132 |
+
return _torch_available
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def is_tf_available():
|
| 136 |
+
return _tf_available
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def is_torch_tpu_available():
|
| 140 |
+
return _torch_tpu_available
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def is_psutil_available():
|
| 144 |
+
return _psutil_available
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def is_py3nvml_available():
|
| 148 |
+
return _py3nvml_available
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def is_apex_available():
|
| 152 |
+
return _has_apex
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def add_start_docstrings(*docstr):
|
| 156 |
+
def docstring_decorator(fn):
|
| 157 |
+
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
| 158 |
+
return fn
|
| 159 |
+
|
| 160 |
+
return docstring_decorator
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def add_start_docstrings_to_callable(*docstr):
|
| 164 |
+
def docstring_decorator(fn):
|
| 165 |
+
class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
|
| 166 |
+
intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
|
| 167 |
+
note = r"""
|
| 168 |
+
|
| 169 |
+
.. note::
|
| 170 |
+
Although the recipe for forward pass needs to be defined within
|
| 171 |
+
this function, one should call the :class:`Module` instance afterwards
|
| 172 |
+
instead of this since the former takes care of running the
|
| 173 |
+
pre and post processing steps while the latter silently ignores them.
|
| 174 |
+
"""
|
| 175 |
+
fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
| 176 |
+
return fn
|
| 177 |
+
|
| 178 |
+
return docstring_decorator
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def add_end_docstrings(*docstr):
|
| 182 |
+
def docstring_decorator(fn):
|
| 183 |
+
fn.__doc__ = fn.__doc__ + "".join(docstr)
|
| 184 |
+
return fn
|
| 185 |
+
|
| 186 |
+
return docstring_decorator
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
| 190 |
+
Example::
|
| 191 |
+
|
| 192 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 193 |
+
>>> import torch
|
| 194 |
+
|
| 195 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 196 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 197 |
+
|
| 198 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 199 |
+
>>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
|
| 200 |
+
|
| 201 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 202 |
+
>>> loss, scores = outputs[:2]
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
PT_QUESTION_ANSWERING_SAMPLE = r"""
|
| 206 |
+
Example::
|
| 207 |
+
|
| 208 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 209 |
+
>>> import torch
|
| 210 |
+
|
| 211 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 212 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 213 |
+
|
| 214 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 215 |
+
>>> start_positions = torch.tensor([1])
|
| 216 |
+
>>> end_positions = torch.tensor([3])
|
| 217 |
+
|
| 218 |
+
>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
|
| 219 |
+
>>> loss, start_scores, end_scores = outputs[:3]
|
| 220 |
+
"""
|
| 221 |
+
|
| 222 |
+
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
| 223 |
+
Example::
|
| 224 |
+
|
| 225 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 226 |
+
>>> import torch
|
| 227 |
+
|
| 228 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 229 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 230 |
+
|
| 231 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 232 |
+
>>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
| 233 |
+
>>> outputs = model(**inputs, labels=labels)
|
| 234 |
+
>>> loss, logits = outputs[:2]
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
PT_MASKED_LM_SAMPLE = r"""
|
| 238 |
+
Example::
|
| 239 |
+
|
| 240 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 241 |
+
>>> import torch
|
| 242 |
+
|
| 243 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 244 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 245 |
+
|
| 246 |
+
>>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
|
| 247 |
+
|
| 248 |
+
>>> outputs = model(input_ids, labels=input_ids)
|
| 249 |
+
>>> loss, prediction_scores = outputs[:2]
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
PT_BASE_MODEL_SAMPLE = r"""
|
| 253 |
+
Example::
|
| 254 |
+
|
| 255 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 256 |
+
>>> import torch
|
| 257 |
+
|
| 258 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 259 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 260 |
+
|
| 261 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 262 |
+
>>> outputs = model(**inputs)
|
| 263 |
+
|
| 264 |
+
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
PT_MULTIPLE_CHOICE_SAMPLE = r"""
|
| 268 |
+
Example::
|
| 269 |
+
|
| 270 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 271 |
+
>>> import torch
|
| 272 |
+
|
| 273 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 274 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 275 |
+
|
| 276 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 277 |
+
>>> choice0 = "It is eaten with a fork and a knife."
|
| 278 |
+
>>> choice1 = "It is eaten while held in the hand."
|
| 279 |
+
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
|
| 280 |
+
|
| 281 |
+
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
|
| 282 |
+
>>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
|
| 283 |
+
|
| 284 |
+
>>> # the linear classifier still needs to be trained
|
| 285 |
+
>>> loss, logits = outputs[:2]
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
PT_CAUSAL_LM_SAMPLE = r"""
|
| 289 |
+
Example::
|
| 290 |
+
|
| 291 |
+
>>> import torch
|
| 292 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 293 |
+
|
| 294 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 295 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 296 |
+
|
| 297 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 298 |
+
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
| 299 |
+
>>> loss, logits = outputs[:2]
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
| 303 |
+
Example::
|
| 304 |
+
|
| 305 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 306 |
+
>>> import tensorflow as tf
|
| 307 |
+
|
| 308 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 309 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 310 |
+
|
| 311 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
| 312 |
+
>>> input_ids = inputs["input_ids"]
|
| 313 |
+
>>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
|
| 314 |
+
|
| 315 |
+
>>> outputs = model(inputs)
|
| 316 |
+
>>> loss, scores = outputs[:2]
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
TF_QUESTION_ANSWERING_SAMPLE = r"""
|
| 320 |
+
Example::
|
| 321 |
+
|
| 322 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 323 |
+
>>> import tensorflow as tf
|
| 324 |
+
|
| 325 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 326 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 327 |
+
|
| 328 |
+
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
| 329 |
+
>>> input_dict = tokenizer(question, text, return_tensors='tf')
|
| 330 |
+
>>> start_scores, end_scores = model(input_dict)
|
| 331 |
+
|
| 332 |
+
>>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
|
| 333 |
+
>>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
|
| 334 |
+
"""
|
| 335 |
+
|
| 336 |
+
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
| 337 |
+
Example::
|
| 338 |
+
|
| 339 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 340 |
+
>>> import tensorflow as tf
|
| 341 |
+
|
| 342 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 343 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 344 |
+
|
| 345 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
| 346 |
+
>>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
|
| 347 |
+
|
| 348 |
+
>>> outputs = model(inputs)
|
| 349 |
+
>>> loss, logits = outputs[:2]
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
TF_MASKED_LM_SAMPLE = r"""
|
| 353 |
+
Example::
|
| 354 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 355 |
+
>>> import tensorflow as tf
|
| 356 |
+
|
| 357 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 358 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 359 |
+
|
| 360 |
+
>>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
|
| 361 |
+
|
| 362 |
+
>>> outputs = model(input_ids)
|
| 363 |
+
>>> prediction_scores = outputs[0]
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
TF_BASE_MODEL_SAMPLE = r"""
|
| 367 |
+
Example::
|
| 368 |
+
|
| 369 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 370 |
+
>>> import tensorflow as tf
|
| 371 |
+
|
| 372 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 373 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 374 |
+
|
| 375 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
| 376 |
+
>>> outputs = model(inputs)
|
| 377 |
+
|
| 378 |
+
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
TF_MULTIPLE_CHOICE_SAMPLE = r"""
|
| 382 |
+
Example::
|
| 383 |
+
|
| 384 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 385 |
+
>>> import tensorflow as tf
|
| 386 |
+
|
| 387 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 388 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 389 |
+
|
| 390 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 391 |
+
>>> choice0 = "It is eaten with a fork and a knife."
|
| 392 |
+
>>> choice1 = "It is eaten while held in the hand."
|
| 393 |
+
|
| 394 |
+
>>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
|
| 395 |
+
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
|
| 396 |
+
>>> outputs = model(inputs) # batch size is 1
|
| 397 |
+
|
| 398 |
+
>>> # the linear classifier still needs to be trained
|
| 399 |
+
>>> logits = outputs[0]
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
TF_CAUSAL_LM_SAMPLE = r"""
|
| 403 |
+
Example::
|
| 404 |
+
|
| 405 |
+
>>> from transformers import {tokenizer_class}, {model_class}
|
| 406 |
+
>>> import tensorflow as tf
|
| 407 |
+
|
| 408 |
+
>>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
|
| 409 |
+
>>> model = {model_class}.from_pretrained('{checkpoint}')
|
| 410 |
+
|
| 411 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
| 412 |
+
>>> outputs = model(inputs)
|
| 413 |
+
>>> logits = outputs[0]
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
|
| 418 |
+
def docstring_decorator(fn):
|
| 419 |
+
model_class = fn.__qualname__.split(".")[0]
|
| 420 |
+
is_tf_class = model_class[:2] == "TF"
|
| 421 |
+
|
| 422 |
+
if "SequenceClassification" in model_class:
|
| 423 |
+
code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
|
| 424 |
+
elif "QuestionAnswering" in model_class:
|
| 425 |
+
code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
|
| 426 |
+
elif "TokenClassification" in model_class:
|
| 427 |
+
code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
|
| 428 |
+
elif "MultipleChoice" in model_class:
|
| 429 |
+
code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
|
| 430 |
+
elif "MaskedLM" in model_class:
|
| 431 |
+
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
|
| 432 |
+
elif "LMHead" in model_class:
|
| 433 |
+
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
|
| 434 |
+
elif "Model" in model_class:
|
| 435 |
+
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
|
| 436 |
+
else:
|
| 437 |
+
raise ValueError(f"Docstring can't be built for model {model_class}")
|
| 438 |
+
|
| 439 |
+
built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
|
| 440 |
+
fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
|
| 441 |
+
return fn
|
| 442 |
+
|
| 443 |
+
return docstring_decorator
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def is_remote_url(url_or_filename):
|
| 447 |
+
parsed = urlparse(url_or_filename)
|
| 448 |
+
return parsed.scheme in ("http", "https")
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
|
| 452 |
+
"""
|
| 453 |
+
Resolve a model identifier, and a file name, to a HF-hosted url
|
| 454 |
+
on either S3 or Cloudfront (a Content Delivery Network, or CDN).
|
| 455 |
+
|
| 456 |
+
Cloudfront is replicated over the globe so downloads are way faster
|
| 457 |
+
for the end user (and it also lowers our bandwidth costs). However, it
|
| 458 |
+
is more aggressively cached by default, so may not always reflect the
|
| 459 |
+
latest changes to the underlying file (default TTL is 24 hours).
|
| 460 |
+
|
| 461 |
+
In terms of client-side caching from this library, even though
|
| 462 |
+
Cloudfront relays the ETags from S3, using one or the other
|
| 463 |
+
(or switching from one to the other) will affect caching: cached files
|
| 464 |
+
are not shared between the two because the cached file's name contains
|
| 465 |
+
a hash of the url.
|
| 466 |
+
"""
|
| 467 |
+
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
|
| 468 |
+
legacy_format = "/" not in model_id
|
| 469 |
+
if legacy_format:
|
| 470 |
+
return f"{endpoint}/{model_id}-{filename}"
|
| 471 |
+
else:
|
| 472 |
+
return f"{endpoint}/{model_id}/{filename}"
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def url_to_filename(url, etag=None):
|
| 476 |
+
"""
|
| 477 |
+
Convert `url` into a hashed filename in a repeatable way.
|
| 478 |
+
If `etag` is specified, append its hash to the url's, delimited
|
| 479 |
+
by a period.
|
| 480 |
+
If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
|
| 481 |
+
so that TF 2.0 can identify it as a HDF5 file
|
| 482 |
+
(see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
|
| 483 |
+
"""
|
| 484 |
+
url_bytes = url.encode("utf-8")
|
| 485 |
+
url_hash = sha256(url_bytes)
|
| 486 |
+
filename = url_hash.hexdigest()
|
| 487 |
+
|
| 488 |
+
if etag:
|
| 489 |
+
etag_bytes = etag.encode("utf-8")
|
| 490 |
+
etag_hash = sha256(etag_bytes)
|
| 491 |
+
filename += "." + etag_hash.hexdigest()
|
| 492 |
+
|
| 493 |
+
if url.endswith(".h5"):
|
| 494 |
+
filename += ".h5"
|
| 495 |
+
|
| 496 |
+
return filename
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def filename_to_url(filename, cache_dir=None):
|
| 500 |
+
"""
|
| 501 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
| 502 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
| 503 |
+
"""
|
| 504 |
+
if cache_dir is None:
|
| 505 |
+
cache_dir = TRANSFORMERS_CACHE
|
| 506 |
+
if isinstance(cache_dir, Path):
|
| 507 |
+
cache_dir = str(cache_dir)
|
| 508 |
+
|
| 509 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 510 |
+
if not os.path.exists(cache_path):
|
| 511 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
| 512 |
+
|
| 513 |
+
meta_path = cache_path + ".json"
|
| 514 |
+
if not os.path.exists(meta_path):
|
| 515 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
| 516 |
+
|
| 517 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
| 518 |
+
metadata = json.load(meta_file)
|
| 519 |
+
url = metadata["url"]
|
| 520 |
+
etag = metadata["etag"]
|
| 521 |
+
|
| 522 |
+
return url, etag
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def cached_path(
|
| 526 |
+
url_or_filename,
|
| 527 |
+
cache_dir=None,
|
| 528 |
+
force_download=False,
|
| 529 |
+
proxies=None,
|
| 530 |
+
resume_download=False,
|
| 531 |
+
user_agent: Union[Dict, str, None] = None,
|
| 532 |
+
extract_compressed_file=False,
|
| 533 |
+
force_extract=False,
|
| 534 |
+
local_files_only=False,
|
| 535 |
+
) -> Optional[str]:
|
| 536 |
+
"""
|
| 537 |
+
Given something that might be a URL (or might be a local path),
|
| 538 |
+
determine which. If it's a URL, download the file and cache it, and
|
| 539 |
+
return the path to the cached file. If it's already a local path,
|
| 540 |
+
make sure the file exists and then return the path.
|
| 541 |
+
Args:
|
| 542 |
+
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
|
| 543 |
+
force_download: if True, re-dowload the file even if it's already cached in the cache dir.
|
| 544 |
+
resume_download: if True, resume the download if incompletly recieved file is found.
|
| 545 |
+
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
|
| 546 |
+
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
|
| 547 |
+
file in a folder along the archive.
|
| 548 |
+
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
|
| 549 |
+
re-extract the archive and overide the folder where it was extracted.
|
| 550 |
+
|
| 551 |
+
Return:
|
| 552 |
+
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
| 553 |
+
Local path (string) otherwise
|
| 554 |
+
"""
|
| 555 |
+
if cache_dir is None:
|
| 556 |
+
cache_dir = TRANSFORMERS_CACHE
|
| 557 |
+
if isinstance(url_or_filename, Path):
|
| 558 |
+
url_or_filename = str(url_or_filename)
|
| 559 |
+
if isinstance(cache_dir, Path):
|
| 560 |
+
cache_dir = str(cache_dir)
|
| 561 |
+
|
| 562 |
+
if is_remote_url(url_or_filename):
|
| 563 |
+
# URL, so get it from the cache (downloading if necessary)
|
| 564 |
+
output_path = get_from_cache(
|
| 565 |
+
url_or_filename,
|
| 566 |
+
cache_dir=cache_dir,
|
| 567 |
+
force_download=force_download,
|
| 568 |
+
proxies=proxies,
|
| 569 |
+
resume_download=resume_download,
|
| 570 |
+
user_agent=user_agent,
|
| 571 |
+
local_files_only=local_files_only,
|
| 572 |
+
)
|
| 573 |
+
elif os.path.exists(url_or_filename):
|
| 574 |
+
# File, and it exists.
|
| 575 |
+
output_path = url_or_filename
|
| 576 |
+
elif urlparse(url_or_filename).scheme == "":
|
| 577 |
+
# File, but it doesn't exist.
|
| 578 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
| 579 |
+
else:
|
| 580 |
+
# Something unknown
|
| 581 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
| 582 |
+
|
| 583 |
+
if extract_compressed_file:
|
| 584 |
+
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
|
| 585 |
+
return output_path
|
| 586 |
+
|
| 587 |
+
# Path where we extract compressed archives
|
| 588 |
+
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
|
| 589 |
+
output_dir, output_file = os.path.split(output_path)
|
| 590 |
+
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
|
| 591 |
+
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
|
| 592 |
+
|
| 593 |
+
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
|
| 594 |
+
return output_path_extracted
|
| 595 |
+
|
| 596 |
+
# Prevent parallel extractions
|
| 597 |
+
lock_path = output_path + ".lock"
|
| 598 |
+
with FileLock(lock_path):
|
| 599 |
+
shutil.rmtree(output_path_extracted, ignore_errors=True)
|
| 600 |
+
os.makedirs(output_path_extracted)
|
| 601 |
+
if is_zipfile(output_path):
|
| 602 |
+
with ZipFile(output_path, "r") as zip_file:
|
| 603 |
+
zip_file.extractall(output_path_extracted)
|
| 604 |
+
zip_file.close()
|
| 605 |
+
elif tarfile.is_tarfile(output_path):
|
| 606 |
+
tar_file = tarfile.open(output_path)
|
| 607 |
+
tar_file.extractall(output_path_extracted)
|
| 608 |
+
tar_file.close()
|
| 609 |
+
else:
|
| 610 |
+
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
|
| 611 |
+
|
| 612 |
+
return output_path_extracted
|
| 613 |
+
|
| 614 |
+
return output_path
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
|
| 618 |
+
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
|
| 619 |
+
if is_torch_available():
|
| 620 |
+
ua += "; torch/{}".format(torch.__version__)
|
| 621 |
+
if is_tf_available():
|
| 622 |
+
ua += "; tensorflow/{}".format(tf.__version__)
|
| 623 |
+
if isinstance(user_agent, dict):
|
| 624 |
+
ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
|
| 625 |
+
elif isinstance(user_agent, str):
|
| 626 |
+
ua += "; " + user_agent
|
| 627 |
+
headers = {"user-agent": ua}
|
| 628 |
+
if resume_size > 0:
|
| 629 |
+
headers["Range"] = "bytes=%d-" % (resume_size,)
|
| 630 |
+
response = requests.get(url, stream=True, proxies=proxies, headers=headers)
|
| 631 |
+
if response.status_code == 416: # Range not satisfiable
|
| 632 |
+
return
|
| 633 |
+
content_length = response.headers.get("Content-Length")
|
| 634 |
+
total = resume_size + int(content_length) if content_length is not None else None
|
| 635 |
+
progress = tqdm(
|
| 636 |
+
unit="B",
|
| 637 |
+
unit_scale=True,
|
| 638 |
+
total=total,
|
| 639 |
+
initial=resume_size,
|
| 640 |
+
desc="Downloading",
|
| 641 |
+
disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
|
| 642 |
+
)
|
| 643 |
+
for chunk in response.iter_content(chunk_size=1024):
|
| 644 |
+
if chunk: # filter out keep-alive new chunks
|
| 645 |
+
progress.update(len(chunk))
|
| 646 |
+
temp_file.write(chunk)
|
| 647 |
+
progress.close()
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def get_from_cache(
|
| 651 |
+
url,
|
| 652 |
+
cache_dir=None,
|
| 653 |
+
force_download=False,
|
| 654 |
+
proxies=None,
|
| 655 |
+
etag_timeout=10,
|
| 656 |
+
resume_download=False,
|
| 657 |
+
user_agent: Union[Dict, str, None] = None,
|
| 658 |
+
local_files_only=False,
|
| 659 |
+
) -> Optional[str]:
|
| 660 |
+
"""
|
| 661 |
+
Given a URL, look for the corresponding file in the local cache.
|
| 662 |
+
If it's not there, download it. Then return the path to the cached file.
|
| 663 |
+
|
| 664 |
+
Return:
|
| 665 |
+
None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
|
| 666 |
+
Local path (string) otherwise
|
| 667 |
+
"""
|
| 668 |
+
if cache_dir is None:
|
| 669 |
+
cache_dir = TRANSFORMERS_CACHE
|
| 670 |
+
if isinstance(cache_dir, Path):
|
| 671 |
+
cache_dir = str(cache_dir)
|
| 672 |
+
|
| 673 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 674 |
+
|
| 675 |
+
etag = None
|
| 676 |
+
if not local_files_only:
|
| 677 |
+
try:
|
| 678 |
+
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
|
| 679 |
+
if response.status_code == 200:
|
| 680 |
+
etag = response.headers.get("ETag")
|
| 681 |
+
except (EnvironmentError, requests.exceptions.Timeout):
|
| 682 |
+
# etag is already None
|
| 683 |
+
pass
|
| 684 |
+
|
| 685 |
+
filename = url_to_filename(url, etag)
|
| 686 |
+
|
| 687 |
+
# get cache path to put the file
|
| 688 |
+
cache_path = os.path.join(cache_dir, filename)
|
| 689 |
+
|
| 690 |
+
# etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
|
| 691 |
+
# try to get the last downloaded one
|
| 692 |
+
if etag is None:
|
| 693 |
+
if os.path.exists(cache_path):
|
| 694 |
+
return cache_path
|
| 695 |
+
else:
|
| 696 |
+
matching_files = [
|
| 697 |
+
file
|
| 698 |
+
for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
|
| 699 |
+
if not file.endswith(".json") and not file.endswith(".lock")
|
| 700 |
+
]
|
| 701 |
+
if len(matching_files) > 0:
|
| 702 |
+
return os.path.join(cache_dir, matching_files[-1])
|
| 703 |
+
else:
|
| 704 |
+
# If files cannot be found and local_files_only=True,
|
| 705 |
+
# the models might've been found if local_files_only=False
|
| 706 |
+
# Notify the user about that
|
| 707 |
+
if local_files_only:
|
| 708 |
+
raise ValueError(
|
| 709 |
+
"Cannot find the requested files in the cached path and outgoing traffic has been"
|
| 710 |
+
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
|
| 711 |
+
" to False."
|
| 712 |
+
)
|
| 713 |
+
return None
|
| 714 |
+
|
| 715 |
+
# From now on, etag is not None.
|
| 716 |
+
if os.path.exists(cache_path) and not force_download:
|
| 717 |
+
return cache_path
|
| 718 |
+
|
| 719 |
+
# Prevent parallel downloads of the same file with a lock.
|
| 720 |
+
lock_path = cache_path + ".lock"
|
| 721 |
+
with FileLock(lock_path):
|
| 722 |
+
|
| 723 |
+
# If the download just completed while the lock was activated.
|
| 724 |
+
if os.path.exists(cache_path) and not force_download:
|
| 725 |
+
# Even if returning early like here, the lock will be released.
|
| 726 |
+
return cache_path
|
| 727 |
+
|
| 728 |
+
if resume_download:
|
| 729 |
+
incomplete_path = cache_path + ".incomplete"
|
| 730 |
+
|
| 731 |
+
@contextmanager
|
| 732 |
+
def _resumable_file_manager():
|
| 733 |
+
with open(incomplete_path, "a+b") as f:
|
| 734 |
+
yield f
|
| 735 |
+
|
| 736 |
+
temp_file_manager = _resumable_file_manager
|
| 737 |
+
if os.path.exists(incomplete_path):
|
| 738 |
+
resume_size = os.stat(incomplete_path).st_size
|
| 739 |
+
else:
|
| 740 |
+
resume_size = 0
|
| 741 |
+
else:
|
| 742 |
+
temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
|
| 743 |
+
resume_size = 0
|
| 744 |
+
|
| 745 |
+
# Download to temporary file, then copy to cache dir once finished.
|
| 746 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
| 747 |
+
with temp_file_manager() as temp_file:
|
| 748 |
+
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
|
| 749 |
+
|
| 750 |
+
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
|
| 751 |
+
|
| 752 |
+
logger.info("storing %s in cache at %s", url, cache_path)
|
| 753 |
+
os.replace(temp_file.name, cache_path)
|
| 754 |
+
|
| 755 |
+
logger.info("creating metadata file for %s", cache_path)
|
| 756 |
+
meta = {"url": url, "etag": etag}
|
| 757 |
+
meta_path = cache_path + ".json"
|
| 758 |
+
with open(meta_path, "w") as meta_file:
|
| 759 |
+
json.dump(meta, meta_file)
|
| 760 |
+
|
| 761 |
+
return cache_path
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class cached_property(property):
|
| 765 |
+
"""
|
| 766 |
+
Descriptor that mimics @property but caches output in member variable.
|
| 767 |
+
|
| 768 |
+
From tensorflow_datasets
|
| 769 |
+
|
| 770 |
+
Built-in in functools from Python 3.8.
|
| 771 |
+
"""
|
| 772 |
+
|
| 773 |
+
def __get__(self, obj, objtype=None):
|
| 774 |
+
# See docs.python.org/3/howto/descriptor.html#properties
|
| 775 |
+
if obj is None:
|
| 776 |
+
return self
|
| 777 |
+
if self.fget is None:
|
| 778 |
+
raise AttributeError("unreadable attribute")
|
| 779 |
+
attr = "__cached_" + self.fget.__name__
|
| 780 |
+
cached = getattr(obj, attr, None)
|
| 781 |
+
if cached is None:
|
| 782 |
+
cached = self.fget(obj)
|
| 783 |
+
setattr(obj, attr, cached)
|
| 784 |
+
return cached
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
def torch_required(func):
|
| 788 |
+
# Chose a different decorator name than in tests so it's clear they are not the same.
|
| 789 |
+
@wraps(func)
|
| 790 |
+
def wrapper(*args, **kwargs):
|
| 791 |
+
if is_torch_available():
|
| 792 |
+
return func(*args, **kwargs)
|
| 793 |
+
else:
|
| 794 |
+
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
| 795 |
+
|
| 796 |
+
return wrapper
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def tf_required(func):
|
| 800 |
+
# Chose a different decorator name than in tests so it's clear they are not the same.
|
| 801 |
+
@wraps(func)
|
| 802 |
+
def wrapper(*args, **kwargs):
|
| 803 |
+
if is_tf_available():
|
| 804 |
+
return func(*args, **kwargs)
|
| 805 |
+
else:
|
| 806 |
+
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
| 807 |
+
|
| 808 |
+
return wrapper
|
CGFormer/bert/generation_utils.py
ADDED
|
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
from typing import Iterable, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
from torch.nn import functional as F
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class GenerationMixin:
|
| 29 |
+
"""
|
| 30 |
+
A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 34 |
+
return {"input_ids": input_ids}
|
| 35 |
+
|
| 36 |
+
def adjust_logits_during_generation(self, logits, **kwargs):
|
| 37 |
+
return logits
|
| 38 |
+
|
| 39 |
+
def _use_cache(self, outputs, use_cache):
|
| 40 |
+
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
|
| 41 |
+
if len(outputs) <= 1 or use_cache is False:
|
| 42 |
+
return False
|
| 43 |
+
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
|
| 44 |
+
return False
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
|
| 48 |
+
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
| 49 |
+
for i in range(batch_size * num_beams):
|
| 50 |
+
for previous_token in set(prev_output_tokens[i].tolist()):
|
| 51 |
+
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
| 52 |
+
if lprobs[i, previous_token] < 0:
|
| 53 |
+
lprobs[i, previous_token] *= repetition_penalty
|
| 54 |
+
else:
|
| 55 |
+
lprobs[i, previous_token] /= repetition_penalty
|
| 56 |
+
|
| 57 |
+
def postprocess_next_token_scores(
|
| 58 |
+
self,
|
| 59 |
+
scores,
|
| 60 |
+
input_ids,
|
| 61 |
+
no_repeat_ngram_size,
|
| 62 |
+
bad_words_ids,
|
| 63 |
+
cur_len,
|
| 64 |
+
min_length,
|
| 65 |
+
max_length,
|
| 66 |
+
eos_token_id,
|
| 67 |
+
repetition_penalty,
|
| 68 |
+
batch_size,
|
| 69 |
+
num_beams,
|
| 70 |
+
):
|
| 71 |
+
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
| 72 |
+
if repetition_penalty != 1.0:
|
| 73 |
+
self.enforce_repetition_penalty_(
|
| 74 |
+
scores, batch_size, num_beams, input_ids, repetition_penalty,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# set eos token prob to zero if min_length is not reached
|
| 78 |
+
if eos_token_id is not None and cur_len < min_length:
|
| 79 |
+
scores[:, eos_token_id] = -float("inf")
|
| 80 |
+
|
| 81 |
+
if no_repeat_ngram_size > 0:
|
| 82 |
+
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
| 83 |
+
num_batch_hypotheses = batch_size * num_beams
|
| 84 |
+
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
| 85 |
+
banned_batch_tokens = calc_banned_ngram_tokens(
|
| 86 |
+
input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
| 87 |
+
)
|
| 88 |
+
for i, banned_tokens in enumerate(banned_batch_tokens):
|
| 89 |
+
scores[i, banned_tokens] = -float("inf")
|
| 90 |
+
|
| 91 |
+
if bad_words_ids is not None:
|
| 92 |
+
# calculate a list of banned tokens according to bad words
|
| 93 |
+
banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
|
| 94 |
+
|
| 95 |
+
for i, banned_tokens in enumerate(banned_tokens):
|
| 96 |
+
scores[i, banned_tokens] = -float("inf")
|
| 97 |
+
|
| 98 |
+
return scores
|
| 99 |
+
|
| 100 |
+
@torch.no_grad()
|
| 101 |
+
def generate(
|
| 102 |
+
self,
|
| 103 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 104 |
+
max_length: Optional[int] = None,
|
| 105 |
+
min_length: Optional[int] = None,
|
| 106 |
+
do_sample: Optional[bool] = None,
|
| 107 |
+
early_stopping: Optional[bool] = None,
|
| 108 |
+
num_beams: Optional[int] = None,
|
| 109 |
+
temperature: Optional[float] = None,
|
| 110 |
+
top_k: Optional[int] = None,
|
| 111 |
+
top_p: Optional[float] = None,
|
| 112 |
+
repetition_penalty: Optional[float] = None,
|
| 113 |
+
bad_words_ids: Optional[Iterable[int]] = None,
|
| 114 |
+
bos_token_id: Optional[int] = None,
|
| 115 |
+
pad_token_id: Optional[int] = None,
|
| 116 |
+
eos_token_id: Optional[int] = None,
|
| 117 |
+
length_penalty: Optional[float] = None,
|
| 118 |
+
no_repeat_ngram_size: Optional[int] = None,
|
| 119 |
+
num_return_sequences: Optional[int] = None,
|
| 120 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 121 |
+
decoder_start_token_id: Optional[int] = None,
|
| 122 |
+
use_cache: Optional[bool] = None,
|
| 123 |
+
**model_specific_kwargs
|
| 124 |
+
) -> torch.LongTensor:
|
| 125 |
+
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
| 126 |
+
|
| 127 |
+
Adapted in part from `Facebook's XLM beam search code`_.
|
| 128 |
+
|
| 129 |
+
.. _`Facebook's XLM beam search code`:
|
| 130 |
+
https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
Parameters:
|
| 134 |
+
|
| 135 |
+
input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
|
| 136 |
+
The sequence used as a prompt for the generation. If `None` the method initializes
|
| 137 |
+
it as an empty `torch.LongTensor` of shape `(1,)`.
|
| 138 |
+
|
| 139 |
+
max_length: (`optional`) int
|
| 140 |
+
The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
|
| 141 |
+
|
| 142 |
+
min_length: (`optional`) int
|
| 143 |
+
The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
|
| 144 |
+
|
| 145 |
+
do_sample: (`optional`) bool
|
| 146 |
+
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
|
| 147 |
+
|
| 148 |
+
early_stopping: (`optional`) bool
|
| 149 |
+
if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
|
| 150 |
+
|
| 151 |
+
num_beams: (`optional`) int
|
| 152 |
+
Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
|
| 153 |
+
|
| 154 |
+
temperature: (`optional`) float
|
| 155 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
| 156 |
+
|
| 157 |
+
top_k: (`optional`) int
|
| 158 |
+
The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
| 159 |
+
|
| 160 |
+
top_p: (`optional`) float
|
| 161 |
+
The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
| 162 |
+
|
| 163 |
+
repetition_penalty: (`optional`) float
|
| 164 |
+
The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
|
| 165 |
+
|
| 166 |
+
pad_token_id: (`optional`) int
|
| 167 |
+
Padding token. Default to specicic model pad_token_id or None if it does not exist.
|
| 168 |
+
|
| 169 |
+
bos_token_id: (`optional`) int
|
| 170 |
+
BOS token. Defaults to `bos_token_id` as defined in the models config.
|
| 171 |
+
|
| 172 |
+
eos_token_id: (`optional`) int
|
| 173 |
+
EOS token. Defaults to `eos_token_id` as defined in the models config.
|
| 174 |
+
|
| 175 |
+
length_penalty: (`optional`) float
|
| 176 |
+
Exponential penalty to the length. Default to 1.
|
| 177 |
+
|
| 178 |
+
no_repeat_ngram_size: (`optional`) int
|
| 179 |
+
If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
|
| 180 |
+
bad_words_ids: (`optional`) list of lists of int
|
| 181 |
+
`bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
|
| 182 |
+
|
| 183 |
+
num_return_sequences: (`optional`) int
|
| 184 |
+
The number of independently computed returned sequences for each element in the batch. Default to 1.
|
| 185 |
+
|
| 186 |
+
attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
|
| 187 |
+
Mask to avoid performing attention on padding token indices.
|
| 188 |
+
Mask values selected in ``[0, 1]``:
|
| 189 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 190 |
+
Defaults to `None`.
|
| 191 |
+
|
| 192 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 193 |
+
|
| 194 |
+
decoder_start_token_id=None: (`optional`) int
|
| 195 |
+
If an encoder-decoder model starts decoding with a different token than BOS.
|
| 196 |
+
Defaults to `None` and is changed to `BOS` later.
|
| 197 |
+
|
| 198 |
+
use_cache: (`optional`) bool
|
| 199 |
+
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
|
| 200 |
+
|
| 201 |
+
model_specific_kwargs: (`optional`) dict
|
| 202 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model.
|
| 203 |
+
|
| 204 |
+
Return:
|
| 205 |
+
|
| 206 |
+
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
| 207 |
+
sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
|
| 208 |
+
|
| 209 |
+
Examples::
|
| 210 |
+
|
| 211 |
+
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
| 212 |
+
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
| 213 |
+
outputs = model.generate(max_length=40) # do greedy decoding
|
| 214 |
+
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
| 215 |
+
|
| 216 |
+
tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
|
| 217 |
+
model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
|
| 218 |
+
input_context = 'The dog'
|
| 219 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
| 220 |
+
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
|
| 221 |
+
for i in range(3): # 3 output sequences were generated
|
| 222 |
+
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
| 223 |
+
|
| 224 |
+
tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
|
| 225 |
+
model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
|
| 226 |
+
input_context = 'The dog'
|
| 227 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
| 228 |
+
outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
|
| 229 |
+
for i in range(3): # 3 output sequences were generated
|
| 230 |
+
print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
|
| 231 |
+
|
| 232 |
+
tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
|
| 233 |
+
model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
|
| 234 |
+
input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
|
| 235 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
| 236 |
+
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
|
| 237 |
+
print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
|
| 238 |
+
|
| 239 |
+
tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
|
| 240 |
+
model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
|
| 241 |
+
input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
|
| 242 |
+
bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
|
| 243 |
+
input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
|
| 244 |
+
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
# We cannot generate if the model does not have a LM head
|
| 248 |
+
if self.get_output_embeddings() is None:
|
| 249 |
+
raise AttributeError(
|
| 250 |
+
"You tried to generate sequences with a model that does not have a LM Head."
|
| 251 |
+
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
max_length = max_length if max_length is not None else self.config.max_length
|
| 255 |
+
min_length = min_length if min_length is not None else self.config.min_length
|
| 256 |
+
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
| 257 |
+
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
|
| 258 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 259 |
+
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
| 260 |
+
temperature = temperature if temperature is not None else self.config.temperature
|
| 261 |
+
top_k = top_k if top_k is not None else self.config.top_k
|
| 262 |
+
top_p = top_p if top_p is not None else self.config.top_p
|
| 263 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
|
| 264 |
+
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
| 265 |
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
| 266 |
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
| 267 |
+
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
|
| 268 |
+
no_repeat_ngram_size = (
|
| 269 |
+
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
|
| 270 |
+
)
|
| 271 |
+
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
| 272 |
+
num_return_sequences = (
|
| 273 |
+
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
|
| 274 |
+
)
|
| 275 |
+
decoder_start_token_id = (
|
| 276 |
+
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
if input_ids is not None:
|
| 280 |
+
batch_size = input_ids.shape[0] # overriden by the input batch_size
|
| 281 |
+
else:
|
| 282 |
+
batch_size = 1
|
| 283 |
+
|
| 284 |
+
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
|
| 285 |
+
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
|
| 286 |
+
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
|
| 287 |
+
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
|
| 288 |
+
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
|
| 289 |
+
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
|
| 290 |
+
assert temperature > 0, "`temperature` should be strictly positive."
|
| 291 |
+
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
|
| 292 |
+
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
|
| 293 |
+
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
|
| 294 |
+
assert input_ids is not None or (
|
| 295 |
+
isinstance(bos_token_id, int) and bos_token_id >= 0
|
| 296 |
+
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
|
| 297 |
+
assert pad_token_id is None or (
|
| 298 |
+
isinstance(pad_token_id, int) and (pad_token_id >= 0)
|
| 299 |
+
), "`pad_token_id` should be a positive integer."
|
| 300 |
+
assert (eos_token_id is None) or (
|
| 301 |
+
isinstance(eos_token_id, int) and (eos_token_id >= 0)
|
| 302 |
+
), "`eos_token_id` should be a positive integer."
|
| 303 |
+
assert length_penalty > 0, "`length_penalty` should be strictly positive."
|
| 304 |
+
assert (
|
| 305 |
+
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
|
| 306 |
+
), "`no_repeat_ngram_size` should be a positive integer."
|
| 307 |
+
assert (
|
| 308 |
+
isinstance(num_return_sequences, int) and num_return_sequences > 0
|
| 309 |
+
), "`num_return_sequences` should be a strictly positive integer."
|
| 310 |
+
assert (
|
| 311 |
+
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
|
| 312 |
+
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
|
| 313 |
+
|
| 314 |
+
if input_ids is None:
|
| 315 |
+
assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
|
| 316 |
+
"you should either supply a context to complete as `input_ids` input "
|
| 317 |
+
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
|
| 318 |
+
)
|
| 319 |
+
input_ids = torch.full(
|
| 320 |
+
(batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
|
| 324 |
+
|
| 325 |
+
# not allow to duplicate outputs when greedy decoding
|
| 326 |
+
if do_sample is False:
|
| 327 |
+
if num_beams == 1:
|
| 328 |
+
# no_beam_search greedy generation conditions
|
| 329 |
+
assert (
|
| 330 |
+
num_return_sequences == 1
|
| 331 |
+
), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
|
| 332 |
+
|
| 333 |
+
else:
|
| 334 |
+
# beam_search greedy generation conditions
|
| 335 |
+
assert (
|
| 336 |
+
num_beams >= num_return_sequences
|
| 337 |
+
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
| 338 |
+
|
| 339 |
+
# create attention mask if necessary
|
| 340 |
+
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
| 341 |
+
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
|
| 342 |
+
attention_mask = input_ids.ne(pad_token_id).long()
|
| 343 |
+
elif attention_mask is None:
|
| 344 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
| 345 |
+
|
| 346 |
+
# set pad_token_id to eos_token_id if not set. Important that this is done after
|
| 347 |
+
# attention_mask is created
|
| 348 |
+
if pad_token_id is None and eos_token_id is not None:
|
| 349 |
+
logger.warning(
|
| 350 |
+
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
|
| 351 |
+
)
|
| 352 |
+
pad_token_id = eos_token_id
|
| 353 |
+
|
| 354 |
+
# current position and vocab size
|
| 355 |
+
if hasattr(self.config, "vocab_size"):
|
| 356 |
+
vocab_size = self.config.vocab_size
|
| 357 |
+
elif (
|
| 358 |
+
self.config.is_encoder_decoder
|
| 359 |
+
and hasattr(self.config, "decoder")
|
| 360 |
+
and hasattr(self.config.decoder, "vocab_size")
|
| 361 |
+
):
|
| 362 |
+
vocab_size = self.config.decoder.vocab_size
|
| 363 |
+
|
| 364 |
+
# set effective batch size and effective batch multiplier according to do_sample
|
| 365 |
+
if do_sample:
|
| 366 |
+
effective_batch_size = batch_size * num_return_sequences
|
| 367 |
+
effective_batch_mult = num_return_sequences
|
| 368 |
+
else:
|
| 369 |
+
effective_batch_size = batch_size
|
| 370 |
+
effective_batch_mult = 1
|
| 371 |
+
|
| 372 |
+
if self.config.is_encoder_decoder:
|
| 373 |
+
if decoder_start_token_id is None:
|
| 374 |
+
decoder_start_token_id = bos_token_id
|
| 375 |
+
|
| 376 |
+
assert (
|
| 377 |
+
decoder_start_token_id is not None
|
| 378 |
+
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
| 379 |
+
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
| 380 |
+
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
| 381 |
+
|
| 382 |
+
# get encoder and store encoder outputs
|
| 383 |
+
encoder = self.get_encoder()
|
| 384 |
+
|
| 385 |
+
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
|
| 386 |
+
|
| 387 |
+
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
| 388 |
+
if num_return_sequences > 1 or num_beams > 1:
|
| 389 |
+
input_ids_len = input_ids.shape[-1]
|
| 390 |
+
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
|
| 391 |
+
attention_mask = attention_mask.unsqueeze(1).expand(
|
| 392 |
+
batch_size, effective_batch_mult * num_beams, input_ids_len
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
input_ids = input_ids.contiguous().view(
|
| 396 |
+
effective_batch_size * num_beams, input_ids_len
|
| 397 |
+
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
| 398 |
+
attention_mask = attention_mask.contiguous().view(
|
| 399 |
+
effective_batch_size * num_beams, input_ids_len
|
| 400 |
+
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
| 401 |
+
|
| 402 |
+
if self.config.is_encoder_decoder:
|
| 403 |
+
# create empty decoder_input_ids
|
| 404 |
+
input_ids = torch.full(
|
| 405 |
+
(effective_batch_size * num_beams, 1),
|
| 406 |
+
decoder_start_token_id,
|
| 407 |
+
dtype=torch.long,
|
| 408 |
+
device=next(self.parameters()).device,
|
| 409 |
+
)
|
| 410 |
+
cur_len = 1
|
| 411 |
+
|
| 412 |
+
assert (
|
| 413 |
+
batch_size == encoder_outputs[0].shape[0]
|
| 414 |
+
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
| 415 |
+
|
| 416 |
+
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
| 417 |
+
expanded_batch_idxs = (
|
| 418 |
+
torch.arange(batch_size)
|
| 419 |
+
.view(-1, 1)
|
| 420 |
+
.repeat(1, num_beams * effective_batch_mult)
|
| 421 |
+
.view(-1)
|
| 422 |
+
.to(input_ids.device)
|
| 423 |
+
)
|
| 424 |
+
# expand encoder_outputs
|
| 425 |
+
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
|
| 426 |
+
|
| 427 |
+
else:
|
| 428 |
+
encoder_outputs = None
|
| 429 |
+
cur_len = input_ids.shape[-1]
|
| 430 |
+
|
| 431 |
+
assert (
|
| 432 |
+
cur_len < max_length
|
| 433 |
+
), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
|
| 434 |
+
|
| 435 |
+
if num_beams > 1:
|
| 436 |
+
output = self._generate_beam_search(
|
| 437 |
+
input_ids,
|
| 438 |
+
cur_len=cur_len,
|
| 439 |
+
max_length=max_length,
|
| 440 |
+
min_length=min_length,
|
| 441 |
+
do_sample=do_sample,
|
| 442 |
+
early_stopping=early_stopping,
|
| 443 |
+
temperature=temperature,
|
| 444 |
+
top_k=top_k,
|
| 445 |
+
top_p=top_p,
|
| 446 |
+
repetition_penalty=repetition_penalty,
|
| 447 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 448 |
+
bad_words_ids=bad_words_ids,
|
| 449 |
+
pad_token_id=pad_token_id,
|
| 450 |
+
eos_token_id=eos_token_id,
|
| 451 |
+
batch_size=effective_batch_size,
|
| 452 |
+
num_return_sequences=num_return_sequences,
|
| 453 |
+
length_penalty=length_penalty,
|
| 454 |
+
num_beams=num_beams,
|
| 455 |
+
vocab_size=vocab_size,
|
| 456 |
+
encoder_outputs=encoder_outputs,
|
| 457 |
+
attention_mask=attention_mask,
|
| 458 |
+
use_cache=use_cache,
|
| 459 |
+
model_specific_kwargs=model_specific_kwargs,
|
| 460 |
+
)
|
| 461 |
+
else:
|
| 462 |
+
output = self._generate_no_beam_search(
|
| 463 |
+
input_ids,
|
| 464 |
+
cur_len=cur_len,
|
| 465 |
+
max_length=max_length,
|
| 466 |
+
min_length=min_length,
|
| 467 |
+
do_sample=do_sample,
|
| 468 |
+
temperature=temperature,
|
| 469 |
+
top_k=top_k,
|
| 470 |
+
top_p=top_p,
|
| 471 |
+
repetition_penalty=repetition_penalty,
|
| 472 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 473 |
+
bad_words_ids=bad_words_ids,
|
| 474 |
+
pad_token_id=pad_token_id,
|
| 475 |
+
eos_token_id=eos_token_id,
|
| 476 |
+
batch_size=effective_batch_size,
|
| 477 |
+
encoder_outputs=encoder_outputs,
|
| 478 |
+
attention_mask=attention_mask,
|
| 479 |
+
use_cache=use_cache,
|
| 480 |
+
model_specific_kwargs=model_specific_kwargs,
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
return output
|
| 484 |
+
|
| 485 |
+
def _generate_no_beam_search(
|
| 486 |
+
self,
|
| 487 |
+
input_ids,
|
| 488 |
+
cur_len,
|
| 489 |
+
max_length,
|
| 490 |
+
min_length,
|
| 491 |
+
do_sample,
|
| 492 |
+
temperature,
|
| 493 |
+
top_k,
|
| 494 |
+
top_p,
|
| 495 |
+
repetition_penalty,
|
| 496 |
+
no_repeat_ngram_size,
|
| 497 |
+
bad_words_ids,
|
| 498 |
+
pad_token_id,
|
| 499 |
+
eos_token_id,
|
| 500 |
+
batch_size,
|
| 501 |
+
encoder_outputs,
|
| 502 |
+
attention_mask,
|
| 503 |
+
use_cache,
|
| 504 |
+
model_specific_kwargs,
|
| 505 |
+
):
|
| 506 |
+
""" Generate sequences for each example without beam search (num_beams == 1).
|
| 507 |
+
All returned sequence are generated independantly.
|
| 508 |
+
"""
|
| 509 |
+
# length of generated sentences / unfinished sentences
|
| 510 |
+
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
| 511 |
+
sent_lengths = input_ids.new(batch_size).fill_(max_length)
|
| 512 |
+
|
| 513 |
+
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
| 514 |
+
|
| 515 |
+
while cur_len < max_length:
|
| 516 |
+
model_inputs = self.prepare_inputs_for_generation(
|
| 517 |
+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
outputs = self(**model_inputs)
|
| 521 |
+
next_token_logits = outputs[0][:, -1, :]
|
| 522 |
+
|
| 523 |
+
scores = self.postprocess_next_token_scores(
|
| 524 |
+
scores=next_token_logits,
|
| 525 |
+
input_ids=input_ids,
|
| 526 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 527 |
+
bad_words_ids=bad_words_ids,
|
| 528 |
+
cur_len=cur_len,
|
| 529 |
+
min_length=min_length,
|
| 530 |
+
max_length=max_length,
|
| 531 |
+
eos_token_id=eos_token_id,
|
| 532 |
+
repetition_penalty=repetition_penalty,
|
| 533 |
+
batch_size=batch_size,
|
| 534 |
+
num_beams=1,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# if model has past, then set the past variable to speed up decoding
|
| 538 |
+
if self._use_cache(outputs, use_cache):
|
| 539 |
+
past = outputs[1]
|
| 540 |
+
|
| 541 |
+
if do_sample:
|
| 542 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
| 543 |
+
if temperature != 1.0:
|
| 544 |
+
scores = scores / temperature
|
| 545 |
+
# Top-p/top-k filtering
|
| 546 |
+
next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
|
| 547 |
+
# Sample
|
| 548 |
+
probs = F.softmax(next_token_logscores, dim=-1)
|
| 549 |
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 550 |
+
else:
|
| 551 |
+
# Greedy decoding
|
| 552 |
+
next_token = torch.argmax(next_token_logits, dim=-1)
|
| 553 |
+
|
| 554 |
+
# update generations and finished sentences
|
| 555 |
+
if eos_token_id is not None:
|
| 556 |
+
# pad finished sentences if eos_token_id exist
|
| 557 |
+
tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
|
| 558 |
+
else:
|
| 559 |
+
tokens_to_add = next_token
|
| 560 |
+
|
| 561 |
+
# add token and increase length by one
|
| 562 |
+
input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
|
| 563 |
+
cur_len = cur_len + 1
|
| 564 |
+
|
| 565 |
+
if eos_token_id is not None:
|
| 566 |
+
eos_in_sents = tokens_to_add == eos_token_id
|
| 567 |
+
# if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
|
| 568 |
+
is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
|
| 569 |
+
sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
|
| 570 |
+
# unfinished_sents is set to zero if eos in sentence
|
| 571 |
+
unfinished_sents.mul_((~eos_in_sents).long())
|
| 572 |
+
|
| 573 |
+
# stop when there is a </s> in each sentence, or if we exceed the maximul length
|
| 574 |
+
if unfinished_sents.max() == 0:
|
| 575 |
+
break
|
| 576 |
+
|
| 577 |
+
# extend attention_mask for new generated input if only decoder
|
| 578 |
+
if self.config.is_encoder_decoder is False:
|
| 579 |
+
attention_mask = torch.cat(
|
| 580 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
return input_ids
|
| 584 |
+
|
| 585 |
+
def _generate_beam_search(
|
| 586 |
+
self,
|
| 587 |
+
input_ids,
|
| 588 |
+
cur_len,
|
| 589 |
+
max_length,
|
| 590 |
+
min_length,
|
| 591 |
+
do_sample,
|
| 592 |
+
early_stopping,
|
| 593 |
+
temperature,
|
| 594 |
+
top_k,
|
| 595 |
+
top_p,
|
| 596 |
+
repetition_penalty,
|
| 597 |
+
no_repeat_ngram_size,
|
| 598 |
+
bad_words_ids,
|
| 599 |
+
pad_token_id,
|
| 600 |
+
eos_token_id,
|
| 601 |
+
batch_size,
|
| 602 |
+
num_return_sequences,
|
| 603 |
+
length_penalty,
|
| 604 |
+
num_beams,
|
| 605 |
+
vocab_size,
|
| 606 |
+
encoder_outputs,
|
| 607 |
+
attention_mask,
|
| 608 |
+
use_cache,
|
| 609 |
+
model_specific_kwargs,
|
| 610 |
+
):
|
| 611 |
+
""" Generate sequences for each example with beam search.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
# generated hypotheses
|
| 615 |
+
generated_hyps = [
|
| 616 |
+
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
|
| 617 |
+
for _ in range(batch_size)
|
| 618 |
+
]
|
| 619 |
+
|
| 620 |
+
# scores for each sentence in the beam
|
| 621 |
+
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
| 622 |
+
|
| 623 |
+
# for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
|
| 624 |
+
if do_sample is False:
|
| 625 |
+
beam_scores[:, 1:] = -1e9
|
| 626 |
+
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
| 627 |
+
|
| 628 |
+
# cache compute states
|
| 629 |
+
past = (encoder_outputs, None) if encoder_outputs is not None else None
|
| 630 |
+
|
| 631 |
+
# done sentences
|
| 632 |
+
done = [False for _ in range(batch_size)]
|
| 633 |
+
|
| 634 |
+
while cur_len < max_length:
|
| 635 |
+
model_inputs = self.prepare_inputs_for_generation(
|
| 636 |
+
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
| 637 |
+
)
|
| 638 |
+
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
| 639 |
+
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
| 640 |
+
|
| 641 |
+
# if model has past, then set the past variable to speed up decoding
|
| 642 |
+
if self._use_cache(outputs, use_cache):
|
| 643 |
+
past = outputs[1]
|
| 644 |
+
if self.config.is_encoder_decoder and do_sample is False:
|
| 645 |
+
# TODO (PVP) still a bit hacky here - there might be a better solution
|
| 646 |
+
next_token_logits = self.adjust_logits_during_generation(
|
| 647 |
+
next_token_logits, cur_len=cur_len, max_length=max_length
|
| 648 |
+
)
|
| 649 |
+
|
| 650 |
+
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
|
| 651 |
+
|
| 652 |
+
scores = self.postprocess_next_token_scores(
|
| 653 |
+
scores=scores,
|
| 654 |
+
input_ids=input_ids,
|
| 655 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 656 |
+
bad_words_ids=bad_words_ids,
|
| 657 |
+
cur_len=cur_len,
|
| 658 |
+
min_length=min_length,
|
| 659 |
+
max_length=max_length,
|
| 660 |
+
eos_token_id=eos_token_id,
|
| 661 |
+
repetition_penalty=repetition_penalty,
|
| 662 |
+
batch_size=batch_size,
|
| 663 |
+
num_beams=num_beams,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
|
| 667 |
+
scores.shape, (batch_size * num_beams, vocab_size)
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if do_sample:
|
| 671 |
+
_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
| 672 |
+
# Temperature
|
| 673 |
+
if temperature != 1.0:
|
| 674 |
+
_scores = _scores / temperature
|
| 675 |
+
# Top-p/top-k filtering
|
| 676 |
+
_scores = top_k_top_p_filtering(
|
| 677 |
+
_scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
|
| 678 |
+
) # (batch_size * num_beams, vocab_size)
|
| 679 |
+
# re-organize to group the beam together to sample from all beam_idxs
|
| 680 |
+
_scores = _scores.contiguous().view(
|
| 681 |
+
batch_size, num_beams * vocab_size
|
| 682 |
+
) # (batch_size, num_beams * vocab_size)
|
| 683 |
+
|
| 684 |
+
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
|
| 685 |
+
probs = F.softmax(_scores, dim=-1)
|
| 686 |
+
next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
|
| 687 |
+
# Compute next scores
|
| 688 |
+
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
|
| 689 |
+
# sort the sampled vector to make sure that the first num_beams samples are the best
|
| 690 |
+
next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
|
| 691 |
+
next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
|
| 692 |
+
|
| 693 |
+
else:
|
| 694 |
+
next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
|
| 695 |
+
|
| 696 |
+
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
|
| 697 |
+
next_scores = next_scores.view(
|
| 698 |
+
batch_size, num_beams * vocab_size
|
| 699 |
+
) # (batch_size, num_beams * vocab_size)
|
| 700 |
+
|
| 701 |
+
next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
| 702 |
+
|
| 703 |
+
assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
|
| 704 |
+
|
| 705 |
+
# next batch beam content
|
| 706 |
+
next_batch_beam = []
|
| 707 |
+
|
| 708 |
+
# for each sentence
|
| 709 |
+
for batch_idx in range(batch_size):
|
| 710 |
+
|
| 711 |
+
# if we are done with this sentence, add a pad token
|
| 712 |
+
if done[batch_idx]:
|
| 713 |
+
assert (
|
| 714 |
+
len(generated_hyps[batch_idx]) >= num_beams
|
| 715 |
+
), "Batch can only be done if at least {} beams have been generated".format(num_beams)
|
| 716 |
+
assert (
|
| 717 |
+
eos_token_id is not None and pad_token_id is not None
|
| 718 |
+
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
|
| 719 |
+
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
|
| 720 |
+
continue
|
| 721 |
+
|
| 722 |
+
# next sentence beam content, this will get added to next_batch_beam
|
| 723 |
+
next_sent_beam = []
|
| 724 |
+
|
| 725 |
+
# next tokens for this sentence
|
| 726 |
+
for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
|
| 727 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx])
|
| 728 |
+
):
|
| 729 |
+
# get beam and token IDs
|
| 730 |
+
beam_id = beam_token_id // vocab_size
|
| 731 |
+
token_id = beam_token_id % vocab_size
|
| 732 |
+
|
| 733 |
+
effective_beam_id = batch_idx * num_beams + beam_id
|
| 734 |
+
# add to generated hypotheses if end of sentence
|
| 735 |
+
if (eos_token_id is not None) and (token_id.item() == eos_token_id):
|
| 736 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
| 737 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
|
| 738 |
+
if is_beam_token_worse_than_top_num_beams:
|
| 739 |
+
continue
|
| 740 |
+
generated_hyps[batch_idx].add(
|
| 741 |
+
input_ids[effective_beam_id].clone(), beam_token_score.item(),
|
| 742 |
+
)
|
| 743 |
+
else:
|
| 744 |
+
# add next predicted token since it is not eos_token
|
| 745 |
+
next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
|
| 746 |
+
|
| 747 |
+
# once the beam for next step is full, don't add more tokens to it.
|
| 748 |
+
if len(next_sent_beam) == num_beams:
|
| 749 |
+
break
|
| 750 |
+
|
| 751 |
+
# Check if we are done so that we can save a pad step if all(done)
|
| 752 |
+
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
|
| 753 |
+
next_scores[batch_idx].max().item(), cur_len
|
| 754 |
+
)
|
| 755 |
+
|
| 756 |
+
# update next beam content
|
| 757 |
+
assert len(next_sent_beam) == num_beams, "Beam should always be full"
|
| 758 |
+
next_batch_beam.extend(next_sent_beam)
|
| 759 |
+
assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
|
| 760 |
+
|
| 761 |
+
# stop when we are done with each sentence
|
| 762 |
+
if all(done):
|
| 763 |
+
break
|
| 764 |
+
|
| 765 |
+
# sanity check / prepare next batch
|
| 766 |
+
assert len(next_batch_beam) == batch_size * num_beams
|
| 767 |
+
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
|
| 768 |
+
beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
|
| 769 |
+
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
| 770 |
+
|
| 771 |
+
# re-order batch and update current length
|
| 772 |
+
input_ids = input_ids[beam_idx, :]
|
| 773 |
+
input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
|
| 774 |
+
cur_len = cur_len + 1
|
| 775 |
+
|
| 776 |
+
# re-order internal states
|
| 777 |
+
if past is not None:
|
| 778 |
+
past = self._reorder_cache(past, beam_idx)
|
| 779 |
+
|
| 780 |
+
# extend attention_mask for new generated input if only decoder
|
| 781 |
+
if self.config.is_encoder_decoder is False:
|
| 782 |
+
attention_mask = torch.cat(
|
| 783 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
| 787 |
+
for batch_idx in range(batch_size):
|
| 788 |
+
if done[batch_idx]:
|
| 789 |
+
continue
|
| 790 |
+
|
| 791 |
+
# test that beam scores match previously calculated scores if not eos and batch_idx not done
|
| 792 |
+
if eos_token_id is not None and all(
|
| 793 |
+
(token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
|
| 794 |
+
):
|
| 795 |
+
assert torch.all(
|
| 796 |
+
next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
|
| 797 |
+
), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
|
| 798 |
+
next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
# need to add best num_beams hypotheses to generated hyps
|
| 802 |
+
for beam_id in range(num_beams):
|
| 803 |
+
effective_beam_id = batch_idx * num_beams + beam_id
|
| 804 |
+
final_score = beam_scores[effective_beam_id].item()
|
| 805 |
+
final_tokens = input_ids[effective_beam_id]
|
| 806 |
+
generated_hyps[batch_idx].add(final_tokens, final_score)
|
| 807 |
+
|
| 808 |
+
# depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
|
| 809 |
+
output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
|
| 810 |
+
output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
|
| 811 |
+
|
| 812 |
+
# select the best hypotheses
|
| 813 |
+
sent_lengths = input_ids.new(output_batch_size)
|
| 814 |
+
best = []
|
| 815 |
+
|
| 816 |
+
# retrieve best hypotheses
|
| 817 |
+
for i, hypotheses in enumerate(generated_hyps):
|
| 818 |
+
sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
|
| 819 |
+
for j in range(output_num_return_sequences_per_batch):
|
| 820 |
+
effective_batch_idx = output_num_return_sequences_per_batch * i + j
|
| 821 |
+
best_hyp = sorted_hyps.pop()[1]
|
| 822 |
+
sent_lengths[effective_batch_idx] = len(best_hyp)
|
| 823 |
+
best.append(best_hyp)
|
| 824 |
+
|
| 825 |
+
# shorter batches are padded
|
| 826 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
| 827 |
+
assert pad_token_id is not None, "`Pad_token_id` has to be defined"
|
| 828 |
+
sent_max_len = min(sent_lengths.max().item() + 1, max_length)
|
| 829 |
+
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
|
| 830 |
+
|
| 831 |
+
# fill with hypothesis and eos_token_id if necessary
|
| 832 |
+
for i, hypo in enumerate(best):
|
| 833 |
+
decoded[i, : sent_lengths[i]] = hypo
|
| 834 |
+
if sent_lengths[i] < max_length:
|
| 835 |
+
decoded[i, sent_lengths[i]] = eos_token_id
|
| 836 |
+
else:
|
| 837 |
+
# none of the hypotheses have an eos_token
|
| 838 |
+
assert (len(hypo) == max_length for hypo in best)
|
| 839 |
+
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
|
| 840 |
+
|
| 841 |
+
return decoded
|
| 842 |
+
|
| 843 |
+
@staticmethod
|
| 844 |
+
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
| 845 |
+
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
|
| 849 |
+
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
| 850 |
+
if cur_len + 1 < no_repeat_ngram_size:
|
| 851 |
+
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
| 852 |
+
return [[] for _ in range(num_hypos)]
|
| 853 |
+
generated_ngrams = [{} for _ in range(num_hypos)]
|
| 854 |
+
for idx in range(num_hypos):
|
| 855 |
+
gen_tokens = prev_input_ids[idx].tolist()
|
| 856 |
+
generated_ngram = generated_ngrams[idx]
|
| 857 |
+
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
| 858 |
+
prev_ngram_tuple = tuple(ngram[:-1])
|
| 859 |
+
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
| 860 |
+
|
| 861 |
+
def _get_generated_ngrams(hypo_idx):
|
| 862 |
+
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
| 863 |
+
start_idx = cur_len + 1 - no_repeat_ngram_size
|
| 864 |
+
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
| 865 |
+
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
| 866 |
+
|
| 867 |
+
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
| 868 |
+
return banned_tokens
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
|
| 872 |
+
banned_tokens = []
|
| 873 |
+
|
| 874 |
+
def _tokens_match(prev_tokens, tokens):
|
| 875 |
+
if len(tokens) == 0:
|
| 876 |
+
# if bad word tokens is just one token always ban it
|
| 877 |
+
return True
|
| 878 |
+
if len(tokens) > len(prev_input_ids):
|
| 879 |
+
# if bad word tokens are longer then prev input_ids they can't be equal
|
| 880 |
+
return False
|
| 881 |
+
|
| 882 |
+
if prev_tokens[-len(tokens) :] == tokens:
|
| 883 |
+
# if tokens match
|
| 884 |
+
return True
|
| 885 |
+
else:
|
| 886 |
+
return False
|
| 887 |
+
|
| 888 |
+
for prev_input_ids_slice in prev_input_ids:
|
| 889 |
+
banned_tokens_slice = []
|
| 890 |
+
|
| 891 |
+
for banned_token_seq in bad_words_ids:
|
| 892 |
+
assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
|
| 893 |
+
bad_words_ids
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
|
| 897 |
+
# if tokens do not match continue
|
| 898 |
+
continue
|
| 899 |
+
|
| 900 |
+
banned_tokens_slice.append(banned_token_seq[-1])
|
| 901 |
+
|
| 902 |
+
banned_tokens.append(banned_tokens_slice)
|
| 903 |
+
|
| 904 |
+
return banned_tokens
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def top_k_top_p_filtering(
|
| 908 |
+
logits: Tensor,
|
| 909 |
+
top_k: int = 0,
|
| 910 |
+
top_p: float = 1.0,
|
| 911 |
+
filter_value: float = -float("Inf"),
|
| 912 |
+
min_tokens_to_keep: int = 1,
|
| 913 |
+
) -> Tensor:
|
| 914 |
+
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
| 915 |
+
Args:
|
| 916 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
| 917 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
| 918 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
| 919 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
| 920 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
| 921 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
| 922 |
+
"""
|
| 923 |
+
if top_k > 0:
|
| 924 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
| 925 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 926 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 927 |
+
logits[indices_to_remove] = filter_value
|
| 928 |
+
|
| 929 |
+
if top_p < 1.0:
|
| 930 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 931 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 932 |
+
|
| 933 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 934 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 935 |
+
if min_tokens_to_keep > 1:
|
| 936 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 937 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 938 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 939 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 940 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 941 |
+
|
| 942 |
+
# scatter sorted tensors to original indexing
|
| 943 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 944 |
+
logits[indices_to_remove] = filter_value
|
| 945 |
+
return logits
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
class BeamHypotheses(object):
|
| 949 |
+
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
| 950 |
+
"""
|
| 951 |
+
Initialize n-best list of hypotheses.
|
| 952 |
+
"""
|
| 953 |
+
self.max_length = max_length - 1 # ignoring bos_token
|
| 954 |
+
self.length_penalty = length_penalty
|
| 955 |
+
self.early_stopping = early_stopping
|
| 956 |
+
self.num_beams = num_beams
|
| 957 |
+
self.beams = []
|
| 958 |
+
self.worst_score = 1e9
|
| 959 |
+
|
| 960 |
+
def __len__(self):
|
| 961 |
+
"""
|
| 962 |
+
Number of hypotheses in the list.
|
| 963 |
+
"""
|
| 964 |
+
return len(self.beams)
|
| 965 |
+
|
| 966 |
+
def add(self, hyp, sum_logprobs):
|
| 967 |
+
"""
|
| 968 |
+
Add a new hypothesis to the list.
|
| 969 |
+
"""
|
| 970 |
+
score = sum_logprobs / len(hyp) ** self.length_penalty
|
| 971 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
| 972 |
+
self.beams.append((score, hyp))
|
| 973 |
+
if len(self) > self.num_beams:
|
| 974 |
+
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
|
| 975 |
+
del self.beams[sorted_scores[0][1]]
|
| 976 |
+
self.worst_score = sorted_scores[1][0]
|
| 977 |
+
else:
|
| 978 |
+
self.worst_score = min(score, self.worst_score)
|
| 979 |
+
|
| 980 |
+
def is_done(self, best_sum_logprobs, cur_len):
|
| 981 |
+
"""
|
| 982 |
+
If there are enough hypotheses and that none of the hypotheses being generated
|
| 983 |
+
can become better than the worst one in the heap, then we are done with this sentence.
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
if len(self) < self.num_beams:
|
| 987 |
+
return False
|
| 988 |
+
elif self.early_stopping:
|
| 989 |
+
return True
|
| 990 |
+
else:
|
| 991 |
+
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
|
| 992 |
+
ret = self.worst_score >= cur_score
|
| 993 |
+
return ret
|
CGFormer/bert/modeling_bert.py
ADDED
|
@@ -0,0 +1,1569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch BERT model. """
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import warnings
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.utils.checkpoint
|
| 26 |
+
from torch import nn
|
| 27 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 28 |
+
|
| 29 |
+
from .activations import gelu, gelu_new, swish
|
| 30 |
+
from .configuration_bert import BertConfig
|
| 31 |
+
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
| 32 |
+
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
| 38 |
+
|
| 39 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 40 |
+
"bert-base-uncased",
|
| 41 |
+
"bert-large-uncased",
|
| 42 |
+
"bert-base-cased",
|
| 43 |
+
"bert-large-cased",
|
| 44 |
+
"bert-base-multilingual-uncased",
|
| 45 |
+
"bert-base-multilingual-cased",
|
| 46 |
+
"bert-base-chinese",
|
| 47 |
+
"bert-base-german-cased",
|
| 48 |
+
"bert-large-uncased-whole-word-masking",
|
| 49 |
+
"bert-large-cased-whole-word-masking",
|
| 50 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
| 51 |
+
"bert-large-cased-whole-word-masking-finetuned-squad",
|
| 52 |
+
"bert-base-cased-finetuned-mrpc",
|
| 53 |
+
"bert-base-german-dbmdz-cased",
|
| 54 |
+
"bert-base-german-dbmdz-uncased",
|
| 55 |
+
"cl-tohoku/bert-base-japanese",
|
| 56 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
| 57 |
+
"cl-tohoku/bert-base-japanese-char",
|
| 58 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
| 59 |
+
"TurkuNLP/bert-base-finnish-cased-v1",
|
| 60 |
+
"TurkuNLP/bert-base-finnish-uncased-v1",
|
| 61 |
+
"wietsedv/bert-base-dutch-cased",
|
| 62 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 67 |
+
""" Load tf checkpoints in a pytorch model.
|
| 68 |
+
"""
|
| 69 |
+
try:
|
| 70 |
+
import re
|
| 71 |
+
import numpy as np
|
| 72 |
+
import tensorflow as tf
|
| 73 |
+
except ImportError:
|
| 74 |
+
logger.error(
|
| 75 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 76 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 77 |
+
)
|
| 78 |
+
raise
|
| 79 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
| 80 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 81 |
+
# Load weights from TF model
|
| 82 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 83 |
+
names = []
|
| 84 |
+
arrays = []
|
| 85 |
+
for name, shape in init_vars:
|
| 86 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 87 |
+
array = tf.train.load_variable(tf_path, name)
|
| 88 |
+
names.append(name)
|
| 89 |
+
arrays.append(array)
|
| 90 |
+
|
| 91 |
+
for name, array in zip(names, arrays):
|
| 92 |
+
name = name.split("/")
|
| 93 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 94 |
+
# which are not required for using pretrained model
|
| 95 |
+
if any(
|
| 96 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
| 97 |
+
for n in name
|
| 98 |
+
):
|
| 99 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 100 |
+
continue
|
| 101 |
+
pointer = model
|
| 102 |
+
for m_name in name:
|
| 103 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
| 104 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
| 105 |
+
else:
|
| 106 |
+
scope_names = [m_name]
|
| 107 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
| 108 |
+
pointer = getattr(pointer, "weight")
|
| 109 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
| 110 |
+
pointer = getattr(pointer, "bias")
|
| 111 |
+
elif scope_names[0] == "output_weights":
|
| 112 |
+
pointer = getattr(pointer, "weight")
|
| 113 |
+
elif scope_names[0] == "squad":
|
| 114 |
+
pointer = getattr(pointer, "classifier")
|
| 115 |
+
else:
|
| 116 |
+
try:
|
| 117 |
+
pointer = getattr(pointer, scope_names[0])
|
| 118 |
+
except AttributeError:
|
| 119 |
+
logger.info("Skipping {}".format("/".join(name)))
|
| 120 |
+
continue
|
| 121 |
+
if len(scope_names) >= 2:
|
| 122 |
+
num = int(scope_names[1])
|
| 123 |
+
pointer = pointer[num]
|
| 124 |
+
if m_name[-11:] == "_embeddings":
|
| 125 |
+
pointer = getattr(pointer, "weight")
|
| 126 |
+
elif m_name == "kernel":
|
| 127 |
+
array = np.transpose(array)
|
| 128 |
+
try:
|
| 129 |
+
assert pointer.shape == array.shape
|
| 130 |
+
except AssertionError as e:
|
| 131 |
+
e.args += (pointer.shape, array.shape)
|
| 132 |
+
raise
|
| 133 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 134 |
+
pointer.data = torch.from_numpy(array)
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def mish(x):
|
| 139 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
BertLayerNorm = torch.nn.LayerNorm
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class BertEmbeddings(nn.Module):
|
| 149 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, config):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
| 155 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
| 156 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
| 157 |
+
|
| 158 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 159 |
+
# any TensorFlow checkpoint file
|
| 160 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 161 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 162 |
+
|
| 163 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
| 164 |
+
if input_ids is not None:
|
| 165 |
+
input_shape = input_ids.size()
|
| 166 |
+
else:
|
| 167 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 168 |
+
|
| 169 |
+
seq_length = input_shape[1]
|
| 170 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 171 |
+
if position_ids is None:
|
| 172 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
| 173 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
| 174 |
+
if token_type_ids is None:
|
| 175 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 176 |
+
|
| 177 |
+
if inputs_embeds is None:
|
| 178 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 179 |
+
position_embeddings = self.position_embeddings(position_ids)
|
| 180 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
| 181 |
+
|
| 182 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
| 183 |
+
embeddings = self.LayerNorm(embeddings)
|
| 184 |
+
embeddings = self.dropout(embeddings)
|
| 185 |
+
return embeddings
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class BertSelfAttention(nn.Module):
|
| 189 |
+
def __init__(self, config):
|
| 190 |
+
super().__init__()
|
| 191 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
| 194 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
self.num_attention_heads = config.num_attention_heads
|
| 198 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 199 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 200 |
+
|
| 201 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 202 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 203 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 204 |
+
|
| 205 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 206 |
+
|
| 207 |
+
def transpose_for_scores(self, x):
|
| 208 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 209 |
+
x = x.view(*new_x_shape)
|
| 210 |
+
return x.permute(0, 2, 1, 3)
|
| 211 |
+
|
| 212 |
+
def forward(
|
| 213 |
+
self,
|
| 214 |
+
hidden_states,
|
| 215 |
+
attention_mask=None,
|
| 216 |
+
head_mask=None,
|
| 217 |
+
encoder_hidden_states=None,
|
| 218 |
+
encoder_attention_mask=None,
|
| 219 |
+
output_attentions=False,
|
| 220 |
+
):
|
| 221 |
+
mixed_query_layer = self.query(hidden_states)
|
| 222 |
+
|
| 223 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 224 |
+
# and values come from an encoder; the attention mask needs to be
|
| 225 |
+
# such that the encoder's padding tokens are not attended to.
|
| 226 |
+
if encoder_hidden_states is not None:
|
| 227 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
| 228 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
| 229 |
+
attention_mask = encoder_attention_mask
|
| 230 |
+
else:
|
| 231 |
+
mixed_key_layer = self.key(hidden_states)
|
| 232 |
+
mixed_value_layer = self.value(hidden_states)
|
| 233 |
+
|
| 234 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 235 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
| 236 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
| 237 |
+
|
| 238 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 239 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 240 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 241 |
+
if attention_mask is not None:
|
| 242 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 243 |
+
attention_scores = attention_scores + attention_mask
|
| 244 |
+
|
| 245 |
+
# Normalize the attention scores to probabilities.
|
| 246 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 247 |
+
|
| 248 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 249 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 250 |
+
attention_probs = self.dropout(attention_probs)
|
| 251 |
+
|
| 252 |
+
# Mask heads if we want to
|
| 253 |
+
if head_mask is not None:
|
| 254 |
+
attention_probs = attention_probs * head_mask
|
| 255 |
+
|
| 256 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
| 257 |
+
|
| 258 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 259 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 260 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 261 |
+
|
| 262 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 263 |
+
return outputs
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class BertSelfOutput(nn.Module):
|
| 267 |
+
def __init__(self, config):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 270 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 271 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 272 |
+
|
| 273 |
+
def forward(self, hidden_states, input_tensor):
|
| 274 |
+
hidden_states = self.dense(hidden_states)
|
| 275 |
+
hidden_states = self.dropout(hidden_states)
|
| 276 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 277 |
+
return hidden_states
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class BertAttention(nn.Module):
|
| 281 |
+
def __init__(self, config):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.self = BertSelfAttention(config)
|
| 284 |
+
self.output = BertSelfOutput(config)
|
| 285 |
+
self.pruned_heads = set()
|
| 286 |
+
|
| 287 |
+
def prune_heads(self, heads):
|
| 288 |
+
if len(heads) == 0:
|
| 289 |
+
return
|
| 290 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 291 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Prune linear layers
|
| 295 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
| 296 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
| 297 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
| 298 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 299 |
+
|
| 300 |
+
# Update hyper params and store pruned heads
|
| 301 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 302 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
| 303 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 304 |
+
|
| 305 |
+
def forward(
|
| 306 |
+
self,
|
| 307 |
+
hidden_states,
|
| 308 |
+
attention_mask=None,
|
| 309 |
+
head_mask=None,
|
| 310 |
+
encoder_hidden_states=None,
|
| 311 |
+
encoder_attention_mask=None,
|
| 312 |
+
output_attentions=False,
|
| 313 |
+
):
|
| 314 |
+
self_outputs = self.self(
|
| 315 |
+
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
|
| 316 |
+
)
|
| 317 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 318 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 319 |
+
return outputs
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class BertIntermediate(nn.Module):
|
| 323 |
+
def __init__(self, config):
|
| 324 |
+
super().__init__()
|
| 325 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 326 |
+
if isinstance(config.hidden_act, str):
|
| 327 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 328 |
+
else:
|
| 329 |
+
self.intermediate_act_fn = config.hidden_act
|
| 330 |
+
|
| 331 |
+
def forward(self, hidden_states):
|
| 332 |
+
hidden_states = self.dense(hidden_states)
|
| 333 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 334 |
+
return hidden_states
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class BertOutput(nn.Module):
|
| 338 |
+
def __init__(self, config):
|
| 339 |
+
super().__init__()
|
| 340 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 341 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 342 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 343 |
+
|
| 344 |
+
def forward(self, hidden_states, input_tensor):
|
| 345 |
+
hidden_states = self.dense(hidden_states)
|
| 346 |
+
hidden_states = self.dropout(hidden_states)
|
| 347 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 348 |
+
return hidden_states
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class BertLayer(nn.Module):
|
| 352 |
+
def __init__(self, config):
|
| 353 |
+
super().__init__()
|
| 354 |
+
self.attention = BertAttention(config)
|
| 355 |
+
self.is_decoder = config.is_decoder
|
| 356 |
+
if self.is_decoder:
|
| 357 |
+
self.crossattention = BertAttention(config)
|
| 358 |
+
self.intermediate = BertIntermediate(config)
|
| 359 |
+
self.output = BertOutput(config)
|
| 360 |
+
|
| 361 |
+
def forward(
|
| 362 |
+
self,
|
| 363 |
+
hidden_states,
|
| 364 |
+
attention_mask=None,
|
| 365 |
+
head_mask=None,
|
| 366 |
+
encoder_hidden_states=None,
|
| 367 |
+
encoder_attention_mask=None,
|
| 368 |
+
output_attentions=False,
|
| 369 |
+
):
|
| 370 |
+
self_attention_outputs = self.attention(
|
| 371 |
+
hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
|
| 372 |
+
)
|
| 373 |
+
attention_output = self_attention_outputs[0]
|
| 374 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
| 375 |
+
|
| 376 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
| 377 |
+
cross_attention_outputs = self.crossattention(
|
| 378 |
+
attention_output,
|
| 379 |
+
attention_mask,
|
| 380 |
+
head_mask,
|
| 381 |
+
encoder_hidden_states,
|
| 382 |
+
encoder_attention_mask,
|
| 383 |
+
output_attentions,
|
| 384 |
+
)
|
| 385 |
+
attention_output = cross_attention_outputs[0]
|
| 386 |
+
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
| 387 |
+
|
| 388 |
+
intermediate_output = self.intermediate(attention_output)
|
| 389 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 390 |
+
outputs = (layer_output,) + outputs
|
| 391 |
+
return outputs
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class BertEncoder(nn.Module):
|
| 395 |
+
def __init__(self, config):
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.config = config
|
| 398 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
| 399 |
+
|
| 400 |
+
def forward(
|
| 401 |
+
self,
|
| 402 |
+
hidden_states,
|
| 403 |
+
attention_mask=None,
|
| 404 |
+
head_mask=None,
|
| 405 |
+
encoder_hidden_states=None,
|
| 406 |
+
encoder_attention_mask=None,
|
| 407 |
+
output_attentions=False,
|
| 408 |
+
output_hidden_states=False,
|
| 409 |
+
):
|
| 410 |
+
all_hidden_states = ()
|
| 411 |
+
all_attentions = ()
|
| 412 |
+
for i, layer_module in enumerate(self.layer):
|
| 413 |
+
if output_hidden_states:
|
| 414 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 415 |
+
|
| 416 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
| 417 |
+
|
| 418 |
+
def create_custom_forward(module):
|
| 419 |
+
def custom_forward(*inputs):
|
| 420 |
+
return module(*inputs, output_attentions)
|
| 421 |
+
|
| 422 |
+
return custom_forward
|
| 423 |
+
|
| 424 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 425 |
+
create_custom_forward(layer_module),
|
| 426 |
+
hidden_states,
|
| 427 |
+
attention_mask,
|
| 428 |
+
head_mask[i],
|
| 429 |
+
encoder_hidden_states,
|
| 430 |
+
encoder_attention_mask,
|
| 431 |
+
)
|
| 432 |
+
else:
|
| 433 |
+
layer_outputs = layer_module(
|
| 434 |
+
hidden_states,
|
| 435 |
+
attention_mask,
|
| 436 |
+
head_mask[i],
|
| 437 |
+
encoder_hidden_states,
|
| 438 |
+
encoder_attention_mask,
|
| 439 |
+
output_attentions,
|
| 440 |
+
)
|
| 441 |
+
hidden_states = layer_outputs[0]
|
| 442 |
+
|
| 443 |
+
if output_attentions:
|
| 444 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 445 |
+
|
| 446 |
+
# Add last layer
|
| 447 |
+
if output_hidden_states:
|
| 448 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 449 |
+
|
| 450 |
+
outputs = (hidden_states,)
|
| 451 |
+
if output_hidden_states:
|
| 452 |
+
outputs = outputs + (all_hidden_states,)
|
| 453 |
+
if output_attentions:
|
| 454 |
+
outputs = outputs + (all_attentions,)
|
| 455 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class BertPooler(nn.Module):
|
| 459 |
+
def __init__(self, config):
|
| 460 |
+
super().__init__()
|
| 461 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 462 |
+
self.activation = nn.Tanh()
|
| 463 |
+
|
| 464 |
+
def forward(self, hidden_states):
|
| 465 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
| 466 |
+
# to the first token.
|
| 467 |
+
first_token_tensor = hidden_states[:, 0]
|
| 468 |
+
pooled_output = self.dense(first_token_tensor)
|
| 469 |
+
pooled_output = self.activation(pooled_output)
|
| 470 |
+
return pooled_output
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 474 |
+
def __init__(self, config):
|
| 475 |
+
super().__init__()
|
| 476 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 477 |
+
if isinstance(config.hidden_act, str):
|
| 478 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 479 |
+
else:
|
| 480 |
+
self.transform_act_fn = config.hidden_act
|
| 481 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 482 |
+
|
| 483 |
+
def forward(self, hidden_states):
|
| 484 |
+
hidden_states = self.dense(hidden_states)
|
| 485 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 486 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 487 |
+
return hidden_states
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class BertLMPredictionHead(nn.Module):
|
| 491 |
+
def __init__(self, config):
|
| 492 |
+
super().__init__()
|
| 493 |
+
self.transform = BertPredictionHeadTransform(config)
|
| 494 |
+
|
| 495 |
+
# The output weights are the same as the input embeddings, but there is
|
| 496 |
+
# an output-only bias for each token.
|
| 497 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 498 |
+
|
| 499 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
| 500 |
+
|
| 501 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
| 502 |
+
self.decoder.bias = self.bias
|
| 503 |
+
|
| 504 |
+
def forward(self, hidden_states):
|
| 505 |
+
hidden_states = self.transform(hidden_states)
|
| 506 |
+
hidden_states = self.decoder(hidden_states)
|
| 507 |
+
return hidden_states
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class BertOnlyMLMHead(nn.Module):
|
| 511 |
+
def __init__(self, config):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.predictions = BertLMPredictionHead(config)
|
| 514 |
+
|
| 515 |
+
def forward(self, sequence_output):
|
| 516 |
+
prediction_scores = self.predictions(sequence_output)
|
| 517 |
+
return prediction_scores
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
class BertOnlyNSPHead(nn.Module):
|
| 521 |
+
def __init__(self, config):
|
| 522 |
+
super().__init__()
|
| 523 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 524 |
+
|
| 525 |
+
def forward(self, pooled_output):
|
| 526 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 527 |
+
return seq_relationship_score
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class BertPreTrainingHeads(nn.Module):
|
| 531 |
+
def __init__(self, config):
|
| 532 |
+
super().__init__()
|
| 533 |
+
self.predictions = BertLMPredictionHead(config)
|
| 534 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
| 535 |
+
|
| 536 |
+
def forward(self, sequence_output, pooled_output):
|
| 537 |
+
prediction_scores = self.predictions(sequence_output)
|
| 538 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
| 539 |
+
return prediction_scores, seq_relationship_score
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
class BertPreTrainedModel(PreTrainedModel):
|
| 543 |
+
""" An abstract class to handle weights initialization and
|
| 544 |
+
a simple interface for downloading and loading pretrained models.
|
| 545 |
+
"""
|
| 546 |
+
|
| 547 |
+
config_class = BertConfig
|
| 548 |
+
load_tf_weights = load_tf_weights_in_bert
|
| 549 |
+
base_model_prefix = "bert"
|
| 550 |
+
|
| 551 |
+
def _init_weights(self, module):
|
| 552 |
+
""" Initialize the weights """
|
| 553 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
| 554 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 555 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 556 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 557 |
+
elif isinstance(module, BertLayerNorm):
|
| 558 |
+
module.bias.data.zero_()
|
| 559 |
+
module.weight.data.fill_(1.0)
|
| 560 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
| 561 |
+
module.bias.data.zero_()
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
BERT_START_DOCSTRING = r"""
|
| 565 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
| 566 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
| 567 |
+
usage and behavior.
|
| 568 |
+
|
| 569 |
+
Parameters:
|
| 570 |
+
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
| 571 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
| 572 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
| 573 |
+
"""
|
| 574 |
+
|
| 575 |
+
BERT_INPUTS_DOCSTRING = r"""
|
| 576 |
+
Args:
|
| 577 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
|
| 578 |
+
Indices of input sequence tokens in the vocabulary.
|
| 579 |
+
|
| 580 |
+
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
| 581 |
+
See :func:`transformers.PreTrainedTokenizer.encode` and
|
| 582 |
+
:func:`transformers.PreTrainedTokenizer.__call__` for details.
|
| 583 |
+
|
| 584 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
| 585 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 586 |
+
Mask to avoid performing attention on padding token indices.
|
| 587 |
+
Mask values selected in ``[0, 1]``:
|
| 588 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 589 |
+
|
| 590 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 591 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 592 |
+
Segment token indices to indicate first and second portions of the inputs.
|
| 593 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
| 594 |
+
corresponds to a `sentence B` token
|
| 595 |
+
|
| 596 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
| 597 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
| 598 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
| 599 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
| 600 |
+
|
| 601 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
| 602 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
| 603 |
+
Mask to nullify selected heads of the self-attention modules.
|
| 604 |
+
Mask values selected in ``[0, 1]``:
|
| 605 |
+
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
| 606 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
| 607 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
| 608 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
| 609 |
+
than the model's internal embedding lookup matrix.
|
| 610 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
| 611 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
| 612 |
+
if the model is configured as a decoder.
|
| 613 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 614 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
| 615 |
+
is used in the cross-attention if the model is configured as a decoder.
|
| 616 |
+
Mask values selected in ``[0, 1]``:
|
| 617 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
| 618 |
+
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
| 619 |
+
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
| 620 |
+
"""
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
@add_start_docstrings(
|
| 624 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
| 625 |
+
BERT_START_DOCSTRING,
|
| 626 |
+
)
|
| 627 |
+
class BertModel(BertPreTrainedModel):
|
| 628 |
+
"""
|
| 629 |
+
|
| 630 |
+
The model can behave as an encoder (with only self-attention) as well
|
| 631 |
+
as a decoder, in which case a layer of cross-attention is added between
|
| 632 |
+
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
| 633 |
+
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 634 |
+
|
| 635 |
+
To behave as an decoder the model needs to be initialized with the
|
| 636 |
+
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
| 637 |
+
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
| 638 |
+
|
| 639 |
+
.. _`Attention is all you need`:
|
| 640 |
+
https://arxiv.org/abs/1706.03762
|
| 641 |
+
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(self, config):
|
| 645 |
+
super().__init__(config)
|
| 646 |
+
self.config = config
|
| 647 |
+
|
| 648 |
+
self.embeddings = BertEmbeddings(config)
|
| 649 |
+
self.encoder = BertEncoder(config)
|
| 650 |
+
self.pooler = BertPooler(config)
|
| 651 |
+
|
| 652 |
+
self.init_weights()
|
| 653 |
+
|
| 654 |
+
def get_input_embeddings(self):
|
| 655 |
+
return self.embeddings.word_embeddings
|
| 656 |
+
|
| 657 |
+
def set_input_embeddings(self, value):
|
| 658 |
+
self.embeddings.word_embeddings = value
|
| 659 |
+
|
| 660 |
+
def _prune_heads(self, heads_to_prune):
|
| 661 |
+
""" Prunes heads of the model.
|
| 662 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 663 |
+
See base class PreTrainedModel
|
| 664 |
+
"""
|
| 665 |
+
for layer, heads in heads_to_prune.items():
|
| 666 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 667 |
+
|
| 668 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 669 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 670 |
+
def forward(
|
| 671 |
+
self,
|
| 672 |
+
input_ids=None,
|
| 673 |
+
attention_mask=None,
|
| 674 |
+
token_type_ids=None,
|
| 675 |
+
position_ids=None,
|
| 676 |
+
head_mask=None,
|
| 677 |
+
inputs_embeds=None,
|
| 678 |
+
encoder_hidden_states=None,
|
| 679 |
+
encoder_attention_mask=None,
|
| 680 |
+
output_attentions=None,
|
| 681 |
+
output_hidden_states=None,
|
| 682 |
+
):
|
| 683 |
+
r"""
|
| 684 |
+
Return:
|
| 685 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 686 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
| 687 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 688 |
+
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
| 689 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
| 690 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
| 691 |
+
layer weights are trained from the next sentence prediction (classification)
|
| 692 |
+
objective during pre-training.
|
| 693 |
+
|
| 694 |
+
This output is usually *not* a good summary
|
| 695 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
| 696 |
+
the sequence of hidden-states for the whole input sequence.
|
| 697 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 698 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 699 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 700 |
+
|
| 701 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 702 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 703 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 704 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 705 |
+
|
| 706 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 707 |
+
heads.
|
| 708 |
+
"""
|
| 709 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 710 |
+
output_hidden_states = (
|
| 711 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 715 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 716 |
+
elif input_ids is not None:
|
| 717 |
+
input_shape = input_ids.size()
|
| 718 |
+
elif inputs_embeds is not None:
|
| 719 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 720 |
+
else:
|
| 721 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 722 |
+
|
| 723 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 724 |
+
|
| 725 |
+
if attention_mask is None:
|
| 726 |
+
attention_mask = torch.ones(input_shape, device=device)
|
| 727 |
+
if token_type_ids is None:
|
| 728 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
| 729 |
+
|
| 730 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 731 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 732 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 733 |
+
|
| 734 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
| 735 |
+
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
| 736 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
| 737 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 738 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 739 |
+
if encoder_attention_mask is None:
|
| 740 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 741 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 742 |
+
else:
|
| 743 |
+
encoder_extended_attention_mask = None
|
| 744 |
+
|
| 745 |
+
# Prepare head mask if needed
|
| 746 |
+
# 1.0 in head_mask indicate we keep the head
|
| 747 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 748 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 749 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 750 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 751 |
+
|
| 752 |
+
embedding_output = self.embeddings(
|
| 753 |
+
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
| 754 |
+
)
|
| 755 |
+
encoder_outputs = self.encoder(
|
| 756 |
+
embedding_output,
|
| 757 |
+
attention_mask=extended_attention_mask,
|
| 758 |
+
head_mask=head_mask,
|
| 759 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 760 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 761 |
+
output_attentions=output_attentions,
|
| 762 |
+
output_hidden_states=output_hidden_states,
|
| 763 |
+
)
|
| 764 |
+
sequence_output = encoder_outputs[0]
|
| 765 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| 766 |
+
|
| 767 |
+
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
| 768 |
+
1:
|
| 769 |
+
] # add hidden_states and attentions if they are here
|
| 770 |
+
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
@add_start_docstrings(
|
| 774 |
+
"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
|
| 775 |
+
a `next sentence prediction (classification)` head. """,
|
| 776 |
+
BERT_START_DOCSTRING,
|
| 777 |
+
)
|
| 778 |
+
class BertForPreTraining(BertPreTrainedModel):
|
| 779 |
+
def __init__(self, config):
|
| 780 |
+
super().__init__(config)
|
| 781 |
+
|
| 782 |
+
self.bert = BertModel(config)
|
| 783 |
+
self.cls = BertPreTrainingHeads(config)
|
| 784 |
+
|
| 785 |
+
self.init_weights()
|
| 786 |
+
|
| 787 |
+
def get_output_embeddings(self):
|
| 788 |
+
return self.cls.predictions.decoder
|
| 789 |
+
|
| 790 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 791 |
+
def forward(
|
| 792 |
+
self,
|
| 793 |
+
input_ids=None,
|
| 794 |
+
attention_mask=None,
|
| 795 |
+
token_type_ids=None,
|
| 796 |
+
position_ids=None,
|
| 797 |
+
head_mask=None,
|
| 798 |
+
inputs_embeds=None,
|
| 799 |
+
labels=None,
|
| 800 |
+
next_sentence_label=None,
|
| 801 |
+
output_attentions=None,
|
| 802 |
+
output_hidden_states=None,
|
| 803 |
+
**kwargs
|
| 804 |
+
):
|
| 805 |
+
r"""
|
| 806 |
+
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
| 807 |
+
Labels for computing the masked language modeling loss.
|
| 808 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 809 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 810 |
+
in ``[0, ..., config.vocab_size]``
|
| 811 |
+
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
| 812 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
| 813 |
+
Indices should be in ``[0, 1]``.
|
| 814 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 815 |
+
``1`` indicates sequence B is a random sequence.
|
| 816 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 817 |
+
Used to hide legacy arguments that have been deprecated.
|
| 818 |
+
|
| 819 |
+
Returns:
|
| 820 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 821 |
+
loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 822 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
| 823 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 824 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 825 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
| 826 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False
|
| 827 |
+
continuation before SoftMax).
|
| 828 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 829 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 830 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 831 |
+
|
| 832 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 833 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 834 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 835 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 836 |
+
|
| 837 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 838 |
+
heads.
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
Examples::
|
| 842 |
+
|
| 843 |
+
>>> from transformers import BertTokenizer, BertForPreTraining
|
| 844 |
+
>>> import torch
|
| 845 |
+
|
| 846 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 847 |
+
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
| 848 |
+
|
| 849 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 850 |
+
>>> outputs = model(**inputs)
|
| 851 |
+
|
| 852 |
+
>>> prediction_scores, seq_relationship_scores = outputs[:2]
|
| 853 |
+
|
| 854 |
+
"""
|
| 855 |
+
if "masked_lm_labels" in kwargs:
|
| 856 |
+
warnings.warn(
|
| 857 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
| 858 |
+
DeprecationWarning,
|
| 859 |
+
)
|
| 860 |
+
labels = kwargs.pop("masked_lm_labels")
|
| 861 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
| 862 |
+
|
| 863 |
+
outputs = self.bert(
|
| 864 |
+
input_ids,
|
| 865 |
+
attention_mask=attention_mask,
|
| 866 |
+
token_type_ids=token_type_ids,
|
| 867 |
+
position_ids=position_ids,
|
| 868 |
+
head_mask=head_mask,
|
| 869 |
+
inputs_embeds=inputs_embeds,
|
| 870 |
+
output_attentions=output_attentions,
|
| 871 |
+
output_hidden_states=output_hidden_states,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
sequence_output, pooled_output = outputs[:2]
|
| 875 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
| 876 |
+
|
| 877 |
+
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
| 878 |
+
2:
|
| 879 |
+
] # add hidden states and attention if they are here
|
| 880 |
+
|
| 881 |
+
if labels is not None and next_sentence_label is not None:
|
| 882 |
+
loss_fct = CrossEntropyLoss()
|
| 883 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 884 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 885 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
| 886 |
+
outputs = (total_loss,) + outputs
|
| 887 |
+
|
| 888 |
+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
@add_start_docstrings(
|
| 892 |
+
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
|
| 893 |
+
)
|
| 894 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
| 895 |
+
def __init__(self, config):
|
| 896 |
+
super().__init__(config)
|
| 897 |
+
assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
|
| 898 |
+
|
| 899 |
+
self.bert = BertModel(config)
|
| 900 |
+
self.cls = BertOnlyMLMHead(config)
|
| 901 |
+
|
| 902 |
+
self.init_weights()
|
| 903 |
+
|
| 904 |
+
def get_output_embeddings(self):
|
| 905 |
+
return self.cls.predictions.decoder
|
| 906 |
+
|
| 907 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 908 |
+
def forward(
|
| 909 |
+
self,
|
| 910 |
+
input_ids=None,
|
| 911 |
+
attention_mask=None,
|
| 912 |
+
token_type_ids=None,
|
| 913 |
+
position_ids=None,
|
| 914 |
+
head_mask=None,
|
| 915 |
+
inputs_embeds=None,
|
| 916 |
+
labels=None,
|
| 917 |
+
encoder_hidden_states=None,
|
| 918 |
+
encoder_attention_mask=None,
|
| 919 |
+
output_attentions=None,
|
| 920 |
+
output_hidden_states=None,
|
| 921 |
+
**kwargs
|
| 922 |
+
):
|
| 923 |
+
r"""
|
| 924 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 925 |
+
Labels for computing the left-to-right language modeling loss (next word prediction).
|
| 926 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 927 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 928 |
+
in ``[0, ..., config.vocab_size]``
|
| 929 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 930 |
+
Used to hide legacy arguments that have been deprecated.
|
| 931 |
+
|
| 932 |
+
Returns:
|
| 933 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 934 |
+
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 935 |
+
Next token prediction loss.
|
| 936 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 937 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 938 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 939 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 940 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 941 |
+
|
| 942 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 943 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 944 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 945 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 946 |
+
|
| 947 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 948 |
+
heads.
|
| 949 |
+
|
| 950 |
+
Example::
|
| 951 |
+
|
| 952 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
| 953 |
+
>>> import torch
|
| 954 |
+
|
| 955 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
| 956 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
| 957 |
+
>>> config.is_decoder = True
|
| 958 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
| 959 |
+
|
| 960 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
| 961 |
+
>>> outputs = model(**inputs)
|
| 962 |
+
|
| 963 |
+
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
| 964 |
+
"""
|
| 965 |
+
|
| 966 |
+
outputs = self.bert(
|
| 967 |
+
input_ids,
|
| 968 |
+
attention_mask=attention_mask,
|
| 969 |
+
token_type_ids=token_type_ids,
|
| 970 |
+
position_ids=position_ids,
|
| 971 |
+
head_mask=head_mask,
|
| 972 |
+
inputs_embeds=inputs_embeds,
|
| 973 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 974 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 975 |
+
output_attentions=output_attentions,
|
| 976 |
+
output_hidden_states=output_hidden_states,
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
sequence_output = outputs[0]
|
| 980 |
+
prediction_scores = self.cls(sequence_output)
|
| 981 |
+
|
| 982 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
| 983 |
+
|
| 984 |
+
if labels is not None:
|
| 985 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
| 986 |
+
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
| 987 |
+
labels = labels[:, 1:].contiguous()
|
| 988 |
+
loss_fct = CrossEntropyLoss()
|
| 989 |
+
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 990 |
+
outputs = (ltr_lm_loss,) + outputs
|
| 991 |
+
|
| 992 |
+
return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
| 993 |
+
|
| 994 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 995 |
+
input_shape = input_ids.shape
|
| 996 |
+
|
| 997 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
| 998 |
+
if attention_mask is None:
|
| 999 |
+
attention_mask = input_ids.new_ones(input_shape)
|
| 1000 |
+
|
| 1001 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
| 1005 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
| 1006 |
+
def __init__(self, config):
|
| 1007 |
+
super().__init__(config)
|
| 1008 |
+
assert (
|
| 1009 |
+
not config.is_decoder
|
| 1010 |
+
), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
|
| 1011 |
+
|
| 1012 |
+
self.bert = BertModel(config)
|
| 1013 |
+
self.cls = BertOnlyMLMHead(config)
|
| 1014 |
+
|
| 1015 |
+
self.init_weights()
|
| 1016 |
+
|
| 1017 |
+
def get_output_embeddings(self):
|
| 1018 |
+
return self.cls.predictions.decoder
|
| 1019 |
+
|
| 1020 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1021 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1022 |
+
def forward(
|
| 1023 |
+
self,
|
| 1024 |
+
input_ids=None,
|
| 1025 |
+
attention_mask=None,
|
| 1026 |
+
token_type_ids=None,
|
| 1027 |
+
position_ids=None,
|
| 1028 |
+
head_mask=None,
|
| 1029 |
+
inputs_embeds=None,
|
| 1030 |
+
labels=None,
|
| 1031 |
+
encoder_hidden_states=None,
|
| 1032 |
+
encoder_attention_mask=None,
|
| 1033 |
+
output_attentions=None,
|
| 1034 |
+
output_hidden_states=None,
|
| 1035 |
+
**kwargs
|
| 1036 |
+
):
|
| 1037 |
+
r"""
|
| 1038 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 1039 |
+
Labels for computing the masked language modeling loss.
|
| 1040 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
| 1041 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
| 1042 |
+
in ``[0, ..., config.vocab_size]``
|
| 1043 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
| 1044 |
+
Used to hide legacy arguments that have been deprecated.
|
| 1045 |
+
|
| 1046 |
+
Returns:
|
| 1047 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1048 |
+
masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 1049 |
+
Masked language modeling loss.
|
| 1050 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
| 1051 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 1052 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1053 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1054 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1055 |
+
|
| 1056 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1057 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1058 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1059 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1060 |
+
|
| 1061 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1062 |
+
heads.
|
| 1063 |
+
"""
|
| 1064 |
+
if "masked_lm_labels" in kwargs:
|
| 1065 |
+
warnings.warn(
|
| 1066 |
+
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
| 1067 |
+
DeprecationWarning,
|
| 1068 |
+
)
|
| 1069 |
+
labels = kwargs.pop("masked_lm_labels")
|
| 1070 |
+
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
|
| 1071 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
| 1072 |
+
|
| 1073 |
+
outputs = self.bert(
|
| 1074 |
+
input_ids,
|
| 1075 |
+
attention_mask=attention_mask,
|
| 1076 |
+
token_type_ids=token_type_ids,
|
| 1077 |
+
position_ids=position_ids,
|
| 1078 |
+
head_mask=head_mask,
|
| 1079 |
+
inputs_embeds=inputs_embeds,
|
| 1080 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1081 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1082 |
+
output_attentions=output_attentions,
|
| 1083 |
+
output_hidden_states=output_hidden_states,
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
sequence_output = outputs[0]
|
| 1087 |
+
prediction_scores = self.cls(sequence_output)
|
| 1088 |
+
|
| 1089 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
| 1090 |
+
|
| 1091 |
+
if labels is not None:
|
| 1092 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
| 1093 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1094 |
+
outputs = (masked_lm_loss,) + outputs
|
| 1095 |
+
|
| 1096 |
+
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
|
| 1097 |
+
|
| 1098 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
| 1099 |
+
input_shape = input_ids.shape
|
| 1100 |
+
effective_batch_size = input_shape[0]
|
| 1101 |
+
|
| 1102 |
+
# add a dummy token
|
| 1103 |
+
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
| 1104 |
+
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
| 1105 |
+
dummy_token = torch.full(
|
| 1106 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
| 1107 |
+
)
|
| 1108 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 1109 |
+
|
| 1110 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
@add_start_docstrings(
|
| 1114 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
|
| 1115 |
+
)
|
| 1116 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
| 1117 |
+
def __init__(self, config):
|
| 1118 |
+
super().__init__(config)
|
| 1119 |
+
|
| 1120 |
+
self.bert = BertModel(config)
|
| 1121 |
+
self.cls = BertOnlyNSPHead(config)
|
| 1122 |
+
|
| 1123 |
+
self.init_weights()
|
| 1124 |
+
|
| 1125 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1126 |
+
def forward(
|
| 1127 |
+
self,
|
| 1128 |
+
input_ids=None,
|
| 1129 |
+
attention_mask=None,
|
| 1130 |
+
token_type_ids=None,
|
| 1131 |
+
position_ids=None,
|
| 1132 |
+
head_mask=None,
|
| 1133 |
+
inputs_embeds=None,
|
| 1134 |
+
next_sentence_label=None,
|
| 1135 |
+
output_attentions=None,
|
| 1136 |
+
output_hidden_states=None,
|
| 1137 |
+
):
|
| 1138 |
+
r"""
|
| 1139 |
+
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1140 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
| 1141 |
+
Indices should be in ``[0, 1]``.
|
| 1142 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
| 1143 |
+
``1`` indicates sequence B is a random sequence.
|
| 1144 |
+
|
| 1145 |
+
Returns:
|
| 1146 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1147 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
| 1148 |
+
Next sequence prediction (classification) loss.
|
| 1149 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
| 1150 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
| 1151 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1152 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1153 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1154 |
+
|
| 1155 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1156 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1157 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1158 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1159 |
+
|
| 1160 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1161 |
+
heads.
|
| 1162 |
+
|
| 1163 |
+
Examples::
|
| 1164 |
+
|
| 1165 |
+
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
|
| 1166 |
+
>>> import torch
|
| 1167 |
+
|
| 1168 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 1169 |
+
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
| 1170 |
+
|
| 1171 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
| 1172 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
| 1173 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
| 1174 |
+
|
| 1175 |
+
>>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
|
| 1176 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
| 1177 |
+
"""
|
| 1178 |
+
|
| 1179 |
+
outputs = self.bert(
|
| 1180 |
+
input_ids,
|
| 1181 |
+
attention_mask=attention_mask,
|
| 1182 |
+
token_type_ids=token_type_ids,
|
| 1183 |
+
position_ids=position_ids,
|
| 1184 |
+
head_mask=head_mask,
|
| 1185 |
+
inputs_embeds=inputs_embeds,
|
| 1186 |
+
output_attentions=output_attentions,
|
| 1187 |
+
output_hidden_states=output_hidden_states,
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
pooled_output = outputs[1]
|
| 1191 |
+
|
| 1192 |
+
seq_relationship_score = self.cls(pooled_output)
|
| 1193 |
+
|
| 1194 |
+
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
| 1195 |
+
if next_sentence_label is not None:
|
| 1196 |
+
loss_fct = CrossEntropyLoss()
|
| 1197 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
| 1198 |
+
outputs = (next_sentence_loss,) + outputs
|
| 1199 |
+
|
| 1200 |
+
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
@add_start_docstrings(
|
| 1204 |
+
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 1205 |
+
the pooled output) e.g. for GLUE tasks. """,
|
| 1206 |
+
BERT_START_DOCSTRING,
|
| 1207 |
+
)
|
| 1208 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
| 1209 |
+
def __init__(self, config):
|
| 1210 |
+
super().__init__(config)
|
| 1211 |
+
self.num_labels = config.num_labels
|
| 1212 |
+
|
| 1213 |
+
self.bert = BertModel(config)
|
| 1214 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1215 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1216 |
+
|
| 1217 |
+
self.init_weights()
|
| 1218 |
+
|
| 1219 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1220 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1221 |
+
def forward(
|
| 1222 |
+
self,
|
| 1223 |
+
input_ids=None,
|
| 1224 |
+
attention_mask=None,
|
| 1225 |
+
token_type_ids=None,
|
| 1226 |
+
position_ids=None,
|
| 1227 |
+
head_mask=None,
|
| 1228 |
+
inputs_embeds=None,
|
| 1229 |
+
labels=None,
|
| 1230 |
+
output_attentions=None,
|
| 1231 |
+
output_hidden_states=None,
|
| 1232 |
+
):
|
| 1233 |
+
r"""
|
| 1234 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1235 |
+
Labels for computing the sequence classification/regression loss.
|
| 1236 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
| 1237 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
| 1238 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1239 |
+
|
| 1240 |
+
Returns:
|
| 1241 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1242 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
| 1243 |
+
Classification (or regression if config.num_labels==1) loss.
|
| 1244 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
| 1245 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
| 1246 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1247 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1248 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1249 |
+
|
| 1250 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1251 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1252 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1253 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1254 |
+
|
| 1255 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1256 |
+
heads.
|
| 1257 |
+
"""
|
| 1258 |
+
|
| 1259 |
+
outputs = self.bert(
|
| 1260 |
+
input_ids,
|
| 1261 |
+
attention_mask=attention_mask,
|
| 1262 |
+
token_type_ids=token_type_ids,
|
| 1263 |
+
position_ids=position_ids,
|
| 1264 |
+
head_mask=head_mask,
|
| 1265 |
+
inputs_embeds=inputs_embeds,
|
| 1266 |
+
output_attentions=output_attentions,
|
| 1267 |
+
output_hidden_states=output_hidden_states,
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
pooled_output = outputs[1]
|
| 1271 |
+
|
| 1272 |
+
pooled_output = self.dropout(pooled_output)
|
| 1273 |
+
logits = self.classifier(pooled_output)
|
| 1274 |
+
|
| 1275 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1276 |
+
|
| 1277 |
+
if labels is not None:
|
| 1278 |
+
if self.num_labels == 1:
|
| 1279 |
+
# We are doing regression
|
| 1280 |
+
loss_fct = MSELoss()
|
| 1281 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
| 1282 |
+
else:
|
| 1283 |
+
loss_fct = CrossEntropyLoss()
|
| 1284 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1285 |
+
outputs = (loss,) + outputs
|
| 1286 |
+
|
| 1287 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
@add_start_docstrings(
|
| 1291 |
+
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
| 1292 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
| 1293 |
+
BERT_START_DOCSTRING,
|
| 1294 |
+
)
|
| 1295 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
| 1296 |
+
def __init__(self, config):
|
| 1297 |
+
super().__init__(config)
|
| 1298 |
+
|
| 1299 |
+
self.bert = BertModel(config)
|
| 1300 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1301 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
| 1302 |
+
|
| 1303 |
+
self.init_weights()
|
| 1304 |
+
|
| 1305 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
|
| 1306 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1307 |
+
def forward(
|
| 1308 |
+
self,
|
| 1309 |
+
input_ids=None,
|
| 1310 |
+
attention_mask=None,
|
| 1311 |
+
token_type_ids=None,
|
| 1312 |
+
position_ids=None,
|
| 1313 |
+
head_mask=None,
|
| 1314 |
+
inputs_embeds=None,
|
| 1315 |
+
labels=None,
|
| 1316 |
+
output_attentions=None,
|
| 1317 |
+
output_hidden_states=None,
|
| 1318 |
+
):
|
| 1319 |
+
r"""
|
| 1320 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1321 |
+
Labels for computing the multiple choice classification loss.
|
| 1322 |
+
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
|
| 1323 |
+
of the input tensors. (see `input_ids` above)
|
| 1324 |
+
|
| 1325 |
+
Returns:
|
| 1326 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1327 |
+
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 1328 |
+
Classification loss.
|
| 1329 |
+
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
| 1330 |
+
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
| 1331 |
+
|
| 1332 |
+
Classification scores (before SoftMax).
|
| 1333 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1334 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1335 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1336 |
+
|
| 1337 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1338 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1339 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1340 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1341 |
+
|
| 1342 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1343 |
+
heads.
|
| 1344 |
+
"""
|
| 1345 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
| 1346 |
+
|
| 1347 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
|
| 1348 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
| 1349 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
| 1350 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
| 1351 |
+
inputs_embeds = (
|
| 1352 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
|
| 1353 |
+
if inputs_embeds is not None
|
| 1354 |
+
else None
|
| 1355 |
+
)
|
| 1356 |
+
|
| 1357 |
+
outputs = self.bert(
|
| 1358 |
+
input_ids,
|
| 1359 |
+
attention_mask=attention_mask,
|
| 1360 |
+
token_type_ids=token_type_ids,
|
| 1361 |
+
position_ids=position_ids,
|
| 1362 |
+
head_mask=head_mask,
|
| 1363 |
+
inputs_embeds=inputs_embeds,
|
| 1364 |
+
output_attentions=output_attentions,
|
| 1365 |
+
output_hidden_states=output_hidden_states,
|
| 1366 |
+
)
|
| 1367 |
+
|
| 1368 |
+
pooled_output = outputs[1]
|
| 1369 |
+
|
| 1370 |
+
pooled_output = self.dropout(pooled_output)
|
| 1371 |
+
logits = self.classifier(pooled_output)
|
| 1372 |
+
reshaped_logits = logits.view(-1, num_choices)
|
| 1373 |
+
|
| 1374 |
+
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1375 |
+
|
| 1376 |
+
if labels is not None:
|
| 1377 |
+
loss_fct = CrossEntropyLoss()
|
| 1378 |
+
loss = loss_fct(reshaped_logits, labels)
|
| 1379 |
+
outputs = (loss,) + outputs
|
| 1380 |
+
|
| 1381 |
+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
| 1382 |
+
|
| 1383 |
+
|
| 1384 |
+
@add_start_docstrings(
|
| 1385 |
+
"""Bert Model with a token classification head on top (a linear layer on top of
|
| 1386 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
| 1387 |
+
BERT_START_DOCSTRING,
|
| 1388 |
+
)
|
| 1389 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
| 1390 |
+
def __init__(self, config):
|
| 1391 |
+
super().__init__(config)
|
| 1392 |
+
self.num_labels = config.num_labels
|
| 1393 |
+
|
| 1394 |
+
self.bert = BertModel(config)
|
| 1395 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1396 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1397 |
+
|
| 1398 |
+
self.init_weights()
|
| 1399 |
+
|
| 1400 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1401 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1402 |
+
def forward(
|
| 1403 |
+
self,
|
| 1404 |
+
input_ids=None,
|
| 1405 |
+
attention_mask=None,
|
| 1406 |
+
token_type_ids=None,
|
| 1407 |
+
position_ids=None,
|
| 1408 |
+
head_mask=None,
|
| 1409 |
+
inputs_embeds=None,
|
| 1410 |
+
labels=None,
|
| 1411 |
+
output_attentions=None,
|
| 1412 |
+
output_hidden_states=None,
|
| 1413 |
+
):
|
| 1414 |
+
r"""
|
| 1415 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
| 1416 |
+
Labels for computing the token classification loss.
|
| 1417 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
| 1418 |
+
|
| 1419 |
+
Returns:
|
| 1420 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1421 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
| 1422 |
+
Classification loss.
|
| 1423 |
+
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
| 1424 |
+
Classification scores (before SoftMax).
|
| 1425 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1426 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1427 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1428 |
+
|
| 1429 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1430 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1431 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1432 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1433 |
+
|
| 1434 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1435 |
+
heads.
|
| 1436 |
+
"""
|
| 1437 |
+
|
| 1438 |
+
outputs = self.bert(
|
| 1439 |
+
input_ids,
|
| 1440 |
+
attention_mask=attention_mask,
|
| 1441 |
+
token_type_ids=token_type_ids,
|
| 1442 |
+
position_ids=position_ids,
|
| 1443 |
+
head_mask=head_mask,
|
| 1444 |
+
inputs_embeds=inputs_embeds,
|
| 1445 |
+
output_attentions=output_attentions,
|
| 1446 |
+
output_hidden_states=output_hidden_states,
|
| 1447 |
+
)
|
| 1448 |
+
|
| 1449 |
+
sequence_output = outputs[0]
|
| 1450 |
+
|
| 1451 |
+
sequence_output = self.dropout(sequence_output)
|
| 1452 |
+
logits = self.classifier(sequence_output)
|
| 1453 |
+
|
| 1454 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
| 1455 |
+
if labels is not None:
|
| 1456 |
+
loss_fct = CrossEntropyLoss()
|
| 1457 |
+
# Only keep active parts of the loss
|
| 1458 |
+
if attention_mask is not None:
|
| 1459 |
+
active_loss = attention_mask.view(-1) == 1
|
| 1460 |
+
active_logits = logits.view(-1, self.num_labels)
|
| 1461 |
+
active_labels = torch.where(
|
| 1462 |
+
active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
|
| 1463 |
+
)
|
| 1464 |
+
loss = loss_fct(active_logits, active_labels)
|
| 1465 |
+
else:
|
| 1466 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1467 |
+
outputs = (loss,) + outputs
|
| 1468 |
+
|
| 1469 |
+
return outputs # (loss), scores, (hidden_states), (attentions)
|
| 1470 |
+
|
| 1471 |
+
|
| 1472 |
+
@add_start_docstrings(
|
| 1473 |
+
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
| 1474 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
| 1475 |
+
BERT_START_DOCSTRING,
|
| 1476 |
+
)
|
| 1477 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
| 1478 |
+
def __init__(self, config):
|
| 1479 |
+
super().__init__(config)
|
| 1480 |
+
self.num_labels = config.num_labels
|
| 1481 |
+
|
| 1482 |
+
self.bert = BertModel(config)
|
| 1483 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 1484 |
+
|
| 1485 |
+
self.init_weights()
|
| 1486 |
+
|
| 1487 |
+
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
| 1488 |
+
@add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
|
| 1489 |
+
def forward(
|
| 1490 |
+
self,
|
| 1491 |
+
input_ids=None,
|
| 1492 |
+
attention_mask=None,
|
| 1493 |
+
token_type_ids=None,
|
| 1494 |
+
position_ids=None,
|
| 1495 |
+
head_mask=None,
|
| 1496 |
+
inputs_embeds=None,
|
| 1497 |
+
start_positions=None,
|
| 1498 |
+
end_positions=None,
|
| 1499 |
+
output_attentions=None,
|
| 1500 |
+
output_hidden_states=None,
|
| 1501 |
+
):
|
| 1502 |
+
r"""
|
| 1503 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1504 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 1505 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1506 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1507 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
| 1508 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 1509 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
| 1510 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
| 1511 |
+
|
| 1512 |
+
Returns:
|
| 1513 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
| 1514 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
| 1515 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
| 1516 |
+
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
| 1517 |
+
Span-start scores (before SoftMax).
|
| 1518 |
+
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
| 1519 |
+
Span-end scores (before SoftMax).
|
| 1520 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 1521 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 1522 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 1523 |
+
|
| 1524 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 1525 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 1526 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
| 1527 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
| 1528 |
+
|
| 1529 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 1530 |
+
heads.
|
| 1531 |
+
"""
|
| 1532 |
+
|
| 1533 |
+
outputs = self.bert(
|
| 1534 |
+
input_ids,
|
| 1535 |
+
attention_mask=attention_mask,
|
| 1536 |
+
token_type_ids=token_type_ids,
|
| 1537 |
+
position_ids=position_ids,
|
| 1538 |
+
head_mask=head_mask,
|
| 1539 |
+
inputs_embeds=inputs_embeds,
|
| 1540 |
+
output_attentions=output_attentions,
|
| 1541 |
+
output_hidden_states=output_hidden_states,
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
sequence_output = outputs[0]
|
| 1545 |
+
|
| 1546 |
+
logits = self.qa_outputs(sequence_output)
|
| 1547 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1548 |
+
start_logits = start_logits.squeeze(-1)
|
| 1549 |
+
end_logits = end_logits.squeeze(-1)
|
| 1550 |
+
|
| 1551 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
| 1552 |
+
if start_positions is not None and end_positions is not None:
|
| 1553 |
+
# If we are on multi-GPU, split add a dimension
|
| 1554 |
+
if len(start_positions.size()) > 1:
|
| 1555 |
+
start_positions = start_positions.squeeze(-1)
|
| 1556 |
+
if len(end_positions.size()) > 1:
|
| 1557 |
+
end_positions = end_positions.squeeze(-1)
|
| 1558 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1559 |
+
ignored_index = start_logits.size(1)
|
| 1560 |
+
start_positions.clamp_(0, ignored_index)
|
| 1561 |
+
end_positions.clamp_(0, ignored_index)
|
| 1562 |
+
|
| 1563 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1564 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1565 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1566 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1567 |
+
outputs = (total_loss,) + outputs
|
| 1568 |
+
|
| 1569 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
CGFormer/bert/modeling_utils.py
ADDED
|
@@ -0,0 +1,1268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch import Tensor, device, dtype, nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
|
| 27 |
+
from .activations import get_activation
|
| 28 |
+
from .configuration_utils import PretrainedConfig
|
| 29 |
+
from .file_utils import (
|
| 30 |
+
DUMMY_INPUTS,
|
| 31 |
+
TF2_WEIGHTS_NAME,
|
| 32 |
+
TF_WEIGHTS_NAME,
|
| 33 |
+
WEIGHTS_NAME,
|
| 34 |
+
cached_path,
|
| 35 |
+
hf_bucket_url,
|
| 36 |
+
is_remote_url,
|
| 37 |
+
)
|
| 38 |
+
from .generation_utils import GenerationMixin
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
from torch.nn import Identity
|
| 46 |
+
except ImportError:
|
| 47 |
+
# Older PyTorch compatibility
|
| 48 |
+
class Identity(nn.Module):
|
| 49 |
+
r"""A placeholder identity operator that is argument-insensitive.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, *args, **kwargs):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
def forward(self, input):
|
| 56 |
+
return input
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def find_pruneable_heads_and_indices(
|
| 60 |
+
heads: List, n_heads: int, head_size: int, already_pruned_heads: set
|
| 61 |
+
) -> Tuple[set, "torch.LongTensor"]:
|
| 62 |
+
mask = torch.ones(n_heads, head_size)
|
| 63 |
+
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
| 64 |
+
for head in heads:
|
| 65 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
| 66 |
+
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
| 67 |
+
mask[head] = 0
|
| 68 |
+
mask = mask.view(-1).contiguous().eq(1)
|
| 69 |
+
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
| 70 |
+
return heads, index
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class ModuleUtilsMixin:
|
| 74 |
+
"""
|
| 75 |
+
A few utilities for torch.nn.Modules, to be used as a mixin.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def num_parameters(self, only_trainable: bool = False) -> int:
|
| 79 |
+
"""
|
| 80 |
+
Get number of (optionally, trainable) parameters in the module.
|
| 81 |
+
"""
|
| 82 |
+
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
| 83 |
+
return sum(p.numel() for p in params)
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
|
| 87 |
+
try:
|
| 88 |
+
import psutil
|
| 89 |
+
except (ImportError):
|
| 90 |
+
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
| 91 |
+
|
| 92 |
+
process = psutil.Process(os.getpid())
|
| 93 |
+
mem = process.memory_info()
|
| 94 |
+
module.mem_rss_pre_forward = mem.rss
|
| 95 |
+
return None
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def _hook_rss_memory_post_forward(module, *args, **kwargs):
|
| 99 |
+
try:
|
| 100 |
+
import psutil
|
| 101 |
+
except (ImportError):
|
| 102 |
+
raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
|
| 103 |
+
|
| 104 |
+
process = psutil.Process(os.getpid())
|
| 105 |
+
mem = process.memory_info()
|
| 106 |
+
module.mem_rss_post_forward = mem.rss
|
| 107 |
+
mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
|
| 108 |
+
module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
def add_memory_hooks(self):
|
| 112 |
+
""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
|
| 113 |
+
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
|
| 114 |
+
"""
|
| 115 |
+
for module in self.modules():
|
| 116 |
+
module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
|
| 117 |
+
module.register_forward_hook(self._hook_rss_memory_post_forward)
|
| 118 |
+
self.reset_memory_hooks_state()
|
| 119 |
+
|
| 120 |
+
def reset_memory_hooks_state(self):
|
| 121 |
+
for module in self.modules():
|
| 122 |
+
module.mem_rss_diff = 0
|
| 123 |
+
module.mem_rss_post_forward = 0
|
| 124 |
+
module.mem_rss_pre_forward = 0
|
| 125 |
+
|
| 126 |
+
@property
|
| 127 |
+
def device(self) -> device:
|
| 128 |
+
"""
|
| 129 |
+
Get torch.device from module, assuming that the whole module has one device.
|
| 130 |
+
"""
|
| 131 |
+
try:
|
| 132 |
+
return next(self.parameters()).device
|
| 133 |
+
except StopIteration:
|
| 134 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
| 135 |
+
|
| 136 |
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
| 137 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 138 |
+
return tuples
|
| 139 |
+
|
| 140 |
+
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
| 141 |
+
first_tuple = next(gen)
|
| 142 |
+
return first_tuple[1].device
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def dtype(self) -> dtype:
|
| 146 |
+
"""
|
| 147 |
+
Get torch.dtype from module, assuming that the whole module has one dtype.
|
| 148 |
+
"""
|
| 149 |
+
try:
|
| 150 |
+
return next(self.parameters()).dtype
|
| 151 |
+
except StopIteration:
|
| 152 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
| 153 |
+
|
| 154 |
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
| 155 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 156 |
+
return tuples
|
| 157 |
+
|
| 158 |
+
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
| 159 |
+
first_tuple = next(gen)
|
| 160 |
+
return first_tuple[1].dtype
|
| 161 |
+
|
| 162 |
+
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
| 163 |
+
"""type: torch.Tensor -> torch.Tensor"""
|
| 164 |
+
if encoder_attention_mask.dim() == 3:
|
| 165 |
+
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
| 166 |
+
if encoder_attention_mask.dim() == 2:
|
| 167 |
+
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
| 168 |
+
# T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
|
| 169 |
+
# Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
|
| 170 |
+
# /transformer/transformer_layers.py#L270
|
| 171 |
+
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
|
| 172 |
+
# encoder_extended_attention_mask.transpose(-1, -2))
|
| 173 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 174 |
+
|
| 175 |
+
if self.dtype == torch.float16:
|
| 176 |
+
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
|
| 177 |
+
elif self.dtype == torch.float32:
|
| 178 |
+
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
| 179 |
+
else:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
|
| 182 |
+
self.dtype
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
return encoder_extended_attention_mask
|
| 187 |
+
|
| 188 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
|
| 189 |
+
"""Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
|
| 190 |
+
|
| 191 |
+
Arguments:
|
| 192 |
+
attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
|
| 193 |
+
input_shape: tuple, shape of input_ids
|
| 194 |
+
device: torch.Device, usually self.device
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
torch.Tensor with dtype of attention_mask.dtype
|
| 198 |
+
"""
|
| 199 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 200 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 201 |
+
if attention_mask.dim() == 3:
|
| 202 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 203 |
+
elif attention_mask.dim() == 2:
|
| 204 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 205 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
| 206 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 207 |
+
if self.config.is_decoder:
|
| 208 |
+
batch_size, seq_length = input_shape
|
| 209 |
+
seq_ids = torch.arange(seq_length, device=device)
|
| 210 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
| 211 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
| 212 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
| 213 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
| 214 |
+
else:
|
| 215 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 219 |
+
input_shape, attention_mask.shape
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 224 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 225 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 226 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 227 |
+
# effectively the same as removing these entirely.
|
| 228 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 229 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 230 |
+
return extended_attention_mask
|
| 231 |
+
|
| 232 |
+
def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
|
| 233 |
+
"""
|
| 234 |
+
# Prepare head mask if needed
|
| 235 |
+
# 1.0 in head_mask indicate we keep the head
|
| 236 |
+
attention_probs has shape bsz x n_heads x N x N
|
| 237 |
+
Arguments:
|
| 238 |
+
head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 239 |
+
num_hidden_layers: int
|
| 240 |
+
Returns:
|
| 241 |
+
Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 242 |
+
or list with [None] for each layer
|
| 243 |
+
"""
|
| 244 |
+
if head_mask is not None:
|
| 245 |
+
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
| 246 |
+
if is_attention_chunked is True:
|
| 247 |
+
head_mask = head_mask.unsqueeze(-1)
|
| 248 |
+
else:
|
| 249 |
+
head_mask = [None] * num_hidden_layers
|
| 250 |
+
|
| 251 |
+
return head_mask
|
| 252 |
+
|
| 253 |
+
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
|
| 254 |
+
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
|
| 255 |
+
if head_mask.dim() == 1:
|
| 256 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
| 257 |
+
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
|
| 258 |
+
elif head_mask.dim() == 2:
|
| 259 |
+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
| 260 |
+
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
|
| 261 |
+
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
|
| 262 |
+
return head_mask
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
| 266 |
+
r""" Base class for all models.
|
| 267 |
+
|
| 268 |
+
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
| 269 |
+
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
| 270 |
+
|
| 271 |
+
Class attributes (overridden by derived classes):
|
| 272 |
+
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
| 273 |
+
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
| 274 |
+
|
| 275 |
+
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
|
| 276 |
+
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
|
| 277 |
+
- ``path``: a path (string) to the TensorFlow checkpoint.
|
| 278 |
+
|
| 279 |
+
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
| 280 |
+
"""
|
| 281 |
+
config_class = None
|
| 282 |
+
base_model_prefix = ""
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def dummy_inputs(self):
|
| 286 |
+
""" Dummy inputs to do a forward pass in the network.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
torch.Tensor with dummy inputs
|
| 290 |
+
"""
|
| 291 |
+
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
|
| 292 |
+
|
| 293 |
+
def __init__(self, config, *inputs, **kwargs):
|
| 294 |
+
super().__init__()
|
| 295 |
+
if not isinstance(config, PretrainedConfig):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
| 298 |
+
"To create a model from a pretrained model use "
|
| 299 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
| 300 |
+
self.__class__.__name__, self.__class__.__name__
|
| 301 |
+
)
|
| 302 |
+
)
|
| 303 |
+
# Save config in model
|
| 304 |
+
self.config = config
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def base_model(self):
|
| 308 |
+
return getattr(self, self.base_model_prefix, self)
|
| 309 |
+
|
| 310 |
+
def get_input_embeddings(self):
|
| 311 |
+
"""
|
| 312 |
+
Returns the model's input embeddings.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
:obj:`nn.Module`:
|
| 316 |
+
A torch module mapping vocabulary to hidden states.
|
| 317 |
+
"""
|
| 318 |
+
base_model = getattr(self, self.base_model_prefix, self)
|
| 319 |
+
if base_model is not self:
|
| 320 |
+
return base_model.get_input_embeddings()
|
| 321 |
+
else:
|
| 322 |
+
raise NotImplementedError
|
| 323 |
+
|
| 324 |
+
def set_input_embeddings(self, value: nn.Module):
|
| 325 |
+
"""
|
| 326 |
+
Set model's input embeddings
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
value (:obj:`nn.Module`):
|
| 330 |
+
A module mapping vocabulary to hidden states.
|
| 331 |
+
"""
|
| 332 |
+
base_model = getattr(self, self.base_model_prefix, self)
|
| 333 |
+
if base_model is not self:
|
| 334 |
+
base_model.set_input_embeddings(value)
|
| 335 |
+
else:
|
| 336 |
+
raise NotImplementedError
|
| 337 |
+
|
| 338 |
+
def get_output_embeddings(self):
|
| 339 |
+
"""
|
| 340 |
+
Returns the model's output embeddings.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
:obj:`nn.Module`:
|
| 344 |
+
A torch module mapping hidden states to vocabulary.
|
| 345 |
+
"""
|
| 346 |
+
return None # Overwrite for models with output embeddings
|
| 347 |
+
|
| 348 |
+
def tie_weights(self):
|
| 349 |
+
"""
|
| 350 |
+
Tie the weights between the input embeddings and the output embeddings.
|
| 351 |
+
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
|
| 352 |
+
the weights instead.
|
| 353 |
+
"""
|
| 354 |
+
output_embeddings = self.get_output_embeddings()
|
| 355 |
+
if output_embeddings is not None:
|
| 356 |
+
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
|
| 357 |
+
|
| 358 |
+
def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
|
| 359 |
+
""" Tie or clone module weights depending of whether we are using TorchScript or not
|
| 360 |
+
"""
|
| 361 |
+
if self.config.torchscript:
|
| 362 |
+
output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
|
| 363 |
+
else:
|
| 364 |
+
output_embeddings.weight = input_embeddings.weight
|
| 365 |
+
|
| 366 |
+
if getattr(output_embeddings, "bias", None) is not None:
|
| 367 |
+
output_embeddings.bias.data = torch.nn.functional.pad(
|
| 368 |
+
output_embeddings.bias.data,
|
| 369 |
+
(0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
|
| 370 |
+
"constant",
|
| 371 |
+
0,
|
| 372 |
+
)
|
| 373 |
+
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
|
| 374 |
+
output_embeddings.out_features = input_embeddings.num_embeddings
|
| 375 |
+
|
| 376 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
|
| 377 |
+
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
| 378 |
+
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
| 379 |
+
|
| 380 |
+
Arguments:
|
| 381 |
+
|
| 382 |
+
new_num_tokens: (`optional`) int:
|
| 383 |
+
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
| 384 |
+
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
|
| 385 |
+
|
| 386 |
+
Return: ``torch.nn.Embeddings``
|
| 387 |
+
Pointer to the input tokens Embeddings Module of the model
|
| 388 |
+
"""
|
| 389 |
+
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
|
| 390 |
+
model_embeds = base_model._resize_token_embeddings(new_num_tokens)
|
| 391 |
+
if new_num_tokens is None:
|
| 392 |
+
return model_embeds
|
| 393 |
+
|
| 394 |
+
# Update base model and current model config
|
| 395 |
+
self.config.vocab_size = new_num_tokens
|
| 396 |
+
base_model.vocab_size = new_num_tokens
|
| 397 |
+
|
| 398 |
+
# Tie weights again if needed
|
| 399 |
+
self.tie_weights()
|
| 400 |
+
|
| 401 |
+
return model_embeds
|
| 402 |
+
|
| 403 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
| 404 |
+
old_embeddings = self.get_input_embeddings()
|
| 405 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
| 406 |
+
self.set_input_embeddings(new_embeddings)
|
| 407 |
+
return self.get_input_embeddings()
|
| 408 |
+
|
| 409 |
+
def _get_resized_embeddings(
|
| 410 |
+
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
|
| 411 |
+
) -> torch.nn.Embedding:
|
| 412 |
+
""" Build a resized Embedding Module from a provided token Embedding Module.
|
| 413 |
+
Increasing the size will add newly initialized vectors at the end
|
| 414 |
+
Reducing the size will remove vectors from the end
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
old_embeddings: ``torch.nn.Embedding``
|
| 418 |
+
Old embeddings to be resized.
|
| 419 |
+
new_num_tokens: (`optional`) int
|
| 420 |
+
New number of tokens in the embedding matrix.
|
| 421 |
+
Increasing the size will add newly initialized vectors at the end
|
| 422 |
+
Reducing the size will remove vectors from the end
|
| 423 |
+
If not provided or None: return the provided token Embedding Module.
|
| 424 |
+
Return: ``torch.nn.Embedding``
|
| 425 |
+
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
| 426 |
+
"""
|
| 427 |
+
if new_num_tokens is None:
|
| 428 |
+
return old_embeddings
|
| 429 |
+
|
| 430 |
+
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
| 431 |
+
if old_num_tokens == new_num_tokens:
|
| 432 |
+
return old_embeddings
|
| 433 |
+
|
| 434 |
+
# Build new embeddings
|
| 435 |
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
| 436 |
+
new_embeddings.to(old_embeddings.weight.device)
|
| 437 |
+
|
| 438 |
+
# initialize all new embeddings (in particular added tokens)
|
| 439 |
+
self._init_weights(new_embeddings)
|
| 440 |
+
|
| 441 |
+
# Copy token embeddings from the previous weights
|
| 442 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
| 443 |
+
new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
| 444 |
+
|
| 445 |
+
return new_embeddings
|
| 446 |
+
|
| 447 |
+
def init_weights(self):
|
| 448 |
+
""" Initialize and prunes weights if needed. """
|
| 449 |
+
# Initialize weights
|
| 450 |
+
self.apply(self._init_weights)
|
| 451 |
+
|
| 452 |
+
# Prune heads if needed
|
| 453 |
+
if self.config.pruned_heads:
|
| 454 |
+
self.prune_heads(self.config.pruned_heads)
|
| 455 |
+
|
| 456 |
+
# Tie weights if needed
|
| 457 |
+
self.tie_weights()
|
| 458 |
+
|
| 459 |
+
def prune_heads(self, heads_to_prune: Dict):
|
| 460 |
+
""" Prunes heads of the base model.
|
| 461 |
+
|
| 462 |
+
Arguments:
|
| 463 |
+
|
| 464 |
+
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
| 465 |
+
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
|
| 466 |
+
"""
|
| 467 |
+
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
|
| 468 |
+
for layer, heads in heads_to_prune.items():
|
| 469 |
+
union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
|
| 470 |
+
self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
|
| 471 |
+
|
| 472 |
+
self.base_model._prune_heads(heads_to_prune)
|
| 473 |
+
|
| 474 |
+
def save_pretrained(self, save_directory):
|
| 475 |
+
""" Save a model and its configuration file to a directory, so that it
|
| 476 |
+
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
|
| 477 |
+
|
| 478 |
+
Arguments:
|
| 479 |
+
save_directory: directory to which to save.
|
| 480 |
+
"""
|
| 481 |
+
if os.path.isfile(save_directory):
|
| 482 |
+
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
|
| 483 |
+
return
|
| 484 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 485 |
+
|
| 486 |
+
# Only save the model itself if we are using distributed training
|
| 487 |
+
model_to_save = self.module if hasattr(self, "module") else self
|
| 488 |
+
|
| 489 |
+
# Attach architecture to the config
|
| 490 |
+
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
| 491 |
+
|
| 492 |
+
# If we save using the predefined names, we can load using `from_pretrained`
|
| 493 |
+
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
| 494 |
+
|
| 495 |
+
if getattr(self.config, "xla_device", False):
|
| 496 |
+
import torch_xla.core.xla_model as xm
|
| 497 |
+
|
| 498 |
+
if xm.is_master_ordinal():
|
| 499 |
+
# Save configuration file
|
| 500 |
+
model_to_save.config.save_pretrained(save_directory)
|
| 501 |
+
# xm.save takes care of saving only from master
|
| 502 |
+
xm.save(model_to_save.state_dict(), output_model_file)
|
| 503 |
+
else:
|
| 504 |
+
model_to_save.config.save_pretrained(save_directory)
|
| 505 |
+
torch.save(model_to_save.state_dict(), output_model_file)
|
| 506 |
+
|
| 507 |
+
logger.info("Model weights saved in {}".format(output_model_file))
|
| 508 |
+
|
| 509 |
+
@classmethod
|
| 510 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 511 |
+
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
| 512 |
+
|
| 513 |
+
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
|
| 514 |
+
To train the model, you should first set it back in training mode with ``model.train()``
|
| 515 |
+
|
| 516 |
+
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
|
| 517 |
+
It is up to you to train those weights with a downstream fine-tuning task.
|
| 518 |
+
|
| 519 |
+
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
|
| 520 |
+
|
| 521 |
+
Parameters:
|
| 522 |
+
pretrained_model_name_or_path: either:
|
| 523 |
+
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
| 524 |
+
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
|
| 525 |
+
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
| 526 |
+
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
|
| 527 |
+
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
|
| 528 |
+
|
| 529 |
+
model_args: (`optional`) Sequence of positional arguments:
|
| 530 |
+
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
| 531 |
+
|
| 532 |
+
config: (`optional`) one of:
|
| 533 |
+
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
|
| 534 |
+
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
|
| 535 |
+
|
| 536 |
+
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
| 537 |
+
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
| 538 |
+
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
| 539 |
+
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
| 540 |
+
|
| 541 |
+
state_dict: (`optional`) dict:
|
| 542 |
+
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
|
| 543 |
+
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
|
| 544 |
+
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
|
| 545 |
+
|
| 546 |
+
cache_dir: (`optional`) string:
|
| 547 |
+
Path to a directory in which a downloaded pre-trained model
|
| 548 |
+
configuration should be cached if the standard cache should not be used.
|
| 549 |
+
|
| 550 |
+
force_download: (`optional`) boolean, default False:
|
| 551 |
+
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
| 552 |
+
|
| 553 |
+
resume_download: (`optional`) boolean, default False:
|
| 554 |
+
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
|
| 555 |
+
|
| 556 |
+
proxies: (`optional`) dict, default None:
|
| 557 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
| 558 |
+
The proxies are used on each request.
|
| 559 |
+
|
| 560 |
+
output_loading_info: (`optional`) boolean:
|
| 561 |
+
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
|
| 562 |
+
|
| 563 |
+
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
| 564 |
+
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
| 565 |
+
|
| 566 |
+
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
| 567 |
+
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
| 568 |
+
|
| 569 |
+
Examples::
|
| 570 |
+
|
| 571 |
+
# For example purposes. Not runnable.
|
| 572 |
+
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
| 573 |
+
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
| 574 |
+
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
| 575 |
+
assert model.config.output_attention == True
|
| 576 |
+
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
| 577 |
+
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
| 578 |
+
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
|
| 579 |
+
|
| 580 |
+
"""
|
| 581 |
+
config = kwargs.pop("config", None)
|
| 582 |
+
state_dict = kwargs.pop("state_dict", None)
|
| 583 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 584 |
+
from_tf = kwargs.pop("from_tf", False)
|
| 585 |
+
force_download = kwargs.pop("force_download", False)
|
| 586 |
+
resume_download = kwargs.pop("resume_download", False)
|
| 587 |
+
proxies = kwargs.pop("proxies", None)
|
| 588 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
| 589 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 590 |
+
use_cdn = kwargs.pop("use_cdn", True)
|
| 591 |
+
|
| 592 |
+
# Load config if we don't provide a configuration
|
| 593 |
+
if not isinstance(config, PretrainedConfig):
|
| 594 |
+
config_path = config if config is not None else pretrained_model_name_or_path
|
| 595 |
+
config, model_kwargs = cls.config_class.from_pretrained(
|
| 596 |
+
config_path,
|
| 597 |
+
*model_args,
|
| 598 |
+
cache_dir=cache_dir,
|
| 599 |
+
return_unused_kwargs=True,
|
| 600 |
+
force_download=force_download,
|
| 601 |
+
resume_download=resume_download,
|
| 602 |
+
proxies=proxies,
|
| 603 |
+
local_files_only=local_files_only,
|
| 604 |
+
**kwargs,
|
| 605 |
+
)
|
| 606 |
+
else:
|
| 607 |
+
model_kwargs = kwargs
|
| 608 |
+
|
| 609 |
+
# Load model
|
| 610 |
+
if pretrained_model_name_or_path is not None:
|
| 611 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
| 612 |
+
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
| 613 |
+
# Load from a TF 1.0 checkpoint
|
| 614 |
+
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
| 615 |
+
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
| 616 |
+
# Load from a TF 2.0 checkpoint
|
| 617 |
+
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
| 618 |
+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
| 619 |
+
# Load from a PyTorch checkpoint
|
| 620 |
+
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
| 621 |
+
else:
|
| 622 |
+
raise EnvironmentError(
|
| 623 |
+
"Error no file named {} found in directory {} or `from_tf` set to False".format(
|
| 624 |
+
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
|
| 625 |
+
pretrained_model_name_or_path,
|
| 626 |
+
)
|
| 627 |
+
)
|
| 628 |
+
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
| 629 |
+
archive_file = pretrained_model_name_or_path
|
| 630 |
+
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
|
| 631 |
+
assert (
|
| 632 |
+
from_tf
|
| 633 |
+
), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
|
| 634 |
+
pretrained_model_name_or_path + ".index"
|
| 635 |
+
)
|
| 636 |
+
archive_file = pretrained_model_name_or_path + ".index"
|
| 637 |
+
else:
|
| 638 |
+
archive_file = hf_bucket_url(
|
| 639 |
+
pretrained_model_name_or_path,
|
| 640 |
+
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
|
| 641 |
+
use_cdn=use_cdn,
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
try:
|
| 645 |
+
# Load from URL or cache if already cached
|
| 646 |
+
resolved_archive_file = cached_path(
|
| 647 |
+
archive_file,
|
| 648 |
+
cache_dir=cache_dir,
|
| 649 |
+
force_download=force_download,
|
| 650 |
+
proxies=proxies,
|
| 651 |
+
resume_download=resume_download,
|
| 652 |
+
local_files_only=local_files_only,
|
| 653 |
+
)
|
| 654 |
+
if resolved_archive_file is None:
|
| 655 |
+
raise EnvironmentError
|
| 656 |
+
except EnvironmentError:
|
| 657 |
+
msg = (
|
| 658 |
+
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
| 659 |
+
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
| 660 |
+
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
|
| 661 |
+
)
|
| 662 |
+
raise EnvironmentError(msg)
|
| 663 |
+
|
| 664 |
+
if resolved_archive_file == archive_file:
|
| 665 |
+
logger.info("loading weights file {}".format(archive_file))
|
| 666 |
+
else:
|
| 667 |
+
logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
|
| 668 |
+
else:
|
| 669 |
+
resolved_archive_file = None
|
| 670 |
+
|
| 671 |
+
# Instantiate model.
|
| 672 |
+
model = cls(config, *model_args, **model_kwargs)
|
| 673 |
+
|
| 674 |
+
if state_dict is None and not from_tf:
|
| 675 |
+
try:
|
| 676 |
+
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
| 677 |
+
except Exception:
|
| 678 |
+
raise OSError(
|
| 679 |
+
"Unable to load weights from pytorch checkpoint file. "
|
| 680 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
missing_keys = []
|
| 684 |
+
unexpected_keys = []
|
| 685 |
+
error_msgs = []
|
| 686 |
+
|
| 687 |
+
if from_tf:
|
| 688 |
+
if resolved_archive_file.endswith(".index"):
|
| 689 |
+
# Load from a TensorFlow 1.X checkpoint - provided by original authors
|
| 690 |
+
model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
|
| 691 |
+
else:
|
| 692 |
+
# Load from our TensorFlow 2.0 checkpoints
|
| 693 |
+
try:
|
| 694 |
+
from transformers import load_tf2_checkpoint_in_pytorch_model
|
| 695 |
+
|
| 696 |
+
model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
|
| 697 |
+
except ImportError:
|
| 698 |
+
logger.error(
|
| 699 |
+
"Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
|
| 700 |
+
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
|
| 701 |
+
)
|
| 702 |
+
raise
|
| 703 |
+
else:
|
| 704 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
| 705 |
+
old_keys = []
|
| 706 |
+
new_keys = []
|
| 707 |
+
for key in state_dict.keys():
|
| 708 |
+
new_key = None
|
| 709 |
+
if "gamma" in key:
|
| 710 |
+
new_key = key.replace("gamma", "weight")
|
| 711 |
+
if "beta" in key:
|
| 712 |
+
new_key = key.replace("beta", "bias")
|
| 713 |
+
if new_key:
|
| 714 |
+
old_keys.append(key)
|
| 715 |
+
new_keys.append(new_key)
|
| 716 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
| 717 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 718 |
+
|
| 719 |
+
# copy state_dict so _load_from_state_dict can modify it
|
| 720 |
+
metadata = getattr(state_dict, "_metadata", None)
|
| 721 |
+
state_dict = state_dict.copy()
|
| 722 |
+
if metadata is not None:
|
| 723 |
+
state_dict._metadata = metadata
|
| 724 |
+
|
| 725 |
+
##############################################################################################
|
| 726 |
+
# Print out state_dict's contents: keys
|
| 727 |
+
'''
|
| 728 |
+
for key, _ in state_dict.items():
|
| 729 |
+
print(key)
|
| 730 |
+
'''
|
| 731 |
+
##############################################################################################
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
| 735 |
+
# so we need to apply the function recursively.
|
| 736 |
+
def load(module: nn.Module, prefix=""):
|
| 737 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
| 738 |
+
module._load_from_state_dict(
|
| 739 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
|
| 740 |
+
)
|
| 741 |
+
for name, child in module._modules.items():
|
| 742 |
+
if child is not None:
|
| 743 |
+
load(child, prefix + name + ".")
|
| 744 |
+
|
| 745 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
| 746 |
+
start_prefix = ""
|
| 747 |
+
model_to_load = model
|
| 748 |
+
has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
|
| 749 |
+
if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
|
| 750 |
+
start_prefix = cls.base_model_prefix + "."
|
| 751 |
+
if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
|
| 752 |
+
model_to_load = getattr(model, cls.base_model_prefix)
|
| 753 |
+
|
| 754 |
+
load(model_to_load, prefix=start_prefix)
|
| 755 |
+
|
| 756 |
+
if model.__class__.__name__ != model_to_load.__class__.__name__:
|
| 757 |
+
base_model_state_dict = model_to_load.state_dict().keys()
|
| 758 |
+
head_model_state_dict_without_base_prefix = [
|
| 759 |
+
key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
|
| 760 |
+
]
|
| 761 |
+
|
| 762 |
+
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
| 763 |
+
|
| 764 |
+
if len(unexpected_keys) > 0:
|
| 765 |
+
logger.warning(
|
| 766 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
| 767 |
+
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
| 768 |
+
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
| 769 |
+
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
| 770 |
+
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
| 771 |
+
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
| 772 |
+
)
|
| 773 |
+
else:
|
| 774 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
| 775 |
+
if len(missing_keys) > 0:
|
| 776 |
+
logger.warning(
|
| 777 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
| 778 |
+
f"and are newly initialized: {missing_keys}\n"
|
| 779 |
+
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
| 780 |
+
)
|
| 781 |
+
else:
|
| 782 |
+
logger.info(
|
| 783 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
| 784 |
+
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
| 785 |
+
f"you can already use {model.__class__.__name__} for predictions without further training."
|
| 786 |
+
)
|
| 787 |
+
if len(error_msgs) > 0:
|
| 788 |
+
raise RuntimeError(
|
| 789 |
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
| 790 |
+
model.__class__.__name__, "\n\t".join(error_msgs)
|
| 791 |
+
)
|
| 792 |
+
)
|
| 793 |
+
model.tie_weights() # make sure token embedding weights are still tied if needed
|
| 794 |
+
|
| 795 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
| 796 |
+
model.eval()
|
| 797 |
+
|
| 798 |
+
if output_loading_info:
|
| 799 |
+
loading_info = {
|
| 800 |
+
"missing_keys": missing_keys,
|
| 801 |
+
"unexpected_keys": unexpected_keys,
|
| 802 |
+
"error_msgs": error_msgs,
|
| 803 |
+
}
|
| 804 |
+
return model, loading_info
|
| 805 |
+
|
| 806 |
+
if hasattr(config, "xla_device") and config.xla_device:
|
| 807 |
+
import torch_xla.core.xla_model as xm
|
| 808 |
+
|
| 809 |
+
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
| 810 |
+
model.to(xm.xla_device())
|
| 811 |
+
|
| 812 |
+
return model
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class Conv1D(nn.Module):
|
| 816 |
+
def __init__(self, nf, nx):
|
| 817 |
+
""" Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
| 818 |
+
Basically works like a Linear layer but the weights are transposed
|
| 819 |
+
"""
|
| 820 |
+
super().__init__()
|
| 821 |
+
self.nf = nf
|
| 822 |
+
w = torch.empty(nx, nf)
|
| 823 |
+
nn.init.normal_(w, std=0.02)
|
| 824 |
+
self.weight = nn.Parameter(w)
|
| 825 |
+
self.bias = nn.Parameter(torch.zeros(nf))
|
| 826 |
+
|
| 827 |
+
def forward(self, x):
|
| 828 |
+
size_out = x.size()[:-1] + (self.nf,)
|
| 829 |
+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
| 830 |
+
x = x.view(*size_out)
|
| 831 |
+
return x
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
class PoolerStartLogits(nn.Module):
|
| 835 |
+
""" Compute SQuAD start_logits from sequence hidden states. """
|
| 836 |
+
|
| 837 |
+
def __init__(self, config):
|
| 838 |
+
super().__init__()
|
| 839 |
+
self.dense = nn.Linear(config.hidden_size, 1)
|
| 840 |
+
|
| 841 |
+
def forward(self, hidden_states, p_mask=None):
|
| 842 |
+
""" Args:
|
| 843 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
| 844 |
+
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
| 845 |
+
1.0 means token should be masked.
|
| 846 |
+
"""
|
| 847 |
+
x = self.dense(hidden_states).squeeze(-1)
|
| 848 |
+
|
| 849 |
+
if p_mask is not None:
|
| 850 |
+
if next(self.parameters()).dtype == torch.float16:
|
| 851 |
+
x = x * (1 - p_mask) - 65500 * p_mask
|
| 852 |
+
else:
|
| 853 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 854 |
+
|
| 855 |
+
return x
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
class PoolerEndLogits(nn.Module):
|
| 859 |
+
""" Compute SQuAD end_logits from sequence hidden states and start token hidden state.
|
| 860 |
+
"""
|
| 861 |
+
|
| 862 |
+
def __init__(self, config):
|
| 863 |
+
super().__init__()
|
| 864 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 865 |
+
self.activation = nn.Tanh()
|
| 866 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 867 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1)
|
| 868 |
+
|
| 869 |
+
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
| 870 |
+
""" Args:
|
| 871 |
+
One of ``start_states``, ``start_positions`` should be not None.
|
| 872 |
+
If both are set, ``start_positions`` overrides ``start_states``.
|
| 873 |
+
|
| 874 |
+
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
| 875 |
+
hidden states of the first tokens for the labeled span.
|
| 876 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 877 |
+
position of the first token for the labeled span:
|
| 878 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
| 879 |
+
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
| 880 |
+
1.0 means token should be masked.
|
| 881 |
+
"""
|
| 882 |
+
assert (
|
| 883 |
+
start_states is not None or start_positions is not None
|
| 884 |
+
), "One of start_states, start_positions should be not None"
|
| 885 |
+
if start_positions is not None:
|
| 886 |
+
slen, hsz = hidden_states.shape[-2:]
|
| 887 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 888 |
+
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
| 889 |
+
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
| 890 |
+
|
| 891 |
+
x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
|
| 892 |
+
x = self.activation(x)
|
| 893 |
+
x = self.LayerNorm(x)
|
| 894 |
+
x = self.dense_1(x).squeeze(-1)
|
| 895 |
+
|
| 896 |
+
if p_mask is not None:
|
| 897 |
+
if next(self.parameters()).dtype == torch.float16:
|
| 898 |
+
x = x * (1 - p_mask) - 65500 * p_mask
|
| 899 |
+
else:
|
| 900 |
+
x = x * (1 - p_mask) - 1e30 * p_mask
|
| 901 |
+
|
| 902 |
+
return x
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class PoolerAnswerClass(nn.Module):
|
| 906 |
+
""" Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
|
| 907 |
+
|
| 908 |
+
def __init__(self, config):
|
| 909 |
+
super().__init__()
|
| 910 |
+
self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
|
| 911 |
+
self.activation = nn.Tanh()
|
| 912 |
+
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
| 913 |
+
|
| 914 |
+
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
| 915 |
+
"""
|
| 916 |
+
Args:
|
| 917 |
+
One of ``start_states``, ``start_positions`` should be not None.
|
| 918 |
+
If both are set, ``start_positions`` overrides ``start_states``.
|
| 919 |
+
|
| 920 |
+
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
| 921 |
+
hidden states of the first tokens for the labeled span.
|
| 922 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 923 |
+
position of the first token for the labeled span.
|
| 924 |
+
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
| 925 |
+
position of the CLS token. If None, take the last token.
|
| 926 |
+
|
| 927 |
+
note(Original repo):
|
| 928 |
+
no dependency on end_feature so that we can obtain one single `cls_logits`
|
| 929 |
+
for each sample
|
| 930 |
+
"""
|
| 931 |
+
hsz = hidden_states.shape[-1]
|
| 932 |
+
assert (
|
| 933 |
+
start_states is not None or start_positions is not None
|
| 934 |
+
), "One of start_states, start_positions should be not None"
|
| 935 |
+
if start_positions is not None:
|
| 936 |
+
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 937 |
+
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
|
| 938 |
+
|
| 939 |
+
if cls_index is not None:
|
| 940 |
+
cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
| 941 |
+
cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
|
| 942 |
+
else:
|
| 943 |
+
cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
|
| 944 |
+
|
| 945 |
+
x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
|
| 946 |
+
x = self.activation(x)
|
| 947 |
+
x = self.dense_1(x).squeeze(-1)
|
| 948 |
+
|
| 949 |
+
return x
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
class SQuADHead(nn.Module):
|
| 953 |
+
r""" A SQuAD head inspired by XLNet.
|
| 954 |
+
|
| 955 |
+
Parameters:
|
| 956 |
+
config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
| 957 |
+
|
| 958 |
+
Inputs:
|
| 959 |
+
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
| 960 |
+
hidden states of sequence tokens
|
| 961 |
+
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 962 |
+
position of the first token for the labeled span.
|
| 963 |
+
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 964 |
+
position of the last token for the labeled span.
|
| 965 |
+
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
| 966 |
+
position of the CLS token. If None, take the last token.
|
| 967 |
+
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
| 968 |
+
Whether the question has a possible answer in the paragraph or not.
|
| 969 |
+
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
| 970 |
+
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
| 971 |
+
1.0 means token should be masked.
|
| 972 |
+
|
| 973 |
+
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
| 974 |
+
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
| 975 |
+
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
| 976 |
+
**start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 977 |
+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
| 978 |
+
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
| 979 |
+
**start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 980 |
+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
| 981 |
+
Indices for the top config.start_n_top start token possibilities (beam-search).
|
| 982 |
+
**end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 983 |
+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
| 984 |
+
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
| 985 |
+
**end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 986 |
+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
| 987 |
+
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
| 988 |
+
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
| 989 |
+
``torch.FloatTensor`` of shape ``(batch_size,)``
|
| 990 |
+
Log probabilities for the ``is_impossible`` label of the answers.
|
| 991 |
+
"""
|
| 992 |
+
|
| 993 |
+
def __init__(self, config):
|
| 994 |
+
super().__init__()
|
| 995 |
+
self.start_n_top = config.start_n_top
|
| 996 |
+
self.end_n_top = config.end_n_top
|
| 997 |
+
|
| 998 |
+
self.start_logits = PoolerStartLogits(config)
|
| 999 |
+
self.end_logits = PoolerEndLogits(config)
|
| 1000 |
+
self.answer_class = PoolerAnswerClass(config)
|
| 1001 |
+
|
| 1002 |
+
def forward(
|
| 1003 |
+
self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
|
| 1004 |
+
):
|
| 1005 |
+
outputs = ()
|
| 1006 |
+
|
| 1007 |
+
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
| 1008 |
+
|
| 1009 |
+
if start_positions is not None and end_positions is not None:
|
| 1010 |
+
# If we are on multi-GPU, let's remove the dimension added by batch splitting
|
| 1011 |
+
for x in (start_positions, end_positions, cls_index, is_impossible):
|
| 1012 |
+
if x is not None and x.dim() > 1:
|
| 1013 |
+
x.squeeze_(-1)
|
| 1014 |
+
|
| 1015 |
+
# during training, compute the end logits based on the ground truth of the start position
|
| 1016 |
+
end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
|
| 1017 |
+
|
| 1018 |
+
loss_fct = CrossEntropyLoss()
|
| 1019 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1020 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1021 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1022 |
+
|
| 1023 |
+
if cls_index is not None and is_impossible is not None:
|
| 1024 |
+
# Predict answerability from the representation of CLS and START
|
| 1025 |
+
cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
|
| 1026 |
+
loss_fct_cls = nn.BCEWithLogitsLoss()
|
| 1027 |
+
cls_loss = loss_fct_cls(cls_logits, is_impossible)
|
| 1028 |
+
|
| 1029 |
+
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
| 1030 |
+
total_loss += cls_loss * 0.5
|
| 1031 |
+
|
| 1032 |
+
outputs = (total_loss,) + outputs
|
| 1033 |
+
|
| 1034 |
+
else:
|
| 1035 |
+
# during inference, compute the end logits based on beam search
|
| 1036 |
+
bsz, slen, hsz = hidden_states.size()
|
| 1037 |
+
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
| 1038 |
+
|
| 1039 |
+
start_top_log_probs, start_top_index = torch.topk(
|
| 1040 |
+
start_log_probs, self.start_n_top, dim=-1
|
| 1041 |
+
) # shape (bsz, start_n_top)
|
| 1042 |
+
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
| 1043 |
+
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
| 1044 |
+
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
| 1045 |
+
|
| 1046 |
+
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
|
| 1047 |
+
start_states
|
| 1048 |
+
) # shape (bsz, slen, start_n_top, hsz)
|
| 1049 |
+
p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
|
| 1050 |
+
end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
|
| 1051 |
+
end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
|
| 1052 |
+
|
| 1053 |
+
end_top_log_probs, end_top_index = torch.topk(
|
| 1054 |
+
end_log_probs, self.end_n_top, dim=1
|
| 1055 |
+
) # shape (bsz, end_n_top, start_n_top)
|
| 1056 |
+
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
|
| 1057 |
+
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
|
| 1058 |
+
|
| 1059 |
+
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
|
| 1060 |
+
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
|
| 1061 |
+
|
| 1062 |
+
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
|
| 1063 |
+
|
| 1064 |
+
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
| 1065 |
+
# or (if labels are provided) (total_loss,)
|
| 1066 |
+
return outputs
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
class SequenceSummary(nn.Module):
|
| 1070 |
+
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
| 1071 |
+
Args of the config class:
|
| 1072 |
+
summary_type:
|
| 1073 |
+
- 'last' => [default] take the last token hidden state (like XLNet)
|
| 1074 |
+
- 'first' => take the first token hidden state (like Bert)
|
| 1075 |
+
- 'mean' => take the mean of all tokens hidden states
|
| 1076 |
+
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
| 1077 |
+
- 'attn' => Not implemented now, use multi-head attention
|
| 1078 |
+
summary_use_proj: Add a projection after the vector extraction
|
| 1079 |
+
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
| 1080 |
+
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
|
| 1081 |
+
summary_first_dropout: Add a dropout before the projection and activation
|
| 1082 |
+
summary_last_dropout: Add a dropout after the projection and activation
|
| 1083 |
+
"""
|
| 1084 |
+
|
| 1085 |
+
def __init__(self, config: PretrainedConfig):
|
| 1086 |
+
super().__init__()
|
| 1087 |
+
|
| 1088 |
+
self.summary_type = getattr(config, "summary_type", "last")
|
| 1089 |
+
if self.summary_type == "attn":
|
| 1090 |
+
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
| 1091 |
+
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
| 1092 |
+
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
| 1093 |
+
raise NotImplementedError
|
| 1094 |
+
|
| 1095 |
+
self.summary = Identity()
|
| 1096 |
+
if hasattr(config, "summary_use_proj") and config.summary_use_proj:
|
| 1097 |
+
if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
|
| 1098 |
+
num_classes = config.num_labels
|
| 1099 |
+
else:
|
| 1100 |
+
num_classes = config.hidden_size
|
| 1101 |
+
self.summary = nn.Linear(config.hidden_size, num_classes)
|
| 1102 |
+
|
| 1103 |
+
activation_string = getattr(config, "summary_activation", None)
|
| 1104 |
+
self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
|
| 1105 |
+
|
| 1106 |
+
self.first_dropout = Identity()
|
| 1107 |
+
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
| 1108 |
+
self.first_dropout = nn.Dropout(config.summary_first_dropout)
|
| 1109 |
+
|
| 1110 |
+
self.last_dropout = Identity()
|
| 1111 |
+
if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
|
| 1112 |
+
self.last_dropout = nn.Dropout(config.summary_last_dropout)
|
| 1113 |
+
|
| 1114 |
+
def forward(self, hidden_states, cls_index=None):
|
| 1115 |
+
""" hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
|
| 1116 |
+
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
| 1117 |
+
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
| 1118 |
+
if summary_type == 'cls_index' and cls_index is None:
|
| 1119 |
+
we take the last token of the sequence as classification token
|
| 1120 |
+
"""
|
| 1121 |
+
if self.summary_type == "last":
|
| 1122 |
+
output = hidden_states[:, -1]
|
| 1123 |
+
elif self.summary_type == "first":
|
| 1124 |
+
output = hidden_states[:, 0]
|
| 1125 |
+
elif self.summary_type == "mean":
|
| 1126 |
+
output = hidden_states.mean(dim=1)
|
| 1127 |
+
elif self.summary_type == "cls_index":
|
| 1128 |
+
if cls_index is None:
|
| 1129 |
+
cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
|
| 1130 |
+
else:
|
| 1131 |
+
cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
|
| 1132 |
+
cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
|
| 1133 |
+
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
| 1134 |
+
output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
|
| 1135 |
+
elif self.summary_type == "attn":
|
| 1136 |
+
raise NotImplementedError
|
| 1137 |
+
|
| 1138 |
+
output = self.first_dropout(output)
|
| 1139 |
+
output = self.summary(output)
|
| 1140 |
+
output = self.activation(output)
|
| 1141 |
+
output = self.last_dropout(output)
|
| 1142 |
+
|
| 1143 |
+
return output
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
def prune_linear_layer(layer, index, dim=0):
|
| 1147 |
+
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
| 1148 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 1149 |
+
Used to remove heads.
|
| 1150 |
+
"""
|
| 1151 |
+
index = index.to(layer.weight.device)
|
| 1152 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
| 1153 |
+
if layer.bias is not None:
|
| 1154 |
+
if dim == 1:
|
| 1155 |
+
b = layer.bias.clone().detach()
|
| 1156 |
+
else:
|
| 1157 |
+
b = layer.bias[index].clone().detach()
|
| 1158 |
+
new_size = list(layer.weight.size())
|
| 1159 |
+
new_size[dim] = len(index)
|
| 1160 |
+
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
| 1161 |
+
new_layer.weight.requires_grad = False
|
| 1162 |
+
new_layer.weight.copy_(W.contiguous())
|
| 1163 |
+
new_layer.weight.requires_grad = True
|
| 1164 |
+
if layer.bias is not None:
|
| 1165 |
+
new_layer.bias.requires_grad = False
|
| 1166 |
+
new_layer.bias.copy_(b.contiguous())
|
| 1167 |
+
new_layer.bias.requires_grad = True
|
| 1168 |
+
return new_layer
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def prune_conv1d_layer(layer, index, dim=1):
|
| 1172 |
+
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
|
| 1173 |
+
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
|
| 1174 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 1175 |
+
Used to remove heads.
|
| 1176 |
+
"""
|
| 1177 |
+
index = index.to(layer.weight.device)
|
| 1178 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
| 1179 |
+
if dim == 0:
|
| 1180 |
+
b = layer.bias.clone().detach()
|
| 1181 |
+
else:
|
| 1182 |
+
b = layer.bias[index].clone().detach()
|
| 1183 |
+
new_size = list(layer.weight.size())
|
| 1184 |
+
new_size[dim] = len(index)
|
| 1185 |
+
new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
|
| 1186 |
+
new_layer.weight.requires_grad = False
|
| 1187 |
+
new_layer.weight.copy_(W.contiguous())
|
| 1188 |
+
new_layer.weight.requires_grad = True
|
| 1189 |
+
new_layer.bias.requires_grad = False
|
| 1190 |
+
new_layer.bias.copy_(b.contiguous())
|
| 1191 |
+
new_layer.bias.requires_grad = True
|
| 1192 |
+
return new_layer
|
| 1193 |
+
|
| 1194 |
+
|
| 1195 |
+
def prune_layer(layer, index, dim=None):
|
| 1196 |
+
""" Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
|
| 1197 |
+
Return the pruned layer as a new layer with requires_grad=True.
|
| 1198 |
+
Used to remove heads.
|
| 1199 |
+
"""
|
| 1200 |
+
if isinstance(layer, nn.Linear):
|
| 1201 |
+
return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
|
| 1202 |
+
elif isinstance(layer, Conv1D):
|
| 1203 |
+
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
| 1204 |
+
else:
|
| 1205 |
+
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
| 1206 |
+
|
| 1207 |
+
|
| 1208 |
+
def apply_chunking_to_forward(
|
| 1209 |
+
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
|
| 1210 |
+
) -> torch.Tensor:
|
| 1211 |
+
"""
|
| 1212 |
+
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
|
| 1213 |
+
It then applies a layer `forward_fn` to each chunk independently to save memory.
|
| 1214 |
+
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
|
| 1215 |
+
same result as not applying it.
|
| 1216 |
+
|
| 1217 |
+
Args:
|
| 1218 |
+
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
|
| 1219 |
+
chunk_dim: int - the dimension over which the input_tensors should be chunked
|
| 1220 |
+
forward_fn: fn - the forward fn of the model
|
| 1221 |
+
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
|
| 1222 |
+
Returns:
|
| 1223 |
+
a Tensor with the same shape the foward_fn would have given if applied
|
| 1224 |
+
|
| 1225 |
+
|
| 1226 |
+
Examples::
|
| 1227 |
+
|
| 1228 |
+
# rename the usual forward() fn to forward_chunk()
|
| 1229 |
+
def forward_chunk(self, hidden_states):
|
| 1230 |
+
hidden_states = self.decoder(hidden_states)
|
| 1231 |
+
return hidden_states
|
| 1232 |
+
|
| 1233 |
+
# implement a chunked forward function
|
| 1234 |
+
def forward(self, hidden_states):
|
| 1235 |
+
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
|
| 1236 |
+
"""
|
| 1237 |
+
|
| 1238 |
+
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
|
| 1239 |
+
tensor_shape = input_tensors[0].shape
|
| 1240 |
+
assert all(
|
| 1241 |
+
input_tensor.shape == tensor_shape for input_tensor in input_tensors
|
| 1242 |
+
), "All input tenors have to be of the same shape"
|
| 1243 |
+
|
| 1244 |
+
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
|
| 1245 |
+
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
| 1246 |
+
assert num_args_in_forward_chunk_fn == len(
|
| 1247 |
+
input_tensors
|
| 1248 |
+
), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
|
| 1249 |
+
num_args_in_forward_chunk_fn, len(input_tensors)
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
if chunk_size > 0:
|
| 1253 |
+
assert (
|
| 1254 |
+
input_tensors[0].shape[chunk_dim] % chunk_size == 0
|
| 1255 |
+
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
|
| 1256 |
+
input_tensors[0].shape[chunk_dim], chunk_size
|
| 1257 |
+
)
|
| 1258 |
+
|
| 1259 |
+
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
| 1260 |
+
|
| 1261 |
+
# chunk input tensor into tuples
|
| 1262 |
+
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
| 1263 |
+
# apply forward fn to every tuple
|
| 1264 |
+
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
| 1265 |
+
# concatenate output at same dimension
|
| 1266 |
+
return torch.cat(output_chunks, dim=chunk_dim)
|
| 1267 |
+
|
| 1268 |
+
return forward_fn(*input_tensors)
|
CGFormer/bert/tokenization_bert.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Tokenization classes."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import collections
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import unicodedata
|
| 22 |
+
from typing import List, Optional
|
| 23 |
+
|
| 24 |
+
from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
| 30 |
+
|
| 31 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
| 32 |
+
"vocab_file": {
|
| 33 |
+
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
|
| 34 |
+
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
|
| 35 |
+
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
|
| 36 |
+
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
|
| 37 |
+
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
|
| 38 |
+
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
|
| 39 |
+
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
| 40 |
+
"bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
|
| 41 |
+
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
|
| 42 |
+
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
|
| 43 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 44 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
|
| 45 |
+
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
| 46 |
+
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
|
| 47 |
+
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
|
| 48 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
|
| 49 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
|
| 50 |
+
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
|
| 51 |
+
}
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
| 55 |
+
"bert-base-uncased": 512,
|
| 56 |
+
"bert-large-uncased": 512,
|
| 57 |
+
"bert-base-cased": 512,
|
| 58 |
+
"bert-large-cased": 512,
|
| 59 |
+
"bert-base-multilingual-uncased": 512,
|
| 60 |
+
"bert-base-multilingual-cased": 512,
|
| 61 |
+
"bert-base-chinese": 512,
|
| 62 |
+
"bert-base-german-cased": 512,
|
| 63 |
+
"bert-large-uncased-whole-word-masking": 512,
|
| 64 |
+
"bert-large-cased-whole-word-masking": 512,
|
| 65 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
| 66 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
| 67 |
+
"bert-base-cased-finetuned-mrpc": 512,
|
| 68 |
+
"bert-base-german-dbmdz-cased": 512,
|
| 69 |
+
"bert-base-german-dbmdz-uncased": 512,
|
| 70 |
+
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
| 71 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
| 72 |
+
"wietsedv/bert-base-dutch-cased": 512,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
| 76 |
+
"bert-base-uncased": {"do_lower_case": True},
|
| 77 |
+
"bert-large-uncased": {"do_lower_case": True},
|
| 78 |
+
"bert-base-cased": {"do_lower_case": False},
|
| 79 |
+
"bert-large-cased": {"do_lower_case": False},
|
| 80 |
+
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
| 81 |
+
"bert-base-multilingual-cased": {"do_lower_case": False},
|
| 82 |
+
"bert-base-chinese": {"do_lower_case": False},
|
| 83 |
+
"bert-base-german-cased": {"do_lower_case": False},
|
| 84 |
+
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
| 85 |
+
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
| 86 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
| 87 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
| 88 |
+
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
| 89 |
+
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
| 90 |
+
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
| 91 |
+
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
| 92 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
| 93 |
+
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_vocab(vocab_file):
|
| 98 |
+
"""Loads a vocabulary file into a dictionary."""
|
| 99 |
+
vocab = collections.OrderedDict()
|
| 100 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
| 101 |
+
tokens = reader.readlines()
|
| 102 |
+
for index, token in enumerate(tokens):
|
| 103 |
+
token = token.rstrip("\n")
|
| 104 |
+
vocab[token] = index
|
| 105 |
+
return vocab
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def whitespace_tokenize(text):
|
| 109 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
| 110 |
+
text = text.strip()
|
| 111 |
+
if not text:
|
| 112 |
+
return []
|
| 113 |
+
tokens = text.split()
|
| 114 |
+
return tokens
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class BertTokenizer(PreTrainedTokenizer):
|
| 118 |
+
r"""
|
| 119 |
+
Constructs a BERT tokenizer. Based on WordPiece.
|
| 120 |
+
|
| 121 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
|
| 122 |
+
should refer to the superclass for more information regarding methods.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
vocab_file (:obj:`string`):
|
| 126 |
+
File containing the vocabulary.
|
| 127 |
+
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 128 |
+
Whether to lowercase the input when tokenizing.
|
| 129 |
+
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 130 |
+
Whether to do basic tokenization before WordPiece.
|
| 131 |
+
never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
|
| 132 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
| 133 |
+
:obj:`do_basic_tokenize=True`
|
| 134 |
+
unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
|
| 135 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
| 136 |
+
token instead.
|
| 137 |
+
sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
|
| 138 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
|
| 139 |
+
for sequence classification or for a text and a question for question answering.
|
| 140 |
+
It is also used as the last token of a sequence built with special tokens.
|
| 141 |
+
pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
|
| 142 |
+
The token used for padding, for example when batching sequences of different lengths.
|
| 143 |
+
cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
|
| 144 |
+
The classifier token which is used when doing sequence classification (classification of the whole
|
| 145 |
+
sequence instead of per-token classification). It is the first token of the sequence when built with
|
| 146 |
+
special tokens.
|
| 147 |
+
mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
|
| 148 |
+
The token used for masking values. This is the token used when training this model with masked language
|
| 149 |
+
modeling. This is the token which the model will try to predict.
|
| 150 |
+
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
| 151 |
+
Whether to tokenize Chinese characters.
|
| 152 |
+
This should likely be deactivated for Japanese:
|
| 153 |
+
see: https://github.com/huggingface/transformers/issues/328
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 157 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 158 |
+
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
| 159 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
vocab_file,
|
| 164 |
+
do_lower_case=True,
|
| 165 |
+
do_basic_tokenize=True,
|
| 166 |
+
never_split=None,
|
| 167 |
+
unk_token="[UNK]",
|
| 168 |
+
sep_token="[SEP]",
|
| 169 |
+
pad_token="[PAD]",
|
| 170 |
+
cls_token="[CLS]",
|
| 171 |
+
mask_token="[MASK]",
|
| 172 |
+
tokenize_chinese_chars=True,
|
| 173 |
+
**kwargs
|
| 174 |
+
):
|
| 175 |
+
super().__init__(
|
| 176 |
+
unk_token=unk_token,
|
| 177 |
+
sep_token=sep_token,
|
| 178 |
+
pad_token=pad_token,
|
| 179 |
+
cls_token=cls_token,
|
| 180 |
+
mask_token=mask_token,
|
| 181 |
+
**kwargs,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if not os.path.isfile(vocab_file):
|
| 185 |
+
raise ValueError(
|
| 186 |
+
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
| 187 |
+
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
|
| 188 |
+
)
|
| 189 |
+
self.vocab = load_vocab(vocab_file)
|
| 190 |
+
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
|
| 191 |
+
self.do_basic_tokenize = do_basic_tokenize
|
| 192 |
+
if do_basic_tokenize:
|
| 193 |
+
self.basic_tokenizer = BasicTokenizer(
|
| 194 |
+
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
|
| 195 |
+
)
|
| 196 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def vocab_size(self):
|
| 200 |
+
return len(self.vocab)
|
| 201 |
+
|
| 202 |
+
def get_vocab(self):
|
| 203 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
| 204 |
+
|
| 205 |
+
def _tokenize(self, text):
|
| 206 |
+
split_tokens = []
|
| 207 |
+
if self.do_basic_tokenize:
|
| 208 |
+
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
| 209 |
+
|
| 210 |
+
# If the token is part of the never_split set
|
| 211 |
+
if token in self.basic_tokenizer.never_split:
|
| 212 |
+
split_tokens.append(token)
|
| 213 |
+
else:
|
| 214 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
| 215 |
+
else:
|
| 216 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
| 217 |
+
return split_tokens
|
| 218 |
+
|
| 219 |
+
def _convert_token_to_id(self, token):
|
| 220 |
+
""" Converts a token (str) in an id using the vocab. """
|
| 221 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
| 222 |
+
|
| 223 |
+
def _convert_id_to_token(self, index):
|
| 224 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 225 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
| 226 |
+
|
| 227 |
+
def convert_tokens_to_string(self, tokens):
|
| 228 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
| 229 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
| 230 |
+
return out_string
|
| 231 |
+
|
| 232 |
+
def build_inputs_with_special_tokens(
|
| 233 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 234 |
+
) -> List[int]:
|
| 235 |
+
"""
|
| 236 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
| 237 |
+
by concatenating and adding special tokens.
|
| 238 |
+
A BERT sequence has the following format:
|
| 239 |
+
|
| 240 |
+
- single sequence: ``[CLS] X [SEP]``
|
| 241 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
token_ids_0 (:obj:`List[int]`):
|
| 245 |
+
List of IDs to which the special tokens will be added
|
| 246 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 247 |
+
Optional second list of IDs for sequence pairs.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
| 251 |
+
"""
|
| 252 |
+
if token_ids_1 is None:
|
| 253 |
+
return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
| 254 |
+
cls = [self.cls_token_id]
|
| 255 |
+
sep = [self.sep_token_id]
|
| 256 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
| 257 |
+
|
| 258 |
+
def get_special_tokens_mask(
|
| 259 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
| 260 |
+
) -> List[int]:
|
| 261 |
+
"""
|
| 262 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 263 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
token_ids_0 (:obj:`List[int]`):
|
| 267 |
+
List of ids.
|
| 268 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 269 |
+
Optional second list of IDs for sequence pairs.
|
| 270 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
| 271 |
+
Set to True if the token list is already formatted with special tokens for the model
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
if already_has_special_tokens:
|
| 278 |
+
if token_ids_1 is not None:
|
| 279 |
+
raise ValueError(
|
| 280 |
+
"You should not supply a second sequence if the provided sequence of "
|
| 281 |
+
"ids is already formated with special tokens for the model."
|
| 282 |
+
)
|
| 283 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
| 284 |
+
|
| 285 |
+
if token_ids_1 is not None:
|
| 286 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
| 287 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
| 288 |
+
|
| 289 |
+
def create_token_type_ids_from_sequences(
|
| 290 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 291 |
+
) -> List[int]:
|
| 292 |
+
"""
|
| 293 |
+
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 294 |
+
A BERT sequence pair mask has the following format:
|
| 295 |
+
|
| 296 |
+
::
|
| 297 |
+
|
| 298 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 299 |
+
| first sequence | second sequence |
|
| 300 |
+
|
| 301 |
+
if token_ids_1 is None, only returns the first portion of the mask (0's).
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
token_ids_0 (:obj:`List[int]`):
|
| 305 |
+
List of ids.
|
| 306 |
+
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
| 307 |
+
Optional second list of IDs for sequence pairs.
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
| 311 |
+
sequence(s).
|
| 312 |
+
"""
|
| 313 |
+
sep = [self.sep_token_id]
|
| 314 |
+
cls = [self.cls_token_id]
|
| 315 |
+
if token_ids_1 is None:
|
| 316 |
+
return len(cls + token_ids_0 + sep) * [0]
|
| 317 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
| 318 |
+
|
| 319 |
+
def save_vocabulary(self, vocab_path):
|
| 320 |
+
"""
|
| 321 |
+
Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
vocab_path (:obj:`str`):
|
| 325 |
+
The directory in which to save the vocabulary.
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
:obj:`Tuple(str)`: Paths to the files saved.
|
| 329 |
+
"""
|
| 330 |
+
index = 0
|
| 331 |
+
if os.path.isdir(vocab_path):
|
| 332 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
| 333 |
+
else:
|
| 334 |
+
vocab_file = vocab_path
|
| 335 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
| 336 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
| 337 |
+
if index != token_index:
|
| 338 |
+
logger.warning(
|
| 339 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
| 340 |
+
" Please check that the vocabulary is not corrupted!".format(vocab_file)
|
| 341 |
+
)
|
| 342 |
+
index = token_index
|
| 343 |
+
writer.write(token + "\n")
|
| 344 |
+
index += 1
|
| 345 |
+
return (vocab_file,)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
class BasicTokenizer(object):
|
| 349 |
+
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
|
| 350 |
+
|
| 351 |
+
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
|
| 352 |
+
""" Constructs a BasicTokenizer.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
**do_lower_case**: Whether to lower case the input.
|
| 356 |
+
**never_split**: (`optional`) list of str
|
| 357 |
+
Kept for backward compatibility purposes.
|
| 358 |
+
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
| 359 |
+
List of token not to split.
|
| 360 |
+
**tokenize_chinese_chars**: (`optional`) boolean (default True)
|
| 361 |
+
Whether to tokenize Chinese characters.
|
| 362 |
+
This should likely be deactivated for Japanese:
|
| 363 |
+
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
|
| 364 |
+
"""
|
| 365 |
+
if never_split is None:
|
| 366 |
+
never_split = []
|
| 367 |
+
self.do_lower_case = do_lower_case
|
| 368 |
+
self.never_split = set(never_split)
|
| 369 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
| 370 |
+
|
| 371 |
+
def tokenize(self, text, never_split=None):
|
| 372 |
+
""" Basic Tokenization of a piece of text.
|
| 373 |
+
Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
**never_split**: (`optional`) list of str
|
| 377 |
+
Kept for backward compatibility purposes.
|
| 378 |
+
Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
|
| 379 |
+
List of token not to split.
|
| 380 |
+
"""
|
| 381 |
+
# union() returns a new set by concatenating the two sets.
|
| 382 |
+
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
|
| 383 |
+
|
| 384 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
| 385 |
+
# models. This is also applied to the English models now, but it doesn't
|
| 386 |
+
# matter since the English models were not trained on any Chinese data
|
| 387 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
| 388 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
| 389 |
+
# words in the English Wikipedia.).
|
| 390 |
+
if self.tokenize_chinese_chars:
|
| 391 |
+
text = self._tokenize_chinese_chars(text)
|
| 392 |
+
orig_tokens = whitespace_tokenize(text)
|
| 393 |
+
split_tokens = []
|
| 394 |
+
for token in orig_tokens:
|
| 395 |
+
if self.do_lower_case and token not in never_split:
|
| 396 |
+
token = token.lower()
|
| 397 |
+
token = self._run_strip_accents(token)
|
| 398 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
| 399 |
+
|
| 400 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
| 401 |
+
return output_tokens
|
| 402 |
+
|
| 403 |
+
def _run_strip_accents(self, text):
|
| 404 |
+
"""Strips accents from a piece of text."""
|
| 405 |
+
text = unicodedata.normalize("NFD", text)
|
| 406 |
+
output = []
|
| 407 |
+
for char in text:
|
| 408 |
+
cat = unicodedata.category(char)
|
| 409 |
+
if cat == "Mn":
|
| 410 |
+
continue
|
| 411 |
+
output.append(char)
|
| 412 |
+
return "".join(output)
|
| 413 |
+
|
| 414 |
+
def _run_split_on_punc(self, text, never_split=None):
|
| 415 |
+
"""Splits punctuation on a piece of text."""
|
| 416 |
+
if never_split is not None and text in never_split:
|
| 417 |
+
return [text]
|
| 418 |
+
chars = list(text)
|
| 419 |
+
i = 0
|
| 420 |
+
start_new_word = True
|
| 421 |
+
output = []
|
| 422 |
+
while i < len(chars):
|
| 423 |
+
char = chars[i]
|
| 424 |
+
if _is_punctuation(char):
|
| 425 |
+
output.append([char])
|
| 426 |
+
start_new_word = True
|
| 427 |
+
else:
|
| 428 |
+
if start_new_word:
|
| 429 |
+
output.append([])
|
| 430 |
+
start_new_word = False
|
| 431 |
+
output[-1].append(char)
|
| 432 |
+
i += 1
|
| 433 |
+
|
| 434 |
+
return ["".join(x) for x in output]
|
| 435 |
+
|
| 436 |
+
def _tokenize_chinese_chars(self, text):
|
| 437 |
+
"""Adds whitespace around any CJK character."""
|
| 438 |
+
output = []
|
| 439 |
+
for char in text:
|
| 440 |
+
cp = ord(char)
|
| 441 |
+
if self._is_chinese_char(cp):
|
| 442 |
+
output.append(" ")
|
| 443 |
+
output.append(char)
|
| 444 |
+
output.append(" ")
|
| 445 |
+
else:
|
| 446 |
+
output.append(char)
|
| 447 |
+
return "".join(output)
|
| 448 |
+
|
| 449 |
+
def _is_chinese_char(self, cp):
|
| 450 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
| 451 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
| 452 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
| 453 |
+
#
|
| 454 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
| 455 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
| 456 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
| 457 |
+
# space-separated words, so they are not treated specially and handled
|
| 458 |
+
# like the all of the other languages.
|
| 459 |
+
if (
|
| 460 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
| 461 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
| 462 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
| 463 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
| 464 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
| 465 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
| 466 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
| 467 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
| 468 |
+
): #
|
| 469 |
+
return True
|
| 470 |
+
|
| 471 |
+
return False
|
| 472 |
+
|
| 473 |
+
def _clean_text(self, text):
|
| 474 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
| 475 |
+
output = []
|
| 476 |
+
for char in text:
|
| 477 |
+
cp = ord(char)
|
| 478 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
| 479 |
+
continue
|
| 480 |
+
if _is_whitespace(char):
|
| 481 |
+
output.append(" ")
|
| 482 |
+
else:
|
| 483 |
+
output.append(char)
|
| 484 |
+
return "".join(output)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
class WordpieceTokenizer(object):
|
| 488 |
+
"""Runs WordPiece tokenization."""
|
| 489 |
+
|
| 490 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
| 491 |
+
self.vocab = vocab
|
| 492 |
+
self.unk_token = unk_token
|
| 493 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
| 494 |
+
|
| 495 |
+
def tokenize(self, text):
|
| 496 |
+
"""Tokenizes a piece of text into its word pieces.
|
| 497 |
+
|
| 498 |
+
This uses a greedy longest-match-first algorithm to perform tokenization
|
| 499 |
+
using the given vocabulary.
|
| 500 |
+
|
| 501 |
+
For example:
|
| 502 |
+
input = "unaffable"
|
| 503 |
+
output = ["un", "##aff", "##able"]
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
text: A single token or whitespace separated tokens. This should have
|
| 507 |
+
already been passed through `BasicTokenizer`.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
A list of wordpiece tokens.
|
| 511 |
+
"""
|
| 512 |
+
|
| 513 |
+
output_tokens = []
|
| 514 |
+
for token in whitespace_tokenize(text):
|
| 515 |
+
chars = list(token)
|
| 516 |
+
if len(chars) > self.max_input_chars_per_word:
|
| 517 |
+
output_tokens.append(self.unk_token)
|
| 518 |
+
continue
|
| 519 |
+
|
| 520 |
+
is_bad = False
|
| 521 |
+
start = 0
|
| 522 |
+
sub_tokens = []
|
| 523 |
+
while start < len(chars):
|
| 524 |
+
end = len(chars)
|
| 525 |
+
cur_substr = None
|
| 526 |
+
while start < end:
|
| 527 |
+
substr = "".join(chars[start:end])
|
| 528 |
+
if start > 0:
|
| 529 |
+
substr = "##" + substr
|
| 530 |
+
if substr in self.vocab:
|
| 531 |
+
cur_substr = substr
|
| 532 |
+
break
|
| 533 |
+
end -= 1
|
| 534 |
+
if cur_substr is None:
|
| 535 |
+
is_bad = True
|
| 536 |
+
break
|
| 537 |
+
sub_tokens.append(cur_substr)
|
| 538 |
+
start = end
|
| 539 |
+
|
| 540 |
+
if is_bad:
|
| 541 |
+
output_tokens.append(self.unk_token)
|
| 542 |
+
else:
|
| 543 |
+
output_tokens.extend(sub_tokens)
|
| 544 |
+
return output_tokens
|
| 545 |
+
|
CGFormer/bert/tokenization_utils.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2020 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Tokenization classes for python tokenizers.
|
| 16 |
+
For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import itertools
|
| 20 |
+
import logging
|
| 21 |
+
import re
|
| 22 |
+
import unicodedata
|
| 23 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
from .file_utils import add_end_docstrings
|
| 26 |
+
from .tokenization_utils_base import (
|
| 27 |
+
ENCODE_KWARGS_DOCSTRING,
|
| 28 |
+
ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
|
| 29 |
+
AddedToken,
|
| 30 |
+
BatchEncoding,
|
| 31 |
+
EncodedInput,
|
| 32 |
+
EncodedInputPair,
|
| 33 |
+
PaddingStrategy,
|
| 34 |
+
PreTokenizedInput,
|
| 35 |
+
PreTokenizedInputPair,
|
| 36 |
+
PreTrainedTokenizerBase,
|
| 37 |
+
TensorType,
|
| 38 |
+
TextInput,
|
| 39 |
+
TextInputPair,
|
| 40 |
+
TruncationStrategy,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _is_whitespace(char):
|
| 48 |
+
"""Checks whether `chars` is a whitespace character."""
|
| 49 |
+
# \t, \n, and \r are technically contorl characters but we treat them
|
| 50 |
+
# as whitespace since they are generally considered as such.
|
| 51 |
+
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
| 52 |
+
return True
|
| 53 |
+
cat = unicodedata.category(char)
|
| 54 |
+
if cat == "Zs":
|
| 55 |
+
return True
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _is_control(char):
|
| 60 |
+
"""Checks whether `chars` is a control character."""
|
| 61 |
+
# These are technically control characters but we count them as whitespace
|
| 62 |
+
# characters.
|
| 63 |
+
if char == "\t" or char == "\n" or char == "\r":
|
| 64 |
+
return False
|
| 65 |
+
cat = unicodedata.category(char)
|
| 66 |
+
if cat.startswith("C"):
|
| 67 |
+
return True
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _is_punctuation(char):
|
| 72 |
+
"""Checks whether `chars` is a punctuation character."""
|
| 73 |
+
cp = ord(char)
|
| 74 |
+
# We treat all non-letter/number ASCII as punctuation.
|
| 75 |
+
# Characters such as "^", "$", and "`" are not in the Unicode
|
| 76 |
+
# Punctuation class but we treat them as punctuation anyways, for
|
| 77 |
+
# consistency.
|
| 78 |
+
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
|
| 79 |
+
return True
|
| 80 |
+
cat = unicodedata.category(char)
|
| 81 |
+
if cat.startswith("P"):
|
| 82 |
+
return True
|
| 83 |
+
return False
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _is_end_of_word(text):
|
| 87 |
+
"""Checks whether the last character in text is one of a punctuation, control or whitespace character."""
|
| 88 |
+
last_char = text[-1]
|
| 89 |
+
return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _is_start_of_word(text):
|
| 93 |
+
"""Checks whether the first character in text is one of a punctuation, control or whitespace character."""
|
| 94 |
+
first_char = text[0]
|
| 95 |
+
return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
| 99 |
+
""" Base class for all slow tokenizers.
|
| 100 |
+
|
| 101 |
+
Handle all the shared methods for tokenization and special tokens as well as methods
|
| 102 |
+
downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
|
| 103 |
+
|
| 104 |
+
This class also contain the added tokens in a unified way on top of all tokenizers so we don't
|
| 105 |
+
have to handle the specific vocabulary augmentation methods of the various underlying
|
| 106 |
+
dictionary structures (BPE, sentencepiece...).
|
| 107 |
+
|
| 108 |
+
Class attributes (overridden by derived classes):
|
| 109 |
+
|
| 110 |
+
- ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
|
| 111 |
+
required by the model, and as associated values, the filename for saving the associated file (string).
|
| 112 |
+
- ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
|
| 113 |
+
being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
|
| 114 |
+
`short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
|
| 115 |
+
associated pretrained vocabulary file.
|
| 116 |
+
- ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
|
| 117 |
+
models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
|
| 118 |
+
model has no maximum input size.
|
| 119 |
+
- ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
|
| 120 |
+
pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
|
| 121 |
+
``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
|
| 122 |
+
``from_pretrained()`` method.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
- ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
|
| 126 |
+
When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
|
| 127 |
+
model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
|
| 128 |
+
no associated max_length can be found in ``max_model_input_sizes``.
|
| 129 |
+
- ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
|
| 130 |
+
Should be selected between ['right', 'left']
|
| 131 |
+
- ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
|
| 132 |
+
model ("token_type_ids", "attention_mask"...).
|
| 133 |
+
- ``bos_token``: (`Optional`) string: a beginning of sentence token.
|
| 134 |
+
Will be associated to ``self.bos_token`` and ``self.bos_token_id``
|
| 135 |
+
- ``eos_token``: (`Optional`) string: an end of sentence token.
|
| 136 |
+
Will be associated to ``self.eos_token`` and ``self.eos_token_id``
|
| 137 |
+
- ``unk_token``: (`Optional`) string: an unknown token.
|
| 138 |
+
Will be associated to ``self.unk_token`` and ``self.unk_token_id``
|
| 139 |
+
- ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
|
| 140 |
+
Will be associated to ``self.sep_token`` and ``self.sep_token_id``
|
| 141 |
+
- ``pad_token``: (`Optional`) string: a padding token.
|
| 142 |
+
Will be associated to ``self.pad_token`` and ``self.pad_token_id``
|
| 143 |
+
- ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
|
| 144 |
+
leveraging self-attention along the full depth of the model).
|
| 145 |
+
Will be associated to ``self.cls_token`` and ``self.cls_token_id``
|
| 146 |
+
- ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
|
| 147 |
+
modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
|
| 148 |
+
- ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
|
| 149 |
+
Adding all special tokens here ensure they won't be split by the tokenization process.
|
| 150 |
+
Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
.. automethod:: __call__
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, **kwargs):
|
| 157 |
+
super().__init__(**kwargs)
|
| 158 |
+
|
| 159 |
+
# Added tokens - We store this for both slow and fast tokenizers
|
| 160 |
+
# until the serialization of Fast tokenizers is updated
|
| 161 |
+
self.added_tokens_encoder: Dict[str, int] = {}
|
| 162 |
+
self.added_tokens_decoder: Dict[int, str] = {}
|
| 163 |
+
self.unique_no_split_tokens: List[str] = []
|
| 164 |
+
|
| 165 |
+
@property
|
| 166 |
+
def is_fast(self) -> bool:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def vocab_size(self) -> int:
|
| 171 |
+
""" Size of the base vocabulary (without the added tokens) """
|
| 172 |
+
raise NotImplementedError
|
| 173 |
+
|
| 174 |
+
def get_vocab(self):
|
| 175 |
+
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
|
| 176 |
+
raise NotImplementedError()
|
| 177 |
+
|
| 178 |
+
def get_added_vocab(self) -> Dict[str, int]:
|
| 179 |
+
return self.added_tokens_encoder
|
| 180 |
+
|
| 181 |
+
def __len__(self):
|
| 182 |
+
""" Size of the full vocabulary with the added tokens """
|
| 183 |
+
return self.vocab_size + len(self.added_tokens_encoder)
|
| 184 |
+
|
| 185 |
+
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
|
| 186 |
+
"""
|
| 187 |
+
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
|
| 188 |
+
vocabulary, they are added to it with indices starting from length of the current vocabulary.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
|
| 192 |
+
already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Number of tokens added to the vocabulary.
|
| 196 |
+
|
| 197 |
+
Examples::
|
| 198 |
+
|
| 199 |
+
# Let's see how to increase the vocabulary of Bert model and tokenizer
|
| 200 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 201 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
| 202 |
+
|
| 203 |
+
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
| 204 |
+
print('We have added', num_added_toks, 'tokens')
|
| 205 |
+
model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
| 206 |
+
"""
|
| 207 |
+
new_tokens = [str(tok) for tok in new_tokens]
|
| 208 |
+
|
| 209 |
+
tokens_to_add = []
|
| 210 |
+
for token in new_tokens:
|
| 211 |
+
assert isinstance(token, str)
|
| 212 |
+
if not special_tokens and self.init_kwargs.get("do_lower_case", False):
|
| 213 |
+
token = token.lower()
|
| 214 |
+
if (
|
| 215 |
+
token != self.unk_token
|
| 216 |
+
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
|
| 217 |
+
and token not in tokens_to_add
|
| 218 |
+
):
|
| 219 |
+
tokens_to_add.append(token)
|
| 220 |
+
if self.verbose:
|
| 221 |
+
logger.info("Adding %s to the vocabulary", token)
|
| 222 |
+
|
| 223 |
+
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
|
| 224 |
+
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
|
| 225 |
+
self.added_tokens_encoder.update(added_tok_encoder)
|
| 226 |
+
self.added_tokens_decoder.update(added_tok_decoder)
|
| 227 |
+
|
| 228 |
+
# Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
|
| 229 |
+
if special_tokens:
|
| 230 |
+
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
|
| 231 |
+
else:
|
| 232 |
+
# Or on the newly added tokens
|
| 233 |
+
self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
|
| 234 |
+
|
| 235 |
+
return len(tokens_to_add)
|
| 236 |
+
|
| 237 |
+
def num_special_tokens_to_add(self, pair=False):
|
| 238 |
+
"""
|
| 239 |
+
Returns the number of added tokens when encoding a sequence with special tokens.
|
| 240 |
+
|
| 241 |
+
Note:
|
| 242 |
+
This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
|
| 243 |
+
inside your training loop.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
|
| 247 |
+
number of added tokens in the case of a single sequence if set to False.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Number of tokens added to sequences
|
| 251 |
+
"""
|
| 252 |
+
token_ids_0 = []
|
| 253 |
+
token_ids_1 = []
|
| 254 |
+
return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
|
| 255 |
+
|
| 256 |
+
def tokenize(self, text: TextInput, **kwargs):
|
| 257 |
+
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
| 258 |
+
Split in words for word-based vocabulary or sub-words for sub-word-based
|
| 259 |
+
vocabularies (BPE/SentencePieces/WordPieces).
|
| 260 |
+
|
| 261 |
+
Take care of added tokens.
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
text (:obj:`string`): The sequence to be encoded.
|
| 265 |
+
**kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
|
| 266 |
+
"""
|
| 267 |
+
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
|
| 268 |
+
all_special_tokens_extended = dict(
|
| 269 |
+
(str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
| 273 |
+
|
| 274 |
+
if kwargs:
|
| 275 |
+
logger.warning(f"Keyword arguments {kwargs} not recognized.")
|
| 276 |
+
|
| 277 |
+
# TODO: should this be in the base class?
|
| 278 |
+
if self.init_kwargs.get("do_lower_case", False):
|
| 279 |
+
# convert non-special tokens to lowercase
|
| 280 |
+
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
|
| 281 |
+
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
| 282 |
+
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
| 283 |
+
|
| 284 |
+
def split_on_token(tok, text):
|
| 285 |
+
result = []
|
| 286 |
+
tok_extended = all_special_tokens_extended.get(tok, None)
|
| 287 |
+
split_text = text.split(tok)
|
| 288 |
+
full_word = ""
|
| 289 |
+
for i, sub_text in enumerate(split_text):
|
| 290 |
+
# AddedToken can control whitespace stripping around them.
|
| 291 |
+
# We use them for GPT2 and Roberta to have different behavior depending on the special token
|
| 292 |
+
# Cf. https://github.com/huggingface/transformers/pull/2778
|
| 293 |
+
# and https://github.com/huggingface/transformers/issues/3788
|
| 294 |
+
if isinstance(tok_extended, AddedToken):
|
| 295 |
+
if tok_extended.single_word:
|
| 296 |
+
# Try to avoid splitting on token
|
| 297 |
+
if (
|
| 298 |
+
i < len(split_text) - 1
|
| 299 |
+
and not _is_end_of_word(sub_text)
|
| 300 |
+
and not _is_start_of_word(split_text[i + 1])
|
| 301 |
+
):
|
| 302 |
+
# Don't extract the special token
|
| 303 |
+
full_word += sub_text + tok
|
| 304 |
+
elif full_word:
|
| 305 |
+
full_word += sub_text
|
| 306 |
+
result += [full_word]
|
| 307 |
+
full_word = ""
|
| 308 |
+
continue
|
| 309 |
+
# Strip white spaces on the right
|
| 310 |
+
if tok_extended.rstrip and i > 0:
|
| 311 |
+
# A bit counter-intuitive but we strip the left of the string
|
| 312 |
+
# since tok_extended.rstrip means the special token is eating all white spaces on its right
|
| 313 |
+
sub_text = sub_text.lstrip()
|
| 314 |
+
# Strip white spaces on the left
|
| 315 |
+
if tok_extended.lstrip and i < len(split_text) - 1:
|
| 316 |
+
sub_text = sub_text.rstrip() # Opposite here
|
| 317 |
+
else:
|
| 318 |
+
# We strip left and right by default
|
| 319 |
+
if i < len(split_text) - 1:
|
| 320 |
+
sub_text = sub_text.rstrip()
|
| 321 |
+
if i > 0:
|
| 322 |
+
sub_text = sub_text.lstrip()
|
| 323 |
+
|
| 324 |
+
if i == 0 and not sub_text:
|
| 325 |
+
result += [tok]
|
| 326 |
+
elif i == len(split_text) - 1:
|
| 327 |
+
if sub_text:
|
| 328 |
+
result += [sub_text]
|
| 329 |
+
else:
|
| 330 |
+
pass
|
| 331 |
+
else:
|
| 332 |
+
if sub_text:
|
| 333 |
+
result += [sub_text]
|
| 334 |
+
result += [tok]
|
| 335 |
+
return result
|
| 336 |
+
|
| 337 |
+
def split_on_tokens(tok_list, text):
|
| 338 |
+
if not text.strip():
|
| 339 |
+
return []
|
| 340 |
+
if not tok_list:
|
| 341 |
+
return self._tokenize(text)
|
| 342 |
+
|
| 343 |
+
tokenized_text = []
|
| 344 |
+
text_list = [text]
|
| 345 |
+
for tok in tok_list:
|
| 346 |
+
tokenized_text = []
|
| 347 |
+
for sub_text in text_list:
|
| 348 |
+
if sub_text not in self.unique_no_split_tokens:
|
| 349 |
+
tokenized_text += split_on_token(tok, sub_text)
|
| 350 |
+
else:
|
| 351 |
+
tokenized_text += [sub_text]
|
| 352 |
+
text_list = tokenized_text
|
| 353 |
+
|
| 354 |
+
return list(
|
| 355 |
+
itertools.chain.from_iterable(
|
| 356 |
+
(
|
| 357 |
+
self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
|
| 358 |
+
for token in tokenized_text
|
| 359 |
+
)
|
| 360 |
+
)
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
no_split_token = self.unique_no_split_tokens
|
| 364 |
+
tokenized_text = split_on_tokens(no_split_token, text)
|
| 365 |
+
return tokenized_text
|
| 366 |
+
|
| 367 |
+
def _tokenize(self, text, **kwargs):
|
| 368 |
+
""" Converts a string in a sequence of tokens (string), using the tokenizer.
|
| 369 |
+
Split in words for word-based vocabulary or sub-words for sub-word-based
|
| 370 |
+
vocabularies (BPE/SentencePieces/WordPieces).
|
| 371 |
+
|
| 372 |
+
Do NOT take care of added tokens.
|
| 373 |
+
"""
|
| 374 |
+
raise NotImplementedError
|
| 375 |
+
|
| 376 |
+
def convert_tokens_to_ids(self, tokens):
|
| 377 |
+
""" Converts a token string (or a sequence of tokens) in a single integer id
|
| 378 |
+
(or a sequence of ids), using the vocabulary.
|
| 379 |
+
"""
|
| 380 |
+
if tokens is None:
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
if isinstance(tokens, str):
|
| 384 |
+
return self._convert_token_to_id_with_added_voc(tokens)
|
| 385 |
+
|
| 386 |
+
ids = []
|
| 387 |
+
for token in tokens:
|
| 388 |
+
ids.append(self._convert_token_to_id_with_added_voc(token))
|
| 389 |
+
return ids
|
| 390 |
+
|
| 391 |
+
def _convert_token_to_id_with_added_voc(self, token):
|
| 392 |
+
if token is None:
|
| 393 |
+
return None
|
| 394 |
+
|
| 395 |
+
if token in self.added_tokens_encoder:
|
| 396 |
+
return self.added_tokens_encoder[token]
|
| 397 |
+
return self._convert_token_to_id(token)
|
| 398 |
+
|
| 399 |
+
def _convert_token_to_id(self, token):
|
| 400 |
+
raise NotImplementedError
|
| 401 |
+
|
| 402 |
+
def _encode_plus(
|
| 403 |
+
self,
|
| 404 |
+
text: Union[TextInput, PreTokenizedInput, EncodedInput],
|
| 405 |
+
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
|
| 406 |
+
add_special_tokens: bool = True,
|
| 407 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 408 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 409 |
+
max_length: Optional[int] = None,
|
| 410 |
+
stride: int = 0,
|
| 411 |
+
is_pretokenized: bool = False,
|
| 412 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 413 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 414 |
+
return_token_type_ids: Optional[bool] = None,
|
| 415 |
+
return_attention_mask: Optional[bool] = None,
|
| 416 |
+
return_overflowing_tokens: bool = False,
|
| 417 |
+
return_special_tokens_mask: bool = False,
|
| 418 |
+
return_offsets_mapping: bool = False,
|
| 419 |
+
return_length: bool = False,
|
| 420 |
+
verbose: bool = True,
|
| 421 |
+
**kwargs
|
| 422 |
+
) -> BatchEncoding:
|
| 423 |
+
def get_input_ids(text):
|
| 424 |
+
if isinstance(text, str):
|
| 425 |
+
tokens = self.tokenize(text, **kwargs)
|
| 426 |
+
return self.convert_tokens_to_ids(tokens)
|
| 427 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
| 428 |
+
if is_pretokenized:
|
| 429 |
+
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
|
| 430 |
+
return self.convert_tokens_to_ids(tokens)
|
| 431 |
+
else:
|
| 432 |
+
return self.convert_tokens_to_ids(text)
|
| 433 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
| 434 |
+
return text
|
| 435 |
+
else:
|
| 436 |
+
if is_pretokenized:
|
| 437 |
+
raise ValueError(
|
| 438 |
+
f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
raise ValueError(
|
| 442 |
+
f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
if return_offsets_mapping:
|
| 446 |
+
raise NotImplementedError(
|
| 447 |
+
"return_offset_mapping is not available when using Python tokenizers."
|
| 448 |
+
"To use this feature, change your tokenizer to one deriving from "
|
| 449 |
+
"transformers.PreTrainedTokenizerFast."
|
| 450 |
+
"More information on available tokenizers at "
|
| 451 |
+
"https://github.com/huggingface/transformers/pull/2674"
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
first_ids = get_input_ids(text)
|
| 455 |
+
second_ids = get_input_ids(text_pair) if text_pair is not None else None
|
| 456 |
+
|
| 457 |
+
return self.prepare_for_model(
|
| 458 |
+
first_ids,
|
| 459 |
+
pair_ids=second_ids,
|
| 460 |
+
add_special_tokens=add_special_tokens,
|
| 461 |
+
padding=padding_strategy.value,
|
| 462 |
+
truncation=truncation_strategy.value,
|
| 463 |
+
max_length=max_length,
|
| 464 |
+
stride=stride,
|
| 465 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 466 |
+
return_tensors=return_tensors,
|
| 467 |
+
prepend_batch_axis=True,
|
| 468 |
+
return_attention_mask=return_attention_mask,
|
| 469 |
+
return_token_type_ids=return_token_type_ids,
|
| 470 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 471 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 472 |
+
return_length=return_length,
|
| 473 |
+
verbose=verbose,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
def _batch_encode_plus(
|
| 477 |
+
self,
|
| 478 |
+
batch_text_or_text_pairs: Union[
|
| 479 |
+
List[TextInput],
|
| 480 |
+
List[TextInputPair],
|
| 481 |
+
List[PreTokenizedInput],
|
| 482 |
+
List[PreTokenizedInputPair],
|
| 483 |
+
List[EncodedInput],
|
| 484 |
+
List[EncodedInputPair],
|
| 485 |
+
],
|
| 486 |
+
add_special_tokens: bool = True,
|
| 487 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 488 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 489 |
+
max_length: Optional[int] = None,
|
| 490 |
+
stride: int = 0,
|
| 491 |
+
is_pretokenized: bool = False,
|
| 492 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 493 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 494 |
+
return_token_type_ids: Optional[bool] = None,
|
| 495 |
+
return_attention_mask: Optional[bool] = None,
|
| 496 |
+
return_overflowing_tokens: bool = False,
|
| 497 |
+
return_special_tokens_mask: bool = False,
|
| 498 |
+
return_offsets_mapping: bool = False,
|
| 499 |
+
return_length: bool = False,
|
| 500 |
+
verbose: bool = True,
|
| 501 |
+
**kwargs
|
| 502 |
+
) -> BatchEncoding:
|
| 503 |
+
def get_input_ids(text):
|
| 504 |
+
if isinstance(text, str):
|
| 505 |
+
tokens = self.tokenize(text, **kwargs)
|
| 506 |
+
return self.convert_tokens_to_ids(tokens)
|
| 507 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
| 508 |
+
if is_pretokenized:
|
| 509 |
+
tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
|
| 510 |
+
return self.convert_tokens_to_ids(tokens)
|
| 511 |
+
else:
|
| 512 |
+
return self.convert_tokens_to_ids(text)
|
| 513 |
+
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
| 514 |
+
return text
|
| 515 |
+
else:
|
| 516 |
+
raise ValueError(
|
| 517 |
+
"Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if return_offsets_mapping:
|
| 521 |
+
raise NotImplementedError(
|
| 522 |
+
"return_offset_mapping is not available when using Python tokenizers."
|
| 523 |
+
"To use this feature, change your tokenizer to one deriving from "
|
| 524 |
+
"transformers.PreTrainedTokenizerFast."
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
input_ids = []
|
| 528 |
+
for ids_or_pair_ids in batch_text_or_text_pairs:
|
| 529 |
+
if not isinstance(ids_or_pair_ids, (list, tuple)):
|
| 530 |
+
ids, pair_ids = ids_or_pair_ids, None
|
| 531 |
+
elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
|
| 532 |
+
ids, pair_ids = ids_or_pair_ids, None
|
| 533 |
+
else:
|
| 534 |
+
ids, pair_ids = ids_or_pair_ids
|
| 535 |
+
|
| 536 |
+
first_ids = get_input_ids(ids)
|
| 537 |
+
second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
|
| 538 |
+
input_ids.append((first_ids, second_ids))
|
| 539 |
+
|
| 540 |
+
batch_outputs = self._batch_prepare_for_model(
|
| 541 |
+
input_ids,
|
| 542 |
+
add_special_tokens=add_special_tokens,
|
| 543 |
+
padding_strategy=padding_strategy,
|
| 544 |
+
truncation_strategy=truncation_strategy,
|
| 545 |
+
max_length=max_length,
|
| 546 |
+
stride=stride,
|
| 547 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 548 |
+
return_attention_mask=return_attention_mask,
|
| 549 |
+
return_token_type_ids=return_token_type_ids,
|
| 550 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 551 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 552 |
+
return_length=return_length,
|
| 553 |
+
return_tensors=return_tensors,
|
| 554 |
+
verbose=verbose,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
return BatchEncoding(batch_outputs)
|
| 558 |
+
|
| 559 |
+
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
| 560 |
+
def _batch_prepare_for_model(
|
| 561 |
+
self,
|
| 562 |
+
batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
|
| 563 |
+
add_special_tokens: bool = True,
|
| 564 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
| 565 |
+
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
|
| 566 |
+
max_length: Optional[int] = None,
|
| 567 |
+
stride: int = 0,
|
| 568 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 569 |
+
return_tensors: Optional[str] = None,
|
| 570 |
+
return_token_type_ids: Optional[bool] = None,
|
| 571 |
+
return_attention_mask: Optional[bool] = None,
|
| 572 |
+
return_overflowing_tokens: bool = False,
|
| 573 |
+
return_special_tokens_mask: bool = False,
|
| 574 |
+
return_length: bool = False,
|
| 575 |
+
verbose: bool = True,
|
| 576 |
+
) -> BatchEncoding:
|
| 577 |
+
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
| 578 |
+
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
| 579 |
+
manages a moving window (with user defined stride) for overflowing tokens
|
| 580 |
+
|
| 581 |
+
Args:
|
| 582 |
+
batch_ids_pairs: list of tokenized input ids or input ids pairs
|
| 583 |
+
"""
|
| 584 |
+
|
| 585 |
+
batch_outputs = {}
|
| 586 |
+
for first_ids, second_ids in batch_ids_pairs:
|
| 587 |
+
outputs = self.prepare_for_model(
|
| 588 |
+
first_ids,
|
| 589 |
+
second_ids,
|
| 590 |
+
add_special_tokens=add_special_tokens,
|
| 591 |
+
padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
|
| 592 |
+
truncation=truncation_strategy.value,
|
| 593 |
+
max_length=max_length,
|
| 594 |
+
stride=stride,
|
| 595 |
+
pad_to_multiple_of=None, # we pad in batch afterward
|
| 596 |
+
return_attention_mask=False, # we pad in batch afterward
|
| 597 |
+
return_token_type_ids=return_token_type_ids,
|
| 598 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 599 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 600 |
+
return_length=return_length,
|
| 601 |
+
return_tensors=None, # We convert the whole batch to tensors at the end
|
| 602 |
+
prepend_batch_axis=False,
|
| 603 |
+
verbose=verbose,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
for key, value in outputs.items():
|
| 607 |
+
if key not in batch_outputs:
|
| 608 |
+
batch_outputs[key] = []
|
| 609 |
+
batch_outputs[key].append(value)
|
| 610 |
+
|
| 611 |
+
batch_outputs = self.pad(
|
| 612 |
+
batch_outputs,
|
| 613 |
+
padding=padding_strategy.value,
|
| 614 |
+
max_length=max_length,
|
| 615 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 616 |
+
return_attention_mask=return_attention_mask,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
| 620 |
+
|
| 621 |
+
return batch_outputs
|
| 622 |
+
|
| 623 |
+
def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
|
| 624 |
+
""" Performs any necessary transformations before tokenization.
|
| 625 |
+
|
| 626 |
+
This method should pop the arguments from kwargs and return kwargs as well.
|
| 627 |
+
We test kwargs at the end of the encoding process to be sure all the arguments have been used.
|
| 628 |
+
"""
|
| 629 |
+
return (text, kwargs)
|
| 630 |
+
|
| 631 |
+
def get_special_tokens_mask(
|
| 632 |
+
self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
|
| 633 |
+
) -> List[int]:
|
| 634 |
+
"""
|
| 635 |
+
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 636 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
| 637 |
+
|
| 638 |
+
Args:
|
| 639 |
+
token_ids_0: list of ids (must not contain special tokens)
|
| 640 |
+
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
|
| 641 |
+
for sequence pairs
|
| 642 |
+
already_has_special_tokens: (default False) Set to True if the token list is already formated with
|
| 643 |
+
special tokens for the model
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 647 |
+
"""
|
| 648 |
+
return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
|
| 649 |
+
|
| 650 |
+
def convert_ids_to_tokens(
|
| 651 |
+
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
| 652 |
+
) -> Union[str, List[str]]:
|
| 653 |
+
""" Converts a single index or a sequence of indices (integers) in a token "
|
| 654 |
+
(resp.) a sequence of tokens (str), using the vocabulary and added tokens.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
|
| 658 |
+
"""
|
| 659 |
+
if isinstance(ids, int):
|
| 660 |
+
if ids in self.added_tokens_decoder:
|
| 661 |
+
return self.added_tokens_decoder[ids]
|
| 662 |
+
else:
|
| 663 |
+
return self._convert_id_to_token(ids)
|
| 664 |
+
tokens = []
|
| 665 |
+
for index in ids:
|
| 666 |
+
index = int(index)
|
| 667 |
+
if skip_special_tokens and index in self.all_special_ids:
|
| 668 |
+
continue
|
| 669 |
+
if index in self.added_tokens_decoder:
|
| 670 |
+
tokens.append(self.added_tokens_decoder[index])
|
| 671 |
+
else:
|
| 672 |
+
tokens.append(self._convert_id_to_token(index))
|
| 673 |
+
return tokens
|
| 674 |
+
|
| 675 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 676 |
+
raise NotImplementedError
|
| 677 |
+
|
| 678 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 679 |
+
""" Converts a sequence of tokens (string) in a single string.
|
| 680 |
+
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
|
| 681 |
+
but we often want to remove sub-word tokenization artifacts at the same time.
|
| 682 |
+
"""
|
| 683 |
+
return " ".join(self.convert_ids_to_tokens(tokens))
|
| 684 |
+
|
| 685 |
+
def decode(
|
| 686 |
+
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
|
| 687 |
+
) -> str:
|
| 688 |
+
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
| 689 |
+
|
| 690 |
+
# To avoid mixing byte-level and unicode for byte-level BPT
|
| 691 |
+
# we need to build string separatly for added tokens and byte-level tokens
|
| 692 |
+
# cf. https://github.com/huggingface/transformers/issues/1133
|
| 693 |
+
sub_texts = []
|
| 694 |
+
current_sub_text = []
|
| 695 |
+
for token in filtered_tokens:
|
| 696 |
+
if skip_special_tokens and token in self.all_special_ids:
|
| 697 |
+
continue
|
| 698 |
+
if token in self.added_tokens_encoder:
|
| 699 |
+
if current_sub_text:
|
| 700 |
+
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
| 701 |
+
current_sub_text = []
|
| 702 |
+
sub_texts.append(token)
|
| 703 |
+
else:
|
| 704 |
+
current_sub_text.append(token)
|
| 705 |
+
if current_sub_text:
|
| 706 |
+
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
|
| 707 |
+
text = " ".join(sub_texts)
|
| 708 |
+
|
| 709 |
+
if clean_up_tokenization_spaces:
|
| 710 |
+
clean_text = self.clean_up_tokenization(text)
|
| 711 |
+
return clean_text
|
| 712 |
+
else:
|
| 713 |
+
return text
|
| 714 |
+
|
| 715 |
+
def save_vocabulary(self, save_directory) -> Tuple[str]:
|
| 716 |
+
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
|
| 717 |
+
and special token mappings.
|
| 718 |
+
|
| 719 |
+
Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
|
| 720 |
+
Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
|
| 721 |
+
class method.
|
| 722 |
+
"""
|
| 723 |
+
raise NotImplementedError
|
CGFormer/bert/tokenization_utils_base.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
CGFormer/ckpts/swin_base_patch4_window12_384_22k.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70812ab6b0a7a38712409d13976df9431632466eaacf991d5e90d9a1e91f3ab1
|
| 3 |
+
size 450809979
|
CGFormer/config/config_gref_ace.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcocog_u
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcocog_u/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcocog_u
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 32 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 # batch size for training
|
| 25 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: True
|
| 42 |
+
exclude_multiobj: True
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
filter_threshold: 0.5
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
use_projections : False
|
| 52 |
+
|
| 53 |
+
Distributed:
|
| 54 |
+
# dist_url: tcp://localhost:18123
|
| 55 |
+
dist_backend: 'nccl'
|
| 56 |
+
# multiprocessing_distributed: True
|
| 57 |
+
world_size: 1
|
| 58 |
+
# rank: 0
|
| 59 |
+
TEST:
|
| 60 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 61 |
+
test_split: test
|
| 62 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 63 |
+
visualize: False
|
CGFormer/config/config_mosaic_refcocog_u.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcocog_u
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcocog_u/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcocog_u
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 32 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 # batch size for training
|
| 25 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer
|
| 36 |
+
output_folder: exp/mosaic_refcocog_u/
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
Distributed:
|
| 42 |
+
dist_url: tcp://localhost:12345
|
| 43 |
+
dist_backend: 'nccl'
|
| 44 |
+
multiprocessing_distributed: True
|
| 45 |
+
world_size: 1
|
| 46 |
+
rank: 0
|
| 47 |
+
TEST:
|
| 48 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 49 |
+
test_split: val
|
| 50 |
+
test_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 51 |
+
visualize: False
|
CGFormer/config/config_rcc_ace.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcoco
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcoco/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcoco/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcoco
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 16 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 #batch size for training
|
| 25 |
+
batch_size_val: 24 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: False
|
| 42 |
+
exclude_multiobj: true
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
use_projections : True
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
filter_threshold: 0.68
|
| 52 |
+
|
| 53 |
+
Distributed:
|
| 54 |
+
# dist_url: tcp://localhost:18123
|
| 55 |
+
dist_backend: 'nccl'
|
| 56 |
+
# multiprocessing_distributed: True
|
| 57 |
+
world_size: 1
|
| 58 |
+
# rank: 0
|
| 59 |
+
TEST:
|
| 60 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 61 |
+
test_split: test
|
| 62 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 63 |
+
visualize: False
|
CGFormer/config/config_rccp_ace.yaml
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcoco+
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcoco+/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcoco+/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcoco+
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 16 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 #batch size for training
|
| 25 |
+
batch_size_val: 24 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: False
|
| 42 |
+
exclude_multiobj: true
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
use_projections : True
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
filter_threshold: 0.68
|
| 52 |
+
|
| 53 |
+
Distributed:
|
| 54 |
+
# dist_url: tcp://localhost:18123
|
| 55 |
+
dist_backend: 'nccl'
|
| 56 |
+
# multiprocessing_distributed: True
|
| 57 |
+
world_size: 1
|
| 58 |
+
# rank: 0
|
| 59 |
+
TEST:
|
| 60 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 61 |
+
test_split: test
|
| 62 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 63 |
+
visualize: False
|
CGFormer/config/config_refzom_ace.yaml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: ref-zom
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
|
| 5 |
+
val_split: test
|
| 6 |
+
val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
|
| 7 |
+
mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 32 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 # batch size for training
|
| 25 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: True
|
| 42 |
+
exclude_multiobj: True
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
filter_threshold: 0.5
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
use_projections : False
|
| 52 |
+
fuse_mode : simple_attn
|
| 53 |
+
|
| 54 |
+
Distributed:
|
| 55 |
+
# dist_url: tcp://localhost:18123
|
| 56 |
+
dist_backend: 'nccl'
|
| 57 |
+
# multiprocessing_distributed: True
|
| 58 |
+
world_size: 1
|
| 59 |
+
# rank: 0
|
| 60 |
+
TEST:
|
| 61 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 62 |
+
test_split: test
|
| 63 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 64 |
+
visualize: False
|
CGFormer/config/config_refzom_repro.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: ref-zom
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
|
| 5 |
+
val_split: test
|
| 6 |
+
val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
|
| 7 |
+
mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 32 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 # batch size for training
|
| 25 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: False
|
| 42 |
+
exclude_multiobj: True
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
filter_threshold: 0.5
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
|
| 52 |
+
Distributed:
|
| 53 |
+
# dist_url: tcp://localhost:18123
|
| 54 |
+
dist_backend: 'nccl'
|
| 55 |
+
# multiprocessing_distributed: True
|
| 56 |
+
world_size: 1
|
| 57 |
+
# rank: 0
|
| 58 |
+
TEST:
|
| 59 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 60 |
+
test_split: test
|
| 61 |
+
test_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
|
| 62 |
+
visualize: False
|
CGFormer/config/config_refzom_repro_eval.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: ref-zom
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
|
| 5 |
+
val_split: test
|
| 6 |
+
val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
|
| 7 |
+
mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
|
| 8 |
+
TRAIN:
|
| 9 |
+
swin_type: base
|
| 10 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 11 |
+
bert: bert-base-uncased
|
| 12 |
+
mha: '8-8-8-8'
|
| 13 |
+
input_size: 480
|
| 14 |
+
word_len: 20
|
| 15 |
+
word_dim: 768
|
| 16 |
+
vis_dim: 512
|
| 17 |
+
num_token: 2
|
| 18 |
+
token_dim: 512
|
| 19 |
+
sync_bn: True
|
| 20 |
+
dropout: 0.
|
| 21 |
+
fusion_drop: 0.
|
| 22 |
+
workers: 32 # data loader workers
|
| 23 |
+
workers_val: 8
|
| 24 |
+
batch_size: 64 # batch size for training
|
| 25 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 26 |
+
start_epoch: 0
|
| 27 |
+
epochs: 50
|
| 28 |
+
lr_backbone: 5.e-5
|
| 29 |
+
lr_text_encoder: 5.e-5
|
| 30 |
+
lr: 1.e-4
|
| 31 |
+
weight_decay: 1.e-4
|
| 32 |
+
amsgrad: True
|
| 33 |
+
manual_seed:
|
| 34 |
+
print_freq: 100
|
| 35 |
+
exp_name: cgformer_test
|
| 36 |
+
output_folder: /data/seunghoon/CGFormer/exp/seunghoon
|
| 37 |
+
save_freq: 1
|
| 38 |
+
weight:
|
| 39 |
+
resume:
|
| 40 |
+
evaluate: True
|
| 41 |
+
metric_learning: False
|
| 42 |
+
exclude_multiobj: True
|
| 43 |
+
metric_mode: hardpos_only_refined
|
| 44 |
+
metric_loss_weight: 0.1
|
| 45 |
+
loss_option: ACE_verbonly
|
| 46 |
+
margin_value: 12
|
| 47 |
+
temperature: 0.07
|
| 48 |
+
hp_selection: strict
|
| 49 |
+
filter_threshold: 0.5
|
| 50 |
+
mixup_lasttwo : False
|
| 51 |
+
|
| 52 |
+
Distributed:
|
| 53 |
+
# dist_url: tcp://localhost:18123
|
| 54 |
+
dist_backend: 'nccl'
|
| 55 |
+
# multiprocessing_distributed: True
|
| 56 |
+
world_size: 1
|
| 57 |
+
# rank: 0
|
| 58 |
+
TEST:
|
| 59 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 60 |
+
test_split: test
|
| 61 |
+
test_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
|
| 62 |
+
visualize: False
|
CGFormer/config/impl/config.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcocog_u
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcocog_u/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcocog_u/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcocog_u
|
| 8 |
+
AUG:
|
| 9 |
+
check: null
|
| 10 |
+
TRAIN:
|
| 11 |
+
swin_type: base
|
| 12 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 13 |
+
bert: bert-base-uncased
|
| 14 |
+
mha: '8-8-8-8'
|
| 15 |
+
input_size: 480
|
| 16 |
+
word_len: 20
|
| 17 |
+
word_dim: 768
|
| 18 |
+
vis_dim: 512
|
| 19 |
+
num_token: 2
|
| 20 |
+
token_dim: 512
|
| 21 |
+
sync_bn: True
|
| 22 |
+
dropout: 0.
|
| 23 |
+
fusion_drop: 0.
|
| 24 |
+
workers: 32 # data loader workers
|
| 25 |
+
workers_val: 8
|
| 26 |
+
batch_size: 32 #64 # batch size for training
|
| 27 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 28 |
+
start_epoch: 0
|
| 29 |
+
epochs: 3
|
| 30 |
+
lr_backbone: 5.e-5
|
| 31 |
+
lr_text_encoder: 5.e-5
|
| 32 |
+
lr: 1.e-4
|
| 33 |
+
weight_decay: 1.e-4
|
| 34 |
+
amsgrad: True
|
| 35 |
+
manual_seed:
|
| 36 |
+
print_freq: 100
|
| 37 |
+
exp_name: cgformer
|
| 38 |
+
output_folder: exp/impl/
|
| 39 |
+
save_freq: 1
|
| 40 |
+
weight:
|
| 41 |
+
resume:
|
| 42 |
+
evaluate: True
|
| 43 |
+
Distributed:
|
| 44 |
+
# dist_url: tcp://localhost:18123
|
| 45 |
+
dist_backend: 'nccl'
|
| 46 |
+
# multiprocessing_distributed: True
|
| 47 |
+
world_size: 1
|
| 48 |
+
# rank: 0
|
| 49 |
+
TEST:
|
| 50 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 51 |
+
test_split: test
|
| 52 |
+
test_lmdb: data/lmdb/refcocog_u/test.lmdb
|
| 53 |
+
visualize: False
|
CGFormer/config/open.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcoco
|
| 3 |
+
train_split: train_seen
|
| 4 |
+
train_lmdb: path/open_lmdb/refcoco/train_seen.lmdb
|
| 5 |
+
val_seen_split: val_seen
|
| 6 |
+
val_seen_lmdb: path/open_lmdb/refcoco/val_seen.lmdb
|
| 7 |
+
val_unseen_split: val_unseen
|
| 8 |
+
val_unseen_lmdb: path/open_lmdb/refcoco/val_unseen.lmdb
|
| 9 |
+
mask_root: path/masks/refcoco
|
| 10 |
+
TRAIN:
|
| 11 |
+
swin_type: base
|
| 12 |
+
swin_pretrain: path/swin_base_patch4_window12_384_22k.pth
|
| 13 |
+
bert: bert-base-uncased
|
| 14 |
+
clip_pretrain: path/pretrain/ViT-L-14-336px.pt
|
| 15 |
+
mha: '8-8-8-8'
|
| 16 |
+
input_size: 480
|
| 17 |
+
clip_dim: 768
|
| 18 |
+
word_len: 20
|
| 19 |
+
num_token: 2
|
| 20 |
+
word_dim: 768
|
| 21 |
+
vis_dim: 512
|
| 22 |
+
token_dim: 512
|
| 23 |
+
sync_bn: True
|
| 24 |
+
dropout: 0.
|
| 25 |
+
fusion_drop: 0.
|
| 26 |
+
workers: 32 # data loader workers
|
| 27 |
+
workers_val: 8
|
| 28 |
+
batch_size: 64 # batch size for training
|
| 29 |
+
batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
|
| 30 |
+
start_epoch: 0
|
| 31 |
+
epochs: 1000
|
| 32 |
+
lr_backbone: 5.e-5
|
| 33 |
+
lr_text_encoder: 5.e-5
|
| 34 |
+
lr: 1.e-4
|
| 35 |
+
weight_decay: 1.e-4
|
| 36 |
+
amsgrad: True
|
| 37 |
+
manual_seed: 0
|
| 38 |
+
print_freq: 100
|
| 39 |
+
exp_name: open
|
| 40 |
+
output_folder: exp/refcoco
|
| 41 |
+
save_freq: 1
|
| 42 |
+
weight:
|
| 43 |
+
resume:
|
| 44 |
+
evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
|
| 45 |
+
Distributed:
|
| 46 |
+
dist_url: tcp://localhost:12345
|
| 47 |
+
dist_backend: 'nccl'
|
| 48 |
+
multiprocessing_distributed: True
|
| 49 |
+
world_size: 1
|
| 50 |
+
rank: 0
|
| 51 |
+
TEST:
|
| 52 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 53 |
+
test_split: test_unseen
|
| 54 |
+
test_lmdb: path/refcoco/test_unseen.lmdb
|
| 55 |
+
visualize: False
|
CGFormer/config/refcoco_mosaic/config.yaml
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA:
|
| 2 |
+
dataset: refcoco
|
| 3 |
+
train_split: train
|
| 4 |
+
train_lmdb: data/lmdb/refcoco/train.lmdb
|
| 5 |
+
val_split: val
|
| 6 |
+
val_lmdb: data/lmdb/refcoco/val.lmdb
|
| 7 |
+
mask_root: data/masks/refcoco
|
| 8 |
+
|
| 9 |
+
TRAIN:
|
| 10 |
+
swin_type: base
|
| 11 |
+
swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
|
| 12 |
+
bert: bert-base-uncased
|
| 13 |
+
mha: '8-8-8-8'
|
| 14 |
+
input_size: 480
|
| 15 |
+
word_len: 20
|
| 16 |
+
word_dim: 768
|
| 17 |
+
vis_dim: 512
|
| 18 |
+
num_token: 2
|
| 19 |
+
token_dim: 512
|
| 20 |
+
sync_bn: True
|
| 21 |
+
dropout: 0.
|
| 22 |
+
fusion_drop: 0.
|
| 23 |
+
workers: 16 # data loader workers
|
| 24 |
+
workers_val: 8
|
| 25 |
+
batch_size: 64 #batch size for training
|
| 26 |
+
batch_size_val: 16 # 16 batch size for validation during training, memory and speed tradeoff
|
| 27 |
+
start_epoch: 0
|
| 28 |
+
epochs: 50
|
| 29 |
+
lr_backbone: 5.e-5
|
| 30 |
+
lr_text_encoder: 5.e-5
|
| 31 |
+
lr: 1.e-4
|
| 32 |
+
weight_decay: 1.e-4
|
| 33 |
+
amsgrad: True
|
| 34 |
+
manual_seed:
|
| 35 |
+
print_freq: 100
|
| 36 |
+
exp_name: cgformer
|
| 37 |
+
output_folder: exp/refcoco_mosaic/
|
| 38 |
+
save_freq: 1
|
| 39 |
+
weight:
|
| 40 |
+
resume:
|
| 41 |
+
evaluate: True
|
| 42 |
+
aug:
|
| 43 |
+
num_bgs: 4
|
| 44 |
+
aug_prob: 0.6
|
| 45 |
+
tgt_selection: fixed
|
| 46 |
+
move_crs_pnt: False
|
| 47 |
+
blur: False
|
| 48 |
+
|
| 49 |
+
Distributed:
|
| 50 |
+
# dist_url: tcp://localhost:18123
|
| 51 |
+
dist_backend: 'nccl'
|
| 52 |
+
# multiprocessing_distributed: True
|
| 53 |
+
world_size: 1
|
| 54 |
+
# rank: 0
|
| 55 |
+
TEST:
|
| 56 |
+
window12: True # if use window12 pretrained for training, testing set true
|
| 57 |
+
test_split: val
|
| 58 |
+
test_lmdb: data/lmdb/refcoco/test.lmdb
|
| 59 |
+
visualize: False
|