Commit
·
9cf79cf
1
Parent(s):
361e832
fine tune decoder with mask
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes copy +35 -0
- .vscode/launch.json +25 -0
- Fine_Tune_VQ_w_P/20240604_030821.log +174 -0
- Fine_Tune_VQ_w_P/checkpoint-20.pth +3 -0
- Fine_Tune_VQ_w_P/checkpoint-last.pth +3 -0
- Fine_Tune_VQ_w_P/events.out.tfevents.1717470501.t2vg-a100-G4-10.976065.0 +3 -0
- Fine_Tune_VQ_w_P/log.txt +24 -0
- MIM_vbr/20240516_010714.log +53 -0
- MIM_vbr/checkpoint-399.pth +3 -0
- MIM_vbr/continue_tune_rope.log +514 -0
- MIM_vbr/events.out.tfevents.1714742678.t2vg-a100-G4-10.4118013.0 +3 -0
- MIM_vbr/log.txt +0 -0
- MIM_vbr_eval_lossless/20240518_044257.log +53 -0
- MIM_vbr_eval_lossless/20240521_120210.log +52 -0
- MIM_vbr_eval_lossless/20240521_121146.log +63 -0
- VQ_w_Prompt/20240602_123240.log +196 -0
- VQ_w_Prompt/checkpoint_69.pth.tar +3 -0
- __pycache__/model_vq.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_cls.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_full.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_high_resolu.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_hr_new.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_hs_w_p.cpython-38.pyc +0 -0
- __pycache__/models_mage_codec_rope.cpython-38.pyc +0 -0
- cal_upper_bound.py +287 -0
- ckpt_pretrained/vgg.pth +3 -0
- coco_eval.py +0 -0
- config/cal_upper_bound.yaml +13 -0
- config/eval copy.yaml +41 -0
- config/eval.yaml +41 -0
- config/eval_config.yaml +34 -0
- config/eval_high_res.yaml +41 -0
- config/eval_vq.yaml +38 -0
- config/fine_tune_vq_w_mask.yaml +39 -0
- config/train_det.yaml +38 -0
- config/train_im_cls.yaml +32 -0
- config/train_mim_config.yaml +33 -0
- config/vqgan.yaml +16 -0
- deploited/models_mage_codec_pos_linear.py +767 -0
- deploited/models_mage_codec_rm_pos.py +759 -0
- deploited/models_mage_codec_rope.py +767 -0
- deploited/random_with_rope.log +118 -0
- deploited/random_with_rope.pth +3 -0
- deploited/test_crop.py +68 -0
- deploited/train_vbr_codec_rope.py +673 -0
- detection.py +635 -0
- engine_finetune.py +128 -0
- engine_pretrain.py +72 -0
- environment.yaml +26 -0
.gitattributes copy
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "0.2.0",
|
| 3 |
+
"configurations": [
|
| 4 |
+
{
|
| 5 |
+
"name": "Train VBR Codec",
|
| 6 |
+
"type": "debugpy",
|
| 7 |
+
"request": "launch",
|
| 8 |
+
"program": "${workspaceFolder}/fine_tune_vq_w_mask.py",
|
| 9 |
+
"console": "integratedTerminal",
|
| 10 |
+
"args": [
|
| 11 |
+
"-c", "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/fine_tune_vq_w_mask.yaml",
|
| 12 |
+
"--world_size", "2"
|
| 13 |
+
],
|
| 14 |
+
"env": {
|
| 15 |
+
"CUDA_VISIBLE_DEVICES": "2,3",
|
| 16 |
+
"MASTER_ADDR": "localhost",
|
| 17 |
+
"MASTER_PORT": "12345",
|
| 18 |
+
"NODE_RANK": "0",
|
| 19 |
+
"NPROC_PER_NODE": "2",
|
| 20 |
+
"NNODES": "1"
|
| 21 |
+
},
|
| 22 |
+
"preLaunchTask": ""
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
}
|
Fine_Tune_VQ_w_P/20240604_030821.log
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-06-04 03:08:21,626 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P//20240604_030821.log
|
| 2 |
+
2024-06-04 03:08:21,626 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-06-04 03:08:21,626 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-06-04 03:08:21,626 [INFO ] name:exp
|
| 5 |
+
2024-06-04 03:08:21,626 [INFO ] lr:None
|
| 6 |
+
2024-06-04 03:08:21,626 [INFO ] world_size:4
|
| 7 |
+
2024-06-04 03:08:21,626 [INFO ] local_rank:-1
|
| 8 |
+
2024-06-04 03:08:21,627 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-06-04 03:08:21,627 [INFO ] dist_url:env://
|
| 10 |
+
2024-06-04 03:08:21,627 [INFO ] TEST:False
|
| 11 |
+
2024-06-04 03:08:21,627 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-06-04 03:08:21,627 [INFO ] dataset:coco
|
| 13 |
+
2024-06-04 03:08:21,627 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
|
| 14 |
+
2024-06-04 03:08:21,627 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/test/
|
| 15 |
+
2024-06-04 03:08:21,627 [INFO ] eval_path:
|
| 16 |
+
2024-06-04 03:08:21,627 [INFO ] batch_size:64
|
| 17 |
+
2024-06-04 03:08:21,627 [INFO ] learning_rate:0.0001
|
| 18 |
+
2024-06-04 03:08:21,627 [INFO ] test_batch_size:1
|
| 19 |
+
2024-06-04 03:08:21,627 [INFO ] num_workers:4
|
| 20 |
+
2024-06-04 03:08:21,627 [INFO ] epochs:100
|
| 21 |
+
2024-06-04 03:08:21,627 [INFO ] blr:0.0001
|
| 22 |
+
2024-06-04 03:08:21,627 [INFO ] min_lr:5e-06
|
| 23 |
+
2024-06-04 03:08:21,627 [INFO ] accum_iter:1
|
| 24 |
+
2024-06-04 03:08:21,627 [INFO ] input_size:256
|
| 25 |
+
2024-06-04 03:08:21,627 [INFO ] weight_decay:0.05
|
| 26 |
+
2024-06-04 03:08:21,627 [INFO ] warmup_epochs:10
|
| 27 |
+
2024-06-04 03:08:21,627 [INFO ] mask_ratio_min:0.25
|
| 28 |
+
2024-06-04 03:08:21,627 [INFO ] mask_ratio_max:0.75
|
| 29 |
+
2024-06-04 03:08:21,627 [INFO ] grad_clip:3.0
|
| 30 |
+
2024-06-04 03:08:21,628 [INFO ] device:cuda
|
| 31 |
+
2024-06-04 03:08:21,628 [INFO ] distributed:True
|
| 32 |
+
2024-06-04 03:08:21,628 [INFO ] seed:0
|
| 33 |
+
2024-06-04 03:08:21,628 [INFO ] gpu_id:2
|
| 34 |
+
2024-06-04 03:08:21,628 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
|
| 35 |
+
2024-06-04 03:08:21,628 [INFO ] maskrcnn_path:/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl
|
| 36 |
+
2024-06-04 03:08:21,628 [INFO ] start_epoch:0
|
| 37 |
+
2024-06-04 03:08:21,628 [INFO ] pin_mem:True
|
| 38 |
+
2024-06-04 03:08:21,628 [INFO ] save:True
|
| 39 |
+
2024-06-04 03:08:21,628 [INFO ] exp_name:Fine_Tune_VQ_w_P
|
| 40 |
+
2024-06-04 03:08:21,628 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 41 |
+
2024-06-04 03:08:21,628 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P
|
| 42 |
+
2024-06-04 03:08:21,628 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P
|
| 43 |
+
2024-06-04 03:08:21,628 [INFO ] strart_epoch:0
|
| 44 |
+
2024-06-04 03:08:21,628 [INFO ] cuda:True
|
| 45 |
+
2024-06-04 03:08:21,628 [INFO ] rank:0
|
| 46 |
+
2024-06-04 03:08:21,628 [INFO ] gpu:0
|
| 47 |
+
2024-06-04 03:08:21,628 [INFO ] dist_backend:nccl
|
| 48 |
+
2024-06-04 03:08:21,628 [INFO ] ===================================================
|
| 49 |
+
2024-06-04 03:10:20,905 [INFO ] -1|taskloss:2.49146|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.06076|lpips:0.46229|dists:0.19295
|
| 50 |
+
2024-06-04 03:10:48,101 [INFO ] -1|taskloss:1.68947|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.15691|lpips:0.41068|dists:0.16344
|
| 51 |
+
2024-06-04 03:11:14,835 [INFO ] -1|taskloss:1.33339|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.94234|lpips:0.37990|dists:0.14601
|
| 52 |
+
2024-06-04 03:11:41,295 [INFO ] -1|taskloss:1.03034|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.56941|lpips:0.35252|dists:0.13884
|
| 53 |
+
2024-06-04 03:11:41,296 [INFO ] Test loss=1.636163592338562
|
| 54 |
+
2024-06-04 03:28:26,377 [INFO ] 0|taskloss:2.33836|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.29709|lpips:0.46480|dists:0.21399
|
| 55 |
+
2024-06-04 03:28:53,053 [INFO ] 0|taskloss:1.62612|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.42644|lpips:0.42104|dists:0.19210
|
| 56 |
+
2024-06-04 03:29:20,005 [INFO ] 0|taskloss:1.29490|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.93270|lpips:0.38919|dists:0.17758
|
| 57 |
+
2024-06-04 03:29:47,006 [INFO ] 0|taskloss:0.99774|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.65782|lpips:0.36090|dists:0.16477
|
| 58 |
+
2024-06-04 03:29:47,006 [INFO ] Test loss=1.973321333527565
|
| 59 |
+
2024-06-04 03:29:47,960 [INFO ] Saving BEST checkpoint!
|
| 60 |
+
2024-06-04 03:46:15,723 [INFO ] 1|taskloss:2.27418|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.28722|lpips:0.47362|dists:0.22182
|
| 61 |
+
2024-06-04 03:46:42,375 [INFO ] 1|taskloss:1.65933|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.20894|lpips:0.41849|dists:0.19428
|
| 62 |
+
2024-06-04 03:47:09,078 [INFO ] 1|taskloss:1.28544|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.95363|lpips:0.39304|dists:0.18215
|
| 63 |
+
2024-06-04 03:47:35,953 [INFO ] 1|taskloss:1.02331|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.67013|lpips:0.35942|dists:0.17008
|
| 64 |
+
2024-06-04 03:47:35,954 [INFO ] Test loss=2.0538930036127567
|
| 65 |
+
2024-06-04 04:04:02,717 [INFO ] 2|taskloss:2.25254|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:13.97699|lpips:0.47056|dists:0.22627
|
| 66 |
+
2024-06-04 04:04:29,767 [INFO ] 2|taskloss:1.66452|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.12406|lpips:0.42705|dists:0.20062
|
| 67 |
+
2024-06-04 04:04:56,853 [INFO ] 2|taskloss:1.26301|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.65730|lpips:0.39007|dists:0.18716
|
| 68 |
+
2024-06-04 04:05:23,967 [INFO ] 2|taskloss:1.03645|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.02926|lpips:0.36933|dists:0.18181
|
| 69 |
+
2024-06-04 04:05:23,968 [INFO ] Test loss=2.067601718939841
|
| 70 |
+
2024-06-04 04:21:49,690 [INFO ] 3|taskloss:2.28804|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.03534|lpips:0.48001|dists:0.22667
|
| 71 |
+
2024-06-04 04:22:16,508 [INFO ] 3|taskloss:1.66096|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18743|lpips:0.42347|dists:0.20229
|
| 72 |
+
2024-06-04 04:22:43,283 [INFO ] 3|taskloss:1.28439|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.80286|lpips:0.39074|dists:0.18586
|
| 73 |
+
2024-06-04 04:23:09,940 [INFO ] 3|taskloss:1.01509|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.39696|lpips:0.37266|dists:0.18077
|
| 74 |
+
2024-06-04 04:23:09,941 [INFO ] Test loss=2.0790173208806664
|
| 75 |
+
2024-06-04 04:39:36,248 [INFO ] 4|taskloss:2.17187|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.04568|lpips:0.46708|dists:0.21667
|
| 76 |
+
2024-06-04 04:40:03,220 [INFO ] 4|taskloss:1.68821|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.43250|lpips:0.42161|dists:0.19500
|
| 77 |
+
2024-06-04 04:40:29,994 [INFO ] 4|taskloss:1.23894|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.90158|lpips:0.38726|dists:0.18233
|
| 78 |
+
2024-06-04 04:40:56,799 [INFO ] 4|taskloss:1.01278|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.31897|lpips:0.36575|dists:0.17080
|
| 79 |
+
2024-06-04 04:40:56,799 [INFO ] Test loss=2.047704319062177
|
| 80 |
+
2024-06-04 04:57:23,679 [INFO ] 5|taskloss:2.24287|mask ratio:0.75||bpp:0.01395|mask:0.00003|psnr:14.29049|lpips:0.46545|dists:0.21852
|
| 81 |
+
2024-06-04 04:57:50,801 [INFO ] 5|taskloss:1.62839|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.02618|lpips:0.41894|dists:0.19633
|
| 82 |
+
2024-06-04 04:58:17,569 [INFO ] 5|taskloss:1.29102|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.74514|lpips:0.39395|dists:0.18185
|
| 83 |
+
2024-06-04 04:58:44,510 [INFO ] 5|taskloss:1.01440|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.38981|lpips:0.36329|dists:0.17676
|
| 84 |
+
2024-06-04 04:58:44,511 [INFO ] Test loss=2.0560930917417863
|
| 85 |
+
2024-06-04 05:15:10,098 [INFO ] 6|taskloss:2.32205|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.04229|lpips:0.46922|dists:0.21845
|
| 86 |
+
2024-06-04 05:15:36,966 [INFO ] 6|taskloss:1.60729|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:14.98244|lpips:0.42790|dists:0.20636
|
| 87 |
+
2024-06-04 05:16:03,975 [INFO ] 6|taskloss:1.26320|mask ratio:0.55||bpp:0.02509|mask:0.00003|psnr:15.94796|lpips:0.38764|dists:0.18432
|
| 88 |
+
2024-06-04 05:16:30,771 [INFO ] 6|taskloss:1.03816|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.29310|lpips:0.37118|dists:0.17678
|
| 89 |
+
2024-06-04 05:16:30,771 [INFO ] Test loss=2.0717015551963414
|
| 90 |
+
2024-06-04 05:32:57,502 [INFO ] 7|taskloss:2.36936|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.30189|lpips:0.46484|dists:0.21939
|
| 91 |
+
2024-06-04 05:33:24,434 [INFO ] 7|taskloss:1.63491|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.08945|lpips:0.42373|dists:0.19809
|
| 92 |
+
2024-06-04 05:33:51,433 [INFO ] 7|taskloss:1.23911|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.95378|lpips:0.38986|dists:0.18503
|
| 93 |
+
2024-06-04 05:34:18,373 [INFO ] 7|taskloss:0.99703|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.33615|lpips:0.36766|dists:0.17350
|
| 94 |
+
2024-06-04 05:34:18,373 [INFO ] Test loss=2.0780272555211923
|
| 95 |
+
2024-06-04 05:50:45,587 [INFO ] 8|taskloss:2.23844|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.26759|lpips:0.46414|dists:0.22619
|
| 96 |
+
2024-06-04 05:51:12,540 [INFO ] 8|taskloss:1.66498|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:14.91135|lpips:0.43537|dists:0.21694
|
| 97 |
+
2024-06-04 05:51:39,386 [INFO ] 8|taskloss:1.27593|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.52610|lpips:0.40321|dists:0.20067
|
| 98 |
+
2024-06-04 05:52:05,972 [INFO ] 8|taskloss:1.03960|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:15.80372|lpips:0.38153|dists:0.19366
|
| 99 |
+
2024-06-04 05:52:05,973 [INFO ] Test loss=2.0742442923749422
|
| 100 |
+
2024-06-04 06:08:33,064 [INFO ] 9|taskloss:2.32985|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.33968|lpips:0.47202|dists:0.22093
|
| 101 |
+
2024-06-04 06:08:59,961 [INFO ] 9|taskloss:1.62963|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.15994|lpips:0.42402|dists:0.19959
|
| 102 |
+
2024-06-04 06:09:26,718 [INFO ] 9|taskloss:1.25288|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.62853|lpips:0.38951|dists:0.17977
|
| 103 |
+
2024-06-04 06:09:53,638 [INFO ] 9|taskloss:1.00600|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.26209|lpips:0.36578|dists:0.17169
|
| 104 |
+
2024-06-04 06:09:53,639 [INFO ] Test loss=2.0731508512786263
|
| 105 |
+
2024-06-04 06:26:19,204 [INFO ] 10|taskloss:2.34149|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.26575|lpips:0.47004|dists:0.22524
|
| 106 |
+
2024-06-04 06:26:45,967 [INFO ] 10|taskloss:1.62147|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.09978|lpips:0.42139|dists:0.19693
|
| 107 |
+
2024-06-04 06:27:12,796 [INFO ] 10|taskloss:1.27020|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.77150|lpips:0.38540|dists:0.17910
|
| 108 |
+
2024-06-04 06:27:39,548 [INFO ] 10|taskloss:1.02932|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.26428|lpips:0.36994|dists:0.17073
|
| 109 |
+
2024-06-04 06:27:39,549 [INFO ] Test loss=2.0839089696479647
|
| 110 |
+
2024-06-04 06:44:06,036 [INFO ] 11|taskloss:2.24896|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.07473|lpips:0.47022|dists:0.21116
|
| 111 |
+
2024-06-04 06:44:32,940 [INFO ] 11|taskloss:1.69199|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18235|lpips:0.42282|dists:0.18916
|
| 112 |
+
2024-06-04 06:44:59,784 [INFO ] 11|taskloss:1.26253|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:16.03878|lpips:0.38759|dists:0.17844
|
| 113 |
+
2024-06-04 06:45:26,464 [INFO ] 11|taskloss:1.01063|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.42851|lpips:0.35960|dists:0.16572
|
| 114 |
+
2024-06-04 06:45:26,465 [INFO ] Test loss=2.0745057896989785
|
| 115 |
+
2024-06-04 07:01:51,689 [INFO ] 12|taskloss:2.26208|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.23028|lpips:0.46908|dists:0.22384
|
| 116 |
+
2024-06-04 07:02:18,690 [INFO ] 12|taskloss:1.60826|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.05205|lpips:0.42262|dists:0.19667
|
| 117 |
+
2024-06-04 07:02:45,653 [INFO ] 12|taskloss:1.24774|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.88546|lpips:0.38749|dists:0.18423
|
| 118 |
+
2024-06-04 07:03:12,706 [INFO ] 12|taskloss:1.00849|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.33344|lpips:0.36366|dists:0.17469
|
| 119 |
+
2024-06-04 07:03:12,707 [INFO ] Test loss=2.0502657253597727
|
| 120 |
+
2024-06-04 07:19:36,537 [INFO ] 13|taskloss:2.27210|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.07016|lpips:0.47514|dists:0.21749
|
| 121 |
+
2024-06-04 07:20:03,327 [INFO ] 13|taskloss:1.62562|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.32130|lpips:0.41837|dists:0.19381
|
| 122 |
+
2024-06-04 07:20:30,312 [INFO ] 13|taskloss:1.20922|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:16.01362|lpips:0.38459|dists:0.17793
|
| 123 |
+
2024-06-04 07:20:57,097 [INFO ] 13|taskloss:1.02460|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.43850|lpips:0.36615|dists:0.17285
|
| 124 |
+
2024-06-04 07:20:57,097 [INFO ] Test loss=2.045453323897087
|
| 125 |
+
2024-06-04 07:37:23,230 [INFO ] 14|taskloss:2.26558|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.31213|lpips:0.47121|dists:0.22154
|
| 126 |
+
2024-06-04 07:37:50,002 [INFO ] 14|taskloss:1.65192|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.25740|lpips:0.43203|dists:0.20555
|
| 127 |
+
2024-06-04 07:38:16,697 [INFO ] 14|taskloss:1.28469|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.82785|lpips:0.39241|dists:0.18424
|
| 128 |
+
2024-06-04 07:38:43,483 [INFO ] 14|taskloss:0.98623|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.52078|lpips:0.36753|dists:0.17625
|
| 129 |
+
2024-06-04 07:38:43,484 [INFO ] Test loss=2.058469656600645
|
| 130 |
+
2024-06-04 07:55:08,283 [INFO ] 15|taskloss:2.34680|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.18895|lpips:0.46331|dists:0.22039
|
| 131 |
+
2024-06-04 07:55:34,925 [INFO ] 15|taskloss:1.61486|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.15459|lpips:0.42977|dists:0.20644
|
| 132 |
+
2024-06-04 07:56:01,738 [INFO ] 15|taskloss:1.21967|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.76448|lpips:0.38905|dists:0.18471
|
| 133 |
+
2024-06-04 07:56:28,420 [INFO ] 15|taskloss:1.01635|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.40507|lpips:0.36936|dists:0.17892
|
| 134 |
+
2024-06-04 07:56:28,421 [INFO ] Test loss=2.064034671459121
|
| 135 |
+
2024-06-04 08:12:55,104 [INFO ] 16|taskloss:2.24783|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.38436|lpips:0.46735|dists:0.21890
|
| 136 |
+
2024-06-04 08:13:22,005 [INFO ] 16|taskloss:1.68004|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.20924|lpips:0.42151|dists:0.19677
|
| 137 |
+
2024-06-04 08:13:48,870 [INFO ] 16|taskloss:1.22699|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.00963|lpips:0.38648|dists:0.18326
|
| 138 |
+
2024-06-04 08:14:15,705 [INFO ] 16|taskloss:1.00658|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.69925|lpips:0.36080|dists:0.17115
|
| 139 |
+
2024-06-04 08:14:15,706 [INFO ] Test loss=2.0563709961557195
|
| 140 |
+
2024-06-04 08:30:41,819 [INFO ] 17|taskloss:2.26464|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.51690|lpips:0.46617|dists:0.21955
|
| 141 |
+
2024-06-04 08:31:08,480 [INFO ] 17|taskloss:1.69014|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.10275|lpips:0.42308|dists:0.19810
|
| 142 |
+
2024-06-04 08:31:35,927 [INFO ] 17|taskloss:1.24362|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.88523|lpips:0.39605|dists:0.19024
|
| 143 |
+
2024-06-04 08:32:04,307 [INFO ] 17|taskloss:1.00460|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.62445|lpips:0.36797|dists:0.17678
|
| 144 |
+
2024-06-04 08:32:04,308 [INFO ] Test loss=2.064842736402745
|
| 145 |
+
2024-06-04 08:48:38,838 [INFO ] 18|taskloss:2.21708|mask ratio:0.75||bpp:0.01399|mask:0.00003|psnr:14.15950|lpips:0.47041|dists:0.22174
|
| 146 |
+
2024-06-04 08:49:06,574 [INFO ] 18|taskloss:1.61804|mask ratio:0.65||bpp:0.01951|mask:0.00003|psnr:15.24594|lpips:0.42104|dists:0.19944
|
| 147 |
+
2024-06-04 08:49:33,943 [INFO ] 18|taskloss:1.25278|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.76423|lpips:0.39182|dists:0.18584
|
| 148 |
+
2024-06-04 08:50:01,283 [INFO ] 18|taskloss:1.00574|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.33763|lpips:0.36889|dists:0.17756
|
| 149 |
+
2024-06-04 08:50:01,284 [INFO ] Test loss=2.039618709603845
|
| 150 |
+
2024-06-04 09:06:34,348 [INFO ] 19|taskloss:2.28699|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.06225|lpips:0.47773|dists:0.22275
|
| 151 |
+
2024-06-04 09:07:01,971 [INFO ] 19|taskloss:1.62656|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.01703|lpips:0.42262|dists:0.19507
|
| 152 |
+
2024-06-04 09:07:29,380 [INFO ] 19|taskloss:1.26701|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.04263|lpips:0.38800|dists:0.18259
|
| 153 |
+
2024-06-04 09:07:56,604 [INFO ] 19|taskloss:1.02889|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.22125|lpips:0.36636|dists:0.17257
|
| 154 |
+
2024-06-04 09:07:56,605 [INFO ] Test loss=2.0622663147548583
|
| 155 |
+
2024-06-04 09:24:25,377 [INFO ] 20|taskloss:2.24449|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.10004|lpips:0.46875|dists:0.22301
|
| 156 |
+
2024-06-04 09:24:52,350 [INFO ] 20|taskloss:1.67153|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.36281|lpips:0.42634|dists:0.20213
|
| 157 |
+
2024-06-04 09:25:19,089 [INFO ] 20|taskloss:1.27836|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.84954|lpips:0.39339|dists:0.18800
|
| 158 |
+
2024-06-04 09:25:45,995 [INFO ] 20|taskloss:1.01734|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.35971|lpips:0.36891|dists:0.17725
|
| 159 |
+
2024-06-04 09:25:45,996 [INFO ] Test loss=2.0684968145514464
|
| 160 |
+
2024-06-04 09:42:14,168 [INFO ] 21|taskloss:2.27269|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.36867|lpips:0.46679|dists:0.21023
|
| 161 |
+
2024-06-04 09:42:41,882 [INFO ] 21|taskloss:1.59626|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.07057|lpips:0.42891|dists:0.19607
|
| 162 |
+
2024-06-04 09:43:09,575 [INFO ] 21|taskloss:1.23144|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.13393|lpips:0.38757|dists:0.18141
|
| 163 |
+
2024-06-04 09:43:37,232 [INFO ] 21|taskloss:1.02276|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.52278|lpips:0.36648|dists:0.17367
|
| 164 |
+
2024-06-04 09:43:37,233 [INFO ] Test loss=2.0479125060594314
|
| 165 |
+
2024-06-04 10:00:03,966 [INFO ] 22|taskloss:2.29455|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.29998|lpips:0.46570|dists:0.21669
|
| 166 |
+
2024-06-04 10:00:32,019 [INFO ] 22|taskloss:1.68339|mask ratio:0.65||bpp:0.01951|mask:0.00003|psnr:15.14034|lpips:0.42580|dists:0.20284
|
| 167 |
+
2024-06-04 10:00:59,949 [INFO ] 22|taskloss:1.26892|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:16.01256|lpips:0.39080|dists:0.18718
|
| 168 |
+
2024-06-04 10:01:27,782 [INFO ] 22|taskloss:1.02441|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.42309|lpips:0.36846|dists:0.17773
|
| 169 |
+
2024-06-04 10:01:27,783 [INFO ] Test loss=2.0797956952938494
|
| 170 |
+
2024-06-04 10:17:48,799 [INFO ] 23|taskloss:2.28370|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.41319|lpips:0.47015|dists:0.21995
|
| 171 |
+
2024-06-04 10:18:16,236 [INFO ] 23|taskloss:1.61646|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18517|lpips:0.42225|dists:0.20285
|
| 172 |
+
2024-06-04 10:18:43,717 [INFO ] 23|taskloss:1.26966|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.16500|lpips:0.39132|dists:0.18260
|
| 173 |
+
2024-06-04 10:19:11,303 [INFO ] 23|taskloss:1.00680|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.58024|lpips:0.36382|dists:0.17243
|
| 174 |
+
2024-06-04 10:19:11,304 [INFO ] Test loss=2.0641073229285345
|
Fine_Tune_VQ_w_P/checkpoint-20.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54885516ebb1f758a23f3f74c232ecf58376c561d0ed59fe51607b2103fdbfc8
|
| 3 |
+
size 867021559
|
Fine_Tune_VQ_w_P/checkpoint-last.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1af8fab142de32e2b7a9f341dc0eef48b6f5f2c91cfd42840697c817240cf78a
|
| 3 |
+
size 867021559
|
Fine_Tune_VQ_w_P/events.out.tfevents.1717470501.t2vg-a100-G4-10.976065.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88722834393587807a9428e2fb5b03c41b660c9fed2948159db4e2b773bb74a1
|
| 3 |
+
size 999224
|
Fine_Tune_VQ_w_P/log.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"train_lr": 4.9890829694323135e-06, "train_loss": 1.243749070896332, "train_mask_ratio": 0.5038230929312227, "epoch": 0}
|
| 2 |
+
{"train_lr": 1.4989082969432316e-05, "train_loss": 1.215295555061909, "train_mask_ratio": 0.502963803220524, "epoch": 1}
|
| 3 |
+
{"train_lr": 2.498908296943232e-05, "train_loss": 1.2075574521349508, "train_mask_ratio": 0.5024819186681223, "epoch": 2}
|
| 4 |
+
{"train_lr": 3.498908296943235e-05, "train_loss": 1.2206204718702744, "train_mask_ratio": 0.50749479735262, "epoch": 3}
|
| 5 |
+
{"train_lr": 4.498908296943233e-05, "train_loss": 1.2098940466754302, "train_mask_ratio": 0.5005287936681223, "epoch": 4}
|
| 6 |
+
{"train_lr": 5.498908296943233e-05, "train_loss": 1.190811416142372, "train_mask_ratio": 0.4951235841975983, "epoch": 5}
|
| 7 |
+
{"train_lr": 6.49890829694323e-05, "train_loss": 1.21159366228695, "train_mask_ratio": 0.5036205308406113, "epoch": 6}
|
| 8 |
+
{"train_lr": 7.498908296943243e-05, "train_loss": 1.2252963672038248, "train_mask_ratio": 0.508678186408297, "epoch": 7}
|
| 9 |
+
{"train_lr": 8.498908296943224e-05, "train_loss": 1.1960869143233028, "train_mask_ratio": 0.49665879162117904, "epoch": 8}
|
| 10 |
+
{"train_lr": 9.49890829694323e-05, "train_loss": 1.1917035718932423, "train_mask_ratio": 0.49613426241812225, "epoch": 9}
|
| 11 |
+
{"train_lr": 9.999038593559283e-05, "train_loss": 1.2097623654718483, "train_mask_ratio": 0.5052154407751092, "epoch": 10}
|
| 12 |
+
{"train_lr": 9.993258939408399e-05, "train_loss": 1.1992728840966411, "train_mask_ratio": 0.4989019002456332, "epoch": 11}
|
| 13 |
+
{"train_lr": 9.981700354882874e-05, "train_loss": 1.2060506974050051, "train_mask_ratio": 0.50310026610262, "epoch": 12}
|
| 14 |
+
{"train_lr": 9.96437692233749e-05, "train_loss": 1.2169636118685314, "train_mask_ratio": 0.5097805506277293, "epoch": 13}
|
| 15 |
+
{"train_lr": 9.94130974770633e-05, "train_loss": 1.1891578605081317, "train_mask_ratio": 0.49562039437772926, "epoch": 14}
|
| 16 |
+
{"train_lr": 9.912526934788473e-05, "train_loss": 1.2026429740817786, "train_mask_ratio": 0.5030107123362445, "epoch": 15}
|
| 17 |
+
{"train_lr": 9.878063551007767e-05, "train_loss": 1.1993349365336927, "train_mask_ratio": 0.5014925627729258, "epoch": 16}
|
| 18 |
+
{"train_lr": 9.837961584688702e-05, "train_loss": 1.1961212191063764, "train_mask_ratio": 0.5002942480895196, "epoch": 17}
|
| 19 |
+
{"train_lr": 9.792269893900009e-05, "train_loss": 1.1912841638753509, "train_mask_ratio": 0.49881874317685587, "epoch": 18}
|
| 20 |
+
{"train_lr": 9.741044146928648e-05, "train_loss": 1.1925643646691044, "train_mask_ratio": 0.4981215031386463, "epoch": 19}
|
| 21 |
+
{"train_lr": 9.68434675445656e-05, "train_loss": 1.1990113721013589, "train_mask_ratio": 0.5008848765010917, "epoch": 20}
|
| 22 |
+
{"train_lr": 9.622246793522915e-05, "train_loss": 1.2057850010790678, "train_mask_ratio": 0.5041706468340611, "epoch": 21}
|
| 23 |
+
{"train_lr": 9.554819923364366e-05, "train_loss": 1.2057653185972481, "train_mask_ratio": 0.5051941184497817, "epoch": 22}
|
| 24 |
+
{"train_lr": 9.482148293235867e-05, "train_loss": 1.1847582247645054, "train_mask_ratio": 0.49711722161572053, "epoch": 23}
|
MIM_vbr/20240516_010714.log
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-05-16 01:07:14,223 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240516_010714.log
|
| 2 |
+
2024-05-16 01:07:14,223 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-05-16 01:07:14,223 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-05-16 01:07:14,223 [INFO ] name:exp
|
| 5 |
+
2024-05-16 01:07:14,223 [INFO ] lr:None
|
| 6 |
+
2024-05-16 01:07:14,223 [INFO ] world_size:1
|
| 7 |
+
2024-05-16 01:07:14,223 [INFO ] local_rank:-1
|
| 8 |
+
2024-05-16 01:07:14,223 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-05-16 01:07:14,223 [INFO ] dist_url:env://
|
| 10 |
+
2024-05-16 01:07:14,223 [INFO ] TEST:False
|
| 11 |
+
2024-05-16 01:07:14,224 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-05-16 01:07:14,224 [INFO ] dataset:coco
|
| 13 |
+
2024-05-16 01:07:14,224 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
|
| 14 |
+
2024-05-16 01:07:14,224 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017/
|
| 15 |
+
2024-05-16 01:07:14,224 [INFO ] eval_path:
|
| 16 |
+
2024-05-16 01:07:14,224 [INFO ] batch_size:32
|
| 17 |
+
2024-05-16 01:07:14,224 [INFO ] test_batch_size:1
|
| 18 |
+
2024-05-16 01:07:14,224 [INFO ] num_workers:4
|
| 19 |
+
2024-05-16 01:07:14,224 [INFO ] epochs:400
|
| 20 |
+
2024-05-16 01:07:14,224 [INFO ] blr:0.00015
|
| 21 |
+
2024-05-16 01:07:14,224 [INFO ] min_lr:5e-06
|
| 22 |
+
2024-05-16 01:07:14,224 [INFO ] accum_iter:1
|
| 23 |
+
2024-05-16 01:07:14,224 [INFO ] input_size:256
|
| 24 |
+
2024-05-16 01:07:14,224 [INFO ] weight_decay:0.05
|
| 25 |
+
2024-05-16 01:07:14,224 [INFO ] warmup_epochs:20
|
| 26 |
+
2024-05-16 01:07:14,224 [INFO ] mask_ratio_min:0.5
|
| 27 |
+
2024-05-16 01:07:14,224 [INFO ] mask_ratio_max:0.9
|
| 28 |
+
2024-05-16 01:07:14,224 [INFO ] mask_ratio_mu:0.55
|
| 29 |
+
2024-05-16 01:07:14,224 [INFO ] mask_ratio_std:0.25
|
| 30 |
+
2024-05-16 01:07:14,225 [INFO ] grad_clip:3.0
|
| 31 |
+
2024-05-16 01:07:14,225 [INFO ] device:cuda
|
| 32 |
+
2024-05-16 01:07:14,225 [INFO ] distributed:True
|
| 33 |
+
2024-05-16 01:07:14,225 [INFO ] seed:0
|
| 34 |
+
2024-05-16 01:07:14,225 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_cls_ori_pos_embed.pth
|
| 35 |
+
2024-05-16 01:07:14,225 [INFO ] start_epoch:0
|
| 36 |
+
2024-05-16 01:07:14,225 [INFO ] pin_mem:True
|
| 37 |
+
2024-05-16 01:07:14,225 [INFO ] save:True
|
| 38 |
+
2024-05-16 01:07:14,225 [INFO ] exp_name:MIM_vbr
|
| 39 |
+
2024-05-16 01:07:14,225 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 40 |
+
2024-05-16 01:07:14,225 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_high_resolu_eval
|
| 41 |
+
2024-05-16 01:07:14,225 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_high_resolu_eval
|
| 42 |
+
2024-05-16 01:07:14,225 [INFO ] strart_epoch:0
|
| 43 |
+
2024-05-16 01:07:14,225 [INFO ] cuda:True
|
| 44 |
+
2024-05-16 01:07:14,225 [INFO ] rank:0
|
| 45 |
+
2024-05-16 01:07:14,225 [INFO ] gpu:0
|
| 46 |
+
2024-05-16 01:07:14,225 [INFO ] dist_backend:nccl
|
| 47 |
+
2024-05-16 01:07:14,225 [INFO ] ===================================================
|
| 48 |
+
2024-05-16 04:50:34,512 [INFO ] -1|bpp:0.06320|mask:0.01002|mask_ratio:0.05|psnr:19.27399|lpips:0.29888|dists:0.10206|Test loss:3932.46558
|
| 49 |
+
2024-05-16 08:29:19,640 [INFO ] -1|bpp:0.05797|mask:0.01040|mask_ratio:0.15|psnr:18.67500|lpips:0.31491|dists:0.10963|Test loss:1802.33093
|
| 50 |
+
2024-05-16 12:06:32,090 [INFO ] -1|bpp:0.05017|mask:0.01099|mask_ratio:0.3|psnr:17.72943|lpips:0.34331|dists:0.12283|Test loss:563.63690
|
| 51 |
+
2024-05-16 15:45:18,037 [INFO ] -1|bpp:0.04235|mask:0.01158|mask_ratio:0.45|psnr:16.67805|lpips:0.37921|dists:0.13982|Test loss:177.07100
|
| 52 |
+
2024-05-16 19:29:09,057 [INFO ] -1|bpp:0.03453|mask:0.01215|mask_ratio:0.6|psnr:15.40944|lpips:0.42819|dists:0.16429|Test loss:56.36422
|
| 53 |
+
2024-05-16 22:59:34,921 [INFO ] -1|bpp:0.02675|mask:0.01275|mask_ratio:0.75|psnr:13.72992|lpips:0.49992|dists:0.20360|Test loss:18.22452
|
MIM_vbr/checkpoint-399.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4435ef340ca3965a9c47c956fea2bb48e8613162260e443c4a37d428e6890569
|
| 3 |
+
size 1960373936
|
MIM_vbr/continue_tune_rope.log
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-05-03 13:24:36,083 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
|
| 2 |
+
2024-05-03 13:24:36,083 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
|
| 3 |
+
2024-05-03 13:24:36,083 [INFO ] ======================= exp =======================
|
| 4 |
+
2024-05-03 13:24:36,083 [INFO ] ======================= exp =======================
|
| 5 |
+
2024-05-03 13:24:36,083 [INFO ] config:config/vpt_default.yaml
|
| 6 |
+
2024-05-03 13:24:36,083 [INFO ] config:config/vpt_default.yaml
|
| 7 |
+
2024-05-03 13:24:36,083 [INFO ] name:exp
|
| 8 |
+
2024-05-03 13:24:36,083 [INFO ] name:exp
|
| 9 |
+
2024-05-03 13:24:36,083 [INFO ] lr:None
|
| 10 |
+
2024-05-03 13:24:36,083 [INFO ] lr:None
|
| 11 |
+
2024-05-03 13:24:36,083 [INFO ] world_size:4
|
| 12 |
+
2024-05-03 13:24:36,083 [INFO ] world_size:4
|
| 13 |
+
2024-05-03 13:24:36,084 [INFO ] local_rank:-1
|
| 14 |
+
2024-05-03 13:24:36,084 [INFO ] local_rank:-1
|
| 15 |
+
2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
|
| 16 |
+
2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
|
| 17 |
+
2024-05-03 13:24:36,084 [INFO ] dist_url:env://
|
| 18 |
+
2024-05-03 13:24:36,084 [INFO ] dist_url:env://
|
| 19 |
+
2024-05-03 13:24:36,084 [INFO ] TEST:False
|
| 20 |
+
2024-05-03 13:24:36,084 [INFO ] TEST:False
|
| 21 |
+
2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
|
| 22 |
+
2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
|
| 23 |
+
2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
|
| 24 |
+
2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
|
| 25 |
+
2024-05-03 13:24:36,084 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
|
| 26 |
+
2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 27 |
+
2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 28 |
+
2024-05-03 13:24:36,084 [INFO ] eval_path:
|
| 29 |
+
2024-05-03 13:24:36,084 [INFO ] eval_path:
|
| 30 |
+
2024-05-03 13:24:36,084 [INFO ] batch_size:64
|
| 31 |
+
2024-05-03 13:24:36,084 [INFO ] batch_size:64
|
| 32 |
+
2024-05-03 13:24:36,084 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
|
| 33 |
+
2024-05-03 13:24:36,084 [INFO ] ======================= exp =======================
|
| 34 |
+
2024-05-03 13:24:36,084 [INFO ] test_batch_size:64
|
| 35 |
+
2024-05-03 13:24:36,084 [INFO ] test_batch_size:64
|
| 36 |
+
2024-05-03 13:24:36,084 [INFO ] config:config/vpt_default.yaml
|
| 37 |
+
2024-05-03 13:24:36,084 [INFO ] num_workers:4
|
| 38 |
+
2024-05-03 13:24:36,084 [INFO ] num_workers:4
|
| 39 |
+
2024-05-03 13:24:36,084 [INFO ] ======================= exp =======================
|
| 40 |
+
2024-05-03 13:24:36,084 [INFO ] name:exp
|
| 41 |
+
2024-05-03 13:24:36,084 [INFO ] epochs:400
|
| 42 |
+
2024-05-03 13:24:36,084 [INFO ] epochs:400
|
| 43 |
+
2024-05-03 13:24:36,084 [INFO ] lr:None
|
| 44 |
+
2024-05-03 13:24:36,084 [INFO ] blr:0.00015
|
| 45 |
+
2024-05-03 13:24:36,084 [INFO ] config:config/vpt_default.yaml
|
| 46 |
+
2024-05-03 13:24:36,084 [INFO ] blr:0.00015
|
| 47 |
+
2024-05-03 13:24:36,084 [INFO ] world_size:4
|
| 48 |
+
2024-05-03 13:24:36,084 [INFO ] min_lr:5e-06
|
| 49 |
+
2024-05-03 13:24:36,084 [INFO ] name:exp
|
| 50 |
+
2024-05-03 13:24:36,084 [INFO ] min_lr:5e-06
|
| 51 |
+
2024-05-03 13:24:36,084 [INFO ] accum_iter:1
|
| 52 |
+
2024-05-03 13:24:36,084 [INFO ] local_rank:-1
|
| 53 |
+
2024-05-03 13:24:36,084 [INFO ] lr:None
|
| 54 |
+
2024-05-03 13:24:36,084 [INFO ] accum_iter:1
|
| 55 |
+
2024-05-03 13:24:36,084 [INFO ] world_size:4
|
| 56 |
+
2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
|
| 57 |
+
2024-05-03 13:24:36,084 [INFO ] input_size:256
|
| 58 |
+
2024-05-03 13:24:36,084 [INFO ] input_size:256
|
| 59 |
+
2024-05-03 13:24:36,084 [INFO ] weight_decay:0.05
|
| 60 |
+
2024-05-03 13:24:36,084 [INFO ] dist_url:env://
|
| 61 |
+
2024-05-03 13:24:36,084 [INFO ] local_rank:-1
|
| 62 |
+
2024-05-03 13:24:36,084 [INFO ] weight_decay:0.05
|
| 63 |
+
2024-05-03 13:24:36,084 [INFO ] warmup_epochs:20
|
| 64 |
+
2024-05-03 13:24:36,084 [INFO ] warmup_epochs:20
|
| 65 |
+
2024-05-03 13:24:36,084 [INFO ] TEST:False
|
| 66 |
+
2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
|
| 67 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_min:0.5
|
| 68 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_min:0.5
|
| 69 |
+
2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
|
| 70 |
+
2024-05-03 13:24:36,084 [INFO ] dist_url:env://
|
| 71 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_max:0.9
|
| 72 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_max:0.9
|
| 73 |
+
2024-05-03 13:24:36,084 [INFO ] TEST:False
|
| 74 |
+
2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
|
| 75 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_mu:0.55
|
| 76 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_mu:0.55
|
| 77 |
+
2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
|
| 78 |
+
2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 79 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_std:0.25
|
| 80 |
+
2024-05-03 13:24:36,084 [INFO ] mask_ratio_std:0.25
|
| 81 |
+
2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
|
| 82 |
+
2024-05-03 13:24:36,084 [INFO ] eval_path:
|
| 83 |
+
2024-05-03 13:24:36,084 [INFO ] grad_clip:3.0
|
| 84 |
+
2024-05-03 13:24:36,084 [INFO ] grad_clip:3.0
|
| 85 |
+
2024-05-03 13:24:36,085 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 86 |
+
2024-05-03 13:24:36,085 [INFO ] batch_size:64
|
| 87 |
+
2024-05-03 13:24:36,085 [INFO ] device:cuda
|
| 88 |
+
2024-05-03 13:24:36,085 [INFO ] device:cuda
|
| 89 |
+
2024-05-03 13:24:36,085 [INFO ] test_batch_size:64
|
| 90 |
+
2024-05-03 13:24:36,085 [INFO ] eval_path:
|
| 91 |
+
2024-05-03 13:24:36,085 [INFO ] distributed:True
|
| 92 |
+
2024-05-03 13:24:36,085 [INFO ] distributed:True
|
| 93 |
+
2024-05-03 13:24:36,085 [INFO ] num_workers:4
|
| 94 |
+
2024-05-03 13:24:36,085 [INFO ] batch_size:64
|
| 95 |
+
2024-05-03 13:24:36,085 [INFO ] seed:0
|
| 96 |
+
2024-05-03 13:24:36,085 [INFO ] seed:0
|
| 97 |
+
2024-05-03 13:24:36,085 [INFO ] test_batch_size:64
|
| 98 |
+
2024-05-03 13:24:36,085 [INFO ] epochs:400
|
| 99 |
+
2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
|
| 100 |
+
2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
|
| 101 |
+
2024-05-03 13:24:36,085 [INFO ] blr:0.00015
|
| 102 |
+
2024-05-03 13:24:36,085 [INFO ] num_workers:4
|
| 103 |
+
2024-05-03 13:24:36,085 [INFO ] start_epoch:0
|
| 104 |
+
2024-05-03 13:24:36,085 [INFO ] start_epoch:0
|
| 105 |
+
2024-05-03 13:24:36,085 [INFO ] epochs:400
|
| 106 |
+
2024-05-03 13:24:36,085 [INFO ] min_lr:5e-06
|
| 107 |
+
2024-05-03 13:24:36,085 [INFO ] pin_mem:True
|
| 108 |
+
2024-05-03 13:24:36,085 [INFO ] pin_mem:True
|
| 109 |
+
2024-05-03 13:24:36,085 [INFO ] blr:0.00015
|
| 110 |
+
2024-05-03 13:24:36,085 [INFO ] accum_iter:1
|
| 111 |
+
2024-05-03 13:24:36,085 [INFO ] save:True
|
| 112 |
+
2024-05-03 13:24:36,085 [INFO ] save:True
|
| 113 |
+
2024-05-03 13:24:36,085 [INFO ] input_size:256
|
| 114 |
+
2024-05-03 13:24:36,085 [INFO ] min_lr:5e-06
|
| 115 |
+
2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
|
| 116 |
+
2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
|
| 117 |
+
2024-05-03 13:24:36,085 [INFO ] weight_decay:0.05
|
| 118 |
+
2024-05-03 13:24:36,085 [INFO ] accum_iter:1
|
| 119 |
+
2024-05-03 13:24:36,085 [INFO ] warmup_epochs:20
|
| 120 |
+
2024-05-03 13:24:36,085 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 121 |
+
2024-05-03 13:24:36,085 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 122 |
+
2024-05-03 13:24:36,085 [INFO ] input_size:256
|
| 123 |
+
2024-05-03 13:24:36,085 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 124 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_min:0.5
|
| 125 |
+
2024-05-03 13:24:36,085 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 126 |
+
2024-05-03 13:24:36,085 [INFO ] weight_decay:0.05
|
| 127 |
+
2024-05-03 13:24:36,085 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 128 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_max:0.9
|
| 129 |
+
2024-05-03 13:24:36,085 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 130 |
+
2024-05-03 13:24:36,085 [INFO ] warmup_epochs:20
|
| 131 |
+
2024-05-03 13:24:36,085 [INFO ] strart_epoch:0
|
| 132 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_mu:0.55
|
| 133 |
+
2024-05-03 13:24:36,085 [INFO ] strart_epoch:0
|
| 134 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_min:0.5
|
| 135 |
+
2024-05-03 13:24:36,085 [INFO ] cuda:True
|
| 136 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_std:0.25
|
| 137 |
+
2024-05-03 13:24:36,085 [INFO ] cuda:True
|
| 138 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_max:0.9
|
| 139 |
+
2024-05-03 13:24:36,085 [INFO ] rank:0
|
| 140 |
+
2024-05-03 13:24:36,085 [INFO ] grad_clip:3.0
|
| 141 |
+
2024-05-03 13:24:36,085 [INFO ] rank:1
|
| 142 |
+
2024-05-03 13:24:36,085 [INFO ] gpu:0
|
| 143 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_mu:0.55
|
| 144 |
+
2024-05-03 13:24:36,085 [INFO ] device:cuda
|
| 145 |
+
2024-05-03 13:24:36,085 [INFO ] gpu:1
|
| 146 |
+
2024-05-03 13:24:36,085 [INFO ] dist_backend:nccl
|
| 147 |
+
2024-05-03 13:24:36,085 [INFO ] distributed:True
|
| 148 |
+
2024-05-03 13:24:36,085 [INFO ] mask_ratio_std:0.25
|
| 149 |
+
2024-05-03 13:24:36,085 [INFO ] dist_backend:nccl
|
| 150 |
+
2024-05-03 13:24:36,085 [INFO ] ===================================================
|
| 151 |
+
2024-05-03 13:24:36,085 [INFO ] grad_clip:3.0
|
| 152 |
+
2024-05-03 13:24:36,085 [INFO ] seed:0
|
| 153 |
+
2024-05-03 13:24:36,085 [INFO ] ===================================================
|
| 154 |
+
2024-05-03 13:24:36,085 [INFO ] device:cuda
|
| 155 |
+
2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
|
| 156 |
+
2024-05-03 13:24:36,085 [INFO ] start_epoch:0
|
| 157 |
+
2024-05-03 13:24:36,085 [INFO ] distributed:True
|
| 158 |
+
2024-05-03 13:24:36,085 [INFO ] seed:0
|
| 159 |
+
2024-05-03 13:24:36,085 [INFO ] pin_mem:True
|
| 160 |
+
2024-05-03 13:24:36,085 [INFO ] save:True
|
| 161 |
+
2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
|
| 162 |
+
2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
|
| 163 |
+
2024-05-03 13:24:36,085 [INFO ] start_epoch:0
|
| 164 |
+
2024-05-03 13:24:36,086 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 165 |
+
2024-05-03 13:24:36,086 [INFO ] pin_mem:True
|
| 166 |
+
2024-05-03 13:24:36,086 [INFO ] save:True
|
| 167 |
+
2024-05-03 13:24:36,086 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 168 |
+
2024-05-03 13:24:36,086 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 169 |
+
2024-05-03 13:24:36,086 [INFO ] exp_name:MIM_vbr
|
| 170 |
+
2024-05-03 13:24:36,086 [INFO ] strart_epoch:0
|
| 171 |
+
2024-05-03 13:24:36,086 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 172 |
+
2024-05-03 13:24:36,086 [INFO ] cuda:True
|
| 173 |
+
2024-05-03 13:24:36,086 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 174 |
+
2024-05-03 13:24:36,086 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
|
| 175 |
+
2024-05-03 13:24:36,086 [INFO ] rank:3
|
| 176 |
+
2024-05-03 13:24:36,086 [INFO ] gpu:3
|
| 177 |
+
2024-05-03 13:24:36,086 [INFO ] strart_epoch:0
|
| 178 |
+
2024-05-03 13:24:36,086 [INFO ] cuda:True
|
| 179 |
+
2024-05-03 13:24:36,086 [INFO ] dist_backend:nccl
|
| 180 |
+
2024-05-03 13:24:36,086 [INFO ] ===================================================
|
| 181 |
+
2024-05-03 13:24:36,086 [INFO ] rank:2
|
| 182 |
+
2024-05-03 13:24:36,086 [INFO ] gpu:2
|
| 183 |
+
2024-05-03 13:24:36,086 [INFO ] dist_backend:nccl
|
| 184 |
+
2024-05-03 13:24:36,086 [INFO ] ===================================================
|
| 185 |
+
2024-05-03 13:25:40,076 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28939|msssim:0.33530|lpips:0.50428|dists:0.27922|Test loss:22.31097
|
| 186 |
+
2024-05-03 13:45:27,931 [INFO ] 71|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27695|msssim:0.33996|lpips:0.50599|dists:0.28065|Test loss:22.37941
|
| 187 |
+
2024-05-03 14:05:20,381 [INFO ] 72|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18272|msssim:0.33581|lpips:0.50830|dists:0.28347|Test loss:22.39342
|
| 188 |
+
2024-05-03 14:25:08,798 [INFO ] 73|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.07158|msssim:0.33378|lpips:0.50512|dists:0.28206|Test loss:22.34897
|
| 189 |
+
2024-05-03 14:45:13,282 [INFO ] 74|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25096|msssim:0.33580|lpips:0.50984|dists:0.28295|Test loss:22.41223
|
| 190 |
+
2024-05-03 15:05:00,304 [INFO ] 75|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34433|msssim:0.34183|lpips:0.50830|dists:0.28667|Test loss:22.41598
|
| 191 |
+
2024-05-03 15:24:49,830 [INFO ] 76|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18746|msssim:0.33532|lpips:0.50817|dists:0.28051|Test loss:22.38807
|
| 192 |
+
2024-05-03 15:44:36,569 [INFO ] 77|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.24094|msssim:0.33337|lpips:0.50695|dists:0.28093|Test loss:22.35640
|
| 193 |
+
2024-05-03 16:04:23,158 [INFO ] 78|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25324|msssim:0.33446|lpips:0.51007|dists:0.28490|Test loss:22.40084
|
| 194 |
+
2024-05-03 16:24:10,702 [INFO ] 79|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22266|msssim:0.33541|lpips:0.50979|dists:0.28361|Test loss:22.35502
|
| 195 |
+
2024-05-03 16:43:58,827 [INFO ] 80|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10645|msssim:0.32806|lpips:0.51042|dists:0.28348|Test loss:22.37403
|
| 196 |
+
2024-05-03 17:03:49,009 [INFO ] 81|bpp:0.01873|mask:0.00888|mask_ratio:0.75|psnr:12.16949|msssim:0.33372|lpips:0.50971|dists:0.28431|Test loss:22.38677
|
| 197 |
+
2024-05-03 17:23:36,872 [INFO ] 82|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19599|msssim:0.33479|lpips:0.50738|dists:0.27982|Test loss:22.39563
|
| 198 |
+
2024-05-03 17:43:24,129 [INFO ] 83|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13946|msssim:0.33657|lpips:0.50608|dists:0.28026|Test loss:22.37709
|
| 199 |
+
2024-05-03 18:03:09,632 [INFO ] 84|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19059|msssim:0.32992|lpips:0.51449|dists:0.28200|Test loss:22.39230
|
| 200 |
+
2024-05-03 18:22:55,813 [INFO ] 85|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18655|msssim:0.33290|lpips:0.50831|dists:0.28350|Test loss:22.38299
|
| 201 |
+
2024-05-03 18:42:41,651 [INFO ] 86|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21415|msssim:0.33196|lpips:0.50763|dists:0.28348|Test loss:22.34028
|
| 202 |
+
2024-05-03 19:02:43,711 [INFO ] 87|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12306|msssim:0.33658|lpips:0.50986|dists:0.28280|Test loss:22.35450
|
| 203 |
+
2024-05-03 19:22:30,211 [INFO ] 88|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27434|msssim:0.33857|lpips:0.50953|dists:0.28144|Test loss:22.32993
|
| 204 |
+
2024-05-03 19:42:31,870 [INFO ] 89|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29707|msssim:0.33239|lpips:0.51159|dists:0.28729|Test loss:22.39535
|
| 205 |
+
2024-05-03 20:02:18,523 [INFO ] 90|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.93105|msssim:0.32874|lpips:0.51374|dists:0.28500|Test loss:22.37451
|
| 206 |
+
2024-05-03 20:22:40,704 [INFO ] 91|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15259|msssim:0.33332|lpips:0.51196|dists:0.28565|Test loss:22.39588
|
| 207 |
+
2024-05-03 20:42:50,402 [INFO ] 92|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22207|msssim:0.33683|lpips:0.50868|dists:0.28256|Test loss:22.34765
|
| 208 |
+
2024-05-03 21:03:00,830 [INFO ] 93|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11886|msssim:0.33659|lpips:0.50814|dists:0.28405|Test loss:22.34739
|
| 209 |
+
2024-05-03 21:23:11,053 [INFO ] 94|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13203|msssim:0.33091|lpips:0.50892|dists:0.28377|Test loss:22.36003
|
| 210 |
+
2024-05-03 21:43:19,380 [INFO ] 95|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13799|msssim:0.33715|lpips:0.51032|dists:0.28500|Test loss:22.35577
|
| 211 |
+
2024-05-03 22:03:29,276 [INFO ] 96|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29663|msssim:0.33942|lpips:0.50849|dists:0.28102|Test loss:22.39275
|
| 212 |
+
2024-05-03 22:23:40,310 [INFO ] 97|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31019|msssim:0.34155|lpips:0.50632|dists:0.28538|Test loss:22.33029
|
| 213 |
+
2024-05-03 22:25:02,985 [INFO ] 98|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28031|msssim:0.33755|lpips:0.50549|dists:0.28245|Test loss:22.33734
|
| 214 |
+
2024-05-03 22:45:13,341 [INFO ] 99|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17199|msssim:0.33258|lpips:0.50680|dists:0.28107|Test loss:22.37611
|
| 215 |
+
2024-05-03 23:05:22,081 [INFO ] 100|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12885|msssim:0.33378|lpips:0.50697|dists:0.27960|Test loss:22.35676
|
| 216 |
+
2024-05-03 23:25:59,251 [INFO ] 101|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23955|msssim:0.33759|lpips:0.50643|dists:0.28164|Test loss:22.33918
|
| 217 |
+
2024-05-03 23:46:28,120 [INFO ] 102|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28548|msssim:0.33363|lpips:0.50599|dists:0.28383|Test loss:22.36715
|
| 218 |
+
2024-05-04 00:07:06,785 [INFO ] 103|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38201|msssim:0.33652|lpips:0.50561|dists:0.28420|Test loss:22.34940
|
| 219 |
+
2024-05-04 00:27:18,585 [INFO ] 104|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23971|msssim:0.33526|lpips:0.50535|dists:0.28196|Test loss:22.32486
|
| 220 |
+
2024-05-04 00:47:46,770 [INFO ] 105|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16863|msssim:0.33425|lpips:0.50665|dists:0.28272|Test loss:22.35667
|
| 221 |
+
2024-05-04 01:07:48,995 [INFO ] 106|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33955|msssim:0.34124|lpips:0.50687|dists:0.28063|Test loss:22.33323
|
| 222 |
+
2024-05-04 01:27:57,424 [INFO ] 107|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32910|msssim:0.33785|lpips:0.50615|dists:0.28166|Test loss:22.33379
|
| 223 |
+
2024-05-04 01:48:04,206 [INFO ] 108|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17313|msssim:0.33479|lpips:0.50718|dists:0.28483|Test loss:22.34892
|
| 224 |
+
2024-05-04 02:08:13,034 [INFO ] 109|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25400|msssim:0.33878|lpips:0.50396|dists:0.27973|Test loss:22.27535
|
| 225 |
+
2024-05-04 02:28:39,568 [INFO ] 110|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18919|msssim:0.33519|lpips:0.50766|dists:0.28370|Test loss:22.34272
|
| 226 |
+
2024-05-04 02:48:31,198 [INFO ] 111|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25287|msssim:0.34151|lpips:0.50405|dists:0.28154|Test loss:22.29224
|
| 227 |
+
2024-05-04 03:08:18,193 [INFO ] 112|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25592|msssim:0.33658|lpips:0.50613|dists:0.27908|Test loss:22.31837
|
| 228 |
+
2024-05-04 03:28:06,312 [INFO ] 113|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19013|msssim:0.33855|lpips:0.50963|dists:0.28313|Test loss:22.35398
|
| 229 |
+
2024-05-04 03:47:53,623 [INFO ] 114|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22372|msssim:0.33530|lpips:0.50483|dists:0.27558|Test loss:22.27828
|
| 230 |
+
2024-05-04 04:07:39,992 [INFO ] 115|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35171|msssim:0.34350|lpips:0.50526|dists:0.28044|Test loss:22.30720
|
| 231 |
+
2024-05-04 04:27:27,792 [INFO ] 116|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14400|msssim:0.33273|lpips:0.51011|dists:0.28111|Test loss:22.33026
|
| 232 |
+
2024-05-04 04:47:15,548 [INFO ] 117|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22451|msssim:0.33556|lpips:0.50722|dists:0.28211|Test loss:22.30748
|
| 233 |
+
2024-05-04 05:07:04,750 [INFO ] 118|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39794|msssim:0.33993|lpips:0.50137|dists:0.28058|Test loss:22.32248
|
| 234 |
+
2024-05-04 05:26:52,003 [INFO ] 119|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35083|msssim:0.34228|lpips:0.50405|dists:0.28113|Test loss:22.29005
|
| 235 |
+
2024-05-04 05:46:54,175 [INFO ] 120|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19306|msssim:0.33767|lpips:0.50508|dists:0.28069|Test loss:22.30016
|
| 236 |
+
2024-05-04 06:07:56,985 [INFO ] 121|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30696|msssim:0.34229|lpips:0.50263|dists:0.27623|Test loss:22.32207
|
| 237 |
+
2024-05-04 06:28:36,777 [INFO ] 122|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32276|msssim:0.34709|lpips:0.50338|dists:0.27890|Test loss:22.27271
|
| 238 |
+
2024-05-04 06:49:34,456 [INFO ] 123|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15545|msssim:0.33990|lpips:0.50738|dists:0.28200|Test loss:22.31753
|
| 239 |
+
2024-05-04 07:10:10,082 [INFO ] 124|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26117|msssim:0.33900|lpips:0.50466|dists:0.28039|Test loss:22.31486
|
| 240 |
+
2024-05-04 07:30:52,259 [INFO ] 125|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22142|msssim:0.33723|lpips:0.50712|dists:0.28167|Test loss:22.32770
|
| 241 |
+
2024-05-04 07:51:29,029 [INFO ] 126|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29230|msssim:0.33539|lpips:0.50329|dists:0.27957|Test loss:22.30125
|
| 242 |
+
2024-05-04 08:11:57,295 [INFO ] 127|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10057|msssim:0.33604|lpips:0.50390|dists:0.27900|Test loss:22.32180
|
| 243 |
+
2024-05-04 08:32:29,651 [INFO ] 128|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35982|msssim:0.34019|lpips:0.50363|dists:0.27863|Test loss:22.30414
|
| 244 |
+
2024-05-04 08:53:13,202 [INFO ] 129|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15058|msssim:0.33351|lpips:0.50549|dists:0.27943|Test loss:22.26283
|
| 245 |
+
2024-05-04 09:13:57,329 [INFO ] 130|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.27512|msssim:0.33742|lpips:0.50353|dists:0.27905|Test loss:22.30725
|
| 246 |
+
2024-05-04 09:15:24,858 [INFO ] 131|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.20872|msssim:0.34373|lpips:0.50210|dists:0.28035|Test loss:22.27275
|
| 247 |
+
2024-05-04 09:36:10,653 [INFO ] 132|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20383|msssim:0.33974|lpips:0.50339|dists:0.27824|Test loss:22.29464
|
| 248 |
+
2024-05-04 09:56:51,249 [INFO ] 133|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12968|msssim:0.33882|lpips:0.50371|dists:0.27564|Test loss:22.30516
|
| 249 |
+
2024-05-04 10:17:09,926 [INFO ] 134|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15733|msssim:0.33909|lpips:0.50420|dists:0.27941|Test loss:22.28371
|
| 250 |
+
2024-05-04 10:37:22,751 [INFO ] 135|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.14190|msssim:0.33220|lpips:0.50398|dists:0.27657|Test loss:22.25691
|
| 251 |
+
2024-05-04 10:58:32,045 [INFO ] 136|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21952|msssim:0.34333|lpips:0.50654|dists:0.28236|Test loss:22.29695
|
| 252 |
+
2024-05-04 11:19:04,383 [INFO ] 137|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.12107|msssim:0.33130|lpips:0.50644|dists:0.27993|Test loss:22.30494
|
| 253 |
+
2024-05-04 11:39:23,509 [INFO ] 138|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28914|msssim:0.34730|lpips:0.50144|dists:0.27736|Test loss:22.26708
|
| 254 |
+
2024-05-04 11:59:44,396 [INFO ] 139|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12318|msssim:0.33862|lpips:0.50449|dists:0.27968|Test loss:22.28541
|
| 255 |
+
2024-05-04 12:20:03,670 [INFO ] 140|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36024|msssim:0.33926|lpips:0.50315|dists:0.27705|Test loss:22.24402
|
| 256 |
+
2024-05-04 12:40:47,447 [INFO ] 141|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37666|msssim:0.33869|lpips:0.50093|dists:0.27720|Test loss:22.27834
|
| 257 |
+
2024-05-04 13:01:04,137 [INFO ] 142|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.21903|msssim:0.34257|lpips:0.50445|dists:0.27656|Test loss:22.27222
|
| 258 |
+
2024-05-04 13:21:33,832 [INFO ] 143|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23086|msssim:0.33986|lpips:0.50459|dists:0.28040|Test loss:22.30935
|
| 259 |
+
2024-05-04 13:42:00,911 [INFO ] 144|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18145|msssim:0.34159|lpips:0.50526|dists:0.27986|Test loss:22.29733
|
| 260 |
+
2024-05-04 14:02:21,466 [INFO ] 145|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.41649|msssim:0.34486|lpips:0.50061|dists:0.27830|Test loss:22.27924
|
| 261 |
+
2024-05-04 14:23:01,203 [INFO ] 146|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40444|msssim:0.33984|lpips:0.50394|dists:0.27939|Test loss:22.27717
|
| 262 |
+
2024-05-04 14:44:11,569 [INFO ] 147|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.10670|msssim:0.33935|lpips:0.50644|dists:0.28129|Test loss:22.28859
|
| 263 |
+
2024-05-04 15:05:12,252 [INFO ] 148|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31263|msssim:0.34371|lpips:0.50163|dists:0.27719|Test loss:22.28988
|
| 264 |
+
2024-05-04 15:26:22,270 [INFO ] 149|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11357|msssim:0.34112|lpips:0.50444|dists:0.28143|Test loss:22.27471
|
| 265 |
+
2024-05-04 15:47:18,685 [INFO ] 150|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37583|msssim:0.34108|lpips:0.50440|dists:0.27920|Test loss:22.28879
|
| 266 |
+
2024-05-04 16:08:26,498 [INFO ] 151|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23233|msssim:0.33481|lpips:0.50793|dists:0.28015|Test loss:22.30692
|
| 267 |
+
2024-05-04 16:29:29,985 [INFO ] 152|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38831|msssim:0.34042|lpips:0.50221|dists:0.27659|Test loss:22.24423
|
| 268 |
+
2024-05-04 16:50:31,249 [INFO ] 153|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19962|msssim:0.34101|lpips:0.50407|dists:0.28065|Test loss:22.29826
|
| 269 |
+
2024-05-04 17:11:31,226 [INFO ] 154|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.11960|msssim:0.34173|lpips:0.50215|dists:0.27693|Test loss:22.27951
|
| 270 |
+
2024-05-04 17:32:28,025 [INFO ] 155|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17072|msssim:0.33614|lpips:0.50420|dists:0.27642|Test loss:22.25476
|
| 271 |
+
2024-05-04 17:53:03,788 [INFO ] 156|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18252|msssim:0.33534|lpips:0.50560|dists:0.27726|Test loss:22.23966
|
| 272 |
+
2024-05-04 18:13:53,631 [INFO ] 157|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.30066|msssim:0.34024|lpips:0.50406|dists:0.27919|Test loss:22.27285
|
| 273 |
+
2024-05-04 18:34:02,548 [INFO ] 158|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16105|msssim:0.33724|lpips:0.50539|dists:0.27802|Test loss:22.28563
|
| 274 |
+
2024-05-04 18:54:17,391 [INFO ] 159|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26286|msssim:0.33958|lpips:0.50246|dists:0.27786|Test loss:22.25693
|
| 275 |
+
2024-05-04 19:14:45,980 [INFO ] 160|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.10013|msssim:0.33552|lpips:0.50203|dists:0.27745|Test loss:22.27180
|
| 276 |
+
2024-05-04 19:35:55,093 [INFO ] 161|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19999|msssim:0.34075|lpips:0.50399|dists:0.28136|Test loss:22.28090
|
| 277 |
+
2024-05-04 19:57:04,549 [INFO ] 162|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35397|msssim:0.34255|lpips:0.50396|dists:0.28067|Test loss:22.30754
|
| 278 |
+
2024-05-04 20:18:09,435 [INFO ] 163|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.07471|msssim:0.33934|lpips:0.50488|dists:0.27944|Test loss:22.29782
|
| 279 |
+
2024-05-04 20:19:36,675 [INFO ] 164|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15191|msssim:0.33803|lpips:0.50363|dists:0.28080|Test loss:22.30785
|
| 280 |
+
2024-05-04 20:40:45,907 [INFO ] 165|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35093|msssim:0.34222|lpips:0.50133|dists:0.27822|Test loss:22.23812
|
| 281 |
+
2024-05-04 21:02:06,997 [INFO ] 166|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14147|msssim:0.34612|lpips:0.50051|dists:0.27732|Test loss:22.26769
|
| 282 |
+
2024-05-04 21:23:13,250 [INFO ] 167|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28212|msssim:0.34412|lpips:0.50155|dists:0.27944|Test loss:22.26505
|
| 283 |
+
2024-05-04 21:44:09,446 [INFO ] 168|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23860|msssim:0.34012|lpips:0.50238|dists:0.27605|Test loss:22.24193
|
| 284 |
+
2024-05-04 22:04:20,528 [INFO ] 169|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18714|msssim:0.34326|lpips:0.50174|dists:0.27552|Test loss:22.19141
|
| 285 |
+
2024-05-04 22:24:52,856 [INFO ] 170|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23269|msssim:0.34707|lpips:0.50277|dists:0.28087|Test loss:22.25782
|
| 286 |
+
2024-05-04 22:46:02,786 [INFO ] 171|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34741|msssim:0.34084|lpips:0.50177|dists:0.27559|Test loss:22.23672
|
| 287 |
+
2024-05-04 23:06:52,141 [INFO ] 172|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28709|msssim:0.34182|lpips:0.50183|dists:0.27629|Test loss:22.25575
|
| 288 |
+
2024-05-04 23:27:45,873 [INFO ] 173|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32183|msssim:0.34279|lpips:0.50424|dists:0.27647|Test loss:22.26211
|
| 289 |
+
2024-05-04 23:49:09,385 [INFO ] 174|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27799|msssim:0.34067|lpips:0.50149|dists:0.27606|Test loss:22.25085
|
| 290 |
+
2024-05-05 00:09:54,865 [INFO ] 175|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23614|msssim:0.34026|lpips:0.50595|dists:0.27867|Test loss:22.26061
|
| 291 |
+
2024-05-05 00:30:37,643 [INFO ] 176|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26992|msssim:0.34440|lpips:0.50466|dists:0.27647|Test loss:22.26275
|
| 292 |
+
2024-05-05 00:51:58,336 [INFO ] 177|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.27768|msssim:0.34192|lpips:0.50217|dists:0.27718|Test loss:22.24697
|
| 293 |
+
2024-05-05 01:12:41,638 [INFO ] 178|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.22162|msssim:0.34181|lpips:0.50407|dists:0.28091|Test loss:22.28326
|
| 294 |
+
2024-05-05 01:33:33,913 [INFO ] 179|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33287|msssim:0.34137|lpips:0.50010|dists:0.27889|Test loss:22.25626
|
| 295 |
+
2024-05-05 01:54:24,919 [INFO ] 180|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26020|msssim:0.33738|lpips:0.50198|dists:0.27898|Test loss:22.27012
|
| 296 |
+
2024-05-05 02:15:11,781 [INFO ] 181|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24128|msssim:0.34296|lpips:0.50206|dists:0.27911|Test loss:22.27168
|
| 297 |
+
2024-05-05 02:36:00,258 [INFO ] 182|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.29608|msssim:0.34143|lpips:0.50258|dists:0.27963|Test loss:22.23007
|
| 298 |
+
2024-05-05 02:56:51,997 [INFO ] 183|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14427|msssim:0.33923|lpips:0.50353|dists:0.27847|Test loss:22.26153
|
| 299 |
+
2024-05-05 03:17:42,729 [INFO ] 184|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.24628|msssim:0.33477|lpips:0.50120|dists:0.27460|Test loss:22.22222
|
| 300 |
+
2024-05-05 03:38:55,556 [INFO ] 185|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22856|msssim:0.33995|lpips:0.49928|dists:0.27506|Test loss:22.21741
|
| 301 |
+
2024-05-05 03:59:54,038 [INFO ] 186|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25829|msssim:0.34396|lpips:0.50041|dists:0.27953|Test loss:22.21000
|
| 302 |
+
2024-05-05 04:20:57,361 [INFO ] 187|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33900|msssim:0.34563|lpips:0.50119|dists:0.27801|Test loss:22.24254
|
| 303 |
+
2024-05-05 04:41:56,894 [INFO ] 188|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25541|msssim:0.34364|lpips:0.50117|dists:0.27644|Test loss:22.22793
|
| 304 |
+
2024-05-05 05:02:57,947 [INFO ] 189|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.29414|msssim:0.34397|lpips:0.50046|dists:0.27675|Test loss:22.22504
|
| 305 |
+
2024-05-05 05:24:00,200 [INFO ] 190|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42643|msssim:0.34251|lpips:0.50034|dists:0.27756|Test loss:22.20704
|
| 306 |
+
2024-05-05 05:45:06,616 [INFO ] 191|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24620|msssim:0.33765|lpips:0.50179|dists:0.27763|Test loss:22.24989
|
| 307 |
+
2024-05-05 06:05:48,232 [INFO ] 192|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20766|msssim:0.34222|lpips:0.50102|dists:0.27429|Test loss:22.20998
|
| 308 |
+
2024-05-05 06:26:41,032 [INFO ] 193|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28995|msssim:0.34232|lpips:0.50011|dists:0.27672|Test loss:22.19870
|
| 309 |
+
2024-05-05 06:47:28,088 [INFO ] 194|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31111|msssim:0.34537|lpips:0.50108|dists:0.27792|Test loss:22.22558
|
| 310 |
+
2024-05-05 07:08:15,894 [INFO ] 195|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19613|msssim:0.34356|lpips:0.50329|dists:0.27837|Test loss:22.27292
|
| 311 |
+
2024-05-05 07:29:07,588 [INFO ] 196|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34946|msssim:0.34073|lpips:0.50324|dists:0.27687|Test loss:22.25968
|
| 312 |
+
2024-05-05 07:30:34,517 [INFO ] 197|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28973|msssim:0.34913|lpips:0.49958|dists:0.27724|Test loss:22.21089
|
| 313 |
+
2024-05-05 07:51:24,205 [INFO ] 198|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.22106|msssim:0.33796|lpips:0.50167|dists:0.27715|Test loss:22.27205
|
| 314 |
+
2024-05-05 08:12:37,156 [INFO ] 199|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26782|msssim:0.34472|lpips:0.50100|dists:0.27782|Test loss:22.21741
|
| 315 |
+
2024-05-05 08:33:49,412 [INFO ] 200|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28700|msssim:0.34200|lpips:0.50071|dists:0.27705|Test loss:22.24541
|
| 316 |
+
2024-05-05 08:55:00,482 [INFO ] 201|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26943|msssim:0.34567|lpips:0.50009|dists:0.27623|Test loss:22.23442
|
| 317 |
+
2024-05-05 09:15:55,407 [INFO ] 202|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28424|msssim:0.33845|lpips:0.50023|dists:0.27742|Test loss:22.21661
|
| 318 |
+
2024-05-05 09:37:07,037 [INFO ] 203|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36869|msssim:0.34921|lpips:0.49663|dists:0.27317|Test loss:22.18738
|
| 319 |
+
2024-05-05 09:58:22,992 [INFO ] 204|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29465|msssim:0.34732|lpips:0.49851|dists:0.27560|Test loss:22.21542
|
| 320 |
+
2024-05-05 10:19:30,345 [INFO ] 205|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34941|msssim:0.34522|lpips:0.49991|dists:0.27673|Test loss:22.23310
|
| 321 |
+
2024-05-05 10:40:34,778 [INFO ] 206|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29729|msssim:0.34074|lpips:0.49963|dists:0.27439|Test loss:22.18471
|
| 322 |
+
2024-05-05 11:01:52,830 [INFO ] 207|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20647|msssim:0.34440|lpips:0.50193|dists:0.27544|Test loss:22.16906
|
| 323 |
+
2024-05-05 11:23:10,869 [INFO ] 208|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33049|msssim:0.34167|lpips:0.49881|dists:0.27573|Test loss:22.22115
|
| 324 |
+
2024-05-05 11:44:29,730 [INFO ] 209|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36687|msssim:0.34201|lpips:0.50120|dists:0.27688|Test loss:22.21192
|
| 325 |
+
2024-05-05 12:05:44,940 [INFO ] 210|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19353|msssim:0.33601|lpips:0.50266|dists:0.27510|Test loss:22.24716
|
| 326 |
+
2024-05-05 12:27:06,882 [INFO ] 211|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31839|msssim:0.34303|lpips:0.49930|dists:0.27422|Test loss:22.19657
|
| 327 |
+
2024-05-05 12:48:21,046 [INFO ] 212|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42858|msssim:0.34387|lpips:0.49946|dists:0.27734|Test loss:22.23478
|
| 328 |
+
2024-05-05 13:09:31,493 [INFO ] 213|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31374|msssim:0.34684|lpips:0.49643|dists:0.27400|Test loss:22.22582
|
| 329 |
+
2024-05-05 13:30:34,522 [INFO ] 214|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21293|msssim:0.34112|lpips:0.49883|dists:0.27535|Test loss:22.21174
|
| 330 |
+
2024-05-05 13:51:47,865 [INFO ] 215|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41841|msssim:0.34328|lpips:0.49828|dists:0.27544|Test loss:22.16456
|
| 331 |
+
2024-05-05 14:13:13,905 [INFO ] 216|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25790|msssim:0.34093|lpips:0.49955|dists:0.27657|Test loss:22.21031
|
| 332 |
+
2024-05-05 14:34:27,772 [INFO ] 217|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25331|msssim:0.34462|lpips:0.50060|dists:0.27798|Test loss:22.19541
|
| 333 |
+
2024-05-05 14:55:44,043 [INFO ] 218|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34120|msssim:0.34909|lpips:0.49698|dists:0.27423|Test loss:22.18064
|
| 334 |
+
2024-05-05 15:17:02,181 [INFO ] 219|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37122|msssim:0.34335|lpips:0.49592|dists:0.27407|Test loss:22.21357
|
| 335 |
+
2024-05-05 15:38:13,952 [INFO ] 220|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33670|msssim:0.34260|lpips:0.50183|dists:0.27681|Test loss:22.21649
|
| 336 |
+
2024-05-05 15:59:11,744 [INFO ] 221|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28591|msssim:0.34584|lpips:0.50050|dists:0.27571|Test loss:22.23926
|
| 337 |
+
2024-05-05 16:20:07,029 [INFO ] 222|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26039|msssim:0.34677|lpips:0.49812|dists:0.27303|Test loss:22.18227
|
| 338 |
+
2024-05-05 16:41:01,438 [INFO ] 223|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35407|msssim:0.34511|lpips:0.49698|dists:0.27452|Test loss:22.20431
|
| 339 |
+
2024-05-05 17:01:56,573 [INFO ] 224|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36067|msssim:0.34552|lpips:0.49578|dists:0.27214|Test loss:22.16769
|
| 340 |
+
2024-05-05 17:23:09,208 [INFO ] 225|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34319|msssim:0.34962|lpips:0.49547|dists:0.27391|Test loss:22.15918
|
| 341 |
+
2024-05-05 17:44:29,823 [INFO ] 226|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21200|msssim:0.34131|lpips:0.49971|dists:0.27562|Test loss:22.19288
|
| 342 |
+
2024-05-05 18:05:38,805 [INFO ] 227|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34673|msssim:0.34305|lpips:0.49692|dists:0.27394|Test loss:22.18381
|
| 343 |
+
2024-05-05 18:26:44,515 [INFO ] 228|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26413|msssim:0.33844|lpips:0.50093|dists:0.27558|Test loss:22.21518
|
| 344 |
+
2024-05-05 18:47:55,570 [INFO ] 229|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20001|msssim:0.34085|lpips:0.50149|dists:0.27539|Test loss:22.16549
|
| 345 |
+
2024-05-05 18:49:21,297 [INFO ] 230|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24847|msssim:0.34171|lpips:0.50002|dists:0.27554|Test loss:22.18368
|
| 346 |
+
2024-05-05 19:10:31,373 [INFO ] 231|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26719|msssim:0.34548|lpips:0.50115|dists:0.27678|Test loss:22.23358
|
| 347 |
+
2024-05-05 19:31:27,088 [INFO ] 232|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29815|msssim:0.34711|lpips:0.49873|dists:0.27557|Test loss:22.20006
|
| 348 |
+
2024-05-05 19:52:35,830 [INFO ] 233|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45504|msssim:0.35044|lpips:0.50058|dists:0.27448|Test loss:22.19251
|
| 349 |
+
2024-05-05 20:13:42,913 [INFO ] 234|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36103|msssim:0.34949|lpips:0.49969|dists:0.27417|Test loss:22.19380
|
| 350 |
+
2024-05-05 20:34:48,991 [INFO ] 235|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43152|msssim:0.34284|lpips:0.49837|dists:0.27674|Test loss:22.17983
|
| 351 |
+
2024-05-05 20:55:57,801 [INFO ] 236|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31431|msssim:0.34494|lpips:0.49805|dists:0.27595|Test loss:22.18135
|
| 352 |
+
2024-05-05 21:17:01,921 [INFO ] 237|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39373|msssim:0.34481|lpips:0.49903|dists:0.27500|Test loss:22.19592
|
| 353 |
+
2024-05-05 21:38:15,187 [INFO ] 238|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32557|msssim:0.34869|lpips:0.50099|dists:0.27830|Test loss:22.18304
|
| 354 |
+
2024-05-05 21:59:20,693 [INFO ] 239|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35618|msssim:0.34089|lpips:0.49914|dists:0.27589|Test loss:22.22733
|
| 355 |
+
2024-05-05 22:20:28,843 [INFO ] 240|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38320|msssim:0.34359|lpips:0.49845|dists:0.27558|Test loss:22.20330
|
| 356 |
+
2024-05-05 22:41:33,103 [INFO ] 241|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41832|msssim:0.34829|lpips:0.49624|dists:0.27300|Test loss:22.20170
|
| 357 |
+
2024-05-05 23:02:36,039 [INFO ] 242|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26759|msssim:0.34710|lpips:0.49868|dists:0.27406|Test loss:22.18713
|
| 358 |
+
2024-05-05 23:23:41,292 [INFO ] 243|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25480|msssim:0.33822|lpips:0.49764|dists:0.27286|Test loss:22.20413
|
| 359 |
+
2024-05-05 23:44:51,601 [INFO ] 244|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24137|msssim:0.34107|lpips:0.49667|dists:0.27336|Test loss:22.19491
|
| 360 |
+
2024-05-06 00:05:59,368 [INFO ] 245|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.44126|msssim:0.34903|lpips:0.49630|dists:0.27284|Test loss:22.21318
|
| 361 |
+
2024-05-06 00:26:57,417 [INFO ] 246|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35589|msssim:0.34458|lpips:0.49846|dists:0.27602|Test loss:22.18793
|
| 362 |
+
2024-05-06 00:48:01,606 [INFO ] 247|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29865|msssim:0.34271|lpips:0.49826|dists:0.27014|Test loss:22.17141
|
| 363 |
+
2024-05-06 01:09:07,370 [INFO ] 248|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31535|msssim:0.34943|lpips:0.49795|dists:0.27377|Test loss:22.17951
|
| 364 |
+
2024-05-06 01:30:20,686 [INFO ] 249|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39809|msssim:0.34563|lpips:0.50040|dists:0.27597|Test loss:22.16323
|
| 365 |
+
2024-05-06 01:51:22,543 [INFO ] 250|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28734|msssim:0.35247|lpips:0.49776|dists:0.27438|Test loss:22.14602
|
| 366 |
+
2024-05-06 02:12:51,537 [INFO ] 251|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28380|msssim:0.34202|lpips:0.49683|dists:0.27599|Test loss:22.20221
|
| 367 |
+
2024-05-06 02:33:50,217 [INFO ] 252|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33769|msssim:0.35137|lpips:0.49936|dists:0.27664|Test loss:22.18598
|
| 368 |
+
2024-05-06 02:55:05,177 [INFO ] 253|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30681|msssim:0.34182|lpips:0.50133|dists:0.27604|Test loss:22.19516
|
| 369 |
+
2024-05-06 03:16:12,944 [INFO ] 254|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39329|msssim:0.34256|lpips:0.49706|dists:0.27318|Test loss:22.16968
|
| 370 |
+
2024-05-06 03:37:19,768 [INFO ] 255|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.44825|msssim:0.34998|lpips:0.49653|dists:0.27555|Test loss:22.16102
|
| 371 |
+
2024-05-06 03:58:30,008 [INFO ] 256|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30121|msssim:0.34926|lpips:0.49515|dists:0.27333|Test loss:22.14309
|
| 372 |
+
2024-05-06 04:19:52,471 [INFO ] 257|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35113|msssim:0.34655|lpips:0.49584|dists:0.27406|Test loss:22.19664
|
| 373 |
+
2024-05-06 04:40:58,643 [INFO ] 258|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34788|msssim:0.34555|lpips:0.49748|dists:0.27429|Test loss:22.15390
|
| 374 |
+
2024-05-06 05:02:09,625 [INFO ] 259|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37536|msssim:0.34993|lpips:0.49926|dists:0.27615|Test loss:22.16938
|
| 375 |
+
2024-05-06 05:23:15,111 [INFO ] 260|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.30780|msssim:0.34366|lpips:0.49741|dists:0.27334|Test loss:22.14079
|
| 376 |
+
2024-05-06 05:44:58,198 [INFO ] 261|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32658|msssim:0.34584|lpips:0.49735|dists:0.27514|Test loss:22.15270
|
| 377 |
+
2024-05-06 06:06:38,107 [INFO ] 262|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34172|msssim:0.34937|lpips:0.49673|dists:0.27184|Test loss:22.16732
|
| 378 |
+
2024-05-06 06:08:06,494 [INFO ] 263|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34207|msssim:0.34853|lpips:0.49675|dists:0.27406|Test loss:22.15906
|
| 379 |
+
2024-05-06 06:29:36,637 [INFO ] 264|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.46720|msssim:0.34606|lpips:0.49734|dists:0.27330|Test loss:22.16713
|
| 380 |
+
2024-05-06 06:50:49,014 [INFO ] 265|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39163|msssim:0.34577|lpips:0.49680|dists:0.27300|Test loss:22.15004
|
| 381 |
+
2024-05-06 07:11:43,798 [INFO ] 266|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34610|msssim:0.34627|lpips:0.49790|dists:0.27368|Test loss:22.18821
|
| 382 |
+
2024-05-06 07:32:32,314 [INFO ] 267|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35793|msssim:0.34650|lpips:0.49353|dists:0.27379|Test loss:22.15183
|
| 383 |
+
2024-05-06 07:53:23,679 [INFO ] 268|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35679|msssim:0.34936|lpips:0.49728|dists:0.27346|Test loss:22.14839
|
| 384 |
+
2024-05-06 08:14:16,352 [INFO ] 269|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33630|msssim:0.34671|lpips:0.49758|dists:0.27488|Test loss:22.11851
|
| 385 |
+
2024-05-06 08:35:41,030 [INFO ] 270|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.50002|msssim:0.35202|lpips:0.49321|dists:0.27390|Test loss:22.13897
|
| 386 |
+
2024-05-06 08:55:44,311 [INFO ] 271|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26539|msssim:0.33970|lpips:0.49583|dists:0.27329|Test loss:22.18124
|
| 387 |
+
2024-05-06 09:15:39,615 [INFO ] 272|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35656|msssim:0.34859|lpips:0.49631|dists:0.27369|Test loss:22.16513
|
| 388 |
+
2024-05-06 09:35:34,569 [INFO ] 273|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37810|msssim:0.34932|lpips:0.49532|dists:0.27358|Test loss:22.14227
|
| 389 |
+
2024-05-06 09:55:25,915 [INFO ] 274|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33323|msssim:0.34434|lpips:0.49772|dists:0.27299|Test loss:22.13977
|
| 390 |
+
2024-05-06 10:15:09,984 [INFO ] 275|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29836|msssim:0.34814|lpips:0.49551|dists:0.27238|Test loss:22.17381
|
| 391 |
+
2024-05-06 10:34:55,762 [INFO ] 276|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41804|msssim:0.34762|lpips:0.49558|dists:0.27152|Test loss:22.13700
|
| 392 |
+
2024-05-06 10:54:41,271 [INFO ] 277|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34791|msssim:0.34759|lpips:0.49562|dists:0.27405|Test loss:22.12424
|
| 393 |
+
2024-05-06 11:14:24,991 [INFO ] 278|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43791|msssim:0.34807|lpips:0.49478|dists:0.27218|Test loss:22.15800
|
| 394 |
+
2024-05-06 11:34:06,688 [INFO ] 279|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37573|msssim:0.34698|lpips:0.49517|dists:0.27491|Test loss:22.17768
|
| 395 |
+
2024-05-06 11:53:52,546 [INFO ] 280|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.32452|msssim:0.34915|lpips:0.49737|dists:0.27415|Test loss:22.14754
|
| 396 |
+
2024-05-06 12:13:42,463 [INFO ] 281|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23457|msssim:0.34536|lpips:0.49823|dists:0.27429|Test loss:22.18300
|
| 397 |
+
2024-05-06 12:33:26,800 [INFO ] 282|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32926|msssim:0.34179|lpips:0.49565|dists:0.27257|Test loss:22.12766
|
| 398 |
+
2024-05-06 12:53:12,784 [INFO ] 283|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37897|msssim:0.34765|lpips:0.49597|dists:0.27514|Test loss:22.13906
|
| 399 |
+
2024-05-06 13:13:01,155 [INFO ] 284|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35453|msssim:0.34555|lpips:0.49322|dists:0.27190|Test loss:22.15379
|
| 400 |
+
2024-05-06 13:32:55,503 [INFO ] 285|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25997|msssim:0.34366|lpips:0.49811|dists:0.27320|Test loss:22.14294
|
| 401 |
+
2024-05-06 13:53:05,337 [INFO ] 286|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33252|msssim:0.34884|lpips:0.49568|dists:0.27505|Test loss:22.14334
|
| 402 |
+
2024-05-06 14:12:57,225 [INFO ] 287|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37169|msssim:0.34353|lpips:0.49783|dists:0.27417|Test loss:22.14881
|
| 403 |
+
2024-05-06 14:32:49,867 [INFO ] 288|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28825|msssim:0.34587|lpips:0.49870|dists:0.27497|Test loss:22.16187
|
| 404 |
+
2024-05-06 14:52:38,558 [INFO ] 289|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37071|msssim:0.34675|lpips:0.49692|dists:0.27600|Test loss:22.16026
|
| 405 |
+
2024-05-06 15:12:23,303 [INFO ] 290|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37343|msssim:0.35109|lpips:0.49530|dists:0.27273|Test loss:22.14378
|
| 406 |
+
2024-05-06 15:32:11,252 [INFO ] 291|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37243|msssim:0.34936|lpips:0.49572|dists:0.27396|Test loss:22.17331
|
| 407 |
+
2024-05-06 15:51:58,068 [INFO ] 292|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.55883|msssim:0.35747|lpips:0.49342|dists:0.27048|Test loss:22.13242
|
| 408 |
+
2024-05-06 16:11:44,645 [INFO ] 293|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33236|msssim:0.34500|lpips:0.49713|dists:0.27188|Test loss:22.13761
|
| 409 |
+
2024-05-06 16:31:32,164 [INFO ] 294|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37576|msssim:0.35218|lpips:0.49406|dists:0.27047|Test loss:22.14124
|
| 410 |
+
2024-05-06 16:51:17,715 [INFO ] 295|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38965|msssim:0.34790|lpips:0.49487|dists:0.27294|Test loss:22.12060
|
| 411 |
+
2024-05-06 16:52:39,431 [INFO ] 296|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32517|msssim:0.35096|lpips:0.49530|dists:0.27325|Test loss:22.13534
|
| 412 |
+
2024-05-06 17:12:41,901 [INFO ] 297|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38625|msssim:0.35519|lpips:0.49261|dists:0.27088|Test loss:22.13709
|
| 413 |
+
2024-05-06 17:32:30,223 [INFO ] 298|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.45197|msssim:0.35502|lpips:0.49372|dists:0.27574|Test loss:22.10319
|
| 414 |
+
2024-05-06 17:53:02,560 [INFO ] 299|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27800|msssim:0.34538|lpips:0.49640|dists:0.27273|Test loss:22.16596
|
| 415 |
+
2024-05-06 18:13:17,675 [INFO ] 300|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43854|msssim:0.35064|lpips:0.49498|dists:0.27317|Test loss:22.13541
|
| 416 |
+
2024-05-06 18:33:40,420 [INFO ] 301|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42806|msssim:0.35004|lpips:0.49591|dists:0.27262|Test loss:22.13021
|
| 417 |
+
2024-05-06 18:54:35,721 [INFO ] 302|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40629|msssim:0.35309|lpips:0.49433|dists:0.27251|Test loss:22.10585
|
| 418 |
+
2024-05-06 19:15:23,438 [INFO ] 303|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43635|msssim:0.35212|lpips:0.49461|dists:0.27331|Test loss:22.09783
|
| 419 |
+
2024-05-06 19:36:11,827 [INFO ] 304|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41165|msssim:0.34843|lpips:0.49495|dists:0.27378|Test loss:22.13802
|
| 420 |
+
2024-05-06 19:56:38,743 [INFO ] 305|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36958|msssim:0.35154|lpips:0.49564|dists:0.27358|Test loss:22.18889
|
| 421 |
+
2024-05-06 20:17:02,992 [INFO ] 306|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36006|msssim:0.34681|lpips:0.49732|dists:0.27467|Test loss:22.16472
|
| 422 |
+
2024-05-06 20:37:27,146 [INFO ] 307|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33619|msssim:0.34884|lpips:0.49251|dists:0.27202|Test loss:22.10112
|
| 423 |
+
2024-05-06 20:57:51,479 [INFO ] 308|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38734|msssim:0.34704|lpips:0.49590|dists:0.27333|Test loss:22.16504
|
| 424 |
+
2024-05-06 21:17:40,325 [INFO ] 309|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.47638|msssim:0.34979|lpips:0.49419|dists:0.27257|Test loss:22.09890
|
| 425 |
+
2024-05-06 21:38:20,823 [INFO ] 310|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45121|msssim:0.35461|lpips:0.49393|dists:0.27205|Test loss:22.10250
|
| 426 |
+
2024-05-06 21:58:54,290 [INFO ] 311|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42236|msssim:0.34856|lpips:0.49429|dists:0.27074|Test loss:22.12690
|
| 427 |
+
2024-05-06 22:19:28,569 [INFO ] 312|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42022|msssim:0.34901|lpips:0.49516|dists:0.27319|Test loss:22.14768
|
| 428 |
+
2024-05-06 22:40:03,891 [INFO ] 313|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35907|msssim:0.34569|lpips:0.49461|dists:0.27076|Test loss:22.15050
|
| 429 |
+
2024-05-06 23:00:39,129 [INFO ] 314|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28936|msssim:0.34797|lpips:0.49777|dists:0.27248|Test loss:22.13753
|
| 430 |
+
2024-05-06 23:21:22,649 [INFO ] 315|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29058|msssim:0.35014|lpips:0.49607|dists:0.27623|Test loss:22.14135
|
| 431 |
+
2024-05-06 23:42:05,520 [INFO ] 316|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37230|msssim:0.34671|lpips:0.49491|dists:0.27353|Test loss:22.08937
|
| 432 |
+
2024-05-07 00:02:59,358 [INFO ] 317|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33447|msssim:0.35346|lpips:0.49475|dists:0.27089|Test loss:22.10210
|
| 433 |
+
2024-05-07 00:23:07,579 [INFO ] 318|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49132|msssim:0.35204|lpips:0.49550|dists:0.27231|Test loss:22.11998
|
| 434 |
+
2024-05-07 00:42:57,724 [INFO ] 319|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37582|msssim:0.35174|lpips:0.49482|dists:0.27242|Test loss:22.16880
|
| 435 |
+
2024-05-07 01:02:46,713 [INFO ] 320|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35898|msssim:0.35062|lpips:0.49379|dists:0.27271|Test loss:22.11210
|
| 436 |
+
2024-05-07 01:22:57,462 [INFO ] 321|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.32817|msssim:0.35568|lpips:0.49295|dists:0.26980|Test loss:22.14307
|
| 437 |
+
2024-05-07 01:42:55,818 [INFO ] 322|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40482|msssim:0.34941|lpips:0.49731|dists:0.27440|Test loss:22.12967
|
| 438 |
+
2024-05-07 02:03:35,260 [INFO ] 323|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.48860|msssim:0.35103|lpips:0.49498|dists:0.27326|Test loss:22.12212
|
| 439 |
+
2024-05-07 02:23:46,368 [INFO ] 324|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41354|msssim:0.35028|lpips:0.49496|dists:0.27339|Test loss:22.11742
|
| 440 |
+
2024-05-07 02:43:38,446 [INFO ] 325|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37060|msssim:0.34866|lpips:0.49335|dists:0.27179|Test loss:22.13050
|
| 441 |
+
2024-05-07 03:03:35,008 [INFO ] 326|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.52411|msssim:0.35067|lpips:0.49280|dists:0.27236|Test loss:22.11754
|
| 442 |
+
2024-05-07 03:23:27,535 [INFO ] 327|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39696|msssim:0.35174|lpips:0.49652|dists:0.27232|Test loss:22.12615
|
| 443 |
+
2024-05-07 03:43:44,363 [INFO ] 328|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39724|msssim:0.34908|lpips:0.49480|dists:0.27135|Test loss:22.14223
|
| 444 |
+
2024-05-07 03:45:05,759 [INFO ] 329|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40106|msssim:0.35324|lpips:0.49468|dists:0.27436|Test loss:22.10712
|
| 445 |
+
2024-05-07 04:04:47,920 [INFO ] 330|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32596|msssim:0.34426|lpips:0.49445|dists:0.26987|Test loss:22.13806
|
| 446 |
+
2024-05-07 04:24:31,360 [INFO ] 331|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26787|msssim:0.34595|lpips:0.49652|dists:0.27052|Test loss:22.12575
|
| 447 |
+
2024-05-07 04:44:11,953 [INFO ] 332|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43838|msssim:0.34961|lpips:0.49307|dists:0.27184|Test loss:22.11797
|
| 448 |
+
2024-05-07 05:03:51,738 [INFO ] 333|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42003|msssim:0.35362|lpips:0.49367|dists:0.27208|Test loss:22.11651
|
| 449 |
+
2024-05-07 05:23:34,532 [INFO ] 334|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37144|msssim:0.35248|lpips:0.49673|dists:0.27496|Test loss:22.14811
|
| 450 |
+
2024-05-07 05:43:17,906 [INFO ] 335|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.41272|msssim:0.35291|lpips:0.49492|dists:0.27193|Test loss:22.12853
|
| 451 |
+
2024-05-07 06:03:26,128 [INFO ] 336|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31984|msssim:0.34795|lpips:0.49517|dists:0.27108|Test loss:22.11144
|
| 452 |
+
2024-05-07 06:23:14,028 [INFO ] 337|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35224|msssim:0.35102|lpips:0.49385|dists:0.27123|Test loss:22.12615
|
| 453 |
+
2024-05-07 06:43:29,005 [INFO ] 338|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34989|msssim:0.34449|lpips:0.49436|dists:0.27227|Test loss:22.14807
|
| 454 |
+
2024-05-07 07:03:22,585 [INFO ] 339|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35900|msssim:0.35273|lpips:0.49441|dists:0.27247|Test loss:22.11002
|
| 455 |
+
2024-05-07 07:23:32,865 [INFO ] 340|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43373|msssim:0.35449|lpips:0.49218|dists:0.27258|Test loss:22.09663
|
| 456 |
+
2024-05-07 07:44:08,078 [INFO ] 341|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49582|msssim:0.35435|lpips:0.49275|dists:0.27220|Test loss:22.09673
|
| 457 |
+
2024-05-07 08:05:12,924 [INFO ] 342|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27418|msssim:0.35315|lpips:0.49452|dists:0.27191|Test loss:22.11007
|
| 458 |
+
2024-05-07 08:26:23,259 [INFO ] 343|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35699|msssim:0.34731|lpips:0.49411|dists:0.27074|Test loss:22.13691
|
| 459 |
+
2024-05-07 08:47:40,259 [INFO ] 344|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43843|msssim:0.35445|lpips:0.49354|dists:0.27134|Test loss:22.09072
|
| 460 |
+
2024-05-07 09:08:57,546 [INFO ] 345|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40076|msssim:0.35221|lpips:0.49355|dists:0.27157|Test loss:22.12100
|
| 461 |
+
2024-05-07 09:30:12,427 [INFO ] 346|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36718|msssim:0.35203|lpips:0.49355|dists:0.27277|Test loss:22.08085
|
| 462 |
+
2024-05-07 09:50:36,931 [INFO ] 347|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34232|msssim:0.34843|lpips:0.49354|dists:0.27177|Test loss:22.12479
|
| 463 |
+
2024-05-07 10:10:45,355 [INFO ] 348|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42333|msssim:0.34844|lpips:0.49300|dists:0.27182|Test loss:22.11209
|
| 464 |
+
2024-05-07 10:30:50,802 [INFO ] 349|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42717|msssim:0.35086|lpips:0.49458|dists:0.27253|Test loss:22.11386
|
| 465 |
+
2024-05-07 10:50:55,389 [INFO ] 350|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35927|msssim:0.34962|lpips:0.49381|dists:0.26980|Test loss:22.10434
|
| 466 |
+
2024-05-07 11:11:33,800 [INFO ] 351|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29633|msssim:0.34823|lpips:0.49724|dists:0.27674|Test loss:22.14577
|
| 467 |
+
2024-05-07 11:32:25,539 [INFO ] 352|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36162|msssim:0.35235|lpips:0.49426|dists:0.26995|Test loss:22.10663
|
| 468 |
+
2024-05-07 11:53:24,658 [INFO ] 353|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37880|msssim:0.35225|lpips:0.49382|dists:0.26963|Test loss:22.11384
|
| 469 |
+
2024-05-07 12:14:36,359 [INFO ] 354|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40548|msssim:0.34668|lpips:0.49391|dists:0.27245|Test loss:22.14980
|
| 470 |
+
2024-05-07 12:35:38,612 [INFO ] 355|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42286|msssim:0.35354|lpips:0.49382|dists:0.27242|Test loss:22.12479
|
| 471 |
+
2024-05-07 12:56:38,507 [INFO ] 356|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33143|msssim:0.35170|lpips:0.49645|dists:0.27390|Test loss:22.13347
|
| 472 |
+
2024-05-07 13:17:39,596 [INFO ] 357|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36672|msssim:0.34573|lpips:0.49592|dists:0.27143|Test loss:22.13349
|
| 473 |
+
2024-05-07 13:38:44,893 [INFO ] 358|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42944|msssim:0.35298|lpips:0.49284|dists:0.27126|Test loss:22.12769
|
| 474 |
+
2024-05-07 13:59:43,389 [INFO ] 359|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38038|msssim:0.35037|lpips:0.49457|dists:0.27234|Test loss:22.09782
|
| 475 |
+
2024-05-07 14:20:17,253 [INFO ] 360|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38950|msssim:0.35237|lpips:0.49389|dists:0.27129|Test loss:22.09906
|
| 476 |
+
2024-05-07 14:40:43,289 [INFO ] 361|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43936|msssim:0.34640|lpips:0.49196|dists:0.27161|Test loss:22.11428
|
| 477 |
+
2024-05-07 14:42:07,243 [INFO ] 362|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40003|msssim:0.35458|lpips:0.49361|dists:0.27126|Test loss:22.08922
|
| 478 |
+
2024-05-07 15:02:21,429 [INFO ] 363|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.48445|msssim:0.35553|lpips:0.49189|dists:0.27160|Test loss:22.08326
|
| 479 |
+
2024-05-07 15:22:34,155 [INFO ] 364|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35317|msssim:0.35344|lpips:0.49614|dists:0.27294|Test loss:22.15190
|
| 480 |
+
2024-05-07 15:42:52,603 [INFO ] 365|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42529|msssim:0.35113|lpips:0.49577|dists:0.27466|Test loss:22.11909
|
| 481 |
+
2024-05-07 16:03:38,749 [INFO ] 366|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33904|msssim:0.35202|lpips:0.49448|dists:0.27241|Test loss:22.12104
|
| 482 |
+
2024-05-07 16:24:24,975 [INFO ] 367|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42592|msssim:0.35292|lpips:0.49367|dists:0.27379|Test loss:22.11291
|
| 483 |
+
2024-05-07 16:44:56,650 [INFO ] 368|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39894|msssim:0.35064|lpips:0.49435|dists:0.27276|Test loss:22.12427
|
| 484 |
+
2024-05-07 17:05:25,010 [INFO ] 369|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.48808|msssim:0.35531|lpips:0.49202|dists:0.27039|Test loss:22.11523
|
| 485 |
+
2024-05-07 17:25:58,021 [INFO ] 370|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39756|msssim:0.35016|lpips:0.49330|dists:0.27168|Test loss:22.10594
|
| 486 |
+
2024-05-07 17:46:18,607 [INFO ] 371|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49306|msssim:0.35298|lpips:0.49468|dists:0.27251|Test loss:22.11654
|
| 487 |
+
2024-05-07 18:06:49,127 [INFO ] 372|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34533|msssim:0.35142|lpips:0.49453|dists:0.27251|Test loss:22.12963
|
| 488 |
+
2024-05-07 18:27:08,641 [INFO ] 373|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34550|msssim:0.35208|lpips:0.49310|dists:0.27301|Test loss:22.11551
|
| 489 |
+
2024-05-07 18:47:50,057 [INFO ] 374|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37816|msssim:0.35055|lpips:0.49417|dists:0.27248|Test loss:22.10346
|
| 490 |
+
2024-05-07 19:08:22,282 [INFO ] 375|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36707|msssim:0.34627|lpips:0.49279|dists:0.27090|Test loss:22.07935
|
| 491 |
+
2024-05-07 19:29:07,038 [INFO ] 376|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42175|msssim:0.35269|lpips:0.49278|dists:0.27011|Test loss:22.07009
|
| 492 |
+
2024-05-07 19:49:50,578 [INFO ] 377|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40540|msssim:0.35338|lpips:0.49272|dists:0.27135|Test loss:22.12318
|
| 493 |
+
2024-05-07 20:10:17,385 [INFO ] 378|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39138|msssim:0.34611|lpips:0.49605|dists:0.27153|Test loss:22.13597
|
| 494 |
+
2024-05-07 20:30:46,458 [INFO ] 379|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39260|msssim:0.34934|lpips:0.49563|dists:0.27480|Test loss:22.14463
|
| 495 |
+
2024-05-07 20:51:11,028 [INFO ] 380|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42045|msssim:0.34795|lpips:0.49372|dists:0.26907|Test loss:22.10082
|
| 496 |
+
2024-05-07 21:11:40,325 [INFO ] 381|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27697|msssim:0.34880|lpips:0.49537|dists:0.27470|Test loss:22.09123
|
| 497 |
+
2024-05-07 21:32:25,704 [INFO ] 382|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35365|msssim:0.34741|lpips:0.49410|dists:0.27204|Test loss:22.12331
|
| 498 |
+
2024-05-07 21:53:04,306 [INFO ] 383|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28105|msssim:0.35066|lpips:0.49763|dists:0.27276|Test loss:22.13340
|
| 499 |
+
2024-05-07 22:13:43,291 [INFO ] 384|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31601|msssim:0.34885|lpips:0.49536|dists:0.27251|Test loss:22.12415
|
| 500 |
+
2024-05-07 22:34:19,100 [INFO ] 385|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34422|msssim:0.35122|lpips:0.49521|dists:0.27288|Test loss:22.14277
|
| 501 |
+
2024-05-07 22:54:56,647 [INFO ] 386|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41440|msssim:0.34890|lpips:0.49383|dists:0.27097|Test loss:22.10782
|
| 502 |
+
2024-05-07 23:15:36,868 [INFO ] 387|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45614|msssim:0.35336|lpips:0.49236|dists:0.27192|Test loss:22.09378
|
| 503 |
+
2024-05-07 23:36:07,945 [INFO ] 388|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38647|msssim:0.34605|lpips:0.49631|dists:0.27383|Test loss:22.12287
|
| 504 |
+
2024-05-07 23:56:49,417 [INFO ] 389|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38224|msssim:0.35091|lpips:0.49457|dists:0.27112|Test loss:22.07871
|
| 505 |
+
2024-05-08 00:17:12,737 [INFO ] 390|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43890|msssim:0.35296|lpips:0.49275|dists:0.27223|Test loss:22.13008
|
| 506 |
+
2024-05-08 00:37:32,168 [INFO ] 391|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35077|msssim:0.34966|lpips:0.49267|dists:0.27072|Test loss:22.13726
|
| 507 |
+
2024-05-08 00:58:00,221 [INFO ] 392|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43491|msssim:0.34659|lpips:0.49620|dists:0.27359|Test loss:22.16497
|
| 508 |
+
2024-05-08 01:18:28,620 [INFO ] 393|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.46330|msssim:0.35150|lpips:0.49221|dists:0.27154|Test loss:22.10838
|
| 509 |
+
2024-05-08 01:39:01,279 [INFO ] 394|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.48316|msssim:0.35406|lpips:0.49092|dists:0.26870|Test loss:22.06746
|
| 510 |
+
2024-05-08 01:40:42,930 [INFO ] 395|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39606|msssim:0.35099|lpips:0.49426|dists:0.27250|Test loss:22.11529
|
| 511 |
+
2024-05-08 02:00:57,924 [INFO ] 396|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37678|msssim:0.35195|lpips:0.49500|dists:0.27399|Test loss:22.10022
|
| 512 |
+
2024-05-08 02:21:26,695 [INFO ] 397|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36464|msssim:0.35303|lpips:0.49387|dists:0.27059|Test loss:22.10485
|
| 513 |
+
2024-05-08 02:42:04,184 [INFO ] 398|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42602|msssim:0.34833|lpips:0.49350|dists:0.27095|Test loss:22.11267
|
| 514 |
+
2024-05-08 03:02:23,282 [INFO ] 399|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.46092|msssim:0.34956|lpips:0.49528|dists:0.27270|Test loss:22.10216
|
MIM_vbr/events.out.tfevents.1714742678.t2vg-a100-G4-10.4118013.0
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03724613c625203caa59f67da24b894285dbf610923f2262214ffab63cca8c59
|
| 3 |
+
size 4582056
|
MIM_vbr/log.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
MIM_vbr_eval_lossless/20240518_044257.log
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-05-18 04:42:57,952 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240518_044257.log
|
| 2 |
+
2024-05-18 04:42:57,952 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-05-18 04:42:57,952 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-05-18 04:42:57,952 [INFO ] name:exp
|
| 5 |
+
2024-05-18 04:42:57,952 [INFO ] lr:None
|
| 6 |
+
2024-05-18 04:42:57,952 [INFO ] world_size:1
|
| 7 |
+
2024-05-18 04:42:57,952 [INFO ] local_rank:-1
|
| 8 |
+
2024-05-18 04:42:57,952 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-05-18 04:42:57,952 [INFO ] dist_url:env://
|
| 10 |
+
2024-05-18 04:42:57,952 [INFO ] TEST:False
|
| 11 |
+
2024-05-18 04:42:57,952 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-05-18 04:42:57,953 [INFO ] dataset:imagenet
|
| 13 |
+
2024-05-18 04:42:57,953 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 14 |
+
2024-05-18 04:42:57,953 [INFO ] eval_path:
|
| 15 |
+
2024-05-18 04:42:57,953 [INFO ] batch_size:64
|
| 16 |
+
2024-05-18 04:42:57,953 [INFO ] test_batch_size:200
|
| 17 |
+
2024-05-18 04:42:57,953 [INFO ] num_workers:4
|
| 18 |
+
2024-05-18 04:42:57,953 [INFO ] epochs:400
|
| 19 |
+
2024-05-18 04:42:57,953 [INFO ] blr:0.0002
|
| 20 |
+
2024-05-18 04:42:57,953 [INFO ] min_lr:5e-06
|
| 21 |
+
2024-05-18 04:42:57,953 [INFO ] accum_iter:1
|
| 22 |
+
2024-05-18 04:42:57,953 [INFO ] input_size:256
|
| 23 |
+
2024-05-18 04:42:57,953 [INFO ] weight_decay:0.05
|
| 24 |
+
2024-05-18 04:42:57,953 [INFO ] warmup_epochs:20
|
| 25 |
+
2024-05-18 04:42:57,953 [INFO ] mask_ratio_min:0.5
|
| 26 |
+
2024-05-18 04:42:57,953 [INFO ] mask_ratio_max:0.9
|
| 27 |
+
2024-05-18 04:42:57,953 [INFO ] mask_ratio_mu:0.55
|
| 28 |
+
2024-05-18 04:42:57,953 [INFO ] mask_ratio_std:0.25
|
| 29 |
+
2024-05-18 04:42:57,953 [INFO ] grad_clip:3.0
|
| 30 |
+
2024-05-18 04:42:57,953 [INFO ] device:cuda
|
| 31 |
+
2024-05-18 04:42:57,953 [INFO ] distributed:True
|
| 32 |
+
2024-05-18 04:42:57,953 [INFO ] seed:0
|
| 33 |
+
2024-05-18 04:42:57,953 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
|
| 34 |
+
2024-05-18 04:42:57,953 [INFO ] start_epoch:0
|
| 35 |
+
2024-05-18 04:42:57,953 [INFO ] pin_mem:True
|
| 36 |
+
2024-05-18 04:42:57,953 [INFO ] save:True
|
| 37 |
+
2024-05-18 04:42:57,953 [INFO ] exp_name:MIM_vbr_eval_lossless
|
| 38 |
+
2024-05-18 04:42:57,954 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 39 |
+
2024-05-18 04:42:57,954 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 40 |
+
2024-05-18 04:42:57,954 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 41 |
+
2024-05-18 04:42:57,954 [INFO ] strart_epoch:0
|
| 42 |
+
2024-05-18 04:42:57,954 [INFO ] cuda:True
|
| 43 |
+
2024-05-18 04:42:57,954 [INFO ] rank:0
|
| 44 |
+
2024-05-18 04:42:57,954 [INFO ] gpu:0
|
| 45 |
+
2024-05-18 04:42:57,954 [INFO ] dist_backend:nccl
|
| 46 |
+
2024-05-18 04:42:57,954 [INFO ] ===================================================
|
| 47 |
+
2024-05-18 06:20:16,116 [INFO ] -1|bpp:0.02418|mask:0.00847|mask_ratio:0.6|psnr:14.80285|msssim:0.51750|lpips:0.27254|dists:0.19907|accu:0.63626|Test loss:64.62691
|
| 48 |
+
2024-05-18 07:56:33,739 [INFO ] -1|bpp:0.02791|mask:0.00819|mask_ratio:0.5|psnr:15.67397|msssim:0.56751|lpips:0.23728|dists:0.18073|accu:0.67258|Test loss:139.71758
|
| 49 |
+
2024-05-18 09:30:08,858 [INFO ] -1|bpp:0.02046|mask:0.00875|mask_ratio:0.7|psnr:13.74718|msssim:0.45169|lpips:0.32138|dists:0.22512|accu:0.56262|Test loss:30.11541
|
| 50 |
+
2024-05-18 11:05:40,441 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:13.17649|msssim:0.41227|lpips:0.35134|dists:0.24154|accu:0.50616|Test loss:21.23853
|
| 51 |
+
2024-05-18 12:36:08,830 [INFO ] -1|bpp:0.03521|mask:0.00764|mask_ratio:0.3|psnr:17.07054|msssim:0.63998|lpips:0.18754|dists:0.15498|accu:0.70446|Test loss:643.00879
|
| 52 |
+
2024-05-18 14:00:50,675 [INFO ] -1|bpp:0.04252|mask:0.00709|mask_ratio:0.1|psnr:18.25655|msssim:0.69397|lpips:0.15150|dists:0.13635|accu:0.72014|Test loss:2993.12524
|
| 53 |
+
2024-05-18 15:25:33,426 [INFO ] -1|bpp:0.04581|mask:0.00684|mask_ratio:0.01|psnr:18.75086|msssim:0.71406|lpips:0.13808|dists:0.12934|accu:0.72174|Test loss:6007.86475
|
MIM_vbr_eval_lossless/20240521_120210.log
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-05-21 12:02:10,417 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240521_120210.log
|
| 2 |
+
2024-05-21 12:02:10,418 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-05-21 12:02:10,418 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-05-21 12:02:10,418 [INFO ] name:exp
|
| 5 |
+
2024-05-21 12:02:10,418 [INFO ] lr:None
|
| 6 |
+
2024-05-21 12:02:10,418 [INFO ] world_size:1
|
| 7 |
+
2024-05-21 12:02:10,418 [INFO ] local_rank:-1
|
| 8 |
+
2024-05-21 12:02:10,418 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-05-21 12:02:10,418 [INFO ] dist_url:env://
|
| 10 |
+
2024-05-21 12:02:10,418 [INFO ] TEST:False
|
| 11 |
+
2024-05-21 12:02:10,418 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-05-21 12:02:10,418 [INFO ] dataset:imagenet
|
| 13 |
+
2024-05-21 12:02:10,418 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 14 |
+
2024-05-21 12:02:10,418 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
|
| 15 |
+
2024-05-21 12:02:10,418 [INFO ] eval_path:
|
| 16 |
+
2024-05-21 12:02:10,418 [INFO ] batch_size:64
|
| 17 |
+
2024-05-21 12:02:10,418 [INFO ] test_batch_size:200
|
| 18 |
+
2024-05-21 12:02:10,418 [INFO ] num_workers:4
|
| 19 |
+
2024-05-21 12:02:10,418 [INFO ] epochs:400
|
| 20 |
+
2024-05-21 12:02:10,418 [INFO ] blr:0.0002
|
| 21 |
+
2024-05-21 12:02:10,419 [INFO ] min_lr:5e-06
|
| 22 |
+
2024-05-21 12:02:10,419 [INFO ] accum_iter:1
|
| 23 |
+
2024-05-21 12:02:10,419 [INFO ] input_size:256
|
| 24 |
+
2024-05-21 12:02:10,419 [INFO ] weight_decay:0.05
|
| 25 |
+
2024-05-21 12:02:10,419 [INFO ] warmup_epochs:20
|
| 26 |
+
2024-05-21 12:02:10,419 [INFO ] mask_ratio_min:0.5
|
| 27 |
+
2024-05-21 12:02:10,419 [INFO ] mask_ratio_max:0.9
|
| 28 |
+
2024-05-21 12:02:10,419 [INFO ] mask_ratio_mu:0.55
|
| 29 |
+
2024-05-21 12:02:10,419 [INFO ] mask_ratio_std:0.25
|
| 30 |
+
2024-05-21 12:02:10,419 [INFO ] grad_clip:3.0
|
| 31 |
+
2024-05-21 12:02:10,419 [INFO ] device:cuda
|
| 32 |
+
2024-05-21 12:02:10,419 [INFO ] distributed:True
|
| 33 |
+
2024-05-21 12:02:10,419 [INFO ] seed:0
|
| 34 |
+
2024-05-21 12:02:10,419 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
|
| 35 |
+
2024-05-21 12:02:10,419 [INFO ] start_epoch:0
|
| 36 |
+
2024-05-21 12:02:10,419 [INFO ] pin_mem:True
|
| 37 |
+
2024-05-21 12:02:10,419 [INFO ] save:True
|
| 38 |
+
2024-05-21 12:02:10,419 [INFO ] exp_name:MIM_vbr_eval_lossless
|
| 39 |
+
2024-05-21 12:02:10,419 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 40 |
+
2024-05-21 12:02:10,419 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 41 |
+
2024-05-21 12:02:10,419 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 42 |
+
2024-05-21 12:02:10,419 [INFO ] strart_epoch:0
|
| 43 |
+
2024-05-21 12:02:10,420 [INFO ] cuda:True
|
| 44 |
+
2024-05-21 12:02:10,420 [INFO ] rank:0
|
| 45 |
+
2024-05-21 12:02:10,420 [INFO ] gpu:0
|
| 46 |
+
2024-05-21 12:02:10,420 [INFO ] dist_backend:nccl
|
| 47 |
+
2024-05-21 12:02:10,420 [INFO ] ===================================================
|
| 48 |
+
2024-05-21 12:03:00,851 [INFO ] -1|bpp:0.00784|mask:0.00908|mask_ratio:0.8|psnr:13.82288|msssim:0.34806|lpips:0.42462|dists:0.27493|Test loss:15.10523
|
| 49 |
+
2024-05-21 12:03:41,020 [INFO ] -1|bpp:0.00987|mask:0.00891|mask_ratio:0.75|psnr:14.37770|msssim:0.39104|lpips:0.38284|dists:0.26121|Test loss:22.23518
|
| 50 |
+
2024-05-21 12:04:19,453 [INFO ] -1|bpp:0.01170|mask:0.00879|mask_ratio:0.7|psnr:15.33662|msssim:0.45445|lpips:0.34781|dists:0.23473|Test loss:31.04415
|
| 51 |
+
2024-05-21 12:05:01,263 [INFO ] -1|bpp:0.01369|mask:0.00867|mask_ratio:0.65|psnr:15.06704|msssim:0.46401|lpips:0.32283|dists:0.22039|Test loss:45.53056
|
| 52 |
+
2024-05-21 12:05:46,943 [INFO ] -1|bpp:0.01568|mask:0.00854|mask_ratio:0.6|psnr:16.10357|msssim:0.51571|lpips:0.28952|dists:0.20649|Test loss:66.32143
|
MIM_vbr_eval_lossless/20240521_121146.log
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-05-21 12:11:46,727 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240521_121146.log
|
| 2 |
+
2024-05-21 12:11:46,727 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-05-21 12:11:46,727 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-05-21 12:11:46,727 [INFO ] name:exp
|
| 5 |
+
2024-05-21 12:11:46,727 [INFO ] lr:None
|
| 6 |
+
2024-05-21 12:11:46,727 [INFO ] world_size:1
|
| 7 |
+
2024-05-21 12:11:46,727 [INFO ] local_rank:-1
|
| 8 |
+
2024-05-21 12:11:46,727 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-05-21 12:11:46,727 [INFO ] dist_url:env://
|
| 10 |
+
2024-05-21 12:11:46,727 [INFO ] TEST:False
|
| 11 |
+
2024-05-21 12:11:46,727 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-05-21 12:11:46,728 [INFO ] dataset:imagenet
|
| 13 |
+
2024-05-21 12:11:46,728 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
|
| 14 |
+
2024-05-21 12:11:46,728 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
|
| 15 |
+
2024-05-21 12:11:46,728 [INFO ] eval_path:
|
| 16 |
+
2024-05-21 12:11:46,728 [INFO ] batch_size:64
|
| 17 |
+
2024-05-21 12:11:46,728 [INFO ] test_batch_size:200
|
| 18 |
+
2024-05-21 12:11:46,728 [INFO ] num_workers:4
|
| 19 |
+
2024-05-21 12:11:46,728 [INFO ] epochs:400
|
| 20 |
+
2024-05-21 12:11:46,728 [INFO ] blr:0.0002
|
| 21 |
+
2024-05-21 12:11:46,728 [INFO ] min_lr:5e-06
|
| 22 |
+
2024-05-21 12:11:46,728 [INFO ] accum_iter:1
|
| 23 |
+
2024-05-21 12:11:46,728 [INFO ] input_size:256
|
| 24 |
+
2024-05-21 12:11:46,728 [INFO ] weight_decay:0.05
|
| 25 |
+
2024-05-21 12:11:46,728 [INFO ] warmup_epochs:20
|
| 26 |
+
2024-05-21 12:11:46,728 [INFO ] mask_ratio_min:0.5
|
| 27 |
+
2024-05-21 12:11:46,728 [INFO ] mask_ratio_max:0.9
|
| 28 |
+
2024-05-21 12:11:46,728 [INFO ] mask_ratio_mu:0.55
|
| 29 |
+
2024-05-21 12:11:46,728 [INFO ] mask_ratio_std:0.25
|
| 30 |
+
2024-05-21 12:11:46,728 [INFO ] grad_clip:3.0
|
| 31 |
+
2024-05-21 12:11:46,728 [INFO ] device:cuda
|
| 32 |
+
2024-05-21 12:11:46,729 [INFO ] distributed:True
|
| 33 |
+
2024-05-21 12:11:46,729 [INFO ] seed:0
|
| 34 |
+
2024-05-21 12:11:46,729 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
|
| 35 |
+
2024-05-21 12:11:46,729 [INFO ] start_epoch:0
|
| 36 |
+
2024-05-21 12:11:46,729 [INFO ] pin_mem:True
|
| 37 |
+
2024-05-21 12:11:46,729 [INFO ] save:True
|
| 38 |
+
2024-05-21 12:11:46,729 [INFO ] exp_name:MIM_vbr_eval_lossless
|
| 39 |
+
2024-05-21 12:11:46,729 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 40 |
+
2024-05-21 12:11:46,729 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 41 |
+
2024-05-21 12:11:46,729 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
|
| 42 |
+
2024-05-21 12:11:46,729 [INFO ] strart_epoch:0
|
| 43 |
+
2024-05-21 12:11:46,729 [INFO ] cuda:True
|
| 44 |
+
2024-05-21 12:11:46,729 [INFO ] rank:0
|
| 45 |
+
2024-05-21 12:11:46,729 [INFO ] gpu:0
|
| 46 |
+
2024-05-21 12:11:46,729 [INFO ] dist_backend:nccl
|
| 47 |
+
2024-05-21 12:11:46,729 [INFO ] ===================================================
|
| 48 |
+
2024-05-21 12:12:38,993 [INFO ] -1|bpp:0.00784|mask:0.00908|mask_ratio:0.8|psnr:13.82276|msssim:0.34805|lpips:0.42462|dists:0.27493|Test loss:15.10523
|
| 49 |
+
2024-05-21 12:13:18,373 [INFO ] -1|bpp:0.00987|mask:0.00891|mask_ratio:0.75|psnr:14.37755|msssim:0.39104|lpips:0.38284|dists:0.26121|Test loss:22.23518
|
| 50 |
+
2024-05-21 12:13:58,355 [INFO ] -1|bpp:0.01170|mask:0.00879|mask_ratio:0.7|psnr:15.33651|msssim:0.45445|lpips:0.34781|dists:0.23473|Test loss:31.04415
|
| 51 |
+
2024-05-21 12:14:38,346 [INFO ] -1|bpp:0.01369|mask:0.00867|mask_ratio:0.65|psnr:15.06691|msssim:0.46401|lpips:0.32283|dists:0.22039|Test loss:45.53056
|
| 52 |
+
2024-05-21 12:15:17,341 [INFO ] -1|bpp:0.01568|mask:0.00854|mask_ratio:0.6|psnr:16.10348|msssim:0.51571|lpips:0.28952|dists:0.20649|Test loss:66.32143
|
| 53 |
+
2024-05-21 12:15:59,518 [INFO ] -1|bpp:0.01769|mask:0.00842|mask_ratio:0.55|psnr:16.66886|msssim:0.54068|lpips:0.27387|dists:0.19902|Test loss:97.59130
|
| 54 |
+
2024-05-21 12:16:43,259 [INFO ] -1|bpp:0.01969|mask:0.00830|mask_ratio:0.5|psnr:17.05398|msssim:0.55989|lpips:0.25349|dists:0.18627|Test loss:143.87851
|
| 55 |
+
2024-05-21 12:17:21,314 [INFO ] -1|bpp:0.02154|mask:0.00818|mask_ratio:0.45|psnr:17.42430|msssim:0.58354|lpips:0.24210|dists:0.17977|Test loss:203.66936
|
| 56 |
+
2024-05-21 12:18:04,417 [INFO ] -1|bpp:0.02357|mask:0.00794|mask_ratio:0.4|psnr:17.89758|msssim:0.60007|lpips:0.22977|dists:0.17287|Test loss:304.47241
|
| 57 |
+
2024-05-21 12:18:46,874 [INFO ] -1|bpp:0.02555|mask:0.00781|mask_ratio:0.35|psnr:18.47392|msssim:0.62136|lpips:0.21420|dists:0.16763|Test loss:447.87213
|
| 58 |
+
2024-05-21 12:19:30,648 [INFO ] -1|bpp:0.02756|mask:0.00769|mask_ratio:0.3|psnr:18.70894|msssim:0.64106|lpips:0.19699|dists:0.16134|Test loss:662.32788
|
| 59 |
+
2024-05-21 12:20:10,918 [INFO ] -1|bpp:0.02956|mask:0.00757|mask_ratio:0.25|psnr:18.98025|msssim:0.65567|lpips:0.18680|dists:0.15801|Test loss:983.28394
|
| 60 |
+
2024-05-21 12:20:50,854 [INFO ] -1|bpp:0.03138|mask:0.00745|mask_ratio:0.2|psnr:19.34285|msssim:0.66431|lpips:0.18044|dists:0.15229|Test loss:1385.45557
|
| 61 |
+
2024-05-21 12:21:32,470 [INFO ] -1|bpp:0.03339|mask:0.00732|mask_ratio:0.15|psnr:19.71813|msssim:0.68537|lpips:0.16575|dists:0.14260|Test loss:2053.51758
|
| 62 |
+
2024-05-21 12:22:12,619 [INFO ] -1|bpp:0.03539|mask:0.00720|mask_ratio:0.1|psnr:20.06234|msssim:0.69720|lpips:0.15856|dists:0.14339|Test loss:3017.49292
|
| 63 |
+
2024-05-21 12:22:51,448 [INFO ] -1|bpp:0.03739|mask:0.00703|mask_ratio:0.05|psnr:20.41125|msssim:0.71023|lpips:0.15008|dists:0.14040|Test loss:4599.72754
|
VQ_w_Prompt/20240602_123240.log
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-06-02 12:32:40,143 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt//20240602_123240.log
|
| 2 |
+
2024-06-02 12:32:40,143 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-06-02 12:32:40,143 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-06-02 12:32:40,143 [INFO ] name:exp
|
| 5 |
+
2024-06-02 12:32:40,144 [INFO ] lr:None
|
| 6 |
+
2024-06-02 12:32:40,144 [INFO ] world_size:1
|
| 7 |
+
2024-06-02 12:32:40,144 [INFO ] local_rank:-1
|
| 8 |
+
2024-06-02 12:32:40,144 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-06-02 12:32:40,144 [INFO ] dist_url:env://
|
| 10 |
+
2024-06-02 12:32:40,144 [INFO ] TEST:False
|
| 11 |
+
2024-06-02 12:32:40,144 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-06-02 12:32:40,144 [INFO ] dataset:coco
|
| 13 |
+
2024-06-02 12:32:40,144 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
|
| 14 |
+
2024-06-02 12:32:40,144 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
|
| 15 |
+
2024-06-02 12:32:40,144 [INFO ] eval_path:
|
| 16 |
+
2024-06-02 12:32:40,144 [INFO ] batch_size:128
|
| 17 |
+
2024-06-02 12:32:40,144 [INFO ] learning_rate:0.0001
|
| 18 |
+
2024-06-02 12:32:40,144 [INFO ] test_batch_size:1
|
| 19 |
+
2024-06-02 12:32:40,144 [INFO ] num_workers:4
|
| 20 |
+
2024-06-02 12:32:40,144 [INFO ] epochs:100
|
| 21 |
+
2024-06-02 12:32:40,144 [INFO ] blr:0.00015
|
| 22 |
+
2024-06-02 12:32:40,144 [INFO ] min_lr:5e-06
|
| 23 |
+
2024-06-02 12:32:40,144 [INFO ] accum_iter:1
|
| 24 |
+
2024-06-02 12:32:40,144 [INFO ] input_size:256
|
| 25 |
+
2024-06-02 12:32:40,144 [INFO ] weight_decay:0.05
|
| 26 |
+
2024-06-02 12:32:40,144 [INFO ] warmup_epochs:10
|
| 27 |
+
2024-06-02 12:32:40,144 [INFO ] mask_ratio_min:0.5
|
| 28 |
+
2024-06-02 12:32:40,144 [INFO ] mask_ratio_max:0.9
|
| 29 |
+
2024-06-02 12:32:40,144 [INFO ] mask_ratio_mu:0.55
|
| 30 |
+
2024-06-02 12:32:40,145 [INFO ] mask_ratio_std:0.25
|
| 31 |
+
2024-06-02 12:32:40,145 [INFO ] grad_clip:3.0
|
| 32 |
+
2024-06-02 12:32:40,145 [INFO ] device:cuda
|
| 33 |
+
2024-06-02 12:32:40,145 [INFO ] distributed:False
|
| 34 |
+
2024-06-02 12:32:40,145 [INFO ] seed:0
|
| 35 |
+
2024-06-02 12:32:40,145 [INFO ] gpu_id:0,1
|
| 36 |
+
2024-06-02 12:32:40,145 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
|
| 37 |
+
2024-06-02 12:32:40,145 [INFO ] maskrcnn_path:/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl
|
| 38 |
+
2024-06-02 12:32:40,145 [INFO ] start_epoch:0
|
| 39 |
+
2024-06-02 12:32:40,145 [INFO ] pin_mem:True
|
| 40 |
+
2024-06-02 12:32:40,145 [INFO ] save:True
|
| 41 |
+
2024-06-02 12:32:40,145 [INFO ] exp_name:VQ_w_Prompt
|
| 42 |
+
2024-06-02 12:32:40,145 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
|
| 43 |
+
2024-06-02 12:32:40,145 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt
|
| 44 |
+
2024-06-02 12:32:40,145 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt
|
| 45 |
+
2024-06-02 12:32:40,145 [INFO ] strart_epoch:0
|
| 46 |
+
2024-06-02 12:32:40,145 [INFO ] cuda:True
|
| 47 |
+
2024-06-02 12:32:40,145 [INFO ] ===================================================
|
| 48 |
+
2024-06-02 12:36:03,409 [INFO ] Epoch: -1 | perc loss: 0.33139 | psnr: 21.90954 | lpips: 0.22209 | dists: 0.09723
|
| 49 |
+
2024-06-02 12:55:03,529 [INFO ] Epoch: 0 | perc loss: 0.33168 | psnr: 21.83200 | lpips: 0.22776 | dists: 0.11052
|
| 50 |
+
2024-06-02 12:55:03,533 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 51 |
+
2024-06-02 12:55:03,812 [INFO ] Saving BEST checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 52 |
+
2024-06-02 13:13:51,301 [INFO ] Epoch: 1 | perc loss: 0.33369 | psnr: 21.56722 | lpips: 0.22748 | dists: 0.10901
|
| 53 |
+
2024-06-02 13:13:51,305 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 54 |
+
2024-06-02 13:32:37,056 [INFO ] Epoch: 2 | perc loss: 0.32811 | psnr: 21.75561 | lpips: 0.22586 | dists: 0.10543
|
| 55 |
+
2024-06-02 13:32:37,060 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 56 |
+
2024-06-02 13:32:37,942 [INFO ] Saving BEST checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 57 |
+
2024-06-02 13:51:19,223 [INFO ] Epoch: 3 | perc loss: 0.33057 | psnr: 21.71532 | lpips: 0.22620 | dists: 0.10827
|
| 58 |
+
2024-06-02 13:51:19,227 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 59 |
+
2024-06-02 14:10:08,103 [INFO ] Epoch: 4 | perc loss: 0.33012 | psnr: 21.70657 | lpips: 0.22693 | dists: 0.10772
|
| 60 |
+
2024-06-02 14:10:08,108 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 61 |
+
2024-06-02 14:28:55,372 [INFO ] Epoch: 5 | perc loss: 0.33078 | psnr: 21.71807 | lpips: 0.22735 | dists: 0.10827
|
| 62 |
+
2024-06-02 14:28:55,376 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 63 |
+
2024-06-02 14:47:49,287 [INFO ] Epoch: 6 | perc loss: 0.33090 | psnr: 21.79235 | lpips: 0.22534 | dists: 0.10498
|
| 64 |
+
2024-06-02 14:47:49,291 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 65 |
+
2024-06-02 15:06:38,406 [INFO ] Epoch: 7 | perc loss: 0.33033 | psnr: 21.74950 | lpips: 0.22632 | dists: 0.10654
|
| 66 |
+
2024-06-02 15:06:38,410 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 67 |
+
2024-06-02 15:25:29,010 [INFO ] Epoch: 8 | perc loss: 0.32939 | psnr: 21.71603 | lpips: 0.22598 | dists: 0.10593
|
| 68 |
+
2024-06-02 15:25:29,014 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 69 |
+
2024-06-02 15:44:21,498 [INFO ] Epoch: 9 | perc loss: 0.33421 | psnr: 21.67705 | lpips: 0.22861 | dists: 0.10923
|
| 70 |
+
2024-06-02 15:44:21,502 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 71 |
+
2024-06-02 16:03:15,138 [INFO ] Epoch: 10 | perc loss: 0.33003 | psnr: 21.75746 | lpips: 0.22706 | dists: 0.10813
|
| 72 |
+
2024-06-02 16:03:15,141 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 73 |
+
2024-06-02 16:21:55,573 [INFO ] Epoch: 11 | perc loss: 0.33178 | psnr: 21.70939 | lpips: 0.22774 | dists: 0.10836
|
| 74 |
+
2024-06-02 16:21:55,577 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 75 |
+
2024-06-02 16:40:34,144 [INFO ] Epoch: 12 | perc loss: 0.33171 | psnr: 21.89004 | lpips: 0.22690 | dists: 0.10719
|
| 76 |
+
2024-06-02 16:40:34,147 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 77 |
+
2024-06-02 16:59:27,863 [INFO ] Epoch: 13 | perc loss: 0.32893 | psnr: 21.72392 | lpips: 0.22659 | dists: 0.10690
|
| 78 |
+
2024-06-02 16:59:27,866 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 79 |
+
2024-06-02 17:18:09,724 [INFO ] Epoch: 14 | perc loss: 0.33351 | psnr: 21.66716 | lpips: 0.22815 | dists: 0.10899
|
| 80 |
+
2024-06-02 17:18:09,727 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 81 |
+
2024-06-02 17:36:50,797 [INFO ] Epoch: 15 | perc loss: 0.33180 | psnr: 21.96293 | lpips: 0.22651 | dists: 0.10622
|
| 82 |
+
2024-06-02 17:36:50,801 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 83 |
+
2024-06-02 17:55:29,103 [INFO ] Epoch: 16 | perc loss: 0.33189 | psnr: 21.86891 | lpips: 0.22776 | dists: 0.10837
|
| 84 |
+
2024-06-02 17:55:29,107 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 85 |
+
2024-06-02 18:14:11,530 [INFO ] Epoch: 17 | perc loss: 0.33210 | psnr: 21.85856 | lpips: 0.22759 | dists: 0.10800
|
| 86 |
+
2024-06-02 18:14:11,534 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 87 |
+
2024-06-02 18:32:51,595 [INFO ] Epoch: 18 | perc loss: 0.33330 | psnr: 21.89861 | lpips: 0.22754 | dists: 0.10816
|
| 88 |
+
2024-06-02 18:32:51,598 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 89 |
+
2024-06-02 18:51:33,222 [INFO ] Epoch: 19 | perc loss: 0.33425 | psnr: 21.91737 | lpips: 0.22715 | dists: 0.10672
|
| 90 |
+
2024-06-02 18:51:33,226 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 91 |
+
2024-06-02 19:10:20,889 [INFO ] Epoch: 20 | perc loss: 0.33207 | psnr: 21.79711 | lpips: 0.22724 | dists: 0.10738
|
| 92 |
+
2024-06-02 19:10:20,892 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 93 |
+
2024-06-02 19:29:09,240 [INFO ] Epoch: 21 | perc loss: 0.33146 | psnr: 21.83736 | lpips: 0.22665 | dists: 0.10620
|
| 94 |
+
2024-06-02 19:29:09,244 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 95 |
+
2024-06-02 19:47:50,856 [INFO ] Epoch: 22 | perc loss: 0.33375 | psnr: 21.78199 | lpips: 0.22757 | dists: 0.10758
|
| 96 |
+
2024-06-02 19:47:50,860 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 97 |
+
2024-06-02 20:06:27,901 [INFO ] Epoch: 23 | perc loss: 0.33233 | psnr: 21.79922 | lpips: 0.22696 | dists: 0.10653
|
| 98 |
+
2024-06-02 20:06:27,905 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 99 |
+
2024-06-02 20:25:12,269 [INFO ] Epoch: 24 | perc loss: 0.33376 | psnr: 21.88994 | lpips: 0.22836 | dists: 0.10865
|
| 100 |
+
2024-06-02 20:25:12,272 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 101 |
+
2024-06-02 20:43:57,739 [INFO ] Epoch: 25 | perc loss: 0.33358 | psnr: 21.91691 | lpips: 0.22740 | dists: 0.10682
|
| 102 |
+
2024-06-02 20:43:57,743 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 103 |
+
2024-06-02 21:02:38,787 [INFO ] Epoch: 26 | perc loss: 0.33293 | psnr: 21.82560 | lpips: 0.22706 | dists: 0.10724
|
| 104 |
+
2024-06-02 21:02:38,790 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 105 |
+
2024-06-02 21:21:23,150 [INFO ] Epoch: 27 | perc loss: 0.33276 | psnr: 21.93110 | lpips: 0.22654 | dists: 0.10670
|
| 106 |
+
2024-06-02 21:21:23,154 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 107 |
+
2024-06-02 21:40:05,814 [INFO ] Epoch: 28 | perc loss: 0.33448 | psnr: 21.88258 | lpips: 0.22824 | dists: 0.10863
|
| 108 |
+
2024-06-02 21:40:05,818 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 109 |
+
2024-06-02 21:58:50,101 [INFO ] Epoch: 29 | perc loss: 0.33435 | psnr: 21.86732 | lpips: 0.22764 | dists: 0.10790
|
| 110 |
+
2024-06-02 21:58:50,105 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 111 |
+
2024-06-02 22:17:35,636 [INFO ] Epoch: 30 | perc loss: 0.33510 | psnr: 21.83126 | lpips: 0.22859 | dists: 0.10916
|
| 112 |
+
2024-06-02 22:17:35,640 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 113 |
+
2024-06-02 22:36:25,572 [INFO ] Epoch: 31 | perc loss: 0.33503 | psnr: 21.87967 | lpips: 0.22826 | dists: 0.10809
|
| 114 |
+
2024-06-02 22:36:25,576 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 115 |
+
2024-06-02 22:55:03,799 [INFO ] Epoch: 32 | perc loss: 0.33180 | psnr: 21.86306 | lpips: 0.22669 | dists: 0.10689
|
| 116 |
+
2024-06-02 22:55:03,803 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 117 |
+
2024-06-02 23:13:45,201 [INFO ] Epoch: 33 | perc loss: 0.33209 | psnr: 21.90117 | lpips: 0.22692 | dists: 0.10679
|
| 118 |
+
2024-06-02 23:13:45,206 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 119 |
+
2024-06-02 23:32:29,814 [INFO ] Epoch: 34 | perc loss: 0.33378 | psnr: 21.78412 | lpips: 0.22783 | dists: 0.10720
|
| 120 |
+
2024-06-02 23:32:29,818 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 121 |
+
2024-06-02 23:51:14,307 [INFO ] Epoch: 35 | perc loss: 0.33245 | psnr: 21.80858 | lpips: 0.22757 | dists: 0.10710
|
| 122 |
+
2024-06-02 23:51:14,311 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 123 |
+
2024-06-03 00:09:57,617 [INFO ] Epoch: 36 | perc loss: 0.33471 | psnr: 21.78929 | lpips: 0.22853 | dists: 0.10903
|
| 124 |
+
2024-06-03 00:09:57,620 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 125 |
+
2024-06-03 00:28:36,627 [INFO ] Epoch: 37 | perc loss: 0.33349 | psnr: 21.98594 | lpips: 0.22731 | dists: 0.10736
|
| 126 |
+
2024-06-03 00:28:36,630 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 127 |
+
2024-06-03 00:47:20,972 [INFO ] Epoch: 38 | perc loss: 0.33425 | psnr: 21.97987 | lpips: 0.22731 | dists: 0.10701
|
| 128 |
+
2024-06-03 00:47:20,976 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 129 |
+
2024-06-03 01:05:59,529 [INFO ] Epoch: 39 | perc loss: 0.33325 | psnr: 21.69044 | lpips: 0.22760 | dists: 0.10743
|
| 130 |
+
2024-06-03 01:05:59,534 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 131 |
+
2024-06-03 01:24:45,951 [INFO ] Epoch: 40 | perc loss: 0.33306 | psnr: 21.84931 | lpips: 0.22768 | dists: 0.10715
|
| 132 |
+
2024-06-03 01:24:45,955 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 133 |
+
2024-06-03 01:43:29,237 [INFO ] Epoch: 41 | perc loss: 0.33394 | psnr: 21.76097 | lpips: 0.22829 | dists: 0.10813
|
| 134 |
+
2024-06-03 01:43:29,240 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 135 |
+
2024-06-03 02:02:16,998 [INFO ] Epoch: 42 | perc loss: 0.33422 | psnr: 22.03626 | lpips: 0.22767 | dists: 0.10776
|
| 136 |
+
2024-06-03 02:02:17,001 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 137 |
+
2024-06-03 02:20:56,025 [INFO ] Epoch: 43 | perc loss: 0.33415 | psnr: 21.82381 | lpips: 0.22773 | dists: 0.10845
|
| 138 |
+
2024-06-03 02:20:56,028 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 139 |
+
2024-06-03 02:39:40,184 [INFO ] Epoch: 44 | perc loss: 0.33257 | psnr: 21.83755 | lpips: 0.22701 | dists: 0.10675
|
| 140 |
+
2024-06-03 02:39:40,189 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 141 |
+
2024-06-03 02:58:26,226 [INFO ] Epoch: 45 | perc loss: 0.33221 | psnr: 21.75377 | lpips: 0.22706 | dists: 0.10657
|
| 142 |
+
2024-06-03 02:58:26,230 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 143 |
+
2024-06-03 03:17:13,245 [INFO ] Epoch: 46 | perc loss: 0.33469 | psnr: 21.82597 | lpips: 0.22778 | dists: 0.10752
|
| 144 |
+
2024-06-03 03:17:13,248 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 145 |
+
2024-06-03 03:35:52,662 [INFO ] Epoch: 47 | perc loss: 0.33427 | psnr: 21.91075 | lpips: 0.22704 | dists: 0.10703
|
| 146 |
+
2024-06-03 03:35:52,666 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 147 |
+
2024-06-03 03:54:34,629 [INFO ] Epoch: 48 | perc loss: 0.33344 | psnr: 21.73634 | lpips: 0.22771 | dists: 0.10786
|
| 148 |
+
2024-06-03 03:54:34,634 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 149 |
+
2024-06-03 04:13:18,077 [INFO ] Epoch: 49 | perc loss: 0.33254 | psnr: 21.84265 | lpips: 0.22720 | dists: 0.10674
|
| 150 |
+
2024-06-03 04:13:18,081 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 151 |
+
2024-06-03 04:31:57,392 [INFO ] Epoch: 50 | perc loss: 0.33315 | psnr: 21.84459 | lpips: 0.22754 | dists: 0.10744
|
| 152 |
+
2024-06-03 04:31:57,396 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 153 |
+
2024-06-03 04:50:35,589 [INFO ] Epoch: 51 | perc loss: 0.33395 | psnr: 21.88569 | lpips: 0.22835 | dists: 0.10805
|
| 154 |
+
2024-06-03 04:50:35,593 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 155 |
+
2024-06-03 05:09:26,310 [INFO ] Epoch: 52 | perc loss: 0.33427 | psnr: 21.80637 | lpips: 0.22858 | dists: 0.10838
|
| 156 |
+
2024-06-03 05:09:26,314 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 157 |
+
2024-06-03 05:28:16,003 [INFO ] Epoch: 53 | perc loss: 0.33332 | psnr: 21.78990 | lpips: 0.22774 | dists: 0.10752
|
| 158 |
+
2024-06-03 05:28:16,008 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 159 |
+
2024-06-03 05:46:56,485 [INFO ] Epoch: 54 | perc loss: 0.33440 | psnr: 21.92330 | lpips: 0.22701 | dists: 0.10652
|
| 160 |
+
2024-06-03 05:46:56,489 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 161 |
+
2024-06-03 06:05:39,283 [INFO ] Epoch: 55 | perc loss: 0.33324 | psnr: 21.91578 | lpips: 0.22739 | dists: 0.10760
|
| 162 |
+
2024-06-03 06:05:39,287 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 163 |
+
2024-06-03 06:24:16,107 [INFO ] Epoch: 56 | perc loss: 0.33304 | psnr: 21.88298 | lpips: 0.22694 | dists: 0.10703
|
| 164 |
+
2024-06-03 06:24:16,111 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 165 |
+
2024-06-03 06:43:07,089 [INFO ] Epoch: 57 | perc loss: 0.33419 | psnr: 21.84341 | lpips: 0.22722 | dists: 0.10693
|
| 166 |
+
2024-06-03 06:43:07,093 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 167 |
+
2024-06-03 07:01:46,596 [INFO ] Epoch: 58 | perc loss: 0.33342 | psnr: 21.82555 | lpips: 0.22763 | dists: 0.10721
|
| 168 |
+
2024-06-03 07:01:46,600 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 169 |
+
2024-06-03 07:20:46,044 [INFO ] Epoch: 59 | perc loss: 0.33426 | psnr: 21.89526 | lpips: 0.22781 | dists: 0.10761
|
| 170 |
+
2024-06-03 07:20:46,047 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 171 |
+
2024-06-03 07:39:26,255 [INFO ] Epoch: 60 | perc loss: 0.33444 | psnr: 21.91903 | lpips: 0.22711 | dists: 0.10757
|
| 172 |
+
2024-06-03 07:39:26,259 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 173 |
+
2024-06-03 07:58:06,507 [INFO ] Epoch: 61 | perc loss: 0.33431 | psnr: 21.97837 | lpips: 0.22842 | dists: 0.10904
|
| 174 |
+
2024-06-03 07:58:06,512 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 175 |
+
2024-06-03 08:16:49,835 [INFO ] Epoch: 62 | perc loss: 0.33514 | psnr: 21.99198 | lpips: 0.22782 | dists: 0.10722
|
| 176 |
+
2024-06-03 08:16:49,839 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 177 |
+
2024-06-03 08:35:33,361 [INFO ] Epoch: 63 | perc loss: 0.33423 | psnr: 21.85692 | lpips: 0.22768 | dists: 0.10727
|
| 178 |
+
2024-06-03 08:35:33,366 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 179 |
+
2024-06-03 08:54:14,939 [INFO ] Epoch: 64 | perc loss: 0.33443 | psnr: 21.82958 | lpips: 0.22788 | dists: 0.10829
|
| 180 |
+
2024-06-03 08:54:14,942 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 181 |
+
2024-06-03 09:12:53,969 [INFO ] Epoch: 65 | perc loss: 0.33494 | psnr: 21.82883 | lpips: 0.22806 | dists: 0.10793
|
| 182 |
+
2024-06-03 09:12:53,973 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 183 |
+
2024-06-03 09:31:37,870 [INFO ] Epoch: 66 | perc loss: 0.33452 | psnr: 21.88860 | lpips: 0.22774 | dists: 0.10777
|
| 184 |
+
2024-06-03 09:31:37,874 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 185 |
+
2024-06-03 09:50:15,999 [INFO ] Epoch: 67 | perc loss: 0.33397 | psnr: 21.87563 | lpips: 0.22739 | dists: 0.10670
|
| 186 |
+
2024-06-03 09:50:16,002 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 187 |
+
2024-06-03 10:09:03,337 [INFO ] Epoch: 68 | perc loss: 0.33403 | psnr: 21.84018 | lpips: 0.22757 | dists: 0.10754
|
| 188 |
+
2024-06-03 10:09:03,340 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 189 |
+
2024-06-03 10:27:57,411 [INFO ] Epoch: 69 | perc loss: 0.33503 | psnr: 21.81171 | lpips: 0.22727 | dists: 0.10667
|
| 190 |
+
2024-06-03 10:27:57,415 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 191 |
+
2024-06-03 10:46:47,458 [INFO ] Epoch: 70 | perc loss: 0.33324 | psnr: 21.89323 | lpips: 0.22725 | dists: 0.10659
|
| 192 |
+
2024-06-03 10:46:47,462 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 193 |
+
2024-06-03 11:05:28,932 [INFO ] Epoch: 71 | perc loss: 0.33461 | psnr: 21.83094 | lpips: 0.22864 | dists: 0.10749
|
| 194 |
+
2024-06-03 11:05:28,935 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
| 195 |
+
2024-06-03 11:24:26,070 [INFO ] Epoch: 72 | perc loss: 0.33356 | psnr: 21.72290 | lpips: 0.22737 | dists: 0.10574
|
| 196 |
+
2024-06-03 11:24:26,074 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
|
VQ_w_Prompt/checkpoint_69.pth.tar
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c013e63dc8fc03dfa3945c99aff18bd6e0e98c351bc205059bb44bd2e329ed45
|
| 3 |
+
size 281021679
|
__pycache__/model_vq.cpython-38.pyc
ADDED
|
Binary file (3.84 kB). View file
|
|
|
__pycache__/models_mage_codec.cpython-38.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
__pycache__/models_mage_codec_cls.cpython-38.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
__pycache__/models_mage_codec_full.cpython-38.pyc
ADDED
|
Binary file (22.2 kB). View file
|
|
|
__pycache__/models_mage_codec_high_resolu.cpython-38.pyc
ADDED
|
Binary file (22.4 kB). View file
|
|
|
__pycache__/models_mage_codec_hr_new.cpython-38.pyc
ADDED
|
Binary file (21.8 kB). View file
|
|
|
__pycache__/models_mage_codec_hs_w_p.cpython-38.pyc
ADDED
|
Binary file (24 kB). View file
|
|
|
__pycache__/models_mage_codec_rope.cpython-38.pyc
ADDED
|
Binary file (21.3 kB). View file
|
|
|
cal_upper_bound.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from model_vq import Model_VQ
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
|
| 13 |
+
import yaml
|
| 14 |
+
from pytorch_msssim import ms_ssim
|
| 15 |
+
from DISTS_pytorch import DISTS
|
| 16 |
+
import lpips
|
| 17 |
+
from torch.nn import functional as F
|
| 18 |
+
from torchvision import utils as vutils
|
| 19 |
+
import numpy as np
|
| 20 |
+
import glob
|
| 21 |
+
|
| 22 |
+
import util.misc as misc
|
| 23 |
+
import PIL.Image as Image
|
| 24 |
+
import torch.backends.cudnn as cudnn
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class CalMetrics(nn.Module):
|
| 32 |
+
"""Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.mse = nn.MSELoss()
|
| 37 |
+
|
| 38 |
+
def psnr(self, rec, ori):
|
| 39 |
+
mse = torch.mean((rec - ori) ** 2)
|
| 40 |
+
if(mse == 0):
|
| 41 |
+
return 100
|
| 42 |
+
max_pixel = 1.
|
| 43 |
+
psnr = 10 * torch.log10(max_pixel / mse)
|
| 44 |
+
return torch.mean(psnr)
|
| 45 |
+
|
| 46 |
+
def lpips_vgg(self, rec, ori):
|
| 47 |
+
loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
|
| 48 |
+
lipis_vgg = loss_fn_vgg(rec, ori)
|
| 49 |
+
return lipis_vgg
|
| 50 |
+
|
| 51 |
+
def lpips_alex(self, rec, ori):
|
| 52 |
+
loss_fn_alex = lpips.LPIPS(net='alex').cuda()
|
| 53 |
+
lipis_alex = loss_fn_alex(rec, ori)
|
| 54 |
+
return lipis_alex
|
| 55 |
+
|
| 56 |
+
def dists(self, rec, ori):
|
| 57 |
+
D = DISTS().cuda()
|
| 58 |
+
dists_value = D(rec, ori)
|
| 59 |
+
return dists_value
|
| 60 |
+
|
| 61 |
+
def forward(self, ori, rec):
|
| 62 |
+
out = {}
|
| 63 |
+
if rec is not None:
|
| 64 |
+
out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
|
| 65 |
+
out["lpips_vgg"] = self.lpips_vgg(torch.clamp(rec, 0, 1), ori)
|
| 66 |
+
out["lpips_alex"] = self.lpips_alex(torch.clamp(rec, 0, 1), ori)
|
| 67 |
+
out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class AverageMeter:
|
| 72 |
+
"""Compute running average."""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
self.val = 0
|
| 76 |
+
self.avg = 0
|
| 77 |
+
self.sum = 0
|
| 78 |
+
self.count = 0
|
| 79 |
+
|
| 80 |
+
def update(self, val, n=1):
|
| 81 |
+
self.val = val
|
| 82 |
+
self.sum += val * n
|
| 83 |
+
self.count += n
|
| 84 |
+
self.avg = self.sum / self.count
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class CustomDataParallel(nn.DataParallel):
|
| 88 |
+
"""Custom DataParallel to access the module methods."""
|
| 89 |
+
|
| 90 |
+
def __getattr__(self, key):
|
| 91 |
+
try:
|
| 92 |
+
return super().__getattr__(key)
|
| 93 |
+
except AttributeError:
|
| 94 |
+
return getattr(self.module, key)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def init(args):
|
| 98 |
+
base_dir = f'{args.root}/{args.exp_name}/'
|
| 99 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 100 |
+
return base_dir
|
| 101 |
+
|
| 102 |
+
def setup_logger(log_dir):
|
| 103 |
+
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
|
| 104 |
+
root_logger = logging.getLogger()
|
| 105 |
+
root_logger.setLevel(logging.INFO)
|
| 106 |
+
|
| 107 |
+
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
|
| 108 |
+
log_file_handler.setFormatter(log_formatter)
|
| 109 |
+
root_logger.addHandler(log_file_handler)
|
| 110 |
+
|
| 111 |
+
log_stream_handler = logging.StreamHandler(sys.stdout)
|
| 112 |
+
log_stream_handler.setFormatter(log_formatter)
|
| 113 |
+
root_logger.addHandler(log_stream_handler)
|
| 114 |
+
|
| 115 |
+
logging.info('Logging file is %s' % log_dir)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def load_img(p, padding=True, factor=64):
|
| 119 |
+
x = Image.open(p)
|
| 120 |
+
x = torch.from_numpy(np.asarray(x))
|
| 121 |
+
if len(x.shape) == 2:
|
| 122 |
+
x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3
|
| 123 |
+
x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
|
| 124 |
+
h, w = x.shape[2:4]
|
| 125 |
+
|
| 126 |
+
if padding:
|
| 127 |
+
dh = factor * math.ceil(h / factor) - h
|
| 128 |
+
dw = factor * math.ceil(w / factor) - w
|
| 129 |
+
# 均匀添加padding
|
| 130 |
+
dh_half = dh // 2
|
| 131 |
+
dw_half = dw // 2
|
| 132 |
+
dh_extra = dh % 2
|
| 133 |
+
dw_extra = dw % 2
|
| 134 |
+
x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra))
|
| 135 |
+
return x, h, w
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def save_img(img: torch.Tensor, vis_path, input_p, rec=False):
|
| 140 |
+
img = img.clone().detach()
|
| 141 |
+
img = img.to(torch.device('cpu'))
|
| 142 |
+
if os.path.isdir(vis_path) is not True:
|
| 143 |
+
os.makedirs(vis_path)
|
| 144 |
+
end = '/'
|
| 145 |
+
if rec:
|
| 146 |
+
vis_path = vis_path + '/rec'
|
| 147 |
+
if os.path.isdir(vis_path) is not True:
|
| 148 |
+
os.makedirs(vis_path)
|
| 149 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 150 |
+
else:
|
| 151 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 152 |
+
vutils.save_image(img, os.path.join(img_name), nrow=8)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def inference(epoch, eval_path, model, metrics_criterion, device, stage='test'):
|
| 156 |
+
model.eval()
|
| 157 |
+
psnr = AverageMeter()
|
| 158 |
+
lpips_vgg = AverageMeter()
|
| 159 |
+
lpips_alex = AverageMeter()
|
| 160 |
+
dists = AverageMeter()
|
| 161 |
+
|
| 162 |
+
vis_path = os.path.join("./VQGAN/", stage)
|
| 163 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 164 |
+
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
for input_p in eval_path:
|
| 167 |
+
x, hx, wx = load_img(input_p, padding=True, factor=64)
|
| 168 |
+
x = x.to(device)
|
| 169 |
+
rec = model(x)
|
| 170 |
+
|
| 171 |
+
x = x[:, :, :hx, :wx]
|
| 172 |
+
rec = rec[:, :, :hx, :wx]
|
| 173 |
+
rec = rec.to(device)
|
| 174 |
+
out_criterion = metrics_criterion(x, rec)
|
| 175 |
+
|
| 176 |
+
psnr.update(out_criterion['psnr'])
|
| 177 |
+
lpips_vgg.update(out_criterion['lpips_vgg'])
|
| 178 |
+
lpips_alex.update(out_criterion['lpips_alex'])
|
| 179 |
+
dists.update(out_criterion['dists'])
|
| 180 |
+
|
| 181 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 182 |
+
# save_img(x, vis_path, input_p)
|
| 183 |
+
save_img(rec, vis_path, input_p, rec=True)
|
| 184 |
+
|
| 185 |
+
model.train()
|
| 186 |
+
log_txt = f"{epoch}|psnr:{psnr.avg:.5f}|lpips_vgg:{lpips_vgg.avg.mean().item():.5f}|lpips_alex:{lpips_alex.avg.mean().item():.5f}|dists:{dists.avg.mean().item():.5f}"
|
| 187 |
+
logging.info(log_txt)
|
| 188 |
+
return psnr
|
| 189 |
+
|
| 190 |
+
def parse_args(argv):
|
| 191 |
+
parser = argparse.ArgumentParser(description="Example training script.")
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"-c",
|
| 194 |
+
"--config",
|
| 195 |
+
default="/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/cal_upper_bound.yaml",
|
| 196 |
+
help="Path to config file",
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
'--name',
|
| 200 |
+
default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
|
| 201 |
+
type=str,
|
| 202 |
+
help='Result dir name',
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
'--eval_path',
|
| 206 |
+
default='/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017',
|
| 207 |
+
type=str,
|
| 208 |
+
help='path to the evaluation dataset',
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 211 |
+
help='learning rate (absolute lr)')
|
| 212 |
+
given_configs, remaining = parser.parse_known_args(argv)
|
| 213 |
+
# distributed training parameters
|
| 214 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 215 |
+
help='number of distributed processes')
|
| 216 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 217 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 218 |
+
parser.add_argument('--dist_url', default='env://',
|
| 219 |
+
help='url used to set up distributed training')
|
| 220 |
+
with open(given_configs.config) as file:
|
| 221 |
+
yaml_data= yaml.safe_load(file)
|
| 222 |
+
parser.set_defaults(**yaml_data)
|
| 223 |
+
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"-T",
|
| 226 |
+
"--TEST",
|
| 227 |
+
action='store_true',
|
| 228 |
+
help='Testing'
|
| 229 |
+
)
|
| 230 |
+
args = parser.parse_args(remaining)
|
| 231 |
+
return args
|
| 232 |
+
|
| 233 |
+
def load_eval_ps(eval_path):
|
| 234 |
+
eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.jpg')))
|
| 235 |
+
return eval_ps
|
| 236 |
+
|
| 237 |
+
def main(argv):
|
| 238 |
+
args = parse_args(argv)
|
| 239 |
+
base_dir = init(args) # create the base dir for saving the results
|
| 240 |
+
|
| 241 |
+
if args.output_dir:
|
| 242 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 243 |
+
args.log_dir = args.output_dir
|
| 244 |
+
|
| 245 |
+
misc.init_distributed_mode(args)
|
| 246 |
+
|
| 247 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 248 |
+
print("{}".format(args).replace(', ', ',\n'))
|
| 249 |
+
|
| 250 |
+
device = torch.device(args.device)
|
| 251 |
+
# fix the seed for reproducibility
|
| 252 |
+
seed = args.seed + misc.get_rank()
|
| 253 |
+
torch.manual_seed(seed)
|
| 254 |
+
np.random.seed(seed)
|
| 255 |
+
|
| 256 |
+
cudnn.benchmark = True
|
| 257 |
+
|
| 258 |
+
setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
|
| 259 |
+
msg = f'======================= {args.name} ======================='
|
| 260 |
+
logging.info(msg)
|
| 261 |
+
for k in args.__dict__:
|
| 262 |
+
logging.info(k + ':' + str(args.__dict__[k]))
|
| 263 |
+
logging.info('=' * len(msg))
|
| 264 |
+
|
| 265 |
+
## ======================= prepare dataset ======================= ##
|
| 266 |
+
eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.jpg')))
|
| 267 |
+
|
| 268 |
+
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
|
| 269 |
+
|
| 270 |
+
## ======================= prepare model ======================= ##
|
| 271 |
+
vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
|
| 272 |
+
config = OmegaConf.load('config/vqgan.yaml').model
|
| 273 |
+
model = Model_VQ(ddconfig=config.params.ddconfig,
|
| 274 |
+
n_embed=config.params.n_embed, # 1024
|
| 275 |
+
embed_dim=config.params.embed_dim, # 256
|
| 276 |
+
ckpt_path=vqgan_ckpt_path)
|
| 277 |
+
|
| 278 |
+
model.to(device)
|
| 279 |
+
|
| 280 |
+
metrics_criterion = CalMetrics()
|
| 281 |
+
## ======================= pre validation ======================= ##
|
| 282 |
+
test_loss = inference(-1, eval_path, model, metrics_criterion, device, 'val')
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
main(sys.argv[1:])
|
ckpt_pretrained/vgg.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
|
| 3 |
+
size 7289
|
coco_eval.py
ADDED
|
File without changes
|
config/cal_upper_bound.yaml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak'
|
| 2 |
+
eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017'
|
| 3 |
+
# eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/shark_new/'
|
| 4 |
+
device: "cuda"
|
| 5 |
+
seed: 0
|
| 6 |
+
save: True
|
| 7 |
+
name: "VQ"
|
| 8 |
+
exp_name: "VQ"
|
| 9 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 10 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ" # path where to save, empty for no saving
|
| 11 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ/logs" # path where to tensorboard log
|
| 12 |
+
strart_epoch: 0
|
| 13 |
+
cuda: True
|
config/eval copy.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 5 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 128 # Batch size per GPU
|
| 8 |
+
learning_rate: 1.e-4
|
| 9 |
+
test_batch_size: 1
|
| 10 |
+
num_workers: 4
|
| 11 |
+
epochs: 100
|
| 12 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 13 |
+
min_lr: 5.e-6
|
| 14 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 15 |
+
input_size: 256
|
| 16 |
+
weight_decay: 0.05
|
| 17 |
+
warmup_epochs: 10
|
| 18 |
+
mask_ratio_min: 0.5
|
| 19 |
+
mask_ratio_max: 0.9
|
| 20 |
+
mask_ratio_mu: 0.55
|
| 21 |
+
mask_ratio_std: 0.25
|
| 22 |
+
grad_clip: 3.0
|
| 23 |
+
device: "cuda"
|
| 24 |
+
distributed: True
|
| 25 |
+
seed: 0
|
| 26 |
+
gpu_id: 0,1
|
| 27 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 28 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 29 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 30 |
+
maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
|
| 31 |
+
start_epoch: 0
|
| 32 |
+
pin_mem: True
|
| 33 |
+
save: True
|
| 34 |
+
name: "exp"
|
| 35 |
+
exp_name: "VQ_w_Prompt"
|
| 36 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 37 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt" # path where to save, empty for no saving
|
| 38 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/logs" # path where to tensorboard log
|
| 39 |
+
strart_epoch: 0
|
| 40 |
+
cuda: True
|
| 41 |
+
TEST: True
|
config/eval.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 5 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 128 # Batch size per GPU
|
| 8 |
+
learning_rate: 1.e-4
|
| 9 |
+
test_batch_size: 1
|
| 10 |
+
num_workers: 4
|
| 11 |
+
epochs: 100
|
| 12 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 13 |
+
min_lr: 5.e-6
|
| 14 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 15 |
+
input_size: 256
|
| 16 |
+
weight_decay: 0.05
|
| 17 |
+
warmup_epochs: 10
|
| 18 |
+
mask_ratio_min: 0.5
|
| 19 |
+
mask_ratio_max: 0.9
|
| 20 |
+
mask_ratio_mu: 0.55
|
| 21 |
+
mask_ratio_std: 0.25
|
| 22 |
+
grad_clip: 3.0
|
| 23 |
+
device: "cuda"
|
| 24 |
+
distributed: True
|
| 25 |
+
seed: 0
|
| 26 |
+
gpu_id: 0,1
|
| 27 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 28 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 29 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 30 |
+
maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
|
| 31 |
+
start_epoch: 0
|
| 32 |
+
pin_mem: True
|
| 33 |
+
save: True
|
| 34 |
+
name: "exp"
|
| 35 |
+
exp_name: "VQ_w_Prompt"
|
| 36 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 37 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt" # path where to save, empty for no saving
|
| 38 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/logs" # path where to tensorboard log
|
| 39 |
+
strart_epoch: 0
|
| 40 |
+
cuda: True
|
| 41 |
+
TEST: True
|
config/eval_config.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "imagenet"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
|
| 4 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 5 |
+
eval_path: ''
|
| 6 |
+
batch_size: 64 # Batch size per GPU
|
| 7 |
+
test_batch_size: 200
|
| 8 |
+
num_workers: 4
|
| 9 |
+
epochs: 400
|
| 10 |
+
blr: 2.e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 11 |
+
min_lr: 5.e-6
|
| 12 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 13 |
+
input_size: 256
|
| 14 |
+
weight_decay: 0.05
|
| 15 |
+
warmup_epochs: 20
|
| 16 |
+
mask_ratio_min: 0.5
|
| 17 |
+
mask_ratio_max: 0.9
|
| 18 |
+
mask_ratio_mu: 0.55
|
| 19 |
+
mask_ratio_std: 0.25
|
| 20 |
+
grad_clip: 3.0
|
| 21 |
+
device: "cuda"
|
| 22 |
+
distributed: True
|
| 23 |
+
seed: 0
|
| 24 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 25 |
+
start_epoch: 0
|
| 26 |
+
pin_mem: True
|
| 27 |
+
save: True
|
| 28 |
+
name: "exp"
|
| 29 |
+
exp_name: "MIM_vbr_eval_lossless"
|
| 30 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 31 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless" # path where to save, empty for no saving
|
| 32 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless/logs_eval" # path where to tensorboard log
|
| 33 |
+
strart_epoch: 0
|
| 34 |
+
cuda: True
|
config/eval_high_res.yaml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 5 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 128 # Batch size per GPU
|
| 8 |
+
learning_rate: 1.e-4
|
| 9 |
+
test_batch_size: 1
|
| 10 |
+
num_workers: 4
|
| 11 |
+
epochs: 100
|
| 12 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 13 |
+
min_lr: 5.e-6
|
| 14 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 15 |
+
input_size: 256
|
| 16 |
+
weight_decay: 0.05
|
| 17 |
+
warmup_epochs: 10
|
| 18 |
+
mask_ratio_min: 0.5
|
| 19 |
+
mask_ratio_max: 0.9
|
| 20 |
+
mask_ratio_mu: 0.55
|
| 21 |
+
mask_ratio_std: 0.25
|
| 22 |
+
grad_clip: 3.0
|
| 23 |
+
device: "cuda"
|
| 24 |
+
distributed: True
|
| 25 |
+
seed: 0
|
| 26 |
+
gpu_id: 0,1
|
| 27 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 28 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 29 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 30 |
+
maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
|
| 31 |
+
start_epoch: 0
|
| 32 |
+
pin_mem: True
|
| 33 |
+
save: True
|
| 34 |
+
name: "exp"
|
| 35 |
+
exp_name: "Test_High_Res"
|
| 36 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 37 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Test_High_Res" # path where to save, empty for no saving
|
| 38 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Test_High_Res/logs" # path where to tensorboard log
|
| 39 |
+
strart_epoch: 0
|
| 40 |
+
cuda: True
|
| 41 |
+
TEST: True
|
config/eval_vq.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 5 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 32 # Batch size per GPU
|
| 8 |
+
test_batch_size: 1
|
| 9 |
+
num_workers: 4
|
| 10 |
+
epochs: 400
|
| 11 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 12 |
+
min_lr: 5.e-6
|
| 13 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 14 |
+
input_size: 256
|
| 15 |
+
weight_decay: 0.05
|
| 16 |
+
warmup_epochs: 20
|
| 17 |
+
mask_ratio_min: 0.5
|
| 18 |
+
mask_ratio_max: 0.9
|
| 19 |
+
mask_ratio_mu: 0.55
|
| 20 |
+
mask_ratio_std: 0.25
|
| 21 |
+
grad_clip: 3.0
|
| 22 |
+
device: "cuda"
|
| 23 |
+
distributed: True
|
| 24 |
+
seed: 0
|
| 25 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 26 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 27 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 28 |
+
start_epoch: 0
|
| 29 |
+
pin_mem: True
|
| 30 |
+
save: True
|
| 31 |
+
name: "exp"
|
| 32 |
+
exp_name: "VQGAN"
|
| 33 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 34 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQGAN" # path where to save, empty for no saving
|
| 35 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQGAN/logs" # path where to tensorboard log
|
| 36 |
+
strart_epoch: 0
|
| 37 |
+
cuda: True
|
| 38 |
+
TEST: True
|
config/fine_tune_vq_w_mask.yaml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 5 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 64 # Batch size per GPU
|
| 8 |
+
learning_rate: 1.e-4
|
| 9 |
+
test_batch_size: 1
|
| 10 |
+
num_workers: 4
|
| 11 |
+
epochs: 100
|
| 12 |
+
blr: 1.e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 13 |
+
min_lr: 5.e-6
|
| 14 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 15 |
+
input_size: 256
|
| 16 |
+
weight_decay: 0.05
|
| 17 |
+
warmup_epochs: 10
|
| 18 |
+
mask_ratio_min: 0.25
|
| 19 |
+
mask_ratio_max: 0.75
|
| 20 |
+
grad_clip: 3.0
|
| 21 |
+
device: "cuda"
|
| 22 |
+
distributed: True
|
| 23 |
+
seed: 0
|
| 24 |
+
gpu_id: 2
|
| 25 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 26 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 27 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 28 |
+
maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
|
| 29 |
+
start_epoch: 0
|
| 30 |
+
pin_mem: True
|
| 31 |
+
save: True
|
| 32 |
+
name: "exp"
|
| 33 |
+
exp_name: "Fine_Tune_VQ_w_P"
|
| 34 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 35 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P" # path where to save, empty for no saving
|
| 36 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P/logs" # path where to tensorboard log
|
| 37 |
+
strart_epoch: 0
|
| 38 |
+
cuda: True
|
| 39 |
+
TEST: True
|
config/train_det.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "coco"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
|
| 4 |
+
kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
|
| 5 |
+
# kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
|
| 6 |
+
eval_path: ''
|
| 7 |
+
batch_size: 32 # Batch size per GPU
|
| 8 |
+
test_batch_size: 1
|
| 9 |
+
num_workers: 4
|
| 10 |
+
epochs: 400
|
| 11 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 12 |
+
min_lr: 5.e-6
|
| 13 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 14 |
+
input_size: 256
|
| 15 |
+
weight_decay: 0.05
|
| 16 |
+
warmup_epochs: 20
|
| 17 |
+
mask_ratio_min: 0.5
|
| 18 |
+
mask_ratio_max: 0.9
|
| 19 |
+
mask_ratio_mu: 0.55
|
| 20 |
+
mask_ratio_std: 0.25
|
| 21 |
+
grad_clip: 3.0
|
| 22 |
+
device: "cuda"
|
| 23 |
+
distributed: True
|
| 24 |
+
seed: 0
|
| 25 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
|
| 26 |
+
# resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
|
| 27 |
+
# resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
|
| 28 |
+
start_epoch: 0
|
| 29 |
+
pin_mem: True
|
| 30 |
+
save: True
|
| 31 |
+
name: "exp"
|
| 32 |
+
exp_name: "MIM_vbr_kodak"
|
| 33 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 34 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_kodak" # path where to save, empty for no saving
|
| 35 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_kodak/logs" # path where to tensorboard log
|
| 36 |
+
strart_epoch: 0
|
| 37 |
+
cuda: True
|
| 38 |
+
TEST: True
|
config/train_im_cls.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "imagenet"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
|
| 4 |
+
eval_path: ''
|
| 5 |
+
batch_size: 256 # Batch size per GPU
|
| 6 |
+
test_batch_size: 64
|
| 7 |
+
num_workers: 4
|
| 8 |
+
epochs: 100
|
| 9 |
+
learning_rate: 1.e-4
|
| 10 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 11 |
+
min_lr: 5.e-6
|
| 12 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 13 |
+
input_size: 256
|
| 14 |
+
weight_decay: 0.05
|
| 15 |
+
warmup_epochs: 10
|
| 16 |
+
mask_ratio_min: 0.3
|
| 17 |
+
mask_ratio_max: 0.75
|
| 18 |
+
grad_clip: 3.0
|
| 19 |
+
device: "cuda"
|
| 20 |
+
distributed: True
|
| 21 |
+
seed: 0
|
| 22 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls/checkpoint-lastbest.pth' # Path to the checkpoint to resume from
|
| 23 |
+
start_epoch: 0
|
| 24 |
+
pin_mem: True
|
| 25 |
+
save: True
|
| 26 |
+
name: "exp"
|
| 27 |
+
exp_name: "train_im_cls"
|
| 28 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 29 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls" # path where to save, empty for no saving
|
| 30 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls/logs" # path where to tensorboard log
|
| 31 |
+
strart_epoch: 0
|
| 32 |
+
cuda: True
|
config/train_mim_config.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model: mage_vit_base_patch16
|
| 2 |
+
dataset: "imagenet"
|
| 3 |
+
dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
|
| 4 |
+
eval_path: ''
|
| 5 |
+
batch_size: 64 # Batch size per GPU
|
| 6 |
+
test_batch_size: 64
|
| 7 |
+
num_workers: 4
|
| 8 |
+
epochs: 400
|
| 9 |
+
blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
|
| 10 |
+
min_lr: 5.e-6
|
| 11 |
+
accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
|
| 12 |
+
input_size: 256
|
| 13 |
+
weight_decay: 0.05
|
| 14 |
+
warmup_epochs: 20
|
| 15 |
+
mask_ratio_min: 0.5
|
| 16 |
+
mask_ratio_max: 0.9
|
| 17 |
+
mask_ratio_mu: 0.55
|
| 18 |
+
mask_ratio_std: 0.25
|
| 19 |
+
grad_clip: 3.0
|
| 20 |
+
device: "cuda"
|
| 21 |
+
distributed: True
|
| 22 |
+
seed: 0
|
| 23 |
+
resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth' # Path to the checkpoint to resume from
|
| 24 |
+
start_epoch: 0
|
| 25 |
+
pin_mem: True
|
| 26 |
+
save: True
|
| 27 |
+
name: "exp"
|
| 28 |
+
exp_name: "MIM_vbr"
|
| 29 |
+
root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
|
| 30 |
+
output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr" # path where to save, empty for no saving
|
| 31 |
+
log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/logs" # path where to tensorboard log
|
| 32 |
+
strart_epoch: 0
|
| 33 |
+
cuda: True
|
config/vqgan.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
target: taming.models.vqgan.VQModel
|
| 3 |
+
params:
|
| 4 |
+
embed_dim: 256
|
| 5 |
+
n_embed: 1024
|
| 6 |
+
ddconfig:
|
| 7 |
+
double_z: False
|
| 8 |
+
z_channels: 256
|
| 9 |
+
resolution: 256
|
| 10 |
+
in_channels: 3
|
| 11 |
+
out_ch: 3
|
| 12 |
+
ch: 128
|
| 13 |
+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
| 14 |
+
num_res_blocks: 2
|
| 15 |
+
attn_resolutions: [16]
|
| 16 |
+
dropout: 0.0
|
deploited/models_mage_codec_pos_linear.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
|
| 7 |
+
|
| 8 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
| 9 |
+
|
| 10 |
+
from taming.models.vqgan import VQModel
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
import numpy as np
|
| 13 |
+
import scipy.stats as stats
|
| 14 |
+
from compressai.entropy_models import EntropyBottleneck
|
| 15 |
+
from compressai.layers import conv3x3, subpel_conv3x3
|
| 16 |
+
import math
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from einops import rearrange, repeat
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torchac
|
| 21 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
SCALES_MIN = 0.11
|
| 24 |
+
SCALES_MAX = 256
|
| 25 |
+
SCALES_LEVELS = 64
|
| 26 |
+
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
|
| 27 |
+
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
|
| 28 |
+
|
| 29 |
+
def ste_round(x: Tensor) -> Tensor:
|
| 30 |
+
return torch.round(x) - x.detach() + x
|
| 31 |
+
|
| 32 |
+
def conv(in_channels, out_channels, kernel_size=5, stride=2):
|
| 33 |
+
return nn.Conv2d(
|
| 34 |
+
in_channels,
|
| 35 |
+
out_channels,
|
| 36 |
+
kernel_size=kernel_size,
|
| 37 |
+
stride=stride,
|
| 38 |
+
padding=kernel_size // 2,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0):
|
| 42 |
+
mask_len = mask_len.squeeze()
|
| 43 |
+
# 使用Gumbel分布进行采样,增加随机性
|
| 44 |
+
confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
|
| 45 |
+
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
| 46 |
+
# Obtains cut off threshold given the mask lengths.
|
| 47 |
+
cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
|
| 48 |
+
# Masks tokens with lower confidence.
|
| 49 |
+
masking = (confidence <= cut_off)
|
| 50 |
+
return masking
|
| 51 |
+
|
| 52 |
+
def adjust_mask_and_drop_embeddings(token_keep_mask):
|
| 53 |
+
"""
|
| 54 |
+
Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
|
| 55 |
+
some of them to False, and then applies this adjusted mask to input_embeddings.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
- input_embeddings: Tensor, The embeddings tensor.
|
| 59 |
+
- token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
- Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
|
| 63 |
+
"""
|
| 64 |
+
# 获取非零(即值为True)元素的索引
|
| 65 |
+
non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
|
| 66 |
+
# 计算非零元素的数量
|
| 67 |
+
non_zero_count = non_zero_indices[0].size(0)
|
| 68 |
+
# 计算最近的整数平方倍
|
| 69 |
+
next_square = math.floor(math.sqrt(non_zero_count))**2
|
| 70 |
+
# 计算需要移除的元素数量
|
| 71 |
+
remove_count = non_zero_count - next_square
|
| 72 |
+
if remove_count > 0:
|
| 73 |
+
# 如果需要移除元素以达到整数平方倍
|
| 74 |
+
permuted_indices = torch.randperm(non_zero_count)[:remove_count]
|
| 75 |
+
for idx in permuted_indices:
|
| 76 |
+
token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
|
| 77 |
+
# 使用更新后的token_keep_mask
|
| 78 |
+
# input_embeddings_after_drop = input_embeddings[token_keep_mask]
|
| 79 |
+
|
| 80 |
+
return token_keep_mask
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FactorizedEntropyModel(EntropyBottleneck):
|
| 84 |
+
def __init__(self, *args, **kwargs):
|
| 85 |
+
super().__init__(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
|
| 88 |
+
if training is None:
|
| 89 |
+
training = self.training
|
| 90 |
+
|
| 91 |
+
# 输入形状已经是 [b, c, seq_len],无需转置
|
| 92 |
+
shape = x.size()
|
| 93 |
+
|
| 94 |
+
# Add noise or quantize
|
| 95 |
+
means = self._get_medians()
|
| 96 |
+
# outputs = self.quantize(
|
| 97 |
+
# x, "noise" if training else "dequantize", means.long()
|
| 98 |
+
# )
|
| 99 |
+
outputs = self.quantize(
|
| 100 |
+
x, "dequantize", means.long()
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if not torch.jit.is_scripting():
|
| 104 |
+
likelihood = self._likelihood(outputs)
|
| 105 |
+
if self.use_likelihood_bound:
|
| 106 |
+
likelihood = self.likelihood_lower_bound(likelihood)
|
| 107 |
+
else:
|
| 108 |
+
raise NotImplementedError("TorchScript is not yet supported")
|
| 109 |
+
|
| 110 |
+
return outputs, likelihood
|
| 111 |
+
|
| 112 |
+
def compress(self, x):
|
| 113 |
+
# 构建索引,适用于单通道序列数据
|
| 114 |
+
indexes = self._build_indexes(x.size())
|
| 115 |
+
# 获取中位数,已经适配为单通道
|
| 116 |
+
medians = self._get_medians().detach()
|
| 117 |
+
# 调整 medians 的形状以匹配 x 的形状
|
| 118 |
+
medians = medians.expand_as(x)
|
| 119 |
+
# 调用基类的 compress 方法进行压缩
|
| 120 |
+
return super().compress(x, indexes, medians)
|
| 121 |
+
|
| 122 |
+
def decompress(self, strings, size):
|
| 123 |
+
# 预期的输出大小应包括单个通道
|
| 124 |
+
output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
|
| 125 |
+
# 构建索引
|
| 126 |
+
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
|
| 127 |
+
# 获取中位数并调整其形状以匹配预期输出的形状
|
| 128 |
+
medians = self._extend_ndims(self._get_medians().detach(), len(size))
|
| 129 |
+
medians = medians.expand(len(strings), 1, *([-1] * len(size)))
|
| 130 |
+
# 调用基类的 decompress 方法进行解压缩
|
| 131 |
+
return super().decompress(strings, indexes, medians.dtype, medians)
|
| 132 |
+
|
| 133 |
+
def _preprocess(self, x):
|
| 134 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Attention(nn.Module):
|
| 139 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_heads = num_heads
|
| 142 |
+
head_dim = dim // num_heads
|
| 143 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 144 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 145 |
+
|
| 146 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 147 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 148 |
+
self.proj = nn.Linear(dim, dim)
|
| 149 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
B, N, C = x.shape
|
| 153 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
|
| 154 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 155 |
+
|
| 156 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 157 |
+
attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
|
| 158 |
+
|
| 159 |
+
attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
|
| 160 |
+
attn = attn.softmax(dim=-1)
|
| 161 |
+
attn = self.attn_drop(attn)
|
| 162 |
+
|
| 163 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 164 |
+
x = self.proj(x)
|
| 165 |
+
x = self.proj_drop(x)
|
| 166 |
+
# x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
|
| 167 |
+
return x, attn
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class Block(nn.Module):
|
| 171 |
+
|
| 172 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 173 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.norm1 = norm_layer(dim)
|
| 176 |
+
self.attn = Attention(
|
| 177 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 178 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 179 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
|
| 180 |
+
self.norm2 = norm_layer(dim)
|
| 181 |
+
mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
|
| 182 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
|
| 183 |
+
|
| 184 |
+
def forward(self, x, return_attention=False):
|
| 185 |
+
if return_attention:
|
| 186 |
+
_, attn = self.attn(self.norm1(x))
|
| 187 |
+
return attn
|
| 188 |
+
else:
|
| 189 |
+
y, _ = self.attn(self.norm1(x))
|
| 190 |
+
x = x + self.drop_path(y)
|
| 191 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
| 196 |
+
""" NLL loss with label smoothing.
|
| 197 |
+
"""
|
| 198 |
+
def __init__(self, smoothing=0.1):
|
| 199 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
| 200 |
+
assert smoothing < 1.0
|
| 201 |
+
self.smoothing = smoothing
|
| 202 |
+
self.confidence = 1. - smoothing
|
| 203 |
+
|
| 204 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
| 206 |
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
| 207 |
+
nll_loss = nll_loss.squeeze(1)
|
| 208 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
| 209 |
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
| 210 |
+
return loss
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class BertEmbeddings(nn.Module):
|
| 214 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 215 |
+
|
| 216 |
+
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
|
| 219 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
| 220 |
+
|
| 221 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 222 |
+
# any TensorFlow checkpoint file
|
| 223 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 224 |
+
self.dropout = nn.Dropout(dropout)
|
| 225 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 226 |
+
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
|
| 227 |
+
# 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
|
| 228 |
+
# 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
|
| 229 |
+
|
| 230 |
+
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
|
| 231 |
+
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self, input_ids, z_H, z_W
|
| 235 |
+
):
|
| 236 |
+
input_shape = input_ids.size() # input_ids: (B, N)(32,1 + latent_h * latent_w)
|
| 237 |
+
seq_length = input_shape[1]
|
| 238 |
+
position_ids = self.position_ids[:, :seq_length]
|
| 239 |
+
inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
|
| 240 |
+
position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
|
| 241 |
+
|
| 242 |
+
# if seq_length != (1 + z_H * z_W):
|
| 243 |
+
# Exclude the first embedding (position 0) before interpolation
|
| 244 |
+
position_embeddings_to_interp = position_embeddings[:, 1:, :] # (b, 256, embed_dim)
|
| 245 |
+
# Perform bilinear interpolation
|
| 246 |
+
interpolated_embeddings = F.interpolate(
|
| 247 |
+
position_embeddings_to_interp.permute(0, 2, 1).reshape(1, -1, 16, 16),
|
| 248 |
+
size=(z_H, z_W), mode="bicubic", align_corners=False
|
| 249 |
+
).reshape(1, -1, z_H * z_W).permute(0, 2, 1) # (b, seq_len-1, embed_dim)
|
| 250 |
+
# Concatenate the first embedding back
|
| 251 |
+
position_embeddings = torch.cat([position_embeddings[:, :1, :], interpolated_embeddings], dim=1)
|
| 252 |
+
|
| 253 |
+
embeddings = inputs_embeds + position_embeddings
|
| 254 |
+
|
| 255 |
+
embeddings = self.LayerNorm(embeddings)
|
| 256 |
+
embeddings = self.dropout(embeddings)
|
| 257 |
+
return embeddings
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class MlmLayer(nn.Module):
|
| 261 |
+
|
| 262 |
+
def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
|
| 263 |
+
super().__init__()
|
| 264 |
+
self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
|
| 265 |
+
self.gelu = nn.GELU()
|
| 266 |
+
self.ln = nn.LayerNorm(word_emb_dim)
|
| 267 |
+
self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
|
| 268 |
+
|
| 269 |
+
def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
|
| 270 |
+
mlm_hidden = self.fc(x)
|
| 271 |
+
mlm_hidden = self.gelu(mlm_hidden)
|
| 272 |
+
mlm_hidden = self.ln(mlm_hidden)
|
| 273 |
+
word_embeddings = word_embeddings.transpose(0, 1)
|
| 274 |
+
logits = torch.matmul(mlm_hidden, word_embeddings)
|
| 275 |
+
logits = logits + self.bias
|
| 276 |
+
return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class MaskedGenerativeEncoderViT(nn.Module):
|
| 280 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 281 |
+
"""
|
| 282 |
+
def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
|
| 283 |
+
embed_dim=1024, depth=24, num_heads=16,
|
| 284 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 285 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
|
| 286 |
+
mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
|
| 287 |
+
vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
|
| 288 |
+
super().__init__()
|
| 289 |
+
|
| 290 |
+
# --------------------------------------------------------------------------
|
| 291 |
+
# VQGAN specifics
|
| 292 |
+
config = OmegaConf.load('/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/config/vqgan.yaml').model
|
| 293 |
+
self.vqgan = VQModel(ddconfig=config.params.ddconfig,
|
| 294 |
+
n_embed=config.params.n_embed, # 1024
|
| 295 |
+
embed_dim=config.params.embed_dim, # 256
|
| 296 |
+
ckpt_path=vqgan_ckpt_path)
|
| 297 |
+
for param in self.vqgan.parameters():
|
| 298 |
+
param.requires_grad = False
|
| 299 |
+
|
| 300 |
+
self.codebook_size = config.params.n_embed # 1024
|
| 301 |
+
vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
|
| 302 |
+
self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
|
| 303 |
+
self.mask_token_label = vocab_size - 1 # 2024
|
| 304 |
+
self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
|
| 305 |
+
hidden_size=embed_dim,
|
| 306 |
+
max_position_embeddings=256 +1,
|
| 307 |
+
# max_position_embeddings=256+1, # 256个patch + 1 class token
|
| 308 |
+
dropout=0.1)
|
| 309 |
+
|
| 310 |
+
# MAGE variant masking ratio
|
| 311 |
+
self.mask_ratio_min = mask_ratio_min
|
| 312 |
+
self.mask_ratio_max = mask_ratio_max
|
| 313 |
+
# self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
|
| 314 |
+
# (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
|
| 315 |
+
# loc=mask_ratio_mu, scale=mask_ratio_std)
|
| 316 |
+
|
| 317 |
+
# --------------------------------------------------------------------------
|
| 318 |
+
# MAGE encoder specifics
|
| 319 |
+
dropout_rate = 0.1
|
| 320 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
|
| 321 |
+
num_patches = self.patch_embed.num_patches # 256
|
| 322 |
+
|
| 323 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 324 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 325 |
+
|
| 326 |
+
self.blocks = nn.ModuleList([ # encoder
|
| 327 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 328 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 329 |
+
for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
|
| 330 |
+
self.norm = norm_layer(embed_dim) # layer norm
|
| 331 |
+
# --------------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
# --------------------------------------------------------------------------
|
| 334 |
+
# MAGE decoder specifics
|
| 335 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 336 |
+
|
| 337 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
|
| 338 |
+
self.pad_with_cls_token = True
|
| 339 |
+
|
| 340 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 341 |
+
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
|
| 342 |
+
|
| 343 |
+
self.decoder_blocks = nn.ModuleList([
|
| 344 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 345 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 346 |
+
for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
|
| 347 |
+
|
| 348 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 349 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
| 350 |
+
# --------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
# --------------------------------------------------------------------------
|
| 353 |
+
# MlmLayer
|
| 354 |
+
self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
|
| 355 |
+
|
| 356 |
+
self.norm_pix_loss = norm_pix_loss
|
| 357 |
+
|
| 358 |
+
self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
|
| 359 |
+
# --------------------------------------------------------------------------
|
| 360 |
+
self.entropy_bottleneck = FactorizedEntropyModel(1)
|
| 361 |
+
|
| 362 |
+
self.initialize_weights()
|
| 363 |
+
|
| 364 |
+
def initialize_weights(self):
|
| 365 |
+
# initialization
|
| 366 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 367 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 368 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 369 |
+
|
| 370 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 371 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 372 |
+
|
| 373 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 374 |
+
w = self.patch_embed.proj.weight.data
|
| 375 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 376 |
+
|
| 377 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 378 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 379 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 380 |
+
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
| 381 |
+
|
| 382 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 383 |
+
self.apply(self._init_weights)
|
| 384 |
+
|
| 385 |
+
def _init_weights(self, m):
|
| 386 |
+
if isinstance(m, nn.Linear):
|
| 387 |
+
# we use xavier_uniform following official JAX ViT:
|
| 388 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 389 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 390 |
+
nn.init.constant_(m.bias, 0)
|
| 391 |
+
elif isinstance(m, nn.LayerNorm):
|
| 392 |
+
nn.init.constant_(m.bias, 0)
|
| 393 |
+
nn.init.constant_(m.weight, 1.0)
|
| 394 |
+
|
| 395 |
+
def random_sample_mask_rate(self):
|
| 396 |
+
# 生成一个 (0, 1] 范围内的随机数
|
| 397 |
+
random_sample = 1 - torch.rand(1)
|
| 398 |
+
# 映射到 mask_ratio_min 到 mask_ratio_max 的范围
|
| 399 |
+
mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
|
| 400 |
+
return mask_rate.item() # 转换为Python的标量值
|
| 401 |
+
|
| 402 |
+
def get_cdf_token_mask(self, token_all_mask):
|
| 403 |
+
bsz, seq_len = token_all_mask.size()
|
| 404 |
+
# --- use Normal distribution.
|
| 405 |
+
dist_normal = torch.distributions.Normal(0, 2)
|
| 406 |
+
cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
|
| 407 |
+
cdf_mask_token = (cdf_mask_token - .5) * 2
|
| 408 |
+
cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
|
| 409 |
+
b=bsz, s=seq_len)
|
| 410 |
+
|
| 411 |
+
cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
|
| 412 |
+
return cdf_mask_token
|
| 413 |
+
|
| 414 |
+
def get_dec_pos_embed(self, z_H, z_W):
|
| 415 |
+
# Exclude the first embedding (position 0) before interpolation
|
| 416 |
+
position_embeddings_to_interp = self.decoder_pos_embed_learned[:, 1:, :] # (b, 256, embed_dim)
|
| 417 |
+
# Perform bilinear interpolation
|
| 418 |
+
interpolated_embeddings = F.interpolate(
|
| 419 |
+
position_embeddings_to_interp.permute(0, 2, 1).reshape(1, -1, 16, 16),
|
| 420 |
+
size=(z_H, z_W), mode="bicubic", align_corners=False
|
| 421 |
+
).reshape(1, -1, z_H * z_W).permute(0, 2, 1) # (b, seq_len -1 , embed_dim)
|
| 422 |
+
# Concatenate the first embedding back
|
| 423 |
+
position_embeddings = torch.cat([self.decoder_pos_embed_learned[:, :1, :], interpolated_embeddings], dim=1)
|
| 424 |
+
|
| 425 |
+
return position_embeddings
|
| 426 |
+
|
| 427 |
+
def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
|
| 428 |
+
"""
|
| 429 |
+
input: x: (B, 3, H, W)
|
| 430 |
+
"""
|
| 431 |
+
# ============ 1. tokenization ============ #
|
| 432 |
+
with torch.no_grad():
|
| 433 |
+
z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
|
| 434 |
+
|
| 435 |
+
_, _, z_H, z_W = z_q.size()
|
| 436 |
+
_, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
|
| 437 |
+
token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
|
| 438 |
+
gt_indices = token_indices.clone().detach().long()
|
| 439 |
+
|
| 440 |
+
# ============ 2. masking process ============ #
|
| 441 |
+
bsz, seq_len = token_indices.size() # seq_len=h*w
|
| 442 |
+
mask_ratio_min = self.mask_ratio_min # 0.5
|
| 443 |
+
|
| 444 |
+
if is_training:
|
| 445 |
+
# mask_rate = self.mask_ratio_generator.rvs(1)[0]
|
| 446 |
+
mask_rate = self.random_sample_mask_rate()
|
| 447 |
+
num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
|
| 448 |
+
else:
|
| 449 |
+
num_dropped_tokens = 0
|
| 450 |
+
if manual_mask_rate is not None:
|
| 451 |
+
mask_rate = manual_mask_rate
|
| 452 |
+
else:
|
| 453 |
+
raise ValueError("mask_rate should be provided for inference!")
|
| 454 |
+
|
| 455 |
+
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
|
| 456 |
+
mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
|
| 457 |
+
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
|
| 458 |
+
while True:
|
| 459 |
+
noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
|
| 460 |
+
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
|
| 461 |
+
if num_dropped_tokens > 0:
|
| 462 |
+
cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
|
| 463 |
+
else:
|
| 464 |
+
cutoff_drop = torch.zeros((bsz, 1), device=x.device)
|
| 465 |
+
cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
|
| 466 |
+
token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
|
| 467 |
+
token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
|
| 468 |
+
if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
|
| 469 |
+
break
|
| 470 |
+
else:
|
| 471 |
+
print("Rerandom the noise!")
|
| 472 |
+
|
| 473 |
+
# 获取unmasked token及其位置
|
| 474 |
+
unmasked_pos = token_all_mask == 0 # 未被mask的位置
|
| 475 |
+
unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
|
| 476 |
+
|
| 477 |
+
return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
|
| 478 |
+
|
| 479 |
+
def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, z_H, z_W):
|
| 480 |
+
bsz, seq_len = gt_indices.size()
|
| 481 |
+
padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
|
| 482 |
+
# 将未被mask的token填充回去
|
| 483 |
+
# 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
|
| 484 |
+
unmasked_token_counter = [0 for _ in range(bsz)]
|
| 485 |
+
|
| 486 |
+
for b in range(bsz):
|
| 487 |
+
for idx in range(seq_len):
|
| 488 |
+
# 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
|
| 489 |
+
if (token_all_mask[b, idx] == 0): # 检查是否未被mask
|
| 490 |
+
# 替换相应的unmaksed token
|
| 491 |
+
padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
|
| 492 |
+
# 更新计数器
|
| 493 |
+
unmasked_token_counter[b] += 1
|
| 494 |
+
|
| 495 |
+
token_indices = padded_token_indices
|
| 496 |
+
# ============ 3. Adding class token ============ #
|
| 497 |
+
# concate class token, add [CLS] token to aggregate sequence-level representations
|
| 498 |
+
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 499 |
+
token_indices[:, 0] = self.fake_class_label # [B, 257]
|
| 500 |
+
# Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
|
| 501 |
+
# ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
|
| 502 |
+
# 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
|
| 503 |
+
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
|
| 504 |
+
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
|
| 505 |
+
token_indices = token_indices.long()
|
| 506 |
+
|
| 507 |
+
# ============ 4. Embedding and Dropout ============ #
|
| 508 |
+
# bert embedding
|
| 509 |
+
input_embeddings = self.token_emb(token_indices, z_H, z_W) # get embeddings [B, 257, 768]
|
| 510 |
+
# print("Input embedding shape:", input_embeddings.shape)
|
| 511 |
+
bsz, seq_len, emb_dim = input_embeddings.shape
|
| 512 |
+
|
| 513 |
+
# dropping
|
| 514 |
+
token_keep_mask = 1 - token_drop_mask
|
| 515 |
+
input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
|
| 516 |
+
# print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
|
| 517 |
+
|
| 518 |
+
# ============ 5. Transformer encoding ============ #
|
| 519 |
+
x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
|
| 520 |
+
for blk in self.blocks:
|
| 521 |
+
x = blk(x) # each block has a multi-head self-attention and a mlp
|
| 522 |
+
x = self.norm(x)
|
| 523 |
+
# print("Encoder representation shape:", x.shape)
|
| 524 |
+
|
| 525 |
+
return x, token_indices, token_all_mask, token_drop_mask
|
| 526 |
+
|
| 527 |
+
def forward_decoding(self, x, token_drop_mask, token_all_mask, z_H, z_W):
|
| 528 |
+
"""
|
| 529 |
+
x: output x of forward_encoder()
|
| 530 |
+
token_drop_mask: positions for dropped tokens
|
| 531 |
+
token_all_mask: positions for masked tokens
|
| 532 |
+
"""
|
| 533 |
+
# ============ 1. Prepare Embedding and padding tokens ============ #
|
| 534 |
+
# embed tokens
|
| 535 |
+
x = self.decoder_embed(x) # input_embedding_after_padding
|
| 536 |
+
|
| 537 |
+
# append mask tokens to sequence
|
| 538 |
+
# replicates the [CLS] token embedding across the sequence length where masking is to be applied
|
| 539 |
+
if self.pad_with_cls_token: # True
|
| 540 |
+
mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
|
| 541 |
+
else:
|
| 542 |
+
mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
|
| 543 |
+
|
| 544 |
+
# ============ 2. Prepare positional embedding ============ #
|
| 545 |
+
# put undropped tokens into original sequence
|
| 546 |
+
x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
|
| 547 |
+
x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
| 548 |
+
# set undropped but masked positions with mask
|
| 549 |
+
x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
|
| 550 |
+
|
| 551 |
+
# add pos embed
|
| 552 |
+
dec_pos_embed_learned = self.get_dec_pos_embed(z_H, z_W)
|
| 553 |
+
x = x_after_pad + dec_pos_embed_learned # add learnable pos embedding
|
| 554 |
+
# x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
|
| 555 |
+
|
| 556 |
+
# apply Transformer blocks
|
| 557 |
+
for blk in self.decoder_blocks:
|
| 558 |
+
x = blk(x)
|
| 559 |
+
|
| 560 |
+
x = self.decoder_norm(x)
|
| 561 |
+
|
| 562 |
+
word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
|
| 563 |
+
logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
|
| 564 |
+
# print("Logits shape:", x.shape)
|
| 565 |
+
|
| 566 |
+
return logits
|
| 567 |
+
|
| 568 |
+
def forward_loss(self, gt_indices, logits, mask):
|
| 569 |
+
bsz, seq_len = gt_indices.size()
|
| 570 |
+
# logits and mask are with seq_len+1 but gt_indices is with seq_len
|
| 571 |
+
loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
|
| 572 |
+
loss = loss.reshape(bsz, seq_len)
|
| 573 |
+
loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
|
| 574 |
+
return loss
|
| 575 |
+
|
| 576 |
+
def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
|
| 577 |
+
lmbda = A * torch.exp(B * (1 - mask_ratio))
|
| 578 |
+
return lmbda
|
| 579 |
+
|
| 580 |
+
def cal_loss(self, logits, gt_indices, mask, mask_ratio):
|
| 581 |
+
mask_ratio = torch.tensor(mask_ratio)
|
| 582 |
+
## cal cross entropy loss
|
| 583 |
+
task_loss = self.forward_loss(gt_indices, logits, mask)
|
| 584 |
+
lmbda = self.cal_lmbda(mask_ratio)
|
| 585 |
+
## cal total loss for codec optimization
|
| 586 |
+
return task_loss, lmbda
|
| 587 |
+
|
| 588 |
+
def forward(self, imgs, is_training=False, manual_mask_rate=None):
|
| 589 |
+
## ---------- encoding process ---------- ##
|
| 590 |
+
gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
|
| 591 |
+
latent = latent.unsqueeze(1)
|
| 592 |
+
|
| 593 |
+
latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
|
| 594 |
+
# 判断latent_hat和latent是否相等
|
| 595 |
+
# print((latent_hat == latent).all())
|
| 596 |
+
cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
|
| 597 |
+
sym = (token_all_mask.short() + 1).cpu()
|
| 598 |
+
bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
|
| 599 |
+
mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
|
| 600 |
+
|
| 601 |
+
## ---------- decoding process ---------- ##
|
| 602 |
+
decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
|
| 603 |
+
decoded_mask = (decoded_sym - 1).to(device=imgs.device)
|
| 604 |
+
latent_hat = latent_hat.squeeze(1)
|
| 605 |
+
x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask, z_H, z_W)
|
| 606 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask, z_H, z_W)
|
| 607 |
+
## calculate loss
|
| 608 |
+
task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
|
| 609 |
+
return_dict = {
|
| 610 |
+
'logits': logits,
|
| 611 |
+
'likelihoods': latent_likelihoods,
|
| 612 |
+
'task_loss': task_loss,
|
| 613 |
+
'token_indices': token_indices,
|
| 614 |
+
'token_all_mask': token_all_mask,
|
| 615 |
+
'bs_mask_token': bs_mask_token,
|
| 616 |
+
'mask_ratio': mask_ratio,
|
| 617 |
+
'lambda': lmbda,
|
| 618 |
+
'mask_vis': 1 - mask_vis,
|
| 619 |
+
'z_H': z_H,
|
| 620 |
+
'z_W': z_W,
|
| 621 |
+
}
|
| 622 |
+
return return_dict
|
| 623 |
+
|
| 624 |
+
# def update(self, scale_table=None, force=False):
|
| 625 |
+
# if scale_table is None:
|
| 626 |
+
# scale_table = get_scale_table()
|
| 627 |
+
# updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
|
| 628 |
+
# updated |= super().update(force=force)
|
| 629 |
+
# return updated
|
| 630 |
+
|
| 631 |
+
def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
|
| 632 |
+
"""
|
| 633 |
+
generated image at inference
|
| 634 |
+
seed: random seed
|
| 635 |
+
logits: predicted logits by model decoder
|
| 636 |
+
token_all_mask: mask token indices
|
| 637 |
+
token_indices: token indices of the input image after the vq tokenizer
|
| 638 |
+
num_iter: number of iterations for sampling
|
| 639 |
+
choice_temperature: temperature for sampling
|
| 640 |
+
"""
|
| 641 |
+
# torch.manual_seed(seed)
|
| 642 |
+
# np.random.seed(seed)
|
| 643 |
+
bsz = logits.size(0)
|
| 644 |
+
codebook_emb_dim = 256
|
| 645 |
+
codebook_size = 1024
|
| 646 |
+
mask_token_id = self.mask_token_label
|
| 647 |
+
_CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
|
| 648 |
+
unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
|
| 649 |
+
for step in range(num_iter):
|
| 650 |
+
if step == 0:
|
| 651 |
+
cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
|
| 652 |
+
cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
|
| 653 |
+
logits = logits[:, 1:, :codebook_size]
|
| 654 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 655 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 656 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 657 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 658 |
+
# get ids for next step
|
| 659 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 660 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 661 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 662 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 663 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 664 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 665 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 666 |
+
ratio = 1. * (step + 1) / num_iter
|
| 667 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 668 |
+
|
| 669 |
+
# sample ids according to prediction confidence
|
| 670 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 671 |
+
selected_probs = torch.squeeze(
|
| 672 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 673 |
+
|
| 674 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 675 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 676 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 677 |
+
# mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
|
| 678 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 679 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 680 |
+
# one and for the next iteration
|
| 681 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 682 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 683 |
+
|
| 684 |
+
# Sample masking tokens for next iteration
|
| 685 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 686 |
+
# Masks tokens with lower confidence.
|
| 687 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 688 |
+
else:
|
| 689 |
+
cur_ids = token_indices.clone().long() # .long(): to int64
|
| 690 |
+
token_indices = torch.cat(
|
| 691 |
+
[torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 692 |
+
token_indices[:, 0] = self.fake_class_label
|
| 693 |
+
token_indices = token_indices.long()
|
| 694 |
+
token_all_mask = token_indices == mask_token_id
|
| 695 |
+
|
| 696 |
+
token_drop_mask = torch.zeros_like(token_indices)
|
| 697 |
+
|
| 698 |
+
# token embedding
|
| 699 |
+
input_embeddings = self.token_emb(token_indices, z_H, z_W) # get input embeddings
|
| 700 |
+
|
| 701 |
+
# encoder
|
| 702 |
+
x = input_embeddings
|
| 703 |
+
for blk in self.blocks:
|
| 704 |
+
x = blk(x)
|
| 705 |
+
x = self.norm(x)
|
| 706 |
+
|
| 707 |
+
# decoder
|
| 708 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask, z_H, z_W)
|
| 709 |
+
logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
|
| 710 |
+
|
| 711 |
+
# get token prediction
|
| 712 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 713 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 714 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 715 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 716 |
+
|
| 717 |
+
# get ids for next step
|
| 718 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 719 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 720 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 721 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 722 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 723 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 724 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 725 |
+
ratio = 1. * (step + 1) / num_iter
|
| 726 |
+
|
| 727 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 728 |
+
|
| 729 |
+
# sample ids according to prediction confidence
|
| 730 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 731 |
+
selected_probs = torch.squeeze(
|
| 732 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 733 |
+
|
| 734 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 735 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 736 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 737 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 738 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 739 |
+
# one and for the next iteration
|
| 740 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 741 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 742 |
+
|
| 743 |
+
# Sample masking tokens for next iteration
|
| 744 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 745 |
+
# Masks tokens with lower confidence.
|
| 746 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 747 |
+
|
| 748 |
+
# vqgan visualization
|
| 749 |
+
z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
|
| 750 |
+
gen_images = self.vqgan.decode(z_q)
|
| 751 |
+
return gen_images
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def mage_vit_base_patch16(**kwargs):
|
| 755 |
+
model = MaskedGenerativeEncoderViT(
|
| 756 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 757 |
+
decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
|
| 758 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 759 |
+
return model
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def mage_vit_large_patch16(**kwargs):
|
| 763 |
+
model = MaskedGenerativeEncoderViT(
|
| 764 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 765 |
+
decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
|
| 766 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 767 |
+
return model
|
deploited/models_mage_codec_rm_pos.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
|
| 7 |
+
|
| 8 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
| 9 |
+
|
| 10 |
+
from taming.models.vqgan import VQModel
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
import numpy as np
|
| 13 |
+
import scipy.stats as stats
|
| 14 |
+
from compressai.entropy_models import EntropyBottleneck
|
| 15 |
+
from compressai.layers import conv3x3, subpel_conv3x3
|
| 16 |
+
import math
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from einops import rearrange, repeat
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torchac
|
| 21 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
SCALES_MIN = 0.11
|
| 24 |
+
SCALES_MAX = 256
|
| 25 |
+
SCALES_LEVELS = 64
|
| 26 |
+
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
|
| 27 |
+
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
|
| 28 |
+
|
| 29 |
+
def ste_round(x: Tensor) -> Tensor:
|
| 30 |
+
return torch.round(x) - x.detach() + x
|
| 31 |
+
|
| 32 |
+
def conv(in_channels, out_channels, kernel_size=5, stride=2):
|
| 33 |
+
return nn.Conv2d(
|
| 34 |
+
in_channels,
|
| 35 |
+
out_channels,
|
| 36 |
+
kernel_size=kernel_size,
|
| 37 |
+
stride=stride,
|
| 38 |
+
padding=kernel_size // 2,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0):
|
| 42 |
+
mask_len = mask_len.squeeze()
|
| 43 |
+
# 使用Gumbel分布进行采样,增加随机性
|
| 44 |
+
confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
|
| 45 |
+
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
| 46 |
+
# Obtains cut off threshold given the mask lengths.
|
| 47 |
+
cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
|
| 48 |
+
# Masks tokens with lower confidence.
|
| 49 |
+
masking = (confidence <= cut_off)
|
| 50 |
+
return masking
|
| 51 |
+
|
| 52 |
+
def adjust_mask_and_drop_embeddings(token_keep_mask):
|
| 53 |
+
"""
|
| 54 |
+
Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
|
| 55 |
+
some of them to False, and then applies this adjusted mask to input_embeddings.
|
| 56 |
+
|
| 57 |
+
Parameters:
|
| 58 |
+
- input_embeddings: Tensor, The embeddings tensor.
|
| 59 |
+
- token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
- Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
|
| 63 |
+
"""
|
| 64 |
+
# 获取非零(即值为True)元素的索引
|
| 65 |
+
non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
|
| 66 |
+
# 计算非零元素的数量
|
| 67 |
+
non_zero_count = non_zero_indices[0].size(0)
|
| 68 |
+
# 计算最近的整数平方倍
|
| 69 |
+
next_square = math.floor(math.sqrt(non_zero_count))**2
|
| 70 |
+
# 计算需要移除的元素数量
|
| 71 |
+
remove_count = non_zero_count - next_square
|
| 72 |
+
if remove_count > 0:
|
| 73 |
+
# 如果需要移除元素以达到整数平方倍
|
| 74 |
+
permuted_indices = torch.randperm(non_zero_count)[:remove_count]
|
| 75 |
+
for idx in permuted_indices:
|
| 76 |
+
token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
|
| 77 |
+
# 使用更新后的token_keep_mask
|
| 78 |
+
# input_embeddings_after_drop = input_embeddings[token_keep_mask]
|
| 79 |
+
|
| 80 |
+
return token_keep_mask
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FactorizedEntropyModel(EntropyBottleneck):
|
| 84 |
+
def __init__(self, *args, **kwargs):
|
| 85 |
+
super().__init__(*args, **kwargs)
|
| 86 |
+
|
| 87 |
+
def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
|
| 88 |
+
if training is None:
|
| 89 |
+
training = self.training
|
| 90 |
+
|
| 91 |
+
# 输入形状已经是 [b, c, seq_len],无需转置
|
| 92 |
+
shape = x.size()
|
| 93 |
+
|
| 94 |
+
# Add noise or quantize
|
| 95 |
+
means = self._get_medians()
|
| 96 |
+
# outputs = self.quantize(
|
| 97 |
+
# x, "noise" if training else "dequantize", means.long()
|
| 98 |
+
# )
|
| 99 |
+
outputs = self.quantize(
|
| 100 |
+
x, "dequantize", means.long()
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
if not torch.jit.is_scripting():
|
| 104 |
+
likelihood = self._likelihood(outputs)
|
| 105 |
+
if self.use_likelihood_bound:
|
| 106 |
+
likelihood = self.likelihood_lower_bound(likelihood)
|
| 107 |
+
else:
|
| 108 |
+
raise NotImplementedError("TorchScript is not yet supported")
|
| 109 |
+
|
| 110 |
+
return outputs, likelihood
|
| 111 |
+
|
| 112 |
+
def compress(self, x):
|
| 113 |
+
# 构建索引,适用于单通道序列数据
|
| 114 |
+
indexes = self._build_indexes(x.size())
|
| 115 |
+
# 获取中位数,已经适配为单通道
|
| 116 |
+
medians = self._get_medians().detach()
|
| 117 |
+
# 调整 medians 的形状以匹配 x 的形状
|
| 118 |
+
medians = medians.expand_as(x)
|
| 119 |
+
# 调用基类的 compress 方法进行压缩
|
| 120 |
+
return super().compress(x, indexes, medians)
|
| 121 |
+
|
| 122 |
+
def decompress(self, strings, size):
|
| 123 |
+
# 预期的输出大小应包括单个通道
|
| 124 |
+
output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
|
| 125 |
+
# 构建索引
|
| 126 |
+
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
|
| 127 |
+
# 获取中位数并调整其形状以匹配预期输出的形状
|
| 128 |
+
medians = self._extend_ndims(self._get_medians().detach(), len(size))
|
| 129 |
+
medians = medians.expand(len(strings), 1, *([-1] * len(size)))
|
| 130 |
+
# 调用基类的 decompress 方法进行解压缩
|
| 131 |
+
return super().decompress(strings, indexes, medians.dtype, medians)
|
| 132 |
+
|
| 133 |
+
def _preprocess(self, x):
|
| 134 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 135 |
+
return x
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Attention(nn.Module):
|
| 139 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.num_heads = num_heads
|
| 142 |
+
head_dim = dim // num_heads
|
| 143 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 144 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 145 |
+
|
| 146 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 147 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 148 |
+
self.proj = nn.Linear(dim, dim)
|
| 149 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
B, N, C = x.shape
|
| 153 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
|
| 154 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 155 |
+
|
| 156 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 157 |
+
attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
|
| 158 |
+
|
| 159 |
+
attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
|
| 160 |
+
attn = attn.softmax(dim=-1)
|
| 161 |
+
attn = self.attn_drop(attn)
|
| 162 |
+
|
| 163 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 164 |
+
x = self.proj(x)
|
| 165 |
+
x = self.proj_drop(x)
|
| 166 |
+
# x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
|
| 167 |
+
return x, attn
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class Block(nn.Module):
|
| 171 |
+
|
| 172 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 173 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.norm1 = norm_layer(dim)
|
| 176 |
+
self.attn = Attention(
|
| 177 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 178 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 179 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
|
| 180 |
+
self.norm2 = norm_layer(dim)
|
| 181 |
+
mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
|
| 182 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
|
| 183 |
+
|
| 184 |
+
def forward(self, x, return_attention=False):
|
| 185 |
+
if return_attention:
|
| 186 |
+
_, attn = self.attn(self.norm1(x))
|
| 187 |
+
return attn
|
| 188 |
+
else:
|
| 189 |
+
y, _ = self.attn(self.norm1(x))
|
| 190 |
+
x = x + self.drop_path(y)
|
| 191 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
| 196 |
+
""" NLL loss with label smoothing.
|
| 197 |
+
"""
|
| 198 |
+
def __init__(self, smoothing=0.1):
|
| 199 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
| 200 |
+
assert smoothing < 1.0
|
| 201 |
+
self.smoothing = smoothing
|
| 202 |
+
self.confidence = 1. - smoothing
|
| 203 |
+
|
| 204 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
| 206 |
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
| 207 |
+
nll_loss = nll_loss.squeeze(1)
|
| 208 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
| 209 |
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
| 210 |
+
return loss
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class BertEmbeddings(nn.Module):
|
| 214 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 215 |
+
|
| 216 |
+
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
|
| 219 |
+
# self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
| 220 |
+
|
| 221 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 222 |
+
# any TensorFlow checkpoint file
|
| 223 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 224 |
+
self.dropout = nn.Dropout(dropout)
|
| 225 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 226 |
+
# self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
|
| 227 |
+
# 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
|
| 228 |
+
# 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
|
| 229 |
+
|
| 230 |
+
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
|
| 231 |
+
# torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
|
| 232 |
+
|
| 233 |
+
def forward(
|
| 234 |
+
self, input_ids
|
| 235 |
+
):
|
| 236 |
+
input_shape = input_ids.size() # input_ids: (B, N)(32,257)
|
| 237 |
+
|
| 238 |
+
seq_length = input_shape[1]
|
| 239 |
+
|
| 240 |
+
# position_ids = self.position_ids[:, :seq_length]
|
| 241 |
+
|
| 242 |
+
inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
|
| 243 |
+
|
| 244 |
+
# position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
|
| 245 |
+
# embeddings = inputs_embeds + position_embeddings
|
| 246 |
+
embeddings = inputs_embeds
|
| 247 |
+
|
| 248 |
+
embeddings = self.LayerNorm(embeddings)
|
| 249 |
+
embeddings = self.dropout(embeddings)
|
| 250 |
+
return embeddings
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class MlmLayer(nn.Module):
|
| 254 |
+
|
| 255 |
+
def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
|
| 258 |
+
self.gelu = nn.GELU()
|
| 259 |
+
self.ln = nn.LayerNorm(word_emb_dim)
|
| 260 |
+
self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
|
| 261 |
+
|
| 262 |
+
def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
|
| 263 |
+
mlm_hidden = self.fc(x)
|
| 264 |
+
mlm_hidden = self.gelu(mlm_hidden)
|
| 265 |
+
mlm_hidden = self.ln(mlm_hidden)
|
| 266 |
+
word_embeddings = word_embeddings.transpose(0, 1)
|
| 267 |
+
logits = torch.matmul(mlm_hidden, word_embeddings)
|
| 268 |
+
logits = logits + self.bias
|
| 269 |
+
return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class MaskedGenerativeEncoderViT(nn.Module):
|
| 273 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 274 |
+
"""
|
| 275 |
+
def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
|
| 276 |
+
embed_dim=1024, depth=24, num_heads=16,
|
| 277 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 278 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
|
| 279 |
+
mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
|
| 280 |
+
vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
|
| 281 |
+
super().__init__()
|
| 282 |
+
|
| 283 |
+
# --------------------------------------------------------------------------
|
| 284 |
+
# VQGAN specifics
|
| 285 |
+
config = OmegaConf.load('config/vqgan.yaml').model
|
| 286 |
+
self.vqgan = VQModel(ddconfig=config.params.ddconfig,
|
| 287 |
+
n_embed=config.params.n_embed, # 1024
|
| 288 |
+
embed_dim=config.params.embed_dim, # 256
|
| 289 |
+
ckpt_path=vqgan_ckpt_path)
|
| 290 |
+
for param in self.vqgan.parameters():
|
| 291 |
+
param.requires_grad = False
|
| 292 |
+
self.codebook_size = config.params.n_embed # 1024
|
| 293 |
+
vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
|
| 294 |
+
self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
|
| 295 |
+
self.mask_token_label = vocab_size - 1 # 2024
|
| 296 |
+
self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
|
| 297 |
+
hidden_size=embed_dim,
|
| 298 |
+
max_position_embeddings=256 +1,
|
| 299 |
+
# max_position_embeddings=256+1, # 256个patch + 1 class token
|
| 300 |
+
dropout=0.1)
|
| 301 |
+
|
| 302 |
+
# MAGE variant masking ratio
|
| 303 |
+
self.mask_ratio_min = mask_ratio_min
|
| 304 |
+
self.mask_ratio_max = mask_ratio_max
|
| 305 |
+
# self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
|
| 306 |
+
# (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
|
| 307 |
+
# loc=mask_ratio_mu, scale=mask_ratio_std)
|
| 308 |
+
|
| 309 |
+
# --------------------------------------------------------------------------
|
| 310 |
+
# MAGE encoder specifics
|
| 311 |
+
dropout_rate = 0.1
|
| 312 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
|
| 313 |
+
num_patches = self.patch_embed.num_patches
|
| 314 |
+
|
| 315 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 316 |
+
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 317 |
+
|
| 318 |
+
self.blocks = nn.ModuleList([ # encoder
|
| 319 |
+
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 320 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 321 |
+
for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
|
| 322 |
+
self.norm = norm_layer(embed_dim) # layer norm
|
| 323 |
+
# --------------------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
# --------------------------------------------------------------------------
|
| 326 |
+
# MAGE decoder specifics
|
| 327 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 328 |
+
|
| 329 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
|
| 330 |
+
self.pad_with_cls_token = True
|
| 331 |
+
|
| 332 |
+
# self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 333 |
+
# self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
|
| 334 |
+
|
| 335 |
+
self.decoder_blocks = nn.ModuleList([
|
| 336 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 337 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 338 |
+
for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
|
| 339 |
+
|
| 340 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 341 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
| 342 |
+
# --------------------------------------------------------------------------
|
| 343 |
+
|
| 344 |
+
# --------------------------------------------------------------------------
|
| 345 |
+
# MlmLayer
|
| 346 |
+
self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
|
| 347 |
+
|
| 348 |
+
self.norm_pix_loss = norm_pix_loss
|
| 349 |
+
|
| 350 |
+
self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
|
| 351 |
+
# --------------------------------------------------------------------------
|
| 352 |
+
self.entropy_bottleneck = FactorizedEntropyModel(1)
|
| 353 |
+
|
| 354 |
+
self.initialize_weights()
|
| 355 |
+
|
| 356 |
+
def initialize_weights(self):
|
| 357 |
+
# initialization
|
| 358 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 359 |
+
# pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 360 |
+
# self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 361 |
+
|
| 362 |
+
# decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 363 |
+
# self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 364 |
+
|
| 365 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 366 |
+
w = self.patch_embed.proj.weight.data
|
| 367 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 368 |
+
|
| 369 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 370 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 371 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 372 |
+
# torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
| 373 |
+
|
| 374 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 375 |
+
self.apply(self._init_weights)
|
| 376 |
+
|
| 377 |
+
def _init_weights(self, m):
|
| 378 |
+
if isinstance(m, nn.Linear):
|
| 379 |
+
# we use xavier_uniform following official JAX ViT:
|
| 380 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 381 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 382 |
+
nn.init.constant_(m.bias, 0)
|
| 383 |
+
elif isinstance(m, nn.LayerNorm):
|
| 384 |
+
nn.init.constant_(m.bias, 0)
|
| 385 |
+
nn.init.constant_(m.weight, 1.0)
|
| 386 |
+
|
| 387 |
+
def random_sample_mask_rate(self):
|
| 388 |
+
# 生成一个 (0, 1] 范围内的随机数
|
| 389 |
+
random_sample = 1 - torch.rand(1)
|
| 390 |
+
# 映射到 mask_ratio_min 到 mask_ratio_max 的范围
|
| 391 |
+
mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
|
| 392 |
+
return mask_rate.item() # 转换为Python的标量值
|
| 393 |
+
|
| 394 |
+
def get_cdf_token_mask(self, token_all_mask):
|
| 395 |
+
bsz, seq_len = token_all_mask.size()
|
| 396 |
+
# --- use Normal distribution.
|
| 397 |
+
dist_normal = torch.distributions.Normal(0, 2)
|
| 398 |
+
cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
|
| 399 |
+
cdf_mask_token = (cdf_mask_token - .5) * 2
|
| 400 |
+
cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
|
| 401 |
+
b=bsz, s=seq_len)
|
| 402 |
+
|
| 403 |
+
cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
|
| 404 |
+
return cdf_mask_token
|
| 405 |
+
|
| 406 |
+
# def get_cdf_token_mask(self, token_all_mask):
|
| 407 |
+
# bsz, seq_len = token_all_mask.size()
|
| 408 |
+
# # 直接生成一个0到1之间的线性空间
|
| 409 |
+
# linear_space = torch.linspace(0, 1, steps=seq_len+1)
|
| 410 |
+
# # 无需映射到-1到1
|
| 411 |
+
# cdf_mask_token = linear_space
|
| 412 |
+
# # 调整形状以匹配token_all_mask,并扩展到每个batch
|
| 413 |
+
# cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
|
| 414 |
+
# b=bsz, s=seq_len)
|
| 415 |
+
# # cdf_mask_token = cdf_mask_token.unsqueeze(0).unsqueeze(-1).repeat(bsz, 1, seq_len)
|
| 416 |
+
# # 添加填充以匹配原始代码的操作
|
| 417 |
+
# cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
|
| 418 |
+
# return cdf_mask_token
|
| 419 |
+
|
| 420 |
+
def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
|
| 421 |
+
"""
|
| 422 |
+
input: x: (B, 3, H, W)
|
| 423 |
+
"""
|
| 424 |
+
# ============ 1. tokenization ============ #
|
| 425 |
+
with torch.no_grad():
|
| 426 |
+
z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
|
| 427 |
+
|
| 428 |
+
_, _, z_H, z_W = z_q.size()
|
| 429 |
+
_, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
|
| 430 |
+
token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
|
| 431 |
+
gt_indices = token_indices.clone().detach().long()
|
| 432 |
+
|
| 433 |
+
# ============ 2. masking process ============ #
|
| 434 |
+
bsz, seq_len = token_indices.size() # seq_len=h*w
|
| 435 |
+
mask_ratio_min = self.mask_ratio_min # 0.5
|
| 436 |
+
|
| 437 |
+
if is_training:
|
| 438 |
+
# mask_rate = self.mask_ratio_generator.rvs(1)[0]
|
| 439 |
+
mask_rate = self.random_sample_mask_rate()
|
| 440 |
+
num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
|
| 441 |
+
else:
|
| 442 |
+
num_dropped_tokens = 0
|
| 443 |
+
if manual_mask_rate is not None:
|
| 444 |
+
mask_rate = manual_mask_rate
|
| 445 |
+
else:
|
| 446 |
+
raise ValueError("mask_rate should be provided for inference!")
|
| 447 |
+
|
| 448 |
+
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
|
| 449 |
+
mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
|
| 450 |
+
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
|
| 451 |
+
while True:
|
| 452 |
+
noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
|
| 453 |
+
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
|
| 454 |
+
if num_dropped_tokens > 0:
|
| 455 |
+
cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
|
| 456 |
+
else:
|
| 457 |
+
cutoff_drop = torch.zeros((bsz, 1), device=x.device)
|
| 458 |
+
cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
|
| 459 |
+
token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
|
| 460 |
+
token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
|
| 461 |
+
if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
|
| 462 |
+
break
|
| 463 |
+
else:
|
| 464 |
+
print("Rerandom the noise!")
|
| 465 |
+
|
| 466 |
+
# 获取unmasked token及其位置
|
| 467 |
+
unmasked_pos = token_all_mask == 0 # 未被mask的位置
|
| 468 |
+
unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
|
| 469 |
+
|
| 470 |
+
return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
|
| 471 |
+
|
| 472 |
+
def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
|
| 473 |
+
bsz, seq_len = gt_indices.size()
|
| 474 |
+
padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
|
| 475 |
+
# 将未被mask的token填充回去
|
| 476 |
+
# 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
|
| 477 |
+
unmasked_token_counter = [0 for _ in range(bsz)]
|
| 478 |
+
|
| 479 |
+
for b in range(bsz):
|
| 480 |
+
for idx in range(seq_len):
|
| 481 |
+
# 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
|
| 482 |
+
if (token_all_mask[b, idx] == 0): # 检查是否未被mask
|
| 483 |
+
# 替换相应的unmaksed token
|
| 484 |
+
padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
|
| 485 |
+
# 更新计数器
|
| 486 |
+
unmasked_token_counter[b] += 1
|
| 487 |
+
|
| 488 |
+
token_indices = padded_token_indices
|
| 489 |
+
# ============ 3. Adding class token ============ #
|
| 490 |
+
# concate class token, add [CLS] token to aggregate sequence-level representations
|
| 491 |
+
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 492 |
+
token_indices[:, 0] = self.fake_class_label # [B, 257]
|
| 493 |
+
# Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
|
| 494 |
+
# ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
|
| 495 |
+
# 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
|
| 496 |
+
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
|
| 497 |
+
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
|
| 498 |
+
token_indices = token_indices.long()
|
| 499 |
+
|
| 500 |
+
# ============ 4. Embedding and Dropout ============ #
|
| 501 |
+
# bert embedding
|
| 502 |
+
input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768]
|
| 503 |
+
# print("Input embedding shape:", input_embeddings.shape)
|
| 504 |
+
bsz, seq_len, emb_dim = input_embeddings.shape
|
| 505 |
+
|
| 506 |
+
# dropping
|
| 507 |
+
token_keep_mask = 1 - token_drop_mask
|
| 508 |
+
input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
|
| 509 |
+
# print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
|
| 510 |
+
|
| 511 |
+
# ============ 5. Transformer encoding ============ #
|
| 512 |
+
x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
|
| 513 |
+
for blk in self.blocks:
|
| 514 |
+
x = blk(x) # each block has a multi-head self-attention and a mlp
|
| 515 |
+
x = self.norm(x)
|
| 516 |
+
# print("Encoder representation shape:", x.shape)
|
| 517 |
+
|
| 518 |
+
return x, token_indices, token_all_mask, token_drop_mask
|
| 519 |
+
|
| 520 |
+
def forward_decoding(self, x, token_drop_mask, token_all_mask):
|
| 521 |
+
"""
|
| 522 |
+
x: output x of forward_encoder()
|
| 523 |
+
token_drop_mask: positions for dropped tokens
|
| 524 |
+
token_all_mask: positions for masked tokens
|
| 525 |
+
"""
|
| 526 |
+
# ============ 1. Prepare Embedding and padding tokens ============ #
|
| 527 |
+
# embed tokens
|
| 528 |
+
x = self.decoder_embed(x) # input_embedding_after_padding
|
| 529 |
+
|
| 530 |
+
# append mask tokens to sequence
|
| 531 |
+
# replicates the [CLS] token embedding across the sequence length where masking is to be applied
|
| 532 |
+
if self.pad_with_cls_token: # True
|
| 533 |
+
mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
|
| 534 |
+
else:
|
| 535 |
+
mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
|
| 536 |
+
|
| 537 |
+
# ============ 2. Prepare positional embedding ============ #
|
| 538 |
+
# put undropped tokens into original sequence
|
| 539 |
+
x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
|
| 540 |
+
x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
| 541 |
+
# set undropped but masked positions with mask
|
| 542 |
+
x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
|
| 543 |
+
|
| 544 |
+
# add pos embed
|
| 545 |
+
# x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
|
| 546 |
+
x = x_after_pad
|
| 547 |
+
|
| 548 |
+
# apply Transformer blocks
|
| 549 |
+
for blk in self.decoder_blocks:
|
| 550 |
+
x = blk(x)
|
| 551 |
+
|
| 552 |
+
x = self.decoder_norm(x)
|
| 553 |
+
|
| 554 |
+
word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
|
| 555 |
+
logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
|
| 556 |
+
# print("Logits shape:", x.shape)
|
| 557 |
+
|
| 558 |
+
return logits
|
| 559 |
+
|
| 560 |
+
def forward_loss(self, gt_indices, logits, mask):
|
| 561 |
+
bsz, seq_len = gt_indices.size()
|
| 562 |
+
# logits and mask are with seq_len+1 but gt_indices is with seq_len
|
| 563 |
+
loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
|
| 564 |
+
loss = loss.reshape(bsz, seq_len)
|
| 565 |
+
loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
|
| 566 |
+
return loss
|
| 567 |
+
|
| 568 |
+
def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
|
| 569 |
+
lmbda = A * torch.exp(B * (1 - mask_ratio))
|
| 570 |
+
return lmbda
|
| 571 |
+
|
| 572 |
+
def cal_loss(self, logits, gt_indices, mask, mask_ratio):
|
| 573 |
+
mask_ratio = torch.tensor(mask_ratio)
|
| 574 |
+
## cal cross entropy loss
|
| 575 |
+
task_loss = self.forward_loss(gt_indices, logits, mask)
|
| 576 |
+
lmbda = self.cal_lmbda(mask_ratio)
|
| 577 |
+
## cal total loss for codec optimization
|
| 578 |
+
return task_loss, lmbda
|
| 579 |
+
|
| 580 |
+
def forward(self, imgs, is_training=False, manual_mask_rate=None):
|
| 581 |
+
## ---------- encoding process ---------- ##
|
| 582 |
+
gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
|
| 583 |
+
latent = latent.unsqueeze(1)
|
| 584 |
+
|
| 585 |
+
latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
|
| 586 |
+
# 判断latent_hat和latent是否相等
|
| 587 |
+
# print((latent_hat == latent).all())
|
| 588 |
+
cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
|
| 589 |
+
sym = (token_all_mask.short() + 1).cpu()
|
| 590 |
+
bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
|
| 591 |
+
mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
|
| 592 |
+
|
| 593 |
+
## ---------- decoding process ---------- ##
|
| 594 |
+
decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
|
| 595 |
+
decoded_mask = (decoded_sym - 1).to(device=imgs.device)
|
| 596 |
+
latent_hat = latent_hat.squeeze(1)
|
| 597 |
+
x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
|
| 598 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
|
| 599 |
+
## calculate loss
|
| 600 |
+
task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
|
| 601 |
+
return_dict = {
|
| 602 |
+
'logits': logits,
|
| 603 |
+
'likelihoods': latent_likelihoods,
|
| 604 |
+
'task_loss': task_loss,
|
| 605 |
+
'token_indices': token_indices,
|
| 606 |
+
'token_all_mask': token_all_mask,
|
| 607 |
+
'bs_mask_token': bs_mask_token,
|
| 608 |
+
'mask_ratio': mask_ratio,
|
| 609 |
+
'lambda': lmbda,
|
| 610 |
+
'mask_vis': 1 - mask_vis,
|
| 611 |
+
'z_H': z_H,
|
| 612 |
+
'z_W': z_W,
|
| 613 |
+
}
|
| 614 |
+
return return_dict
|
| 615 |
+
|
| 616 |
+
# def update(self, scale_table=None, force=False):
|
| 617 |
+
# if scale_table is None:
|
| 618 |
+
# scale_table = get_scale_table()
|
| 619 |
+
# updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
|
| 620 |
+
# updated |= super().update(force=force)
|
| 621 |
+
# return updated
|
| 622 |
+
|
| 623 |
+
def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
|
| 624 |
+
"""
|
| 625 |
+
generated image at inference
|
| 626 |
+
seed: random seed
|
| 627 |
+
logits: predicted logits by model decoder
|
| 628 |
+
token_all_mask: mask token indices
|
| 629 |
+
token_indices: token indices of the input image after the vq tokenizer
|
| 630 |
+
num_iter: number of iterations for sampling
|
| 631 |
+
choice_temperature: temperature for sampling
|
| 632 |
+
"""
|
| 633 |
+
# torch.manual_seed(seed)
|
| 634 |
+
# np.random.seed(seed)
|
| 635 |
+
bsz = logits.size(0)
|
| 636 |
+
codebook_emb_dim = 256
|
| 637 |
+
codebook_size = 1024
|
| 638 |
+
mask_token_id = self.mask_token_label
|
| 639 |
+
_CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
|
| 640 |
+
unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
|
| 641 |
+
for step in range(num_iter):
|
| 642 |
+
if step == 0:
|
| 643 |
+
cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
|
| 644 |
+
cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
|
| 645 |
+
logits = logits[:, 1:, :codebook_size]
|
| 646 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 647 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 648 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 649 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 650 |
+
# get ids for next step
|
| 651 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 652 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 653 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 654 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 655 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 656 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 657 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 658 |
+
ratio = 1. * (step + 1) / num_iter
|
| 659 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 660 |
+
|
| 661 |
+
# sample ids according to prediction confidence
|
| 662 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 663 |
+
selected_probs = torch.squeeze(
|
| 664 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 665 |
+
|
| 666 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 667 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 668 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 669 |
+
# mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
|
| 670 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 671 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 672 |
+
# one and for the next iteration
|
| 673 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 674 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 675 |
+
|
| 676 |
+
# Sample masking tokens for next iteration
|
| 677 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 678 |
+
# Masks tokens with lower confidence.
|
| 679 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 680 |
+
else:
|
| 681 |
+
cur_ids = token_indices.clone().long() # .long(): to int64
|
| 682 |
+
token_indices = torch.cat(
|
| 683 |
+
[torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 684 |
+
token_indices[:, 0] = self.fake_class_label
|
| 685 |
+
token_indices = token_indices.long()
|
| 686 |
+
token_all_mask = token_indices == mask_token_id
|
| 687 |
+
|
| 688 |
+
token_drop_mask = torch.zeros_like(token_indices)
|
| 689 |
+
|
| 690 |
+
# token embedding
|
| 691 |
+
input_embeddings = self.token_emb(token_indices) # get input embeddings
|
| 692 |
+
|
| 693 |
+
# encoder
|
| 694 |
+
x = input_embeddings
|
| 695 |
+
for blk in self.blocks:
|
| 696 |
+
x = blk(x)
|
| 697 |
+
x = self.norm(x)
|
| 698 |
+
|
| 699 |
+
# decoder
|
| 700 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
|
| 701 |
+
logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
|
| 702 |
+
|
| 703 |
+
# get token prediction
|
| 704 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 705 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 706 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 707 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 708 |
+
|
| 709 |
+
# get ids for next step
|
| 710 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 711 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 712 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 713 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 714 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 715 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 716 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 717 |
+
ratio = 1. * (step + 1) / num_iter
|
| 718 |
+
|
| 719 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 720 |
+
|
| 721 |
+
# sample ids according to prediction confidence
|
| 722 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 723 |
+
selected_probs = torch.squeeze(
|
| 724 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 725 |
+
|
| 726 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 727 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 728 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 729 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 730 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 731 |
+
# one and for the next iteration
|
| 732 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 733 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 734 |
+
|
| 735 |
+
# Sample masking tokens for next iteration
|
| 736 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 737 |
+
# Masks tokens with lower confidence.
|
| 738 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 739 |
+
|
| 740 |
+
# vqgan visualization
|
| 741 |
+
z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
|
| 742 |
+
gen_images = self.vqgan.decode(z_q)
|
| 743 |
+
return gen_images
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def mage_vit_base_patch16(**kwargs):
|
| 747 |
+
model = MaskedGenerativeEncoderViT(
|
| 748 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 749 |
+
decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
|
| 750 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 751 |
+
return model
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def mage_vit_large_patch16(**kwargs):
|
| 755 |
+
model = MaskedGenerativeEncoderViT(
|
| 756 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 757 |
+
decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
|
| 758 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 759 |
+
return model
|
deploited/models_mage_codec_rope.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
|
| 7 |
+
|
| 8 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
| 9 |
+
|
| 10 |
+
from taming.models.vqgan import VQModel
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
import numpy as np
|
| 13 |
+
import scipy.stats as stats
|
| 14 |
+
from compressai.entropy_models import EntropyBottleneck
|
| 15 |
+
from compressai.layers import conv3x3, subpel_conv3x3
|
| 16 |
+
import math
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from einops import rearrange, repeat
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torchac
|
| 21 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 22 |
+
from util.RoPE import LlamaDynamicNTKScalingRotaryEmbedding, apply_rotary_pos_emb
|
| 23 |
+
|
| 24 |
+
SCALES_MIN = 0.11
|
| 25 |
+
SCALES_MAX = 256
|
| 26 |
+
SCALES_LEVELS = 64
|
| 27 |
+
def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
|
| 28 |
+
return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
|
| 29 |
+
|
| 30 |
+
def ste_round(x: Tensor) -> Tensor:
|
| 31 |
+
return torch.round(x) - x.detach() + x
|
| 32 |
+
|
| 33 |
+
def conv(in_channels, out_channels, kernel_size=5, stride=2):
|
| 34 |
+
return nn.Conv2d(
|
| 35 |
+
in_channels,
|
| 36 |
+
out_channels,
|
| 37 |
+
kernel_size=kernel_size,
|
| 38 |
+
stride=stride,
|
| 39 |
+
padding=kernel_size // 2,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def mask_by_random_topk(mask_len, probs, temperature=1.0):
|
| 43 |
+
mask_len = mask_len.squeeze()
|
| 44 |
+
# 使用Gumbel分布进行采样,增加随机性
|
| 45 |
+
confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
|
| 46 |
+
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
| 47 |
+
# Obtains cut off threshold given the mask lengths.
|
| 48 |
+
cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
|
| 49 |
+
# Masks tokens with lower confidence.
|
| 50 |
+
masking = (confidence <= cut_off)
|
| 51 |
+
return masking
|
| 52 |
+
|
| 53 |
+
def adjust_mask_and_drop_embeddings(token_keep_mask):
|
| 54 |
+
"""
|
| 55 |
+
Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
|
| 56 |
+
some of them to False, and then applies this adjusted mask to input_embeddings.
|
| 57 |
+
|
| 58 |
+
Parameters:
|
| 59 |
+
- input_embeddings: Tensor, The embeddings tensor.
|
| 60 |
+
- token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
- Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
|
| 64 |
+
"""
|
| 65 |
+
# 获取非零(即值为True)元素的索引
|
| 66 |
+
non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
|
| 67 |
+
# 计算非零元素的数量
|
| 68 |
+
non_zero_count = non_zero_indices[0].size(0)
|
| 69 |
+
# 计算最近的整数平方倍
|
| 70 |
+
next_square = math.floor(math.sqrt(non_zero_count))**2
|
| 71 |
+
# 计算需要移除的元素数量
|
| 72 |
+
remove_count = non_zero_count - next_square
|
| 73 |
+
if remove_count > 0:
|
| 74 |
+
# 如果需要移除元素以达到整数平方倍
|
| 75 |
+
permuted_indices = torch.randperm(non_zero_count)[:remove_count]
|
| 76 |
+
for idx in permuted_indices:
|
| 77 |
+
token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
|
| 78 |
+
# 使用更新后的token_keep_mask
|
| 79 |
+
# input_embeddings_after_drop = input_embeddings[token_keep_mask]
|
| 80 |
+
|
| 81 |
+
return token_keep_mask
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class FactorizedEntropyModel(EntropyBottleneck):
|
| 85 |
+
def __init__(self, *args, **kwargs):
|
| 86 |
+
super().__init__(*args, **kwargs)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
|
| 89 |
+
if training is None:
|
| 90 |
+
training = self.training
|
| 91 |
+
|
| 92 |
+
# 输入形状已经是 [b, c, seq_len],无需转置
|
| 93 |
+
shape = x.size()
|
| 94 |
+
|
| 95 |
+
# Add noise or quantize
|
| 96 |
+
means = self._get_medians()
|
| 97 |
+
# outputs = self.quantize(
|
| 98 |
+
# x, "noise" if training else "dequantize", means.long()
|
| 99 |
+
# )
|
| 100 |
+
outputs = self.quantize(
|
| 101 |
+
x, "dequantize", means.long()
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if not torch.jit.is_scripting():
|
| 105 |
+
likelihood = self._likelihood(outputs)
|
| 106 |
+
if self.use_likelihood_bound:
|
| 107 |
+
likelihood = self.likelihood_lower_bound(likelihood)
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError("TorchScript is not yet supported")
|
| 110 |
+
|
| 111 |
+
return outputs, likelihood
|
| 112 |
+
|
| 113 |
+
def compress(self, x):
|
| 114 |
+
# 构建索引,适用于单通道序列数据
|
| 115 |
+
indexes = self._build_indexes(x.size())
|
| 116 |
+
# 获取中位数,已经适配为单通道
|
| 117 |
+
medians = self._get_medians().detach()
|
| 118 |
+
# 调整 medians 的形状以匹配 x 的形状
|
| 119 |
+
medians = medians.expand_as(x)
|
| 120 |
+
# 调用基类的 compress 方法进行压缩
|
| 121 |
+
return super().compress(x, indexes, medians)
|
| 122 |
+
|
| 123 |
+
def decompress(self, strings, size):
|
| 124 |
+
# 预期的输出大小应包括单个通道
|
| 125 |
+
output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
|
| 126 |
+
# 构建索引
|
| 127 |
+
indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
|
| 128 |
+
# 获取中位数并调整其形状以匹配预期输出的形状
|
| 129 |
+
medians = self._extend_ndims(self._get_medians().detach(), len(size))
|
| 130 |
+
medians = medians.expand(len(strings), 1, *([-1] * len(size)))
|
| 131 |
+
# 调用基类的 decompress 方法进行解压缩
|
| 132 |
+
return super().decompress(strings, indexes, medians.dtype, medians)
|
| 133 |
+
|
| 134 |
+
def _preprocess(self, x):
|
| 135 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class Attention(nn.Module):
|
| 140 |
+
def __init__(self, dim, max_position_embeddings, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.max_position_embeddings = max_position_embeddings
|
| 143 |
+
self.num_heads = num_heads
|
| 144 |
+
head_dim = dim // num_heads
|
| 145 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 146 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 147 |
+
|
| 148 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 149 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 150 |
+
self.proj = nn.Linear(dim, dim)
|
| 151 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 152 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
| 153 |
+
head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=8
|
| 154 |
+
)
|
| 155 |
+
# self.register_buffer("position_ids", torch.arange(self.max_position_embeddings).expand((1, -1)))
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
B, N, C = x.shape
|
| 159 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
|
| 160 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
| 161 |
+
cos, sin = self.rotary_emb(v, seq_len=N)
|
| 162 |
+
position_ids = torch.arange(N).expand((1, -1))
|
| 163 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
| 164 |
+
|
| 165 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 166 |
+
attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
|
| 167 |
+
|
| 168 |
+
attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
|
| 169 |
+
attn = attn.softmax(dim=-1)
|
| 170 |
+
attn = self.attn_drop(attn)
|
| 171 |
+
|
| 172 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 173 |
+
x = self.proj(x)
|
| 174 |
+
x = self.proj_drop(x)
|
| 175 |
+
# x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
|
| 176 |
+
return x, attn
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Block(nn.Module):
|
| 180 |
+
|
| 181 |
+
def __init__(self, dim, max_position_embeddings, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 182 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.norm1 = norm_layer(dim)
|
| 185 |
+
self.attn = Attention(
|
| 186 |
+
dim, max_position_embeddings, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 187 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 188 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
|
| 189 |
+
self.norm2 = norm_layer(dim)
|
| 190 |
+
mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
|
| 191 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
|
| 192 |
+
|
| 193 |
+
def forward(self, x, return_attention=False):
|
| 194 |
+
if return_attention:
|
| 195 |
+
_, attn = self.attn(self.norm1(x))
|
| 196 |
+
return attn
|
| 197 |
+
else:
|
| 198 |
+
y, _ = self.attn(self.norm1(x))
|
| 199 |
+
x = x + self.drop_path(y)
|
| 200 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
| 205 |
+
""" NLL loss with label smoothing.
|
| 206 |
+
"""
|
| 207 |
+
def __init__(self, smoothing=0.1):
|
| 208 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
| 209 |
+
assert smoothing < 1.0
|
| 210 |
+
self.smoothing = smoothing
|
| 211 |
+
self.confidence = 1. - smoothing
|
| 212 |
+
|
| 213 |
+
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 214 |
+
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
|
| 215 |
+
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
|
| 216 |
+
nll_loss = nll_loss.squeeze(1)
|
| 217 |
+
smooth_loss = -logprobs.mean(dim=-1)
|
| 218 |
+
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
|
| 219 |
+
return loss
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class BertEmbeddings(nn.Module):
|
| 223 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
| 224 |
+
|
| 225 |
+
def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
|
| 228 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
| 229 |
+
|
| 230 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
| 231 |
+
# any TensorFlow checkpoint file
|
| 232 |
+
self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 233 |
+
self.dropout = nn.Dropout(dropout)
|
| 234 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
| 235 |
+
self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
|
| 236 |
+
# 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0���max_position_embeddings-1的整数。
|
| 237 |
+
# 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
|
| 238 |
+
|
| 239 |
+
torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
|
| 240 |
+
torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self, input_ids
|
| 244 |
+
):
|
| 245 |
+
input_shape = input_ids.size() # input_ids: (B, N)(32,1 + latent_h * latent_w)
|
| 246 |
+
seq_length = input_shape[1]
|
| 247 |
+
# position_ids = self.position_ids[:, :seq_length]
|
| 248 |
+
inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
|
| 249 |
+
# position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
|
| 250 |
+
|
| 251 |
+
# embeddings = inputs_embeds + position_embeddings
|
| 252 |
+
embeddings = inputs_embeds
|
| 253 |
+
|
| 254 |
+
embeddings = self.LayerNorm(embeddings)
|
| 255 |
+
embeddings = self.dropout(embeddings)
|
| 256 |
+
return embeddings
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class MlmLayer(nn.Module):
|
| 260 |
+
|
| 261 |
+
def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
|
| 264 |
+
self.gelu = nn.GELU()
|
| 265 |
+
self.ln = nn.LayerNorm(word_emb_dim)
|
| 266 |
+
self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
|
| 267 |
+
|
| 268 |
+
def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
|
| 269 |
+
mlm_hidden = self.fc(x)
|
| 270 |
+
mlm_hidden = self.gelu(mlm_hidden)
|
| 271 |
+
mlm_hidden = self.ln(mlm_hidden)
|
| 272 |
+
word_embeddings = word_embeddings.transpose(0, 1)
|
| 273 |
+
logits = torch.matmul(mlm_hidden, word_embeddings)
|
| 274 |
+
logits = logits + self.bias
|
| 275 |
+
return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MaskedGenerativeEncoderViT(nn.Module):
|
| 279 |
+
""" Masked Autoencoder with VisionTransformer backbone
|
| 280 |
+
"""
|
| 281 |
+
def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
|
| 282 |
+
embed_dim=1024, depth=24, num_heads=16,
|
| 283 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
| 284 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
|
| 285 |
+
mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
|
| 286 |
+
vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
|
| 287 |
+
super().__init__()
|
| 288 |
+
|
| 289 |
+
# --------------------------------------------------------------------------
|
| 290 |
+
# VQGAN specifics
|
| 291 |
+
config = OmegaConf.load('/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/vqgan.yaml').model
|
| 292 |
+
self.vqgan = VQModel(ddconfig=config.params.ddconfig,
|
| 293 |
+
n_embed=config.params.n_embed, # 1024
|
| 294 |
+
embed_dim=config.params.embed_dim, # 256
|
| 295 |
+
ckpt_path=vqgan_ckpt_path)
|
| 296 |
+
for param in self.vqgan.parameters():
|
| 297 |
+
param.requires_grad = False
|
| 298 |
+
|
| 299 |
+
self.codebook_size = config.params.n_embed # 1024
|
| 300 |
+
vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
|
| 301 |
+
self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
|
| 302 |
+
self.mask_token_label = vocab_size - 1 # 2024
|
| 303 |
+
self.max_position_embeddings = 256 + 1
|
| 304 |
+
self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
|
| 305 |
+
hidden_size=embed_dim,
|
| 306 |
+
max_position_embeddings=256 +1,
|
| 307 |
+
# max_position_embeddings=256+1, # 256个patch + 1 class token
|
| 308 |
+
dropout=0.1)
|
| 309 |
+
|
| 310 |
+
# MAGE variant masking ratio
|
| 311 |
+
self.mask_ratio_min = mask_ratio_min
|
| 312 |
+
self.mask_ratio_max = mask_ratio_max
|
| 313 |
+
# self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
|
| 314 |
+
# (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
|
| 315 |
+
# loc=mask_ratio_mu, scale=mask_ratio_std)
|
| 316 |
+
|
| 317 |
+
# --------------------------------------------------------------------------
|
| 318 |
+
# MAGE encoder specifics
|
| 319 |
+
dropout_rate = 0.1
|
| 320 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
|
| 321 |
+
num_patches = self.patch_embed.num_patches # 256
|
| 322 |
+
|
| 323 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 324 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 325 |
+
|
| 326 |
+
self.blocks = nn.ModuleList([ # encoder
|
| 327 |
+
Block(embed_dim, self.max_position_embeddings, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 328 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 329 |
+
for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
|
| 330 |
+
self.norm = norm_layer(embed_dim) # layer norm
|
| 331 |
+
# --------------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
# --------------------------------------------------------------------------
|
| 334 |
+
# MAGE decoder specifics
|
| 335 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
| 336 |
+
|
| 337 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
|
| 338 |
+
self.pad_with_cls_token = True
|
| 339 |
+
|
| 340 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
|
| 341 |
+
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
|
| 342 |
+
|
| 343 |
+
self.decoder_blocks = nn.ModuleList([
|
| 344 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
|
| 345 |
+
drop=dropout_rate, attn_drop=dropout_rate)
|
| 346 |
+
for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
|
| 347 |
+
|
| 348 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 349 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
|
| 350 |
+
# --------------------------------------------------------------------------
|
| 351 |
+
|
| 352 |
+
# --------------------------------------------------------------------------
|
| 353 |
+
# MlmLayer
|
| 354 |
+
self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
|
| 355 |
+
|
| 356 |
+
self.norm_pix_loss = norm_pix_loss
|
| 357 |
+
|
| 358 |
+
self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
|
| 359 |
+
# --------------------------------------------------------------------------
|
| 360 |
+
self.entropy_bottleneck = FactorizedEntropyModel(1)
|
| 361 |
+
|
| 362 |
+
self.initialize_weights()
|
| 363 |
+
|
| 364 |
+
def initialize_weights(self):
|
| 365 |
+
# initialization
|
| 366 |
+
# initialize (and freeze) pos_embed by sin-cos embedding
|
| 367 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 368 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 369 |
+
|
| 370 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
|
| 371 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 372 |
+
|
| 373 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 374 |
+
w = self.patch_embed.proj.weight.data
|
| 375 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 376 |
+
|
| 377 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 378 |
+
torch.nn.init.normal_(self.cls_token, std=.02)
|
| 379 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 380 |
+
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
|
| 381 |
+
|
| 382 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 383 |
+
self.apply(self._init_weights)
|
| 384 |
+
|
| 385 |
+
def _init_weights(self, m):
|
| 386 |
+
if isinstance(m, nn.Linear):
|
| 387 |
+
# we use xavier_uniform following official JAX ViT:
|
| 388 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 389 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 390 |
+
nn.init.constant_(m.bias, 0)
|
| 391 |
+
elif isinstance(m, nn.LayerNorm):
|
| 392 |
+
nn.init.constant_(m.bias, 0)
|
| 393 |
+
nn.init.constant_(m.weight, 1.0)
|
| 394 |
+
|
| 395 |
+
def random_sample_mask_rate(self):
|
| 396 |
+
# 生成一个 (0, 1] 范围内的随机数
|
| 397 |
+
random_sample = 1 - torch.rand(1)
|
| 398 |
+
# 映射到 mask_ratio_min 到 mask_ratio_max 的范围
|
| 399 |
+
mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
|
| 400 |
+
return mask_rate.item() # 转换为Python的标量值
|
| 401 |
+
|
| 402 |
+
def get_cdf_token_mask(self, token_all_mask):
|
| 403 |
+
bsz, seq_len = token_all_mask.size()
|
| 404 |
+
# --- use Normal distribution.
|
| 405 |
+
dist_normal = torch.distributions.Normal(0, 2)
|
| 406 |
+
cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
|
| 407 |
+
cdf_mask_token = (cdf_mask_token - .5) * 2
|
| 408 |
+
cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
|
| 409 |
+
b=bsz, s=seq_len)
|
| 410 |
+
|
| 411 |
+
cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
|
| 412 |
+
return cdf_mask_token
|
| 413 |
+
|
| 414 |
+
# def get_dec_pos_embed(self, z_H, z_W):
|
| 415 |
+
# # Exclude the first embedding (position 0) before interpolation
|
| 416 |
+
# position_embeddings = self.decoder_pos_embed_learned # (b, 256, embed_dim)
|
| 417 |
+
# # Perform bilinear interpolation
|
| 418 |
+
# interpolated_embeddings = F.interpolate(
|
| 419 |
+
# position_embeddings.permute(0, 2, 1),
|
| 420 |
+
# size=(1+z_H*z_W), mode="linear", align_corners=False
|
| 421 |
+
# ).permute(0, 2, 1) # (b, seq_len -1 , embed_dim)
|
| 422 |
+
# # Concatenate the first embedding back
|
| 423 |
+
# # position_embeddings = torch.cat([self.decoder_pos_embed_learned[:, :1, :], interpolated_embeddings], dim=1)
|
| 424 |
+
# position_embeddings = interpolated_embeddings
|
| 425 |
+
|
| 426 |
+
# return position_embeddings
|
| 427 |
+
|
| 428 |
+
def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
|
| 429 |
+
"""
|
| 430 |
+
input: x: (B, 3, H, W)
|
| 431 |
+
"""
|
| 432 |
+
# ============ 1. tokenization ============ #
|
| 433 |
+
with torch.no_grad():
|
| 434 |
+
z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
|
| 435 |
+
|
| 436 |
+
_, _, z_H, z_W = z_q.size()
|
| 437 |
+
_, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
|
| 438 |
+
token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
|
| 439 |
+
gt_indices = token_indices.clone().detach().long()
|
| 440 |
+
|
| 441 |
+
# ============ 2. masking process ============ #
|
| 442 |
+
bsz, seq_len = token_indices.size() # seq_len=h*w
|
| 443 |
+
mask_ratio_min = self.mask_ratio_min # 0.5
|
| 444 |
+
|
| 445 |
+
if is_training:
|
| 446 |
+
# mask_rate = self.mask_ratio_generator.rvs(1)[0]
|
| 447 |
+
mask_rate = self.random_sample_mask_rate()
|
| 448 |
+
num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
|
| 449 |
+
else:
|
| 450 |
+
num_dropped_tokens = 0
|
| 451 |
+
if manual_mask_rate is not None:
|
| 452 |
+
mask_rate = manual_mask_rate
|
| 453 |
+
else:
|
| 454 |
+
raise ValueError("mask_rate should be provided for inference!")
|
| 455 |
+
|
| 456 |
+
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
|
| 457 |
+
mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
|
| 458 |
+
# it is possible that two elements of the noise is the same, so do a while loop to avoid it
|
| 459 |
+
while True:
|
| 460 |
+
noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
|
| 461 |
+
sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
|
| 462 |
+
if num_dropped_tokens > 0:
|
| 463 |
+
cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
|
| 464 |
+
else:
|
| 465 |
+
cutoff_drop = torch.zeros((bsz, 1), device=x.device)
|
| 466 |
+
cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
|
| 467 |
+
token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
|
| 468 |
+
token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
|
| 469 |
+
if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
|
| 470 |
+
break
|
| 471 |
+
else:
|
| 472 |
+
print("Rerandom the noise!")
|
| 473 |
+
|
| 474 |
+
# 获取unmasked token及其位置
|
| 475 |
+
unmasked_pos = token_all_mask == 0 # 未被mask的位置
|
| 476 |
+
unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
|
| 477 |
+
|
| 478 |
+
return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
|
| 479 |
+
|
| 480 |
+
def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
|
| 481 |
+
bsz, seq_len = gt_indices.size()
|
| 482 |
+
padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
|
| 483 |
+
# 将未被mask的token填充回去
|
| 484 |
+
# 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
|
| 485 |
+
unmasked_token_counter = [0 for _ in range(bsz)]
|
| 486 |
+
|
| 487 |
+
for b in range(bsz):
|
| 488 |
+
for idx in range(seq_len):
|
| 489 |
+
# 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
|
| 490 |
+
if (token_all_mask[b, idx] == 0): # 检查是否未被mask
|
| 491 |
+
# 替换相应的unmaksed token
|
| 492 |
+
padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
|
| 493 |
+
# 更新计数器
|
| 494 |
+
unmasked_token_counter[b] += 1
|
| 495 |
+
|
| 496 |
+
token_indices = padded_token_indices
|
| 497 |
+
# ============ 3. Adding class token ============ #
|
| 498 |
+
# concate class token, add [CLS] token to aggregate sequence-level representations
|
| 499 |
+
token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 500 |
+
token_indices[:, 0] = self.fake_class_label # [B, 257]
|
| 501 |
+
# Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
|
| 502 |
+
# ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
|
| 503 |
+
# 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
|
| 504 |
+
token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
|
| 505 |
+
token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
|
| 506 |
+
token_indices = token_indices.long()
|
| 507 |
+
|
| 508 |
+
# ============ 4. Embedding and Dropout ============ #
|
| 509 |
+
# bert embedding
|
| 510 |
+
input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768]
|
| 511 |
+
# print("Input embedding shape:", input_embeddings.shape)
|
| 512 |
+
bsz, seq_len, emb_dim = input_embeddings.shape
|
| 513 |
+
|
| 514 |
+
# dropping
|
| 515 |
+
token_keep_mask = 1 - token_drop_mask
|
| 516 |
+
input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
|
| 517 |
+
# print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
|
| 518 |
+
|
| 519 |
+
# ============ 5. Transformer encoding ============ #
|
| 520 |
+
x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
|
| 521 |
+
for blk in self.blocks:
|
| 522 |
+
x = blk(x) # each block has a multi-head self-attention and a mlp
|
| 523 |
+
x = self.norm(x)
|
| 524 |
+
# print("Encoder representation shape:", x.shape)
|
| 525 |
+
|
| 526 |
+
return x, token_indices, token_all_mask, token_drop_mask
|
| 527 |
+
|
| 528 |
+
def forward_decoding(self, x, token_drop_mask, token_all_mask):
|
| 529 |
+
"""
|
| 530 |
+
x: output x of forward_encoder()
|
| 531 |
+
token_drop_mask: positions for dropped tokens
|
| 532 |
+
token_all_mask: positions for masked tokens
|
| 533 |
+
"""
|
| 534 |
+
# ============ 1. Prepare Embedding and padding tokens ============ #
|
| 535 |
+
# embed tokens
|
| 536 |
+
x = self.decoder_embed(x) # input_embedding_after_padding
|
| 537 |
+
|
| 538 |
+
# append mask tokens to sequence
|
| 539 |
+
# replicates the [CLS] token embedding across the sequence length where masking is to be applied
|
| 540 |
+
if self.pad_with_cls_token: # True
|
| 541 |
+
mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
|
| 542 |
+
else:
|
| 543 |
+
mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
|
| 544 |
+
|
| 545 |
+
# ============ 2. Prepare positional embedding ============ #
|
| 546 |
+
# put undropped tokens into original sequence
|
| 547 |
+
x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
|
| 548 |
+
x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
|
| 549 |
+
# set undropped but masked positions with mask
|
| 550 |
+
x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
|
| 551 |
+
|
| 552 |
+
# add pos embed
|
| 553 |
+
# x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
|
| 554 |
+
x = x_after_pad
|
| 555 |
+
|
| 556 |
+
# apply Transformer blocks
|
| 557 |
+
for blk in self.decoder_blocks:
|
| 558 |
+
x = blk(x)
|
| 559 |
+
|
| 560 |
+
x = self.decoder_norm(x)
|
| 561 |
+
|
| 562 |
+
word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
|
| 563 |
+
logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
|
| 564 |
+
# print("Logits shape:", x.shape)
|
| 565 |
+
|
| 566 |
+
return logits
|
| 567 |
+
|
| 568 |
+
def forward_loss(self, gt_indices, logits, mask):
|
| 569 |
+
bsz, seq_len = gt_indices.size()
|
| 570 |
+
# logits and mask are with seq_len+1 but gt_indices is with seq_len
|
| 571 |
+
loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
|
| 572 |
+
loss = loss.reshape(bsz, seq_len)
|
| 573 |
+
loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
|
| 574 |
+
return loss
|
| 575 |
+
|
| 576 |
+
def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
|
| 577 |
+
lmbda = A * torch.exp(B * (1 - mask_ratio))
|
| 578 |
+
return lmbda
|
| 579 |
+
|
| 580 |
+
def cal_loss(self, logits, gt_indices, mask, mask_ratio):
|
| 581 |
+
mask_ratio = torch.tensor(mask_ratio)
|
| 582 |
+
## cal cross entropy loss
|
| 583 |
+
task_loss = self.forward_loss(gt_indices, logits, mask)
|
| 584 |
+
lmbda = self.cal_lmbda(mask_ratio)
|
| 585 |
+
## cal total loss for codec optimization
|
| 586 |
+
return task_loss, lmbda
|
| 587 |
+
|
| 588 |
+
def forward(self, imgs, is_training=False, manual_mask_rate=None):
|
| 589 |
+
## ---------- encoding process ---------- ##
|
| 590 |
+
gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
|
| 591 |
+
latent = latent.unsqueeze(1)
|
| 592 |
+
|
| 593 |
+
latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
|
| 594 |
+
# 判断latent_hat和latent是否相等
|
| 595 |
+
# print((latent_hat == latent).all())
|
| 596 |
+
cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
|
| 597 |
+
sym = (token_all_mask.short() + 1).cpu()
|
| 598 |
+
bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
|
| 599 |
+
mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
|
| 600 |
+
|
| 601 |
+
## ---------- decoding process ---------- ##
|
| 602 |
+
decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
|
| 603 |
+
decoded_mask = (decoded_sym - 1).to(device=imgs.device)
|
| 604 |
+
latent_hat = latent_hat.squeeze(1)
|
| 605 |
+
x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
|
| 606 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
|
| 607 |
+
## calculate loss
|
| 608 |
+
task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
|
| 609 |
+
return_dict = {
|
| 610 |
+
'logits': logits,
|
| 611 |
+
'likelihoods': latent_likelihoods,
|
| 612 |
+
'task_loss': task_loss,
|
| 613 |
+
'token_indices': token_indices,
|
| 614 |
+
'token_all_mask': token_all_mask,
|
| 615 |
+
'bs_mask_token': bs_mask_token,
|
| 616 |
+
'mask_ratio': mask_ratio,
|
| 617 |
+
'lambda': lmbda,
|
| 618 |
+
'mask_vis': 1 - mask_vis,
|
| 619 |
+
'z_H': z_H,
|
| 620 |
+
'z_W': z_W,
|
| 621 |
+
}
|
| 622 |
+
return return_dict
|
| 623 |
+
|
| 624 |
+
# def update(self, scale_table=None, force=False):
|
| 625 |
+
# if scale_table is None:
|
| 626 |
+
# scale_table = get_scale_table()
|
| 627 |
+
# updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
|
| 628 |
+
# updated |= super().update(force=force)
|
| 629 |
+
# return updated
|
| 630 |
+
|
| 631 |
+
def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
|
| 632 |
+
"""
|
| 633 |
+
generated image at inference
|
| 634 |
+
seed: random seed
|
| 635 |
+
logits: predicted logits by model decoder
|
| 636 |
+
token_all_mask: mask token indices
|
| 637 |
+
token_indices: token indices of the input image after the vq tokenizer
|
| 638 |
+
num_iter: number of iterations for sampling
|
| 639 |
+
choice_temperature: temperature for sampling
|
| 640 |
+
"""
|
| 641 |
+
# torch.manual_seed(seed)
|
| 642 |
+
# np.random.seed(seed)
|
| 643 |
+
bsz = logits.size(0)
|
| 644 |
+
codebook_emb_dim = 256
|
| 645 |
+
codebook_size = 1024
|
| 646 |
+
mask_token_id = self.mask_token_label
|
| 647 |
+
_CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
|
| 648 |
+
unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
|
| 649 |
+
for step in range(num_iter):
|
| 650 |
+
if step == 0:
|
| 651 |
+
cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
|
| 652 |
+
cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
|
| 653 |
+
logits = logits[:, 1:, :codebook_size]
|
| 654 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 655 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 656 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 657 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 658 |
+
# get ids for next step
|
| 659 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 660 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 661 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 662 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 663 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 664 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 665 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 666 |
+
ratio = 1. * (step + 1) / num_iter
|
| 667 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 668 |
+
|
| 669 |
+
# sample ids according to prediction confidence
|
| 670 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 671 |
+
selected_probs = torch.squeeze(
|
| 672 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 673 |
+
|
| 674 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 675 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 676 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 677 |
+
# mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
|
| 678 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 679 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 680 |
+
# one and for the next iteration
|
| 681 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 682 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 683 |
+
|
| 684 |
+
# Sample masking tokens for next iteration
|
| 685 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 686 |
+
# Masks tokens with lower confidence.
|
| 687 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 688 |
+
else:
|
| 689 |
+
cur_ids = token_indices.clone().long() # .long(): to int64
|
| 690 |
+
token_indices = torch.cat(
|
| 691 |
+
[torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
|
| 692 |
+
token_indices[:, 0] = self.fake_class_label
|
| 693 |
+
token_indices = token_indices.long()
|
| 694 |
+
token_all_mask = token_indices == mask_token_id
|
| 695 |
+
|
| 696 |
+
token_drop_mask = torch.zeros_like(token_indices)
|
| 697 |
+
|
| 698 |
+
# token embedding
|
| 699 |
+
input_embeddings = self.token_emb(token_indices) # get input embeddings
|
| 700 |
+
|
| 701 |
+
# encoder
|
| 702 |
+
x = input_embeddings
|
| 703 |
+
for blk in self.blocks:
|
| 704 |
+
x = blk(x)
|
| 705 |
+
x = self.norm(x)
|
| 706 |
+
|
| 707 |
+
# decoder
|
| 708 |
+
logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
|
| 709 |
+
logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
|
| 710 |
+
|
| 711 |
+
# get token prediction
|
| 712 |
+
# the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
|
| 713 |
+
# For iter=6, we use categorical sampling and temp=4.5."
|
| 714 |
+
sample_dist = torch.distributions.categorical.Categorical(logits=logits)
|
| 715 |
+
sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
|
| 716 |
+
|
| 717 |
+
# get ids for next step
|
| 718 |
+
# unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
|
| 719 |
+
# 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
|
| 720 |
+
# 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
|
| 721 |
+
unknown_map = (cur_ids == mask_token_id)
|
| 722 |
+
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
|
| 723 |
+
# Defines the mask ratio for the next round. The number to mask out is
|
| 724 |
+
# determined by mask_ratio * unknown_number_in_the_beginning.
|
| 725 |
+
ratio = 1. * (step + 1) / num_iter
|
| 726 |
+
|
| 727 |
+
mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
|
| 728 |
+
|
| 729 |
+
# sample ids according to prediction confidence
|
| 730 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
| 731 |
+
selected_probs = torch.squeeze(
|
| 732 |
+
torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
| 733 |
+
|
| 734 |
+
selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
|
| 735 |
+
unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
|
| 736 |
+
mask_ratio = torch.tensor(mask_ratio).cuda()
|
| 737 |
+
mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
|
| 738 |
+
# Keeps at least one of prediction in this round and also masks out at least
|
| 739 |
+
# one and for the next iteration
|
| 740 |
+
mask_len = torch.maximum(torch.Tensor([1]).cuda(),
|
| 741 |
+
torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
|
| 742 |
+
|
| 743 |
+
# Sample masking tokens for next iteration
|
| 744 |
+
masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
|
| 745 |
+
# Masks tokens with lower confidence.
|
| 746 |
+
token_indices = torch.where(masking, mask_token_id, sampled_ids)
|
| 747 |
+
|
| 748 |
+
# vqgan visualization
|
| 749 |
+
z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
|
| 750 |
+
gen_images = self.vqgan.decode(z_q)
|
| 751 |
+
return gen_images
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def mage_vit_base_patch16(**kwargs):
|
| 755 |
+
model = MaskedGenerativeEncoderViT(
|
| 756 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 757 |
+
decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
|
| 758 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 759 |
+
return model
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
def mage_vit_large_patch16(**kwargs):
|
| 763 |
+
model = MaskedGenerativeEncoderViT(
|
| 764 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 765 |
+
decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
|
| 766 |
+
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 767 |
+
return model
|
deploited/random_with_rope.log
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2024-04-26 09:40:33,300 [INFO ] Logging file is /home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr//20240426_094033.log
|
| 2 |
+
2024-04-26 09:40:33,301 [INFO ] ======================= exp =======================
|
| 3 |
+
2024-04-26 09:40:33,301 [INFO ] config:config/vpt_default.yaml
|
| 4 |
+
2024-04-26 09:40:33,301 [INFO ] name:exp
|
| 5 |
+
2024-04-26 09:40:33,301 [INFO ] lr:None
|
| 6 |
+
2024-04-26 09:40:33,301 [INFO ] world_size:1
|
| 7 |
+
2024-04-26 09:40:33,301 [INFO ] local_rank:-1
|
| 8 |
+
2024-04-26 09:40:33,301 [INFO ] dist_on_itp:False
|
| 9 |
+
2024-04-26 09:40:33,301 [INFO ] dist_url:env://
|
| 10 |
+
2024-04-26 09:40:33,301 [INFO ] TEST:False
|
| 11 |
+
2024-04-26 09:40:33,301 [INFO ] model:mage_vit_base_patch16
|
| 12 |
+
2024-04-26 09:40:33,301 [INFO ] dataset:imagenet
|
| 13 |
+
2024-04-26 09:40:33,301 [INFO ] dataset_path:/home/v-ruoyufeng/v-ruoyufeng/qyp/datasets/imagenet
|
| 14 |
+
2024-04-26 09:40:33,301 [INFO ] eval_path:
|
| 15 |
+
2024-04-26 09:40:33,302 [INFO ] batch_size:64
|
| 16 |
+
2024-04-26 09:40:33,302 [INFO ] test_batch_size:96
|
| 17 |
+
2024-04-26 09:40:33,302 [INFO ] num_workers:4
|
| 18 |
+
2024-04-26 09:40:33,302 [INFO ] epochs:400
|
| 19 |
+
2024-04-26 09:40:33,302 [INFO ] blr:0.0002
|
| 20 |
+
2024-04-26 09:40:33,302 [INFO ] min_lr:5e-06
|
| 21 |
+
2024-04-26 09:40:33,302 [INFO ] accum_iter:1
|
| 22 |
+
2024-04-26 09:40:33,302 [INFO ] input_size:256
|
| 23 |
+
2024-04-26 09:40:33,302 [INFO ] weight_decay:0.05
|
| 24 |
+
2024-04-26 09:40:33,302 [INFO ] warmup_epochs:20
|
| 25 |
+
2024-04-26 09:40:33,302 [INFO ] mask_ratio_min:0.5
|
| 26 |
+
2024-04-26 09:40:33,302 [INFO ] mask_ratio_max:0.9
|
| 27 |
+
2024-04-26 09:40:33,302 [INFO ] mask_ratio_mu:0.55
|
| 28 |
+
2024-04-26 09:40:33,302 [INFO ] mask_ratio_std:0.25
|
| 29 |
+
2024-04-26 09:40:33,303 [INFO ] grad_clip:3.0
|
| 30 |
+
2024-04-26 09:40:33,303 [INFO ] device:cuda
|
| 31 |
+
2024-04-26 09:40:33,303 [INFO ] distributed:True
|
| 32 |
+
2024-04-26 09:40:33,303 [INFO ] seed:0
|
| 33 |
+
2024-04-26 09:40:33,303 [INFO ] resume:/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_cls/checkpoint-220.pth
|
| 34 |
+
2024-04-26 09:40:33,303 [INFO ] start_epoch:0
|
| 35 |
+
2024-04-26 09:40:33,303 [INFO ] pin_mem:True
|
| 36 |
+
2024-04-26 09:40:33,303 [INFO ] save:True
|
| 37 |
+
2024-04-26 09:40:33,303 [INFO ] exp_name:MIM_vbr
|
| 38 |
+
2024-04-26 09:40:33,303 [INFO ] root:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage
|
| 39 |
+
2024-04-26 09:40:33,303 [INFO ] output_dir:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr
|
| 40 |
+
2024-04-26 09:40:33,303 [INFO ] log_dir:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr
|
| 41 |
+
2024-04-26 09:40:33,303 [INFO ] strart_epoch:0
|
| 42 |
+
2024-04-26 09:40:33,303 [INFO ] cuda:True
|
| 43 |
+
2024-04-26 09:40:33,304 [INFO ] rank:0
|
| 44 |
+
2024-04-26 09:40:33,304 [INFO ] gpu:0
|
| 45 |
+
2024-04-26 09:40:33,304 [INFO ] dist_backend:nccl
|
| 46 |
+
2024-04-26 09:40:33,304 [INFO ] ===================================================
|
| 47 |
+
2024-04-26 09:43:59,448 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.19228|msssim:0.19877|lpips:0.65322|dists:0.41351|Test loss:26.43571
|
| 48 |
+
2024-04-26 11:03:41,816 [INFO ] 0|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.59349|msssim:0.26058|lpips:0.58663|dists:0.36222|Test loss:24.07364
|
| 49 |
+
2024-04-26 12:23:19,122 [INFO ] 1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.76288|msssim:0.28263|lpips:0.56872|dists:0.33883|Test loss:23.62683
|
| 50 |
+
2024-04-26 13:43:34,495 [INFO ] 2|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.85301|msssim:0.28844|lpips:0.55959|dists:0.32906|Test loss:23.47551
|
| 51 |
+
2024-04-26 15:03:52,749 [INFO ] 3|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.92868|msssim:0.29865|lpips:0.55056|dists:0.32090|Test loss:23.33249
|
| 52 |
+
2024-04-26 16:23:59,470 [INFO ] 4|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.89536|msssim:0.30246|lpips:0.54562|dists:0.31920|Test loss:23.25514
|
| 53 |
+
2024-04-26 17:44:34,630 [INFO ] 5|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.77614|msssim:0.30227|lpips:0.54257|dists:0.30956|Test loss:23.14211
|
| 54 |
+
2024-04-26 19:05:03,170 [INFO ] 6|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.99645|msssim:0.30600|lpips:0.53496|dists:0.30253|Test loss:23.07454
|
| 55 |
+
2024-04-26 20:25:38,328 [INFO ] 7|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01199|msssim:0.31352|lpips:0.53121|dists:0.29895|Test loss:23.00766
|
| 56 |
+
2024-04-26 21:46:25,922 [INFO ] 8|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.95318|msssim:0.31135|lpips:0.53315|dists:0.30050|Test loss:22.96596
|
| 57 |
+
2024-04-26 23:07:15,459 [INFO ] 9|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.86282|msssim:0.31913|lpips:0.52871|dists:0.30449|Test loss:22.93078
|
| 58 |
+
2024-04-27 00:27:59,646 [INFO ] 10|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.02011|msssim:0.32396|lpips:0.52660|dists:0.29969|Test loss:22.88336
|
| 59 |
+
2024-04-27 01:48:56,864 [INFO ] 11|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.07232|msssim:0.32188|lpips:0.52460|dists:0.29983|Test loss:22.86315
|
| 60 |
+
2024-04-27 03:09:11,839 [INFO ] 12|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01043|msssim:0.32597|lpips:0.52504|dists:0.30010|Test loss:22.79985
|
| 61 |
+
2024-04-27 04:29:49,673 [INFO ] 13|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10261|msssim:0.32433|lpips:0.52505|dists:0.29726|Test loss:22.78871
|
| 62 |
+
2024-04-27 05:50:35,081 [INFO ] 14|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11125|msssim:0.32866|lpips:0.51948|dists:0.29619|Test loss:22.71350
|
| 63 |
+
2024-04-27 07:10:26,373 [INFO ] 15|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.06473|msssim:0.32313|lpips:0.51909|dists:0.29284|Test loss:22.72064
|
| 64 |
+
2024-04-27 08:29:03,434 [INFO ] 16|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.93886|msssim:0.32386|lpips:0.51924|dists:0.29350|Test loss:22.67830
|
| 65 |
+
2024-04-27 09:48:42,090 [INFO ] 17|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.91387|msssim:0.32367|lpips:0.51942|dists:0.29057|Test loss:22.68104
|
| 66 |
+
2024-04-27 11:07:05,282 [INFO ] 18|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04036|msssim:0.32969|lpips:0.51792|dists:0.29092|Test loss:22.64099
|
| 67 |
+
2024-04-27 12:26:41,635 [INFO ] 19|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.97499|msssim:0.32280|lpips:0.51902|dists:0.28883|Test loss:22.61567
|
| 68 |
+
2024-04-27 13:46:27,048 [INFO ] 20|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04421|msssim:0.32968|lpips:0.51996|dists:0.28981|Test loss:22.59735
|
| 69 |
+
2024-04-27 15:07:26,901 [INFO ] 21|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.09963|msssim:0.32759|lpips:0.51643|dists:0.29144|Test loss:22.58986
|
| 70 |
+
2024-04-27 16:27:18,924 [INFO ] 22|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25422|msssim:0.33293|lpips:0.51218|dists:0.28637|Test loss:22.55061
|
| 71 |
+
2024-04-27 17:47:28,605 [INFO ] 23|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.03031|msssim:0.32350|lpips:0.51621|dists:0.28594|Test loss:22.55859
|
| 72 |
+
2024-04-27 19:06:30,307 [INFO ] 24|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01311|msssim:0.33033|lpips:0.51452|dists:0.28868|Test loss:22.51194
|
| 73 |
+
2024-04-27 20:26:25,446 [INFO ] 25|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12269|msssim:0.33470|lpips:0.51502|dists:0.28773|Test loss:22.53424
|
| 74 |
+
2024-04-27 21:45:53,181 [INFO ] 26|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22099|msssim:0.33310|lpips:0.51191|dists:0.28884|Test loss:22.50715
|
| 75 |
+
2024-04-27 23:06:53,646 [INFO ] 27|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.08395|msssim:0.33514|lpips:0.51467|dists:0.28832|Test loss:22.50186
|
| 76 |
+
2024-04-28 00:26:46,986 [INFO ] 28|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23953|msssim:0.33534|lpips:0.51394|dists:0.28696|Test loss:22.49509
|
| 77 |
+
2024-04-28 01:47:49,741 [INFO ] 29|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11600|msssim:0.33224|lpips:0.51388|dists:0.28858|Test loss:22.48610
|
| 78 |
+
2024-04-28 03:08:05,651 [INFO ] 30|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04495|msssim:0.33490|lpips:0.51515|dists:0.28526|Test loss:22.47889
|
| 79 |
+
2024-04-28 04:29:31,374 [INFO ] 31|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13572|msssim:0.33166|lpips:0.51150|dists:0.28472|Test loss:22.44420
|
| 80 |
+
2024-04-28 04:36:00,208 [INFO ] 32|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29226|msssim:0.33557|lpips:0.51159|dists:0.28618|Test loss:22.46866
|
| 81 |
+
2024-04-28 05:55:42,942 [INFO ] 33|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24471|msssim:0.33982|lpips:0.50918|dists:0.28688|Test loss:22.45550
|
| 82 |
+
2024-04-28 07:14:35,318 [INFO ] 34|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26866|msssim:0.34018|lpips:0.50957|dists:0.28656|Test loss:22.46537
|
| 83 |
+
2024-04-28 08:33:33,967 [INFO ] 35|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17031|msssim:0.33502|lpips:0.51055|dists:0.28262|Test loss:22.42946
|
| 84 |
+
2024-04-28 09:54:29,978 [INFO ] 36|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25919|msssim:0.33783|lpips:0.51071|dists:0.28604|Test loss:22.45442
|
| 85 |
+
2024-04-28 11:15:39,983 [INFO ] 37|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20635|msssim:0.33632|lpips:0.50824|dists:0.28454|Test loss:22.45101
|
| 86 |
+
2024-04-28 12:35:52,799 [INFO ] 38|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32195|msssim:0.34213|lpips:0.50806|dists:0.28493|Test loss:22.43027
|
| 87 |
+
2024-04-28 13:55:26,853 [INFO ] 39|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36062|msssim:0.33969|lpips:0.50819|dists:0.28332|Test loss:22.41071
|
| 88 |
+
2024-04-28 15:15:39,162 [INFO ] 40|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22780|msssim:0.33774|lpips:0.50971|dists:0.28230|Test loss:22.41198
|
| 89 |
+
2024-04-28 16:36:09,180 [INFO ] 41|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29715|msssim:0.34114|lpips:0.50735|dists:0.28232|Test loss:22.39130
|
| 90 |
+
2024-04-28 17:57:27,808 [INFO ] 42|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.09102|msssim:0.33483|lpips:0.50913|dists:0.28347|Test loss:22.41543
|
| 91 |
+
2024-04-28 19:17:26,206 [INFO ] 43|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20306|msssim:0.34007|lpips:0.51021|dists:0.28400|Test loss:22.40152
|
| 92 |
+
2024-04-28 20:35:40,022 [INFO ] 44|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22824|msssim:0.33591|lpips:0.50983|dists:0.28228|Test loss:22.42137
|
| 93 |
+
2024-04-28 21:54:40,195 [INFO ] 45|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38086|msssim:0.34395|lpips:0.50842|dists:0.28619|Test loss:22.39918
|
| 94 |
+
2024-04-28 23:13:35,391 [INFO ] 46|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16532|msssim:0.33896|lpips:0.50933|dists:0.28314|Test loss:22.41701
|
| 95 |
+
2024-04-29 00:32:11,881 [INFO ] 47|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26819|msssim:0.33832|lpips:0.50565|dists:0.28173|Test loss:22.37620
|
| 96 |
+
2024-04-29 01:52:28,969 [INFO ] 48|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29298|msssim:0.34225|lpips:0.50837|dists:0.28685|Test loss:22.38713
|
| 97 |
+
2024-04-29 03:11:36,092 [INFO ] 49|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10050|msssim:0.33825|lpips:0.50617|dists:0.28070|Test loss:22.36497
|
| 98 |
+
2024-04-29 04:31:26,489 [INFO ] 50|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33884|msssim:0.34209|lpips:0.50693|dists:0.28271|Test loss:22.37217
|
| 99 |
+
2024-04-29 05:52:26,051 [INFO ] 51|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32687|msssim:0.34543|lpips:0.50534|dists:0.28436|Test loss:22.36113
|
| 100 |
+
2024-04-29 07:12:41,027 [INFO ] 52|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23419|msssim:0.34313|lpips:0.50683|dists:0.28346|Test loss:22.34546
|
| 101 |
+
2024-04-29 08:32:37,029 [INFO ] 53|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19908|msssim:0.34212|lpips:0.51034|dists:0.28401|Test loss:22.35217
|
| 102 |
+
2024-04-29 09:53:24,004 [INFO ] 54|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20177|msssim:0.34039|lpips:0.50786|dists:0.28354|Test loss:22.35897
|
| 103 |
+
2024-04-29 11:13:36,377 [INFO ] 55|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26382|msssim:0.34287|lpips:0.50529|dists:0.28344|Test loss:22.36005
|
| 104 |
+
2024-04-29 12:34:01,726 [INFO ] 56|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17537|msssim:0.34354|lpips:0.50563|dists:0.28147|Test loss:22.34528
|
| 105 |
+
2024-04-29 13:55:09,515 [INFO ] 57|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12880|msssim:0.34119|lpips:0.50656|dists:0.28078|Test loss:22.33371
|
| 106 |
+
2024-04-29 15:16:53,478 [INFO ] 58|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24665|msssim:0.34374|lpips:0.50576|dists:0.28350|Test loss:22.32770
|
| 107 |
+
2024-04-29 16:37:56,041 [INFO ] 59|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21803|msssim:0.34118|lpips:0.50617|dists:0.28309|Test loss:22.35576
|
| 108 |
+
2024-04-29 17:57:24,559 [INFO ] 60|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27436|msssim:0.34233|lpips:0.50439|dists:0.28176|Test loss:22.33442
|
| 109 |
+
2024-04-29 19:16:29,152 [INFO ] 61|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18612|msssim:0.33822|lpips:0.50917|dists:0.28183|Test loss:22.35925
|
| 110 |
+
2024-04-29 20:35:53,546 [INFO ] 62|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28599|msssim:0.34549|lpips:0.50387|dists:0.28009|Test loss:22.33322
|
| 111 |
+
2024-04-29 21:55:16,529 [INFO ] 63|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21795|msssim:0.34211|lpips:0.50608|dists:0.27972|Test loss:22.33488
|
| 112 |
+
2024-04-29 23:14:30,846 [INFO ] 64|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29407|msssim:0.34144|lpips:0.50596|dists:0.28011|Test loss:22.35140
|
| 113 |
+
2024-04-29 23:19:58,151 [INFO ] 65|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28235|msssim:0.34247|lpips:0.50468|dists:0.27957|Test loss:22.33808
|
| 114 |
+
2024-04-30 00:39:16,656 [INFO ] 66|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25703|msssim:0.34225|lpips:0.50430|dists:0.28089|Test loss:22.31090
|
| 115 |
+
2024-04-30 01:59:45,043 [INFO ] 67|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26808|msssim:0.34043|lpips:0.50586|dists:0.28125|Test loss:22.32784
|
| 116 |
+
2024-04-30 03:21:24,464 [INFO ] 68|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31984|msssim:0.34534|lpips:0.50445|dists:0.28171|Test loss:22.30779
|
| 117 |
+
2024-04-30 04:42:24,830 [INFO ] 69|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35958|msssim:0.34179|lpips:0.50447|dists:0.28130|Test loss:22.31575
|
| 118 |
+
2024-04-30 06:03:32,581 [INFO ] 70|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32468|msssim:0.34391|lpips:0.50387|dists:0.28060|Test loss:22.32640
|
deploited/random_with_rope.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b81f2e9114c831f49bf4926d5a40d8655a9d21376f3e71fcc66ecfb1c5db1cb7
|
| 3 |
+
size 1960373872
|
deploited/test_crop.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from util.utils import pad_to_multiple_of_256, split_into_blocks, merge_blocks, crop_to_original_shape
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
import PIL.Image as Image
|
| 9 |
+
from torchvision import utils as vutils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_img(p, padding=True, factor=64):
|
| 13 |
+
x = Image.open(p)
|
| 14 |
+
x = torch.from_numpy(np.asarray(x))
|
| 15 |
+
if len(x.shape) == 2:
|
| 16 |
+
x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3
|
| 17 |
+
x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
|
| 18 |
+
h, w = x.shape[2:4]
|
| 19 |
+
|
| 20 |
+
if padding:
|
| 21 |
+
dh = factor * math.ceil(h / factor) - h
|
| 22 |
+
dw = factor * math.ceil(w / factor) - w
|
| 23 |
+
# 均匀添加padding
|
| 24 |
+
dh_half = dh // 2
|
| 25 |
+
dw_half = dw // 2
|
| 26 |
+
dh_extra = dh % 2
|
| 27 |
+
dw_extra = dw % 2
|
| 28 |
+
x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra))
|
| 29 |
+
return x, h, w
|
| 30 |
+
|
| 31 |
+
def save_img(img: torch.Tensor, vis_path, input_p, rec=False):
|
| 32 |
+
img = img.clone().detach()
|
| 33 |
+
img = img.to(torch.device('cpu'))
|
| 34 |
+
if os.path.isdir(vis_path) is not True:
|
| 35 |
+
os.makedirs(vis_path)
|
| 36 |
+
end = '/'
|
| 37 |
+
if rec:
|
| 38 |
+
vis_path = vis_path + '/rec'
|
| 39 |
+
if os.path.isdir(vis_path) is not True:
|
| 40 |
+
os.makedirs(vis_path)
|
| 41 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 42 |
+
else:
|
| 43 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 44 |
+
vutils.save_image(img, os.path.join(img_name), nrow=8)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
eval_path = sorted(glob.glob(os.path.join('/home/t2vg-a100-G4-10/project/qyp/datasets/test', '*.jpg')))
|
| 48 |
+
vis_path = os.path.join("./test_crop/")
|
| 49 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
for input_p in eval_path:
|
| 52 |
+
x, hx, wx = load_img(input_p, padding=True, factor=64)
|
| 53 |
+
print("ori height", hx, "ori width", wx)
|
| 54 |
+
ori_shape = x.shape
|
| 55 |
+
print("input shape", ori_shape)
|
| 56 |
+
x = pad_to_multiple_of_256(x, 0)
|
| 57 |
+
save_img(x, vis_path, input_p, rec=False)
|
| 58 |
+
print("shape after padding", x.shape)
|
| 59 |
+
_, _, new_h, new_w = x.shape
|
| 60 |
+
x = split_into_blocks(x)
|
| 61 |
+
print("new shape", x.shape)
|
| 62 |
+
new_bsz = x.size(0)
|
| 63 |
+
new_shape = [ori_shape[0], 3, new_h, new_w]
|
| 64 |
+
x = merge_blocks(x, new_shape)
|
| 65 |
+
print("shape after merge", x.shape)
|
| 66 |
+
x = crop_to_original_shape(x, ori_shape)
|
| 67 |
+
print("shape after crop", x.shape)
|
| 68 |
+
save_img(x, vis_path, input_p, rec=True)
|
deploited/train_vbr_codec_rope.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
# import torchac
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
|
| 15 |
+
import torchvision
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torchvision.models import resnet50
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
from pytorch_msssim import ms_ssim
|
| 22 |
+
from DISTS_pytorch import DISTS
|
| 23 |
+
from util.lpips import LPIPS
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
from torchvision import utils as vutils
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
import numpy as np
|
| 28 |
+
import glob
|
| 29 |
+
|
| 30 |
+
import util.misc as misc
|
| 31 |
+
import util.lr_sched as lr_sched
|
| 32 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 33 |
+
import models_mage_codec
|
| 34 |
+
import models_mage_codec_rope
|
| 35 |
+
import timm.optim.optim_factory as optim_factory
|
| 36 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
| 37 |
+
import json
|
| 38 |
+
import PIL.Image as Image
|
| 39 |
+
import torch.backends.cudnn as cudnn
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
import random
|
| 42 |
+
import torch.distributed as dist
|
| 43 |
+
|
| 44 |
+
class CalMetrics(nn.Module):
|
| 45 |
+
"""Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
|
| 46 |
+
|
| 47 |
+
def __init__(self):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.mse = nn.MSELoss()
|
| 50 |
+
|
| 51 |
+
def bpp_loss(self, ori, out_net):
|
| 52 |
+
b, _, h, w = ori.shape
|
| 53 |
+
num_pixels = b * h * w
|
| 54 |
+
# bpp = sum(
|
| 55 |
+
# (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
|
| 56 |
+
# for likelihoods in out_net["likelihoods"].values()
|
| 57 |
+
# )
|
| 58 |
+
bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels)
|
| 59 |
+
bs_mask_token = out_net['bs_mask_token']
|
| 60 |
+
bytes_length = len(bs_mask_token)
|
| 61 |
+
# 因为每个字节包含8位,所以总位数是字节数 * 8
|
| 62 |
+
total_bits = bytes_length * 8
|
| 63 |
+
# 计算每像素的位数(bpp)
|
| 64 |
+
bpp_mask = total_bits / num_pixels
|
| 65 |
+
return bpp, bpp_mask
|
| 66 |
+
|
| 67 |
+
def psnr(self, rec, ori):
|
| 68 |
+
mse = torch.mean((rec - ori) ** 2)
|
| 69 |
+
if(mse == 0):
|
| 70 |
+
return 100
|
| 71 |
+
max_pixel = 1.
|
| 72 |
+
psnr = 10 * torch.log10(max_pixel / mse)
|
| 73 |
+
return torch.mean(psnr)
|
| 74 |
+
|
| 75 |
+
def lpips(self, rec, ori):
|
| 76 |
+
lpips_func = LPIPS().eval().to(device=rec.device)
|
| 77 |
+
lipis_value = lpips_func(rec, ori)
|
| 78 |
+
return lipis_value.mean()
|
| 79 |
+
|
| 80 |
+
def dists(self, rec, ori):
|
| 81 |
+
D = DISTS().cuda()
|
| 82 |
+
dists_value = D(rec, ori)
|
| 83 |
+
return dists_value.mean()
|
| 84 |
+
|
| 85 |
+
def cal_total_loss(self, lpips, bpp, out_net):
|
| 86 |
+
# task_loss = out_net['task_loss'] + 0.1 * lpips
|
| 87 |
+
task_loss = out_net['task_loss']
|
| 88 |
+
total_loss = bpp + out_net['lambda'] * task_loss
|
| 89 |
+
return total_loss
|
| 90 |
+
|
| 91 |
+
def forward(self, ori, out_net, rec=None):
|
| 92 |
+
out = {}
|
| 93 |
+
out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net)
|
| 94 |
+
out["bpp_loss"] = out["bpp"] + out["bpp_mask"]
|
| 95 |
+
# out["total_loss"] = self.cal_total_loss(out["bpp_loss"], out_net)
|
| 96 |
+
if rec is not None:
|
| 97 |
+
out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
|
| 98 |
+
out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True)
|
| 99 |
+
out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori)
|
| 100 |
+
out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
|
| 101 |
+
out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net)
|
| 102 |
+
return out
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class FeatureHook():
|
| 106 |
+
def __init__(self, module):
|
| 107 |
+
module.register_forward_hook(self.attach)
|
| 108 |
+
|
| 109 |
+
def attach(self, model, input, output):
|
| 110 |
+
self.feature = output
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Clsloss(nn.Module):
|
| 114 |
+
def __init__(self, device, cls_loss=True) -> None:
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.ce = nn.CrossEntropyLoss()
|
| 117 |
+
self.classifier = resnet50(True)
|
| 118 |
+
self.classifier.requires_grad_(False)
|
| 119 |
+
self.hooks = [FeatureHook(i) for i in [ # for calculating perceptual loss
|
| 120 |
+
self.classifier.layer1,
|
| 121 |
+
self.classifier.layer2,
|
| 122 |
+
self.classifier.layer3,
|
| 123 |
+
self.classifier.layer4,
|
| 124 |
+
]]
|
| 125 |
+
self.classifier = self.classifier.to(device)
|
| 126 |
+
for k, p in self.classifier.named_parameters():
|
| 127 |
+
p.requires_grad = False
|
| 128 |
+
self.classifier.eval()
|
| 129 |
+
self.cls_loss = cls_loss
|
| 130 |
+
self.transform = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 131 |
+
|
| 132 |
+
def accuracy(output, target, topk=(1,)):
|
| 133 |
+
maxk = max(topk)
|
| 134 |
+
batch_size = target.size(0)
|
| 135 |
+
|
| 136 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 137 |
+
pred = pred.t()
|
| 138 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 139 |
+
|
| 140 |
+
res = []
|
| 141 |
+
for k in topk:
|
| 142 |
+
correct_k = correct[:k].view(-1).float().sum(0)
|
| 143 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 144 |
+
return res
|
| 145 |
+
|
| 146 |
+
def forward(self, d, rec, y_true):
|
| 147 |
+
# output = codec(d), d = original image, y_true = label
|
| 148 |
+
x_hat = torch.clamp(rec,0,1)
|
| 149 |
+
pred = self.classifier(self.transform(x_hat)) # transform is normalization
|
| 150 |
+
cls_loss = self.ce(pred, y_true)
|
| 151 |
+
accu = sum(torch.argmax(pred,-1)==y_true)/pred.shape[0]
|
| 152 |
+
if self.perceptual_loss:
|
| 153 |
+
pred_feat = [i.feature.clone() for i in self.hooks]
|
| 154 |
+
_ = self.classifier(self.transform(d))
|
| 155 |
+
ori_feat = [i.feature.clone() for i in self.hooks]
|
| 156 |
+
perc_loss = torch.stack([nn.functional.mse_loss(p,o, reduction='none').mean((1,2,3)) for p,o in zip(pred_feat, ori_feat)])
|
| 157 |
+
perc_loss = perc_loss.mean()
|
| 158 |
+
return perc_loss
|
| 159 |
+
|
| 160 |
+
return cls_loss, accu, None
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class AverageMeter:
|
| 164 |
+
"""Compute running average."""
|
| 165 |
+
|
| 166 |
+
def __init__(self):
|
| 167 |
+
self.val = 0
|
| 168 |
+
self.avg = 0
|
| 169 |
+
self.sum = 0
|
| 170 |
+
self.count = 0
|
| 171 |
+
|
| 172 |
+
def update(self, val, n=1):
|
| 173 |
+
self.val = val
|
| 174 |
+
self.sum += val * n
|
| 175 |
+
self.count += n
|
| 176 |
+
self.avg = self.sum / self.count
|
| 177 |
+
|
| 178 |
+
class CustomDataParallel(nn.DataParallel):
|
| 179 |
+
"""Custom DataParallel to access the module methods."""
|
| 180 |
+
|
| 181 |
+
def __getattr__(self, key):
|
| 182 |
+
try:
|
| 183 |
+
return super().__getattr__(key)
|
| 184 |
+
except AttributeError:
|
| 185 |
+
return getattr(self.module, key)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def init(args):
|
| 189 |
+
base_dir = f'{args.root}/{args.exp_name}/'
|
| 190 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 191 |
+
return base_dir
|
| 192 |
+
|
| 193 |
+
def setup_logger(log_dir):
|
| 194 |
+
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
|
| 195 |
+
root_logger = logging.getLogger()
|
| 196 |
+
root_logger.setLevel(logging.INFO)
|
| 197 |
+
|
| 198 |
+
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
|
| 199 |
+
log_file_handler.setFormatter(log_formatter)
|
| 200 |
+
root_logger.addHandler(log_file_handler)
|
| 201 |
+
|
| 202 |
+
log_stream_handler = logging.StreamHandler(sys.stdout)
|
| 203 |
+
log_stream_handler.setFormatter(log_formatter)
|
| 204 |
+
root_logger.addHandler(log_stream_handler)
|
| 205 |
+
|
| 206 |
+
logging.info('Logging file is %s' % log_dir)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def load_img(p, padding=True, factor=64):
|
| 210 |
+
x = Image.open(p)
|
| 211 |
+
x = torch.from_numpy(np.asarray(x))
|
| 212 |
+
if len(x.shape) == 2:
|
| 213 |
+
x = x.unsqueeze(-1).repeat(1,1,3) # h,w -> h,w,3
|
| 214 |
+
x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
|
| 215 |
+
h, w = x.shape[2:4]
|
| 216 |
+
|
| 217 |
+
if padding:
|
| 218 |
+
dh = factor * math.ceil(h / factor) - h
|
| 219 |
+
dw = factor * math.ceil(w / factor) - w
|
| 220 |
+
x = F.pad(x, (0, dw, 0, dh))
|
| 221 |
+
return x, h, w
|
| 222 |
+
|
| 223 |
+
def save_img(img: torch.Tensor, vis_path, input_p, mask=False):
|
| 224 |
+
img = img.clone().detach()
|
| 225 |
+
img = img.to(torch.device('cpu'))
|
| 226 |
+
if os.path.isdir(vis_path) is not True:
|
| 227 |
+
os.makedirs(vis_path)
|
| 228 |
+
end = '/'
|
| 229 |
+
if mask:
|
| 230 |
+
img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):])
|
| 231 |
+
else:
|
| 232 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 233 |
+
vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8)
|
| 234 |
+
|
| 235 |
+
def train_one_epoch(model, data_loader, metrics_criterion, device,
|
| 236 |
+
optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'):
|
| 237 |
+
## ======================= set configs ======================= ##
|
| 238 |
+
model.train(True)
|
| 239 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 240 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 241 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 242 |
+
print_freq = 20
|
| 243 |
+
accum_iter = args.accum_iter
|
| 244 |
+
optimizer.zero_grad()
|
| 245 |
+
if log_writer is not None:
|
| 246 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 247 |
+
|
| 248 |
+
vis_path = os.path.join("./MIM_vbr/", stage)
|
| 249 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 250 |
+
|
| 251 |
+
# tqdm_emu = tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train), leave=False)
|
| 252 |
+
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 253 |
+
samples = samples.to(device, non_blocking=True) # samples = original image
|
| 254 |
+
|
| 255 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 256 |
+
if data_iter_step % accum_iter == 0:
|
| 257 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
| 258 |
+
|
| 259 |
+
with torch.cuda.amp.autocast():
|
| 260 |
+
out_net = model(samples, is_training=True, manual_mask_rate=None)
|
| 261 |
+
|
| 262 |
+
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
|
| 263 |
+
# rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
|
| 264 |
+
rec = rec.to(device)
|
| 265 |
+
out_criterion = metrics_criterion(samples, out_net, rec)
|
| 266 |
+
loss_value = out_criterion['total_loss'].item()
|
| 267 |
+
|
| 268 |
+
if not math.isfinite(loss_value):
|
| 269 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 270 |
+
sys.exit(1)
|
| 271 |
+
|
| 272 |
+
out_criterion['total_loss'] /= accum_iter
|
| 273 |
+
loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
|
| 274 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
| 275 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 276 |
+
optimizer.zero_grad()
|
| 277 |
+
|
| 278 |
+
torch.cuda.synchronize()
|
| 279 |
+
|
| 280 |
+
metric_logger.update(loss=loss_value)
|
| 281 |
+
|
| 282 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 283 |
+
metric_logger.update(lr=lr)
|
| 284 |
+
metric_logger.update(bpp=out_criterion['bpp_loss'])
|
| 285 |
+
metric_logger.update(bpp_mask=out_criterion['bpp_mask'])
|
| 286 |
+
metric_logger.update(task_loss=out_net['task_loss'].item()) # task_loss未更新,均值更新了
|
| 287 |
+
metric_logger.update(lmbda=out_net['lambda'])
|
| 288 |
+
metric_logger.update(mask_ratio=out_net['mask_ratio']) # mask_ratio未更新,均值更新了
|
| 289 |
+
metric_logger.update(lpips=out_criterion['lpips'].item()) # lpips未更新,均值更新了
|
| 290 |
+
metric_logger.update(dists=out_criterion['dists'].item())
|
| 291 |
+
|
| 292 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 293 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
| 294 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 295 |
+
This calibrates different curves when batch size changes.
|
| 296 |
+
"""
|
| 297 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 298 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
| 299 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
| 300 |
+
|
| 301 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 302 |
+
if data_iter_step % 500 == 0:
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
real_fake_images = torch.cat((samples, rec), dim=0)
|
| 305 |
+
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8)
|
| 306 |
+
# vutils.save_image(samples, os.path.join(vis_path, f"{epoch}_{data_iter_step}_ori.jpg"), nrow=6)
|
| 307 |
+
# vutils.save_image(rec, os.path.join(vis_path, f"{epoch}_{data_iter_step}_rec.jpg"), nrow=6)
|
| 308 |
+
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8)
|
| 309 |
+
|
| 310 |
+
# if (data_iter_step % 10000 == 0) and (data_iter_step != 0):
|
| 311 |
+
# test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, 0.75, args, 'val')
|
| 312 |
+
# gather the stats from all processes
|
| 313 |
+
metric_logger.synchronize_between_processes()
|
| 314 |
+
print("Averaged stats:", metric_logger)
|
| 315 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 316 |
+
|
| 317 |
+
def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'):
|
| 318 |
+
model.eval()
|
| 319 |
+
bpp_loss = AverageMeter()
|
| 320 |
+
bpp_mask = AverageMeter()
|
| 321 |
+
psnr = AverageMeter()
|
| 322 |
+
msssim = AverageMeter()
|
| 323 |
+
lpips = AverageMeter()
|
| 324 |
+
dists = AverageMeter()
|
| 325 |
+
test_loss = AverageMeter()
|
| 326 |
+
|
| 327 |
+
vis_path = os.path.join("./MIM_vbr/", stage)
|
| 328 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 329 |
+
if stage == 'test':
|
| 330 |
+
test_vis_path = os.path.join("/home/t2vg-a100-G4-10/project/qyp/rec_fid", manual_mask_ratio)
|
| 331 |
+
os.makedirs(test_vis_path, exist_ok=True)
|
| 332 |
+
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
# tqdm_meter = tqdm.tqdm(eval_path, leave=False)
|
| 335 |
+
tqdm_meter = tqdm.tqdm(enumerate(test_loader),leave=False, total=len(test_loader))
|
| 336 |
+
for i, (d, l) in tqdm_meter:
|
| 337 |
+
d = d.to(device)
|
| 338 |
+
# l = l.to(device)
|
| 339 |
+
# for input_p in eval_path:
|
| 340 |
+
# x, hx, wx = load_img(input_p, padding=True, factor=64)
|
| 341 |
+
# x = x.to(device)
|
| 342 |
+
out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
|
| 343 |
+
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
|
| 344 |
+
|
| 345 |
+
# x = x[:, :, :hx, :wx]
|
| 346 |
+
# rec = rec[:, :, :hx, :wx]
|
| 347 |
+
rec = rec.to(device)
|
| 348 |
+
out_criterion = metrics_criterion(d, out_net, rec)
|
| 349 |
+
|
| 350 |
+
bpp_loss.update(out_criterion["bpp_loss"])
|
| 351 |
+
bpp_mask.update(out_criterion["bpp_mask"])
|
| 352 |
+
psnr.update(out_criterion['psnr'])
|
| 353 |
+
msssim.update(out_criterion['msssim'])
|
| 354 |
+
lpips.update(out_criterion['lpips'])
|
| 355 |
+
dists.update(out_criterion['dists'])
|
| 356 |
+
test_loss.update(out_criterion['total_loss'])
|
| 357 |
+
|
| 358 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 359 |
+
if stage == 'val':
|
| 360 |
+
if i % 5 == 0:
|
| 361 |
+
with torch.no_grad():
|
| 362 |
+
real_fake_images = torch.cat((d, rec), dim=0)
|
| 363 |
+
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{i}.jpg"), nrow=8)
|
| 364 |
+
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg"), nrow=8)
|
| 365 |
+
if stage == 'test':
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"), nrow=8)
|
| 368 |
+
# txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
|
| 369 |
+
# tqdm_meter.set_postfix_str(txt)
|
| 370 |
+
|
| 371 |
+
model.train()
|
| 372 |
+
|
| 373 |
+
# 假设其它变量和环境已经正确设置
|
| 374 |
+
if torch.distributed.is_initialized():
|
| 375 |
+
rank = dist.get_rank()
|
| 376 |
+
else:
|
| 377 |
+
rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
|
| 378 |
+
|
| 379 |
+
if rank == 0:
|
| 380 |
+
log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}"
|
| 381 |
+
logging.info(log_txt)
|
| 382 |
+
return test_loss.avg
|
| 383 |
+
|
| 384 |
+
def inference_with_acc(epoch, test_loader, model, metrics_criterion, cls_criterion, device, manual_mask_ratio, args, stage='test'):
|
| 385 |
+
model.eval()
|
| 386 |
+
bpp_loss = AverageMeter()
|
| 387 |
+
bpp_mask = AverageMeter()
|
| 388 |
+
psnr = AverageMeter()
|
| 389 |
+
msssim = AverageMeter()
|
| 390 |
+
lpips = AverageMeter()
|
| 391 |
+
dists = AverageMeter()
|
| 392 |
+
accuracy = AverageMeter()
|
| 393 |
+
test_loss = AverageMeter()
|
| 394 |
+
|
| 395 |
+
if stage == 'test':
|
| 396 |
+
# test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio)
|
| 397 |
+
test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/test_pos", str(manual_mask_ratio))
|
| 398 |
+
os.makedirs(test_vis_path, exist_ok=True)
|
| 399 |
+
|
| 400 |
+
with torch.no_grad():
|
| 401 |
+
# tqdm_meter = tqdm.tqdm(eval_path, leave=False)
|
| 402 |
+
tqdm_meter = tqdm.tqdm(enumerate(test_loader),leave=False, total=len(test_loader))
|
| 403 |
+
for i, (d, l) in tqdm_meter:
|
| 404 |
+
d = d.to(device)
|
| 405 |
+
# l = l.to(device)
|
| 406 |
+
# for input_p in eval_path:
|
| 407 |
+
# x, hx, wx = load_img(input_p, padding=True, factor=64)
|
| 408 |
+
# x = x.to(device)
|
| 409 |
+
# _, _, hx, wx = d.shape
|
| 410 |
+
out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
|
| 411 |
+
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'], num_iter=12)
|
| 412 |
+
|
| 413 |
+
# x = x[:, :, :hx, :wx]
|
| 414 |
+
# rec = rec[:, :, :hx, :wx]
|
| 415 |
+
rec = rec.to(device)
|
| 416 |
+
out_criterion = metrics_criterion(d, out_net, rec)
|
| 417 |
+
_, accu, _ = cls_criterion(d, rec, l)
|
| 418 |
+
|
| 419 |
+
bpp_loss.update(out_criterion["bpp_loss"])
|
| 420 |
+
bpp_mask.update(out_criterion["bpp_mask"])
|
| 421 |
+
psnr.update(out_criterion['psnr'])
|
| 422 |
+
msssim.update(out_criterion['msssim'])
|
| 423 |
+
lpips.update(out_criterion['lpips'])
|
| 424 |
+
dists.update(out_criterion['dists'])
|
| 425 |
+
test_loss.update(out_criterion['total_loss'])
|
| 426 |
+
accuracy.update(accu)
|
| 427 |
+
|
| 428 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 429 |
+
if stage == 'test':
|
| 430 |
+
with torch.no_grad():
|
| 431 |
+
vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"))
|
| 432 |
+
# txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
|
| 433 |
+
# tqdm_meter.set_postfix_str(txt)
|
| 434 |
+
|
| 435 |
+
model.train()
|
| 436 |
+
|
| 437 |
+
# 假设其它变量和环境已经正确设置
|
| 438 |
+
if torch.distributed.is_initialized():
|
| 439 |
+
rank = dist.get_rank()
|
| 440 |
+
else:
|
| 441 |
+
rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
|
| 442 |
+
|
| 443 |
+
if rank == 0:
|
| 444 |
+
log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|accu:{accuracy.avg:.5f}|Test loss:{test_loss.avg.item():.5f}"
|
| 445 |
+
logging.info(log_txt)
|
| 446 |
+
return test_loss.avg
|
| 447 |
+
|
| 448 |
+
def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"):
|
| 449 |
+
torch.save(state, base_dir+filename)
|
| 450 |
+
if is_best:
|
| 451 |
+
torch.save(state, base_dir+"checkpoint_best.pth.tar")
|
| 452 |
+
# shutil.copyfile(base_dir+filename, base_dir+"checkpoint_best_loss.pth.tar")
|
| 453 |
+
|
| 454 |
+
def parse_args(argv):
|
| 455 |
+
parser = argparse.ArgumentParser(description="Example training script.")
|
| 456 |
+
parser.add_argument(
|
| 457 |
+
"-c",
|
| 458 |
+
"--config",
|
| 459 |
+
default="config/vpt_default.yaml",
|
| 460 |
+
help="Path to config file",
|
| 461 |
+
)
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
'--name',
|
| 464 |
+
default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
|
| 465 |
+
type=str,
|
| 466 |
+
help='Result dir name',
|
| 467 |
+
)
|
| 468 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 469 |
+
help='learning rate (absolute lr)')
|
| 470 |
+
given_configs, remaining = parser.parse_known_args(argv)
|
| 471 |
+
# distributed training parameters
|
| 472 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 473 |
+
help='number of distributed processes')
|
| 474 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 475 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 476 |
+
parser.add_argument('--dist_url', default='env://',
|
| 477 |
+
help='url used to set up distributed training')
|
| 478 |
+
with open(given_configs.config) as file:
|
| 479 |
+
yaml_data= yaml.safe_load(file)
|
| 480 |
+
parser.set_defaults(**yaml_data)
|
| 481 |
+
|
| 482 |
+
parser.add_argument(
|
| 483 |
+
"-T",
|
| 484 |
+
"--TEST",
|
| 485 |
+
# action='store_true',
|
| 486 |
+
default=False,
|
| 487 |
+
help='Testing'
|
| 488 |
+
)
|
| 489 |
+
args = parser.parse_args(remaining)
|
| 490 |
+
return args
|
| 491 |
+
|
| 492 |
+
def load_eval_ps(eval_path):
|
| 493 |
+
eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.png')))
|
| 494 |
+
return eval_ps
|
| 495 |
+
|
| 496 |
+
def main(argv):
|
| 497 |
+
args = parse_args(argv)
|
| 498 |
+
base_dir = init(args) # create the base dir for saving the results
|
| 499 |
+
|
| 500 |
+
if args.output_dir:
|
| 501 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 502 |
+
args.log_dir = args.output_dir
|
| 503 |
+
|
| 504 |
+
misc.init_distributed_mode(args)
|
| 505 |
+
|
| 506 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 507 |
+
print("{}".format(args).replace(', ', ',\n'))
|
| 508 |
+
|
| 509 |
+
device = torch.device(args.device)
|
| 510 |
+
# fix the seed for reproducibility
|
| 511 |
+
seed = args.seed + misc.get_rank()
|
| 512 |
+
torch.manual_seed(seed)
|
| 513 |
+
torch.cuda.manual_seed(seed)
|
| 514 |
+
np.random.seed(seed)
|
| 515 |
+
random.seed(seed)
|
| 516 |
+
|
| 517 |
+
cudnn.benchmark = True
|
| 518 |
+
|
| 519 |
+
setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
|
| 520 |
+
msg = f'======================= {args.name} ======================='
|
| 521 |
+
logging.info(msg)
|
| 522 |
+
for k in args.__dict__:
|
| 523 |
+
logging.info(k + ':' + str(args.__dict__[k]))
|
| 524 |
+
logging.info('=' * len(msg))
|
| 525 |
+
|
| 526 |
+
## ======================= prepare dataset ======================= ##
|
| 527 |
+
transform_train = transforms.Compose([
|
| 528 |
+
transforms.RandomResizedCrop(args.input_size, scale=(0.8, 1.0)),
|
| 529 |
+
transforms.RandomHorizontalFlip(),
|
| 530 |
+
transforms.ToTensor()])
|
| 531 |
+
transform_test = transforms.Compose( # crop to 256x256
|
| 532 |
+
[transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()]
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
if args.dataset=='imagenet':
|
| 536 |
+
train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_path, "train"), transform=transform_train)
|
| 537 |
+
test_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_path, "val"), transform=transform_test)
|
| 538 |
+
|
| 539 |
+
val_dataset, _ = torch.utils.data.random_split(test_dataset, [2000, 48000])
|
| 540 |
+
small_train_datasets = torch.utils.data.random_split(train_dataset, [40000]*32+[1167])
|
| 541 |
+
eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.png')))
|
| 542 |
+
|
| 543 |
+
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
|
| 544 |
+
|
| 545 |
+
# if args.distributed:
|
| 546 |
+
if True:
|
| 547 |
+
num_tasks = misc.get_world_size()
|
| 548 |
+
global_rank = misc.get_rank()
|
| 549 |
+
sampler_test = torch.utils.data.DistributedSampler(
|
| 550 |
+
test_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 551 |
+
)
|
| 552 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
| 553 |
+
val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 554 |
+
)
|
| 555 |
+
else:
|
| 556 |
+
sampler_train = torch.utils.data.RandomSampler(train_dataset)
|
| 557 |
+
|
| 558 |
+
if global_rank == 0 and args.log_dir is not None:
|
| 559 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 560 |
+
log_writer = SummaryWriter(log_dir=args.log_dir)
|
| 561 |
+
else:
|
| 562 |
+
log_writer = None
|
| 563 |
+
|
| 564 |
+
val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=args.test_batch_size,
|
| 565 |
+
num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True)
|
| 566 |
+
test_dataloader = DataLoader(test_dataset, sampler=sampler_test, batch_size=1,
|
| 567 |
+
num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem,)
|
| 568 |
+
|
| 569 |
+
## ======================= prepare model ======================= ##
|
| 570 |
+
vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
|
| 571 |
+
# model = models_mage_codec.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
|
| 572 |
+
# mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
|
| 573 |
+
# vqgan_ckpt_path=vqgan_ckpt_path)
|
| 574 |
+
model = models_mage_codec_rope.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
|
| 575 |
+
mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
|
| 576 |
+
vqgan_ckpt_path=vqgan_ckpt_path)
|
| 577 |
+
|
| 578 |
+
model.to(device)
|
| 579 |
+
model_without_ddp = model
|
| 580 |
+
print("Model = %s" % str(model_without_ddp))
|
| 581 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
| 582 |
+
if args.lr is None: # only base_lr is specified
|
| 583 |
+
args.lr = args.blr * eff_batch_size / 256
|
| 584 |
+
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
| 585 |
+
print("actual lr: %.2e" % args.lr)
|
| 586 |
+
|
| 587 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
| 588 |
+
print("effective batch size: %d" % eff_batch_size)
|
| 589 |
+
|
| 590 |
+
if args.distributed:
|
| 591 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 592 |
+
model_without_ddp = model.module
|
| 593 |
+
|
| 594 |
+
# following timm: set wd as 0 for bias and norm layers
|
| 595 |
+
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
|
| 596 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
| 597 |
+
print(optimizer)
|
| 598 |
+
loss_scaler = NativeScaler()
|
| 599 |
+
|
| 600 |
+
# resume from a checkpoint
|
| 601 |
+
misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
| 602 |
+
optimizer=optimizer, loss_scaler=loss_scaler, strict=False)
|
| 603 |
+
|
| 604 |
+
metrics_criterion = CalMetrics()
|
| 605 |
+
cls_criterion = Clsloss(device, True)
|
| 606 |
+
## ======================= Start Training ======================= ##
|
| 607 |
+
last_epoch = args.start_epoch
|
| 608 |
+
|
| 609 |
+
if args.TEST: # test only
|
| 610 |
+
best_loss = float("inf")
|
| 611 |
+
tqrange = tqdm.trange(last_epoch, args.epochs)
|
| 612 |
+
# loss = test_epoch(-1, test_dataloader, net, rate_criterion, args.VPT_lmbda, args, 'test')
|
| 613 |
+
for manual_mask_ratio in [0.3]:
|
| 614 |
+
loss = inference_with_acc(-1, test_dataloader, model, metrics_criterion, cls_criterion, device, manual_mask_ratio, args, 'test')
|
| 615 |
+
return
|
| 616 |
+
|
| 617 |
+
## ======================= pre validation ======================= ##
|
| 618 |
+
print("############## pre validation ##############")
|
| 619 |
+
best_loss = float("inf")
|
| 620 |
+
tqrange = tqdm.trange(last_epoch, args.epochs)
|
| 621 |
+
val_mask_ratio = 0.75
|
| 622 |
+
test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
|
| 623 |
+
|
| 624 |
+
## ======================= start training ======================= ##
|
| 625 |
+
print(f"############## Start training for {args.epochs} epochs ##############")
|
| 626 |
+
start_time = time.time()
|
| 627 |
+
for epoch in tqrange:
|
| 628 |
+
current_dataset = small_train_datasets[epoch % len(small_train_datasets)]
|
| 629 |
+
sampler_train = torch.utils.data.DistributedSampler(current_dataset, shuffle=True)
|
| 630 |
+
data_loader_train = DataLoader(
|
| 631 |
+
current_dataset, sampler=sampler_train,
|
| 632 |
+
batch_size=args.batch_size,
|
| 633 |
+
num_workers=args.num_workers,
|
| 634 |
+
pin_memory=args.pin_mem,
|
| 635 |
+
drop_last=True,
|
| 636 |
+
)
|
| 637 |
+
if args.distributed:
|
| 638 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 639 |
+
train_stats = train_one_epoch(model, data_loader_train, metrics_criterion, device,
|
| 640 |
+
optimizer, epoch, loss_scaler, log_writer=log_writer, args=args, val_dataloader=val_dataloader, stage='train')
|
| 641 |
+
|
| 642 |
+
test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
|
| 643 |
+
|
| 644 |
+
is_best = test_loss < best_loss
|
| 645 |
+
best_loss = min(test_loss, best_loss)
|
| 646 |
+
|
| 647 |
+
if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs):
|
| 648 |
+
misc.save_model(
|
| 649 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 650 |
+
loss_scaler=loss_scaler, epoch=epoch)
|
| 651 |
+
if is_best:
|
| 652 |
+
misc.save_model_last(
|
| 653 |
+
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 654 |
+
loss_scaler=loss_scaler, epoch=epoch, is_best=is_best)
|
| 655 |
+
# misc.save_model_last(
|
| 656 |
+
# args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
| 657 |
+
# loss_scaler=loss_scaler, epoch=epoch)
|
| 658 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 659 |
+
'epoch': epoch,}
|
| 660 |
+
if args.output_dir and misc.is_main_process():
|
| 661 |
+
if log_writer is not None:
|
| 662 |
+
log_writer.flush()
|
| 663 |
+
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
|
| 664 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 665 |
+
|
| 666 |
+
total_time = time.time() - start_time
|
| 667 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 668 |
+
print('Training time {}'.format(total_time_str))
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
|
| 672 |
+
if __name__ == "__main__":
|
| 673 |
+
main(sys.argv[1:])
|
detection.py
ADDED
|
@@ -0,0 +1,635 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tqdm
|
| 2 |
+
import argparse
|
| 3 |
+
import math
|
| 4 |
+
# import torchac
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.optim as optim
|
| 14 |
+
|
| 15 |
+
import torchvision
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torchvision.models import resnet50
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
from pytorch_msssim import ms_ssim
|
| 22 |
+
from DISTS_pytorch import DISTS
|
| 23 |
+
from util.lpips import LPIPS
|
| 24 |
+
from torch.nn import functional as F
|
| 25 |
+
from torchvision import utils as vutils
|
| 26 |
+
import numpy as np
|
| 27 |
+
|
| 28 |
+
import util.misc as misc
|
| 29 |
+
import util.lr_sched as lr_sched
|
| 30 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 31 |
+
import models_mage_codec_high_resolu
|
| 32 |
+
import timm.optim.optim_factory as optim_factory
|
| 33 |
+
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
| 34 |
+
from collections import OrderedDict
|
| 35 |
+
import pickle
|
| 36 |
+
import torch.backends.cudnn as cudnn
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
import random
|
| 39 |
+
import torch.distributed as dist
|
| 40 |
+
from util.dataloader import MSCOCO, Kodak, prepadding
|
| 41 |
+
from util.utils import adaptively_split_and_pad, crop_and_reconstruct
|
| 42 |
+
from util.alignment import Alignment
|
| 43 |
+
|
| 44 |
+
## General
|
| 45 |
+
from detectron2.config import get_cfg
|
| 46 |
+
from detectron2.layers import ShapeSpec
|
| 47 |
+
from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone
|
| 48 |
+
|
| 49 |
+
## Test
|
| 50 |
+
from detectron2.evaluation import COCOEvaluator
|
| 51 |
+
from detectron2.data.datasets import register_coco_instances
|
| 52 |
+
from detectron2.data import build_detection_test_loader
|
| 53 |
+
from detectron2.data.detection_utils import read_image
|
| 54 |
+
|
| 55 |
+
from contextlib import ExitStack, contextmanager
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
## Function for model to eval 用于临时将模型切换到评估模式并在操作完成后恢复模型的原始模式
|
| 59 |
+
@contextmanager
|
| 60 |
+
def inference_context(model):
|
| 61 |
+
training_mode = model.training
|
| 62 |
+
model.eval()
|
| 63 |
+
yield
|
| 64 |
+
model.train(training_mode)
|
| 65 |
+
|
| 66 |
+
class CalMetrics(nn.Module):
|
| 67 |
+
"""Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.mse = nn.MSELoss()
|
| 72 |
+
|
| 73 |
+
def bpp_loss(self, ori, out_net):
|
| 74 |
+
b, _, h, w = ori.shape
|
| 75 |
+
num_pixels = b * h * w
|
| 76 |
+
bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels)
|
| 77 |
+
bs_mask_token = out_net['bs_mask_token']
|
| 78 |
+
bytes_length = len(bs_mask_token)
|
| 79 |
+
# 因为每个字节包含8位,所以总位数是字节数 * 8
|
| 80 |
+
total_bits = bytes_length * 8
|
| 81 |
+
# 计算每像素的位数(bpp)
|
| 82 |
+
bpp_mask = total_bits / num_pixels
|
| 83 |
+
return bpp, bpp_mask
|
| 84 |
+
|
| 85 |
+
def psnr(self, rec, ori):
|
| 86 |
+
mse = torch.mean((rec - ori) ** 2)
|
| 87 |
+
if(mse == 0):
|
| 88 |
+
return 100
|
| 89 |
+
max_pixel = 1.
|
| 90 |
+
psnr = 10 * torch.log10(max_pixel / mse)
|
| 91 |
+
return torch.mean(psnr)
|
| 92 |
+
|
| 93 |
+
def lpips(self, rec, ori):
|
| 94 |
+
lpips_func = LPIPS().eval().to(device=rec.device)
|
| 95 |
+
lipis_value = lpips_func(rec, ori)
|
| 96 |
+
return lipis_value.mean()
|
| 97 |
+
|
| 98 |
+
def dists(self, rec, ori):
|
| 99 |
+
D = DISTS().cuda()
|
| 100 |
+
dists_value = D(rec, ori)
|
| 101 |
+
return dists_value.mean()
|
| 102 |
+
|
| 103 |
+
def cal_total_loss(self, lpips, bpp, out_net):
|
| 104 |
+
# task_loss = out_net['task_loss'] + 0.1 * lpips
|
| 105 |
+
task_loss = out_net['task_loss']
|
| 106 |
+
total_loss = bpp + out_net['lambda'] * task_loss
|
| 107 |
+
return total_loss
|
| 108 |
+
|
| 109 |
+
def forward(self, ori, out_net, rec=None):
|
| 110 |
+
out = {}
|
| 111 |
+
out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net)
|
| 112 |
+
out["bpp_loss"] = out["bpp"] + out["bpp_mask"]
|
| 113 |
+
# out["total_loss"] = self.cal_total_loss(out["bpp_loss"], out_net)
|
| 114 |
+
if rec is not None:
|
| 115 |
+
out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
|
| 116 |
+
out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True)
|
| 117 |
+
out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori)
|
| 118 |
+
out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
|
| 119 |
+
out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net)
|
| 120 |
+
return out
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class TaskLoss(nn.Module):
|
| 124 |
+
def __init__(self, cfg, device) -> None:
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.ce = nn.CrossEntropyLoss()
|
| 127 |
+
self.task_net = build_resnet_fpn_backbone(cfg, ShapeSpec(channels=3))
|
| 128 |
+
checkpoint = OrderedDict()
|
| 129 |
+
with open(cfg.MODEL.WEIGHTS, 'rb') as f:
|
| 130 |
+
FPN_ckpt = pickle.load(f)
|
| 131 |
+
for k, v in FPN_ckpt['model'].items():
|
| 132 |
+
if 'backbone' in k:
|
| 133 |
+
checkpoint['.'.join(k.split('.')[1:])] = torch.from_numpy(v)
|
| 134 |
+
self.task_net.load_state_dict(checkpoint, strict=True)
|
| 135 |
+
self.task_net = self.task_net.to(device)
|
| 136 |
+
for k, p in self.task_net.named_parameters():
|
| 137 |
+
p.requires_grad = False
|
| 138 |
+
self.task_net.eval()
|
| 139 |
+
self.align = Alignment(divisor=32).to(device) # 初始化对齐模块,用于图像大小调整。
|
| 140 |
+
self.pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device) # imagenet mean
|
| 141 |
+
|
| 142 |
+
def forward(self, output, d, train_mode=False):
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
## Ground truth for perceptual loss
|
| 145 |
+
d = d.flip(1).mul(255) # RGB to BGR, [0,1] to [0,255]
|
| 146 |
+
d = d - self.pixel_mean
|
| 147 |
+
if not train_mode:
|
| 148 |
+
d = self.align.align(d)
|
| 149 |
+
gt_out = self.task_net(d)
|
| 150 |
+
|
| 151 |
+
x_hat = torch.clamp(output["x_hat"], 0, 1)
|
| 152 |
+
x_hat = x_hat.flip(1).mul(255)
|
| 153 |
+
x_hat = x_hat - self.pixel_mean
|
| 154 |
+
if not train_mode:
|
| 155 |
+
x_hat = self.align.align(x_hat)
|
| 156 |
+
task_net_out = self.task_net(x_hat)
|
| 157 |
+
|
| 158 |
+
distortion_p2 = nn.MSELoss(reduction='none')(gt_out["p2"], task_net_out["p2"])
|
| 159 |
+
distortion_p3 = nn.MSELoss(reduction='none')(gt_out["p3"], task_net_out["p3"])
|
| 160 |
+
distortion_p4 = nn.MSELoss(reduction='none')(gt_out["p4"], task_net_out["p4"])
|
| 161 |
+
distortion_p5 = nn.MSELoss(reduction='none')(gt_out["p5"], task_net_out["p5"])
|
| 162 |
+
distortion_p6 = nn.MSELoss(reduction='none')(gt_out["p6"], task_net_out["p6"])
|
| 163 |
+
|
| 164 |
+
return 0.2*(distortion_p2.mean()+distortion_p3.mean()+distortion_p4.mean()+distortion_p5.mean()+distortion_p6.mean())
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class AverageMeter:
|
| 168 |
+
"""Compute running average."""
|
| 169 |
+
|
| 170 |
+
def __init__(self):
|
| 171 |
+
self.val = 0
|
| 172 |
+
self.avg = 0
|
| 173 |
+
self.sum = 0
|
| 174 |
+
self.count = 0
|
| 175 |
+
|
| 176 |
+
def update(self, val, n=1):
|
| 177 |
+
self.val = val
|
| 178 |
+
self.sum += val * n
|
| 179 |
+
self.count += n
|
| 180 |
+
self.avg = self.sum / self.count
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def init(args):
|
| 184 |
+
base_dir = f'{args.root}/{args.exp_name}/'
|
| 185 |
+
os.makedirs(base_dir, exist_ok=True)
|
| 186 |
+
return base_dir
|
| 187 |
+
|
| 188 |
+
def setup_logger(log_dir):
|
| 189 |
+
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
|
| 190 |
+
root_logger = logging.getLogger()
|
| 191 |
+
root_logger.setLevel(logging.INFO)
|
| 192 |
+
|
| 193 |
+
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
|
| 194 |
+
log_file_handler.setFormatter(log_formatter)
|
| 195 |
+
root_logger.addHandler(log_file_handler)
|
| 196 |
+
|
| 197 |
+
log_stream_handler = logging.StreamHandler(sys.stdout)
|
| 198 |
+
log_stream_handler.setFormatter(log_formatter)
|
| 199 |
+
root_logger.addHandler(log_stream_handler)
|
| 200 |
+
|
| 201 |
+
logging.info('Logging file is %s' % log_dir)
|
| 202 |
+
|
| 203 |
+
def save_img(img: torch.Tensor, vis_path, input_p, mask=False):
|
| 204 |
+
img = img.clone().detach()
|
| 205 |
+
img = img.to(torch.device('cpu'))
|
| 206 |
+
if os.path.isdir(vis_path) is not True:
|
| 207 |
+
os.makedirs(vis_path)
|
| 208 |
+
end = '/'
|
| 209 |
+
if mask:
|
| 210 |
+
img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):])
|
| 211 |
+
else:
|
| 212 |
+
img_name = vis_path + str(input_p[input_p.rfind(end):])
|
| 213 |
+
vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8)
|
| 214 |
+
|
| 215 |
+
def train_one_epoch(model, data_loader, metrics_criterion, device,
|
| 216 |
+
optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'):
|
| 217 |
+
## ======================= set configs ======================= ##
|
| 218 |
+
model.train(True)
|
| 219 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 220 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 221 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 222 |
+
print_freq = 20
|
| 223 |
+
accum_iter = args.accum_iter
|
| 224 |
+
optimizer.zero_grad()
|
| 225 |
+
if log_writer is not None:
|
| 226 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 227 |
+
|
| 228 |
+
vis_path = os.path.join("./MIM_vbr/", stage)
|
| 229 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 230 |
+
|
| 231 |
+
# tqdm_emu = tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train), leave=False)
|
| 232 |
+
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 233 |
+
samples = samples.to(device, non_blocking=True) # samples = original image
|
| 234 |
+
|
| 235 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 236 |
+
if data_iter_step % accum_iter == 0:
|
| 237 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
| 238 |
+
|
| 239 |
+
with torch.cuda.amp.autocast():
|
| 240 |
+
out_net = model(samples, is_training=True, manual_mask_rate=None)
|
| 241 |
+
|
| 242 |
+
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'])
|
| 243 |
+
# rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'])
|
| 244 |
+
rec = rec.to(device)
|
| 245 |
+
out_criterion = metrics_criterion(samples, out_net, rec)
|
| 246 |
+
loss_value = out_criterion['total_loss'].item()
|
| 247 |
+
|
| 248 |
+
if not math.isfinite(loss_value):
|
| 249 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 250 |
+
sys.exit(1)
|
| 251 |
+
|
| 252 |
+
out_criterion['total_loss'] /= accum_iter
|
| 253 |
+
loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
|
| 254 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
| 255 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 256 |
+
optimizer.zero_grad()
|
| 257 |
+
|
| 258 |
+
torch.cuda.synchronize()
|
| 259 |
+
|
| 260 |
+
metric_logger.update(loss=loss_value)
|
| 261 |
+
|
| 262 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 263 |
+
metric_logger.update(lr=lr)
|
| 264 |
+
metric_logger.update(bpp=out_criterion['bpp_loss'])
|
| 265 |
+
metric_logger.update(bpp_mask=out_criterion['bpp_mask'])
|
| 266 |
+
metric_logger.update(task_loss=out_net['task_loss'].item()) # task_loss未更新,均值更新了
|
| 267 |
+
metric_logger.update(lmbda=out_net['lambda'])
|
| 268 |
+
metric_logger.update(mask_ratio=out_net['mask_ratio']) # mask_ratio未更新,均值更新了
|
| 269 |
+
metric_logger.update(lpips=out_criterion['lpips'].item()) # lpips未更新,均值更新了
|
| 270 |
+
metric_logger.update(dists=out_criterion['dists'].item())
|
| 271 |
+
|
| 272 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 273 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
| 274 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 275 |
+
This calibrates different curves when batch size changes.
|
| 276 |
+
"""
|
| 277 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 278 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
| 279 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
| 280 |
+
|
| 281 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 282 |
+
if data_iter_step % 1000 == 0:
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
real_fake_images = torch.cat((samples, rec), dim=0)
|
| 285 |
+
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8)
|
| 286 |
+
# vutils.save_image(samples, os.path.join(vis_path, f"{epoch}_{data_iter_step}_ori.jpg"), nrow=6)
|
| 287 |
+
# vutils.save_image(rec, os.path.join(vis_path, f"{epoch}_{data_iter_step}_rec.jpg"), nrow=6)
|
| 288 |
+
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8)
|
| 289 |
+
|
| 290 |
+
# if (data_iter_step % 10000 == 0) and (data_iter_step != 0):
|
| 291 |
+
# test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, 0.75, args, 'val')
|
| 292 |
+
# gather the stats from all processes
|
| 293 |
+
metric_logger.synchronize_between_processes()
|
| 294 |
+
print("Averaged stats:", metric_logger)
|
| 295 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 296 |
+
|
| 297 |
+
def train_one_epoch(train_dataloader, optimizer, model, criterion_rd, criterion_task, lmbda):
|
| 298 |
+
model.train()
|
| 299 |
+
device = next(model.parameters()).device
|
| 300 |
+
tqdm_emu = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False)
|
| 301 |
+
for i, d in tqdm_emu:
|
| 302 |
+
d = d.to(device)
|
| 303 |
+
|
| 304 |
+
optimizer.zero_grad()
|
| 305 |
+
|
| 306 |
+
out_net = model(d)
|
| 307 |
+
|
| 308 |
+
out_criterion = criterion_rd(out_net, d)
|
| 309 |
+
perc_loss = criterion_task(out_net, d)
|
| 310 |
+
total_loss = perc_loss + lmbda * out_criterion['bpp_loss']
|
| 311 |
+
total_loss.backward()
|
| 312 |
+
optimizer.step()
|
| 313 |
+
|
| 314 |
+
update_txt=f'[{i*len(d)}/{len(train_dataloader.dataset)}] | Loss: {total_loss.item():.3f} | Distortion loss: {perc_loss.item():.5f} | Bpp loss: {out_criterion["bpp_loss"].item():.4f}'
|
| 315 |
+
tqdm_emu.set_postfix_str(update_txt, refresh=True)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def validation_epoch(epoch, val_dataloader, model, criterion_rd, criterion_task, lmbda):
|
| 319 |
+
model.eval()
|
| 320 |
+
device = next(model.parameters()).device
|
| 321 |
+
|
| 322 |
+
bpp_loss = AverageMeter()
|
| 323 |
+
mse_loss = AverageMeter()
|
| 324 |
+
psnr = AverageMeter()
|
| 325 |
+
percloss = AverageMeter()
|
| 326 |
+
totalloss = AverageMeter()
|
| 327 |
+
|
| 328 |
+
with torch.no_grad():
|
| 329 |
+
tqdm_meter = tqdm.tqdm(enumerate(val_dataloader),leave=False, total=len(val_dataloader))
|
| 330 |
+
for i, d in tqdm_meter:
|
| 331 |
+
align = Alignment(divisor=256, mode='resize').to(device)
|
| 332 |
+
|
| 333 |
+
d = d.to(device)
|
| 334 |
+
align_d = align.align(d)
|
| 335 |
+
|
| 336 |
+
out_net = model(align_d)
|
| 337 |
+
out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1)
|
| 338 |
+
out_criterion = criterion_rd(out_net, d)
|
| 339 |
+
perc_loss = criterion_task(out_net, d)
|
| 340 |
+
total_loss = perc_loss + lmbda * out_criterion['bpp_loss']
|
| 341 |
+
|
| 342 |
+
bpp_loss.update(out_criterion["bpp_loss"])
|
| 343 |
+
mse_loss.update(out_criterion["mse_loss"])
|
| 344 |
+
psnr.update(out_criterion['psnr'])
|
| 345 |
+
percloss.update(perc_loss)
|
| 346 |
+
totalloss.update(total_loss)
|
| 347 |
+
|
| 348 |
+
txt = f"Loss: {totalloss.avg:.3f} | MSE loss: {mse_loss.avg:.5f} | Perception loss: {percloss.avg:.4f} | Bpp loss: {bpp_loss.avg:.4f}"
|
| 349 |
+
tqdm_meter.set_postfix_str(txt)
|
| 350 |
+
|
| 351 |
+
model.train()
|
| 352 |
+
print(f"Epoch: {epoch} | bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}")
|
| 353 |
+
return totalloss.avg
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def test_epoch(test_dataloader, model, criterion_rd, predictor, evaluator):
|
| 357 |
+
model.eval()
|
| 358 |
+
device = next(model.parameters()).device
|
| 359 |
+
pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device)
|
| 360 |
+
|
| 361 |
+
bpp_loss = AverageMeter()
|
| 362 |
+
psnr = AverageMeter()
|
| 363 |
+
|
| 364 |
+
with torch.no_grad():
|
| 365 |
+
tqdm_meter = tqdm.tqdm(enumerate(test_dataloader),leave=False, total=len(test_dataloader))
|
| 366 |
+
for i, batch in tqdm_meter:
|
| 367 |
+
with ExitStack() as stack:
|
| 368 |
+
## model to eval()
|
| 369 |
+
if isinstance(predictor.model, nn.Module):
|
| 370 |
+
stack.enter_context(inference_context(predictor.model)) # inference_context:将预测器的模型设为评估模式
|
| 371 |
+
stack.enter_context(torch.no_grad())
|
| 372 |
+
|
| 373 |
+
align = Alignment(divisor=256, mode='resize').to(device)
|
| 374 |
+
rcnn_align = Alignment(divisor=32).to(device)
|
| 375 |
+
|
| 376 |
+
img = read_image(batch[0]["file_name"], format="BGR")
|
| 377 |
+
d = torch.stack([batch[0]['image'].float().div(255)]).flip(1).to(device)
|
| 378 |
+
align_d = align.align(d)
|
| 379 |
+
|
| 380 |
+
out_net = model(align_d)
|
| 381 |
+
out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1)
|
| 382 |
+
out_criterion = criterion_rd(out_net, d)
|
| 383 |
+
|
| 384 |
+
trand_y_tilde = out_net['x_hat'].flip(1).mul(255)
|
| 385 |
+
trand_y_tilde = rcnn_align.align(trand_y_tilde - pixel_mean)
|
| 386 |
+
|
| 387 |
+
bpp_loss.update(out_criterion["bpp_loss"])
|
| 388 |
+
psnr.update(out_criterion['psnr'])
|
| 389 |
+
|
| 390 |
+
## MaskRCNN
|
| 391 |
+
predictions = predictor(img, trand_y_tilde)
|
| 392 |
+
evaluator.process(batch, [predictions])
|
| 393 |
+
txt = f"Bpp loss: {bpp_loss.avg:.4f} | PSNR loss: {psnr.avg:.4f}"
|
| 394 |
+
tqdm_meter.set_postfix_str(txt)
|
| 395 |
+
|
| 396 |
+
results = evaluator.evaluate()
|
| 397 |
+
model.train()
|
| 398 |
+
print(f"bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}")
|
| 399 |
+
return
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'):
|
| 403 |
+
model.eval()
|
| 404 |
+
bpp_loss = AverageMeter()
|
| 405 |
+
bpp_mask = AverageMeter()
|
| 406 |
+
psnr = AverageMeter()
|
| 407 |
+
msssim = AverageMeter()
|
| 408 |
+
lpips = AverageMeter()
|
| 409 |
+
dists = AverageMeter()
|
| 410 |
+
test_loss = AverageMeter()
|
| 411 |
+
|
| 412 |
+
vis_path = os.path.join("./MIM_test_high_resolu/", stage)
|
| 413 |
+
os.makedirs(vis_path, exist_ok=True)
|
| 414 |
+
if stage == 'test':
|
| 415 |
+
test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio)
|
| 416 |
+
os.makedirs(test_vis_path, exist_ok=True)
|
| 417 |
+
|
| 418 |
+
with torch.no_grad():
|
| 419 |
+
tqdm_meter = tqdm.tqdm(enumerate(test_loader), leave=False, total=len(test_loader))
|
| 420 |
+
for i, d in tqdm_meter:
|
| 421 |
+
d = d.to(device)
|
| 422 |
+
d0 = d
|
| 423 |
+
b_ori, _, h_ori, w_ori = d.shape
|
| 424 |
+
d, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(d)
|
| 425 |
+
# d, h_ori, w_ori = prepadding(d)
|
| 426 |
+
out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
|
| 427 |
+
# rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['ori_shape'], out_net['new_shape'])
|
| 428 |
+
rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], num_iter=20)
|
| 429 |
+
rec = crop_and_reconstruct(rec, patch_sizes, num_blocks_h, num_blocks_w)
|
| 430 |
+
rec = rec.unsqueeze(0)
|
| 431 |
+
rec = rec.to(device)
|
| 432 |
+
print('d0:', d0.shape)
|
| 433 |
+
print('rec:', rec.shape)
|
| 434 |
+
|
| 435 |
+
# d = d[:, :, :h_ori, :w_ori]
|
| 436 |
+
# rec = rec[:, :, :h_ori, :w_ori]
|
| 437 |
+
# rec = rec[:, :, :hx, :wx]
|
| 438 |
+
out_criterion = metrics_criterion(d0, out_net, rec)
|
| 439 |
+
|
| 440 |
+
bpp_loss.update(out_criterion["bpp_loss"])
|
| 441 |
+
bpp_mask.update(out_criterion["bpp_mask"])
|
| 442 |
+
psnr.update(out_criterion['psnr'])
|
| 443 |
+
msssim.update(out_criterion['msssim'])
|
| 444 |
+
lpips.update(out_criterion['lpips'])
|
| 445 |
+
dists.update(out_criterion['dists'])
|
| 446 |
+
test_loss.update(out_criterion['total_loss'])
|
| 447 |
+
|
| 448 |
+
## ======================= update progress bar & visualization ======================= ##
|
| 449 |
+
if stage == 'val':
|
| 450 |
+
# if i % 5 == 0:
|
| 451 |
+
with torch.no_grad():
|
| 452 |
+
real_fake_images = torch.cat((d0, rec), dim=0)
|
| 453 |
+
vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{i}.jpg"))
|
| 454 |
+
vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg"))
|
| 455 |
+
if stage == 'test':
|
| 456 |
+
with torch.no_grad():
|
| 457 |
+
vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"), nrow=8)
|
| 458 |
+
# txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
|
| 459 |
+
# tqdm_meter.set_postfix_str(txt)
|
| 460 |
+
|
| 461 |
+
model.train()
|
| 462 |
+
|
| 463 |
+
# 假设其它变量和环境已经正确设置
|
| 464 |
+
if torch.distributed.is_initialized():
|
| 465 |
+
rank = dist.get_rank()
|
| 466 |
+
else:
|
| 467 |
+
rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
|
| 468 |
+
|
| 469 |
+
if rank == 0:
|
| 470 |
+
log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}"
|
| 471 |
+
logging.info(log_txt)
|
| 472 |
+
return test_loss.avg
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"):
|
| 476 |
+
torch.save(state, base_dir+filename)
|
| 477 |
+
if is_best:
|
| 478 |
+
torch.save(state, base_dir+"checkpoint_best.pth.tar")
|
| 479 |
+
|
| 480 |
+
def parse_args(argv):
|
| 481 |
+
parser = argparse.ArgumentParser(description="Example training script.")
|
| 482 |
+
parser.add_argument(
|
| 483 |
+
"-c",
|
| 484 |
+
"--config",
|
| 485 |
+
default="config/vpt_default.yaml",
|
| 486 |
+
help="Path to config file",
|
| 487 |
+
)
|
| 488 |
+
parser.add_argument(
|
| 489 |
+
'--name',
|
| 490 |
+
default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
|
| 491 |
+
type=str,
|
| 492 |
+
help='Result dir name',
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
| 495 |
+
help='learning rate (absolute lr)')
|
| 496 |
+
given_configs, remaining = parser.parse_known_args(argv)
|
| 497 |
+
# distributed training parameters
|
| 498 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 499 |
+
help='number of distributed processes')
|
| 500 |
+
parser.add_argument('--local-rank', default=-1, type=int)
|
| 501 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
| 502 |
+
parser.add_argument('--dist_url', default='env://',
|
| 503 |
+
help='url used to set up distributed training')
|
| 504 |
+
with open(given_configs.config) as file:
|
| 505 |
+
yaml_data= yaml.safe_load(file)
|
| 506 |
+
parser.set_defaults(**yaml_data)
|
| 507 |
+
|
| 508 |
+
parser.add_argument(
|
| 509 |
+
"-T",
|
| 510 |
+
"--TEST",
|
| 511 |
+
# action='store_true',
|
| 512 |
+
default=False,
|
| 513 |
+
help='Testing'
|
| 514 |
+
)
|
| 515 |
+
args = parser.parse_args(remaining)
|
| 516 |
+
return args
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def main(argv):
|
| 520 |
+
args = parse_args(argv)
|
| 521 |
+
base_dir = init(args) # create the base dir for saving the results
|
| 522 |
+
|
| 523 |
+
if args.output_dir:
|
| 524 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 525 |
+
args.log_dir = args.output_dir
|
| 526 |
+
|
| 527 |
+
misc.init_distributed_mode(args)
|
| 528 |
+
|
| 529 |
+
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
| 530 |
+
print("{}".format(args).replace(', ', ',\n'))
|
| 531 |
+
|
| 532 |
+
device = torch.device(args.device)
|
| 533 |
+
# fix the seed for reproducibility
|
| 534 |
+
seed = args.seed + misc.get_rank()
|
| 535 |
+
torch.manual_seed(seed)
|
| 536 |
+
torch.cuda.manual_seed(seed)
|
| 537 |
+
np.random.seed(seed)
|
| 538 |
+
random.seed(seed)
|
| 539 |
+
|
| 540 |
+
cudnn.benchmark = True
|
| 541 |
+
|
| 542 |
+
setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
|
| 543 |
+
msg = f'======================= {args.name} ======================='
|
| 544 |
+
logging.info(msg)
|
| 545 |
+
for k in args.__dict__:
|
| 546 |
+
logging.info(k + ':' + str(args.__dict__[k]))
|
| 547 |
+
logging.info('=' * len(msg))
|
| 548 |
+
|
| 549 |
+
## ======================= prepare dataset ======================= ##
|
| 550 |
+
transform_det = transforms.Compose([
|
| 551 |
+
transforms.RandomHorizontalFlip(),
|
| 552 |
+
transforms.ToTensor()])
|
| 553 |
+
transform_val = transforms.Compose([
|
| 554 |
+
# transforms.Resize(224),
|
| 555 |
+
# transforms.CenterCrop(224),
|
| 556 |
+
transforms.ToTensor()
|
| 557 |
+
])
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
if args.dataset=='coco':
|
| 561 |
+
train_dataset = MSCOCO(args.dataset_path + "/train2017/",
|
| 562 |
+
transform_det,
|
| 563 |
+
"/home/t2vg-a100-G4-10/project/qyp/mimc_rope/util/img_list.txt")
|
| 564 |
+
# val_dataset = Kodak(args.kodak_path, transform_val)
|
| 565 |
+
val_dataset = MSCOCO(args.kodak_path, transform_val)
|
| 566 |
+
|
| 567 |
+
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
|
| 568 |
+
|
| 569 |
+
# if args.distributed:
|
| 570 |
+
if True:
|
| 571 |
+
num_tasks = misc.get_world_size()
|
| 572 |
+
global_rank = misc.get_rank()
|
| 573 |
+
sampler_val = torch.utils.data.DistributedSampler(
|
| 574 |
+
val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
| 575 |
+
)
|
| 576 |
+
else:
|
| 577 |
+
sampler_train = torch.utils.data.RandomSampler(train_dataset)
|
| 578 |
+
|
| 579 |
+
if global_rank == 0 and args.log_dir is not None:
|
| 580 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
| 581 |
+
log_writer = SummaryWriter(log_dir=args.log_dir)
|
| 582 |
+
else:
|
| 583 |
+
log_writer = None
|
| 584 |
+
|
| 585 |
+
val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=1,
|
| 586 |
+
num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True)
|
| 587 |
+
|
| 588 |
+
## ======================= prepare model ======================= ##
|
| 589 |
+
vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
|
| 590 |
+
model = models_mage_codec_high_resolu.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
|
| 591 |
+
mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
|
| 592 |
+
vqgan_ckpt_path=vqgan_ckpt_path)
|
| 593 |
+
|
| 594 |
+
model.to(device)
|
| 595 |
+
model_without_ddp = model
|
| 596 |
+
print("Model = %s" % str(model_without_ddp))
|
| 597 |
+
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
| 598 |
+
if args.lr is None: # only base_lr is specified
|
| 599 |
+
args.lr = args.blr * eff_batch_size / 256
|
| 600 |
+
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
| 601 |
+
print("actual lr: %.2e" % args.lr)
|
| 602 |
+
|
| 603 |
+
print("accumulate grad iterations: %d" % args.accum_iter)
|
| 604 |
+
print("effective batch size: %d" % eff_batch_size)
|
| 605 |
+
|
| 606 |
+
if args.distributed:
|
| 607 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 608 |
+
model_without_ddp = model.module
|
| 609 |
+
|
| 610 |
+
# following timm: set wd as 0 for bias and norm layers
|
| 611 |
+
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
|
| 612 |
+
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
|
| 613 |
+
print(optimizer)
|
| 614 |
+
loss_scaler = NativeScaler()
|
| 615 |
+
|
| 616 |
+
# resume from a checkpoint
|
| 617 |
+
misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
| 618 |
+
optimizer=optimizer, loss_scaler=loss_scaler, strict=False)
|
| 619 |
+
|
| 620 |
+
metrics_criterion = CalMetrics()
|
| 621 |
+
# cls_criterion = Clsloss(device, True)
|
| 622 |
+
## ======================= Start Training ======================= ##
|
| 623 |
+
last_epoch = args.start_epoch
|
| 624 |
+
|
| 625 |
+
## ======================= pre validation ======================= ##
|
| 626 |
+
print("############## pre validation ##############")
|
| 627 |
+
best_loss = float("inf")
|
| 628 |
+
tqrange = tqdm.trange(last_epoch, args.epochs)
|
| 629 |
+
val_mask_ratio = 0.5
|
| 630 |
+
test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
if __name__ == "__main__":
|
| 635 |
+
main(sys.argv[1:])
|
engine_finetune.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Iterable, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from timm.data import Mixup
|
| 8 |
+
from timm.utils import accuracy
|
| 9 |
+
|
| 10 |
+
import util.misc as misc
|
| 11 |
+
import util.lr_sched as lr_sched
|
| 12 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 16 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 17 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 18 |
+
mixup_fn: Optional[Mixup] = None, log_writer=None,
|
| 19 |
+
args=None, denormalize=False):
|
| 20 |
+
model.train(True)
|
| 21 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 22 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 23 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 24 |
+
print_freq = 20
|
| 25 |
+
|
| 26 |
+
accum_iter = args.accum_iter
|
| 27 |
+
|
| 28 |
+
optimizer.zero_grad()
|
| 29 |
+
|
| 30 |
+
if log_writer is not None:
|
| 31 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 32 |
+
|
| 33 |
+
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 34 |
+
|
| 35 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 36 |
+
if data_iter_step % accum_iter == 0:
|
| 37 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
| 38 |
+
|
| 39 |
+
samples = samples.to(device, non_blocking=True)
|
| 40 |
+
targets = targets.to(device, non_blocking=True)
|
| 41 |
+
|
| 42 |
+
if mixup_fn is not None:
|
| 43 |
+
samples, targets = mixup_fn(samples, targets)
|
| 44 |
+
|
| 45 |
+
if denormalize:
|
| 46 |
+
for c_id, (mean, std) in enumerate(zip(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)):
|
| 47 |
+
samples[:, c_id] = samples[:, c_id] * std + mean
|
| 48 |
+
|
| 49 |
+
with torch.cuda.amp.autocast():
|
| 50 |
+
outputs = model(samples)
|
| 51 |
+
loss = criterion(outputs, targets)
|
| 52 |
+
|
| 53 |
+
loss_value = loss.item()
|
| 54 |
+
|
| 55 |
+
if not math.isfinite(loss_value):
|
| 56 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
loss /= accum_iter
|
| 60 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 61 |
+
parameters=model.parameters(), create_graph=False,
|
| 62 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
| 63 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 64 |
+
optimizer.zero_grad()
|
| 65 |
+
|
| 66 |
+
torch.cuda.synchronize()
|
| 67 |
+
|
| 68 |
+
metric_logger.update(loss=loss_value)
|
| 69 |
+
min_lr = 10.
|
| 70 |
+
max_lr = 0.
|
| 71 |
+
for group in optimizer.param_groups:
|
| 72 |
+
min_lr = min(min_lr, group["lr"])
|
| 73 |
+
max_lr = max(max_lr, group["lr"])
|
| 74 |
+
|
| 75 |
+
metric_logger.update(lr=max_lr)
|
| 76 |
+
|
| 77 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 78 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
| 79 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 80 |
+
This calibrates different curves when batch size changes.
|
| 81 |
+
"""
|
| 82 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 83 |
+
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
| 84 |
+
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
| 85 |
+
|
| 86 |
+
# gather the stats from all processes
|
| 87 |
+
metric_logger.synchronize_between_processes()
|
| 88 |
+
print("Averaged stats:", metric_logger)
|
| 89 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@torch.no_grad()
|
| 93 |
+
def evaluate(data_loader, model, device, denormalize=False):
|
| 94 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 95 |
+
|
| 96 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 97 |
+
header = 'Test:'
|
| 98 |
+
|
| 99 |
+
# switch to evaluation mode
|
| 100 |
+
model.eval()
|
| 101 |
+
|
| 102 |
+
for batch in metric_logger.log_every(data_loader, 10, header):
|
| 103 |
+
images = batch[0]
|
| 104 |
+
target = batch[-1]
|
| 105 |
+
images = images.to(device, non_blocking=True)
|
| 106 |
+
target = target.to(device, non_blocking=True)
|
| 107 |
+
|
| 108 |
+
if denormalize:
|
| 109 |
+
for c_id, (mean, std) in enumerate(zip(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)):
|
| 110 |
+
images[:, c_id] = images[:, c_id] * std + mean
|
| 111 |
+
|
| 112 |
+
# compute output
|
| 113 |
+
with torch.cuda.amp.autocast():
|
| 114 |
+
output = model(images)
|
| 115 |
+
loss = criterion(output, target)
|
| 116 |
+
|
| 117 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
| 118 |
+
|
| 119 |
+
batch_size = images.shape[0]
|
| 120 |
+
metric_logger.update(loss=loss.item())
|
| 121 |
+
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
| 122 |
+
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
| 123 |
+
# gather the stats from all processes
|
| 124 |
+
metric_logger.synchronize_between_processes()
|
| 125 |
+
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
| 126 |
+
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
| 127 |
+
|
| 128 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
engine_pretrain.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import sys
|
| 3 |
+
from typing import Iterable
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import util.misc as misc
|
| 8 |
+
import util.lr_sched as lr_sched
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def train_one_epoch(model: torch.nn.Module,
|
| 12 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
| 13 |
+
device: torch.device, epoch: int, loss_scaler,
|
| 14 |
+
log_writer=None,
|
| 15 |
+
args=None):
|
| 16 |
+
model.train(True)
|
| 17 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
| 18 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
| 19 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 20 |
+
print_freq = 20
|
| 21 |
+
|
| 22 |
+
accum_iter = args.accum_iter
|
| 23 |
+
|
| 24 |
+
optimizer.zero_grad()
|
| 25 |
+
|
| 26 |
+
if log_writer is not None:
|
| 27 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
| 28 |
+
|
| 29 |
+
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
| 30 |
+
|
| 31 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
| 32 |
+
if data_iter_step % accum_iter == 0:
|
| 33 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
| 34 |
+
|
| 35 |
+
samples = samples.to(device, non_blocking=True)
|
| 36 |
+
|
| 37 |
+
with torch.cuda.amp.autocast():
|
| 38 |
+
loss, _, _ = model(samples)
|
| 39 |
+
|
| 40 |
+
loss_value = loss.item()
|
| 41 |
+
|
| 42 |
+
if not math.isfinite(loss_value):
|
| 43 |
+
print("Loss is {}, stopping training".format(loss_value))
|
| 44 |
+
sys.exit(1)
|
| 45 |
+
|
| 46 |
+
loss /= accum_iter
|
| 47 |
+
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
|
| 48 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
| 49 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
|
| 52 |
+
torch.cuda.synchronize()
|
| 53 |
+
|
| 54 |
+
metric_logger.update(loss=loss_value)
|
| 55 |
+
|
| 56 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 57 |
+
metric_logger.update(lr=lr)
|
| 58 |
+
|
| 59 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
| 60 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
| 61 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
| 62 |
+
This calibrates different curves when batch size changes.
|
| 63 |
+
"""
|
| 64 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
| 65 |
+
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
|
| 66 |
+
log_writer.add_scalar('lr', lr, epoch_1000x)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# gather the stats from all processes
|
| 70 |
+
metric_logger.synchronize_between_processes()
|
| 71 |
+
print("Averaged stats:", metric_logger)
|
| 72 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
environment.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: mage
|
| 2 |
+
channels:
|
| 3 |
+
- pytorch
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- python=3.8.5
|
| 7 |
+
- pip=20.3
|
| 8 |
+
- cudatoolkit=10.2
|
| 9 |
+
- pytorch=1.7.1
|
| 10 |
+
- torchvision=0.8.2
|
| 11 |
+
- numpy=1.19.2
|
| 12 |
+
- pip:
|
| 13 |
+
- albumentations==0.4.3
|
| 14 |
+
- opencv-python==4.1.2.30
|
| 15 |
+
- pudb==2019.2
|
| 16 |
+
- imageio==2.9.0
|
| 17 |
+
- imageio-ffmpeg==0.4.2
|
| 18 |
+
- pytorch-lightning==1.0.8
|
| 19 |
+
- omegaconf==2.0.0
|
| 20 |
+
- test-tube>=0.7.5
|
| 21 |
+
- streamlit>=0.73.1
|
| 22 |
+
- einops==0.3.0
|
| 23 |
+
- more-itertools>=8.0.0
|
| 24 |
+
- transformers==4.3.1
|
| 25 |
+
- timm==0.3.2
|
| 26 |
+
- -e .
|