dianecy commited on
Commit
ea1014e
·
verified ·
1 Parent(s): 1120a2f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +38 -0
  2. CGFormer/.gitignore +7 -0
  3. CGFormer/.ipynb_checkpoints/test-checkpoint.py +106 -0
  4. CGFormer/.ipynb_checkpoints/test_mosaic-checkpoint.py +106 -0
  5. CGFormer/LICENSE +21 -0
  6. CGFormer/README.md +56 -0
  7. CGFormer/bash_logs/ACE_filter050.log +480 -0
  8. CGFormer/bash_logs/ACE_filter050_rev.log +528 -0
  9. CGFormer/bash_logs/sanity_node03.log +0 -0
  10. CGFormer/bert/__pycache__/activations.cpython-38.pyc +0 -0
  11. CGFormer/bert/__pycache__/activations.cpython-39.pyc +0 -0
  12. CGFormer/bert/__pycache__/configuration_bert.cpython-38.pyc +0 -0
  13. CGFormer/bert/__pycache__/configuration_bert.cpython-39.pyc +0 -0
  14. CGFormer/bert/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  15. CGFormer/bert/__pycache__/configuration_utils.cpython-39.pyc +0 -0
  16. CGFormer/bert/__pycache__/file_utils.cpython-38.pyc +0 -0
  17. CGFormer/bert/__pycache__/file_utils.cpython-39.pyc +0 -0
  18. CGFormer/bert/__pycache__/generation_utils.cpython-38.pyc +0 -0
  19. CGFormer/bert/__pycache__/generation_utils.cpython-39.pyc +0 -0
  20. CGFormer/bert/__pycache__/modeling_bert.cpython-38.pyc +0 -0
  21. CGFormer/bert/__pycache__/modeling_bert.cpython-39.pyc +0 -0
  22. CGFormer/bert/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  23. CGFormer/bert/__pycache__/modeling_utils.cpython-39.pyc +0 -0
  24. CGFormer/bert/__pycache__/tokenization_bert.cpython-38.pyc +0 -0
  25. CGFormer/bert/__pycache__/tokenization_bert.cpython-39.pyc +0 -0
  26. CGFormer/bert/__pycache__/tokenization_utils.cpython-38.pyc +0 -0
  27. CGFormer/bert/__pycache__/tokenization_utils.cpython-39.pyc +0 -0
  28. CGFormer/bert/__pycache__/tokenization_utils_base.cpython-38.pyc +0 -0
  29. CGFormer/bert/__pycache__/tokenization_utils_base.cpython-39.pyc +0 -0
  30. CGFormer/bert/activations.py +56 -0
  31. CGFormer/bert/configuration_bert.py +143 -0
  32. CGFormer/bert/configuration_utils.py +408 -0
  33. CGFormer/bert/file_utils.py +808 -0
  34. CGFormer/bert/generation_utils.py +993 -0
  35. CGFormer/bert/modeling_bert.py +1569 -0
  36. CGFormer/bert/modeling_utils.py +1268 -0
  37. CGFormer/bert/tokenization_bert.py +545 -0
  38. CGFormer/bert/tokenization_utils.py +723 -0
  39. CGFormer/bert/tokenization_utils_base.py +0 -0
  40. CGFormer/ckpts/swin_base_patch4_window12_384_22k.pth +3 -0
  41. CGFormer/config/config_gref_ace.yaml +63 -0
  42. CGFormer/config/config_mosaic_refcocog_u.yaml +51 -0
  43. CGFormer/config/config_rcc_ace.yaml +63 -0
  44. CGFormer/config/config_rccp_ace.yaml +63 -0
  45. CGFormer/config/config_refzom_ace.yaml +64 -0
  46. CGFormer/config/config_refzom_repro.yaml +62 -0
  47. CGFormer/config/config_refzom_repro_eval.yaml +62 -0
  48. CGFormer/config/impl/config.yaml +53 -0
  49. CGFormer/config/open.yaml +55 -0
  50. CGFormer/config/refcoco_mosaic/config.yaml +59 -0
.gitattributes CHANGED
@@ -34,3 +34,41 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  RIS-DMMI/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
37
+ CGFormer/external/mmsegmentation/demo/demo.png filter=lfs diff=lfs merge=lfs -text
38
+ CGFormer/external/mmsegmentation/docs/zh_cn/imgs/qq_group_qrcode.jpg filter=lfs diff=lfs merge=lfs -text
39
+ CGFormer/external/mmsegmentation/docs/zh_cn/imgs/zhihu_qrcode.jpg filter=lfs diff=lfs merge=lfs -text
40
+ CGFormer/external/mmsegmentation/projects/medical/2d_image/histopathology/conic2022_seg/conic2022_seg_dataset.png filter=lfs diff=lfs merge=lfs -text
41
+ CGFormer/external/mmsegmentation/resources/3dogs.jpg filter=lfs diff=lfs merge=lfs -text
42
+ CGFormer/external/mmsegmentation/resources/seg_demo.gif filter=lfs diff=lfs merge=lfs -text
43
+ CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/0.png filter=lfs diff=lfs merge=lfs -text
44
+ CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/1.png filter=lfs diff=lfs merge=lfs -text
45
+ CGFormer/external/mmsegmentation/tests/data/pseudo_loveda_dataset/img_dir/2.png filter=lfs diff=lfs merge=lfs -text
46
+ CGFormer/external/mmsegmentation/tests/data/pseudo_potsdam_dataset/img_dir/2_10_0_0_512_512.png filter=lfs diff=lfs merge=lfs -text
47
+ CGFormer/external/mmsegmentation/tests/data/pseudo_refuge_dataset/img_dir/pseudo_g0001.png filter=lfs diff=lfs merge=lfs -text
48
+ CGFormer/external/mmsegmentation/tests/data/pseudo_vaihingen_dataset/img_dir/area1_0_0_512_512.png filter=lfs diff=lfs merge=lfs -text
49
+ CGFormer/image/framework.jpg filter=lfs diff=lfs merge=lfs -text
50
+ CGFormer/wandb/offline-run-20250307_173512-9h2on932/run-9h2on932.wandb filter=lfs diff=lfs merge=lfs -text
51
+ CGFormer/wandb/offline-run-20250307_174303-li5zqatl/run-li5zqatl.wandb filter=lfs diff=lfs merge=lfs -text
52
+ CGFormer/wandb/offline-run-20250307_182402-j7d7o60n/run-j7d7o60n.wandb filter=lfs diff=lfs merge=lfs -text
53
+ CGFormer/wandb/offline-run-20250307_183427-lje8pep7/run-lje8pep7.wandb filter=lfs diff=lfs merge=lfs -text
54
+ CGFormer/wandb/offline-run-20250307_191605-qwg5jc6l/run-qwg5jc6l.wandb filter=lfs diff=lfs merge=lfs -text
55
+ CGFormer/wandb/offline-run-20250307_191652-pdgidm12/run-pdgidm12.wandb filter=lfs diff=lfs merge=lfs -text
56
+ CGFormer/wandb/offline-run-20250307_200613-ikc5v4qd/run-ikc5v4qd.wandb filter=lfs diff=lfs merge=lfs -text
57
+ CGFormer/wandb/offline-run-20250307_201001-i0m64au8/run-i0m64au8.wandb filter=lfs diff=lfs merge=lfs -text
58
+ CGFormer/wandb/offline-run-20250307_210707-ialdzorz/run-ialdzorz.wandb filter=lfs diff=lfs merge=lfs -text
59
+ CGFormer/wandb/offline-run-20250307_211011-2bbev839/run-2bbev839.wandb filter=lfs diff=lfs merge=lfs -text
60
+ CGFormer/wandb/offline-run-20250308_193217-dnb3uu3l/run-dnb3uu3l.wandb filter=lfs diff=lfs merge=lfs -text
61
+ CGFormer/wandb/offline-run-20250308_231240-qhgmf9fk/run-qhgmf9fk.wandb filter=lfs diff=lfs merge=lfs -text
62
+ CGFormer/wandb/offline-run-20250309_115725-5fgrfjdy/run-5fgrfjdy.wandb filter=lfs diff=lfs merge=lfs -text
63
+ CGFormer/wandb/offline-run-20250309_170924-0x06srss/run-0x06srss.wandb filter=lfs diff=lfs merge=lfs -text
64
+ CGFormer/wandb/offline-run-20250309_171616-684omhh0/run-684omhh0.wandb filter=lfs diff=lfs merge=lfs -text
65
+ CGFormer/wandb/offline-run-20250309_171623-3b8sr48c/run-3b8sr48c.wandb filter=lfs diff=lfs merge=lfs -text
66
+ CGFormer/wandb/offline-run-20250309_173234-04u0nc2s/run-04u0nc2s.wandb filter=lfs diff=lfs merge=lfs -text
67
+ CGFormer/wandb/offline-run-20250309_202147-gbujz424/run-gbujz424.wandb filter=lfs diff=lfs merge=lfs -text
68
+ CGFormer/wandb/offline-run-20250309_203311-xfi3d65b/run-xfi3d65b.wandb filter=lfs diff=lfs merge=lfs -text
69
+ CGFormer/wandb/offline-run-20250309_203334-e1fdhljy/run-e1fdhljy.wandb filter=lfs diff=lfs merge=lfs -text
70
+ CGFormer/wandb/offline-run-20250309_205719-wlfh3gyq/run-wlfh3gyq.wandb filter=lfs diff=lfs merge=lfs -text
71
+ CGFormer/wandb/offline-run-20250309_205856-a8l51dy6/run-a8l51dy6.wandb filter=lfs diff=lfs merge=lfs -text
72
+ CGFormer/wandb/offline-run-20250309_212058-k3fcizav/run-k3fcizav.wandb filter=lfs diff=lfs merge=lfs -text
73
+ CGFormer/wandb/offline-run-20250309_212213-a992tfly/run-a992tfly.wandb filter=lfs diff=lfs merge=lfs -text
74
+ CGFormer/wandb/offline-run-20250311_174825-ghnm4ky9/run-ghnm4ky9.wandb filter=lfs diff=lfs merge=lfs -text
CGFormer/.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ /exp
2
+ /wandb/**
3
+ **/__pycache__
4
+ /train_open.py
5
+ /.vscode
6
+ config/config.yaml
7
+ config/open.yaml
CGFormer/.ipynb_checkpoints/test-checkpoint.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.utils.data
9
+ from loguru import logger
10
+
11
+ import deepspeed
12
+ import utils.config as config
13
+ from engine.engine import inference
14
+ from model import build_segmenter
15
+ from utils.dataset import RefDataset
16
+ from utils.misc import setup_logger
17
+
18
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
19
+
20
+
21
+ warnings.filterwarnings("ignore")
22
+ cv2.setNumThreads(0)
23
+
24
+
25
+ def get_parser():
26
+ parser = argparse.ArgumentParser(
27
+ description='Pytorch Referring Expression Segmentation')
28
+ parser.add_argument('--config',
29
+ default='path to xxx.yaml',
30
+ type=str,
31
+ help='config file')
32
+ parser.add_argument('--opts',
33
+ default=None,
34
+ nargs=argparse.REMAINDER,
35
+ help='override some settings in the config.')
36
+ args = parser.parse_args()
37
+ assert args.config is not None
38
+ cfg = config.load_cfg_from_cfg_file(args.config)
39
+ if args.opts is not None:
40
+ cfg = config.merge_cfg_from_list(cfg, args.opts)
41
+ return cfg
42
+
43
+
44
+ @logger.catch
45
+ def main():
46
+ args = get_parser()
47
+ args.output_dir = os.path.join(args.output_folder, args.exp_name)
48
+ if args.visualize:
49
+ args.vis_dir = os.path.join(args.output_dir, "vis")
50
+ os.makedirs(args.vis_dir, exist_ok=True)
51
+
52
+ # logger
53
+ setup_logger(args.output_dir,
54
+ distributed_rank=0,
55
+ filename="test.log",
56
+ mode="a")
57
+ logger.info(args.test_split)
58
+
59
+ # build dataset & dataloader
60
+ test_data = RefDataset(lmdb_dir=args.test_lmdb,
61
+ mask_dir=args.mask_root,
62
+ dataset=args.dataset,
63
+ split=args.test_split,
64
+ mode='test',
65
+ input_size=args.input_size,
66
+ word_length=args.word_len,
67
+ args=args)
68
+ test_loader = torch.utils.data.DataLoader(test_data,
69
+ batch_size=1,
70
+ shuffle=False,
71
+ num_workers=1,
72
+ pin_memory=True)
73
+
74
+
75
+ # build model
76
+ model = build_segmenter(args, DDP=False)
77
+
78
+ # deepspeed_config = {
79
+ # "kernel_inject": False,
80
+ # "dtype": "fp16",
81
+ # "enable_cuda_graph": True,
82
+ # "checkpoint": f'{args.output_dir}/best_model'
83
+ # }
84
+
85
+ # logger.info(model)
86
+
87
+ #args.model_dir = os.path.join(args.output_dir, "best_model.pth")
88
+ if os.path.isdir(args.output_dir):
89
+ logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
90
+ #checkpoint = torch.load(args.model_dir)
91
+ #model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)
92
+ model = load_state_dict_from_zero_checkpoint(model, args.output_dir, tag="best_model").cuda()
93
+ #model.load_checkpoint(args.output_dir, tag="best_model")
94
+
95
+ logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
96
+ else:
97
+ raise ValueError(
98
+ "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
99
+ .format(args.output_dir))
100
+
101
+ # inference
102
+ inference(test_loader, model, args)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
CGFormer/.ipynb_checkpoints/test_mosaic-checkpoint.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn.parallel
8
+ import torch.utils.data
9
+ from loguru import logger
10
+
11
+ import deepspeed
12
+ import utils.config as config
13
+ from engine.engine import inference
14
+ from model import build_segmenter
15
+ from utils.dataset_mosaic import RefDataset
16
+ from utils.misc import setup_logger
17
+
18
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
19
+
20
+
21
+ warnings.filterwarnings("ignore")
22
+ cv2.setNumThreads(0)
23
+
24
+
25
+ def get_parser():
26
+ parser = argparse.ArgumentParser(
27
+ description='Pytorch Referring Expression Segmentation')
28
+ parser.add_argument('--config',
29
+ default='path to xxx.yaml',
30
+ type=str,
31
+ help='config file')
32
+ parser.add_argument('--opts',
33
+ default=None,
34
+ nargs=argparse.REMAINDER,
35
+ help='override some settings in the config.')
36
+ args = parser.parse_args()
37
+ assert args.config is not None
38
+ cfg = config.load_cfg_from_cfg_file(args.config)
39
+ if args.opts is not None:
40
+ cfg = config.merge_cfg_from_list(cfg, args.opts)
41
+ return cfg
42
+
43
+
44
+ @logger.catch
45
+ def main():
46
+ args = get_parser()
47
+ args.output_dir = os.path.join(args.output_folder, args.exp_name)
48
+ if args.visualize:
49
+ args.vis_dir = os.path.join(args.output_dir, "vis")
50
+ os.makedirs(args.vis_dir, exist_ok=True)
51
+
52
+ # logger
53
+ setup_logger(args.output_dir,
54
+ distributed_rank=0,
55
+ filename="test.log",
56
+ mode="a")
57
+ logger.info(args.test_split)
58
+
59
+ # build dataset & dataloader
60
+ test_data = RefDataset(lmdb_dir=args.test_lmdb,
61
+ mask_dir=args.mask_root,
62
+ dataset=args.dataset,
63
+ split=args.test_split,
64
+ mode='test',
65
+ input_size=args.input_size,
66
+ word_length=args.word_len,
67
+ args=args)
68
+ test_loader = torch.utils.data.DataLoader(test_data,
69
+ batch_size=1,
70
+ shuffle=False,
71
+ num_workers=1,
72
+ pin_memory=True)
73
+
74
+
75
+ # build model
76
+ model = build_segmenter(args, DDP=False)
77
+
78
+ # deepspeed_config = {
79
+ # "kernel_inject": False,
80
+ # "dtype": "fp16",
81
+ # "enable_cuda_graph": True,
82
+ # "checkpoint": f'{args.output_dir}/best_model'
83
+ # }
84
+
85
+ # logger.info(model)
86
+
87
+ #args.model_dir = os.path.join(args.output_dir, "best_model.pth")
88
+ if os.path.isdir(args.output_dir):
89
+ logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
90
+ #checkpoint = torch.load(args.model_dir)
91
+ #model.module.load_state_dict(checkpoint['model_state_dict'], strict=True)
92
+ model = load_state_dict_from_zero_checkpoint(model, args.output_dir, tag="best_model").cuda()
93
+ #model.load_checkpoint(args.output_dir, tag="best_model")
94
+
95
+ logger.info(f"=> loading checkpoint '{args.output_dir}/best_model'")
96
+ else:
97
+ raise ValueError(
98
+ "=> resume failed! no checkpoint found at '{}'. Please check args.resume again!"
99
+ .format(args.output_dir))
100
+
101
+ # inference
102
+ inference(test_loader, model, args)
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
CGFormer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Toneyaya
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
CGFormer/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CGFormer
2
+ The official PyTorch implementation of the CVPR 2023 paper "Contrastive Grouping with Transformer for Referring Image Segmentation".
3
+
4
+ This paper first introduces learnable query tokens to represent objects and then alternately queries linguistic features and groups visual features into the query tokens for object-aware cross-modal reasoning. CGFormer achieves cross-level interaction by jointly updating the query tokens and decoding masks in every two consecutive layers. In addition, we introduce new splits on datasets for evaluating generalization for referring image segmentation models.
5
+
6
+ ## Framework
7
+ <p align="center">
8
+ <img src="image/framework.jpg" width="1000">
9
+ </p>
10
+
11
+ ## Preparation
12
+
13
+ 1. Environment
14
+ - [PyTorch](www.pytorch.org)
15
+ - Other dependencies in `requirements.txt`
16
+ 2. Datasets
17
+ - The detailed instruction is in [prepare_datasets](data/READEME.md)
18
+ 3. Pretrained weights
19
+ - [Swin-Base-window12](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)
20
+
21
+ ## Train and Test (RIS)
22
+
23
+ This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. Besides, the evaluation only supports single-gpu mode.
24
+
25
+ To do training of CGFormer with 8 GPUs, run:
26
+
27
+ ```
28
+ python -u train.py --config config/config.yaml
29
+ ```
30
+
31
+ To do evaluation of CGFormer with 1 GPU, run:
32
+ ```
33
+ CUDA_VISIBLE_DEVICES=0 python -u test.py \
34
+ --config config/refcoco/config.yaml \
35
+ --opts TEST.test_split val \
36
+ TEST.test_lmdb path/val.lmdb
37
+ ```
38
+ ## License
39
+
40
+ This project is under the MIT license. See [LICENSE](LICENSE) for details.
41
+
42
+
43
+ ## Citation
44
+ If you find our work useful in your research, please consider citing:
45
+ ```
46
+ @InProceedings{Tang_2023_CVPR,
47
+ author = {Tang, Jiajin and Zheng, Ge and Shi, Cheng and Yang, Sibei},
48
+ title = {Contrastive Grouping With Transformer for Referring Image Segmentation},
49
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
50
+ month = {June},
51
+ year = {2023},
52
+ pages = {23570-23580}
53
+ }
54
+ ```
55
+
56
+ Some code changes come from [CRIS](https://github.com/DerrickWang005/CRIS.pytorch/tree/master) and [LAVT](https://github.com/yz93/LAVT-RIS).
CGFormer/bash_logs/ACE_filter050.log ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated
2
+ and will be removed in future. Use torchrun.
3
+ Note that --use-env is set by default in torchrun.
4
+ If your script expects `--local-rank` argument to be set, please
5
+ change it to read from `os.environ['LOCAL_RANK']` instead. See
6
+ https://pytorch.org/docs/stable/distributed.html#launch-utility for
7
+ further instructions
8
+
9
+ warnings.warn(
10
+ [2025-03-03 00:25:13,383] torch.distributed.run: [WARNING]
11
+ [2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] *****************************************
12
+ [2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
13
+ [2025-03-03 00:25:13,383] torch.distributed.run: [WARNING] *****************************************
14
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
15
+ check_for_updates()
16
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
17
+ check_for_updates()
18
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
19
+ check_for_updates()
20
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
21
+ check_for_updates()
22
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
23
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
24
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
25
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
26
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
27
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
28
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
29
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
30
+ 2025-03-03 00:25:21.347 | INFO | __main__:main:66 - LOCAL_RANK from env: 1
31
+ 2025-03-03 00:25:21.347 | INFO | __main__:main:66 - LOCAL_RANK from env: 0
32
+ 2025-03-03 00:25:21.359 | INFO | __main__:main:66 - LOCAL_RANK from env: 3
33
+ 2025-03-03 00:25:21.369 | INFO | __main__:main:66 - LOCAL_RANK from env: 2
34
+ 2025-03-03 00:25:21 | INFO | __main__:90 - Starting with GPU: 0, Rank: 0, World Size: 4
35
+ git root error: Cmd('git') failed due to: exit code(128)
36
+ cmdline: git rev-parse --show-toplevel
37
+ stderr: 'fatal: detected dubious ownership in repository at '/data2/projects/chaeyun/CGFormer'
38
+ To add an exception for this directory, call:
39
+
40
+ git config --global --add safe.directory /data2/projects/chaeyun/CGFormer'
41
+ wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
42
+ wandb: Tracking run with wandb version 0.19.1
43
+ wandb: W&B syncing is set to `offline` in this directory.
44
+ wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
45
+ node03:2017036:2017036 [0] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
46
+ node03:2017036:2017036 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
47
+ node03:2017036:2017036 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
48
+ node03:2017036:2017036 [0] NCCL INFO cudaDriverVersion 12070
49
+ NCCL version 2.18.5+cuda11.8
50
+ node03:2017039:2017039 [3] NCCL INFO cudaDriverVersion 12070
51
+ node03:2017038:2017038 [2] NCCL INFO cudaDriverVersion 12070
52
+ node03:2017037:2017037 [1] NCCL INFO cudaDriverVersion 12070
53
+ node03:2017038:2017038 [2] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
54
+ node03:2017039:2017039 [3] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
55
+ node03:2017037:2017037 [1] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
56
+ node03:2017038:2017038 [2] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
57
+ node03:2017038:2017038 [2] NCCL INFO NET/Plugin : No plugin found, using internal implementation
58
+ node03:2017039:2017039 [3] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
59
+ node03:2017039:2017039 [3] NCCL INFO NET/Plugin : No plugin found, using internal implementation
60
+ node03:2017037:2017037 [1] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
61
+ node03:2017037:2017037 [1] NCCL INFO NET/Plugin : No plugin found, using internal implementation
62
+ node03:2017039:2017152 [3] NCCL INFO NET/IB : No device found.
63
+ node03:2017039:2017152 [3] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
64
+ node03:2017039:2017152 [3] NCCL INFO Using network Socket
65
+ node03:2017038:2017151 [2] NCCL INFO NET/IB : No device found.
66
+ node03:2017038:2017151 [2] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
67
+ node03:2017038:2017151 [2] NCCL INFO Using network Socket
68
+ node03:2017036:2017153 [0] NCCL INFO NET/IB : No device found.
69
+ node03:2017037:2017154 [1] NCCL INFO NET/IB : No device found.
70
+ node03:2017036:2017153 [0] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
71
+ node03:2017036:2017153 [0] NCCL INFO Using network Socket
72
+ node03:2017037:2017154 [1] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
73
+ node03:2017037:2017154 [1] NCCL INFO Using network Socket
74
+ node03:2017038:2017151 [2] NCCL INFO comm 0xaa0bbe0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 14000 commId 0x190df8e791fecf1b - Init START
75
+ node03:2017037:2017154 [1] NCCL INFO comm 0xadc9fc0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 13000 commId 0x190df8e791fecf1b - Init START
76
+ node03:2017039:2017152 [3] NCCL INFO comm 0xb009870 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 48000 commId 0x190df8e791fecf1b - Init START
77
+ node03:2017036:2017153 [0] NCCL INFO comm 0xb555ea0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 12000 commId 0x190df8e791fecf1b - Init START
78
+ node03:2017037:2017154 [1] NCCL INFO Setting affinity for GPU 1 to 5500,00000055
79
+ node03:2017038:2017151 [2] NCCL INFO Setting affinity for GPU 2 to 5500,00000055
80
+ node03:2017039:2017152 [3] NCCL INFO Setting affinity for GPU 3 to 5500,00000055
81
+ node03:2017036:2017153 [0] NCCL INFO Setting affinity for GPU 0 to 5500,00000055
82
+ node03:2017039:2017152 [3] NCCL INFO Trees [0] -1/-1/-1->3->2 [1] -1/-1/-1->3->2
83
+ node03:2017036:2017153 [0] NCCL INFO Channel 00/02 : 0 1 2 3
84
+ node03:2017038:2017151 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1
85
+ node03:2017037:2017154 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0
86
+ node03:2017039:2017152 [3] NCCL INFO P2P Chunksize set to 131072
87
+ node03:2017036:2017153 [0] NCCL INFO Channel 01/02 : 0 1 2 3
88
+ node03:2017038:2017151 [2] NCCL INFO P2P Chunksize set to 131072
89
+ node03:2017037:2017154 [1] NCCL INFO P2P Chunksize set to 131072
90
+ node03:2017036:2017153 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
91
+ node03:2017036:2017153 [0] NCCL INFO P2P Chunksize set to 131072
92
+ node03:2017037:2017154 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/IPC
93
+ node03:2017038:2017151 [2] NCCL INFO Channel 00 : 2[2] -> 3[3] via SHM/direct/direct
94
+ node03:2017037:2017154 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/IPC
95
+ node03:2017038:2017151 [2] NCCL INFO Channel 01 : 2[2] -> 3[3] via SHM/direct/direct
96
+ node03:2017036:2017153 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/IPC
97
+ node03:2017039:2017152 [3] NCCL INFO Channel 00 : 3[3] -> 0[0] via SHM/direct/direct
98
+ node03:2017039:2017152 [3] NCCL INFO Channel 01 : 3[3] -> 0[0] via SHM/direct/direct
99
+ node03:2017036:2017153 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/IPC
100
+ node03:2017037:2017154 [1] NCCL INFO Connected all rings
101
+ node03:2017036:2017153 [0] NCCL INFO Connected all rings
102
+ node03:2017038:2017151 [2] NCCL INFO Connected all rings
103
+ node03:2017039:2017152 [3] NCCL INFO Connected all rings
104
+ node03:2017037:2017154 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/IPC
105
+ node03:2017039:2017152 [3] NCCL INFO Channel 00 : 3[3] -> 2[2] via SHM/direct/direct
106
+ node03:2017037:2017154 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/IPC
107
+ node03:2017039:2017152 [3] NCCL INFO Channel 01 : 3[3] -> 2[2] via SHM/direct/direct
108
+ node03:2017036:2017153 [0] NCCL INFO Connected all trees
109
+ node03:2017036:2017153 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
110
+ node03:2017036:2017153 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
111
+ node03:2017038:2017151 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/IPC
112
+ node03:2017038:2017151 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/IPC
113
+ node03:2017037:2017154 [1] NCCL INFO Connected all trees
114
+ node03:2017037:2017154 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
115
+ node03:2017037:2017154 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
116
+ node03:2017038:2017151 [2] NCCL INFO Connected all trees
117
+ node03:2017038:2017151 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
118
+ node03:2017038:2017151 [2] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
119
+ node03:2017039:2017152 [3] NCCL INFO Connected all trees
120
+ node03:2017039:2017152 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 512 | 512
121
+ node03:2017039:2017152 [3] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
122
+ node03:2017039:2017152 [3] NCCL INFO comm 0xb009870 rank 3 nranks 4 cudaDev 3 nvmlDev 3 busId 48000 commId 0x190df8e791fecf1b - Init COMPLETE
123
+ node03:2017037:2017154 [1] NCCL INFO comm 0xadc9fc0 rank 1 nranks 4 cudaDev 1 nvmlDev 1 busId 13000 commId 0x190df8e791fecf1b - Init COMPLETE
124
+ node03:2017038:2017151 [2] NCCL INFO comm 0xaa0bbe0 rank 2 nranks 4 cudaDev 2 nvmlDev 2 busId 14000 commId 0x190df8e791fecf1b - Init COMPLETE
125
+ node03:2017036:2017153 [0] NCCL INFO comm 0xb555ea0 rank 0 nranks 4 cudaDev 0 nvmlDev 0 busId 12000 commId 0x190df8e791fecf1b - Init COMPLETE
126
+ 2025-03-03 00:25:23 | INFO | model:31 - Window size 12!
127
+ 2025-03-03 00:25:24 | INFO | model:51 - Initializing Multi-modal Swin Transformer weights from ckpts/swin_base_patch4_window12_384_22k.pth
128
+ 2025-03-03 00:25:25 | INFO | model.backbone:459 - loading swin success !!!
129
+ 2025-03-03 00:25:29 | INFO | __main__:144 - Model moved to GPU: 0
130
+ 2025-03-03 00:25:29 | INFO | __main__:145 - amsgrad: True
131
+ batch_size: 24
132
+ batch_size_val: 16
133
+ bert: bert-base-uncased
134
+ dataset: refcocog_u
135
+ dist_backend: nccl
136
+ dropout: 0.0
137
+ epochs: 50
138
+ evaluate: True
139
+ exclude_multiobj: True
140
+ exp_name: ACE_filter050
141
+ filter_threshold: 0.5
142
+ fusion_drop: 0.0
143
+ gpu: 0
144
+ hp_selection: strict
145
+ input_size: 480
146
+ local_rank: 0
147
+ loss_option: ACE_verbonly
148
+ lr: 0.0001
149
+ lr_backbone: 5e-05
150
+ lr_text_encoder: 5e-05
151
+ manual_seed: 2051388757
152
+ margin_value: 12
153
+ mask_root: data/masks/refcocog_u
154
+ metric_learning: True
155
+ metric_loss_weight: 0.1
156
+ metric_mode: hardpos_only_sbertsim_refined
157
+ mha: 8-8-8-8
158
+ mixup_lasttwo: True
159
+ num_token: 2
160
+ output_dir: exp/refcoco_u/ACE_filter050
161
+ output_folder: exp/refcoco_u
162
+ print_freq: 100
163
+ rank: 0
164
+ resume: None
165
+ save_freq: 1
166
+ start_epoch: 0
167
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
168
+ swin_type: base
169
+ sync_bn: True
170
+ temperature: 0.07
171
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
172
+ test_split: test
173
+ token_dim: 512
174
+ train_lmdb: data/lmdb/refcocog_u/train.lmdb
175
+ train_split: train
176
+ val_lmdb: data/lmdb/refcocog_u/val.lmdb
177
+ val_split: val
178
+ vis_dim: 512
179
+ visualize: False
180
+ weight: None
181
+ weight_decay: 0.0001
182
+ window12: True
183
+ word_dim: 768
184
+ word_len: 20
185
+ workers: 32
186
+ workers_val: 8
187
+ world_size: 4
188
+ 2025-03-03 00:28:05 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 100/1759] Batch=1.34 (1.53) Data=0.00 (0.09) Lr=0.000100 Loss=1.1470 (1.1906) IoU=15.42 (19.69) Prec@50=0.00 (8.61)
189
+ 2025-03-03 00:30:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 200/1759] Batch=1.54 (1.50) Data=0.00 (0.06) Lr=0.000100 Loss=1.1851 (1.1191) IoU=24.54 (23.36) Prec@50=12.77 (12.84)
190
+ 2025-03-03 00:32:56 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 300/1759] Batch=1.23 (1.48) Data=0.00 (0.05) Lr=0.000100 Loss=0.9150 (1.0781) IoU=33.89 (25.72) Prec@50=29.46 (15.38)
191
+ 2025-03-03 00:35:23 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 400/1759] Batch=1.47 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.9083 (1.0518) IoU=31.66 (26.80) Prec@50=18.33 (16.54)
192
+ 2025-03-03 00:37:51 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 500/1759] Batch=1.51 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.9000 (1.0296) IoU=32.92 (27.76) Prec@50=22.46 (17.54)
193
+ 2025-03-03 00:40:19 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 600/1759] Batch=1.31 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.8722 (1.0109) IoU=40.11 (28.53) Prec@50=26.04 (18.47)
194
+ 2025-03-03 00:42:47 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 700/1759] Batch=1.56 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8722 (0.9958) IoU=38.67 (29.04) Prec@50=33.12 (19.23)
195
+ 2025-03-03 00:45:15 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 800/1759] Batch=1.21 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.9943 (0.9830) IoU=31.87 (29.65) Prec@50=22.32 (20.11)
196
+ 2025-03-03 00:47:43 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 900/1759] Batch=1.35 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8947 (0.9714) IoU=36.31 (30.10) Prec@50=25.10 (20.81)
197
+ 2025-03-03 00:50:09 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1000/1759] Batch=1.77 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.8877 (0.9635) IoU=33.68 (30.41) Prec@50=25.49 (21.28)
198
+ 2025-03-03 00:52:36 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1100/1759] Batch=1.59 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7724 (0.9548) IoU=38.99 (30.85) Prec@50=36.88 (21.78)
199
+ 2025-03-03 00:55:03 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1200/1759] Batch=1.45 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7028 (0.9458) IoU=47.27 (31.19) Prec@50=49.79 (22.26)
200
+ 2025-03-03 00:57:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1300/1759] Batch=1.42 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7939 (0.9381) IoU=39.03 (31.59) Prec@50=25.15 (22.68)
201
+ 2025-03-03 00:59:59 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1400/1759] Batch=1.27 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7186 (0.9315) IoU=40.97 (31.96) Prec@50=35.71 (23.16)
202
+ 2025-03-03 01:02:23 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1500/1759] Batch=1.80 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8415 (0.9243) IoU=33.69 (32.31) Prec@50=23.39 (23.62)
203
+ 2025-03-03 01:04:50 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1600/1759] Batch=2.03 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8356 (0.9182) IoU=35.39 (32.66) Prec@50=21.51 (24.10)
204
+ 2025-03-03 01:07:15 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1700/1759] Batch=1.40 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7456 (0.9132) IoU=41.80 (33.01) Prec@50=30.85 (24.57)
205
+ 2025-03-03 01:09:16 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[1/50] mIoU=43.35 oIoU=42.54 Pr@50: 39.60 Pr@60: 28.96 Pr@70: 18.63 Pr@80: 10.71 Pr@90: 3.11
206
+ 2025-03-03 01:12:09 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 100/1759] Batch=1.30 (1.47) Data=0.00 (0.04) Lr=0.000100 Loss=0.7384 (0.7771) IoU=36.20 (40.67) Prec@50=25.00 (35.48)
207
+ 2025-03-03 01:14:38 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 200/1759] Batch=1.28 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7371 (0.7694) IoU=41.03 (40.41) Prec@50=33.33 (35.67)
208
+ 2025-03-03 01:17:04 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 300/1759] Batch=1.42 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7440 (0.7664) IoU=37.40 (40.49) Prec@50=29.86 (36.04)
209
+ 2025-03-03 01:19:35 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 400/1759] Batch=1.50 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7437 (0.7650) IoU=35.39 (40.39) Prec@50=33.61 (36.12)
210
+ 2025-03-03 01:22:01 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 500/1759] Batch=1.61 (1.48) Data=0.01 (0.03) Lr=0.000100 Loss=0.8122 (0.7668) IoU=40.75 (40.34) Prec@50=39.79 (36.16)
211
+ 2025-03-03 01:24:25 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 600/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7956 (0.7655) IoU=35.37 (40.32) Prec@50=29.86 (36.20)
212
+ 2025-03-03 01:26:51 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 700/1759] Batch=1.30 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.7419 (0.7664) IoU=43.77 (40.30) Prec@50=37.50 (36.12)
213
+ 2025-03-03 01:29:16 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 800/1759] Batch=1.22 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7913 (0.7653) IoU=42.33 (40.31) Prec@50=45.83 (36.31)
214
+ 2025-03-03 01:31:42 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 900/1759] Batch=1.67 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8527 (0.7637) IoU=32.06 (40.55) Prec@50=30.27 (36.72)
215
+ 2025-03-03 01:34:07 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1000/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8082 (0.7611) IoU=40.24 (40.63) Prec@50=29.76 (36.89)
216
+ 2025-03-03 01:36:33 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1100/1759] Batch=1.27 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.8084 (0.7607) IoU=38.52 (40.66) Prec@50=31.25 (36.92)
217
+ 2025-03-03 01:38:57 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1200/1759] Batch=1.47 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6577 (0.7588) IoU=41.28 (40.73) Prec@50=39.49 (37.16)
218
+ 2025-03-03 01:41:22 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1300/1759] Batch=1.32 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7004 (0.7568) IoU=43.16 (40.84) Prec@50=37.05 (37.41)
219
+ 2025-03-03 01:43:49 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1400/1759] Batch=1.23 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7302 (0.7556) IoU=50.00 (40.88) Prec@50=54.76 (37.47)
220
+ 2025-03-03 01:46:16 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1500/1759] Batch=1.36 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.8330 (0.7550) IoU=39.82 (40.90) Prec@50=27.08 (37.47)
221
+ 2025-03-03 01:48:44 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1600/1759] Batch=1.76 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.7036 (0.7541) IoU=44.15 (40.92) Prec@50=44.38 (37.50)
222
+ 2025-03-03 01:51:09 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1700/1759] Batch=1.50 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6859 (0.7531) IoU=42.78 (41.05) Prec@50=39.23 (37.69)
223
+ 2025-03-03 01:53:12 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[2/50] mIoU=49.64 oIoU=46.67 Pr@50: 47.79 Pr@60: 37.85 Pr@70: 29.00 Pr@80: 18.63 Pr@90: 6.25
224
+ 2025-03-03 01:56:07 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 100/1759] Batch=1.46 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.6057 (0.6714) IoU=51.61 (44.71) Prec@50=54.41 (44.64)
225
+ 2025-03-03 01:58:35 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 200/1759] Batch=1.53 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.9918 (0.6685) IoU=29.66 (44.32) Prec@50=33.96 (43.73)
226
+ 2025-03-03 02:01:02 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 300/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6285 (0.6760) IoU=49.97 (44.40) Prec@50=53.97 (43.81)
227
+ 2025-03-03 02:03:29 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 400/1759] Batch=1.83 (1.48) Data=0.01 (0.03) Lr=0.000100 Loss=0.8127 (0.6792) IoU=33.49 (44.24) Prec@50=30.44 (43.48)
228
+ 2025-03-03 02:05:57 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 500/1759] Batch=1.61 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6658 (0.6781) IoU=43.32 (44.19) Prec@50=49.29 (43.60)
229
+ 2025-03-03 02:08:23 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 600/1759] Batch=1.67 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6875 (0.6804) IoU=45.37 (43.86) Prec@50=38.48 (43.05)
230
+ 2025-03-03 02:10:51 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 700/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6716 (0.6817) IoU=49.63 (43.89) Prec@50=45.83 (42.92)
231
+ 2025-03-03 02:13:18 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 800/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8302 (0.6830) IoU=40.04 (44.00) Prec@50=37.35 (43.08)
232
+ 2025-03-03 02:15:43 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 900/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8283 (0.6843) IoU=36.38 (44.04) Prec@50=28.97 (43.11)
233
+ 2025-03-03 02:18:09 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1000/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6356 (0.6843) IoU=48.44 (43.85) Prec@50=50.15 (42.86)
234
+ 2025-03-03 02:20:34 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1100/1759] Batch=1.86 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5373 (0.6842) IoU=44.84 (43.78) Prec@50=50.40 (42.74)
235
+ 2025-03-03 02:23:04 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1200/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6934 (0.6824) IoU=43.73 (43.80) Prec@50=42.86 (42.84)
236
+ 2025-03-03 02:25:31 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1300/1759] Batch=1.38 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.8082 (0.6824) IoU=44.25 (43.82) Prec@50=43.25 (42.90)
237
+ 2025-03-03 02:27:58 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1400/1759] Batch=1.77 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7234 (0.6808) IoU=37.74 (43.89) Prec@50=31.61 (42.97)
238
+ 2025-03-03 02:30:25 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6207 (0.6785) IoU=52.28 (44.04) Prec@50=49.60 (43.19)
239
+ 2025-03-03 02:32:52 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1600/1759] Batch=1.38 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5630 (0.6783) IoU=50.30 (44.12) Prec@50=51.19 (43.27)
240
+ 2025-03-03 02:35:20 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1700/1759] Batch=1.35 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6715 (0.6784) IoU=39.24 (44.26) Prec@50=39.93 (43.41)
241
+ 2025-03-03 02:37:23 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[3/50] mIoU=52.86 oIoU=49.37 Pr@50: 54.19 Pr@60: 45.54 Pr@70: 36.26 Pr@80: 25.82 Pr@90: 10.52
242
+ 2025-03-03 02:40:20 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 100/1759] Batch=1.27 (1.50) Data=0.00 (0.04) Lr=0.000100 Loss=0.6958 (0.6286) IoU=44.29 (46.25) Prec@50=38.39 (45.99)
243
+ 2025-03-03 02:42:47 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 200/1759] Batch=1.30 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6073 (0.6325) IoU=46.17 (46.22) Prec@50=52.38 (45.96)
244
+ 2025-03-03 02:45:16 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 300/1759] Batch=1.32 (1.49) Data=0.00 (0.03) Lr=0.000100 Loss=0.5563 (0.6330) IoU=52.67 (46.34) Prec@50=59.87 (45.94)
245
+ 2025-03-03 02:47:42 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 400/1759] Batch=1.30 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6527 (0.6364) IoU=52.90 (46.53) Prec@50=59.38 (46.17)
246
+ 2025-03-03 02:50:08 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 500/1759] Batch=1.28 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.7021 (0.6343) IoU=49.21 (46.89) Prec@50=42.71 (46.54)
247
+ 2025-03-03 02:52:34 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 600/1759] Batch=1.24 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7423 (0.6356) IoU=39.62 (46.82) Prec@50=36.61 (46.63)
248
+ 2025-03-03 02:55:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 700/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.7147 (0.6425) IoU=42.98 (46.63) Prec@50=37.15 (46.32)
249
+ 2025-03-03 02:57:30 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 800/1759] Batch=1.77 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5617 (0.6499) IoU=42.31 (46.31) Prec@50=36.67 (45.86)
250
+ 2025-03-03 02:59:58 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 900/1759] Batch=1.70 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.6406 (0.6511) IoU=46.95 (46.46) Prec@50=45.00 (46.10)
251
+ 2025-03-03 03:02:23 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1000/1759] Batch=1.28 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5956 (0.6502) IoU=47.62 (46.59) Prec@50=43.75 (46.30)
252
+ 2025-03-03 03:04:52 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1100/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6672 (0.6503) IoU=51.07 (46.69) Prec@50=51.29 (46.50)
253
+ 2025-03-03 03:07:16 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1200/1759] Batch=1.22 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6364 (0.6505) IoU=51.87 (46.69) Prec@50=51.93 (46.48)
254
+ 2025-03-03 03:09:43 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1300/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4925 (0.6489) IoU=57.07 (46.74) Prec@50=55.01 (46.50)
255
+ 2025-03-03 03:12:10 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1400/1759] Batch=1.79 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.6326 (0.6491) IoU=47.66 (46.81) Prec@50=46.04 (46.58)
256
+ 2025-03-03 03:14:35 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6763 (0.6487) IoU=52.54 (46.97) Prec@50=58.58 (46.77)
257
+ 2025-03-03 03:17:03 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1600/1759] Batch=1.41 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6583 (0.6469) IoU=48.20 (47.13) Prec@50=50.83 (46.97)
258
+ 2025-03-03 03:19:31 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1700/1759] Batch=1.27 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6820 (0.6478) IoU=43.32 (47.12) Prec@50=48.66 (46.90)
259
+ 2025-03-03 03:21:35 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[4/50] mIoU=54.75 oIoU=50.99 Pr@50: 57.73 Pr@60: 48.80 Pr@70: 40.37 Pr@80: 29.04 Pr@90: 12.15
260
+ 2025-03-03 03:24:28 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 100/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5842 (0.5822) IoU=52.98 (50.21) Prec@50=50.89 (51.29)
261
+ 2025-03-03 03:26:54 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 200/1759] Batch=1.42 (1.46) Data=0.01 (0.03) Lr=0.000100 Loss=0.5671 (0.5931) IoU=52.04 (49.29) Prec@50=67.11 (50.16)
262
+ 2025-03-03 03:29:22 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 300/1759] Batch=1.35 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6269 (0.5962) IoU=49.78 (48.97) Prec@50=53.17 (49.76)
263
+ 2025-03-03 03:31:49 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 400/1759] Batch=1.48 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5380 (0.5930) IoU=53.36 (49.08) Prec@50=60.08 (50.03)
264
+ 2025-03-03 03:34:16 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 500/1759] Batch=1.79 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6159 (0.5929) IoU=46.81 (49.18) Prec@50=45.49 (50.33)
265
+ 2025-03-03 03:36:45 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 600/1759] Batch=1.43 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4669 (0.5917) IoU=56.32 (49.45) Prec@50=55.22 (50.71)
266
+ 2025-03-03 03:39:10 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 700/1759] Batch=1.57 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6299 (0.5933) IoU=44.67 (49.38) Prec@50=45.40 (50.54)
267
+ 2025-03-03 03:41:37 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 800/1759] Batch=1.43 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7958 (0.5949) IoU=42.46 (49.27) Prec@50=42.29 (50.46)
268
+ 2025-03-03 03:44:03 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 900/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7949 (0.5958) IoU=40.19 (49.17) Prec@50=37.70 (50.46)
269
+ 2025-03-03 03:46:30 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1000/1759] Batch=1.75 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4842 (0.5962) IoU=55.96 (49.27) Prec@50=60.33 (50.63)
270
+ 2025-03-03 03:48:56 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1100/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5489 (0.5967) IoU=55.96 (49.18) Prec@50=60.76 (50.47)
271
+ 2025-03-03 03:51:24 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1200/1759] Batch=1.77 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5284 (0.5968) IoU=48.37 (49.16) Prec@50=56.47 (50.49)
272
+ 2025-03-03 03:53:48 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1300/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5359 (0.5963) IoU=50.98 (49.14) Prec@50=49.26 (50.53)
273
+ 2025-03-03 03:56:18 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1400/1759] Batch=1.38 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.8166 (0.5978) IoU=42.17 (48.99) Prec@50=36.31 (50.28)
274
+ 2025-03-03 03:58:44 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6011 (0.5979) IoU=48.88 (49.04) Prec@50=53.27 (50.35)
275
+ 2025-03-03 04:01:11 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1600/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7017 (0.5988) IoU=47.34 (48.99) Prec@50=45.04 (50.25)
276
+ 2025-03-03 04:03:39 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1700/1759] Batch=1.66 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5102 (0.5985) IoU=51.01 (49.05) Prec@50=53.77 (50.32)
277
+ 2025-03-03 04:05:41 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[5/50] mIoU=56.95 oIoU=53.30 Pr@50: 60.75 Pr@60: 53.26 Pr@70: 44.57 Pr@80: 32.76 Pr@90: 14.32
278
+ 2025-03-03 04:08:38 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 100/1759] Batch=1.27 (1.49) Data=0.00 (0.04) Lr=0.000100 Loss=0.6468 (0.5652) IoU=49.76 (51.26) Prec@50=42.71 (53.73)
279
+ 2025-03-03 04:11:06 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 200/1759] Batch=1.41 (1.49) Data=0.00 (0.03) Lr=0.000100 Loss=0.6561 (0.5522) IoU=44.92 (52.29) Prec@50=38.89 (54.84)
280
+ 2025-03-03 04:13:33 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 300/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4852 (0.5548) IoU=52.33 (51.63) Prec@50=54.51 (54.21)
281
+ 2025-03-03 04:15:59 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 400/1759] Batch=1.37 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5306 (0.5601) IoU=55.68 (51.13) Prec@50=62.65 (53.69)
282
+ 2025-03-03 04:18:27 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 500/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4842 (0.5597) IoU=56.85 (51.28) Prec@50=55.85 (53.78)
283
+ 2025-03-03 04:20:53 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 600/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5329 (0.5586) IoU=47.32 (51.24) Prec@50=50.00 (53.65)
284
+ 2025-03-03 04:23:18 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 700/1759] Batch=1.14 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5324 (0.5594) IoU=56.30 (51.25) Prec@50=59.52 (53.65)
285
+ 2025-03-03 04:25:47 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 800/1759] Batch=1.23 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5867 (0.5610) IoU=50.78 (51.24) Prec@50=52.38 (53.66)
286
+ 2025-03-03 04:28:14 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 900/1759] Batch=1.49 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5270 (0.5646) IoU=51.31 (51.14) Prec@50=48.21 (53.48)
287
+ 2025-03-03 04:30:41 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1000/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5509 (0.5655) IoU=52.33 (51.12) Prec@50=51.79 (53.51)
288
+ 2025-03-03 04:33:09 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1100/1759] Batch=1.25 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5011 (0.5665) IoU=59.98 (51.03) Prec@50=73.66 (53.37)
289
+ 2025-03-03 04:35:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1200/1759] Batch=1.13 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7173 (0.5666) IoU=49.50 (51.04) Prec@50=52.98 (53.44)
290
+ 2025-03-03 04:38:02 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1300/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6343 (0.5665) IoU=46.32 (50.93) Prec@50=45.29 (53.32)
291
+ 2025-03-03 04:40:29 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1400/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5961 (0.5674) IoU=49.77 (50.91) Prec@50=50.00 (53.32)
292
+ 2025-03-03 04:42:57 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1500/1759] Batch=1.43 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4715 (0.5672) IoU=52.89 (50.96) Prec@50=54.68 (53.44)
293
+ 2025-03-03 04:45:21 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1600/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5379 (0.5673) IoU=52.68 (50.94) Prec@50=58.28 (53.43)
294
+ 2025-03-03 04:47:46 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1700/1759] Batch=1.70 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6034 (0.5678) IoU=48.31 (50.90) Prec@50=47.55 (53.37)
295
+ 2025-03-03 04:49:50 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[6/50] mIoU=57.77 oIoU=53.88 Pr@50: 62.42 Pr@60: 54.46 Pr@70: 45.96 Pr@80: 34.74 Pr@90: 15.64
296
+ 2025-03-03 04:52:43 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 100/1759] Batch=1.26 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4772 (0.5159) IoU=58.15 (52.95) Prec@50=68.60 (57.46)
297
+ 2025-03-03 04:55:10 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 200/1759] Batch=1.26 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5554 (0.5202) IoU=53.89 (52.93) Prec@50=58.48 (57.39)
298
+ 2025-03-03 04:57:36 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 300/1759] Batch=1.34 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4606 (0.5208) IoU=58.93 (53.09) Prec@50=64.43 (57.32)
299
+ 2025-03-03 05:00:03 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 400/1759] Batch=1.41 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4532 (0.5196) IoU=59.22 (53.45) Prec@50=74.80 (57.58)
300
+ 2025-03-03 05:02:29 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 500/1759] Batch=1.38 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5778 (0.5210) IoU=59.85 (53.71) Prec@50=80.95 (57.77)
301
+ 2025-03-03 05:04:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 600/1759] Batch=1.72 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5753 (0.5226) IoU=56.04 (53.67) Prec@50=55.98 (57.49)
302
+ 2025-03-03 05:07:23 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 700/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.6087 (0.5253) IoU=51.54 (53.41) Prec@50=50.00 (57.11)
303
+ 2025-03-03 05:09:49 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 800/1759] Batch=1.36 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6311 (0.5259) IoU=55.29 (53.46) Prec@50=55.51 (57.13)
304
+ 2025-03-03 05:12:16 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 900/1759] Batch=1.26 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.7219 (0.5281) IoU=49.02 (53.55) Prec@50=53.87 (57.31)
305
+ 2025-03-03 05:14:42 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1000/1759] Batch=1.43 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4692 (0.5275) IoU=65.16 (53.58) Prec@50=78.17 (57.38)
306
+ 2025-03-03 05:17:07 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1100/1759] Batch=1.26 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5824 (0.5272) IoU=45.79 (53.59) Prec@50=48.66 (57.39)
307
+ 2025-03-03 05:19:35 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1200/1759] Batch=1.81 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4784 (0.5274) IoU=53.63 (53.55) Prec@50=54.79 (57.31)
308
+ 2025-03-03 05:22:04 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1300/1759] Batch=1.72 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5570 (0.5278) IoU=51.42 (53.54) Prec@50=50.05 (57.32)
309
+ 2025-03-03 05:24:31 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1400/1759] Batch=1.28 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4228 (0.5272) IoU=57.91 (53.52) Prec@50=65.62 (57.28)
310
+ 2025-03-03 05:27:00 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1500/1759] Batch=1.83 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.4513 (0.5273) IoU=58.91 (53.56) Prec@50=49.84 (57.28)
311
+ 2025-03-03 05:29:27 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1600/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4685 (0.5275) IoU=58.79 (53.56) Prec@50=64.34 (57.24)
312
+ 2025-03-03 05:31:54 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1700/1759] Batch=1.62 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4928 (0.5272) IoU=47.91 (53.57) Prec@50=45.88 (57.28)
313
+ 2025-03-03 05:33:56 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[7/50] mIoU=60.25 oIoU=55.60 Pr@50: 65.68 Pr@60: 59.39 Pr@70: 51.09 Pr@80: 38.20 Pr@90: 17.59
314
+ 2025-03-03 05:36:51 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 100/1759] Batch=1.44 (1.48) Data=0.00 (0.04) Lr=0.000100 Loss=0.4619 (0.4767) IoU=50.05 (57.19) Prec@50=56.96 (61.73)
315
+ 2025-03-03 05:39:19 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 200/1759] Batch=1.33 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4224 (0.4827) IoU=60.05 (56.71) Prec@50=67.26 (61.55)
316
+ 2025-03-03 05:41:44 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 300/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4845 (0.4818) IoU=55.20 (56.42) Prec@50=65.23 (61.33)
317
+ 2025-03-03 05:44:13 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 400/1759] Batch=1.95 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5197 (0.4854) IoU=51.30 (56.06) Prec@50=53.19 (60.90)
318
+ 2025-03-03 05:46:43 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 500/1759] Batch=1.39 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.5332 (0.4882) IoU=48.13 (55.89) Prec@50=51.34 (60.70)
319
+ 2025-03-03 05:49:08 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 600/1759] Batch=1.43 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3839 (0.4883) IoU=52.71 (55.88) Prec@50=55.28 (60.68)
320
+ 2025-03-03 05:51:36 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 700/1759] Batch=1.25 (1.48) Data=0.00 (0.02) Lr=0.000100 Loss=0.4711 (0.4896) IoU=53.68 (55.75) Prec@50=63.10 (60.51)
321
+ 2025-03-03 05:54:02 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 800/1759] Batch=1.84 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3726 (0.4901) IoU=57.30 (55.78) Prec@50=64.32 (60.39)
322
+ 2025-03-03 05:56:27 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 900/1759] Batch=1.25 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4996 (0.4914) IoU=56.04 (55.96) Prec@50=54.91 (60.61)
323
+ 2025-03-03 05:58:54 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1000/1759] Batch=1.41 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.5232 (0.4923) IoU=60.26 (55.92) Prec@50=69.84 (60.50)
324
+ 2025-03-03 06:01:18 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1100/1759] Batch=1.80 (1.47) Data=0.01 (0.02) Lr=0.000100 Loss=0.5720 (0.4947) IoU=50.41 (55.79) Prec@50=54.66 (60.34)
325
+ 2025-03-03 06:03:42 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1200/1759] Batch=1.44 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4019 (0.4951) IoU=61.91 (55.80) Prec@50=69.10 (60.33)
326
+ 2025-03-03 06:06:09 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1300/1759] Batch=1.35 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5695 (0.4957) IoU=54.77 (55.78) Prec@50=55.75 (60.21)
327
+ 2025-03-03 06:08:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1400/1759] Batch=1.84 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6492 (0.4982) IoU=42.49 (55.71) Prec@50=42.14 (60.10)
328
+ 2025-03-03 06:11:03 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1500/1759] Batch=1.36 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5565 (0.4978) IoU=49.92 (55.73) Prec@50=46.83 (60.09)
329
+ 2025-03-03 06:13:31 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1600/1759] Batch=1.30 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4775 (0.4973) IoU=63.03 (55.75) Prec@50=68.60 (60.11)
330
+ 2025-03-03 06:16:00 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1700/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4690 (0.4972) IoU=55.66 (55.79) Prec@50=46.83 (60.15)
331
+ 2025-03-03 06:18:02 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[8/50] mIoU=60.71 oIoU=56.70 Pr@50: 66.73 Pr@60: 59.36 Pr@70: 50.62 Pr@80: 38.63 Pr@90: 16.96
332
+ 2025-03-03 06:20:56 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 100/1759] Batch=1.45 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3857 (0.4604) IoU=59.86 (58.29) Prec@50=64.48 (63.98)
333
+ 2025-03-03 06:23:21 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 200/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4570 (0.4735) IoU=63.38 (57.40) Prec@50=73.66 (62.69)
334
+ 2025-03-03 06:25:48 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 300/1759] Batch=1.43 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.5468 (0.4705) IoU=55.90 (57.64) Prec@50=59.08 (62.91)
335
+ 2025-03-03 06:28:13 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 400/1759] Batch=1.28 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4366 (0.4683) IoU=63.43 (57.61) Prec@50=68.75 (62.74)
336
+ 2025-03-03 06:30:38 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 500/1759] Batch=1.31 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.6139 (0.4667) IoU=54.24 (57.58) Prec@50=55.75 (62.60)
337
+ 2025-03-03 06:33:03 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 600/1759] Batch=1.22 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5159 (0.4657) IoU=58.25 (57.54) Prec@50=59.82 (62.54)
338
+ 2025-03-03 06:35:27 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 700/1759] Batch=1.20 (1.45) Data=0.00 (0.02) Lr=0.000100 Loss=0.4972 (0.4664) IoU=57.24 (57.36) Prec@50=67.56 (62.27)
339
+ 2025-03-03 06:37:55 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 800/1759] Batch=1.37 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5249 (0.4666) IoU=58.34 (57.30) Prec@50=60.76 (62.13)
340
+ 2025-03-03 06:40:21 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 900/1759] Batch=1.27 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5121 (0.4688) IoU=55.35 (57.25) Prec@50=52.23 (62.01)
341
+ 2025-03-03 06:42:47 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1000/1759] Batch=1.45 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5049 (0.4697) IoU=53.27 (57.15) Prec@50=58.41 (61.89)
342
+ 2025-03-03 06:45:12 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1100/1759] Batch=1.25 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5198 (0.4704) IoU=57.83 (57.14) Prec@50=58.48 (61.89)
343
+ 2025-03-03 06:47:41 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1200/1759] Batch=1.39 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4482 (0.4701) IoU=56.85 (57.09) Prec@50=62.10 (61.84)
344
+ 2025-03-03 06:50:09 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1300/1759] Batch=1.39 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6194 (0.4711) IoU=45.18 (57.02) Prec@50=49.31 (61.76)
345
+ 2025-03-03 06:52:37 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1400/1759] Batch=1.26 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5719 (0.4718) IoU=53.01 (56.94) Prec@50=56.99 (61.68)
346
+ 2025-03-03 06:55:05 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1500/1759] Batch=1.31 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3935 (0.4720) IoU=62.80 (56.86) Prec@50=68.85 (61.58)
347
+ 2025-03-03 06:57:30 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1600/1759] Batch=1.32 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.5472 (0.4724) IoU=60.17 (56.79) Prec@50=65.97 (61.48)
348
+ 2025-03-03 06:59:58 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1700/1759] Batch=1.38 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3759 (0.4721) IoU=60.89 (56.73) Prec@50=63.24 (61.42)
349
+ 2025-03-03 07:01:58 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[9/50] mIoU=61.79 oIoU=57.95 Pr@50: 68.91 Pr@60: 61.84 Pr@70: 54.15 Pr@80: 40.99 Pr@90: 18.32
350
+ 2025-03-03 07:04:52 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 100/1759] Batch=1.42 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4251 (0.4314) IoU=59.89 (58.14) Prec@50=68.54 (63.57)
351
+ 2025-03-03 07:07:22 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 200/1759] Batch=1.69 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4248 (0.4374) IoU=62.13 (58.02) Prec@50=67.50 (63.52)
352
+ 2025-03-03 07:09:47 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 300/1759] Batch=1.39 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.3592 (0.4369) IoU=56.82 (58.28) Prec@50=68.06 (63.59)
353
+ 2025-03-03 07:12:15 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 400/1759] Batch=1.35 (1.47) Data=0.01 (0.03) Lr=0.000100 Loss=0.4061 (0.4374) IoU=64.88 (58.70) Prec@50=70.29 (63.98)
354
+ 2025-03-03 07:14:41 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 500/1759] Batch=1.37 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4022 (0.4395) IoU=60.43 (58.72) Prec@50=59.42 (63.88)
355
+ 2025-03-03 07:17:09 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 600/1759] Batch=1.32 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.6635 (0.4405) IoU=42.66 (58.58) Prec@50=49.95 (63.78)
356
+ 2025-03-03 07:19:37 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 700/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4055 (0.4416) IoU=63.60 (58.54) Prec@50=67.86 (63.75)
357
+ 2025-03-03 07:22:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 800/1759] Batch=1.22 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3670 (0.4434) IoU=68.35 (58.42) Prec@50=73.66 (63.65)
358
+ 2025-03-03 07:24:32 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 900/1759] Batch=1.27 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4577 (0.4437) IoU=59.77 (58.38) Prec@50=72.02 (63.55)
359
+ 2025-03-03 07:26:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1000/1759] Batch=1.38 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.3914 (0.4433) IoU=56.84 (58.38) Prec@50=71.92 (63.59)
360
+ 2025-03-03 07:29:24 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1100/1759] Batch=1.40 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4301 (0.4452) IoU=64.13 (58.34) Prec@50=64.24 (63.54)
361
+ 2025-03-03 07:31:49 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1200/1759] Batch=1.78 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5309 (0.4448) IoU=51.24 (58.37) Prec@50=55.62 (63.55)
362
+ 2025-03-03 07:34:15 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1300/1759] Batch=1.31 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.5097 (0.4461) IoU=53.53 (58.35) Prec@50=59.08 (63.49)
363
+ 2025-03-03 07:36:42 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1400/1759] Batch=1.40 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4394 (0.4461) IoU=58.68 (58.39) Prec@50=59.52 (63.55)
364
+ 2025-03-03 07:39:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1500/1759] Batch=1.39 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4440 (0.4464) IoU=59.17 (58.38) Prec@50=67.06 (63.51)
365
+ 2025-03-03 07:41:31 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1600/1759] Batch=1.34 (1.47) Data=0.00 (0.02) Lr=0.000100 Loss=0.4039 (0.4464) IoU=56.05 (58.33) Prec@50=60.76 (63.41)
366
+ 2025-03-03 07:43:56 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1700/1759] Batch=1.61 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4076 (0.4472) IoU=57.63 (58.25) Prec@50=65.15 (63.35)
367
+ 2025-03-03 07:46:00 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[10/50] mIoU=61.95 oIoU=58.15 Pr@50: 69.22 Pr@60: 62.58 Pr@70: 54.46 Pr@80: 42.31 Pr@90: 19.91
368
+ 2025-03-03 07:48:54 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 100/1759] Batch=1.31 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.4570 (0.3980) IoU=60.56 (59.63) Prec@50=61.71 (65.29)
369
+ 2025-03-03 07:51:22 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 200/1759] Batch=1.44 (1.48) Data=0.00 (0.03) Lr=0.000100 Loss=0.4088 (0.4088) IoU=55.49 (58.90) Prec@50=56.47 (64.34)
370
+ 2025-03-03 07:53:49 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 300/1759] Batch=1.34 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5129 (0.4076) IoU=58.42 (59.49) Prec@50=58.73 (65.27)
371
+ 2025-03-03 07:56:14 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 400/1759] Batch=1.36 (1.47) Data=0.00 (0.03) Lr=0.000100 Loss=0.5632 (0.4112) IoU=48.45 (59.59) Prec@50=45.14 (65.45)
372
+ 2025-03-03 07:58:39 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 500/1759] Batch=1.45 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3631 (0.4145) IoU=60.37 (59.62) Prec@50=61.03 (65.47)
373
+ 2025-03-03 08:01:04 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 600/1759] Batch=1.42 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.2917 (0.4155) IoU=71.40 (59.60) Prec@50=78.17 (65.43)
374
+ 2025-03-03 08:03:30 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 700/1759] Batch=1.44 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.4789 (0.4174) IoU=58.62 (59.62) Prec@50=61.34 (65.38)
375
+ 2025-03-03 08:05:57 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 800/1759] Batch=1.33 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3516 (0.4210) IoU=60.97 (59.48) Prec@50=69.59 (65.11)
376
+ 2025-03-03 08:08:21 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 900/1759] Batch=1.95 (1.46) Data=0.00 (0.03) Lr=0.000100 Loss=0.3417 (0.4205) IoU=59.15 (59.43) Prec@50=67.80 (65.06)
377
+ 2025-03-03 08:10:47 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1000/1759] Batch=1.21 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3900 (0.4195) IoU=64.07 (59.51) Prec@50=68.75 (65.12)
378
+ 2025-03-03 08:13:12 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1100/1759] Batch=1.51 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3569 (0.4213) IoU=69.53 (59.50) Prec@50=74.52 (65.06)
379
+ 2025-03-03 08:15:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1200/1759] Batch=1.22 (1.46) Data=0.01 (0.02) Lr=0.000100 Loss=0.5098 (0.4224) IoU=58.16 (59.38) Prec@50=62.05 (64.97)
380
+ 2025-03-03 08:18:03 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1300/1759] Batch=1.35 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.6300 (0.4239) IoU=47.06 (59.24) Prec@50=51.79 (64.82)
381
+ 2025-03-03 08:20:26 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1400/1759] Batch=1.32 (1.46) Data=0.01 (0.02) Lr=0.000100 Loss=0.3446 (0.4252) IoU=64.94 (59.19) Prec@50=75.60 (64.73)
382
+ 2025-03-03 08:22:51 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1500/1759] Batch=1.28 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.4085 (0.4247) IoU=60.33 (59.26) Prec@50=66.96 (64.89)
383
+ 2025-03-03 08:25:17 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1600/1759] Batch=1.81 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.2942 (0.4246) IoU=60.72 (59.32) Prec@50=59.59 (64.92)
384
+ 2025-03-03 08:27:44 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1700/1759] Batch=1.20 (1.46) Data=0.00 (0.02) Lr=0.000100 Loss=0.3268 (0.4244) IoU=71.48 (59.39) Prec@50=83.04 (64.99)
385
+ 2025-03-03 08:29:47 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[11/50] mIoU=62.68 oIoU=58.89 Pr@50: 69.95 Pr@60: 62.97 Pr@70: 54.85 Pr@80: 43.25 Pr@90: 20.81
386
+ [2025-03-03 08:31:52,535] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGINT death signal, shutting down workers
387
+ [2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017036 closing signal SIGINT
388
+ [2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017037 closing signal SIGINT
389
+ [2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017038 closing signal SIGINT
390
+ [2025-03-03 08:31:52,536] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017039 closing signal SIGINT
391
+ Exception in thread Thread-24:
392
+ Traceback (most recent call last):
393
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/threading.py", line 980, in _bootstrap_inner
394
+ self.run()
395
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/threading.py", line 917, in run
396
+ self._target(*self._args, **self._kwargs)
397
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
398
+ do_one_step()
399
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
400
+ r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
401
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/multiprocessing/queues.py", line 122, in get
402
+ return _ForkingPickler.loads(res)
403
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 355, in rebuild_storage_fd
404
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fa02edd68b0>
405
+ Traceback (most recent call last):
406
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
407
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6b5cf638b0>
408
+ Traceback (most recent call last):
409
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
410
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f4c191bd8b0>
411
+ Traceback (most recent call last):
412
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
413
+ [2025-03-03 08:31:52,707] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017036 closing signal SIGTERM
414
+ [2025-03-03 08:31:52,707] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017037 closing signal SIGTERM
415
+ [2025-03-03 08:31:52,708] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017038 closing signal SIGTERM
416
+ [2025-03-03 08:31:52,708] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2017039 closing signal SIGTERM
417
+ Traceback (most recent call last):
418
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 736, in run
419
+ result = self._invoke_run(role)
420
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 877, in _invoke_run
421
+ time.sleep(monitor_interval)
422
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
423
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
424
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
425
+
426
+ During handling of the above exception, another exception occurred:
427
+
428
+ Traceback (most recent call last):
429
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 743, in run
430
+ self._shutdown(e.sigval)
431
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
432
+ self._pcontext.close(death_sig)
433
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
434
+ self._close(death_sig=death_sig, timeout=timeout)
435
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
436
+ handler.proc.wait(time_to_wait)
437
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
438
+ return self._wait(timeout=timeout)
439
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
440
+ time.sleep(delay)
441
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
442
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
443
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
444
+
445
+ During handling of the above exception, another exception occurred:
446
+
447
+ Traceback (most recent call last):
448
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 197, in _run_module_as_main
449
+ return _run_code(code, main_globals, None,
450
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 87, in _run_code
451
+ exec(code, run_globals)
452
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 196, in <module>
453
+ main()
454
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 192, in main
455
+ launch(args)
456
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 177, in launch
457
+ run(args)
458
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/run.py", line 797, in run
459
+ elastic_launch(
460
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
461
+ return launch_agent(self._config, self._entrypoint, list(args))
462
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 255, in launch_agent
463
+ result = agent.run()
464
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
465
+ result = f(*args, **kwargs)
466
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 748, in run
467
+ self._shutdown()
468
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
469
+ self._pcontext.close(death_sig)
470
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
471
+ self._close(death_sig=death_sig, timeout=timeout)
472
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
473
+ handler.proc.wait(time_to_wait)
474
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
475
+ return self._wait(timeout=timeout)
476
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
477
+ time.sleep(delay)
478
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
479
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
480
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2017022 got signal: 2
CGFormer/bash_logs/ACE_filter050_rev.log ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py:181: FutureWarning: The module torch.distributed.launch is deprecated
2
+ and will be removed in future. Use torchrun.
3
+ Note that --use-env is set by default in torchrun.
4
+ If your script expects `--local-rank` argument to be set, please
5
+ change it to read from `os.environ['LOCAL_RANK']` instead. See
6
+ https://pytorch.org/docs/stable/distributed.html#launch-utility for
7
+ further instructions
8
+
9
+ warnings.warn(
10
+ [2025-03-03 16:23:35,171] torch.distributed.run: [WARNING]
11
+ [2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] *****************************************
12
+ [2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
13
+ [2025-03-03 16:23:35,171] torch.distributed.run: [WARNING] *****************************************
14
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
15
+ check_for_updates()
16
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
17
+ check_for_updates()
18
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
19
+ check_for_updates()
20
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
21
+ check_for_updates()
22
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
23
+ check_for_updates()
24
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/albumentations/__init__.py:24: UserWarning: A new version of Albumentations is available: 2.0.5 (you have 1.4.24). Upgrade using: pip install -U albumentations. To disable automatic update checks, set the environment variable NO_ALBUMENTATIONS_UPDATE to 1.
25
+ check_for_updates()
26
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
27
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
28
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
29
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
30
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
31
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
32
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
33
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
34
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
35
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
36
+ /home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
37
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)
38
+ 2025-03-03 16:24:00.550 | INFO | __main__:main:66 - LOCAL_RANK from env: 2
39
+ 2025-03-03 16:24:00.550 | INFO | __main__:main:66 - LOCAL_RANK from env: 5
40
+ 2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 4
41
+ 2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 0
42
+ 2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 1
43
+ 2025-03-03 16:24:00.551 | INFO | __main__:main:66 - LOCAL_RANK from env: 3
44
+ 2025-03-03 16:24:00 | INFO | __main__:90 - Starting with GPU: 0, Rank: 0, World Size: 6
45
+ git root error: Cmd('git') failed due to: exit code(128)
46
+ cmdline: git rev-parse --show-toplevel
47
+ stderr: 'fatal: detected dubious ownership in repository at '/data2/projects/chaeyun/CGFormer'
48
+ To add an exception for this directory, call:
49
+
50
+ git config --global --add safe.directory /data2/projects/chaeyun/CGFormer'
51
+ wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
52
+ wandb: Tracking run with wandb version 0.19.1
53
+ wandb: W&B syncing is set to `offline` in this directory.
54
+ wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
55
+ node03:2316571:2316571 [0] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
56
+ node03:2316571:2316571 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
57
+ node03:2316571:2316571 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
58
+ node03:2316571:2316571 [0] NCCL INFO cudaDriverVersion 12070
59
+ NCCL version 2.18.5+cuda11.8
60
+ node03:2316574:2316574 [3] NCCL INFO cudaDriverVersion 12070
61
+ node03:2316575:2316575 [4] NCCL INFO cudaDriverVersion 12070
62
+ node03:2316572:2316572 [1] NCCL INFO cudaDriverVersion 12070
63
+ node03:2316576:2316576 [5] NCCL INFO cudaDriverVersion 12070
64
+ node03:2316573:2316573 [2] NCCL INFO cudaDriverVersion 12070
65
+ node03:2316575:2316575 [4] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
66
+ node03:2316574:2316574 [3] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
67
+ node03:2316576:2316576 [5] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
68
+ node03:2316572:2316572 [1] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
69
+ node03:2316573:2316573 [2] NCCL INFO Bootstrap : Using eth2:10.1.10.3<0>
70
+ node03:2316575:2316575 [4] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
71
+ node03:2316575:2316575 [4] NCCL INFO NET/Plugin : No plugin found, using internal implementation
72
+ node03:2316574:2316574 [3] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
73
+ node03:2316574:2316574 [3] NCCL INFO NET/Plugin : No plugin found, using internal implementation
74
+ node03:2316572:2316572 [1] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
75
+ node03:2316572:2316572 [1] NCCL INFO NET/Plugin : No plugin found, using internal implementation
76
+ node03:2316576:2316576 [5] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
77
+ node03:2316573:2316573 [2] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
78
+ node03:2316576:2316576 [5] NCCL INFO NET/Plugin : No plugin found, using internal implementation
79
+ node03:2316573:2316573 [2] NCCL INFO NET/Plugin : No plugin found, using internal implementation
80
+ node03:2316571:2316708 [0] NCCL INFO NET/IB : No device found.
81
+ node03:2316573:2316712 [2] NCCL INFO NET/IB : No device found.
82
+ node03:2316571:2316708 [0] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
83
+ node03:2316571:2316708 [0] NCCL INFO Using network Socket
84
+ node03:2316573:2316712 [2] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
85
+ node03:2316573:2316712 [2] NCCL INFO Using network Socket
86
+ node03:2316574:2316710 [3] NCCL INFO NET/IB : No device found.
87
+ node03:2316574:2316710 [3] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
88
+ node03:2316574:2316710 [3] NCCL INFO Using network Socket
89
+ node03:2316572:2316711 [1] NCCL INFO NET/IB : No device found.
90
+ node03:2316572:2316711 [1] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
91
+ node03:2316572:2316711 [1] NCCL INFO Using network Socket
92
+ node03:2316576:2316713 [5] NCCL INFO NET/IB : No device found.
93
+ node03:2316576:2316713 [5] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
94
+ node03:2316576:2316713 [5] NCCL INFO Using network Socket
95
+ node03:2316575:2316709 [4] NCCL INFO NET/IB : No device found.
96
+ node03:2316575:2316709 [4] NCCL INFO NET/Socket : Using [0]eth2:10.1.10.3<0>
97
+ node03:2316575:2316709 [4] NCCL INFO Using network Socket
98
+ node03:2316571:2316708 [0] NCCL INFO comm 0xa205d10 rank 0 nranks 6 cudaDev 0 nvmlDev 0 busId 12000 commId 0x5c97a0f7f601b696 - Init START
99
+ node03:2316573:2316712 [2] NCCL INFO comm 0xac1c200 rank 2 nranks 6 cudaDev 2 nvmlDev 2 busId 14000 commId 0x5c97a0f7f601b696 - Init START
100
+ node03:2316576:2316713 [5] NCCL INFO comm 0x9b6ede0 rank 5 nranks 6 cudaDev 5 nvmlDev 5 busId c1000 commId 0x5c97a0f7f601b696 - Init START
101
+ node03:2316572:2316711 [1] NCCL INFO comm 0xa155230 rank 1 nranks 6 cudaDev 1 nvmlDev 1 busId 13000 commId 0x5c97a0f7f601b696 - Init START
102
+ node03:2316575:2316709 [4] NCCL INFO comm 0xa46f6f0 rank 4 nranks 6 cudaDev 4 nvmlDev 4 busId c0000 commId 0x5c97a0f7f601b696 - Init START
103
+ node03:2316574:2316710 [3] NCCL INFO comm 0xadbe390 rank 3 nranks 6 cudaDev 3 nvmlDev 3 busId 48000 commId 0x5c97a0f7f601b696 - Init START
104
+ node03:2316573:2316712 [2] NCCL INFO Setting affinity for GPU 2 to 14005500,00140055
105
+ node03:2316572:2316711 [1] NCCL INFO Setting affinity for GPU 1 to 14005500,00140055
106
+ node03:2316574:2316710 [3] NCCL INFO Setting affinity for GPU 3 to 14005500,00140055
107
+ node03:2316571:2316708 [0] NCCL INFO Setting affinity for GPU 0 to 14005500,00140055
108
+ node03:2316571:2316708 [0] NCCL INFO Channel 00/02 : 0 1 2 3 4 5
109
+ node03:2316575:2316709 [4] NCCL INFO Trees [0] 5/-1/-1->4->3 [1] 5/-1/-1->4->3
110
+ node03:2316572:2316711 [1] NCCL INFO Trees [0] 2/-1/-1->1->0 [1] 2/-1/-1->1->0
111
+ node03:2316574:2316710 [3] NCCL INFO Trees [0] 4/-1/-1->3->2 [1] 4/-1/-1->3->2
112
+ node03:2316573:2316712 [2] NCCL INFO Trees [0] 3/-1/-1->2->1 [1] 3/-1/-1->2->1
113
+ node03:2316571:2316708 [0] NCCL INFO Channel 01/02 : 0 1 2 3 4 5
114
+ node03:2316576:2316713 [5] NCCL INFO Trees [0] -1/-1/-1->5->4 [1] -1/-1/-1->5->4
115
+ node03:2316575:2316709 [4] NCCL INFO P2P Chunksize set to 131072
116
+ node03:2316572:2316711 [1] NCCL INFO P2P Chunksize set to 131072
117
+ node03:2316574:2316710 [3] NCCL INFO P2P Chunksize set to 131072
118
+ node03:2316573:2316712 [2] NCCL INFO P2P Chunksize set to 131072
119
+ node03:2316571:2316708 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1
120
+ node03:2316576:2316713 [5] NCCL INFO P2P Chunksize set to 131072
121
+ node03:2316571:2316708 [0] NCCL INFO P2P Chunksize set to 131072
122
+ node03:2316572:2316711 [1] NCCL INFO Channel 00/0 : 1[1] -> 2[2] via P2P/IPC
123
+ node03:2316572:2316711 [1] NCCL INFO Channel 01/0 : 1[1] -> 2[2] via P2P/IPC
124
+ node03:2316576:2316713 [5] NCCL INFO Channel 00 : 5[5] -> 0[0] via SHM/direct/direct
125
+ node03:2316573:2316712 [2] NCCL INFO Channel 00 : 2[2] -> 3[3] via SHM/direct/direct
126
+ node03:2316576:2316713 [5] NCCL INFO Channel 01 : 5[5] -> 0[0] via SHM/direct/direct
127
+ node03:2316573:2316712 [2] NCCL INFO Channel 01 : 2[2] -> 3[3] via SHM/direct/direct
128
+ node03:2316575:2316709 [4] NCCL INFO Channel 00/0 : 4[4] -> 5[5] via P2P/IPC
129
+ node03:2316575:2316709 [4] NCCL INFO Channel 01/0 : 4[4] -> 5[5] via P2P/IPC
130
+ node03:2316571:2316708 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/IPC
131
+ node03:2316574:2316710 [3] NCCL INFO Channel 00 : 3[3] -> 4[4] via SHM/direct/direct
132
+ node03:2316571:2316708 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/IPC
133
+ node03:2316574:2316710 [3] NCCL INFO Channel 01 : 3[3] -> 4[4] via SHM/direct/direct
134
+ node03:2316575:2316709 [4] NCCL INFO Connected all rings
135
+ node03:2316572:2316711 [1] NCCL INFO Connected all rings
136
+ node03:2316571:2316708 [0] NCCL INFO Connected all rings
137
+ node03:2316573:2316712 [2] NCCL INFO Connected all rings
138
+ node03:2316576:2316713 [5] NCCL INFO Connected all rings
139
+ node03:2316576:2316713 [5] NCCL INFO Channel 00/0 : 5[5] -> 4[4] via P2P/IPC
140
+ node03:2316572:2316711 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/IPC
141
+ node03:2316575:2316709 [4] NCCL INFO Channel 00 : 4[4] -> 3[3] via SHM/direct/direct
142
+ node03:2316574:2316710 [3] NCCL INFO Connected all rings
143
+ node03:2316572:2316711 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/IPC
144
+ node03:2316576:2316713 [5] NCCL INFO Channel 01/0 : 5[5] -> 4[4] via P2P/IPC
145
+ node03:2316575:2316709 [4] NCCL INFO Channel 01 : 4[4] -> 3[3] via SHM/direct/direct
146
+ node03:2316571:2316708 [0] NCCL INFO Connected all trees
147
+ node03:2316571:2316708 [0] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
148
+ node03:2316571:2316708 [0] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
149
+ node03:2316576:2316713 [5] NCCL INFO Connected all trees
150
+ node03:2316576:2316713 [5] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
151
+ node03:2316576:2316713 [5] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
152
+ node03:2316573:2316712 [2] NCCL INFO Channel 00/0 : 2[2] -> 1[1] via P2P/IPC
153
+ node03:2316573:2316712 [2] NCCL INFO Channel 01/0 : 2[2] -> 1[1] via P2P/IPC
154
+ node03:2316572:2316711 [1] NCCL INFO Connected all trees
155
+ node03:2316572:2316711 [1] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
156
+ node03:2316572:2316711 [1] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
157
+ node03:2316574:2316710 [3] NCCL INFO Channel 00 : 3[3] -> 2[2] via SHM/direct/direct
158
+ node03:2316574:2316710 [3] NCCL INFO Channel 01 : 3[3] -> 2[2] via SHM/direct/direct
159
+ node03:2316573:2316712 [2] NCCL INFO Connected all trees
160
+ node03:2316573:2316712 [2] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
161
+ node03:2316573:2316712 [2] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
162
+ node03:2316575:2316709 [4] NCCL INFO Connected all trees
163
+ node03:2316575:2316709 [4] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
164
+ node03:2316575:2316709 [4] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
165
+ node03:2316574:2316710 [3] NCCL INFO Connected all trees
166
+ node03:2316574:2316710 [3] NCCL INFO threadThresholds 8/8/64 | 48/8/64 | 512 | 512
167
+ node03:2316574:2316710 [3] NCCL INFO 2 coll channels, 0 nvls channels, 2 p2p channels, 2 p2p channels per peer
168
+ node03:2316575:2316709 [4] NCCL INFO comm 0xa46f6f0 rank 4 nranks 6 cudaDev 4 nvmlDev 4 busId c0000 commId 0x5c97a0f7f601b696 - Init COMPLETE
169
+ node03:2316573:2316712 [2] NCCL INFO comm 0xac1c200 rank 2 nranks 6 cudaDev 2 nvmlDev 2 busId 14000 commId 0x5c97a0f7f601b696 - Init COMPLETE
170
+ node03:2316571:2316708 [0] NCCL INFO comm 0xa205d10 rank 0 nranks 6 cudaDev 0 nvmlDev 0 busId 12000 commId 0x5c97a0f7f601b696 - Init COMPLETE
171
+ node03:2316572:2316711 [1] NCCL INFO comm 0xa155230 rank 1 nranks 6 cudaDev 1 nvmlDev 1 busId 13000 commId 0x5c97a0f7f601b696 - Init COMPLETE
172
+ node03:2316574:2316710 [3] NCCL INFO comm 0xadbe390 rank 3 nranks 6 cudaDev 3 nvmlDev 3 busId 48000 commId 0x5c97a0f7f601b696 - Init COMPLETE
173
+ node03:2316576:2316713 [5] NCCL INFO comm 0x9b6ede0 rank 5 nranks 6 cudaDev 5 nvmlDev 5 busId c1000 commId 0x5c97a0f7f601b696 - Init COMPLETE
174
+ 2025-03-03 16:24:03 | INFO | model:31 - Window size 12!
175
+ 2025-03-03 16:24:03 | INFO | model:51 - Initializing Multi-modal Swin Transformer weights from ckpts/swin_base_patch4_window12_384_22k.pth
176
+ 2025-03-03 16:24:05 | INFO | model.backbone:459 - loading swin success !!!
177
+ 2025-03-03 16:24:08 | INFO | __main__:144 - Model moved to GPU: 0
178
+ 2025-03-03 16:24:08 | INFO | __main__:145 - amsgrad: True
179
+ batch_size: 30
180
+ batch_size_val: 16
181
+ bert: bert-base-uncased
182
+ dataset: refcocog_u
183
+ dist_backend: nccl
184
+ dropout: 0.0
185
+ epochs: 50
186
+ evaluate: True
187
+ exclude_multiobj: True
188
+ exp_name: ACE_filter050_rev
189
+ filter_threshold: 0.5
190
+ fusion_drop: 0.0
191
+ gpu: 0
192
+ hp_selection: strict
193
+ input_size: 480
194
+ local_rank: 0
195
+ loss_option: ACE_verbonly
196
+ lr: 0.0001
197
+ lr_backbone: 5e-05
198
+ lr_text_encoder: 5e-05
199
+ manual_seed: 1455390217
200
+ margin_value: 12
201
+ mask_root: data/masks/refcocog_u
202
+ metric_learning: True
203
+ metric_loss_weight: 0.1
204
+ metric_mode: hardpos_only_sbertsim_refined
205
+ mha: 8-8-8-8
206
+ mixup_lasttwo: False
207
+ num_token: 2
208
+ output_dir: exp/refcoco_u/ACE_filter050_rev
209
+ output_folder: exp/refcoco_u
210
+ print_freq: 100
211
+ rank: 0
212
+ resume: None
213
+ save_freq: 1
214
+ start_epoch: 0
215
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
216
+ swin_type: base
217
+ sync_bn: True
218
+ temperature: 0.07
219
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
220
+ test_split: test
221
+ token_dim: 512
222
+ train_lmdb: data/lmdb/refcocog_u/train.lmdb
223
+ train_split: train
224
+ val_lmdb: data/lmdb/refcocog_u/val.lmdb
225
+ val_split: val
226
+ vis_dim: 512
227
+ visualize: False
228
+ weight: None
229
+ weight_decay: 0.0001
230
+ window12: True
231
+ word_dim: 768
232
+ word_len: 20
233
+ workers: 32
234
+ workers_val: 8
235
+ world_size: 6
236
+ 2025-03-03 16:26:29 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 100/1407] Batch=1.30 (1.38) Data=0.00 (0.07) Lr=0.000100 Loss=1.0907 (1.1761) IoU=26.22 (19.83) Prec@50=23.61 (8.67)
237
+ 2025-03-03 16:28:41 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 200/1407] Batch=1.29 (1.35) Data=0.00 (0.05) Lr=0.000100 Loss=0.9969 (1.0899) IoU=29.22 (24.57) Prec@50=25.56 (13.58)
238
+ 2025-03-03 16:30:54 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 300/1407] Batch=1.36 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.9088 (1.0496) IoU=35.17 (26.91) Prec@50=21.79 (16.03)
239
+ 2025-03-03 16:33:06 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 400/1407] Batch=1.22 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.8703 (1.0167) IoU=35.01 (28.40) Prec@50=21.29 (17.80)
240
+ 2025-03-03 16:35:19 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 500/1407] Batch=1.15 (1.34) Data=0.00 (0.04) Lr=0.000100 Loss=0.9754 (0.9922) IoU=31.02 (29.64) Prec@50=17.78 (19.37)
241
+ 2025-03-03 16:37:31 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 600/1407] Batch=1.16 (1.33) Data=0.00 (0.04) Lr=0.000100 Loss=0.8838 (0.9712) IoU=33.06 (30.63) Prec@50=26.75 (20.87)
242
+ 2025-03-03 16:39:43 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 700/1407] Batch=1.24 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.7971 (0.9545) IoU=34.82 (31.32) Prec@50=30.06 (21.88)
243
+ 2025-03-03 16:41:53 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 800/1407] Batch=1.48 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.9076 (0.9414) IoU=35.91 (31.87) Prec@50=23.91 (22.66)
244
+ 2025-03-03 16:44:04 | INFO | utils.misc:108 - Training: Epoch=[1/50] [ 900/1407] Batch=1.29 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.9933 (0.9301) IoU=30.51 (32.29) Prec@50=18.81 (23.24)
245
+ 2025-03-03 16:46:13 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1000/1407] Batch=1.26 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7584 (0.9199) IoU=40.61 (32.82) Prec@50=30.95 (23.96)
246
+ 2025-03-03 16:48:24 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1100/1407] Batch=1.69 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8691 (0.9110) IoU=35.25 (33.29) Prec@50=19.97 (24.69)
247
+ 2025-03-03 16:50:37 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1200/1407] Batch=1.39 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8363 (0.9033) IoU=34.11 (33.67) Prec@50=28.97 (25.27)
248
+ 2025-03-03 16:52:49 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1300/1407] Batch=1.35 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8169 (0.8959) IoU=34.20 (34.05) Prec@50=26.85 (25.87)
249
+ 2025-03-03 16:55:00 | INFO | utils.misc:108 - Training: Epoch=[1/50] [1400/1407] Batch=1.09 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7557 (0.8885) IoU=44.41 (34.49) Prec@50=41.27 (26.50)
250
+ 2025-03-03 16:55:45 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[1/50] mIoU=44.49 oIoU=42.87 Pr@50: 40.95 Pr@60: 30.46 Pr@70: 20.32 Pr@80: 11.81 Pr@90: 2.80
251
+ 2025-03-03 16:58:21 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 100/1407] Batch=1.28 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.9049 (0.7654) IoU=35.26 (41.03) Prec@50=29.66 (36.59)
252
+ 2025-03-03 17:00:37 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 200/1407] Batch=1.68 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.7241 (0.7458) IoU=40.81 (41.92) Prec@50=42.92 (38.62)
253
+ 2025-03-03 17:02:47 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 300/1407] Batch=1.42 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.8373 (0.7440) IoU=40.21 (42.04) Prec@50=32.79 (38.73)
254
+ 2025-03-03 17:04:56 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 400/1407] Batch=1.33 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8551 (0.7431) IoU=35.71 (41.98) Prec@50=34.74 (38.65)
255
+ 2025-03-03 17:07:05 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 500/1407] Batch=1.56 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7449 (0.7434) IoU=44.00 (41.89) Prec@50=46.22 (38.68)
256
+ 2025-03-03 17:09:14 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 600/1407] Batch=1.31 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7519 (0.7414) IoU=38.65 (41.92) Prec@50=30.89 (38.80)
257
+ 2025-03-03 17:11:26 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 700/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6646 (0.7398) IoU=43.75 (42.04) Prec@50=41.87 (39.10)
258
+ 2025-03-03 17:13:36 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 800/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7340 (0.7378) IoU=42.55 (42.15) Prec@50=39.05 (39.38)
259
+ 2025-03-03 17:15:44 | INFO | utils.misc:108 - Training: Epoch=[2/50] [ 900/1407] Batch=1.77 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6731 (0.7369) IoU=43.04 (42.09) Prec@50=48.61 (39.34)
260
+ 2025-03-03 17:17:53 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1000/1407] Batch=1.36 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6417 (0.7343) IoU=45.02 (42.19) Prec@50=51.01 (39.55)
261
+ 2025-03-03 17:20:04 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1100/1407] Batch=1.30 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.8503 (0.7329) IoU=38.33 (42.24) Prec@50=35.36 (39.68)
262
+ 2025-03-03 17:22:15 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1200/1407] Batch=1.62 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.7639 (0.7322) IoU=45.07 (42.29) Prec@50=34.68 (39.71)
263
+ 2025-03-03 17:24:27 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1300/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8485 (0.7314) IoU=36.88 (42.42) Prec@50=38.10 (39.84)
264
+ 2025-03-03 17:26:37 | INFO | utils.misc:108 - Training: Epoch=[2/50] [1400/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.8243 (0.7302) IoU=44.98 (42.64) Prec@50=41.79 (40.17)
265
+ 2025-03-03 17:27:21 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[2/50] mIoU=50.35 oIoU=48.08 Pr@50: 51.28 Pr@60: 40.60 Pr@70: 30.92 Pr@80: 19.46 Pr@90: 6.02
266
+ 2025-03-03 17:29:59 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 100/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6662 (0.6521) IoU=43.34 (48.06) Prec@50=45.54 (48.24)
267
+ 2025-03-03 17:32:11 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 200/1407] Batch=1.32 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.6270 (0.6581) IoU=50.53 (47.15) Prec@50=58.53 (47.12)
268
+ 2025-03-03 17:34:21 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 300/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7448 (0.6597) IoU=42.87 (46.97) Prec@50=36.43 (46.95)
269
+ 2025-03-03 17:36:33 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 400/1407] Batch=1.39 (1.31) Data=0.01 (0.03) Lr=0.000100 Loss=0.6364 (0.6586) IoU=41.82 (46.78) Prec@50=42.46 (46.52)
270
+ 2025-03-03 17:38:45 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 500/1407] Batch=1.40 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7957 (0.6597) IoU=43.02 (46.63) Prec@50=32.08 (46.34)
271
+ 2025-03-03 17:40:56 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 600/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7063 (0.6615) IoU=44.00 (46.61) Prec@50=39.84 (46.29)
272
+ 2025-03-03 17:43:06 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 700/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5882 (0.6644) IoU=54.05 (46.60) Prec@50=57.14 (46.15)
273
+ 2025-03-03 17:45:16 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 800/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6320 (0.6651) IoU=48.72 (46.60) Prec@50=47.24 (46.12)
274
+ 2025-03-03 17:47:27 | INFO | utils.misc:108 - Training: Epoch=[3/50] [ 900/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7309 (0.6651) IoU=44.48 (46.71) Prec@50=40.69 (46.31)
275
+ 2025-03-03 17:49:35 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1000/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.8106 (0.6643) IoU=33.29 (46.83) Prec@50=28.57 (46.37)
276
+ 2025-03-03 17:51:46 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1100/1407] Batch=1.32 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6855 (0.6630) IoU=43.40 (46.90) Prec@50=45.09 (46.52)
277
+ 2025-03-03 17:53:58 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1200/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6270 (0.6619) IoU=51.54 (47.00) Prec@50=47.42 (46.63)
278
+ 2025-03-03 17:56:09 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1300/1407] Batch=1.38 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6188 (0.6614) IoU=48.14 (47.06) Prec@50=53.94 (46.75)
279
+ 2025-03-03 17:58:19 | INFO | utils.misc:108 - Training: Epoch=[3/50] [1400/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6312 (0.6612) IoU=53.97 (47.06) Prec@50=59.68 (46.75)
280
+ 2025-03-03 17:59:03 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[3/50] mIoU=54.39 oIoU=51.28 Pr@50: 57.42 Pr@60: 48.45 Pr@70: 39.20 Pr@80: 27.04 Pr@90: 10.96
281
+ 2025-03-03 18:01:40 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 100/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.7058 (0.6106) IoU=38.22 (48.27) Prec@50=32.74 (49.35)
282
+ 2025-03-03 18:03:51 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 200/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5542 (0.6165) IoU=43.66 (47.87) Prec@50=45.73 (48.88)
283
+ 2025-03-03 18:06:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 300/1407] Batch=1.51 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5978 (0.6161) IoU=47.81 (48.07) Prec@50=44.17 (49.12)
284
+ 2025-03-03 18:08:11 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 400/1407] Batch=1.11 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7430 (0.6135) IoU=45.09 (48.32) Prec@50=44.84 (49.50)
285
+ 2025-03-03 18:10:22 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 500/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6319 (0.6166) IoU=49.79 (48.05) Prec@50=47.02 (49.14)
286
+ 2025-03-03 18:12:33 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 600/1407] Batch=1.59 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6365 (0.6162) IoU=41.68 (48.00) Prec@50=35.05 (48.94)
287
+ 2025-03-03 18:14:46 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 700/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5867 (0.6171) IoU=53.39 (47.94) Prec@50=57.92 (48.86)
288
+ 2025-03-03 18:16:55 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 800/1407] Batch=1.62 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6259 (0.6161) IoU=45.88 (48.17) Prec@50=48.51 (49.16)
289
+ 2025-03-03 18:19:05 | INFO | utils.misc:108 - Training: Epoch=[4/50] [ 900/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6503 (0.6170) IoU=50.25 (48.31) Prec@50=51.11 (49.28)
290
+ 2025-03-03 18:21:17 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1000/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6409 (0.6158) IoU=48.29 (48.45) Prec@50=50.24 (49.47)
291
+ 2025-03-03 18:23:27 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1100/1407] Batch=1.28 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6323 (0.6145) IoU=50.15 (48.53) Prec@50=51.49 (49.55)
292
+ 2025-03-03 18:25:37 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5679 (0.6129) IoU=55.70 (48.67) Prec@50=61.69 (49.72)
293
+ 2025-03-03 18:27:50 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1300/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6499 (0.6116) IoU=50.86 (48.78) Prec@50=50.87 (49.90)
294
+ 2025-03-03 18:30:01 | INFO | utils.misc:108 - Training: Epoch=[4/50] [1400/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.7374 (0.6114) IoU=39.56 (48.79) Prec@50=34.68 (49.92)
295
+ 2025-03-03 18:30:45 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[4/50] mIoU=56.35 oIoU=53.31 Pr@50: 60.68 Pr@60: 51.48 Pr@70: 42.50 Pr@80: 30.30 Pr@90: 11.50
296
+ 2025-03-03 18:33:23 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 100/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4113 (0.5488) IoU=61.30 (51.36) Prec@50=69.78 (53.66)
297
+ 2025-03-03 18:35:32 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 200/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4562 (0.5582) IoU=59.50 (51.34) Prec@50=66.80 (53.66)
298
+ 2025-03-03 18:37:45 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 300/1407] Batch=1.43 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5283 (0.5580) IoU=52.54 (51.87) Prec@50=57.52 (54.45)
299
+ 2025-03-03 18:39:56 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 400/1407] Batch=1.33 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5587 (0.5559) IoU=49.20 (52.23) Prec@50=55.06 (54.90)
300
+ 2025-03-03 18:42:08 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 500/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6211 (0.5567) IoU=50.09 (52.48) Prec@50=54.76 (55.43)
301
+ 2025-03-03 18:44:19 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 600/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5935 (0.5572) IoU=51.16 (52.40) Prec@50=53.10 (55.33)
302
+ 2025-03-03 18:46:29 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 700/1407] Batch=1.20 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6169 (0.5571) IoU=55.58 (52.48) Prec@50=61.11 (55.51)
303
+ 2025-03-03 18:48:40 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 800/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3962 (0.5566) IoU=63.84 (52.62) Prec@50=73.62 (55.67)
304
+ 2025-03-03 18:50:52 | INFO | utils.misc:108 - Training: Epoch=[5/50] [ 900/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5741 (0.5561) IoU=51.11 (52.62) Prec@50=49.78 (55.68)
305
+ 2025-03-03 18:53:04 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1000/1407] Batch=1.30 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5752 (0.5549) IoU=46.45 (52.75) Prec@50=45.34 (55.80)
306
+ 2025-03-03 18:55:17 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1100/1407] Batch=1.49 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5531 (0.5562) IoU=54.64 (52.65) Prec@50=56.15 (55.63)
307
+ 2025-03-03 18:57:30 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1200/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.6240 (0.5565) IoU=53.21 (52.64) Prec@50=60.42 (55.60)
308
+ 2025-03-03 18:59:42 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1300/1407] Batch=1.35 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.5657 (0.5568) IoU=57.00 (52.65) Prec@50=65.97 (55.57)
309
+ 2025-03-03 19:01:53 | INFO | utils.misc:108 - Training: Epoch=[5/50] [1400/1407] Batch=1.10 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4952 (0.5573) IoU=56.29 (52.65) Prec@50=57.14 (55.56)
310
+ 2025-03-03 19:02:37 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[5/50] mIoU=59.29 oIoU=56.00 Pr@50: 65.58 Pr@60: 57.50 Pr@70: 48.68 Pr@80: 36.48 Pr@90: 15.23
311
+ 2025-03-03 19:05:14 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 100/1407] Batch=1.34 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4567 (0.4978) IoU=56.26 (55.83) Prec@50=58.94 (60.25)
312
+ 2025-03-03 19:07:25 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 200/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5726 (0.4995) IoU=47.47 (55.42) Prec@50=48.71 (60.07)
313
+ 2025-03-03 19:09:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 300/1407] Batch=1.32 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4761 (0.5059) IoU=58.15 (55.06) Prec@50=62.62 (59.25)
314
+ 2025-03-03 19:11:44 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 400/1407] Batch=1.23 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5243 (0.5106) IoU=55.15 (54.66) Prec@50=58.63 (58.77)
315
+ 2025-03-03 19:13:54 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 500/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6256 (0.5109) IoU=46.39 (54.44) Prec@50=48.02 (58.53)
316
+ 2025-03-03 19:16:05 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 600/1407] Batch=1.45 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5675 (0.5104) IoU=47.12 (54.51) Prec@50=49.41 (58.52)
317
+ 2025-03-03 19:18:15 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 700/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6829 (0.5110) IoU=48.30 (54.50) Prec@50=54.37 (58.45)
318
+ 2025-03-03 19:20:25 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 800/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5415 (0.5114) IoU=49.64 (54.42) Prec@50=50.30 (58.29)
319
+ 2025-03-03 19:22:34 | INFO | utils.misc:108 - Training: Epoch=[6/50] [ 900/1407] Batch=1.28 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4806 (0.5120) IoU=58.17 (54.46) Prec@50=57.84 (58.30)
320
+ 2025-03-03 19:24:45 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1000/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5476 (0.5130) IoU=46.38 (54.36) Prec@50=48.41 (58.26)
321
+ 2025-03-03 19:26:54 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1100/1407] Batch=1.13 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4327 (0.5143) IoU=61.53 (54.18) Prec@50=70.24 (58.03)
322
+ 2025-03-03 19:29:04 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1200/1407] Batch=1.27 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4774 (0.5143) IoU=55.10 (54.15) Prec@50=55.14 (58.05)
323
+ 2025-03-03 19:31:12 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1300/1407] Batch=1.40 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.5230 (0.5145) IoU=51.25 (54.14) Prec@50=52.98 (58.04)
324
+ 2025-03-03 19:33:21 | INFO | utils.misc:108 - Training: Epoch=[6/50] [1400/1407] Batch=1.39 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.6039 (0.5156) IoU=48.91 (54.14) Prec@50=46.36 (57.95)
325
+ 2025-03-03 19:34:05 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[6/50] mIoU=60.83 oIoU=57.66 Pr@50: 67.29 Pr@60: 60.37 Pr@70: 51.44 Pr@80: 39.70 Pr@90: 18.07
326
+ 2025-03-03 19:36:44 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 100/1407] Batch=1.18 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3694 (0.4535) IoU=64.46 (57.97) Prec@50=69.13 (63.46)
327
+ 2025-03-03 19:38:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4216 (0.4611) IoU=56.11 (57.46) Prec@50=58.08 (62.46)
328
+ 2025-03-03 19:41:04 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 300/1407] Batch=1.18 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3782 (0.4635) IoU=61.55 (57.46) Prec@50=69.13 (62.34)
329
+ 2025-03-03 19:43:16 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 400/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4767 (0.4658) IoU=57.43 (57.24) Prec@50=67.86 (61.95)
330
+ 2025-03-03 19:45:27 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 500/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4054 (0.4673) IoU=62.52 (57.11) Prec@50=69.68 (61.74)
331
+ 2025-03-03 19:47:38 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 600/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3967 (0.4709) IoU=63.16 (56.86) Prec@50=71.21 (61.39)
332
+ 2025-03-03 19:49:48 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 700/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4413 (0.4737) IoU=62.03 (56.70) Prec@50=80.79 (61.12)
333
+ 2025-03-03 19:51:58 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 800/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4080 (0.4749) IoU=62.66 (56.63) Prec@50=68.57 (61.05)
334
+ 2025-03-03 19:54:08 | INFO | utils.misc:108 - Training: Epoch=[7/50] [ 900/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4112 (0.4776) IoU=60.53 (56.50) Prec@50=67.70 (60.89)
335
+ 2025-03-03 19:56:20 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1000/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4987 (0.4784) IoU=61.67 (56.37) Prec@50=65.00 (60.70)
336
+ 2025-03-03 19:58:31 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1100/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3991 (0.4795) IoU=56.89 (56.34) Prec@50=59.13 (60.63)
337
+ 2025-03-03 20:00:42 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1200/1407] Batch=1.59 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3944 (0.4802) IoU=61.46 (56.35) Prec@50=64.31 (60.58)
338
+ 2025-03-03 20:02:55 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1300/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6520 (0.4818) IoU=52.79 (56.28) Prec@50=57.54 (60.48)
339
+ 2025-03-03 20:05:06 | INFO | utils.misc:108 - Training: Epoch=[7/50] [1400/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5381 (0.4822) IoU=54.06 (56.23) Prec@50=57.18 (60.46)
340
+ 2025-03-03 20:05:51 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[7/50] mIoU=62.05 oIoU=58.56 Pr@50: 69.35 Pr@60: 61.85 Pr@70: 53.96 Pr@80: 40.87 Pr@90: 19.00
341
+ 2025-03-03 20:08:27 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 100/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3871 (0.4439) IoU=57.42 (57.00) Prec@50=60.81 (62.43)
342
+ 2025-03-03 20:10:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 200/1407] Batch=1.12 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4502 (0.4448) IoU=59.06 (57.36) Prec@50=60.71 (62.74)
343
+ 2025-03-03 20:12:49 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 300/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4527 (0.4435) IoU=61.08 (57.48) Prec@50=67.06 (63.02)
344
+ 2025-03-03 20:15:02 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 400/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4179 (0.4455) IoU=59.71 (57.75) Prec@50=65.75 (63.11)
345
+ 2025-03-03 20:17:13 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 500/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4548 (0.4472) IoU=59.80 (57.79) Prec@50=66.67 (63.14)
346
+ 2025-03-03 20:19:26 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 600/1407] Batch=1.25 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4598 (0.4486) IoU=49.73 (57.75) Prec@50=53.57 (62.92)
347
+ 2025-03-03 20:21:38 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 700/1407] Batch=1.24 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4703 (0.4484) IoU=56.49 (57.69) Prec@50=60.24 (62.81)
348
+ 2025-03-03 20:23:48 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 800/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5171 (0.4514) IoU=57.90 (57.64) Prec@50=64.52 (62.68)
349
+ 2025-03-03 20:25:58 | INFO | utils.misc:108 - Training: Epoch=[8/50] [ 900/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5224 (0.4527) IoU=59.30 (57.70) Prec@50=61.31 (62.74)
350
+ 2025-03-03 20:28:09 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5577 (0.4544) IoU=52.09 (57.60) Prec@50=56.67 (62.56)
351
+ 2025-03-03 20:30:22 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1100/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4545 (0.4550) IoU=56.95 (57.56) Prec@50=62.40 (62.53)
352
+ 2025-03-03 20:32:33 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1200/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3771 (0.4548) IoU=61.85 (57.58) Prec@50=72.30 (62.52)
353
+ 2025-03-03 20:34:45 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1300/1407] Batch=1.32 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4313 (0.4551) IoU=60.03 (57.54) Prec@50=65.79 (62.46)
354
+ 2025-03-03 20:36:57 | INFO | utils.misc:108 - Training: Epoch=[8/50] [1400/1407] Batch=1.21 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.5147 (0.4564) IoU=53.92 (57.49) Prec@50=52.18 (62.34)
355
+ 2025-03-03 20:37:41 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[8/50] mIoU=62.41 oIoU=58.68 Pr@50: 69.23 Pr@60: 61.46 Pr@70: 53.65 Pr@80: 41.69 Pr@90: 18.92
356
+ 2025-03-03 20:40:22 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 100/1407] Batch=1.15 (1.32) Data=0.01 (0.03) Lr=0.000100 Loss=0.3586 (0.4313) IoU=70.73 (59.54) Prec@50=83.97 (64.88)
357
+ 2025-03-03 20:42:34 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 200/1407] Batch=1.17 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4021 (0.4262) IoU=58.45 (59.39) Prec@50=69.84 (64.84)
358
+ 2025-03-03 20:44:44 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 300/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5048 (0.4308) IoU=62.54 (59.46) Prec@50=73.25 (65.16)
359
+ 2025-03-03 20:46:54 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 400/1407] Batch=1.13 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.5139 (0.4306) IoU=57.54 (59.52) Prec@50=63.33 (65.13)
360
+ 2025-03-03 20:49:05 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 500/1407] Batch=1.14 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3949 (0.4328) IoU=64.36 (59.33) Prec@50=70.08 (64.82)
361
+ 2025-03-03 20:51:16 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 600/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4611 (0.4326) IoU=57.47 (59.35) Prec@50=60.89 (64.74)
362
+ 2025-03-03 20:53:26 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 700/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4755 (0.4334) IoU=56.52 (59.25) Prec@50=56.35 (64.61)
363
+ 2025-03-03 20:55:37 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 800/1407] Batch=1.29 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4796 (0.4337) IoU=60.29 (59.28) Prec@50=68.59 (64.69)
364
+ 2025-03-03 20:57:47 | INFO | utils.misc:108 - Training: Epoch=[9/50] [ 900/1407] Batch=1.24 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4704 (0.4349) IoU=52.80 (59.28) Prec@50=63.89 (64.72)
365
+ 2025-03-03 20:59:59 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3763 (0.4347) IoU=65.34 (59.13) Prec@50=72.64 (64.62)
366
+ 2025-03-03 21:02:11 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1100/1407] Batch=1.42 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4699 (0.4350) IoU=59.93 (59.02) Prec@50=65.93 (64.56)
367
+ 2025-03-03 21:04:23 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1200/1407] Batch=1.21 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3397 (0.4357) IoU=65.16 (58.84) Prec@50=72.56 (64.37)
368
+ 2025-03-03 21:06:34 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1300/1407] Batch=1.26 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.6262 (0.4363) IoU=48.86 (58.78) Prec@50=47.62 (64.23)
369
+ 2025-03-03 21:08:46 | INFO | utils.misc:108 - Training: Epoch=[9/50] [1400/1407] Batch=1.35 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4082 (0.4374) IoU=63.09 (58.70) Prec@50=74.50 (64.19)
370
+ 2025-03-03 21:09:29 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[9/50] mIoU=62.25 oIoU=58.06 Pr@50: 69.23 Pr@60: 63.05 Pr@70: 54.78 Pr@80: 42.58 Pr@90: 20.16
371
+ 2025-03-03 21:11:56 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 100/1407] Batch=1.28 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3925 (0.4111) IoU=55.59 (59.60) Prec@50=61.05 (66.27)
372
+ 2025-03-03 21:14:06 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 200/1407] Batch=1.19 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4349 (0.4082) IoU=58.24 (59.65) Prec@50=63.89 (66.06)
373
+ 2025-03-03 21:16:18 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 300/1407] Batch=1.72 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4199 (0.4066) IoU=59.00 (59.75) Prec@50=65.48 (65.97)
374
+ 2025-03-03 21:18:31 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 400/1407] Batch=1.18 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.3139 (0.4110) IoU=64.88 (59.49) Prec@50=76.98 (65.56)
375
+ 2025-03-03 21:20:42 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 500/1407] Batch=1.24 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4145 (0.4128) IoU=58.26 (59.39) Prec@50=63.10 (65.38)
376
+ 2025-03-03 21:22:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 600/1407] Batch=1.47 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4332 (0.4123) IoU=55.08 (59.26) Prec@50=59.82 (65.29)
377
+ 2025-03-03 21:25:06 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 700/1407] Batch=1.37 (1.32) Data=0.01 (0.03) Lr=0.000100 Loss=0.5174 (0.4145) IoU=49.25 (59.16) Prec@50=51.03 (65.12)
378
+ 2025-03-03 21:27:17 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 800/1407] Batch=1.13 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4140 (0.4154) IoU=55.88 (59.11) Prec@50=62.70 (65.09)
379
+ 2025-03-03 21:29:28 | INFO | utils.misc:108 - Training: Epoch=[10/50] [ 900/1407] Batch=1.36 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4532 (0.4157) IoU=56.31 (59.08) Prec@50=59.76 (65.11)
380
+ 2025-03-03 21:31:37 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1000/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4510 (0.4167) IoU=54.24 (59.11) Prec@50=59.07 (65.15)
381
+ 2025-03-03 21:33:46 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1100/1407] Batch=1.25 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4805 (0.4178) IoU=54.15 (59.11) Prec@50=53.57 (65.15)
382
+ 2025-03-03 21:35:55 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1200/1407] Batch=1.17 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4031 (0.4190) IoU=61.14 (59.09) Prec@50=66.90 (65.10)
383
+ 2025-03-03 21:38:05 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1300/1407] Batch=1.38 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4519 (0.4200) IoU=53.42 (59.04) Prec@50=58.10 (65.00)
384
+ 2025-03-03 21:40:16 | INFO | utils.misc:108 - Training: Epoch=[10/50] [1400/1407] Batch=1.15 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4071 (0.4211) IoU=64.93 (59.02) Prec@50=74.13 (64.97)
385
+ 2025-03-03 21:40:58 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[10/50] mIoU=62.92 oIoU=59.93 Pr@50: 70.59 Pr@60: 64.65 Pr@70: 55.67 Pr@80: 43.43 Pr@90: 20.32
386
+ 2025-03-03 21:43:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 100/1407] Batch=1.28 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4913 (0.3952) IoU=56.25 (61.12) Prec@50=59.94 (67.24)
387
+ 2025-03-03 21:45:49 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 200/1407] Batch=1.39 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4227 (0.3953) IoU=61.48 (60.92) Prec@50=66.01 (66.88)
388
+ 2025-03-03 21:47:57 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 300/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.2995 (0.3939) IoU=66.41 (61.32) Prec@50=76.88 (67.40)
389
+ 2025-03-03 21:50:08 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 400/1407] Batch=1.29 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3530 (0.3954) IoU=63.13 (61.36) Prec@50=66.90 (67.54)
390
+ 2025-03-03 21:52:18 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 500/1407] Batch=1.28 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3185 (0.3975) IoU=61.71 (61.22) Prec@50=66.07 (67.42)
391
+ 2025-03-03 21:54:28 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 600/1407] Batch=1.22 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.3977 (0.3992) IoU=57.77 (60.99) Prec@50=64.72 (67.13)
392
+ 2025-03-03 21:56:38 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 700/1407] Batch=1.23 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4287 (0.3990) IoU=58.03 (60.86) Prec@50=61.11 (67.02)
393
+ 2025-03-03 21:58:50 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 800/1407] Batch=1.24 (1.30) Data=0.00 (0.03) Lr=0.000100 Loss=0.4175 (0.3999) IoU=62.48 (60.75) Prec@50=69.74 (66.85)
394
+ 2025-03-03 22:01:01 | INFO | utils.misc:108 - Training: Epoch=[11/50] [ 900/1407] Batch=1.27 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4177 (0.4006) IoU=56.65 (60.83) Prec@50=60.22 (66.99)
395
+ 2025-03-03 22:03:14 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1000/1407] Batch=1.34 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3733 (0.4012) IoU=64.66 (60.94) Prec@50=75.50 (67.11)
396
+ 2025-03-03 22:05:25 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1100/1407] Batch=1.16 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3426 (0.4011) IoU=72.35 (61.04) Prec@50=86.75 (67.21)
397
+ 2025-03-03 22:07:36 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1200/1407] Batch=1.22 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.3453 (0.4008) IoU=65.79 (61.17) Prec@50=76.83 (67.42)
398
+ 2025-03-03 22:09:48 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1300/1407] Batch=1.64 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4473 (0.4011) IoU=59.89 (61.20) Prec@50=73.39 (67.39)
399
+ 2025-03-03 22:12:00 | INFO | utils.misc:108 - Training: Epoch=[11/50] [1400/1407] Batch=1.23 (1.31) Data=0.00 (0.03) Lr=0.000100 Loss=0.4229 (0.4018) IoU=57.42 (61.18) Prec@50=66.96 (67.32)
400
+ 2025-03-03 22:12:46 | INFO | engine.engine_gref:166 - Evaluation: Epoch=[11/50] mIoU=63.56 oIoU=60.39 Pr@50: 71.37 Pr@60: 65.11 Pr@70: 57.58 Pr@80: 44.68 Pr@90: 20.90
401
+ 2025-03-03 22:15:27 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 100/1407] Batch=1.32 (1.35) Data=0.00 (0.04) Lr=0.000100 Loss=0.3085 (0.3645) IoU=65.95 (62.33) Prec@50=78.01 (69.10)
402
+ 2025-03-03 22:17:39 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 200/1407] Batch=1.66 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.2717 (0.3668) IoU=68.88 (62.85) Prec@50=78.70 (69.82)
403
+ 2025-03-03 22:19:51 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 300/1407] Batch=1.14 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3265 (0.3708) IoU=62.95 (62.78) Prec@50=68.89 (69.53)
404
+ 2025-03-03 22:22:04 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 400/1407] Batch=1.13 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3058 (0.3704) IoU=68.49 (62.68) Prec@50=73.65 (69.41)
405
+ 2025-03-03 22:24:17 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 500/1407] Batch=1.25 (1.33) Data=0.00 (0.03) Lr=0.000100 Loss=0.3924 (0.3704) IoU=69.36 (62.64) Prec@50=85.52 (69.33)
406
+ 2025-03-03 22:26:26 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 600/1407] Batch=1.16 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.2908 (0.3704) IoU=66.72 (62.64) Prec@50=72.06 (69.23)
407
+ 2025-03-03 22:28:38 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 700/1407] Batch=1.30 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.4189 (0.3716) IoU=54.04 (62.63) Prec@50=55.06 (69.14)
408
+ 2025-03-03 22:30:49 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 800/1407] Batch=1.26 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.2783 (0.3722) IoU=71.81 (62.64) Prec@50=80.36 (69.16)
409
+ 2025-03-03 22:32:58 | INFO | utils.misc:108 - Training: Epoch=[12/50] [ 900/1407] Batch=1.23 (1.32) Data=0.00 (0.03) Lr=0.000100 Loss=0.3591 (0.3739) IoU=69.03 (62.56) Prec@50=79.31 (69.07)
410
+ [2025-03-03 22:34:38,633] torch.distributed.elastic.agent.server.api: [WARNING] Received Signals.SIGINT death signal, shutting down workers
411
+ [2025-03-03 22:34:38,634] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316571 closing signal SIGINT
412
+ [2025-03-03 22:34:38,636] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316572 closing signal SIGINT
413
+ [2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316573 closing signal SIGINT
414
+ [2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316574 closing signal SIGINT
415
+ [2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316575 closing signal SIGINT
416
+ [2025-03-03 22:34:38,637] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316576 closing signal SIGINT
417
+ Exception ignored in: [2025-03-03 22:34:38,812] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316571 closing signal SIGTERM
418
+ <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcbc8afc8b0>
419
+ Traceback (most recent call last):
420
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
421
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fe97c7bf8b0>
422
+ Traceback (most recent call last):
423
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
424
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f7e575df8b0>
425
+ Traceback (most recent call last):
426
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
427
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0f473318b0>
428
+ Traceback (most recent call last):
429
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
430
+ Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f29013bb8b0>
431
+ Traceback (most recent call last):
432
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
433
+ [2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316572 closing signal SIGTERM
434
+ [2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316573 closing signal SIGTERM
435
+ [2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316574 closing signal SIGTERM
436
+ [2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316575 closing signal SIGTERM
437
+ [2025-03-03 22:34:38,864] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 2316576 closing signal SIGTERM
438
+ Traceback (most recent call last):
439
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 736, in run
440
+ result = self._invoke_run(role)
441
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 877, in _invoke_run
442
+ time.sleep(monitor_interval)
443
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
444
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
445
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
446
+
447
+ During handling of the above exception, another exception occurred:
448
+
449
+ Traceback (most recent call last):
450
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 743, in run
451
+ self._shutdown(e.sigval)
452
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
453
+ self._pcontext.close(death_sig)
454
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
455
+ self._close(death_sig=death_sig, timeout=timeout)
456
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
457
+ handler.proc.wait(time_to_wait)
458
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
459
+ return self._wait(timeout=timeout)
460
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
461
+ time.sleep(delay)
462
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
463
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
464
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
465
+
466
+ During handling of the above exception, another exception occurred:
467
+
468
+ Traceback (most recent call last):
469
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 255, in launch_agent
470
+ result = agent.run()
471
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/metrics/api.py", line 124, in wrapper
472
+ result = f(*args, **kwargs)
473
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 748, in run
474
+ self._shutdown()
475
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/local_elastic_agent.py", line 289, in _shutdown
476
+ self._pcontext.close(death_sig)
477
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 331, in close
478
+ self._close(death_sig=death_sig, timeout=timeout)
479
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 713, in _close
480
+ handler.proc.wait(time_to_wait)
481
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1189, in wait
482
+ return self._wait(timeout=timeout)
483
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/subprocess.py", line 1927, in _wait
484
+ time.sleep(delay)
485
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
486
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
487
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
488
+
489
+ During handling of the above exception, another exception occurred:
490
+
491
+ Traceback (most recent call last):
492
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 197, in _run_module_as_main
493
+ return _run_code(code, main_globals, None,
494
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/runpy.py", line 87, in _run_code
495
+ exec(code, run_globals)
496
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 196, in <module>
497
+ main()
498
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 192, in main
499
+ launch(args)
500
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launch.py", line 177, in launch
501
+ run(args)
502
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/run.py", line 797, in run
503
+ elastic_launch(
504
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
505
+ return launch_agent(self._config, self._entrypoint, list(args))
506
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 277, in launch_agent
507
+ events.record(agent.get_event_failed())
508
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 756, in get_event_failed
509
+ raw_error=traceback.format_exc(),
510
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 167, in format_exc
511
+ return "".join(format_exception(*sys.exc_info(), limit=limit, chain=chain))
512
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 120, in format_exception
513
+ return list(TracebackException(
514
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 517, in __init__
515
+ self.stack = StackSummary.extract(
516
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 366, in extract
517
+ f.line
518
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/traceback.py", line 288, in line
519
+ self._line = linecache.getline(self.filename, self.lineno).strip()
520
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 30, in getline
521
+ lines = getlines(filename, module_globals)
522
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 46, in getlines
523
+ return updatecache(filename, module_globals)
524
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/linecache.py", line 93, in updatecache
525
+ stat = os.stat(fullname)
526
+ File "/home/seunghoon/.conda/envs/ris_all/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 62, in _terminate_process_handler
527
+ raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
528
+ torch.distributed.elastic.multiprocessing.api.SignalException: Process 2316558 got signal: 2
CGFormer/bash_logs/sanity_node03.log ADDED
The diff for this file is too large to render. See raw diff
 
CGFormer/bert/__pycache__/activations.cpython-38.pyc ADDED
Binary file (1.95 kB). View file
 
CGFormer/bert/__pycache__/activations.cpython-39.pyc ADDED
Binary file (1.94 kB). View file
 
CGFormer/bert/__pycache__/configuration_bert.cpython-38.pyc ADDED
Binary file (7.88 kB). View file
 
CGFormer/bert/__pycache__/configuration_bert.cpython-39.pyc ADDED
Binary file (7.88 kB). View file
 
CGFormer/bert/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (16.3 kB). View file
 
CGFormer/bert/__pycache__/configuration_utils.cpython-39.pyc ADDED
Binary file (16.3 kB). View file
 
CGFormer/bert/__pycache__/file_utils.cpython-38.pyc ADDED
Binary file (24.5 kB). View file
 
CGFormer/bert/__pycache__/file_utils.cpython-39.pyc ADDED
Binary file (24.7 kB). View file
 
CGFormer/bert/__pycache__/generation_utils.cpython-38.pyc ADDED
Binary file (28.2 kB). View file
 
CGFormer/bert/__pycache__/generation_utils.cpython-39.pyc ADDED
Binary file (28 kB). View file
 
CGFormer/bert/__pycache__/modeling_bert.cpython-38.pyc ADDED
Binary file (55.3 kB). View file
 
CGFormer/bert/__pycache__/modeling_bert.cpython-39.pyc ADDED
Binary file (55.2 kB). View file
 
CGFormer/bert/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (48 kB). View file
 
CGFormer/bert/__pycache__/modeling_utils.cpython-39.pyc ADDED
Binary file (48 kB). View file
 
CGFormer/bert/__pycache__/tokenization_bert.cpython-38.pyc ADDED
Binary file (19.3 kB). View file
 
CGFormer/bert/__pycache__/tokenization_bert.cpython-39.pyc ADDED
Binary file (19.3 kB). View file
 
CGFormer/bert/__pycache__/tokenization_utils.cpython-38.pyc ADDED
Binary file (24.9 kB). View file
 
CGFormer/bert/__pycache__/tokenization_utils.cpython-39.pyc ADDED
Binary file (24.9 kB). View file
 
CGFormer/bert/__pycache__/tokenization_utils_base.cpython-38.pyc ADDED
Binary file (82.4 kB). View file
 
CGFormer/bert/__pycache__/tokenization_utils_base.cpython-39.pyc ADDED
Binary file (82.4 kB). View file
 
CGFormer/bert/activations.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ def swish(x):
12
+ return x * torch.sigmoid(x)
13
+
14
+
15
+ def _gelu_python(x):
16
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
17
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
18
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
19
+ This is now written in C in torch.nn.functional
20
+ Also see https://arxiv.org/abs/1606.08415
21
+ """
22
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
23
+
24
+
25
+ def gelu_new(x):
26
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
27
+ Also see https://arxiv.org/abs/1606.08415
28
+ """
29
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
30
+
31
+
32
+ if torch.__version__ < "1.4.0":
33
+ gelu = _gelu_python
34
+ else:
35
+ gelu = F.gelu
36
+
37
+
38
+ def gelu_fast(x):
39
+ return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
40
+
41
+
42
+ ACT2FN = {
43
+ "relu": F.relu,
44
+ "swish": swish,
45
+ "gelu": gelu,
46
+ "tanh": torch.tanh,
47
+ "gelu_new": gelu_new,
48
+ "gelu_fast": gelu_fast,
49
+ }
50
+
51
+
52
+ def get_activation(activation_string):
53
+ if activation_string in ACT2FN:
54
+ return ACT2FN[activation_string]
55
+ else:
56
+ raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
CGFormer/bert/configuration_bert.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ BERT model configuration """
17
+
18
+
19
+ import logging
20
+
21
+ from .configuration_utils import PretrainedConfig
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
28
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
29
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
30
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
31
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
32
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
33
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
34
+ "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
35
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
36
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
37
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
38
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
39
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
40
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
41
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
42
+ "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
43
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
44
+ "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
45
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
46
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
47
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
48
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
49
+ # See all BERT models at https://huggingface.co/models?filter=bert
50
+ }
51
+
52
+
53
+ class BertConfig(PretrainedConfig):
54
+ r"""
55
+ This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
56
+ It is used to instantiate an BERT model according to the specified arguments, defining the model
57
+ architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
58
+ the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
59
+
60
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
61
+ to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
62
+ for more information.
63
+
64
+
65
+ Args:
66
+ vocab_size (:obj:`int`, optional, defaults to 30522):
67
+ Vocabulary size of the BERT model. Defines the different tokens that
68
+ can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
69
+ hidden_size (:obj:`int`, optional, defaults to 768):
70
+ Dimensionality of the encoder layers and the pooler layer.
71
+ num_hidden_layers (:obj:`int`, optional, defaults to 12):
72
+ Number of hidden layers in the Transformer encoder.
73
+ num_attention_heads (:obj:`int`, optional, defaults to 12):
74
+ Number of attention heads for each attention layer in the Transformer encoder.
75
+ intermediate_size (:obj:`int`, optional, defaults to 3072):
76
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
77
+ hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
78
+ The non-linear activation function (function or string) in the encoder and pooler.
79
+ If string, "gelu", "relu", "swish" and "gelu_new" are supported.
80
+ hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
81
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
82
+ attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
83
+ The dropout ratio for the attention probabilities.
84
+ max_position_embeddings (:obj:`int`, optional, defaults to 512):
85
+ The maximum sequence length that this model might ever be used with.
86
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
87
+ type_vocab_size (:obj:`int`, optional, defaults to 2):
88
+ The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
89
+ initializer_range (:obj:`float`, optional, defaults to 0.02):
90
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
91
+ layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
92
+ The epsilon used by the layer normalization layers.
93
+ gradient_checkpointing (:obj:`bool`, optional, defaults to False):
94
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
95
+
96
+ Example::
97
+
98
+ >>> from transformers import BertModel, BertConfig
99
+
100
+ >>> # Initializing a BERT bert-base-uncased style configuration
101
+ >>> configuration = BertConfig()
102
+
103
+ >>> # Initializing a model from the bert-base-uncased style configuration
104
+ >>> model = BertModel(configuration)
105
+
106
+ >>> # Accessing the model configuration
107
+ >>> configuration = model.config
108
+ """
109
+ model_type = "bert"
110
+
111
+ def __init__(
112
+ self,
113
+ vocab_size=30522,
114
+ hidden_size=768,
115
+ num_hidden_layers=12,
116
+ num_attention_heads=12,
117
+ intermediate_size=3072,
118
+ hidden_act="gelu",
119
+ hidden_dropout_prob=0.1,
120
+ attention_probs_dropout_prob=0.1,
121
+ max_position_embeddings=512,
122
+ type_vocab_size=2,
123
+ initializer_range=0.02,
124
+ layer_norm_eps=1e-12,
125
+ pad_token_id=0,
126
+ gradient_checkpointing=False,
127
+ **kwargs
128
+ ):
129
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
130
+
131
+ self.vocab_size = vocab_size
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.hidden_act = hidden_act
136
+ self.intermediate_size = intermediate_size
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.layer_norm_eps = layer_norm_eps
143
+ self.gradient_checkpointing = gradient_checkpointing
CGFormer/bert/configuration_utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Configuration base class and utilities."""
17
+
18
+
19
+ import copy
20
+ import json
21
+ import logging
22
+ import os
23
+ from typing import Dict, Tuple
24
+
25
+ from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class PretrainedConfig(object):
32
+ r""" Base class for all configuration classes.
33
+ Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations.
34
+
35
+ Note:
36
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights.
37
+ It only affects the model's configuration.
38
+
39
+ Class attributes (overridden by derived classes):
40
+ - ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
41
+
42
+ Args:
43
+ finetuning_task (:obj:`string` or :obj:`None`, `optional`, defaults to :obj:`None`):
44
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint.
45
+ num_labels (:obj:`int`, `optional`, defaults to `2`):
46
+ Number of classes to use when the model is a classification model (sequences/tokens)
47
+ output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`False`):
48
+ Should the model returns all hidden-states.
49
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`False`):
50
+ Should the model returns all attentions.
51
+ torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
52
+ Is the model used with Torchscript (for PyTorch models).
53
+ """
54
+ model_type: str = ""
55
+
56
+ def __init__(self, **kwargs):
57
+ # Attributes with defaults
58
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
59
+ self.output_attentions = kwargs.pop("output_attentions", False)
60
+ self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
61
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
62
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
63
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
64
+
65
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
66
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
67
+ self.is_decoder = kwargs.pop("is_decoder", False)
68
+
69
+ # Parameters for sequence generation
70
+ self.max_length = kwargs.pop("max_length", 20)
71
+ self.min_length = kwargs.pop("min_length", 0)
72
+ self.do_sample = kwargs.pop("do_sample", False)
73
+ self.early_stopping = kwargs.pop("early_stopping", False)
74
+ self.num_beams = kwargs.pop("num_beams", 1)
75
+ self.temperature = kwargs.pop("temperature", 1.0)
76
+ self.top_k = kwargs.pop("top_k", 50)
77
+ self.top_p = kwargs.pop("top_p", 1.0)
78
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
79
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
80
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
81
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
82
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
83
+
84
+ # Fine-tuning task arguments
85
+ self.architectures = kwargs.pop("architectures", None)
86
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
87
+ self.id2label = kwargs.pop("id2label", None)
88
+ self.label2id = kwargs.pop("label2id", None)
89
+ if self.id2label is not None:
90
+ kwargs.pop("num_labels", None)
91
+ self.id2label = dict((int(key), value) for key, value in self.id2label.items())
92
+ # Keys are always strings in JSON so convert ids to int here.
93
+ else:
94
+ self.num_labels = kwargs.pop("num_labels", 2)
95
+
96
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
97
+ self.prefix = kwargs.pop("prefix", None)
98
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
99
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
100
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
101
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
102
+
103
+ # task specific arguments
104
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
105
+
106
+ # TPU arguments
107
+ self.xla_device = kwargs.pop("xla_device", None)
108
+
109
+ # Additional attributes without default values
110
+ for key, value in kwargs.items():
111
+ try:
112
+ setattr(self, key, value)
113
+ except AttributeError as err:
114
+ logger.error("Can't set {} with value {} for {}".format(key, value, self))
115
+ raise err
116
+
117
+ @property
118
+ def num_labels(self):
119
+ return len(self.id2label)
120
+
121
+ @num_labels.setter
122
+ def num_labels(self, num_labels):
123
+ self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
124
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
125
+
126
+ def save_pretrained(self, save_directory):
127
+ """
128
+ Save a configuration object to the directory `save_directory`, so that it
129
+ can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
130
+
131
+ Args:
132
+ save_directory (:obj:`string`):
133
+ Directory where the configuration JSON file will be saved.
134
+ """
135
+ if os.path.isfile(save_directory):
136
+ raise AssertionError("Provided path ({}) should be a directory, not a file".format(save_directory))
137
+ os.makedirs(save_directory, exist_ok=True)
138
+ # If we save using the predefined names, we can load using `from_pretrained`
139
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
140
+
141
+ self.to_json_file(output_config_file, use_diff=True)
142
+ logger.info("Configuration saved in {}".format(output_config_file))
143
+
144
+ @classmethod
145
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "PretrainedConfig":
146
+ r"""
147
+
148
+ Instantiate a :class:`~transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration.
149
+
150
+ Args:
151
+ pretrained_model_name_or_path (:obj:`string`):
152
+ either:
153
+ - a string with the `shortcut name` of a pre-trained model configuration to load from cache or
154
+ download, e.g.: ``bert-base-uncased``.
155
+ - a string with the `identifier name` of a pre-trained model configuration that was user-uploaded to
156
+ our S3, e.g.: ``dbmdz/bert-base-german-cased``.
157
+ - a path to a `directory` containing a configuration file saved using the
158
+ :func:`~transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``.
159
+ - a path or url to a saved configuration JSON `file`, e.g.:
160
+ ``./my_model_directory/configuration.json``.
161
+ cache_dir (:obj:`string`, `optional`):
162
+ Path to a directory in which a downloaded pre-trained model
163
+ configuration should be cached if the standard cache should not be used.
164
+ kwargs (:obj:`Dict[str, any]`, `optional`):
165
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
166
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is
167
+ controlled by the `return_unused_kwargs` keyword parameter.
168
+ force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
169
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exist.
170
+ resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
171
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
172
+ proxies (:obj:`Dict`, `optional`):
173
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.:
174
+ :obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.`
175
+ The proxies are used on each request.
176
+ return_unused_kwargs: (`optional`) bool:
177
+ If False, then this function returns just the final configuration object.
178
+ If True, then this functions returns a :obj:`Tuple(config, unused_kwargs)` where `unused_kwargs` is a
179
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part
180
+ of kwargs which has not been used to update `config` and is otherwise ignored.
181
+
182
+ Returns:
183
+ :class:`PretrainedConfig`: An instance of a configuration object
184
+
185
+ Examples::
186
+
187
+ # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a
188
+ # derived class: BertConfig
189
+ config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
190
+ config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
191
+ config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
192
+ config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False)
193
+ assert config.output_attention == True
194
+ config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True,
195
+ foo=False, return_unused_kwargs=True)
196
+ assert config.output_attention == True
197
+ assert unused_kwargs == {'foo': False}
198
+
199
+ """
200
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
201
+ return cls.from_dict(config_dict, **kwargs)
202
+
203
+ @classmethod
204
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
205
+ """
206
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
207
+ for instantiating a Config using `from_dict`.
208
+
209
+ Parameters:
210
+ pretrained_model_name_or_path (:obj:`string`):
211
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
212
+
213
+ Returns:
214
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
215
+
216
+ """
217
+ cache_dir = kwargs.pop("cache_dir", None)
218
+ force_download = kwargs.pop("force_download", False)
219
+ resume_download = kwargs.pop("resume_download", False)
220
+ proxies = kwargs.pop("proxies", None)
221
+ local_files_only = kwargs.pop("local_files_only", False)
222
+
223
+ if os.path.isdir(pretrained_model_name_or_path):
224
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
225
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
226
+ config_file = pretrained_model_name_or_path
227
+ else:
228
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
229
+
230
+ try:
231
+ # Load from URL or cache if already cached
232
+ resolved_config_file = cached_path(
233
+ config_file,
234
+ cache_dir=cache_dir,
235
+ force_download=force_download,
236
+ proxies=proxies,
237
+ resume_download=resume_download,
238
+ local_files_only=local_files_only,
239
+ )
240
+ # Load config dict
241
+ if resolved_config_file is None:
242
+ raise EnvironmentError
243
+ config_dict = cls._dict_from_json_file(resolved_config_file)
244
+
245
+ except EnvironmentError:
246
+ msg = (
247
+ f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
248
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
249
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
250
+ )
251
+ raise EnvironmentError(msg)
252
+
253
+ except json.JSONDecodeError:
254
+ msg = (
255
+ "Couldn't reach server at '{}' to download configuration file or "
256
+ "configuration file is not a valid JSON file. "
257
+ "Please check network or file content here: {}.".format(config_file, resolved_config_file)
258
+ )
259
+ raise EnvironmentError(msg)
260
+
261
+ if resolved_config_file == config_file:
262
+ logger.info("loading configuration file {}".format(config_file))
263
+ else:
264
+ logger.info("loading configuration file {} from cache at {}".format(config_file, resolved_config_file))
265
+
266
+ return config_dict, kwargs
267
+
268
+ @classmethod
269
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
270
+ """
271
+ Constructs a `Config` from a Python dictionary of parameters.
272
+
273
+ Args:
274
+ config_dict (:obj:`Dict[str, any]`):
275
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved
276
+ from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict`
277
+ method.
278
+ kwargs (:obj:`Dict[str, any]`):
279
+ Additional parameters from which to initialize the configuration object.
280
+
281
+ Returns:
282
+ :class:`PretrainedConfig`: An instance of a configuration object
283
+ """
284
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
285
+
286
+ config = cls(**config_dict)
287
+
288
+ if hasattr(config, "pruned_heads"):
289
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
290
+
291
+ # Update config with kwargs if needed
292
+ to_remove = []
293
+ for key, value in kwargs.items():
294
+ if hasattr(config, key):
295
+ setattr(config, key, value)
296
+ to_remove.append(key)
297
+ for key in to_remove:
298
+ kwargs.pop(key, None)
299
+
300
+ logger.info("Model config %s", str(config))
301
+ if return_unused_kwargs:
302
+ return config, kwargs
303
+ else:
304
+ return config
305
+
306
+ @classmethod
307
+ def from_json_file(cls, json_file: str) -> "PretrainedConfig":
308
+ """
309
+ Constructs a `Config` from the path to a json file of parameters.
310
+
311
+ Args:
312
+ json_file (:obj:`string`):
313
+ Path to the JSON file containing the parameters.
314
+
315
+ Returns:
316
+ :class:`PretrainedConfig`: An instance of a configuration object
317
+
318
+ """
319
+ config_dict = cls._dict_from_json_file(json_file)
320
+ return cls(**config_dict)
321
+
322
+ @classmethod
323
+ def _dict_from_json_file(cls, json_file: str):
324
+ with open(json_file, "r", encoding="utf-8") as reader:
325
+ text = reader.read()
326
+ return json.loads(text)
327
+
328
+ def __eq__(self, other):
329
+ return self.__dict__ == other.__dict__
330
+
331
+ def __repr__(self):
332
+ return "{} {}".format(self.__class__.__name__, self.to_json_string())
333
+
334
+ def to_diff_dict(self):
335
+ """
336
+ Removes all attributes from config which correspond to the default
337
+ config attributes for better readability and serializes to a Python
338
+ dictionary.
339
+
340
+ Returns:
341
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
342
+ """
343
+ config_dict = self.to_dict()
344
+
345
+ # get the default config dict
346
+ default_config_dict = PretrainedConfig().to_dict()
347
+
348
+ serializable_config_dict = {}
349
+
350
+ # only serialize values that differ from the default config
351
+ for key, value in config_dict.items():
352
+ if key not in default_config_dict or value != default_config_dict[key]:
353
+ serializable_config_dict[key] = value
354
+
355
+ return serializable_config_dict
356
+
357
+ def to_dict(self):
358
+ """
359
+ Serializes this instance to a Python dictionary.
360
+
361
+ Returns:
362
+ :obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
363
+ """
364
+ output = copy.deepcopy(self.__dict__)
365
+ if hasattr(self.__class__, "model_type"):
366
+ output["model_type"] = self.__class__.model_type
367
+ return output
368
+
369
+ def to_json_string(self, use_diff=True):
370
+ """
371
+ Serializes this instance to a JSON string.
372
+
373
+ Args:
374
+ use_diff (:obj:`bool`):
375
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON string.
376
+
377
+ Returns:
378
+ :obj:`string`: String containing all the attributes that make up this configuration instance in JSON format.
379
+ """
380
+ if use_diff is True:
381
+ config_dict = self.to_diff_dict()
382
+ else:
383
+ config_dict = self.to_dict()
384
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
385
+
386
+ def to_json_file(self, json_file_path, use_diff=True):
387
+ """
388
+ Save this instance to a json file.
389
+
390
+ Args:
391
+ json_file_path (:obj:`string`):
392
+ Path to the JSON file in which this configuration instance's parameters will be saved.
393
+ use_diff (:obj:`bool`):
394
+ If set to True, only the difference between the config instance and the default PretrainedConfig() is serialized to JSON file.
395
+ """
396
+ with open(json_file_path, "w", encoding="utf-8") as writer:
397
+ writer.write(self.to_json_string(use_diff=use_diff))
398
+
399
+ def update(self, config_dict: Dict):
400
+ """
401
+ Updates attributes of this class
402
+ with attributes from `config_dict`.
403
+
404
+ Args:
405
+ :obj:`Dict[str, any]`: Dictionary of attributes that shall be updated for this class.
406
+ """
407
+ for key, value in config_dict.items():
408
+ setattr(self, key, value)
CGFormer/bert/file_utils.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import fnmatch
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import sys
13
+ import tarfile
14
+ import tempfile
15
+ from contextlib import contextmanager
16
+ from functools import partial, wraps
17
+ from hashlib import sha256
18
+ from pathlib import Path
19
+ from typing import Dict, Optional, Union
20
+ from urllib.parse import urlparse
21
+ from zipfile import ZipFile, is_zipfile
22
+
23
+ import requests
24
+ from filelock import FileLock
25
+ from tqdm.auto import tqdm
26
+
27
+ #from . import __version__
28
+ __version__ = "3.0.2"
29
+
30
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
31
+
32
+ try:
33
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
34
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
35
+ if USE_TORCH in ("1", "ON", "YES", "AUTO") and USE_TF not in ("1", "ON", "YES"):
36
+ import torch
37
+
38
+ _torch_available = True # pylint: disable=invalid-name
39
+ logger.info("PyTorch version {} available.".format(torch.__version__))
40
+ else:
41
+ logger.info("Disabling PyTorch because USE_TF is set")
42
+ _torch_available = False
43
+ except ImportError:
44
+ _torch_available = False # pylint: disable=invalid-name
45
+
46
+ try:
47
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
48
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
49
+
50
+ if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
51
+ import tensorflow as tf
52
+
53
+ assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
54
+ _tf_available = True # pylint: disable=invalid-name
55
+ logger.info("TensorFlow version {} available.".format(tf.__version__))
56
+ else:
57
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
58
+ _tf_available = False
59
+ except (ImportError, AssertionError):
60
+ _tf_available = False # pylint: disable=invalid-name
61
+
62
+
63
+ try:
64
+ from torch.hub import _get_torch_home
65
+
66
+ torch_cache_home = _get_torch_home()
67
+ except ImportError:
68
+ torch_cache_home = os.path.expanduser(
69
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
70
+ )
71
+
72
+
73
+ try:
74
+ import torch_xla.core.xla_model as xm # noqa: F401
75
+
76
+ if _torch_available:
77
+ _torch_tpu_available = True # pylint: disable=
78
+ else:
79
+ _torch_tpu_available = False
80
+ except ImportError:
81
+ _torch_tpu_available = False
82
+
83
+
84
+ try:
85
+ import psutil # noqa: F401
86
+
87
+ _psutil_available = True
88
+
89
+ except ImportError:
90
+ _psutil_available = False
91
+
92
+
93
+ try:
94
+ import py3nvml # noqa: F401
95
+
96
+ _py3nvml_available = True
97
+
98
+ except ImportError:
99
+ _py3nvml_available = False
100
+
101
+
102
+ try:
103
+ from apex import amp # noqa: F401
104
+
105
+ _has_apex = True
106
+ except ImportError:
107
+ _has_apex = False
108
+
109
+ default_cache_path = os.path.join(torch_cache_home, "transformers")
110
+
111
+
112
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
113
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
114
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
115
+
116
+ WEIGHTS_NAME = "pytorch_model.bin"
117
+ TF2_WEIGHTS_NAME = "tf_model.h5"
118
+ TF_WEIGHTS_NAME = "model.ckpt"
119
+ CONFIG_NAME = "config.json"
120
+ MODEL_CARD_NAME = "modelcard.json"
121
+
122
+
123
+ MULTIPLE_CHOICE_DUMMY_INPUTS = [[[0], [1]], [[0], [1]]]
124
+ DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
125
+ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
126
+
127
+ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
128
+ CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
129
+
130
+
131
+ def is_torch_available():
132
+ return _torch_available
133
+
134
+
135
+ def is_tf_available():
136
+ return _tf_available
137
+
138
+
139
+ def is_torch_tpu_available():
140
+ return _torch_tpu_available
141
+
142
+
143
+ def is_psutil_available():
144
+ return _psutil_available
145
+
146
+
147
+ def is_py3nvml_available():
148
+ return _py3nvml_available
149
+
150
+
151
+ def is_apex_available():
152
+ return _has_apex
153
+
154
+
155
+ def add_start_docstrings(*docstr):
156
+ def docstring_decorator(fn):
157
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
158
+ return fn
159
+
160
+ return docstring_decorator
161
+
162
+
163
+ def add_start_docstrings_to_callable(*docstr):
164
+ def docstring_decorator(fn):
165
+ class_name = ":class:`~transformers.{}`".format(fn.__qualname__.split(".")[0])
166
+ intro = " The {} forward method, overrides the :func:`__call__` special method.".format(class_name)
167
+ note = r"""
168
+
169
+ .. note::
170
+ Although the recipe for forward pass needs to be defined within
171
+ this function, one should call the :class:`Module` instance afterwards
172
+ instead of this since the former takes care of running the
173
+ pre and post processing steps while the latter silently ignores them.
174
+ """
175
+ fn.__doc__ = intro + note + "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
176
+ return fn
177
+
178
+ return docstring_decorator
179
+
180
+
181
+ def add_end_docstrings(*docstr):
182
+ def docstring_decorator(fn):
183
+ fn.__doc__ = fn.__doc__ + "".join(docstr)
184
+ return fn
185
+
186
+ return docstring_decorator
187
+
188
+
189
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
190
+ Example::
191
+
192
+ >>> from transformers import {tokenizer_class}, {model_class}
193
+ >>> import torch
194
+
195
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
196
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
197
+
198
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
199
+ >>> labels = torch.tensor([1] * inputs["input_ids"].size(1)).unsqueeze(0) # Batch size 1
200
+
201
+ >>> outputs = model(**inputs, labels=labels)
202
+ >>> loss, scores = outputs[:2]
203
+ """
204
+
205
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
206
+ Example::
207
+
208
+ >>> from transformers import {tokenizer_class}, {model_class}
209
+ >>> import torch
210
+
211
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
212
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
213
+
214
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
215
+ >>> start_positions = torch.tensor([1])
216
+ >>> end_positions = torch.tensor([3])
217
+
218
+ >>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
219
+ >>> loss, start_scores, end_scores = outputs[:3]
220
+ """
221
+
222
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
223
+ Example::
224
+
225
+ >>> from transformers import {tokenizer_class}, {model_class}
226
+ >>> import torch
227
+
228
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
229
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
230
+
231
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
232
+ >>> labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
233
+ >>> outputs = model(**inputs, labels=labels)
234
+ >>> loss, logits = outputs[:2]
235
+ """
236
+
237
+ PT_MASKED_LM_SAMPLE = r"""
238
+ Example::
239
+
240
+ >>> from transformers import {tokenizer_class}, {model_class}
241
+ >>> import torch
242
+
243
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
244
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
245
+
246
+ >>> input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt")["input_ids"]
247
+
248
+ >>> outputs = model(input_ids, labels=input_ids)
249
+ >>> loss, prediction_scores = outputs[:2]
250
+ """
251
+
252
+ PT_BASE_MODEL_SAMPLE = r"""
253
+ Example::
254
+
255
+ >>> from transformers import {tokenizer_class}, {model_class}
256
+ >>> import torch
257
+
258
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
259
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
260
+
261
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
262
+ >>> outputs = model(**inputs)
263
+
264
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
265
+ """
266
+
267
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
268
+ Example::
269
+
270
+ >>> from transformers import {tokenizer_class}, {model_class}
271
+ >>> import torch
272
+
273
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
274
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
275
+
276
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
277
+ >>> choice0 = "It is eaten with a fork and a knife."
278
+ >>> choice1 = "It is eaten while held in the hand."
279
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
280
+
281
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)
282
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, labels=labels) # batch size is 1
283
+
284
+ >>> # the linear classifier still needs to be trained
285
+ >>> loss, logits = outputs[:2]
286
+ """
287
+
288
+ PT_CAUSAL_LM_SAMPLE = r"""
289
+ Example::
290
+
291
+ >>> import torch
292
+ >>> from transformers import {tokenizer_class}, {model_class}
293
+
294
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
295
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
296
+
297
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
298
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
299
+ >>> loss, logits = outputs[:2]
300
+ """
301
+
302
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
303
+ Example::
304
+
305
+ >>> from transformers import {tokenizer_class}, {model_class}
306
+ >>> import tensorflow as tf
307
+
308
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
309
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
310
+
311
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
312
+ >>> input_ids = inputs["input_ids"]
313
+ >>> inputs["labels"] = tf.reshape(tf.constant([1] * tf.size(input_ids).numpy()), (-1, tf.size(input_ids))) # Batch size 1
314
+
315
+ >>> outputs = model(inputs)
316
+ >>> loss, scores = outputs[:2]
317
+ """
318
+
319
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
320
+ Example::
321
+
322
+ >>> from transformers import {tokenizer_class}, {model_class}
323
+ >>> import tensorflow as tf
324
+
325
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
326
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
327
+
328
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
329
+ >>> input_dict = tokenizer(question, text, return_tensors='tf')
330
+ >>> start_scores, end_scores = model(input_dict)
331
+
332
+ >>> all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
333
+ >>> answer = ' '.join(all_tokens[tf.math.argmax(start_scores, 1)[0] : tf.math.argmax(end_scores, 1)[0]+1])
334
+ """
335
+
336
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
337
+ Example::
338
+
339
+ >>> from transformers import {tokenizer_class}, {model_class}
340
+ >>> import tensorflow as tf
341
+
342
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
343
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
344
+
345
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
346
+ >>> inputs["labels"] = tf.reshape(tf.constant(1), (-1, 1)) # Batch size 1
347
+
348
+ >>> outputs = model(inputs)
349
+ >>> loss, logits = outputs[:2]
350
+ """
351
+
352
+ TF_MASKED_LM_SAMPLE = r"""
353
+ Example::
354
+ >>> from transformers import {tokenizer_class}, {model_class}
355
+ >>> import tensorflow as tf
356
+
357
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
358
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
359
+
360
+ >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
361
+
362
+ >>> outputs = model(input_ids)
363
+ >>> prediction_scores = outputs[0]
364
+ """
365
+
366
+ TF_BASE_MODEL_SAMPLE = r"""
367
+ Example::
368
+
369
+ >>> from transformers import {tokenizer_class}, {model_class}
370
+ >>> import tensorflow as tf
371
+
372
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
373
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
374
+
375
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
376
+ >>> outputs = model(inputs)
377
+
378
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
379
+ """
380
+
381
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
382
+ Example::
383
+
384
+ >>> from transformers import {tokenizer_class}, {model_class}
385
+ >>> import tensorflow as tf
386
+
387
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
388
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
389
+
390
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
391
+ >>> choice0 = "It is eaten with a fork and a knife."
392
+ >>> choice1 = "It is eaten while held in the hand."
393
+
394
+ >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='tf', padding=True)
395
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
396
+ >>> outputs = model(inputs) # batch size is 1
397
+
398
+ >>> # the linear classifier still needs to be trained
399
+ >>> logits = outputs[0]
400
+ """
401
+
402
+ TF_CAUSAL_LM_SAMPLE = r"""
403
+ Example::
404
+
405
+ >>> from transformers import {tokenizer_class}, {model_class}
406
+ >>> import tensorflow as tf
407
+
408
+ >>> tokenizer = {tokenizer_class}.from_pretrained('{checkpoint}')
409
+ >>> model = {model_class}.from_pretrained('{checkpoint}')
410
+
411
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
412
+ >>> outputs = model(inputs)
413
+ >>> logits = outputs[0]
414
+ """
415
+
416
+
417
+ def add_code_sample_docstrings(*docstr, tokenizer_class=None, checkpoint=None):
418
+ def docstring_decorator(fn):
419
+ model_class = fn.__qualname__.split(".")[0]
420
+ is_tf_class = model_class[:2] == "TF"
421
+
422
+ if "SequenceClassification" in model_class:
423
+ code_sample = TF_SEQUENCE_CLASSIFICATION_SAMPLE if is_tf_class else PT_SEQUENCE_CLASSIFICATION_SAMPLE
424
+ elif "QuestionAnswering" in model_class:
425
+ code_sample = TF_QUESTION_ANSWERING_SAMPLE if is_tf_class else PT_QUESTION_ANSWERING_SAMPLE
426
+ elif "TokenClassification" in model_class:
427
+ code_sample = TF_TOKEN_CLASSIFICATION_SAMPLE if is_tf_class else PT_TOKEN_CLASSIFICATION_SAMPLE
428
+ elif "MultipleChoice" in model_class:
429
+ code_sample = TF_MULTIPLE_CHOICE_SAMPLE if is_tf_class else PT_MULTIPLE_CHOICE_SAMPLE
430
+ elif "MaskedLM" in model_class:
431
+ code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
432
+ elif "LMHead" in model_class:
433
+ code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
434
+ elif "Model" in model_class:
435
+ code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
436
+ else:
437
+ raise ValueError(f"Docstring can't be built for model {model_class}")
438
+
439
+ built_doc = code_sample.format(model_class=model_class, tokenizer_class=tokenizer_class, checkpoint=checkpoint)
440
+ fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + built_doc
441
+ return fn
442
+
443
+ return docstring_decorator
444
+
445
+
446
+ def is_remote_url(url_or_filename):
447
+ parsed = urlparse(url_or_filename)
448
+ return parsed.scheme in ("http", "https")
449
+
450
+
451
+ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
452
+ """
453
+ Resolve a model identifier, and a file name, to a HF-hosted url
454
+ on either S3 or Cloudfront (a Content Delivery Network, or CDN).
455
+
456
+ Cloudfront is replicated over the globe so downloads are way faster
457
+ for the end user (and it also lowers our bandwidth costs). However, it
458
+ is more aggressively cached by default, so may not always reflect the
459
+ latest changes to the underlying file (default TTL is 24 hours).
460
+
461
+ In terms of client-side caching from this library, even though
462
+ Cloudfront relays the ETags from S3, using one or the other
463
+ (or switching from one to the other) will affect caching: cached files
464
+ are not shared between the two because the cached file's name contains
465
+ a hash of the url.
466
+ """
467
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
468
+ legacy_format = "/" not in model_id
469
+ if legacy_format:
470
+ return f"{endpoint}/{model_id}-{filename}"
471
+ else:
472
+ return f"{endpoint}/{model_id}/{filename}"
473
+
474
+
475
+ def url_to_filename(url, etag=None):
476
+ """
477
+ Convert `url` into a hashed filename in a repeatable way.
478
+ If `etag` is specified, append its hash to the url's, delimited
479
+ by a period.
480
+ If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name
481
+ so that TF 2.0 can identify it as a HDF5 file
482
+ (see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
483
+ """
484
+ url_bytes = url.encode("utf-8")
485
+ url_hash = sha256(url_bytes)
486
+ filename = url_hash.hexdigest()
487
+
488
+ if etag:
489
+ etag_bytes = etag.encode("utf-8")
490
+ etag_hash = sha256(etag_bytes)
491
+ filename += "." + etag_hash.hexdigest()
492
+
493
+ if url.endswith(".h5"):
494
+ filename += ".h5"
495
+
496
+ return filename
497
+
498
+
499
+ def filename_to_url(filename, cache_dir=None):
500
+ """
501
+ Return the url and etag (which may be ``None``) stored for `filename`.
502
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
503
+ """
504
+ if cache_dir is None:
505
+ cache_dir = TRANSFORMERS_CACHE
506
+ if isinstance(cache_dir, Path):
507
+ cache_dir = str(cache_dir)
508
+
509
+ cache_path = os.path.join(cache_dir, filename)
510
+ if not os.path.exists(cache_path):
511
+ raise EnvironmentError("file {} not found".format(cache_path))
512
+
513
+ meta_path = cache_path + ".json"
514
+ if not os.path.exists(meta_path):
515
+ raise EnvironmentError("file {} not found".format(meta_path))
516
+
517
+ with open(meta_path, encoding="utf-8") as meta_file:
518
+ metadata = json.load(meta_file)
519
+ url = metadata["url"]
520
+ etag = metadata["etag"]
521
+
522
+ return url, etag
523
+
524
+
525
+ def cached_path(
526
+ url_or_filename,
527
+ cache_dir=None,
528
+ force_download=False,
529
+ proxies=None,
530
+ resume_download=False,
531
+ user_agent: Union[Dict, str, None] = None,
532
+ extract_compressed_file=False,
533
+ force_extract=False,
534
+ local_files_only=False,
535
+ ) -> Optional[str]:
536
+ """
537
+ Given something that might be a URL (or might be a local path),
538
+ determine which. If it's a URL, download the file and cache it, and
539
+ return the path to the cached file. If it's already a local path,
540
+ make sure the file exists and then return the path.
541
+ Args:
542
+ cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
543
+ force_download: if True, re-dowload the file even if it's already cached in the cache dir.
544
+ resume_download: if True, resume the download if incompletly recieved file is found.
545
+ user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
546
+ extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
547
+ file in a folder along the archive.
548
+ force_extract: if True when extract_compressed_file is True and the archive was already extracted,
549
+ re-extract the archive and overide the folder where it was extracted.
550
+
551
+ Return:
552
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
553
+ Local path (string) otherwise
554
+ """
555
+ if cache_dir is None:
556
+ cache_dir = TRANSFORMERS_CACHE
557
+ if isinstance(url_or_filename, Path):
558
+ url_or_filename = str(url_or_filename)
559
+ if isinstance(cache_dir, Path):
560
+ cache_dir = str(cache_dir)
561
+
562
+ if is_remote_url(url_or_filename):
563
+ # URL, so get it from the cache (downloading if necessary)
564
+ output_path = get_from_cache(
565
+ url_or_filename,
566
+ cache_dir=cache_dir,
567
+ force_download=force_download,
568
+ proxies=proxies,
569
+ resume_download=resume_download,
570
+ user_agent=user_agent,
571
+ local_files_only=local_files_only,
572
+ )
573
+ elif os.path.exists(url_or_filename):
574
+ # File, and it exists.
575
+ output_path = url_or_filename
576
+ elif urlparse(url_or_filename).scheme == "":
577
+ # File, but it doesn't exist.
578
+ raise EnvironmentError("file {} not found".format(url_or_filename))
579
+ else:
580
+ # Something unknown
581
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
582
+
583
+ if extract_compressed_file:
584
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
585
+ return output_path
586
+
587
+ # Path where we extract compressed archives
588
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
589
+ output_dir, output_file = os.path.split(output_path)
590
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
591
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
592
+
593
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
594
+ return output_path_extracted
595
+
596
+ # Prevent parallel extractions
597
+ lock_path = output_path + ".lock"
598
+ with FileLock(lock_path):
599
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
600
+ os.makedirs(output_path_extracted)
601
+ if is_zipfile(output_path):
602
+ with ZipFile(output_path, "r") as zip_file:
603
+ zip_file.extractall(output_path_extracted)
604
+ zip_file.close()
605
+ elif tarfile.is_tarfile(output_path):
606
+ tar_file = tarfile.open(output_path)
607
+ tar_file.extractall(output_path_extracted)
608
+ tar_file.close()
609
+ else:
610
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
611
+
612
+ return output_path_extracted
613
+
614
+ return output_path
615
+
616
+
617
+ def http_get(url, temp_file, proxies=None, resume_size=0, user_agent: Union[Dict, str, None] = None):
618
+ ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
619
+ if is_torch_available():
620
+ ua += "; torch/{}".format(torch.__version__)
621
+ if is_tf_available():
622
+ ua += "; tensorflow/{}".format(tf.__version__)
623
+ if isinstance(user_agent, dict):
624
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
625
+ elif isinstance(user_agent, str):
626
+ ua += "; " + user_agent
627
+ headers = {"user-agent": ua}
628
+ if resume_size > 0:
629
+ headers["Range"] = "bytes=%d-" % (resume_size,)
630
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
631
+ if response.status_code == 416: # Range not satisfiable
632
+ return
633
+ content_length = response.headers.get("Content-Length")
634
+ total = resume_size + int(content_length) if content_length is not None else None
635
+ progress = tqdm(
636
+ unit="B",
637
+ unit_scale=True,
638
+ total=total,
639
+ initial=resume_size,
640
+ desc="Downloading",
641
+ disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
642
+ )
643
+ for chunk in response.iter_content(chunk_size=1024):
644
+ if chunk: # filter out keep-alive new chunks
645
+ progress.update(len(chunk))
646
+ temp_file.write(chunk)
647
+ progress.close()
648
+
649
+
650
+ def get_from_cache(
651
+ url,
652
+ cache_dir=None,
653
+ force_download=False,
654
+ proxies=None,
655
+ etag_timeout=10,
656
+ resume_download=False,
657
+ user_agent: Union[Dict, str, None] = None,
658
+ local_files_only=False,
659
+ ) -> Optional[str]:
660
+ """
661
+ Given a URL, look for the corresponding file in the local cache.
662
+ If it's not there, download it. Then return the path to the cached file.
663
+
664
+ Return:
665
+ None in case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
666
+ Local path (string) otherwise
667
+ """
668
+ if cache_dir is None:
669
+ cache_dir = TRANSFORMERS_CACHE
670
+ if isinstance(cache_dir, Path):
671
+ cache_dir = str(cache_dir)
672
+
673
+ os.makedirs(cache_dir, exist_ok=True)
674
+
675
+ etag = None
676
+ if not local_files_only:
677
+ try:
678
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
679
+ if response.status_code == 200:
680
+ etag = response.headers.get("ETag")
681
+ except (EnvironmentError, requests.exceptions.Timeout):
682
+ # etag is already None
683
+ pass
684
+
685
+ filename = url_to_filename(url, etag)
686
+
687
+ # get cache path to put the file
688
+ cache_path = os.path.join(cache_dir, filename)
689
+
690
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
691
+ # try to get the last downloaded one
692
+ if etag is None:
693
+ if os.path.exists(cache_path):
694
+ return cache_path
695
+ else:
696
+ matching_files = [
697
+ file
698
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
699
+ if not file.endswith(".json") and not file.endswith(".lock")
700
+ ]
701
+ if len(matching_files) > 0:
702
+ return os.path.join(cache_dir, matching_files[-1])
703
+ else:
704
+ # If files cannot be found and local_files_only=True,
705
+ # the models might've been found if local_files_only=False
706
+ # Notify the user about that
707
+ if local_files_only:
708
+ raise ValueError(
709
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
710
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
711
+ " to False."
712
+ )
713
+ return None
714
+
715
+ # From now on, etag is not None.
716
+ if os.path.exists(cache_path) and not force_download:
717
+ return cache_path
718
+
719
+ # Prevent parallel downloads of the same file with a lock.
720
+ lock_path = cache_path + ".lock"
721
+ with FileLock(lock_path):
722
+
723
+ # If the download just completed while the lock was activated.
724
+ if os.path.exists(cache_path) and not force_download:
725
+ # Even if returning early like here, the lock will be released.
726
+ return cache_path
727
+
728
+ if resume_download:
729
+ incomplete_path = cache_path + ".incomplete"
730
+
731
+ @contextmanager
732
+ def _resumable_file_manager():
733
+ with open(incomplete_path, "a+b") as f:
734
+ yield f
735
+
736
+ temp_file_manager = _resumable_file_manager
737
+ if os.path.exists(incomplete_path):
738
+ resume_size = os.stat(incomplete_path).st_size
739
+ else:
740
+ resume_size = 0
741
+ else:
742
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
743
+ resume_size = 0
744
+
745
+ # Download to temporary file, then copy to cache dir once finished.
746
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
747
+ with temp_file_manager() as temp_file:
748
+ logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)
749
+
750
+ http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
751
+
752
+ logger.info("storing %s in cache at %s", url, cache_path)
753
+ os.replace(temp_file.name, cache_path)
754
+
755
+ logger.info("creating metadata file for %s", cache_path)
756
+ meta = {"url": url, "etag": etag}
757
+ meta_path = cache_path + ".json"
758
+ with open(meta_path, "w") as meta_file:
759
+ json.dump(meta, meta_file)
760
+
761
+ return cache_path
762
+
763
+
764
+ class cached_property(property):
765
+ """
766
+ Descriptor that mimics @property but caches output in member variable.
767
+
768
+ From tensorflow_datasets
769
+
770
+ Built-in in functools from Python 3.8.
771
+ """
772
+
773
+ def __get__(self, obj, objtype=None):
774
+ # See docs.python.org/3/howto/descriptor.html#properties
775
+ if obj is None:
776
+ return self
777
+ if self.fget is None:
778
+ raise AttributeError("unreadable attribute")
779
+ attr = "__cached_" + self.fget.__name__
780
+ cached = getattr(obj, attr, None)
781
+ if cached is None:
782
+ cached = self.fget(obj)
783
+ setattr(obj, attr, cached)
784
+ return cached
785
+
786
+
787
+ def torch_required(func):
788
+ # Chose a different decorator name than in tests so it's clear they are not the same.
789
+ @wraps(func)
790
+ def wrapper(*args, **kwargs):
791
+ if is_torch_available():
792
+ return func(*args, **kwargs)
793
+ else:
794
+ raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
795
+
796
+ return wrapper
797
+
798
+
799
+ def tf_required(func):
800
+ # Chose a different decorator name than in tests so it's clear they are not the same.
801
+ @wraps(func)
802
+ def wrapper(*args, **kwargs):
803
+ if is_tf_available():
804
+ return func(*args, **kwargs)
805
+ else:
806
+ raise ImportError(f"Method `{func.__name__}` requires TF.")
807
+
808
+ return wrapper
CGFormer/bert/generation_utils.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from typing import Iterable, Optional, Tuple
19
+
20
+ import torch
21
+ from torch import Tensor
22
+ from torch.nn import functional as F
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class GenerationMixin:
29
+ """
30
+ A class contraining all of the functions supporting generation, to be used as a mixin in PreTrainedModel.
31
+ """
32
+
33
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
34
+ return {"input_ids": input_ids}
35
+
36
+ def adjust_logits_during_generation(self, logits, **kwargs):
37
+ return logits
38
+
39
+ def _use_cache(self, outputs, use_cache):
40
+ """During generation, decide whether to pass the `past` variable to the next forward pass."""
41
+ if len(outputs) <= 1 or use_cache is False:
42
+ return False
43
+ if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
44
+ return False
45
+ return True
46
+
47
+ def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
48
+ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
49
+ for i in range(batch_size * num_beams):
50
+ for previous_token in set(prev_output_tokens[i].tolist()):
51
+ # if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
52
+ if lprobs[i, previous_token] < 0:
53
+ lprobs[i, previous_token] *= repetition_penalty
54
+ else:
55
+ lprobs[i, previous_token] /= repetition_penalty
56
+
57
+ def postprocess_next_token_scores(
58
+ self,
59
+ scores,
60
+ input_ids,
61
+ no_repeat_ngram_size,
62
+ bad_words_ids,
63
+ cur_len,
64
+ min_length,
65
+ max_length,
66
+ eos_token_id,
67
+ repetition_penalty,
68
+ batch_size,
69
+ num_beams,
70
+ ):
71
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
72
+ if repetition_penalty != 1.0:
73
+ self.enforce_repetition_penalty_(
74
+ scores, batch_size, num_beams, input_ids, repetition_penalty,
75
+ )
76
+
77
+ # set eos token prob to zero if min_length is not reached
78
+ if eos_token_id is not None and cur_len < min_length:
79
+ scores[:, eos_token_id] = -float("inf")
80
+
81
+ if no_repeat_ngram_size > 0:
82
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
83
+ num_batch_hypotheses = batch_size * num_beams
84
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
85
+ banned_batch_tokens = calc_banned_ngram_tokens(
86
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
87
+ )
88
+ for i, banned_tokens in enumerate(banned_batch_tokens):
89
+ scores[i, banned_tokens] = -float("inf")
90
+
91
+ if bad_words_ids is not None:
92
+ # calculate a list of banned tokens according to bad words
93
+ banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
94
+
95
+ for i, banned_tokens in enumerate(banned_tokens):
96
+ scores[i, banned_tokens] = -float("inf")
97
+
98
+ return scores
99
+
100
+ @torch.no_grad()
101
+ def generate(
102
+ self,
103
+ input_ids: Optional[torch.LongTensor] = None,
104
+ max_length: Optional[int] = None,
105
+ min_length: Optional[int] = None,
106
+ do_sample: Optional[bool] = None,
107
+ early_stopping: Optional[bool] = None,
108
+ num_beams: Optional[int] = None,
109
+ temperature: Optional[float] = None,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ repetition_penalty: Optional[float] = None,
113
+ bad_words_ids: Optional[Iterable[int]] = None,
114
+ bos_token_id: Optional[int] = None,
115
+ pad_token_id: Optional[int] = None,
116
+ eos_token_id: Optional[int] = None,
117
+ length_penalty: Optional[float] = None,
118
+ no_repeat_ngram_size: Optional[int] = None,
119
+ num_return_sequences: Optional[int] = None,
120
+ attention_mask: Optional[torch.LongTensor] = None,
121
+ decoder_start_token_id: Optional[int] = None,
122
+ use_cache: Optional[bool] = None,
123
+ **model_specific_kwargs
124
+ ) -> torch.LongTensor:
125
+ r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
126
+
127
+ Adapted in part from `Facebook's XLM beam search code`_.
128
+
129
+ .. _`Facebook's XLM beam search code`:
130
+ https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
131
+
132
+
133
+ Parameters:
134
+
135
+ input_ids: (`optional`) `torch.LongTensor` of shape `(batch_size, sequence_length)`
136
+ The sequence used as a prompt for the generation. If `None` the method initializes
137
+ it as an empty `torch.LongTensor` of shape `(1,)`.
138
+
139
+ max_length: (`optional`) int
140
+ The max length of the sequence to be generated. Between `min_length` and infinity. Default to 20.
141
+
142
+ min_length: (`optional`) int
143
+ The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
144
+
145
+ do_sample: (`optional`) bool
146
+ If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
147
+
148
+ early_stopping: (`optional`) bool
149
+ if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
150
+
151
+ num_beams: (`optional`) int
152
+ Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
153
+
154
+ temperature: (`optional`) float
155
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
156
+
157
+ top_k: (`optional`) int
158
+ The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
159
+
160
+ top_p: (`optional`) float
161
+ The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
162
+
163
+ repetition_penalty: (`optional`) float
164
+ The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
165
+
166
+ pad_token_id: (`optional`) int
167
+ Padding token. Default to specicic model pad_token_id or None if it does not exist.
168
+
169
+ bos_token_id: (`optional`) int
170
+ BOS token. Defaults to `bos_token_id` as defined in the models config.
171
+
172
+ eos_token_id: (`optional`) int
173
+ EOS token. Defaults to `eos_token_id` as defined in the models config.
174
+
175
+ length_penalty: (`optional`) float
176
+ Exponential penalty to the length. Default to 1.
177
+
178
+ no_repeat_ngram_size: (`optional`) int
179
+ If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
180
+ bad_words_ids: (`optional`) list of lists of int
181
+ `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
182
+
183
+ num_return_sequences: (`optional`) int
184
+ The number of independently computed returned sequences for each element in the batch. Default to 1.
185
+
186
+ attention_mask (`optional`) obj: `torch.LongTensor` of same shape as `input_ids`
187
+ Mask to avoid performing attention on padding token indices.
188
+ Mask values selected in ``[0, 1]``:
189
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
190
+ Defaults to `None`.
191
+
192
+ `What are attention masks? <../glossary.html#attention-mask>`__
193
+
194
+ decoder_start_token_id=None: (`optional`) int
195
+ If an encoder-decoder model starts decoding with a different token than BOS.
196
+ Defaults to `None` and is changed to `BOS` later.
197
+
198
+ use_cache: (`optional`) bool
199
+ If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
200
+
201
+ model_specific_kwargs: (`optional`) dict
202
+ Additional model specific kwargs will be forwarded to the `forward` function of the model.
203
+
204
+ Return:
205
+
206
+ output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
207
+ sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
208
+
209
+ Examples::
210
+
211
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
212
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
213
+ outputs = model.generate(max_length=40) # do greedy decoding
214
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
217
+ model = AutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
218
+ input_context = 'The dog'
219
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
220
+ outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
221
+ for i in range(3): # 3 output sequences were generated
222
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
223
+
224
+ tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
225
+ model = AutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
226
+ input_context = 'The dog'
227
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
228
+ outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
229
+ for i in range(3): # 3 output sequences were generated
230
+ print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
231
+
232
+ tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
233
+ model = AutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
234
+ input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
235
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
236
+ outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
237
+ print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
238
+
239
+ tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
240
+ model = AutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
241
+ input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
242
+ bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
243
+ input_ids = tokenizer.encode(input_context, return_tensors='pt') # encode input context
244
+ outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
245
+ """
246
+
247
+ # We cannot generate if the model does not have a LM head
248
+ if self.get_output_embeddings() is None:
249
+ raise AttributeError(
250
+ "You tried to generate sequences with a model that does not have a LM Head."
251
+ "Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`, `XLMWithLMHeadModel`, `BartForConditionalGeneration` )"
252
+ )
253
+
254
+ max_length = max_length if max_length is not None else self.config.max_length
255
+ min_length = min_length if min_length is not None else self.config.min_length
256
+ do_sample = do_sample if do_sample is not None else self.config.do_sample
257
+ early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
258
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
259
+ num_beams = num_beams if num_beams is not None else self.config.num_beams
260
+ temperature = temperature if temperature is not None else self.config.temperature
261
+ top_k = top_k if top_k is not None else self.config.top_k
262
+ top_p = top_p if top_p is not None else self.config.top_p
263
+ repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
264
+ bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
265
+ pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
266
+ eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
267
+ length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
268
+ no_repeat_ngram_size = (
269
+ no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
270
+ )
271
+ bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
272
+ num_return_sequences = (
273
+ num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
274
+ )
275
+ decoder_start_token_id = (
276
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
277
+ )
278
+
279
+ if input_ids is not None:
280
+ batch_size = input_ids.shape[0] # overriden by the input batch_size
281
+ else:
282
+ batch_size = 1
283
+
284
+ assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
285
+ assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
286
+ assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
287
+ assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
288
+ assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
289
+ assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
290
+ assert temperature > 0, "`temperature` should be strictly positive."
291
+ assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
292
+ assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
293
+ assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
294
+ assert input_ids is not None or (
295
+ isinstance(bos_token_id, int) and bos_token_id >= 0
296
+ ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
297
+ assert pad_token_id is None or (
298
+ isinstance(pad_token_id, int) and (pad_token_id >= 0)
299
+ ), "`pad_token_id` should be a positive integer."
300
+ assert (eos_token_id is None) or (
301
+ isinstance(eos_token_id, int) and (eos_token_id >= 0)
302
+ ), "`eos_token_id` should be a positive integer."
303
+ assert length_penalty > 0, "`length_penalty` should be strictly positive."
304
+ assert (
305
+ isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
306
+ ), "`no_repeat_ngram_size` should be a positive integer."
307
+ assert (
308
+ isinstance(num_return_sequences, int) and num_return_sequences > 0
309
+ ), "`num_return_sequences` should be a strictly positive integer."
310
+ assert (
311
+ bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
312
+ ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
313
+
314
+ if input_ids is None:
315
+ assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
316
+ "you should either supply a context to complete as `input_ids` input "
317
+ "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
318
+ )
319
+ input_ids = torch.full(
320
+ (batch_size, 1), bos_token_id, dtype=torch.long, device=next(self.parameters()).device,
321
+ )
322
+ else:
323
+ assert input_ids.dim() == 2, "Input prompt should be of shape (batch_size, sequence length)."
324
+
325
+ # not allow to duplicate outputs when greedy decoding
326
+ if do_sample is False:
327
+ if num_beams == 1:
328
+ # no_beam_search greedy generation conditions
329
+ assert (
330
+ num_return_sequences == 1
331
+ ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
332
+
333
+ else:
334
+ # beam_search greedy generation conditions
335
+ assert (
336
+ num_beams >= num_return_sequences
337
+ ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
338
+
339
+ # create attention mask if necessary
340
+ # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
341
+ if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
342
+ attention_mask = input_ids.ne(pad_token_id).long()
343
+ elif attention_mask is None:
344
+ attention_mask = input_ids.new_ones(input_ids.shape)
345
+
346
+ # set pad_token_id to eos_token_id if not set. Important that this is done after
347
+ # attention_mask is created
348
+ if pad_token_id is None and eos_token_id is not None:
349
+ logger.warning(
350
+ "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
351
+ )
352
+ pad_token_id = eos_token_id
353
+
354
+ # current position and vocab size
355
+ if hasattr(self.config, "vocab_size"):
356
+ vocab_size = self.config.vocab_size
357
+ elif (
358
+ self.config.is_encoder_decoder
359
+ and hasattr(self.config, "decoder")
360
+ and hasattr(self.config.decoder, "vocab_size")
361
+ ):
362
+ vocab_size = self.config.decoder.vocab_size
363
+
364
+ # set effective batch size and effective batch multiplier according to do_sample
365
+ if do_sample:
366
+ effective_batch_size = batch_size * num_return_sequences
367
+ effective_batch_mult = num_return_sequences
368
+ else:
369
+ effective_batch_size = batch_size
370
+ effective_batch_mult = 1
371
+
372
+ if self.config.is_encoder_decoder:
373
+ if decoder_start_token_id is None:
374
+ decoder_start_token_id = bos_token_id
375
+
376
+ assert (
377
+ decoder_start_token_id is not None
378
+ ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
379
+ assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
380
+ assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
381
+
382
+ # get encoder and store encoder outputs
383
+ encoder = self.get_encoder()
384
+
385
+ encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
+
387
+ # Expand input ids if num_beams > 1 or num_return_sequences > 1
388
+ if num_return_sequences > 1 or num_beams > 1:
389
+ input_ids_len = input_ids.shape[-1]
390
+ input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
391
+ attention_mask = attention_mask.unsqueeze(1).expand(
392
+ batch_size, effective_batch_mult * num_beams, input_ids_len
393
+ )
394
+
395
+ input_ids = input_ids.contiguous().view(
396
+ effective_batch_size * num_beams, input_ids_len
397
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
398
+ attention_mask = attention_mask.contiguous().view(
399
+ effective_batch_size * num_beams, input_ids_len
400
+ ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
401
+
402
+ if self.config.is_encoder_decoder:
403
+ # create empty decoder_input_ids
404
+ input_ids = torch.full(
405
+ (effective_batch_size * num_beams, 1),
406
+ decoder_start_token_id,
407
+ dtype=torch.long,
408
+ device=next(self.parameters()).device,
409
+ )
410
+ cur_len = 1
411
+
412
+ assert (
413
+ batch_size == encoder_outputs[0].shape[0]
414
+ ), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
415
+
416
+ # expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
417
+ expanded_batch_idxs = (
418
+ torch.arange(batch_size)
419
+ .view(-1, 1)
420
+ .repeat(1, num_beams * effective_batch_mult)
421
+ .view(-1)
422
+ .to(input_ids.device)
423
+ )
424
+ # expand encoder_outputs
425
+ encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
426
+
427
+ else:
428
+ encoder_outputs = None
429
+ cur_len = input_ids.shape[-1]
430
+
431
+ assert (
432
+ cur_len < max_length
433
+ ), f"The context has {cur_len} number of tokens, but `max_length` is only {max_length}. Please make sure that `max_length` is bigger than the number of tokens, by setting either `generate(max_length=...,...)` or `config.max_length = ...`"
434
+
435
+ if num_beams > 1:
436
+ output = self._generate_beam_search(
437
+ input_ids,
438
+ cur_len=cur_len,
439
+ max_length=max_length,
440
+ min_length=min_length,
441
+ do_sample=do_sample,
442
+ early_stopping=early_stopping,
443
+ temperature=temperature,
444
+ top_k=top_k,
445
+ top_p=top_p,
446
+ repetition_penalty=repetition_penalty,
447
+ no_repeat_ngram_size=no_repeat_ngram_size,
448
+ bad_words_ids=bad_words_ids,
449
+ pad_token_id=pad_token_id,
450
+ eos_token_id=eos_token_id,
451
+ batch_size=effective_batch_size,
452
+ num_return_sequences=num_return_sequences,
453
+ length_penalty=length_penalty,
454
+ num_beams=num_beams,
455
+ vocab_size=vocab_size,
456
+ encoder_outputs=encoder_outputs,
457
+ attention_mask=attention_mask,
458
+ use_cache=use_cache,
459
+ model_specific_kwargs=model_specific_kwargs,
460
+ )
461
+ else:
462
+ output = self._generate_no_beam_search(
463
+ input_ids,
464
+ cur_len=cur_len,
465
+ max_length=max_length,
466
+ min_length=min_length,
467
+ do_sample=do_sample,
468
+ temperature=temperature,
469
+ top_k=top_k,
470
+ top_p=top_p,
471
+ repetition_penalty=repetition_penalty,
472
+ no_repeat_ngram_size=no_repeat_ngram_size,
473
+ bad_words_ids=bad_words_ids,
474
+ pad_token_id=pad_token_id,
475
+ eos_token_id=eos_token_id,
476
+ batch_size=effective_batch_size,
477
+ encoder_outputs=encoder_outputs,
478
+ attention_mask=attention_mask,
479
+ use_cache=use_cache,
480
+ model_specific_kwargs=model_specific_kwargs,
481
+ )
482
+
483
+ return output
484
+
485
+ def _generate_no_beam_search(
486
+ self,
487
+ input_ids,
488
+ cur_len,
489
+ max_length,
490
+ min_length,
491
+ do_sample,
492
+ temperature,
493
+ top_k,
494
+ top_p,
495
+ repetition_penalty,
496
+ no_repeat_ngram_size,
497
+ bad_words_ids,
498
+ pad_token_id,
499
+ eos_token_id,
500
+ batch_size,
501
+ encoder_outputs,
502
+ attention_mask,
503
+ use_cache,
504
+ model_specific_kwargs,
505
+ ):
506
+ """ Generate sequences for each example without beam search (num_beams == 1).
507
+ All returned sequence are generated independantly.
508
+ """
509
+ # length of generated sentences / unfinished sentences
510
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
511
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
512
+
513
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
514
+
515
+ while cur_len < max_length:
516
+ model_inputs = self.prepare_inputs_for_generation(
517
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
518
+ )
519
+
520
+ outputs = self(**model_inputs)
521
+ next_token_logits = outputs[0][:, -1, :]
522
+
523
+ scores = self.postprocess_next_token_scores(
524
+ scores=next_token_logits,
525
+ input_ids=input_ids,
526
+ no_repeat_ngram_size=no_repeat_ngram_size,
527
+ bad_words_ids=bad_words_ids,
528
+ cur_len=cur_len,
529
+ min_length=min_length,
530
+ max_length=max_length,
531
+ eos_token_id=eos_token_id,
532
+ repetition_penalty=repetition_penalty,
533
+ batch_size=batch_size,
534
+ num_beams=1,
535
+ )
536
+
537
+ # if model has past, then set the past variable to speed up decoding
538
+ if self._use_cache(outputs, use_cache):
539
+ past = outputs[1]
540
+
541
+ if do_sample:
542
+ # Temperature (higher temperature => more likely to sample low probability tokens)
543
+ if temperature != 1.0:
544
+ scores = scores / temperature
545
+ # Top-p/top-k filtering
546
+ next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
547
+ # Sample
548
+ probs = F.softmax(next_token_logscores, dim=-1)
549
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
550
+ else:
551
+ # Greedy decoding
552
+ next_token = torch.argmax(next_token_logits, dim=-1)
553
+
554
+ # update generations and finished sentences
555
+ if eos_token_id is not None:
556
+ # pad finished sentences if eos_token_id exist
557
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
558
+ else:
559
+ tokens_to_add = next_token
560
+
561
+ # add token and increase length by one
562
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
563
+ cur_len = cur_len + 1
564
+
565
+ if eos_token_id is not None:
566
+ eos_in_sents = tokens_to_add == eos_token_id
567
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
568
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
569
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
570
+ # unfinished_sents is set to zero if eos in sentence
571
+ unfinished_sents.mul_((~eos_in_sents).long())
572
+
573
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
574
+ if unfinished_sents.max() == 0:
575
+ break
576
+
577
+ # extend attention_mask for new generated input if only decoder
578
+ if self.config.is_encoder_decoder is False:
579
+ attention_mask = torch.cat(
580
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
581
+ )
582
+
583
+ return input_ids
584
+
585
+ def _generate_beam_search(
586
+ self,
587
+ input_ids,
588
+ cur_len,
589
+ max_length,
590
+ min_length,
591
+ do_sample,
592
+ early_stopping,
593
+ temperature,
594
+ top_k,
595
+ top_p,
596
+ repetition_penalty,
597
+ no_repeat_ngram_size,
598
+ bad_words_ids,
599
+ pad_token_id,
600
+ eos_token_id,
601
+ batch_size,
602
+ num_return_sequences,
603
+ length_penalty,
604
+ num_beams,
605
+ vocab_size,
606
+ encoder_outputs,
607
+ attention_mask,
608
+ use_cache,
609
+ model_specific_kwargs,
610
+ ):
611
+ """ Generate sequences for each example with beam search.
612
+ """
613
+
614
+ # generated hypotheses
615
+ generated_hyps = [
616
+ BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
617
+ for _ in range(batch_size)
618
+ ]
619
+
620
+ # scores for each sentence in the beam
621
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
622
+
623
+ # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
624
+ if do_sample is False:
625
+ beam_scores[:, 1:] = -1e9
626
+ beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
627
+
628
+ # cache compute states
629
+ past = (encoder_outputs, None) if encoder_outputs is not None else None
630
+
631
+ # done sentences
632
+ done = [False for _ in range(batch_size)]
633
+
634
+ while cur_len < max_length:
635
+ model_inputs = self.prepare_inputs_for_generation(
636
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
637
+ )
638
+ outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
639
+ next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
640
+
641
+ # if model has past, then set the past variable to speed up decoding
642
+ if self._use_cache(outputs, use_cache):
643
+ past = outputs[1]
644
+ if self.config.is_encoder_decoder and do_sample is False:
645
+ # TODO (PVP) still a bit hacky here - there might be a better solution
646
+ next_token_logits = self.adjust_logits_during_generation(
647
+ next_token_logits, cur_len=cur_len, max_length=max_length
648
+ )
649
+
650
+ scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
651
+
652
+ scores = self.postprocess_next_token_scores(
653
+ scores=scores,
654
+ input_ids=input_ids,
655
+ no_repeat_ngram_size=no_repeat_ngram_size,
656
+ bad_words_ids=bad_words_ids,
657
+ cur_len=cur_len,
658
+ min_length=min_length,
659
+ max_length=max_length,
660
+ eos_token_id=eos_token_id,
661
+ repetition_penalty=repetition_penalty,
662
+ batch_size=batch_size,
663
+ num_beams=num_beams,
664
+ )
665
+
666
+ assert scores.shape == (batch_size * num_beams, vocab_size), "Shapes of scores: {} != {}".format(
667
+ scores.shape, (batch_size * num_beams, vocab_size)
668
+ )
669
+
670
+ if do_sample:
671
+ _scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
672
+ # Temperature
673
+ if temperature != 1.0:
674
+ _scores = _scores / temperature
675
+ # Top-p/top-k filtering
676
+ _scores = top_k_top_p_filtering(
677
+ _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
678
+ ) # (batch_size * num_beams, vocab_size)
679
+ # re-organize to group the beam together to sample from all beam_idxs
680
+ _scores = _scores.contiguous().view(
681
+ batch_size, num_beams * vocab_size
682
+ ) # (batch_size, num_beams * vocab_size)
683
+
684
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
685
+ probs = F.softmax(_scores, dim=-1)
686
+ next_tokens = torch.multinomial(probs, num_samples=2 * num_beams) # (batch_size, num_beams * 2)
687
+ # Compute next scores
688
+ next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)
689
+ # sort the sampled vector to make sure that the first num_beams samples are the best
690
+ next_scores, next_scores_indices = torch.sort(next_scores, descending=True, dim=1)
691
+ next_tokens = torch.gather(next_tokens, -1, next_scores_indices) # (batch_size, num_beams * 2)
692
+
693
+ else:
694
+ next_scores = scores + beam_scores[:, None].expand_as(scores) # (batch_size * num_beams, vocab_size)
695
+
696
+ # re-organize to group the beam together (we are keeping top hypothesis accross beams)
697
+ next_scores = next_scores.view(
698
+ batch_size, num_beams * vocab_size
699
+ ) # (batch_size, num_beams * vocab_size)
700
+
701
+ next_scores, next_tokens = torch.topk(next_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
702
+
703
+ assert next_scores.size() == next_tokens.size() == (batch_size, 2 * num_beams)
704
+
705
+ # next batch beam content
706
+ next_batch_beam = []
707
+
708
+ # for each sentence
709
+ for batch_idx in range(batch_size):
710
+
711
+ # if we are done with this sentence, add a pad token
712
+ if done[batch_idx]:
713
+ assert (
714
+ len(generated_hyps[batch_idx]) >= num_beams
715
+ ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
716
+ assert (
717
+ eos_token_id is not None and pad_token_id is not None
718
+ ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
719
+ next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
720
+ continue
721
+
722
+ # next sentence beam content, this will get added to next_batch_beam
723
+ next_sent_beam = []
724
+
725
+ # next tokens for this sentence
726
+ for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
727
+ zip(next_tokens[batch_idx], next_scores[batch_idx])
728
+ ):
729
+ # get beam and token IDs
730
+ beam_id = beam_token_id // vocab_size
731
+ token_id = beam_token_id % vocab_size
732
+
733
+ effective_beam_id = batch_idx * num_beams + beam_id
734
+ # add to generated hypotheses if end of sentence
735
+ if (eos_token_id is not None) and (token_id.item() == eos_token_id):
736
+ # if beam_token does not belong to top num_beams tokens, it should not be added
737
+ is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
738
+ if is_beam_token_worse_than_top_num_beams:
739
+ continue
740
+ generated_hyps[batch_idx].add(
741
+ input_ids[effective_beam_id].clone(), beam_token_score.item(),
742
+ )
743
+ else:
744
+ # add next predicted token since it is not eos_token
745
+ next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
746
+
747
+ # once the beam for next step is full, don't add more tokens to it.
748
+ if len(next_sent_beam) == num_beams:
749
+ break
750
+
751
+ # Check if we are done so that we can save a pad step if all(done)
752
+ done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
753
+ next_scores[batch_idx].max().item(), cur_len
754
+ )
755
+
756
+ # update next beam content
757
+ assert len(next_sent_beam) == num_beams, "Beam should always be full"
758
+ next_batch_beam.extend(next_sent_beam)
759
+ assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
760
+
761
+ # stop when we are done with each sentence
762
+ if all(done):
763
+ break
764
+
765
+ # sanity check / prepare next batch
766
+ assert len(next_batch_beam) == batch_size * num_beams
767
+ beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
768
+ beam_tokens = input_ids.new([x[1] for x in next_batch_beam])
769
+ beam_idx = input_ids.new([x[2] for x in next_batch_beam])
770
+
771
+ # re-order batch and update current length
772
+ input_ids = input_ids[beam_idx, :]
773
+ input_ids = torch.cat([input_ids, beam_tokens.unsqueeze(1)], dim=-1)
774
+ cur_len = cur_len + 1
775
+
776
+ # re-order internal states
777
+ if past is not None:
778
+ past = self._reorder_cache(past, beam_idx)
779
+
780
+ # extend attention_mask for new generated input if only decoder
781
+ if self.config.is_encoder_decoder is False:
782
+ attention_mask = torch.cat(
783
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
784
+ )
785
+
786
+ # finalize all open beam hypotheses and add to generated hypotheses
787
+ for batch_idx in range(batch_size):
788
+ if done[batch_idx]:
789
+ continue
790
+
791
+ # test that beam scores match previously calculated scores if not eos and batch_idx not done
792
+ if eos_token_id is not None and all(
793
+ (token_id % vocab_size).item() != eos_token_id for token_id in next_tokens[batch_idx]
794
+ ):
795
+ assert torch.all(
796
+ next_scores[batch_idx, :num_beams] == beam_scores.view(batch_size, num_beams)[batch_idx]
797
+ ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
798
+ next_scores[:, :num_beams][batch_idx], beam_scores.view(batch_size, num_beams)[batch_idx],
799
+ )
800
+
801
+ # need to add best num_beams hypotheses to generated hyps
802
+ for beam_id in range(num_beams):
803
+ effective_beam_id = batch_idx * num_beams + beam_id
804
+ final_score = beam_scores[effective_beam_id].item()
805
+ final_tokens = input_ids[effective_beam_id]
806
+ generated_hyps[batch_idx].add(final_tokens, final_score)
807
+
808
+ # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
809
+ output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
810
+ output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
811
+
812
+ # select the best hypotheses
813
+ sent_lengths = input_ids.new(output_batch_size)
814
+ best = []
815
+
816
+ # retrieve best hypotheses
817
+ for i, hypotheses in enumerate(generated_hyps):
818
+ sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
819
+ for j in range(output_num_return_sequences_per_batch):
820
+ effective_batch_idx = output_num_return_sequences_per_batch * i + j
821
+ best_hyp = sorted_hyps.pop()[1]
822
+ sent_lengths[effective_batch_idx] = len(best_hyp)
823
+ best.append(best_hyp)
824
+
825
+ # shorter batches are padded
826
+ if sent_lengths.min().item() != sent_lengths.max().item():
827
+ assert pad_token_id is not None, "`Pad_token_id` has to be defined"
828
+ sent_max_len = min(sent_lengths.max().item() + 1, max_length)
829
+ decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
830
+
831
+ # fill with hypothesis and eos_token_id if necessary
832
+ for i, hypo in enumerate(best):
833
+ decoded[i, : sent_lengths[i]] = hypo
834
+ if sent_lengths[i] < max_length:
835
+ decoded[i, sent_lengths[i]] = eos_token_id
836
+ else:
837
+ # none of the hypotheses have an eos_token
838
+ assert (len(hypo) == max_length for hypo in best)
839
+ decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
840
+
841
+ return decoded
842
+
843
+ @staticmethod
844
+ def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
845
+ return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
846
+
847
+
848
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
849
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
850
+ if cur_len + 1 < no_repeat_ngram_size:
851
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
852
+ return [[] for _ in range(num_hypos)]
853
+ generated_ngrams = [{} for _ in range(num_hypos)]
854
+ for idx in range(num_hypos):
855
+ gen_tokens = prev_input_ids[idx].tolist()
856
+ generated_ngram = generated_ngrams[idx]
857
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
858
+ prev_ngram_tuple = tuple(ngram[:-1])
859
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
860
+
861
+ def _get_generated_ngrams(hypo_idx):
862
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
863
+ start_idx = cur_len + 1 - no_repeat_ngram_size
864
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
865
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
866
+
867
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
868
+ return banned_tokens
869
+
870
+
871
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
872
+ banned_tokens = []
873
+
874
+ def _tokens_match(prev_tokens, tokens):
875
+ if len(tokens) == 0:
876
+ # if bad word tokens is just one token always ban it
877
+ return True
878
+ if len(tokens) > len(prev_input_ids):
879
+ # if bad word tokens are longer then prev input_ids they can't be equal
880
+ return False
881
+
882
+ if prev_tokens[-len(tokens) :] == tokens:
883
+ # if tokens match
884
+ return True
885
+ else:
886
+ return False
887
+
888
+ for prev_input_ids_slice in prev_input_ids:
889
+ banned_tokens_slice = []
890
+
891
+ for banned_token_seq in bad_words_ids:
892
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
893
+ bad_words_ids
894
+ )
895
+
896
+ if _tokens_match(prev_input_ids_slice.tolist(), banned_token_seq[:-1]) is False:
897
+ # if tokens do not match continue
898
+ continue
899
+
900
+ banned_tokens_slice.append(banned_token_seq[-1])
901
+
902
+ banned_tokens.append(banned_tokens_slice)
903
+
904
+ return banned_tokens
905
+
906
+
907
+ def top_k_top_p_filtering(
908
+ logits: Tensor,
909
+ top_k: int = 0,
910
+ top_p: float = 1.0,
911
+ filter_value: float = -float("Inf"),
912
+ min_tokens_to_keep: int = 1,
913
+ ) -> Tensor:
914
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
915
+ Args:
916
+ logits: logits distribution shape (batch size, vocabulary size)
917
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
918
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
919
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
920
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
921
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
922
+ """
923
+ if top_k > 0:
924
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
925
+ # Remove all tokens with a probability less than the last token of the top-k
926
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
927
+ logits[indices_to_remove] = filter_value
928
+
929
+ if top_p < 1.0:
930
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
931
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
932
+
933
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
934
+ sorted_indices_to_remove = cumulative_probs > top_p
935
+ if min_tokens_to_keep > 1:
936
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
937
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
938
+ # Shift the indices to the right to keep also the first token above the threshold
939
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
940
+ sorted_indices_to_remove[..., 0] = 0
941
+
942
+ # scatter sorted tensors to original indexing
943
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
944
+ logits[indices_to_remove] = filter_value
945
+ return logits
946
+
947
+
948
+ class BeamHypotheses(object):
949
+ def __init__(self, num_beams, max_length, length_penalty, early_stopping):
950
+ """
951
+ Initialize n-best list of hypotheses.
952
+ """
953
+ self.max_length = max_length - 1 # ignoring bos_token
954
+ self.length_penalty = length_penalty
955
+ self.early_stopping = early_stopping
956
+ self.num_beams = num_beams
957
+ self.beams = []
958
+ self.worst_score = 1e9
959
+
960
+ def __len__(self):
961
+ """
962
+ Number of hypotheses in the list.
963
+ """
964
+ return len(self.beams)
965
+
966
+ def add(self, hyp, sum_logprobs):
967
+ """
968
+ Add a new hypothesis to the list.
969
+ """
970
+ score = sum_logprobs / len(hyp) ** self.length_penalty
971
+ if len(self) < self.num_beams or score > self.worst_score:
972
+ self.beams.append((score, hyp))
973
+ if len(self) > self.num_beams:
974
+ sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
975
+ del self.beams[sorted_scores[0][1]]
976
+ self.worst_score = sorted_scores[1][0]
977
+ else:
978
+ self.worst_score = min(score, self.worst_score)
979
+
980
+ def is_done(self, best_sum_logprobs, cur_len):
981
+ """
982
+ If there are enough hypotheses and that none of the hypotheses being generated
983
+ can become better than the worst one in the heap, then we are done with this sentence.
984
+ """
985
+
986
+ if len(self) < self.num_beams:
987
+ return False
988
+ elif self.early_stopping:
989
+ return True
990
+ else:
991
+ cur_score = best_sum_logprobs / cur_len ** self.length_penalty
992
+ ret = self.worst_score >= cur_score
993
+ return ret
CGFormer/bert/modeling_bert.py ADDED
@@ -0,0 +1,1569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+
19
+ import logging
20
+ import math
21
+ import os
22
+ import warnings
23
+
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import CrossEntropyLoss, MSELoss
28
+
29
+ from .activations import gelu, gelu_new, swish
30
+ from .configuration_bert import BertConfig
31
+ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
32
+ from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
33
+
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
38
+
39
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
40
+ "bert-base-uncased",
41
+ "bert-large-uncased",
42
+ "bert-base-cased",
43
+ "bert-large-cased",
44
+ "bert-base-multilingual-uncased",
45
+ "bert-base-multilingual-cased",
46
+ "bert-base-chinese",
47
+ "bert-base-german-cased",
48
+ "bert-large-uncased-whole-word-masking",
49
+ "bert-large-cased-whole-word-masking",
50
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
51
+ "bert-large-cased-whole-word-masking-finetuned-squad",
52
+ "bert-base-cased-finetuned-mrpc",
53
+ "bert-base-german-dbmdz-cased",
54
+ "bert-base-german-dbmdz-uncased",
55
+ "cl-tohoku/bert-base-japanese",
56
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
57
+ "cl-tohoku/bert-base-japanese-char",
58
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
59
+ "TurkuNLP/bert-base-finnish-cased-v1",
60
+ "TurkuNLP/bert-base-finnish-uncased-v1",
61
+ "wietsedv/bert-base-dutch-cased",
62
+ # See all BERT models at https://huggingface.co/models?filter=bert
63
+ ]
64
+
65
+
66
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
67
+ """ Load tf checkpoints in a pytorch model.
68
+ """
69
+ try:
70
+ import re
71
+ import numpy as np
72
+ import tensorflow as tf
73
+ except ImportError:
74
+ logger.error(
75
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
76
+ "https://www.tensorflow.org/install/ for installation instructions."
77
+ )
78
+ raise
79
+ tf_path = os.path.abspath(tf_checkpoint_path)
80
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
81
+ # Load weights from TF model
82
+ init_vars = tf.train.list_variables(tf_path)
83
+ names = []
84
+ arrays = []
85
+ for name, shape in init_vars:
86
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
87
+ array = tf.train.load_variable(tf_path, name)
88
+ names.append(name)
89
+ arrays.append(array)
90
+
91
+ for name, array in zip(names, arrays):
92
+ name = name.split("/")
93
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
94
+ # which are not required for using pretrained model
95
+ if any(
96
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
97
+ for n in name
98
+ ):
99
+ logger.info("Skipping {}".format("/".join(name)))
100
+ continue
101
+ pointer = model
102
+ for m_name in name:
103
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
104
+ scope_names = re.split(r"_(\d+)", m_name)
105
+ else:
106
+ scope_names = [m_name]
107
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
108
+ pointer = getattr(pointer, "weight")
109
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
110
+ pointer = getattr(pointer, "bias")
111
+ elif scope_names[0] == "output_weights":
112
+ pointer = getattr(pointer, "weight")
113
+ elif scope_names[0] == "squad":
114
+ pointer = getattr(pointer, "classifier")
115
+ else:
116
+ try:
117
+ pointer = getattr(pointer, scope_names[0])
118
+ except AttributeError:
119
+ logger.info("Skipping {}".format("/".join(name)))
120
+ continue
121
+ if len(scope_names) >= 2:
122
+ num = int(scope_names[1])
123
+ pointer = pointer[num]
124
+ if m_name[-11:] == "_embeddings":
125
+ pointer = getattr(pointer, "weight")
126
+ elif m_name == "kernel":
127
+ array = np.transpose(array)
128
+ try:
129
+ assert pointer.shape == array.shape
130
+ except AssertionError as e:
131
+ e.args += (pointer.shape, array.shape)
132
+ raise
133
+ logger.info("Initialize PyTorch weight {}".format(name))
134
+ pointer.data = torch.from_numpy(array)
135
+ return model
136
+
137
+
138
+ def mish(x):
139
+ return x * torch.tanh(nn.functional.softplus(x))
140
+
141
+
142
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
143
+
144
+
145
+ BertLayerNorm = torch.nn.LayerNorm
146
+
147
+
148
+ class BertEmbeddings(nn.Module):
149
+ """Construct the embeddings from word, position and token_type embeddings.
150
+ """
151
+
152
+ def __init__(self, config):
153
+ super().__init__()
154
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
155
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
156
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
157
+
158
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
159
+ # any TensorFlow checkpoint file
160
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
161
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
162
+
163
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
164
+ if input_ids is not None:
165
+ input_shape = input_ids.size()
166
+ else:
167
+ input_shape = inputs_embeds.size()[:-1]
168
+
169
+ seq_length = input_shape[1]
170
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
171
+ if position_ids is None:
172
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
173
+ position_ids = position_ids.unsqueeze(0).expand(input_shape)
174
+ if token_type_ids is None:
175
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
176
+
177
+ if inputs_embeds is None:
178
+ inputs_embeds = self.word_embeddings(input_ids)
179
+ position_embeddings = self.position_embeddings(position_ids)
180
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
181
+
182
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
183
+ embeddings = self.LayerNorm(embeddings)
184
+ embeddings = self.dropout(embeddings)
185
+ return embeddings
186
+
187
+
188
+ class BertSelfAttention(nn.Module):
189
+ def __init__(self, config):
190
+ super().__init__()
191
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
192
+ raise ValueError(
193
+ "The hidden size (%d) is not a multiple of the number of attention "
194
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
195
+ )
196
+
197
+ self.num_attention_heads = config.num_attention_heads
198
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
199
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
200
+
201
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
202
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
203
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
204
+
205
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
206
+
207
+ def transpose_for_scores(self, x):
208
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
209
+ x = x.view(*new_x_shape)
210
+ return x.permute(0, 2, 1, 3)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states,
215
+ attention_mask=None,
216
+ head_mask=None,
217
+ encoder_hidden_states=None,
218
+ encoder_attention_mask=None,
219
+ output_attentions=False,
220
+ ):
221
+ mixed_query_layer = self.query(hidden_states)
222
+
223
+ # If this is instantiated as a cross-attention module, the keys
224
+ # and values come from an encoder; the attention mask needs to be
225
+ # such that the encoder's padding tokens are not attended to.
226
+ if encoder_hidden_states is not None:
227
+ mixed_key_layer = self.key(encoder_hidden_states)
228
+ mixed_value_layer = self.value(encoder_hidden_states)
229
+ attention_mask = encoder_attention_mask
230
+ else:
231
+ mixed_key_layer = self.key(hidden_states)
232
+ mixed_value_layer = self.value(hidden_states)
233
+
234
+ query_layer = self.transpose_for_scores(mixed_query_layer)
235
+ key_layer = self.transpose_for_scores(mixed_key_layer)
236
+ value_layer = self.transpose_for_scores(mixed_value_layer)
237
+
238
+ # Take the dot product between "query" and "key" to get the raw attention scores.
239
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
240
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
241
+ if attention_mask is not None:
242
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
243
+ attention_scores = attention_scores + attention_mask
244
+
245
+ # Normalize the attention scores to probabilities.
246
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
247
+
248
+ # This is actually dropping out entire tokens to attend to, which might
249
+ # seem a bit unusual, but is taken from the original Transformer paper.
250
+ attention_probs = self.dropout(attention_probs)
251
+
252
+ # Mask heads if we want to
253
+ if head_mask is not None:
254
+ attention_probs = attention_probs * head_mask
255
+
256
+ context_layer = torch.matmul(attention_probs, value_layer)
257
+
258
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
259
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
260
+ context_layer = context_layer.view(*new_context_layer_shape)
261
+
262
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
263
+ return outputs
264
+
265
+
266
+ class BertSelfOutput(nn.Module):
267
+ def __init__(self, config):
268
+ super().__init__()
269
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
270
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
271
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
272
+
273
+ def forward(self, hidden_states, input_tensor):
274
+ hidden_states = self.dense(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
277
+ return hidden_states
278
+
279
+
280
+ class BertAttention(nn.Module):
281
+ def __init__(self, config):
282
+ super().__init__()
283
+ self.self = BertSelfAttention(config)
284
+ self.output = BertSelfOutput(config)
285
+ self.pruned_heads = set()
286
+
287
+ def prune_heads(self, heads):
288
+ if len(heads) == 0:
289
+ return
290
+ heads, index = find_pruneable_heads_and_indices(
291
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
292
+ )
293
+
294
+ # Prune linear layers
295
+ self.self.query = prune_linear_layer(self.self.query, index)
296
+ self.self.key = prune_linear_layer(self.self.key, index)
297
+ self.self.value = prune_linear_layer(self.self.value, index)
298
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
299
+
300
+ # Update hyper params and store pruned heads
301
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
302
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
303
+ self.pruned_heads = self.pruned_heads.union(heads)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states,
308
+ attention_mask=None,
309
+ head_mask=None,
310
+ encoder_hidden_states=None,
311
+ encoder_attention_mask=None,
312
+ output_attentions=False,
313
+ ):
314
+ self_outputs = self.self(
315
+ hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions,
316
+ )
317
+ attention_output = self.output(self_outputs[0], hidden_states)
318
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
319
+ return outputs
320
+
321
+
322
+ class BertIntermediate(nn.Module):
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
326
+ if isinstance(config.hidden_act, str):
327
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
328
+ else:
329
+ self.intermediate_act_fn = config.hidden_act
330
+
331
+ def forward(self, hidden_states):
332
+ hidden_states = self.dense(hidden_states)
333
+ hidden_states = self.intermediate_act_fn(hidden_states)
334
+ return hidden_states
335
+
336
+
337
+ class BertOutput(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
341
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
342
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
343
+
344
+ def forward(self, hidden_states, input_tensor):
345
+ hidden_states = self.dense(hidden_states)
346
+ hidden_states = self.dropout(hidden_states)
347
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
348
+ return hidden_states
349
+
350
+
351
+ class BertLayer(nn.Module):
352
+ def __init__(self, config):
353
+ super().__init__()
354
+ self.attention = BertAttention(config)
355
+ self.is_decoder = config.is_decoder
356
+ if self.is_decoder:
357
+ self.crossattention = BertAttention(config)
358
+ self.intermediate = BertIntermediate(config)
359
+ self.output = BertOutput(config)
360
+
361
+ def forward(
362
+ self,
363
+ hidden_states,
364
+ attention_mask=None,
365
+ head_mask=None,
366
+ encoder_hidden_states=None,
367
+ encoder_attention_mask=None,
368
+ output_attentions=False,
369
+ ):
370
+ self_attention_outputs = self.attention(
371
+ hidden_states, attention_mask, head_mask, output_attentions=output_attentions,
372
+ )
373
+ attention_output = self_attention_outputs[0]
374
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
375
+
376
+ if self.is_decoder and encoder_hidden_states is not None:
377
+ cross_attention_outputs = self.crossattention(
378
+ attention_output,
379
+ attention_mask,
380
+ head_mask,
381
+ encoder_hidden_states,
382
+ encoder_attention_mask,
383
+ output_attentions,
384
+ )
385
+ attention_output = cross_attention_outputs[0]
386
+ outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
387
+
388
+ intermediate_output = self.intermediate(attention_output)
389
+ layer_output = self.output(intermediate_output, attention_output)
390
+ outputs = (layer_output,) + outputs
391
+ return outputs
392
+
393
+
394
+ class BertEncoder(nn.Module):
395
+ def __init__(self, config):
396
+ super().__init__()
397
+ self.config = config
398
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states,
403
+ attention_mask=None,
404
+ head_mask=None,
405
+ encoder_hidden_states=None,
406
+ encoder_attention_mask=None,
407
+ output_attentions=False,
408
+ output_hidden_states=False,
409
+ ):
410
+ all_hidden_states = ()
411
+ all_attentions = ()
412
+ for i, layer_module in enumerate(self.layer):
413
+ if output_hidden_states:
414
+ all_hidden_states = all_hidden_states + (hidden_states,)
415
+
416
+ if getattr(self.config, "gradient_checkpointing", False):
417
+
418
+ def create_custom_forward(module):
419
+ def custom_forward(*inputs):
420
+ return module(*inputs, output_attentions)
421
+
422
+ return custom_forward
423
+
424
+ layer_outputs = torch.utils.checkpoint.checkpoint(
425
+ create_custom_forward(layer_module),
426
+ hidden_states,
427
+ attention_mask,
428
+ head_mask[i],
429
+ encoder_hidden_states,
430
+ encoder_attention_mask,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ head_mask[i],
437
+ encoder_hidden_states,
438
+ encoder_attention_mask,
439
+ output_attentions,
440
+ )
441
+ hidden_states = layer_outputs[0]
442
+
443
+ if output_attentions:
444
+ all_attentions = all_attentions + (layer_outputs[1],)
445
+
446
+ # Add last layer
447
+ if output_hidden_states:
448
+ all_hidden_states = all_hidden_states + (hidden_states,)
449
+
450
+ outputs = (hidden_states,)
451
+ if output_hidden_states:
452
+ outputs = outputs + (all_hidden_states,)
453
+ if output_attentions:
454
+ outputs = outputs + (all_attentions,)
455
+ return outputs # last-layer hidden state, (all hidden states), (all attentions)
456
+
457
+
458
+ class BertPooler(nn.Module):
459
+ def __init__(self, config):
460
+ super().__init__()
461
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
462
+ self.activation = nn.Tanh()
463
+
464
+ def forward(self, hidden_states):
465
+ # We "pool" the model by simply taking the hidden state corresponding
466
+ # to the first token.
467
+ first_token_tensor = hidden_states[:, 0]
468
+ pooled_output = self.dense(first_token_tensor)
469
+ pooled_output = self.activation(pooled_output)
470
+ return pooled_output
471
+
472
+
473
+ class BertPredictionHeadTransform(nn.Module):
474
+ def __init__(self, config):
475
+ super().__init__()
476
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
477
+ if isinstance(config.hidden_act, str):
478
+ self.transform_act_fn = ACT2FN[config.hidden_act]
479
+ else:
480
+ self.transform_act_fn = config.hidden_act
481
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
482
+
483
+ def forward(self, hidden_states):
484
+ hidden_states = self.dense(hidden_states)
485
+ hidden_states = self.transform_act_fn(hidden_states)
486
+ hidden_states = self.LayerNorm(hidden_states)
487
+ return hidden_states
488
+
489
+
490
+ class BertLMPredictionHead(nn.Module):
491
+ def __init__(self, config):
492
+ super().__init__()
493
+ self.transform = BertPredictionHeadTransform(config)
494
+
495
+ # The output weights are the same as the input embeddings, but there is
496
+ # an output-only bias for each token.
497
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
498
+
499
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
500
+
501
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
502
+ self.decoder.bias = self.bias
503
+
504
+ def forward(self, hidden_states):
505
+ hidden_states = self.transform(hidden_states)
506
+ hidden_states = self.decoder(hidden_states)
507
+ return hidden_states
508
+
509
+
510
+ class BertOnlyMLMHead(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.predictions = BertLMPredictionHead(config)
514
+
515
+ def forward(self, sequence_output):
516
+ prediction_scores = self.predictions(sequence_output)
517
+ return prediction_scores
518
+
519
+
520
+ class BertOnlyNSPHead(nn.Module):
521
+ def __init__(self, config):
522
+ super().__init__()
523
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
524
+
525
+ def forward(self, pooled_output):
526
+ seq_relationship_score = self.seq_relationship(pooled_output)
527
+ return seq_relationship_score
528
+
529
+
530
+ class BertPreTrainingHeads(nn.Module):
531
+ def __init__(self, config):
532
+ super().__init__()
533
+ self.predictions = BertLMPredictionHead(config)
534
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
535
+
536
+ def forward(self, sequence_output, pooled_output):
537
+ prediction_scores = self.predictions(sequence_output)
538
+ seq_relationship_score = self.seq_relationship(pooled_output)
539
+ return prediction_scores, seq_relationship_score
540
+
541
+
542
+ class BertPreTrainedModel(PreTrainedModel):
543
+ """ An abstract class to handle weights initialization and
544
+ a simple interface for downloading and loading pretrained models.
545
+ """
546
+
547
+ config_class = BertConfig
548
+ load_tf_weights = load_tf_weights_in_bert
549
+ base_model_prefix = "bert"
550
+
551
+ def _init_weights(self, module):
552
+ """ Initialize the weights """
553
+ if isinstance(module, (nn.Linear, nn.Embedding)):
554
+ # Slightly different from the TF version which uses truncated_normal for initialization
555
+ # cf https://github.com/pytorch/pytorch/pull/5617
556
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
557
+ elif isinstance(module, BertLayerNorm):
558
+ module.bias.data.zero_()
559
+ module.weight.data.fill_(1.0)
560
+ if isinstance(module, nn.Linear) and module.bias is not None:
561
+ module.bias.data.zero_()
562
+
563
+
564
+ BERT_START_DOCSTRING = r"""
565
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
566
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
567
+ usage and behavior.
568
+
569
+ Parameters:
570
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
571
+ Initializing with a config file does not load the weights associated with the model, only the configuration.
572
+ Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
573
+ """
574
+
575
+ BERT_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
578
+ Indices of input sequence tokens in the vocabulary.
579
+
580
+ Indices can be obtained using :class:`transformers.BertTokenizer`.
581
+ See :func:`transformers.PreTrainedTokenizer.encode` and
582
+ :func:`transformers.PreTrainedTokenizer.__call__` for details.
583
+
584
+ `What are input IDs? <../glossary.html#input-ids>`__
585
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
586
+ Mask to avoid performing attention on padding token indices.
587
+ Mask values selected in ``[0, 1]``:
588
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
589
+
590
+ `What are attention masks? <../glossary.html#attention-mask>`__
591
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
592
+ Segment token indices to indicate first and second portions of the inputs.
593
+ Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
594
+ corresponds to a `sentence B` token
595
+
596
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
597
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
598
+ Indices of positions of each input sequence tokens in the position embeddings.
599
+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
600
+
601
+ `What are position IDs? <../glossary.html#position-ids>`_
602
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
603
+ Mask to nullify selected heads of the self-attention modules.
604
+ Mask values selected in ``[0, 1]``:
605
+ :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
606
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
607
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
608
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
609
+ than the model's internal embedding lookup matrix.
610
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
611
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
612
+ if the model is configured as a decoder.
613
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
614
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
615
+ is used in the cross-attention if the model is configured as a decoder.
616
+ Mask values selected in ``[0, 1]``:
617
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
618
+ output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
619
+ If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
620
+ """
621
+
622
+
623
+ @add_start_docstrings(
624
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
625
+ BERT_START_DOCSTRING,
626
+ )
627
+ class BertModel(BertPreTrainedModel):
628
+ """
629
+
630
+ The model can behave as an encoder (with only self-attention) as well
631
+ as a decoder, in which case a layer of cross-attention is added between
632
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
633
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
634
+
635
+ To behave as an decoder the model needs to be initialized with the
636
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
637
+ :obj:`encoder_hidden_states` is expected as an input to the forward pass.
638
+
639
+ .. _`Attention is all you need`:
640
+ https://arxiv.org/abs/1706.03762
641
+
642
+ """
643
+
644
+ def __init__(self, config):
645
+ super().__init__(config)
646
+ self.config = config
647
+
648
+ self.embeddings = BertEmbeddings(config)
649
+ self.encoder = BertEncoder(config)
650
+ self.pooler = BertPooler(config)
651
+
652
+ self.init_weights()
653
+
654
+ def get_input_embeddings(self):
655
+ return self.embeddings.word_embeddings
656
+
657
+ def set_input_embeddings(self, value):
658
+ self.embeddings.word_embeddings = value
659
+
660
+ def _prune_heads(self, heads_to_prune):
661
+ """ Prunes heads of the model.
662
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
663
+ See base class PreTrainedModel
664
+ """
665
+ for layer, heads in heads_to_prune.items():
666
+ self.encoder.layer[layer].attention.prune_heads(heads)
667
+
668
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
669
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
670
+ def forward(
671
+ self,
672
+ input_ids=None,
673
+ attention_mask=None,
674
+ token_type_ids=None,
675
+ position_ids=None,
676
+ head_mask=None,
677
+ inputs_embeds=None,
678
+ encoder_hidden_states=None,
679
+ encoder_attention_mask=None,
680
+ output_attentions=None,
681
+ output_hidden_states=None,
682
+ ):
683
+ r"""
684
+ Return:
685
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
686
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
687
+ Sequence of hidden-states at the output of the last layer of the model.
688
+ pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
689
+ Last layer hidden-state of the first token of the sequence (classification token)
690
+ further processed by a Linear layer and a Tanh activation function. The Linear
691
+ layer weights are trained from the next sentence prediction (classification)
692
+ objective during pre-training.
693
+
694
+ This output is usually *not* a good summary
695
+ of the semantic content of the input, you're often better with averaging or pooling
696
+ the sequence of hidden-states for the whole input sequence.
697
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
698
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
699
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
700
+
701
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
702
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
703
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
704
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
705
+
706
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
707
+ heads.
708
+ """
709
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
710
+ output_hidden_states = (
711
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
712
+ )
713
+
714
+ if input_ids is not None and inputs_embeds is not None:
715
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
716
+ elif input_ids is not None:
717
+ input_shape = input_ids.size()
718
+ elif inputs_embeds is not None:
719
+ input_shape = inputs_embeds.size()[:-1]
720
+ else:
721
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
722
+
723
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
724
+
725
+ if attention_mask is None:
726
+ attention_mask = torch.ones(input_shape, device=device)
727
+ if token_type_ids is None:
728
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
729
+
730
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
731
+ # ourselves in which case we just need to make it broadcastable to all heads.
732
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
733
+
734
+ # If a 2D ou 3D attention mask is provided for the cross-attention
735
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
736
+ if self.config.is_decoder and encoder_hidden_states is not None:
737
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
738
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
739
+ if encoder_attention_mask is None:
740
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
741
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
742
+ else:
743
+ encoder_extended_attention_mask = None
744
+
745
+ # Prepare head mask if needed
746
+ # 1.0 in head_mask indicate we keep the head
747
+ # attention_probs has shape bsz x n_heads x N x N
748
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
749
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
750
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
751
+
752
+ embedding_output = self.embeddings(
753
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
754
+ )
755
+ encoder_outputs = self.encoder(
756
+ embedding_output,
757
+ attention_mask=extended_attention_mask,
758
+ head_mask=head_mask,
759
+ encoder_hidden_states=encoder_hidden_states,
760
+ encoder_attention_mask=encoder_extended_attention_mask,
761
+ output_attentions=output_attentions,
762
+ output_hidden_states=output_hidden_states,
763
+ )
764
+ sequence_output = encoder_outputs[0]
765
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
766
+
767
+ outputs = (sequence_output, pooled_output,) + encoder_outputs[
768
+ 1:
769
+ ] # add hidden_states and attentions if they are here
770
+ return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
771
+
772
+
773
+ @add_start_docstrings(
774
+ """Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
775
+ a `next sentence prediction (classification)` head. """,
776
+ BERT_START_DOCSTRING,
777
+ )
778
+ class BertForPreTraining(BertPreTrainedModel):
779
+ def __init__(self, config):
780
+ super().__init__(config)
781
+
782
+ self.bert = BertModel(config)
783
+ self.cls = BertPreTrainingHeads(config)
784
+
785
+ self.init_weights()
786
+
787
+ def get_output_embeddings(self):
788
+ return self.cls.predictions.decoder
789
+
790
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
791
+ def forward(
792
+ self,
793
+ input_ids=None,
794
+ attention_mask=None,
795
+ token_type_ids=None,
796
+ position_ids=None,
797
+ head_mask=None,
798
+ inputs_embeds=None,
799
+ labels=None,
800
+ next_sentence_label=None,
801
+ output_attentions=None,
802
+ output_hidden_states=None,
803
+ **kwargs
804
+ ):
805
+ r"""
806
+ labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
807
+ Labels for computing the masked language modeling loss.
808
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
809
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
810
+ in ``[0, ..., config.vocab_size]``
811
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
812
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
813
+ Indices should be in ``[0, 1]``.
814
+ ``0`` indicates sequence B is a continuation of sequence A,
815
+ ``1`` indicates sequence B is a random sequence.
816
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
817
+ Used to hide legacy arguments that have been deprecated.
818
+
819
+ Returns:
820
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
821
+ loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
822
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
823
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
824
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
825
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
826
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False
827
+ continuation before SoftMax).
828
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
829
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
830
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
831
+
832
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
833
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
834
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
835
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
836
+
837
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
838
+ heads.
839
+
840
+
841
+ Examples::
842
+
843
+ >>> from transformers import BertTokenizer, BertForPreTraining
844
+ >>> import torch
845
+
846
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
847
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
848
+
849
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
850
+ >>> outputs = model(**inputs)
851
+
852
+ >>> prediction_scores, seq_relationship_scores = outputs[:2]
853
+
854
+ """
855
+ if "masked_lm_labels" in kwargs:
856
+ warnings.warn(
857
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
858
+ DeprecationWarning,
859
+ )
860
+ labels = kwargs.pop("masked_lm_labels")
861
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
862
+
863
+ outputs = self.bert(
864
+ input_ids,
865
+ attention_mask=attention_mask,
866
+ token_type_ids=token_type_ids,
867
+ position_ids=position_ids,
868
+ head_mask=head_mask,
869
+ inputs_embeds=inputs_embeds,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ )
873
+
874
+ sequence_output, pooled_output = outputs[:2]
875
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
876
+
877
+ outputs = (prediction_scores, seq_relationship_score,) + outputs[
878
+ 2:
879
+ ] # add hidden states and attention if they are here
880
+
881
+ if labels is not None and next_sentence_label is not None:
882
+ loss_fct = CrossEntropyLoss()
883
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
884
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
885
+ total_loss = masked_lm_loss + next_sentence_loss
886
+ outputs = (total_loss,) + outputs
887
+
888
+ return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
889
+
890
+
891
+ @add_start_docstrings(
892
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
893
+ )
894
+ class BertLMHeadModel(BertPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+ assert config.is_decoder, "If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True`."
898
+
899
+ self.bert = BertModel(config)
900
+ self.cls = BertOnlyMLMHead(config)
901
+
902
+ self.init_weights()
903
+
904
+ def get_output_embeddings(self):
905
+ return self.cls.predictions.decoder
906
+
907
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
908
+ def forward(
909
+ self,
910
+ input_ids=None,
911
+ attention_mask=None,
912
+ token_type_ids=None,
913
+ position_ids=None,
914
+ head_mask=None,
915
+ inputs_embeds=None,
916
+ labels=None,
917
+ encoder_hidden_states=None,
918
+ encoder_attention_mask=None,
919
+ output_attentions=None,
920
+ output_hidden_states=None,
921
+ **kwargs
922
+ ):
923
+ r"""
924
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
925
+ Labels for computing the left-to-right language modeling loss (next word prediction).
926
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
927
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
928
+ in ``[0, ..., config.vocab_size]``
929
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
930
+ Used to hide legacy arguments that have been deprecated.
931
+
932
+ Returns:
933
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
934
+ ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
935
+ Next token prediction loss.
936
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
937
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
938
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
939
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
940
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
941
+
942
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
943
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
944
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
945
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
946
+
947
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
948
+ heads.
949
+
950
+ Example::
951
+
952
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
953
+ >>> import torch
954
+
955
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
956
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
957
+ >>> config.is_decoder = True
958
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
959
+
960
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
961
+ >>> outputs = model(**inputs)
962
+
963
+ >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
964
+ """
965
+
966
+ outputs = self.bert(
967
+ input_ids,
968
+ attention_mask=attention_mask,
969
+ token_type_ids=token_type_ids,
970
+ position_ids=position_ids,
971
+ head_mask=head_mask,
972
+ inputs_embeds=inputs_embeds,
973
+ encoder_hidden_states=encoder_hidden_states,
974
+ encoder_attention_mask=encoder_attention_mask,
975
+ output_attentions=output_attentions,
976
+ output_hidden_states=output_hidden_states,
977
+ )
978
+
979
+ sequence_output = outputs[0]
980
+ prediction_scores = self.cls(sequence_output)
981
+
982
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
983
+
984
+ if labels is not None:
985
+ # we are doing next-token prediction; shift prediction scores and input ids by one
986
+ prediction_scores = prediction_scores[:, :-1, :].contiguous()
987
+ labels = labels[:, 1:].contiguous()
988
+ loss_fct = CrossEntropyLoss()
989
+ ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
990
+ outputs = (ltr_lm_loss,) + outputs
991
+
992
+ return outputs # (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
993
+
994
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
995
+ input_shape = input_ids.shape
996
+
997
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
998
+ if attention_mask is None:
999
+ attention_mask = input_ids.new_ones(input_shape)
1000
+
1001
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1002
+
1003
+
1004
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1005
+ class BertForMaskedLM(BertPreTrainedModel):
1006
+ def __init__(self, config):
1007
+ super().__init__(config)
1008
+ assert (
1009
+ not config.is_decoder
1010
+ ), "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for bi-directional self-attention."
1011
+
1012
+ self.bert = BertModel(config)
1013
+ self.cls = BertOnlyMLMHead(config)
1014
+
1015
+ self.init_weights()
1016
+
1017
+ def get_output_embeddings(self):
1018
+ return self.cls.predictions.decoder
1019
+
1020
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1021
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1022
+ def forward(
1023
+ self,
1024
+ input_ids=None,
1025
+ attention_mask=None,
1026
+ token_type_ids=None,
1027
+ position_ids=None,
1028
+ head_mask=None,
1029
+ inputs_embeds=None,
1030
+ labels=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ output_attentions=None,
1034
+ output_hidden_states=None,
1035
+ **kwargs
1036
+ ):
1037
+ r"""
1038
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1039
+ Labels for computing the masked language modeling loss.
1040
+ Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
1041
+ Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
1042
+ in ``[0, ..., config.vocab_size]``
1043
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1044
+ Used to hide legacy arguments that have been deprecated.
1045
+
1046
+ Returns:
1047
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1048
+ masked_lm_loss (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
1049
+ Masked language modeling loss.
1050
+ prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
1051
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1052
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1053
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1054
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1055
+
1056
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1057
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1058
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1059
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1060
+
1061
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1062
+ heads.
1063
+ """
1064
+ if "masked_lm_labels" in kwargs:
1065
+ warnings.warn(
1066
+ "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
1067
+ DeprecationWarning,
1068
+ )
1069
+ labels = kwargs.pop("masked_lm_labels")
1070
+ assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
1071
+ assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
1072
+
1073
+ outputs = self.bert(
1074
+ input_ids,
1075
+ attention_mask=attention_mask,
1076
+ token_type_ids=token_type_ids,
1077
+ position_ids=position_ids,
1078
+ head_mask=head_mask,
1079
+ inputs_embeds=inputs_embeds,
1080
+ encoder_hidden_states=encoder_hidden_states,
1081
+ encoder_attention_mask=encoder_attention_mask,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ )
1085
+
1086
+ sequence_output = outputs[0]
1087
+ prediction_scores = self.cls(sequence_output)
1088
+
1089
+ outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
1090
+
1091
+ if labels is not None:
1092
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1093
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1094
+ outputs = (masked_lm_loss,) + outputs
1095
+
1096
+ return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
1097
+
1098
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1099
+ input_shape = input_ids.shape
1100
+ effective_batch_size = input_shape[0]
1101
+
1102
+ # add a dummy token
1103
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1104
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1105
+ dummy_token = torch.full(
1106
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1107
+ )
1108
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1109
+
1110
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1111
+
1112
+
1113
+ @add_start_docstrings(
1114
+ """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
1115
+ )
1116
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1117
+ def __init__(self, config):
1118
+ super().__init__(config)
1119
+
1120
+ self.bert = BertModel(config)
1121
+ self.cls = BertOnlyNSPHead(config)
1122
+
1123
+ self.init_weights()
1124
+
1125
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1126
+ def forward(
1127
+ self,
1128
+ input_ids=None,
1129
+ attention_mask=None,
1130
+ token_type_ids=None,
1131
+ position_ids=None,
1132
+ head_mask=None,
1133
+ inputs_embeds=None,
1134
+ next_sentence_label=None,
1135
+ output_attentions=None,
1136
+ output_hidden_states=None,
1137
+ ):
1138
+ r"""
1139
+ next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1140
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
1141
+ Indices should be in ``[0, 1]``.
1142
+ ``0`` indicates sequence B is a continuation of sequence A,
1143
+ ``1`` indicates sequence B is a random sequence.
1144
+
1145
+ Returns:
1146
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1147
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
1148
+ Next sequence prediction (classification) loss.
1149
+ seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
1150
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
1151
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1152
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1153
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1154
+
1155
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1156
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1157
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1158
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1159
+
1160
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1161
+ heads.
1162
+
1163
+ Examples::
1164
+
1165
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1166
+ >>> import torch
1167
+
1168
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1169
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1170
+
1171
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1172
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1173
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1174
+
1175
+ >>> loss, logits = model(**encoding, next_sentence_label=torch.LongTensor([1]))
1176
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1177
+ """
1178
+
1179
+ outputs = self.bert(
1180
+ input_ids,
1181
+ attention_mask=attention_mask,
1182
+ token_type_ids=token_type_ids,
1183
+ position_ids=position_ids,
1184
+ head_mask=head_mask,
1185
+ inputs_embeds=inputs_embeds,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ )
1189
+
1190
+ pooled_output = outputs[1]
1191
+
1192
+ seq_relationship_score = self.cls(pooled_output)
1193
+
1194
+ outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
1195
+ if next_sentence_label is not None:
1196
+ loss_fct = CrossEntropyLoss()
1197
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1198
+ outputs = (next_sentence_loss,) + outputs
1199
+
1200
+ return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
1201
+
1202
+
1203
+ @add_start_docstrings(
1204
+ """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
1205
+ the pooled output) e.g. for GLUE tasks. """,
1206
+ BERT_START_DOCSTRING,
1207
+ )
1208
+ class BertForSequenceClassification(BertPreTrainedModel):
1209
+ def __init__(self, config):
1210
+ super().__init__(config)
1211
+ self.num_labels = config.num_labels
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1215
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1216
+
1217
+ self.init_weights()
1218
+
1219
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1220
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1221
+ def forward(
1222
+ self,
1223
+ input_ids=None,
1224
+ attention_mask=None,
1225
+ token_type_ids=None,
1226
+ position_ids=None,
1227
+ head_mask=None,
1228
+ inputs_embeds=None,
1229
+ labels=None,
1230
+ output_attentions=None,
1231
+ output_hidden_states=None,
1232
+ ):
1233
+ r"""
1234
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1235
+ Labels for computing the sequence classification/regression loss.
1236
+ Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
1237
+ If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1238
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1239
+
1240
+ Returns:
1241
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1242
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
1243
+ Classification (or regression if config.num_labels==1) loss.
1244
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
1245
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
1246
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1247
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1248
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1249
+
1250
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1251
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1252
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1253
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1254
+
1255
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1256
+ heads.
1257
+ """
1258
+
1259
+ outputs = self.bert(
1260
+ input_ids,
1261
+ attention_mask=attention_mask,
1262
+ token_type_ids=token_type_ids,
1263
+ position_ids=position_ids,
1264
+ head_mask=head_mask,
1265
+ inputs_embeds=inputs_embeds,
1266
+ output_attentions=output_attentions,
1267
+ output_hidden_states=output_hidden_states,
1268
+ )
1269
+
1270
+ pooled_output = outputs[1]
1271
+
1272
+ pooled_output = self.dropout(pooled_output)
1273
+ logits = self.classifier(pooled_output)
1274
+
1275
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1276
+
1277
+ if labels is not None:
1278
+ if self.num_labels == 1:
1279
+ # We are doing regression
1280
+ loss_fct = MSELoss()
1281
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1282
+ else:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1285
+ outputs = (loss,) + outputs
1286
+
1287
+ return outputs # (loss), logits, (hidden_states), (attentions)
1288
+
1289
+
1290
+ @add_start_docstrings(
1291
+ """Bert Model with a multiple choice classification head on top (a linear layer on top of
1292
+ the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
1293
+ BERT_START_DOCSTRING,
1294
+ )
1295
+ class BertForMultipleChoice(BertPreTrainedModel):
1296
+ def __init__(self, config):
1297
+ super().__init__(config)
1298
+
1299
+ self.bert = BertModel(config)
1300
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1301
+ self.classifier = nn.Linear(config.hidden_size, 1)
1302
+
1303
+ self.init_weights()
1304
+
1305
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
1306
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1307
+ def forward(
1308
+ self,
1309
+ input_ids=None,
1310
+ attention_mask=None,
1311
+ token_type_ids=None,
1312
+ position_ids=None,
1313
+ head_mask=None,
1314
+ inputs_embeds=None,
1315
+ labels=None,
1316
+ output_attentions=None,
1317
+ output_hidden_states=None,
1318
+ ):
1319
+ r"""
1320
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1321
+ Labels for computing the multiple choice classification loss.
1322
+ Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
1323
+ of the input tensors. (see `input_ids` above)
1324
+
1325
+ Returns:
1326
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1327
+ loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
1328
+ Classification loss.
1329
+ classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
1330
+ `num_choices` is the second dimension of the input tensors. (see `input_ids` above).
1331
+
1332
+ Classification scores (before SoftMax).
1333
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1334
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1335
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1336
+
1337
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1338
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1339
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1340
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1341
+
1342
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1343
+ heads.
1344
+ """
1345
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1346
+
1347
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1348
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1349
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1350
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1351
+ inputs_embeds = (
1352
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1353
+ if inputs_embeds is not None
1354
+ else None
1355
+ )
1356
+
1357
+ outputs = self.bert(
1358
+ input_ids,
1359
+ attention_mask=attention_mask,
1360
+ token_type_ids=token_type_ids,
1361
+ position_ids=position_ids,
1362
+ head_mask=head_mask,
1363
+ inputs_embeds=inputs_embeds,
1364
+ output_attentions=output_attentions,
1365
+ output_hidden_states=output_hidden_states,
1366
+ )
1367
+
1368
+ pooled_output = outputs[1]
1369
+
1370
+ pooled_output = self.dropout(pooled_output)
1371
+ logits = self.classifier(pooled_output)
1372
+ reshaped_logits = logits.view(-1, num_choices)
1373
+
1374
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
1375
+
1376
+ if labels is not None:
1377
+ loss_fct = CrossEntropyLoss()
1378
+ loss = loss_fct(reshaped_logits, labels)
1379
+ outputs = (loss,) + outputs
1380
+
1381
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
1382
+
1383
+
1384
+ @add_start_docstrings(
1385
+ """Bert Model with a token classification head on top (a linear layer on top of
1386
+ the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
1387
+ BERT_START_DOCSTRING,
1388
+ )
1389
+ class BertForTokenClassification(BertPreTrainedModel):
1390
+ def __init__(self, config):
1391
+ super().__init__(config)
1392
+ self.num_labels = config.num_labels
1393
+
1394
+ self.bert = BertModel(config)
1395
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1396
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1397
+
1398
+ self.init_weights()
1399
+
1400
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1401
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1402
+ def forward(
1403
+ self,
1404
+ input_ids=None,
1405
+ attention_mask=None,
1406
+ token_type_ids=None,
1407
+ position_ids=None,
1408
+ head_mask=None,
1409
+ inputs_embeds=None,
1410
+ labels=None,
1411
+ output_attentions=None,
1412
+ output_hidden_states=None,
1413
+ ):
1414
+ r"""
1415
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
1416
+ Labels for computing the token classification loss.
1417
+ Indices should be in ``[0, ..., config.num_labels - 1]``.
1418
+
1419
+ Returns:
1420
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1421
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
1422
+ Classification loss.
1423
+ scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
1424
+ Classification scores (before SoftMax).
1425
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1426
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1427
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1428
+
1429
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1430
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1431
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1432
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1433
+
1434
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1435
+ heads.
1436
+ """
1437
+
1438
+ outputs = self.bert(
1439
+ input_ids,
1440
+ attention_mask=attention_mask,
1441
+ token_type_ids=token_type_ids,
1442
+ position_ids=position_ids,
1443
+ head_mask=head_mask,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_attentions=output_attentions,
1446
+ output_hidden_states=output_hidden_states,
1447
+ )
1448
+
1449
+ sequence_output = outputs[0]
1450
+
1451
+ sequence_output = self.dropout(sequence_output)
1452
+ logits = self.classifier(sequence_output)
1453
+
1454
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1455
+ if labels is not None:
1456
+ loss_fct = CrossEntropyLoss()
1457
+ # Only keep active parts of the loss
1458
+ if attention_mask is not None:
1459
+ active_loss = attention_mask.view(-1) == 1
1460
+ active_logits = logits.view(-1, self.num_labels)
1461
+ active_labels = torch.where(
1462
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1463
+ )
1464
+ loss = loss_fct(active_logits, active_labels)
1465
+ else:
1466
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1467
+ outputs = (loss,) + outputs
1468
+
1469
+ return outputs # (loss), scores, (hidden_states), (attentions)
1470
+
1471
+
1472
+ @add_start_docstrings(
1473
+ """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1474
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
1475
+ BERT_START_DOCSTRING,
1476
+ )
1477
+ class BertForQuestionAnswering(BertPreTrainedModel):
1478
+ def __init__(self, config):
1479
+ super().__init__(config)
1480
+ self.num_labels = config.num_labels
1481
+
1482
+ self.bert = BertModel(config)
1483
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1484
+
1485
+ self.init_weights()
1486
+
1487
+ @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
1488
+ @add_code_sample_docstrings(tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased")
1489
+ def forward(
1490
+ self,
1491
+ input_ids=None,
1492
+ attention_mask=None,
1493
+ token_type_ids=None,
1494
+ position_ids=None,
1495
+ head_mask=None,
1496
+ inputs_embeds=None,
1497
+ start_positions=None,
1498
+ end_positions=None,
1499
+ output_attentions=None,
1500
+ output_hidden_states=None,
1501
+ ):
1502
+ r"""
1503
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1504
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1505
+ Positions are clamped to the length of the sequence (`sequence_length`).
1506
+ Position outside of the sequence are not taken into account for computing the loss.
1507
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
1508
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1509
+ Positions are clamped to the length of the sequence (`sequence_length`).
1510
+ Position outside of the sequence are not taken into account for computing the loss.
1511
+
1512
+ Returns:
1513
+ :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
1514
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
1515
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
1516
+ start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1517
+ Span-start scores (before SoftMax).
1518
+ end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
1519
+ Span-end scores (before SoftMax).
1520
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
1521
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
1522
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
1523
+
1524
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
1525
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
1526
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
1527
+ :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
1528
+
1529
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
1530
+ heads.
1531
+ """
1532
+
1533
+ outputs = self.bert(
1534
+ input_ids,
1535
+ attention_mask=attention_mask,
1536
+ token_type_ids=token_type_ids,
1537
+ position_ids=position_ids,
1538
+ head_mask=head_mask,
1539
+ inputs_embeds=inputs_embeds,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=output_hidden_states,
1542
+ )
1543
+
1544
+ sequence_output = outputs[0]
1545
+
1546
+ logits = self.qa_outputs(sequence_output)
1547
+ start_logits, end_logits = logits.split(1, dim=-1)
1548
+ start_logits = start_logits.squeeze(-1)
1549
+ end_logits = end_logits.squeeze(-1)
1550
+
1551
+ outputs = (start_logits, end_logits,) + outputs[2:]
1552
+ if start_positions is not None and end_positions is not None:
1553
+ # If we are on multi-GPU, split add a dimension
1554
+ if len(start_positions.size()) > 1:
1555
+ start_positions = start_positions.squeeze(-1)
1556
+ if len(end_positions.size()) > 1:
1557
+ end_positions = end_positions.squeeze(-1)
1558
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1559
+ ignored_index = start_logits.size(1)
1560
+ start_positions.clamp_(0, ignored_index)
1561
+ end_positions.clamp_(0, ignored_index)
1562
+
1563
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1564
+ start_loss = loss_fct(start_logits, start_positions)
1565
+ end_loss = loss_fct(end_logits, end_positions)
1566
+ total_loss = (start_loss + end_loss) / 2
1567
+ outputs = (total_loss,) + outputs
1568
+
1569
+ return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
CGFormer/bert/modeling_utils.py ADDED
@@ -0,0 +1,1268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import logging
19
+ import os
20
+ from typing import Callable, Dict, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import Tensor, device, dtype, nn
24
+ from torch.nn import CrossEntropyLoss
25
+ from torch.nn import functional as F
26
+
27
+ from .activations import get_activation
28
+ from .configuration_utils import PretrainedConfig
29
+ from .file_utils import (
30
+ DUMMY_INPUTS,
31
+ TF2_WEIGHTS_NAME,
32
+ TF_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ cached_path,
35
+ hf_bucket_url,
36
+ is_remote_url,
37
+ )
38
+ from .generation_utils import GenerationMixin
39
+
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ try:
45
+ from torch.nn import Identity
46
+ except ImportError:
47
+ # Older PyTorch compatibility
48
+ class Identity(nn.Module):
49
+ r"""A placeholder identity operator that is argument-insensitive.
50
+ """
51
+
52
+ def __init__(self, *args, **kwargs):
53
+ super().__init__()
54
+
55
+ def forward(self, input):
56
+ return input
57
+
58
+
59
+ def find_pruneable_heads_and_indices(
60
+ heads: List, n_heads: int, head_size: int, already_pruned_heads: set
61
+ ) -> Tuple[set, "torch.LongTensor"]:
62
+ mask = torch.ones(n_heads, head_size)
63
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
64
+ for head in heads:
65
+ # Compute how many pruned heads are before the head and move the index accordingly
66
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
67
+ mask[head] = 0
68
+ mask = mask.view(-1).contiguous().eq(1)
69
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
70
+ return heads, index
71
+
72
+
73
+ class ModuleUtilsMixin:
74
+ """
75
+ A few utilities for torch.nn.Modules, to be used as a mixin.
76
+ """
77
+
78
+ def num_parameters(self, only_trainable: bool = False) -> int:
79
+ """
80
+ Get number of (optionally, trainable) parameters in the module.
81
+ """
82
+ params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
83
+ return sum(p.numel() for p in params)
84
+
85
+ @staticmethod
86
+ def _hook_rss_memory_pre_forward(module, *args, **kwargs):
87
+ try:
88
+ import psutil
89
+ except (ImportError):
90
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
91
+
92
+ process = psutil.Process(os.getpid())
93
+ mem = process.memory_info()
94
+ module.mem_rss_pre_forward = mem.rss
95
+ return None
96
+
97
+ @staticmethod
98
+ def _hook_rss_memory_post_forward(module, *args, **kwargs):
99
+ try:
100
+ import psutil
101
+ except (ImportError):
102
+ raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
103
+
104
+ process = psutil.Process(os.getpid())
105
+ mem = process.memory_info()
106
+ module.mem_rss_post_forward = mem.rss
107
+ mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
108
+ module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
109
+ return None
110
+
111
+ def add_memory_hooks(self):
112
+ """ Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
113
+ Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
114
+ """
115
+ for module in self.modules():
116
+ module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
117
+ module.register_forward_hook(self._hook_rss_memory_post_forward)
118
+ self.reset_memory_hooks_state()
119
+
120
+ def reset_memory_hooks_state(self):
121
+ for module in self.modules():
122
+ module.mem_rss_diff = 0
123
+ module.mem_rss_post_forward = 0
124
+ module.mem_rss_pre_forward = 0
125
+
126
+ @property
127
+ def device(self) -> device:
128
+ """
129
+ Get torch.device from module, assuming that the whole module has one device.
130
+ """
131
+ try:
132
+ return next(self.parameters()).device
133
+ except StopIteration:
134
+ # For nn.DataParallel compatibility in PyTorch 1.5
135
+
136
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
137
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
138
+ return tuples
139
+
140
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
141
+ first_tuple = next(gen)
142
+ return first_tuple[1].device
143
+
144
+ @property
145
+ def dtype(self) -> dtype:
146
+ """
147
+ Get torch.dtype from module, assuming that the whole module has one dtype.
148
+ """
149
+ try:
150
+ return next(self.parameters()).dtype
151
+ except StopIteration:
152
+ # For nn.DataParallel compatibility in PyTorch 1.5
153
+
154
+ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
155
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
156
+ return tuples
157
+
158
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
159
+ first_tuple = next(gen)
160
+ return first_tuple[1].dtype
161
+
162
+ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
163
+ """type: torch.Tensor -> torch.Tensor"""
164
+ if encoder_attention_mask.dim() == 3:
165
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
166
+ if encoder_attention_mask.dim() == 2:
167
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
168
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
169
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
170
+ # /transformer/transformer_layers.py#L270
171
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
172
+ # encoder_extended_attention_mask.transpose(-1, -2))
173
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
174
+
175
+ if self.dtype == torch.float16:
176
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
177
+ elif self.dtype == torch.float32:
178
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
179
+ else:
180
+ raise ValueError(
181
+ "{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
182
+ self.dtype
183
+ )
184
+ )
185
+
186
+ return encoder_extended_attention_mask
187
+
188
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple, device: device) -> Tensor:
189
+ """Makes broadcastable attention mask and causal mask so that future and maked tokens are ignored.
190
+
191
+ Arguments:
192
+ attention_mask: torch.Tensor with 1 indicating tokens to ATTEND to
193
+ input_shape: tuple, shape of input_ids
194
+ device: torch.Device, usually self.device
195
+
196
+ Returns:
197
+ torch.Tensor with dtype of attention_mask.dtype
198
+ """
199
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
200
+ # ourselves in which case we just need to make it broadcastable to all heads.
201
+ if attention_mask.dim() == 3:
202
+ extended_attention_mask = attention_mask[:, None, :, :]
203
+ elif attention_mask.dim() == 2:
204
+ # Provided a padding mask of dimensions [batch_size, seq_length]
205
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
206
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
207
+ if self.config.is_decoder:
208
+ batch_size, seq_length = input_shape
209
+ seq_ids = torch.arange(seq_length, device=device)
210
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
211
+ # causal and attention masks must have same type with pytorch version < 1.3
212
+ causal_mask = causal_mask.to(attention_mask.dtype)
213
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
214
+ else:
215
+ extended_attention_mask = attention_mask[:, None, None, :]
216
+ else:
217
+ raise ValueError(
218
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
219
+ input_shape, attention_mask.shape
220
+ )
221
+ )
222
+
223
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
224
+ # masked positions, this operation will create a tensor which is 0.0 for
225
+ # positions we want to attend and -10000.0 for masked positions.
226
+ # Since we are adding it to the raw scores before the softmax, this is
227
+ # effectively the same as removing these entirely.
228
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
229
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
230
+ return extended_attention_mask
231
+
232
+ def get_head_mask(self, head_mask: Tensor, num_hidden_layers: int, is_attention_chunked: bool = False) -> Tensor:
233
+ """
234
+ # Prepare head mask if needed
235
+ # 1.0 in head_mask indicate we keep the head
236
+ attention_probs has shape bsz x n_heads x N x N
237
+ Arguments:
238
+ head_mask: torch.Tensor or None: has shape [num_heads] or [num_hidden_layers x num_heads]
239
+ num_hidden_layers: int
240
+ Returns:
241
+ Tensor of shape shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
242
+ or list with [None] for each layer
243
+ """
244
+ if head_mask is not None:
245
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
246
+ if is_attention_chunked is True:
247
+ head_mask = head_mask.unsqueeze(-1)
248
+ else:
249
+ head_mask = [None] * num_hidden_layers
250
+
251
+ return head_mask
252
+
253
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
254
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
255
+ if head_mask.dim() == 1:
256
+ head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
257
+ head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
258
+ elif head_mask.dim() == 2:
259
+ head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
260
+ assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
261
+ head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
262
+ return head_mask
263
+
264
+
265
+ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
266
+ r""" Base class for all models.
267
+
268
+ :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
269
+ as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
270
+
271
+ Class attributes (overridden by derived classes):
272
+ - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
273
+ - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
274
+
275
+ - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
276
+ - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
277
+ - ``path``: a path (string) to the TensorFlow checkpoint.
278
+
279
+ - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
280
+ """
281
+ config_class = None
282
+ base_model_prefix = ""
283
+
284
+ @property
285
+ def dummy_inputs(self):
286
+ """ Dummy inputs to do a forward pass in the network.
287
+
288
+ Returns:
289
+ torch.Tensor with dummy inputs
290
+ """
291
+ return {"input_ids": torch.tensor(DUMMY_INPUTS)}
292
+
293
+ def __init__(self, config, *inputs, **kwargs):
294
+ super().__init__()
295
+ if not isinstance(config, PretrainedConfig):
296
+ raise ValueError(
297
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
298
+ "To create a model from a pretrained model use "
299
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
300
+ self.__class__.__name__, self.__class__.__name__
301
+ )
302
+ )
303
+ # Save config in model
304
+ self.config = config
305
+
306
+ @property
307
+ def base_model(self):
308
+ return getattr(self, self.base_model_prefix, self)
309
+
310
+ def get_input_embeddings(self):
311
+ """
312
+ Returns the model's input embeddings.
313
+
314
+ Returns:
315
+ :obj:`nn.Module`:
316
+ A torch module mapping vocabulary to hidden states.
317
+ """
318
+ base_model = getattr(self, self.base_model_prefix, self)
319
+ if base_model is not self:
320
+ return base_model.get_input_embeddings()
321
+ else:
322
+ raise NotImplementedError
323
+
324
+ def set_input_embeddings(self, value: nn.Module):
325
+ """
326
+ Set model's input embeddings
327
+
328
+ Args:
329
+ value (:obj:`nn.Module`):
330
+ A module mapping vocabulary to hidden states.
331
+ """
332
+ base_model = getattr(self, self.base_model_prefix, self)
333
+ if base_model is not self:
334
+ base_model.set_input_embeddings(value)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ def get_output_embeddings(self):
339
+ """
340
+ Returns the model's output embeddings.
341
+
342
+ Returns:
343
+ :obj:`nn.Module`:
344
+ A torch module mapping hidden states to vocabulary.
345
+ """
346
+ return None # Overwrite for models with output embeddings
347
+
348
+ def tie_weights(self):
349
+ """
350
+ Tie the weights between the input embeddings and the output embeddings.
351
+ If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
352
+ the weights instead.
353
+ """
354
+ output_embeddings = self.get_output_embeddings()
355
+ if output_embeddings is not None:
356
+ self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
357
+
358
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
359
+ """ Tie or clone module weights depending of whether we are using TorchScript or not
360
+ """
361
+ if self.config.torchscript:
362
+ output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
363
+ else:
364
+ output_embeddings.weight = input_embeddings.weight
365
+
366
+ if getattr(output_embeddings, "bias", None) is not None:
367
+ output_embeddings.bias.data = torch.nn.functional.pad(
368
+ output_embeddings.bias.data,
369
+ (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],),
370
+ "constant",
371
+ 0,
372
+ )
373
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
374
+ output_embeddings.out_features = input_embeddings.num_embeddings
375
+
376
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None):
377
+ """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
378
+ Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
379
+
380
+ Arguments:
381
+
382
+ new_num_tokens: (`optional`) int:
383
+ New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
384
+ If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
385
+
386
+ Return: ``torch.nn.Embeddings``
387
+ Pointer to the input tokens Embeddings Module of the model
388
+ """
389
+ base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
390
+ model_embeds = base_model._resize_token_embeddings(new_num_tokens)
391
+ if new_num_tokens is None:
392
+ return model_embeds
393
+
394
+ # Update base model and current model config
395
+ self.config.vocab_size = new_num_tokens
396
+ base_model.vocab_size = new_num_tokens
397
+
398
+ # Tie weights again if needed
399
+ self.tie_weights()
400
+
401
+ return model_embeds
402
+
403
+ def _resize_token_embeddings(self, new_num_tokens):
404
+ old_embeddings = self.get_input_embeddings()
405
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
406
+ self.set_input_embeddings(new_embeddings)
407
+ return self.get_input_embeddings()
408
+
409
+ def _get_resized_embeddings(
410
+ self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
411
+ ) -> torch.nn.Embedding:
412
+ """ Build a resized Embedding Module from a provided token Embedding Module.
413
+ Increasing the size will add newly initialized vectors at the end
414
+ Reducing the size will remove vectors from the end
415
+
416
+ Args:
417
+ old_embeddings: ``torch.nn.Embedding``
418
+ Old embeddings to be resized.
419
+ new_num_tokens: (`optional`) int
420
+ New number of tokens in the embedding matrix.
421
+ Increasing the size will add newly initialized vectors at the end
422
+ Reducing the size will remove vectors from the end
423
+ If not provided or None: return the provided token Embedding Module.
424
+ Return: ``torch.nn.Embedding``
425
+ Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
426
+ """
427
+ if new_num_tokens is None:
428
+ return old_embeddings
429
+
430
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
431
+ if old_num_tokens == new_num_tokens:
432
+ return old_embeddings
433
+
434
+ # Build new embeddings
435
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
436
+ new_embeddings.to(old_embeddings.weight.device)
437
+
438
+ # initialize all new embeddings (in particular added tokens)
439
+ self._init_weights(new_embeddings)
440
+
441
+ # Copy token embeddings from the previous weights
442
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
443
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
444
+
445
+ return new_embeddings
446
+
447
+ def init_weights(self):
448
+ """ Initialize and prunes weights if needed. """
449
+ # Initialize weights
450
+ self.apply(self._init_weights)
451
+
452
+ # Prune heads if needed
453
+ if self.config.pruned_heads:
454
+ self.prune_heads(self.config.pruned_heads)
455
+
456
+ # Tie weights if needed
457
+ self.tie_weights()
458
+
459
+ def prune_heads(self, heads_to_prune: Dict):
460
+ """ Prunes heads of the base model.
461
+
462
+ Arguments:
463
+
464
+ heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
465
+ E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
466
+ """
467
+ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
468
+ for layer, heads in heads_to_prune.items():
469
+ union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
470
+ self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
471
+
472
+ self.base_model._prune_heads(heads_to_prune)
473
+
474
+ def save_pretrained(self, save_directory):
475
+ """ Save a model and its configuration file to a directory, so that it
476
+ can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
477
+
478
+ Arguments:
479
+ save_directory: directory to which to save.
480
+ """
481
+ if os.path.isfile(save_directory):
482
+ logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
483
+ return
484
+ os.makedirs(save_directory, exist_ok=True)
485
+
486
+ # Only save the model itself if we are using distributed training
487
+ model_to_save = self.module if hasattr(self, "module") else self
488
+
489
+ # Attach architecture to the config
490
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
491
+
492
+ # If we save using the predefined names, we can load using `from_pretrained`
493
+ output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
494
+
495
+ if getattr(self.config, "xla_device", False):
496
+ import torch_xla.core.xla_model as xm
497
+
498
+ if xm.is_master_ordinal():
499
+ # Save configuration file
500
+ model_to_save.config.save_pretrained(save_directory)
501
+ # xm.save takes care of saving only from master
502
+ xm.save(model_to_save.state_dict(), output_model_file)
503
+ else:
504
+ model_to_save.config.save_pretrained(save_directory)
505
+ torch.save(model_to_save.state_dict(), output_model_file)
506
+
507
+ logger.info("Model weights saved in {}".format(output_model_file))
508
+
509
+ @classmethod
510
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
511
+ r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
512
+
513
+ The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
514
+ To train the model, you should first set it back in training mode with ``model.train()``
515
+
516
+ The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
517
+ It is up to you to train those weights with a downstream fine-tuning task.
518
+
519
+ The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
520
+
521
+ Parameters:
522
+ pretrained_model_name_or_path: either:
523
+ - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
524
+ - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
525
+ - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
526
+ - a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
527
+ - None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
528
+
529
+ model_args: (`optional`) Sequence of positional arguments:
530
+ All remaning positional arguments will be passed to the underlying model's ``__init__`` method
531
+
532
+ config: (`optional`) one of:
533
+ - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
534
+ - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
535
+
536
+ Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
537
+ - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
538
+ - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
539
+ - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
540
+
541
+ state_dict: (`optional`) dict:
542
+ an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
543
+ This option can be used if you want to create a model from a pretrained configuration but load your own weights.
544
+ In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
545
+
546
+ cache_dir: (`optional`) string:
547
+ Path to a directory in which a downloaded pre-trained model
548
+ configuration should be cached if the standard cache should not be used.
549
+
550
+ force_download: (`optional`) boolean, default False:
551
+ Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
552
+
553
+ resume_download: (`optional`) boolean, default False:
554
+ Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
555
+
556
+ proxies: (`optional`) dict, default None:
557
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
558
+ The proxies are used on each request.
559
+
560
+ output_loading_info: (`optional`) boolean:
561
+ Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
562
+
563
+ kwargs: (`optional`) Remaining dictionary of keyword arguments:
564
+ Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
565
+
566
+ - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
567
+ - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
568
+
569
+ Examples::
570
+
571
+ # For example purposes. Not runnable.
572
+ model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
573
+ model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
574
+ model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
575
+ assert model.config.output_attention == True
576
+ # Loading from a TF checkpoint file instead of a PyTorch model (slower)
577
+ config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
578
+ model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
579
+
580
+ """
581
+ config = kwargs.pop("config", None)
582
+ state_dict = kwargs.pop("state_dict", None)
583
+ cache_dir = kwargs.pop("cache_dir", None)
584
+ from_tf = kwargs.pop("from_tf", False)
585
+ force_download = kwargs.pop("force_download", False)
586
+ resume_download = kwargs.pop("resume_download", False)
587
+ proxies = kwargs.pop("proxies", None)
588
+ output_loading_info = kwargs.pop("output_loading_info", False)
589
+ local_files_only = kwargs.pop("local_files_only", False)
590
+ use_cdn = kwargs.pop("use_cdn", True)
591
+
592
+ # Load config if we don't provide a configuration
593
+ if not isinstance(config, PretrainedConfig):
594
+ config_path = config if config is not None else pretrained_model_name_or_path
595
+ config, model_kwargs = cls.config_class.from_pretrained(
596
+ config_path,
597
+ *model_args,
598
+ cache_dir=cache_dir,
599
+ return_unused_kwargs=True,
600
+ force_download=force_download,
601
+ resume_download=resume_download,
602
+ proxies=proxies,
603
+ local_files_only=local_files_only,
604
+ **kwargs,
605
+ )
606
+ else:
607
+ model_kwargs = kwargs
608
+
609
+ # Load model
610
+ if pretrained_model_name_or_path is not None:
611
+ if os.path.isdir(pretrained_model_name_or_path):
612
+ if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
613
+ # Load from a TF 1.0 checkpoint
614
+ archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
615
+ elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
616
+ # Load from a TF 2.0 checkpoint
617
+ archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
618
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
619
+ # Load from a PyTorch checkpoint
620
+ archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
621
+ else:
622
+ raise EnvironmentError(
623
+ "Error no file named {} found in directory {} or `from_tf` set to False".format(
624
+ [WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME + ".index"],
625
+ pretrained_model_name_or_path,
626
+ )
627
+ )
628
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
629
+ archive_file = pretrained_model_name_or_path
630
+ elif os.path.isfile(pretrained_model_name_or_path + ".index"):
631
+ assert (
632
+ from_tf
633
+ ), "We found a TensorFlow checkpoint at {}, please set from_tf to True to load from this checkpoint".format(
634
+ pretrained_model_name_or_path + ".index"
635
+ )
636
+ archive_file = pretrained_model_name_or_path + ".index"
637
+ else:
638
+ archive_file = hf_bucket_url(
639
+ pretrained_model_name_or_path,
640
+ filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
641
+ use_cdn=use_cdn,
642
+ )
643
+
644
+ try:
645
+ # Load from URL or cache if already cached
646
+ resolved_archive_file = cached_path(
647
+ archive_file,
648
+ cache_dir=cache_dir,
649
+ force_download=force_download,
650
+ proxies=proxies,
651
+ resume_download=resume_download,
652
+ local_files_only=local_files_only,
653
+ )
654
+ if resolved_archive_file is None:
655
+ raise EnvironmentError
656
+ except EnvironmentError:
657
+ msg = (
658
+ f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
659
+ f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
660
+ f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
661
+ )
662
+ raise EnvironmentError(msg)
663
+
664
+ if resolved_archive_file == archive_file:
665
+ logger.info("loading weights file {}".format(archive_file))
666
+ else:
667
+ logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
668
+ else:
669
+ resolved_archive_file = None
670
+
671
+ # Instantiate model.
672
+ model = cls(config, *model_args, **model_kwargs)
673
+
674
+ if state_dict is None and not from_tf:
675
+ try:
676
+ state_dict = torch.load(resolved_archive_file, map_location="cpu")
677
+ except Exception:
678
+ raise OSError(
679
+ "Unable to load weights from pytorch checkpoint file. "
680
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. "
681
+ )
682
+
683
+ missing_keys = []
684
+ unexpected_keys = []
685
+ error_msgs = []
686
+
687
+ if from_tf:
688
+ if resolved_archive_file.endswith(".index"):
689
+ # Load from a TensorFlow 1.X checkpoint - provided by original authors
690
+ model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
691
+ else:
692
+ # Load from our TensorFlow 2.0 checkpoints
693
+ try:
694
+ from transformers import load_tf2_checkpoint_in_pytorch_model
695
+
696
+ model = load_tf2_checkpoint_in_pytorch_model(model, resolved_archive_file, allow_missing_keys=True)
697
+ except ImportError:
698
+ logger.error(
699
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
700
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
701
+ )
702
+ raise
703
+ else:
704
+ # Convert old format to new format if needed from a PyTorch state_dict
705
+ old_keys = []
706
+ new_keys = []
707
+ for key in state_dict.keys():
708
+ new_key = None
709
+ if "gamma" in key:
710
+ new_key = key.replace("gamma", "weight")
711
+ if "beta" in key:
712
+ new_key = key.replace("beta", "bias")
713
+ if new_key:
714
+ old_keys.append(key)
715
+ new_keys.append(new_key)
716
+ for old_key, new_key in zip(old_keys, new_keys):
717
+ state_dict[new_key] = state_dict.pop(old_key)
718
+
719
+ # copy state_dict so _load_from_state_dict can modify it
720
+ metadata = getattr(state_dict, "_metadata", None)
721
+ state_dict = state_dict.copy()
722
+ if metadata is not None:
723
+ state_dict._metadata = metadata
724
+
725
+ ##############################################################################################
726
+ # Print out state_dict's contents: keys
727
+ '''
728
+ for key, _ in state_dict.items():
729
+ print(key)
730
+ '''
731
+ ##############################################################################################
732
+
733
+
734
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
735
+ # so we need to apply the function recursively.
736
+ def load(module: nn.Module, prefix=""):
737
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
738
+ module._load_from_state_dict(
739
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs,
740
+ )
741
+ for name, child in module._modules.items():
742
+ if child is not None:
743
+ load(child, prefix + name + ".")
744
+
745
+ # Make sure we are able to load base models as well as derived models (with heads)
746
+ start_prefix = ""
747
+ model_to_load = model
748
+ has_prefix_module = any(s.startswith(cls.base_model_prefix) for s in state_dict.keys())
749
+ if not hasattr(model, cls.base_model_prefix) and has_prefix_module:
750
+ start_prefix = cls.base_model_prefix + "."
751
+ if hasattr(model, cls.base_model_prefix) and not has_prefix_module:
752
+ model_to_load = getattr(model, cls.base_model_prefix)
753
+
754
+ load(model_to_load, prefix=start_prefix)
755
+
756
+ if model.__class__.__name__ != model_to_load.__class__.__name__:
757
+ base_model_state_dict = model_to_load.state_dict().keys()
758
+ head_model_state_dict_without_base_prefix = [
759
+ key.split(cls.base_model_prefix + ".")[-1] for key in model.state_dict().keys()
760
+ ]
761
+
762
+ missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
763
+
764
+ if len(unexpected_keys) > 0:
765
+ logger.warning(
766
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
767
+ f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
768
+ f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
769
+ f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
770
+ f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
771
+ f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
772
+ )
773
+ else:
774
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
775
+ if len(missing_keys) > 0:
776
+ logger.warning(
777
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
778
+ f"and are newly initialized: {missing_keys}\n"
779
+ f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
780
+ )
781
+ else:
782
+ logger.info(
783
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
784
+ f"If your task is similar to the task the model of the ckeckpoint was trained on, "
785
+ f"you can already use {model.__class__.__name__} for predictions without further training."
786
+ )
787
+ if len(error_msgs) > 0:
788
+ raise RuntimeError(
789
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
790
+ model.__class__.__name__, "\n\t".join(error_msgs)
791
+ )
792
+ )
793
+ model.tie_weights() # make sure token embedding weights are still tied if needed
794
+
795
+ # Set model in evaluation mode to deactivate DropOut modules by default
796
+ model.eval()
797
+
798
+ if output_loading_info:
799
+ loading_info = {
800
+ "missing_keys": missing_keys,
801
+ "unexpected_keys": unexpected_keys,
802
+ "error_msgs": error_msgs,
803
+ }
804
+ return model, loading_info
805
+
806
+ if hasattr(config, "xla_device") and config.xla_device:
807
+ import torch_xla.core.xla_model as xm
808
+
809
+ model = xm.send_cpu_data_to_device(model, xm.xla_device())
810
+ model.to(xm.xla_device())
811
+
812
+ return model
813
+
814
+
815
+ class Conv1D(nn.Module):
816
+ def __init__(self, nf, nx):
817
+ """ Conv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
818
+ Basically works like a Linear layer but the weights are transposed
819
+ """
820
+ super().__init__()
821
+ self.nf = nf
822
+ w = torch.empty(nx, nf)
823
+ nn.init.normal_(w, std=0.02)
824
+ self.weight = nn.Parameter(w)
825
+ self.bias = nn.Parameter(torch.zeros(nf))
826
+
827
+ def forward(self, x):
828
+ size_out = x.size()[:-1] + (self.nf,)
829
+ x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
830
+ x = x.view(*size_out)
831
+ return x
832
+
833
+
834
+ class PoolerStartLogits(nn.Module):
835
+ """ Compute SQuAD start_logits from sequence hidden states. """
836
+
837
+ def __init__(self, config):
838
+ super().__init__()
839
+ self.dense = nn.Linear(config.hidden_size, 1)
840
+
841
+ def forward(self, hidden_states, p_mask=None):
842
+ """ Args:
843
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
844
+ invalid position mask such as query and special symbols (PAD, SEP, CLS)
845
+ 1.0 means token should be masked.
846
+ """
847
+ x = self.dense(hidden_states).squeeze(-1)
848
+
849
+ if p_mask is not None:
850
+ if next(self.parameters()).dtype == torch.float16:
851
+ x = x * (1 - p_mask) - 65500 * p_mask
852
+ else:
853
+ x = x * (1 - p_mask) - 1e30 * p_mask
854
+
855
+ return x
856
+
857
+
858
+ class PoolerEndLogits(nn.Module):
859
+ """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.
860
+ """
861
+
862
+ def __init__(self, config):
863
+ super().__init__()
864
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
865
+ self.activation = nn.Tanh()
866
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
867
+ self.dense_1 = nn.Linear(config.hidden_size, 1)
868
+
869
+ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
870
+ """ Args:
871
+ One of ``start_states``, ``start_positions`` should be not None.
872
+ If both are set, ``start_positions`` overrides ``start_states``.
873
+
874
+ **start_states**: ``torch.LongTensor`` of shape identical to hidden_states
875
+ hidden states of the first tokens for the labeled span.
876
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
877
+ position of the first token for the labeled span:
878
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
879
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
880
+ 1.0 means token should be masked.
881
+ """
882
+ assert (
883
+ start_states is not None or start_positions is not None
884
+ ), "One of start_states, start_positions should be not None"
885
+ if start_positions is not None:
886
+ slen, hsz = hidden_states.shape[-2:]
887
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
888
+ start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
889
+ start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
890
+
891
+ x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
892
+ x = self.activation(x)
893
+ x = self.LayerNorm(x)
894
+ x = self.dense_1(x).squeeze(-1)
895
+
896
+ if p_mask is not None:
897
+ if next(self.parameters()).dtype == torch.float16:
898
+ x = x * (1 - p_mask) - 65500 * p_mask
899
+ else:
900
+ x = x * (1 - p_mask) - 1e30 * p_mask
901
+
902
+ return x
903
+
904
+
905
+ class PoolerAnswerClass(nn.Module):
906
+ """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
907
+
908
+ def __init__(self, config):
909
+ super().__init__()
910
+ self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
911
+ self.activation = nn.Tanh()
912
+ self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
913
+
914
+ def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
915
+ """
916
+ Args:
917
+ One of ``start_states``, ``start_positions`` should be not None.
918
+ If both are set, ``start_positions`` overrides ``start_states``.
919
+
920
+ **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
921
+ hidden states of the first tokens for the labeled span.
922
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
923
+ position of the first token for the labeled span.
924
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
925
+ position of the CLS token. If None, take the last token.
926
+
927
+ note(Original repo):
928
+ no dependency on end_feature so that we can obtain one single `cls_logits`
929
+ for each sample
930
+ """
931
+ hsz = hidden_states.shape[-1]
932
+ assert (
933
+ start_states is not None or start_positions is not None
934
+ ), "One of start_states, start_positions should be not None"
935
+ if start_positions is not None:
936
+ start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
937
+ start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
938
+
939
+ if cls_index is not None:
940
+ cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
941
+ cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
942
+ else:
943
+ cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
944
+
945
+ x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
946
+ x = self.activation(x)
947
+ x = self.dense_1(x).squeeze(-1)
948
+
949
+ return x
950
+
951
+
952
+ class SQuADHead(nn.Module):
953
+ r""" A SQuAD head inspired by XLNet.
954
+
955
+ Parameters:
956
+ config (:class:`~transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
957
+
958
+ Inputs:
959
+ **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
960
+ hidden states of sequence tokens
961
+ **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
962
+ position of the first token for the labeled span.
963
+ **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
964
+ position of the last token for the labeled span.
965
+ **cls_index**: torch.LongTensor of shape ``(batch_size,)``
966
+ position of the CLS token. If None, take the last token.
967
+ **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
968
+ Whether the question has a possible answer in the paragraph or not.
969
+ **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
970
+ Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
971
+ 1.0 means token should be masked.
972
+
973
+ Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
974
+ **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
975
+ Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
976
+ **start_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
977
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
978
+ Log probabilities for the top config.start_n_top start token possibilities (beam-search).
979
+ **start_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
980
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
981
+ Indices for the top config.start_n_top start token possibilities (beam-search).
982
+ **end_top_log_probs**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
983
+ ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
984
+ Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
985
+ **end_top_index**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
986
+ ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
987
+ Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
988
+ **cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
989
+ ``torch.FloatTensor`` of shape ``(batch_size,)``
990
+ Log probabilities for the ``is_impossible`` label of the answers.
991
+ """
992
+
993
+ def __init__(self, config):
994
+ super().__init__()
995
+ self.start_n_top = config.start_n_top
996
+ self.end_n_top = config.end_n_top
997
+
998
+ self.start_logits = PoolerStartLogits(config)
999
+ self.end_logits = PoolerEndLogits(config)
1000
+ self.answer_class = PoolerAnswerClass(config)
1001
+
1002
+ def forward(
1003
+ self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None,
1004
+ ):
1005
+ outputs = ()
1006
+
1007
+ start_logits = self.start_logits(hidden_states, p_mask=p_mask)
1008
+
1009
+ if start_positions is not None and end_positions is not None:
1010
+ # If we are on multi-GPU, let's remove the dimension added by batch splitting
1011
+ for x in (start_positions, end_positions, cls_index, is_impossible):
1012
+ if x is not None and x.dim() > 1:
1013
+ x.squeeze_(-1)
1014
+
1015
+ # during training, compute the end logits based on the ground truth of the start position
1016
+ end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
1017
+
1018
+ loss_fct = CrossEntropyLoss()
1019
+ start_loss = loss_fct(start_logits, start_positions)
1020
+ end_loss = loss_fct(end_logits, end_positions)
1021
+ total_loss = (start_loss + end_loss) / 2
1022
+
1023
+ if cls_index is not None and is_impossible is not None:
1024
+ # Predict answerability from the representation of CLS and START
1025
+ cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
1026
+ loss_fct_cls = nn.BCEWithLogitsLoss()
1027
+ cls_loss = loss_fct_cls(cls_logits, is_impossible)
1028
+
1029
+ # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
1030
+ total_loss += cls_loss * 0.5
1031
+
1032
+ outputs = (total_loss,) + outputs
1033
+
1034
+ else:
1035
+ # during inference, compute the end logits based on beam search
1036
+ bsz, slen, hsz = hidden_states.size()
1037
+ start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
1038
+
1039
+ start_top_log_probs, start_top_index = torch.topk(
1040
+ start_log_probs, self.start_n_top, dim=-1
1041
+ ) # shape (bsz, start_n_top)
1042
+ start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
1043
+ start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
1044
+ start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
1045
+
1046
+ hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
1047
+ start_states
1048
+ ) # shape (bsz, slen, start_n_top, hsz)
1049
+ p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
1050
+ end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
1051
+ end_log_probs = F.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
1052
+
1053
+ end_top_log_probs, end_top_index = torch.topk(
1054
+ end_log_probs, self.end_n_top, dim=1
1055
+ ) # shape (bsz, end_n_top, start_n_top)
1056
+ end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
1057
+ end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
1058
+
1059
+ start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
1060
+ cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
1061
+
1062
+ outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits,) + outputs
1063
+
1064
+ # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
1065
+ # or (if labels are provided) (total_loss,)
1066
+ return outputs
1067
+
1068
+
1069
+ class SequenceSummary(nn.Module):
1070
+ r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
1071
+ Args of the config class:
1072
+ summary_type:
1073
+ - 'last' => [default] take the last token hidden state (like XLNet)
1074
+ - 'first' => take the first token hidden state (like Bert)
1075
+ - 'mean' => take the mean of all tokens hidden states
1076
+ - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
1077
+ - 'attn' => Not implemented now, use multi-head attention
1078
+ summary_use_proj: Add a projection after the vector extraction
1079
+ summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
1080
+ summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
1081
+ summary_first_dropout: Add a dropout before the projection and activation
1082
+ summary_last_dropout: Add a dropout after the projection and activation
1083
+ """
1084
+
1085
+ def __init__(self, config: PretrainedConfig):
1086
+ super().__init__()
1087
+
1088
+ self.summary_type = getattr(config, "summary_type", "last")
1089
+ if self.summary_type == "attn":
1090
+ # We should use a standard multi-head attention module with absolute positional embedding for that.
1091
+ # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
1092
+ # We can probably just use the multi-head attention module of PyTorch >=1.1.0
1093
+ raise NotImplementedError
1094
+
1095
+ self.summary = Identity()
1096
+ if hasattr(config, "summary_use_proj") and config.summary_use_proj:
1097
+ if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
1098
+ num_classes = config.num_labels
1099
+ else:
1100
+ num_classes = config.hidden_size
1101
+ self.summary = nn.Linear(config.hidden_size, num_classes)
1102
+
1103
+ activation_string = getattr(config, "summary_activation", None)
1104
+ self.activation: Callable = (get_activation(activation_string) if activation_string else Identity())
1105
+
1106
+ self.first_dropout = Identity()
1107
+ if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
1108
+ self.first_dropout = nn.Dropout(config.summary_first_dropout)
1109
+
1110
+ self.last_dropout = Identity()
1111
+ if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
1112
+ self.last_dropout = nn.Dropout(config.summary_last_dropout)
1113
+
1114
+ def forward(self, hidden_states, cls_index=None):
1115
+ """ hidden_states: float Tensor in shape [bsz, ..., seq_len, hidden_size], the hidden-states of the last layer.
1116
+ cls_index: [optional] position of the classification token if summary_type == 'cls_index',
1117
+ shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
1118
+ if summary_type == 'cls_index' and cls_index is None:
1119
+ we take the last token of the sequence as classification token
1120
+ """
1121
+ if self.summary_type == "last":
1122
+ output = hidden_states[:, -1]
1123
+ elif self.summary_type == "first":
1124
+ output = hidden_states[:, 0]
1125
+ elif self.summary_type == "mean":
1126
+ output = hidden_states.mean(dim=1)
1127
+ elif self.summary_type == "cls_index":
1128
+ if cls_index is None:
1129
+ cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)
1130
+ else:
1131
+ cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
1132
+ cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
1133
+ # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
1134
+ output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
1135
+ elif self.summary_type == "attn":
1136
+ raise NotImplementedError
1137
+
1138
+ output = self.first_dropout(output)
1139
+ output = self.summary(output)
1140
+ output = self.activation(output)
1141
+ output = self.last_dropout(output)
1142
+
1143
+ return output
1144
+
1145
+
1146
+ def prune_linear_layer(layer, index, dim=0):
1147
+ """ Prune a linear layer (a model parameters) to keep only entries in index.
1148
+ Return the pruned layer as a new layer with requires_grad=True.
1149
+ Used to remove heads.
1150
+ """
1151
+ index = index.to(layer.weight.device)
1152
+ W = layer.weight.index_select(dim, index).clone().detach()
1153
+ if layer.bias is not None:
1154
+ if dim == 1:
1155
+ b = layer.bias.clone().detach()
1156
+ else:
1157
+ b = layer.bias[index].clone().detach()
1158
+ new_size = list(layer.weight.size())
1159
+ new_size[dim] = len(index)
1160
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1161
+ new_layer.weight.requires_grad = False
1162
+ new_layer.weight.copy_(W.contiguous())
1163
+ new_layer.weight.requires_grad = True
1164
+ if layer.bias is not None:
1165
+ new_layer.bias.requires_grad = False
1166
+ new_layer.bias.copy_(b.contiguous())
1167
+ new_layer.bias.requires_grad = True
1168
+ return new_layer
1169
+
1170
+
1171
+ def prune_conv1d_layer(layer, index, dim=1):
1172
+ """ Prune a Conv1D layer (a model parameters) to keep only entries in index.
1173
+ A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
1174
+ Return the pruned layer as a new layer with requires_grad=True.
1175
+ Used to remove heads.
1176
+ """
1177
+ index = index.to(layer.weight.device)
1178
+ W = layer.weight.index_select(dim, index).clone().detach()
1179
+ if dim == 0:
1180
+ b = layer.bias.clone().detach()
1181
+ else:
1182
+ b = layer.bias[index].clone().detach()
1183
+ new_size = list(layer.weight.size())
1184
+ new_size[dim] = len(index)
1185
+ new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
1186
+ new_layer.weight.requires_grad = False
1187
+ new_layer.weight.copy_(W.contiguous())
1188
+ new_layer.weight.requires_grad = True
1189
+ new_layer.bias.requires_grad = False
1190
+ new_layer.bias.copy_(b.contiguous())
1191
+ new_layer.bias.requires_grad = True
1192
+ return new_layer
1193
+
1194
+
1195
+ def prune_layer(layer, index, dim=None):
1196
+ """ Prune a Conv1D or nn.Linear layer (a model parameters) to keep only entries in index.
1197
+ Return the pruned layer as a new layer with requires_grad=True.
1198
+ Used to remove heads.
1199
+ """
1200
+ if isinstance(layer, nn.Linear):
1201
+ return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
1202
+ elif isinstance(layer, Conv1D):
1203
+ return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
1204
+ else:
1205
+ raise ValueError("Can't prune layer of class {}".format(layer.__class__))
1206
+
1207
+
1208
+ def apply_chunking_to_forward(
1209
+ chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
1210
+ ) -> torch.Tensor:
1211
+ """
1212
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
1213
+ It then applies a layer `forward_fn` to each chunk independently to save memory.
1214
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the
1215
+ same result as not applying it.
1216
+
1217
+ Args:
1218
+ chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
1219
+ chunk_dim: int - the dimension over which the input_tensors should be chunked
1220
+ forward_fn: fn - the forward fn of the model
1221
+ input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
1222
+ Returns:
1223
+ a Tensor with the same shape the foward_fn would have given if applied
1224
+
1225
+
1226
+ Examples::
1227
+
1228
+ # rename the usual forward() fn to forward_chunk()
1229
+ def forward_chunk(self, hidden_states):
1230
+ hidden_states = self.decoder(hidden_states)
1231
+ return hidden_states
1232
+
1233
+ # implement a chunked forward function
1234
+ def forward(self, hidden_states):
1235
+ return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
1236
+ """
1237
+
1238
+ assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
1239
+ tensor_shape = input_tensors[0].shape
1240
+ assert all(
1241
+ input_tensor.shape == tensor_shape for input_tensor in input_tensors
1242
+ ), "All input tenors have to be of the same shape"
1243
+
1244
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
1245
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1246
+ assert num_args_in_forward_chunk_fn == len(
1247
+ input_tensors
1248
+ ), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
1249
+ num_args_in_forward_chunk_fn, len(input_tensors)
1250
+ )
1251
+
1252
+ if chunk_size > 0:
1253
+ assert (
1254
+ input_tensors[0].shape[chunk_dim] % chunk_size == 0
1255
+ ), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
1256
+ input_tensors[0].shape[chunk_dim], chunk_size
1257
+ )
1258
+
1259
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1260
+
1261
+ # chunk input tensor into tuples
1262
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1263
+ # apply forward fn to every tuple
1264
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1265
+ # concatenate output at same dimension
1266
+ return torch.cat(output_chunks, dim=chunk_dim)
1267
+
1268
+ return forward_fn(*input_tensors)
CGFormer/bert/tokenization_bert.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes."""
16
+
17
+
18
+ import collections
19
+ import logging
20
+ import os
21
+ import unicodedata
22
+ from typing import List, Optional
23
+
24
+ from .tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
34
+ "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
35
+ "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
36
+ "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
38
+ "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
39
+ "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
40
+ "bert-base-german-cased": "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Constructs a BERT tokenizer. Based on WordPiece.
120
+
121
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
122
+ should refer to the superclass for more information regarding methods.
123
+
124
+ Args:
125
+ vocab_file (:obj:`string`):
126
+ File containing the vocabulary.
127
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether to lowercase the input when tokenizing.
129
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
130
+ Whether to do basic tokenization before WordPiece.
131
+ never_split (:obj:`Iterable`, `optional`, defaults to :obj:`None`):
132
+ Collection of tokens which will never be split during tokenization. Only has an effect when
133
+ :obj:`do_basic_tokenize=True`
134
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
135
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
136
+ token instead.
137
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
138
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
139
+ for sequence classification or for a text and a question for question answering.
140
+ It is also used as the last token of a sequence built with special tokens.
141
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
142
+ The token used for padding, for example when batching sequences of different lengths.
143
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
144
+ The classifier token which is used when doing sequence classification (classification of the whole
145
+ sequence instead of per-token classification). It is the first token of the sequence when built with
146
+ special tokens.
147
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
148
+ The token used for masking values. This is the token used when training this model with masked language
149
+ modeling. This is the token which the model will try to predict.
150
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
151
+ Whether to tokenize Chinese characters.
152
+ This should likely be deactivated for Japanese:
153
+ see: https://github.com/huggingface/transformers/issues/328
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ **kwargs
174
+ ):
175
+ super().__init__(
176
+ unk_token=unk_token,
177
+ sep_token=sep_token,
178
+ pad_token=pad_token,
179
+ cls_token=cls_token,
180
+ mask_token=mask_token,
181
+ **kwargs,
182
+ )
183
+
184
+ if not os.path.isfile(vocab_file):
185
+ raise ValueError(
186
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
187
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
188
+ )
189
+ self.vocab = load_vocab(vocab_file)
190
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
191
+ self.do_basic_tokenize = do_basic_tokenize
192
+ if do_basic_tokenize:
193
+ self.basic_tokenizer = BasicTokenizer(
194
+ do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
195
+ )
196
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
197
+
198
+ @property
199
+ def vocab_size(self):
200
+ return len(self.vocab)
201
+
202
+ def get_vocab(self):
203
+ return dict(self.vocab, **self.added_tokens_encoder)
204
+
205
+ def _tokenize(self, text):
206
+ split_tokens = []
207
+ if self.do_basic_tokenize:
208
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
209
+
210
+ # If the token is part of the never_split set
211
+ if token in self.basic_tokenizer.never_split:
212
+ split_tokens.append(token)
213
+ else:
214
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
215
+ else:
216
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
217
+ return split_tokens
218
+
219
+ def _convert_token_to_id(self, token):
220
+ """ Converts a token (str) in an id using the vocab. """
221
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
222
+
223
+ def _convert_id_to_token(self, index):
224
+ """Converts an index (integer) in a token (str) using the vocab."""
225
+ return self.ids_to_tokens.get(index, self.unk_token)
226
+
227
+ def convert_tokens_to_string(self, tokens):
228
+ """ Converts a sequence of tokens (string) in a single string. """
229
+ out_string = " ".join(tokens).replace(" ##", "").strip()
230
+ return out_string
231
+
232
+ def build_inputs_with_special_tokens(
233
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
+ ) -> List[int]:
235
+ """
236
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
237
+ by concatenating and adding special tokens.
238
+ A BERT sequence has the following format:
239
+
240
+ - single sequence: ``[CLS] X [SEP]``
241
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
242
+
243
+ Args:
244
+ token_ids_0 (:obj:`List[int]`):
245
+ List of IDs to which the special tokens will be added
246
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
247
+ Optional second list of IDs for sequence pairs.
248
+
249
+ Returns:
250
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
251
+ """
252
+ if token_ids_1 is None:
253
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
254
+ cls = [self.cls_token_id]
255
+ sep = [self.sep_token_id]
256
+ return cls + token_ids_0 + sep + token_ids_1 + sep
257
+
258
+ def get_special_tokens_mask(
259
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
260
+ ) -> List[int]:
261
+ """
262
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
263
+ special tokens using the tokenizer ``prepare_for_model`` method.
264
+
265
+ Args:
266
+ token_ids_0 (:obj:`List[int]`):
267
+ List of ids.
268
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
269
+ Optional second list of IDs for sequence pairs.
270
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
271
+ Set to True if the token list is already formatted with special tokens for the model
272
+
273
+ Returns:
274
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
275
+ """
276
+
277
+ if already_has_special_tokens:
278
+ if token_ids_1 is not None:
279
+ raise ValueError(
280
+ "You should not supply a second sequence if the provided sequence of "
281
+ "ids is already formated with special tokens for the model."
282
+ )
283
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
284
+
285
+ if token_ids_1 is not None:
286
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
287
+ return [1] + ([0] * len(token_ids_0)) + [1]
288
+
289
+ def create_token_type_ids_from_sequences(
290
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
291
+ ) -> List[int]:
292
+ """
293
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
294
+ A BERT sequence pair mask has the following format:
295
+
296
+ ::
297
+
298
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
299
+ | first sequence | second sequence |
300
+
301
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
302
+
303
+ Args:
304
+ token_ids_0 (:obj:`List[int]`):
305
+ List of ids.
306
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
307
+ Optional second list of IDs for sequence pairs.
308
+
309
+ Returns:
310
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
311
+ sequence(s).
312
+ """
313
+ sep = [self.sep_token_id]
314
+ cls = [self.cls_token_id]
315
+ if token_ids_1 is None:
316
+ return len(cls + token_ids_0 + sep) * [0]
317
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
318
+
319
+ def save_vocabulary(self, vocab_path):
320
+ """
321
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
322
+
323
+ Args:
324
+ vocab_path (:obj:`str`):
325
+ The directory in which to save the vocabulary.
326
+
327
+ Returns:
328
+ :obj:`Tuple(str)`: Paths to the files saved.
329
+ """
330
+ index = 0
331
+ if os.path.isdir(vocab_path):
332
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
333
+ else:
334
+ vocab_file = vocab_path
335
+ with open(vocab_file, "w", encoding="utf-8") as writer:
336
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
337
+ if index != token_index:
338
+ logger.warning(
339
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
340
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
350
+
351
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
352
+ """ Constructs a BasicTokenizer.
353
+
354
+ Args:
355
+ **do_lower_case**: Whether to lower case the input.
356
+ **never_split**: (`optional`) list of str
357
+ Kept for backward compatibility purposes.
358
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
359
+ List of token not to split.
360
+ **tokenize_chinese_chars**: (`optional`) boolean (default True)
361
+ Whether to tokenize Chinese characters.
362
+ This should likely be deactivated for Japanese:
363
+ see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
364
+ """
365
+ if never_split is None:
366
+ never_split = []
367
+ self.do_lower_case = do_lower_case
368
+ self.never_split = set(never_split)
369
+ self.tokenize_chinese_chars = tokenize_chinese_chars
370
+
371
+ def tokenize(self, text, never_split=None):
372
+ """ Basic Tokenization of a piece of text.
373
+ Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer.
374
+
375
+ Args:
376
+ **never_split**: (`optional`) list of str
377
+ Kept for backward compatibility purposes.
378
+ Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`)
379
+ List of token not to split.
380
+ """
381
+ # union() returns a new set by concatenating the two sets.
382
+ never_split = self.never_split.union(set(never_split)) if never_split else self.never_split
383
+
384
+ # This was added on November 1st, 2018 for the multilingual and Chinese
385
+ # models. This is also applied to the English models now, but it doesn't
386
+ # matter since the English models were not trained on any Chinese data
387
+ # and generally don't have any Chinese data in them (there are Chinese
388
+ # characters in the vocabulary because Wikipedia does have some Chinese
389
+ # words in the English Wikipedia.).
390
+ if self.tokenize_chinese_chars:
391
+ text = self._tokenize_chinese_chars(text)
392
+ orig_tokens = whitespace_tokenize(text)
393
+ split_tokens = []
394
+ for token in orig_tokens:
395
+ if self.do_lower_case and token not in never_split:
396
+ token = token.lower()
397
+ token = self._run_strip_accents(token)
398
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
399
+
400
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
401
+ return output_tokens
402
+
403
+ def _run_strip_accents(self, text):
404
+ """Strips accents from a piece of text."""
405
+ text = unicodedata.normalize("NFD", text)
406
+ output = []
407
+ for char in text:
408
+ cat = unicodedata.category(char)
409
+ if cat == "Mn":
410
+ continue
411
+ output.append(char)
412
+ return "".join(output)
413
+
414
+ def _run_split_on_punc(self, text, never_split=None):
415
+ """Splits punctuation on a piece of text."""
416
+ if never_split is not None and text in never_split:
417
+ return [text]
418
+ chars = list(text)
419
+ i = 0
420
+ start_new_word = True
421
+ output = []
422
+ while i < len(chars):
423
+ char = chars[i]
424
+ if _is_punctuation(char):
425
+ output.append([char])
426
+ start_new_word = True
427
+ else:
428
+ if start_new_word:
429
+ output.append([])
430
+ start_new_word = False
431
+ output[-1].append(char)
432
+ i += 1
433
+
434
+ return ["".join(x) for x in output]
435
+
436
+ def _tokenize_chinese_chars(self, text):
437
+ """Adds whitespace around any CJK character."""
438
+ output = []
439
+ for char in text:
440
+ cp = ord(char)
441
+ if self._is_chinese_char(cp):
442
+ output.append(" ")
443
+ output.append(char)
444
+ output.append(" ")
445
+ else:
446
+ output.append(char)
447
+ return "".join(output)
448
+
449
+ def _is_chinese_char(self, cp):
450
+ """Checks whether CP is the codepoint of a CJK character."""
451
+ # This defines a "chinese character" as anything in the CJK Unicode block:
452
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
453
+ #
454
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
455
+ # despite its name. The modern Korean Hangul alphabet is a different block,
456
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
457
+ # space-separated words, so they are not treated specially and handled
458
+ # like the all of the other languages.
459
+ if (
460
+ (cp >= 0x4E00 and cp <= 0x9FFF)
461
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
462
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
463
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
464
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
465
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
466
+ or (cp >= 0xF900 and cp <= 0xFAFF)
467
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
468
+ ): #
469
+ return True
470
+
471
+ return False
472
+
473
+ def _clean_text(self, text):
474
+ """Performs invalid character removal and whitespace cleanup on text."""
475
+ output = []
476
+ for char in text:
477
+ cp = ord(char)
478
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
479
+ continue
480
+ if _is_whitespace(char):
481
+ output.append(" ")
482
+ else:
483
+ output.append(char)
484
+ return "".join(output)
485
+
486
+
487
+ class WordpieceTokenizer(object):
488
+ """Runs WordPiece tokenization."""
489
+
490
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
491
+ self.vocab = vocab
492
+ self.unk_token = unk_token
493
+ self.max_input_chars_per_word = max_input_chars_per_word
494
+
495
+ def tokenize(self, text):
496
+ """Tokenizes a piece of text into its word pieces.
497
+
498
+ This uses a greedy longest-match-first algorithm to perform tokenization
499
+ using the given vocabulary.
500
+
501
+ For example:
502
+ input = "unaffable"
503
+ output = ["un", "##aff", "##able"]
504
+
505
+ Args:
506
+ text: A single token or whitespace separated tokens. This should have
507
+ already been passed through `BasicTokenizer`.
508
+
509
+ Returns:
510
+ A list of wordpiece tokens.
511
+ """
512
+
513
+ output_tokens = []
514
+ for token in whitespace_tokenize(text):
515
+ chars = list(token)
516
+ if len(chars) > self.max_input_chars_per_word:
517
+ output_tokens.append(self.unk_token)
518
+ continue
519
+
520
+ is_bad = False
521
+ start = 0
522
+ sub_tokens = []
523
+ while start < len(chars):
524
+ end = len(chars)
525
+ cur_substr = None
526
+ while start < end:
527
+ substr = "".join(chars[start:end])
528
+ if start > 0:
529
+ substr = "##" + substr
530
+ if substr in self.vocab:
531
+ cur_substr = substr
532
+ break
533
+ end -= 1
534
+ if cur_substr is None:
535
+ is_bad = True
536
+ break
537
+ sub_tokens.append(cur_substr)
538
+ start = end
539
+
540
+ if is_bad:
541
+ output_tokens.append(self.unk_token)
542
+ else:
543
+ output_tokens.extend(sub_tokens)
544
+ return output_tokens
545
+
CGFormer/bert/tokenization_utils.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Tokenization classes for python tokenizers.
16
+ For fast tokenizers (provided by HuggingFace's tokenizers library) see tokenization_utils_fast.py
17
+ """
18
+
19
+ import itertools
20
+ import logging
21
+ import re
22
+ import unicodedata
23
+ from typing import Dict, List, Optional, Tuple, Union
24
+
25
+ from .file_utils import add_end_docstrings
26
+ from .tokenization_utils_base import (
27
+ ENCODE_KWARGS_DOCSTRING,
28
+ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
29
+ AddedToken,
30
+ BatchEncoding,
31
+ EncodedInput,
32
+ EncodedInputPair,
33
+ PaddingStrategy,
34
+ PreTokenizedInput,
35
+ PreTokenizedInputPair,
36
+ PreTrainedTokenizerBase,
37
+ TensorType,
38
+ TextInput,
39
+ TextInputPair,
40
+ TruncationStrategy,
41
+ )
42
+
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ def _is_whitespace(char):
48
+ """Checks whether `chars` is a whitespace character."""
49
+ # \t, \n, and \r are technically contorl characters but we treat them
50
+ # as whitespace since they are generally considered as such.
51
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
52
+ return True
53
+ cat = unicodedata.category(char)
54
+ if cat == "Zs":
55
+ return True
56
+ return False
57
+
58
+
59
+ def _is_control(char):
60
+ """Checks whether `chars` is a control character."""
61
+ # These are technically control characters but we count them as whitespace
62
+ # characters.
63
+ if char == "\t" or char == "\n" or char == "\r":
64
+ return False
65
+ cat = unicodedata.category(char)
66
+ if cat.startswith("C"):
67
+ return True
68
+ return False
69
+
70
+
71
+ def _is_punctuation(char):
72
+ """Checks whether `chars` is a punctuation character."""
73
+ cp = ord(char)
74
+ # We treat all non-letter/number ASCII as punctuation.
75
+ # Characters such as "^", "$", and "`" are not in the Unicode
76
+ # Punctuation class but we treat them as punctuation anyways, for
77
+ # consistency.
78
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def _is_end_of_word(text):
87
+ """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
88
+ last_char = text[-1]
89
+ return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
90
+
91
+
92
+ def _is_start_of_word(text):
93
+ """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
94
+ first_char = text[0]
95
+ return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
96
+
97
+
98
+ class PreTrainedTokenizer(PreTrainedTokenizerBase):
99
+ """ Base class for all slow tokenizers.
100
+
101
+ Handle all the shared methods for tokenization and special tokens as well as methods
102
+ downloading/caching/loading pretrained tokenizers as well as adding tokens to the vocabulary.
103
+
104
+ This class also contain the added tokens in a unified way on top of all tokenizers so we don't
105
+ have to handle the specific vocabulary augmentation methods of the various underlying
106
+ dictionary structures (BPE, sentencepiece...).
107
+
108
+ Class attributes (overridden by derived classes):
109
+
110
+ - ``vocab_files_names``: a python ``dict`` with, as keys, the ``__init__`` keyword name of each vocabulary file
111
+ required by the model, and as associated values, the filename for saving the associated file (string).
112
+ - ``pretrained_vocab_files_map``: a python ``dict of dict`` the high-level keys
113
+ being the ``__init__`` keyword name of each vocabulary file required by the model, the low-level being the
114
+ `short-cut-names` (string) of the pretrained models with, as associated values, the `url` (string) to the
115
+ associated pretrained vocabulary file.
116
+ - ``max_model_input_sizes``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the pretrained
117
+ models, and as associated values, the maximum length of the sequence inputs of this model, or None if the
118
+ model has no maximum input size.
119
+ - ``pretrained_init_configuration``: a python ``dict`` with, as keys, the `short-cut-names` (string) of the
120
+ pretrained models, and as associated values, a dictionnary of specific arguments to pass to the
121
+ ``__init__``method of the tokenizer class for this pretrained model when loading the tokenizer with the
122
+ ``from_pretrained()`` method.
123
+
124
+ Args:
125
+ - ``model_max_length``: (`Optional`) int: the maximum length in number of tokens for the inputs to the transformer model.
126
+ When the tokenizer is loaded with `from_pretrained`, this will be set to the value stored for the associated
127
+ model in ``max_model_input_sizes`` (see above). If no value is provided, will default to VERY_LARGE_INTEGER (`int(1e30)`).
128
+ no associated max_length can be found in ``max_model_input_sizes``.
129
+ - ``padding_side``: (`Optional`) string: the side on which the model should have padding applied.
130
+ Should be selected between ['right', 'left']
131
+ - ``model_input_names``: (`Optional`) List[string]: the list of the forward pass inputs accepted by the
132
+ model ("token_type_ids", "attention_mask"...).
133
+ - ``bos_token``: (`Optional`) string: a beginning of sentence token.
134
+ Will be associated to ``self.bos_token`` and ``self.bos_token_id``
135
+ - ``eos_token``: (`Optional`) string: an end of sentence token.
136
+ Will be associated to ``self.eos_token`` and ``self.eos_token_id``
137
+ - ``unk_token``: (`Optional`) string: an unknown token.
138
+ Will be associated to ``self.unk_token`` and ``self.unk_token_id``
139
+ - ``sep_token``: (`Optional`) string: a separation token (e.g. to separate context and query in an input sequence).
140
+ Will be associated to ``self.sep_token`` and ``self.sep_token_id``
141
+ - ``pad_token``: (`Optional`) string: a padding token.
142
+ Will be associated to ``self.pad_token`` and ``self.pad_token_id``
143
+ - ``cls_token``: (`Optional`) string: a classification token (e.g. to extract a summary of an input sequence
144
+ leveraging self-attention along the full depth of the model).
145
+ Will be associated to ``self.cls_token`` and ``self.cls_token_id``
146
+ - ``mask_token``: (`Optional`) string: a masking token (e.g. when training a model with masked-language
147
+ modeling). Will be associated to ``self.mask_token`` and ``self.mask_token_id``
148
+ - ``additional_special_tokens``: (`Optional`) list: a list of additional special tokens.
149
+ Adding all special tokens here ensure they won't be split by the tokenization process.
150
+ Will be associated to ``self.additional_special_tokens`` and ``self.additional_special_tokens_ids``
151
+
152
+
153
+ .. automethod:: __call__
154
+ """
155
+
156
+ def __init__(self, **kwargs):
157
+ super().__init__(**kwargs)
158
+
159
+ # Added tokens - We store this for both slow and fast tokenizers
160
+ # until the serialization of Fast tokenizers is updated
161
+ self.added_tokens_encoder: Dict[str, int] = {}
162
+ self.added_tokens_decoder: Dict[int, str] = {}
163
+ self.unique_no_split_tokens: List[str] = []
164
+
165
+ @property
166
+ def is_fast(self) -> bool:
167
+ return False
168
+
169
+ @property
170
+ def vocab_size(self) -> int:
171
+ """ Size of the base vocabulary (without the added tokens) """
172
+ raise NotImplementedError
173
+
174
+ def get_vocab(self):
175
+ """ Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
176
+ raise NotImplementedError()
177
+
178
+ def get_added_vocab(self) -> Dict[str, int]:
179
+ return self.added_tokens_encoder
180
+
181
+ def __len__(self):
182
+ """ Size of the full vocabulary with the added tokens """
183
+ return self.vocab_size + len(self.added_tokens_encoder)
184
+
185
+ def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens=False) -> int:
186
+ """
187
+ Add a list of new tokens to the tokenizer class. If the new tokens are not in the
188
+ vocabulary, they are added to it with indices starting from length of the current vocabulary.
189
+
190
+ Args:
191
+ new_tokens: string or list of string. Each string is a token to add. Tokens are only added if they are not
192
+ already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
193
+
194
+ Returns:
195
+ Number of tokens added to the vocabulary.
196
+
197
+ Examples::
198
+
199
+ # Let's see how to increase the vocabulary of Bert model and tokenizer
200
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
201
+ model = BertModel.from_pretrained('bert-base-uncased')
202
+
203
+ num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
204
+ print('We have added', num_added_toks, 'tokens')
205
+ model.resize_token_embeddings(len(tokenizer)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
206
+ """
207
+ new_tokens = [str(tok) for tok in new_tokens]
208
+
209
+ tokens_to_add = []
210
+ for token in new_tokens:
211
+ assert isinstance(token, str)
212
+ if not special_tokens and self.init_kwargs.get("do_lower_case", False):
213
+ token = token.lower()
214
+ if (
215
+ token != self.unk_token
216
+ and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
217
+ and token not in tokens_to_add
218
+ ):
219
+ tokens_to_add.append(token)
220
+ if self.verbose:
221
+ logger.info("Adding %s to the vocabulary", token)
222
+
223
+ added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
224
+ added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
225
+ self.added_tokens_encoder.update(added_tok_encoder)
226
+ self.added_tokens_decoder.update(added_tok_decoder)
227
+
228
+ # Make sure we don't split on any special tokens (even they were already in the vocab before e.g. for Albert)
229
+ if special_tokens:
230
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(new_tokens)))
231
+ else:
232
+ # Or on the newly added tokens
233
+ self.unique_no_split_tokens = list(set(self.unique_no_split_tokens).union(set(tokens_to_add)))
234
+
235
+ return len(tokens_to_add)
236
+
237
+ def num_special_tokens_to_add(self, pair=False):
238
+ """
239
+ Returns the number of added tokens when encoding a sequence with special tokens.
240
+
241
+ Note:
242
+ This encodes inputs and checks the number of added tokens, and is therefore not efficient. Do not put this
243
+ inside your training loop.
244
+
245
+ Args:
246
+ pair: Returns the number of added tokens in the case of a sequence pair if set to True, returns the
247
+ number of added tokens in the case of a single sequence if set to False.
248
+
249
+ Returns:
250
+ Number of tokens added to sequences
251
+ """
252
+ token_ids_0 = []
253
+ token_ids_1 = []
254
+ return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
255
+
256
+ def tokenize(self, text: TextInput, **kwargs):
257
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
258
+ Split in words for word-based vocabulary or sub-words for sub-word-based
259
+ vocabularies (BPE/SentencePieces/WordPieces).
260
+
261
+ Take care of added tokens.
262
+
263
+ Args:
264
+ text (:obj:`string`): The sequence to be encoded.
265
+ **kwargs (:obj: `dict`): Arguments passed to the model-specific `prepare_for_tokenization` preprocessing method.
266
+ """
267
+ # Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
268
+ all_special_tokens_extended = dict(
269
+ (str(t), t) for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
270
+ )
271
+
272
+ text, kwargs = self.prepare_for_tokenization(text, **kwargs)
273
+
274
+ if kwargs:
275
+ logger.warning(f"Keyword arguments {kwargs} not recognized.")
276
+
277
+ # TODO: should this be in the base class?
278
+ if self.init_kwargs.get("do_lower_case", False):
279
+ # convert non-special tokens to lowercase
280
+ escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens]
281
+ pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
282
+ text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
283
+
284
+ def split_on_token(tok, text):
285
+ result = []
286
+ tok_extended = all_special_tokens_extended.get(tok, None)
287
+ split_text = text.split(tok)
288
+ full_word = ""
289
+ for i, sub_text in enumerate(split_text):
290
+ # AddedToken can control whitespace stripping around them.
291
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
292
+ # Cf. https://github.com/huggingface/transformers/pull/2778
293
+ # and https://github.com/huggingface/transformers/issues/3788
294
+ if isinstance(tok_extended, AddedToken):
295
+ if tok_extended.single_word:
296
+ # Try to avoid splitting on token
297
+ if (
298
+ i < len(split_text) - 1
299
+ and not _is_end_of_word(sub_text)
300
+ and not _is_start_of_word(split_text[i + 1])
301
+ ):
302
+ # Don't extract the special token
303
+ full_word += sub_text + tok
304
+ elif full_word:
305
+ full_word += sub_text
306
+ result += [full_word]
307
+ full_word = ""
308
+ continue
309
+ # Strip white spaces on the right
310
+ if tok_extended.rstrip and i > 0:
311
+ # A bit counter-intuitive but we strip the left of the string
312
+ # since tok_extended.rstrip means the special token is eating all white spaces on its right
313
+ sub_text = sub_text.lstrip()
314
+ # Strip white spaces on the left
315
+ if tok_extended.lstrip and i < len(split_text) - 1:
316
+ sub_text = sub_text.rstrip() # Opposite here
317
+ else:
318
+ # We strip left and right by default
319
+ if i < len(split_text) - 1:
320
+ sub_text = sub_text.rstrip()
321
+ if i > 0:
322
+ sub_text = sub_text.lstrip()
323
+
324
+ if i == 0 and not sub_text:
325
+ result += [tok]
326
+ elif i == len(split_text) - 1:
327
+ if sub_text:
328
+ result += [sub_text]
329
+ else:
330
+ pass
331
+ else:
332
+ if sub_text:
333
+ result += [sub_text]
334
+ result += [tok]
335
+ return result
336
+
337
+ def split_on_tokens(tok_list, text):
338
+ if not text.strip():
339
+ return []
340
+ if not tok_list:
341
+ return self._tokenize(text)
342
+
343
+ tokenized_text = []
344
+ text_list = [text]
345
+ for tok in tok_list:
346
+ tokenized_text = []
347
+ for sub_text in text_list:
348
+ if sub_text not in self.unique_no_split_tokens:
349
+ tokenized_text += split_on_token(tok, sub_text)
350
+ else:
351
+ tokenized_text += [sub_text]
352
+ text_list = tokenized_text
353
+
354
+ return list(
355
+ itertools.chain.from_iterable(
356
+ (
357
+ self._tokenize(token) if token not in self.unique_no_split_tokens else [token]
358
+ for token in tokenized_text
359
+ )
360
+ )
361
+ )
362
+
363
+ no_split_token = self.unique_no_split_tokens
364
+ tokenized_text = split_on_tokens(no_split_token, text)
365
+ return tokenized_text
366
+
367
+ def _tokenize(self, text, **kwargs):
368
+ """ Converts a string in a sequence of tokens (string), using the tokenizer.
369
+ Split in words for word-based vocabulary or sub-words for sub-word-based
370
+ vocabularies (BPE/SentencePieces/WordPieces).
371
+
372
+ Do NOT take care of added tokens.
373
+ """
374
+ raise NotImplementedError
375
+
376
+ def convert_tokens_to_ids(self, tokens):
377
+ """ Converts a token string (or a sequence of tokens) in a single integer id
378
+ (or a sequence of ids), using the vocabulary.
379
+ """
380
+ if tokens is None:
381
+ return None
382
+
383
+ if isinstance(tokens, str):
384
+ return self._convert_token_to_id_with_added_voc(tokens)
385
+
386
+ ids = []
387
+ for token in tokens:
388
+ ids.append(self._convert_token_to_id_with_added_voc(token))
389
+ return ids
390
+
391
+ def _convert_token_to_id_with_added_voc(self, token):
392
+ if token is None:
393
+ return None
394
+
395
+ if token in self.added_tokens_encoder:
396
+ return self.added_tokens_encoder[token]
397
+ return self._convert_token_to_id(token)
398
+
399
+ def _convert_token_to_id(self, token):
400
+ raise NotImplementedError
401
+
402
+ def _encode_plus(
403
+ self,
404
+ text: Union[TextInput, PreTokenizedInput, EncodedInput],
405
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
406
+ add_special_tokens: bool = True,
407
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
408
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
409
+ max_length: Optional[int] = None,
410
+ stride: int = 0,
411
+ is_pretokenized: bool = False,
412
+ pad_to_multiple_of: Optional[int] = None,
413
+ return_tensors: Optional[Union[str, TensorType]] = None,
414
+ return_token_type_ids: Optional[bool] = None,
415
+ return_attention_mask: Optional[bool] = None,
416
+ return_overflowing_tokens: bool = False,
417
+ return_special_tokens_mask: bool = False,
418
+ return_offsets_mapping: bool = False,
419
+ return_length: bool = False,
420
+ verbose: bool = True,
421
+ **kwargs
422
+ ) -> BatchEncoding:
423
+ def get_input_ids(text):
424
+ if isinstance(text, str):
425
+ tokens = self.tokenize(text, **kwargs)
426
+ return self.convert_tokens_to_ids(tokens)
427
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
428
+ if is_pretokenized:
429
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
430
+ return self.convert_tokens_to_ids(tokens)
431
+ else:
432
+ return self.convert_tokens_to_ids(text)
433
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
434
+ return text
435
+ else:
436
+ if is_pretokenized:
437
+ raise ValueError(
438
+ f"Input {text} is not valid. Should be a string or a list/tuple of strings when `is_pretokenized=True`."
439
+ )
440
+ else:
441
+ raise ValueError(
442
+ f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
443
+ )
444
+
445
+ if return_offsets_mapping:
446
+ raise NotImplementedError(
447
+ "return_offset_mapping is not available when using Python tokenizers."
448
+ "To use this feature, change your tokenizer to one deriving from "
449
+ "transformers.PreTrainedTokenizerFast."
450
+ "More information on available tokenizers at "
451
+ "https://github.com/huggingface/transformers/pull/2674"
452
+ )
453
+
454
+ first_ids = get_input_ids(text)
455
+ second_ids = get_input_ids(text_pair) if text_pair is not None else None
456
+
457
+ return self.prepare_for_model(
458
+ first_ids,
459
+ pair_ids=second_ids,
460
+ add_special_tokens=add_special_tokens,
461
+ padding=padding_strategy.value,
462
+ truncation=truncation_strategy.value,
463
+ max_length=max_length,
464
+ stride=stride,
465
+ pad_to_multiple_of=pad_to_multiple_of,
466
+ return_tensors=return_tensors,
467
+ prepend_batch_axis=True,
468
+ return_attention_mask=return_attention_mask,
469
+ return_token_type_ids=return_token_type_ids,
470
+ return_overflowing_tokens=return_overflowing_tokens,
471
+ return_special_tokens_mask=return_special_tokens_mask,
472
+ return_length=return_length,
473
+ verbose=verbose,
474
+ )
475
+
476
+ def _batch_encode_plus(
477
+ self,
478
+ batch_text_or_text_pairs: Union[
479
+ List[TextInput],
480
+ List[TextInputPair],
481
+ List[PreTokenizedInput],
482
+ List[PreTokenizedInputPair],
483
+ List[EncodedInput],
484
+ List[EncodedInputPair],
485
+ ],
486
+ add_special_tokens: bool = True,
487
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
488
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
489
+ max_length: Optional[int] = None,
490
+ stride: int = 0,
491
+ is_pretokenized: bool = False,
492
+ pad_to_multiple_of: Optional[int] = None,
493
+ return_tensors: Optional[Union[str, TensorType]] = None,
494
+ return_token_type_ids: Optional[bool] = None,
495
+ return_attention_mask: Optional[bool] = None,
496
+ return_overflowing_tokens: bool = False,
497
+ return_special_tokens_mask: bool = False,
498
+ return_offsets_mapping: bool = False,
499
+ return_length: bool = False,
500
+ verbose: bool = True,
501
+ **kwargs
502
+ ) -> BatchEncoding:
503
+ def get_input_ids(text):
504
+ if isinstance(text, str):
505
+ tokens = self.tokenize(text, **kwargs)
506
+ return self.convert_tokens_to_ids(tokens)
507
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
508
+ if is_pretokenized:
509
+ tokens = list(itertools.chain(*(self.tokenize(t, is_pretokenized=True, **kwargs) for t in text)))
510
+ return self.convert_tokens_to_ids(tokens)
511
+ else:
512
+ return self.convert_tokens_to_ids(text)
513
+ elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
514
+ return text
515
+ else:
516
+ raise ValueError(
517
+ "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
518
+ )
519
+
520
+ if return_offsets_mapping:
521
+ raise NotImplementedError(
522
+ "return_offset_mapping is not available when using Python tokenizers."
523
+ "To use this feature, change your tokenizer to one deriving from "
524
+ "transformers.PreTrainedTokenizerFast."
525
+ )
526
+
527
+ input_ids = []
528
+ for ids_or_pair_ids in batch_text_or_text_pairs:
529
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
530
+ ids, pair_ids = ids_or_pair_ids, None
531
+ elif is_pretokenized and not isinstance(ids_or_pair_ids[0], (list, tuple)):
532
+ ids, pair_ids = ids_or_pair_ids, None
533
+ else:
534
+ ids, pair_ids = ids_or_pair_ids
535
+
536
+ first_ids = get_input_ids(ids)
537
+ second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
538
+ input_ids.append((first_ids, second_ids))
539
+
540
+ batch_outputs = self._batch_prepare_for_model(
541
+ input_ids,
542
+ add_special_tokens=add_special_tokens,
543
+ padding_strategy=padding_strategy,
544
+ truncation_strategy=truncation_strategy,
545
+ max_length=max_length,
546
+ stride=stride,
547
+ pad_to_multiple_of=pad_to_multiple_of,
548
+ return_attention_mask=return_attention_mask,
549
+ return_token_type_ids=return_token_type_ids,
550
+ return_overflowing_tokens=return_overflowing_tokens,
551
+ return_special_tokens_mask=return_special_tokens_mask,
552
+ return_length=return_length,
553
+ return_tensors=return_tensors,
554
+ verbose=verbose,
555
+ )
556
+
557
+ return BatchEncoding(batch_outputs)
558
+
559
+ @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
560
+ def _batch_prepare_for_model(
561
+ self,
562
+ batch_ids_pairs: List[Union[PreTokenizedInputPair, Tuple[List[int], None]]],
563
+ add_special_tokens: bool = True,
564
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
565
+ truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
566
+ max_length: Optional[int] = None,
567
+ stride: int = 0,
568
+ pad_to_multiple_of: Optional[int] = None,
569
+ return_tensors: Optional[str] = None,
570
+ return_token_type_ids: Optional[bool] = None,
571
+ return_attention_mask: Optional[bool] = None,
572
+ return_overflowing_tokens: bool = False,
573
+ return_special_tokens_mask: bool = False,
574
+ return_length: bool = False,
575
+ verbose: bool = True,
576
+ ) -> BatchEncoding:
577
+ """ Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
578
+ It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
579
+ manages a moving window (with user defined stride) for overflowing tokens
580
+
581
+ Args:
582
+ batch_ids_pairs: list of tokenized input ids or input ids pairs
583
+ """
584
+
585
+ batch_outputs = {}
586
+ for first_ids, second_ids in batch_ids_pairs:
587
+ outputs = self.prepare_for_model(
588
+ first_ids,
589
+ second_ids,
590
+ add_special_tokens=add_special_tokens,
591
+ padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward
592
+ truncation=truncation_strategy.value,
593
+ max_length=max_length,
594
+ stride=stride,
595
+ pad_to_multiple_of=None, # we pad in batch afterward
596
+ return_attention_mask=False, # we pad in batch afterward
597
+ return_token_type_ids=return_token_type_ids,
598
+ return_overflowing_tokens=return_overflowing_tokens,
599
+ return_special_tokens_mask=return_special_tokens_mask,
600
+ return_length=return_length,
601
+ return_tensors=None, # We convert the whole batch to tensors at the end
602
+ prepend_batch_axis=False,
603
+ verbose=verbose,
604
+ )
605
+
606
+ for key, value in outputs.items():
607
+ if key not in batch_outputs:
608
+ batch_outputs[key] = []
609
+ batch_outputs[key].append(value)
610
+
611
+ batch_outputs = self.pad(
612
+ batch_outputs,
613
+ padding=padding_strategy.value,
614
+ max_length=max_length,
615
+ pad_to_multiple_of=pad_to_multiple_of,
616
+ return_attention_mask=return_attention_mask,
617
+ )
618
+
619
+ batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
620
+
621
+ return batch_outputs
622
+
623
+ def prepare_for_tokenization(self, text: str, is_pretokenized=False, **kwargs) -> (str, dict):
624
+ """ Performs any necessary transformations before tokenization.
625
+
626
+ This method should pop the arguments from kwargs and return kwargs as well.
627
+ We test kwargs at the end of the encoding process to be sure all the arguments have been used.
628
+ """
629
+ return (text, kwargs)
630
+
631
+ def get_special_tokens_mask(
632
+ self, token_ids_0: List, token_ids_1: Optional[List] = None, already_has_special_tokens: bool = False
633
+ ) -> List[int]:
634
+ """
635
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
636
+ special tokens using the tokenizer ``prepare_for_model`` method.
637
+
638
+ Args:
639
+ token_ids_0: list of ids (must not contain special tokens)
640
+ token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
641
+ for sequence pairs
642
+ already_has_special_tokens: (default False) Set to True if the token list is already formated with
643
+ special tokens for the model
644
+
645
+ Returns:
646
+ A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
647
+ """
648
+ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
649
+
650
+ def convert_ids_to_tokens(
651
+ self, ids: Union[int, List[int]], skip_special_tokens: bool = False
652
+ ) -> Union[str, List[str]]:
653
+ """ Converts a single index or a sequence of indices (integers) in a token "
654
+ (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
655
+
656
+ Args:
657
+ skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
658
+ """
659
+ if isinstance(ids, int):
660
+ if ids in self.added_tokens_decoder:
661
+ return self.added_tokens_decoder[ids]
662
+ else:
663
+ return self._convert_id_to_token(ids)
664
+ tokens = []
665
+ for index in ids:
666
+ index = int(index)
667
+ if skip_special_tokens and index in self.all_special_ids:
668
+ continue
669
+ if index in self.added_tokens_decoder:
670
+ tokens.append(self.added_tokens_decoder[index])
671
+ else:
672
+ tokens.append(self._convert_id_to_token(index))
673
+ return tokens
674
+
675
+ def _convert_id_to_token(self, index: int) -> str:
676
+ raise NotImplementedError
677
+
678
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
679
+ """ Converts a sequence of tokens (string) in a single string.
680
+ The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
681
+ but we often want to remove sub-word tokenization artifacts at the same time.
682
+ """
683
+ return " ".join(self.convert_ids_to_tokens(tokens))
684
+
685
+ def decode(
686
+ self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
687
+ ) -> str:
688
+ filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
689
+
690
+ # To avoid mixing byte-level and unicode for byte-level BPT
691
+ # we need to build string separatly for added tokens and byte-level tokens
692
+ # cf. https://github.com/huggingface/transformers/issues/1133
693
+ sub_texts = []
694
+ current_sub_text = []
695
+ for token in filtered_tokens:
696
+ if skip_special_tokens and token in self.all_special_ids:
697
+ continue
698
+ if token in self.added_tokens_encoder:
699
+ if current_sub_text:
700
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
701
+ current_sub_text = []
702
+ sub_texts.append(token)
703
+ else:
704
+ current_sub_text.append(token)
705
+ if current_sub_text:
706
+ sub_texts.append(self.convert_tokens_to_string(current_sub_text))
707
+ text = " ".join(sub_texts)
708
+
709
+ if clean_up_tokenization_spaces:
710
+ clean_text = self.clean_up_tokenization(text)
711
+ return clean_text
712
+ else:
713
+ return text
714
+
715
+ def save_vocabulary(self, save_directory) -> Tuple[str]:
716
+ """ Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
717
+ and special token mappings.
718
+
719
+ Please use :func:`~transformers.PreTrainedTokenizer.save_pretrained` `()` to save the full
720
+ Tokenizer state if you want to reload it using the :func:`~transformers.PreTrainedTokenizer.from_pretrained`
721
+ class method.
722
+ """
723
+ raise NotImplementedError
CGFormer/bert/tokenization_utils_base.py ADDED
The diff for this file is too large to render. See raw diff
 
CGFormer/ckpts/swin_base_patch4_window12_384_22k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70812ab6b0a7a38712409d13976df9431632466eaacf991d5e90d9a1e91f3ab1
3
+ size 450809979
CGFormer/config/config_gref_ace.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcocog_u
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcocog_u/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcocog_u/val.lmdb
7
+ mask_root: data/masks/refcocog_u
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 32 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 # batch size for training
25
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: True
42
+ exclude_multiobj: True
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ filter_threshold: 0.5
50
+ mixup_lasttwo : False
51
+ use_projections : False
52
+
53
+ Distributed:
54
+ # dist_url: tcp://localhost:18123
55
+ dist_backend: 'nccl'
56
+ # multiprocessing_distributed: True
57
+ world_size: 1
58
+ # rank: 0
59
+ TEST:
60
+ window12: True # if use window12 pretrained for training, testing set true
61
+ test_split: test
62
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
63
+ visualize: False
CGFormer/config/config_mosaic_refcocog_u.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcocog_u
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcocog_u/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcocog_u/val.lmdb
7
+ mask_root: data/masks/refcocog_u
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 32 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 # batch size for training
25
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer
36
+ output_folder: exp/mosaic_refcocog_u/
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ Distributed:
42
+ dist_url: tcp://localhost:12345
43
+ dist_backend: 'nccl'
44
+ multiprocessing_distributed: True
45
+ world_size: 1
46
+ rank: 0
47
+ TEST:
48
+ window12: True # if use window12 pretrained for training, testing set true
49
+ test_split: val
50
+ test_lmdb: data/lmdb/refcocog_u/val.lmdb
51
+ visualize: False
CGFormer/config/config_rcc_ace.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcoco
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcoco/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcoco/val.lmdb
7
+ mask_root: data/masks/refcoco
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 16 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 #batch size for training
25
+ batch_size_val: 24 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: False
42
+ exclude_multiobj: true
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ use_projections : True
50
+ mixup_lasttwo : False
51
+ filter_threshold: 0.68
52
+
53
+ Distributed:
54
+ # dist_url: tcp://localhost:18123
55
+ dist_backend: 'nccl'
56
+ # multiprocessing_distributed: True
57
+ world_size: 1
58
+ # rank: 0
59
+ TEST:
60
+ window12: True # if use window12 pretrained for training, testing set true
61
+ test_split: test
62
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
63
+ visualize: False
CGFormer/config/config_rccp_ace.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcoco+
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcoco+/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcoco+/val.lmdb
7
+ mask_root: data/masks/refcoco+
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 16 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 #batch size for training
25
+ batch_size_val: 24 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: False
42
+ exclude_multiobj: true
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ use_projections : True
50
+ mixup_lasttwo : False
51
+ filter_threshold: 0.68
52
+
53
+ Distributed:
54
+ # dist_url: tcp://localhost:18123
55
+ dist_backend: 'nccl'
56
+ # multiprocessing_distributed: True
57
+ world_size: 1
58
+ # rank: 0
59
+ TEST:
60
+ window12: True # if use window12 pretrained for training, testing set true
61
+ test_split: test
62
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
63
+ visualize: False
CGFormer/config/config_refzom_ace.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: ref-zom
3
+ train_split: train
4
+ train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
5
+ val_split: test
6
+ val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
7
+ mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 32 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 # batch size for training
25
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: True
42
+ exclude_multiobj: True
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ filter_threshold: 0.5
50
+ mixup_lasttwo : False
51
+ use_projections : False
52
+ fuse_mode : simple_attn
53
+
54
+ Distributed:
55
+ # dist_url: tcp://localhost:18123
56
+ dist_backend: 'nccl'
57
+ # multiprocessing_distributed: True
58
+ world_size: 1
59
+ # rank: 0
60
+ TEST:
61
+ window12: True # if use window12 pretrained for training, testing set true
62
+ test_split: test
63
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
64
+ visualize: False
CGFormer/config/config_refzom_repro.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: ref-zom
3
+ train_split: train
4
+ train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
5
+ val_split: test
6
+ val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
7
+ mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 32 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 # batch size for training
25
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: False
42
+ exclude_multiobj: True
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ filter_threshold: 0.5
50
+ mixup_lasttwo : False
51
+
52
+ Distributed:
53
+ # dist_url: tcp://localhost:18123
54
+ dist_backend: 'nccl'
55
+ # multiprocessing_distributed: True
56
+ world_size: 1
57
+ # rank: 0
58
+ TEST:
59
+ window12: True # if use window12 pretrained for training, testing set true
60
+ test_split: test
61
+ test_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
62
+ visualize: False
CGFormer/config/config_refzom_repro_eval.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: ref-zom
3
+ train_split: train
4
+ train_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/train.lmdb
5
+ val_split: test
6
+ val_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
7
+ mask_root: /data2/projects/chaeyun/VerbCentric_RIS/datasets/masks/ref-zom
8
+ TRAIN:
9
+ swin_type: base
10
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
11
+ bert: bert-base-uncased
12
+ mha: '8-8-8-8'
13
+ input_size: 480
14
+ word_len: 20
15
+ word_dim: 768
16
+ vis_dim: 512
17
+ num_token: 2
18
+ token_dim: 512
19
+ sync_bn: True
20
+ dropout: 0.
21
+ fusion_drop: 0.
22
+ workers: 32 # data loader workers
23
+ workers_val: 8
24
+ batch_size: 64 # batch size for training
25
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
26
+ start_epoch: 0
27
+ epochs: 50
28
+ lr_backbone: 5.e-5
29
+ lr_text_encoder: 5.e-5
30
+ lr: 1.e-4
31
+ weight_decay: 1.e-4
32
+ amsgrad: True
33
+ manual_seed:
34
+ print_freq: 100
35
+ exp_name: cgformer_test
36
+ output_folder: /data/seunghoon/CGFormer/exp/seunghoon
37
+ save_freq: 1
38
+ weight:
39
+ resume:
40
+ evaluate: True
41
+ metric_learning: False
42
+ exclude_multiobj: True
43
+ metric_mode: hardpos_only_refined
44
+ metric_loss_weight: 0.1
45
+ loss_option: ACE_verbonly
46
+ margin_value: 12
47
+ temperature: 0.07
48
+ hp_selection: strict
49
+ filter_threshold: 0.5
50
+ mixup_lasttwo : False
51
+
52
+ Distributed:
53
+ # dist_url: tcp://localhost:18123
54
+ dist_backend: 'nccl'
55
+ # multiprocessing_distributed: True
56
+ world_size: 1
57
+ # rank: 0
58
+ TEST:
59
+ window12: True # if use window12 pretrained for training, testing set true
60
+ test_split: test
61
+ test_lmdb: /data2/projects/chaeyun/VerbCentric_RIS/datasets/lmdb/ref-zom/test.lmdb
62
+ visualize: False
CGFormer/config/impl/config.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcocog_u
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcocog_u/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcocog_u/val.lmdb
7
+ mask_root: data/masks/refcocog_u
8
+ AUG:
9
+ check: null
10
+ TRAIN:
11
+ swin_type: base
12
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
13
+ bert: bert-base-uncased
14
+ mha: '8-8-8-8'
15
+ input_size: 480
16
+ word_len: 20
17
+ word_dim: 768
18
+ vis_dim: 512
19
+ num_token: 2
20
+ token_dim: 512
21
+ sync_bn: True
22
+ dropout: 0.
23
+ fusion_drop: 0.
24
+ workers: 32 # data loader workers
25
+ workers_val: 8
26
+ batch_size: 32 #64 # batch size for training
27
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
28
+ start_epoch: 0
29
+ epochs: 3
30
+ lr_backbone: 5.e-5
31
+ lr_text_encoder: 5.e-5
32
+ lr: 1.e-4
33
+ weight_decay: 1.e-4
34
+ amsgrad: True
35
+ manual_seed:
36
+ print_freq: 100
37
+ exp_name: cgformer
38
+ output_folder: exp/impl/
39
+ save_freq: 1
40
+ weight:
41
+ resume:
42
+ evaluate: True
43
+ Distributed:
44
+ # dist_url: tcp://localhost:18123
45
+ dist_backend: 'nccl'
46
+ # multiprocessing_distributed: True
47
+ world_size: 1
48
+ # rank: 0
49
+ TEST:
50
+ window12: True # if use window12 pretrained for training, testing set true
51
+ test_split: test
52
+ test_lmdb: data/lmdb/refcocog_u/test.lmdb
53
+ visualize: False
CGFormer/config/open.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcoco
3
+ train_split: train_seen
4
+ train_lmdb: path/open_lmdb/refcoco/train_seen.lmdb
5
+ val_seen_split: val_seen
6
+ val_seen_lmdb: path/open_lmdb/refcoco/val_seen.lmdb
7
+ val_unseen_split: val_unseen
8
+ val_unseen_lmdb: path/open_lmdb/refcoco/val_unseen.lmdb
9
+ mask_root: path/masks/refcoco
10
+ TRAIN:
11
+ swin_type: base
12
+ swin_pretrain: path/swin_base_patch4_window12_384_22k.pth
13
+ bert: bert-base-uncased
14
+ clip_pretrain: path/pretrain/ViT-L-14-336px.pt
15
+ mha: '8-8-8-8'
16
+ input_size: 480
17
+ clip_dim: 768
18
+ word_len: 20
19
+ num_token: 2
20
+ word_dim: 768
21
+ vis_dim: 512
22
+ token_dim: 512
23
+ sync_bn: True
24
+ dropout: 0.
25
+ fusion_drop: 0.
26
+ workers: 32 # data loader workers
27
+ workers_val: 8
28
+ batch_size: 64 # batch size for training
29
+ batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff
30
+ start_epoch: 0
31
+ epochs: 1000
32
+ lr_backbone: 5.e-5
33
+ lr_text_encoder: 5.e-5
34
+ lr: 1.e-4
35
+ weight_decay: 1.e-4
36
+ amsgrad: True
37
+ manual_seed: 0
38
+ print_freq: 100
39
+ exp_name: open
40
+ output_folder: exp/refcoco
41
+ save_freq: 1
42
+ weight:
43
+ resume:
44
+ evaluate: True # evaluate on validation set, extra gpu memory needed and small batch_size_val is recommend
45
+ Distributed:
46
+ dist_url: tcp://localhost:12345
47
+ dist_backend: 'nccl'
48
+ multiprocessing_distributed: True
49
+ world_size: 1
50
+ rank: 0
51
+ TEST:
52
+ window12: True # if use window12 pretrained for training, testing set true
53
+ test_split: test_unseen
54
+ test_lmdb: path/refcoco/test_unseen.lmdb
55
+ visualize: False
CGFormer/config/refcoco_mosaic/config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATA:
2
+ dataset: refcoco
3
+ train_split: train
4
+ train_lmdb: data/lmdb/refcoco/train.lmdb
5
+ val_split: val
6
+ val_lmdb: data/lmdb/refcoco/val.lmdb
7
+ mask_root: data/masks/refcoco
8
+
9
+ TRAIN:
10
+ swin_type: base
11
+ swin_pretrain: ckpts/swin_base_patch4_window12_384_22k.pth
12
+ bert: bert-base-uncased
13
+ mha: '8-8-8-8'
14
+ input_size: 480
15
+ word_len: 20
16
+ word_dim: 768
17
+ vis_dim: 512
18
+ num_token: 2
19
+ token_dim: 512
20
+ sync_bn: True
21
+ dropout: 0.
22
+ fusion_drop: 0.
23
+ workers: 16 # data loader workers
24
+ workers_val: 8
25
+ batch_size: 64 #batch size for training
26
+ batch_size_val: 16 # 16 batch size for validation during training, memory and speed tradeoff
27
+ start_epoch: 0
28
+ epochs: 50
29
+ lr_backbone: 5.e-5
30
+ lr_text_encoder: 5.e-5
31
+ lr: 1.e-4
32
+ weight_decay: 1.e-4
33
+ amsgrad: True
34
+ manual_seed:
35
+ print_freq: 100
36
+ exp_name: cgformer
37
+ output_folder: exp/refcoco_mosaic/
38
+ save_freq: 1
39
+ weight:
40
+ resume:
41
+ evaluate: True
42
+ aug:
43
+ num_bgs: 4
44
+ aug_prob: 0.6
45
+ tgt_selection: fixed
46
+ move_crs_pnt: False
47
+ blur: False
48
+
49
+ Distributed:
50
+ # dist_url: tcp://localhost:18123
51
+ dist_backend: 'nccl'
52
+ # multiprocessing_distributed: True
53
+ world_size: 1
54
+ # rank: 0
55
+ TEST:
56
+ window12: True # if use window12 pretrained for training, testing set true
57
+ test_split: val
58
+ test_lmdb: data/lmdb/refcoco/test.lmdb
59
+ visualize: False