wangyanhui666 commited on
Commit
9cf79cf
·
1 Parent(s): 361e832

fine tune decoder with mask

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes copy +35 -0
  2. .vscode/launch.json +25 -0
  3. Fine_Tune_VQ_w_P/20240604_030821.log +174 -0
  4. Fine_Tune_VQ_w_P/checkpoint-20.pth +3 -0
  5. Fine_Tune_VQ_w_P/checkpoint-last.pth +3 -0
  6. Fine_Tune_VQ_w_P/events.out.tfevents.1717470501.t2vg-a100-G4-10.976065.0 +3 -0
  7. Fine_Tune_VQ_w_P/log.txt +24 -0
  8. MIM_vbr/20240516_010714.log +53 -0
  9. MIM_vbr/checkpoint-399.pth +3 -0
  10. MIM_vbr/continue_tune_rope.log +514 -0
  11. MIM_vbr/events.out.tfevents.1714742678.t2vg-a100-G4-10.4118013.0 +3 -0
  12. MIM_vbr/log.txt +0 -0
  13. MIM_vbr_eval_lossless/20240518_044257.log +53 -0
  14. MIM_vbr_eval_lossless/20240521_120210.log +52 -0
  15. MIM_vbr_eval_lossless/20240521_121146.log +63 -0
  16. VQ_w_Prompt/20240602_123240.log +196 -0
  17. VQ_w_Prompt/checkpoint_69.pth.tar +3 -0
  18. __pycache__/model_vq.cpython-38.pyc +0 -0
  19. __pycache__/models_mage_codec.cpython-38.pyc +0 -0
  20. __pycache__/models_mage_codec_cls.cpython-38.pyc +0 -0
  21. __pycache__/models_mage_codec_full.cpython-38.pyc +0 -0
  22. __pycache__/models_mage_codec_high_resolu.cpython-38.pyc +0 -0
  23. __pycache__/models_mage_codec_hr_new.cpython-38.pyc +0 -0
  24. __pycache__/models_mage_codec_hs_w_p.cpython-38.pyc +0 -0
  25. __pycache__/models_mage_codec_rope.cpython-38.pyc +0 -0
  26. cal_upper_bound.py +287 -0
  27. ckpt_pretrained/vgg.pth +3 -0
  28. coco_eval.py +0 -0
  29. config/cal_upper_bound.yaml +13 -0
  30. config/eval copy.yaml +41 -0
  31. config/eval.yaml +41 -0
  32. config/eval_config.yaml +34 -0
  33. config/eval_high_res.yaml +41 -0
  34. config/eval_vq.yaml +38 -0
  35. config/fine_tune_vq_w_mask.yaml +39 -0
  36. config/train_det.yaml +38 -0
  37. config/train_im_cls.yaml +32 -0
  38. config/train_mim_config.yaml +33 -0
  39. config/vqgan.yaml +16 -0
  40. deploited/models_mage_codec_pos_linear.py +767 -0
  41. deploited/models_mage_codec_rm_pos.py +759 -0
  42. deploited/models_mage_codec_rope.py +767 -0
  43. deploited/random_with_rope.log +118 -0
  44. deploited/random_with_rope.pth +3 -0
  45. deploited/test_crop.py +68 -0
  46. deploited/train_vbr_codec_rope.py +673 -0
  47. detection.py +635 -0
  48. engine_finetune.py +128 -0
  49. engine_pretrain.py +72 -0
  50. environment.yaml +26 -0
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.vscode/launch.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Train VBR Codec",
6
+ "type": "debugpy",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/fine_tune_vq_w_mask.py",
9
+ "console": "integratedTerminal",
10
+ "args": [
11
+ "-c", "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/fine_tune_vq_w_mask.yaml",
12
+ "--world_size", "2"
13
+ ],
14
+ "env": {
15
+ "CUDA_VISIBLE_DEVICES": "2,3",
16
+ "MASTER_ADDR": "localhost",
17
+ "MASTER_PORT": "12345",
18
+ "NODE_RANK": "0",
19
+ "NPROC_PER_NODE": "2",
20
+ "NNODES": "1"
21
+ },
22
+ "preLaunchTask": ""
23
+ }
24
+ ]
25
+ }
Fine_Tune_VQ_w_P/20240604_030821.log ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-06-04 03:08:21,626 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P//20240604_030821.log
2
+ 2024-06-04 03:08:21,626 [INFO ] ======================= exp =======================
3
+ 2024-06-04 03:08:21,626 [INFO ] config:config/vpt_default.yaml
4
+ 2024-06-04 03:08:21,626 [INFO ] name:exp
5
+ 2024-06-04 03:08:21,626 [INFO ] lr:None
6
+ 2024-06-04 03:08:21,626 [INFO ] world_size:4
7
+ 2024-06-04 03:08:21,626 [INFO ] local_rank:-1
8
+ 2024-06-04 03:08:21,627 [INFO ] dist_on_itp:False
9
+ 2024-06-04 03:08:21,627 [INFO ] dist_url:env://
10
+ 2024-06-04 03:08:21,627 [INFO ] TEST:False
11
+ 2024-06-04 03:08:21,627 [INFO ] model:mage_vit_base_patch16
12
+ 2024-06-04 03:08:21,627 [INFO ] dataset:coco
13
+ 2024-06-04 03:08:21,627 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
14
+ 2024-06-04 03:08:21,627 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/test/
15
+ 2024-06-04 03:08:21,627 [INFO ] eval_path:
16
+ 2024-06-04 03:08:21,627 [INFO ] batch_size:64
17
+ 2024-06-04 03:08:21,627 [INFO ] learning_rate:0.0001
18
+ 2024-06-04 03:08:21,627 [INFO ] test_batch_size:1
19
+ 2024-06-04 03:08:21,627 [INFO ] num_workers:4
20
+ 2024-06-04 03:08:21,627 [INFO ] epochs:100
21
+ 2024-06-04 03:08:21,627 [INFO ] blr:0.0001
22
+ 2024-06-04 03:08:21,627 [INFO ] min_lr:5e-06
23
+ 2024-06-04 03:08:21,627 [INFO ] accum_iter:1
24
+ 2024-06-04 03:08:21,627 [INFO ] input_size:256
25
+ 2024-06-04 03:08:21,627 [INFO ] weight_decay:0.05
26
+ 2024-06-04 03:08:21,627 [INFO ] warmup_epochs:10
27
+ 2024-06-04 03:08:21,627 [INFO ] mask_ratio_min:0.25
28
+ 2024-06-04 03:08:21,627 [INFO ] mask_ratio_max:0.75
29
+ 2024-06-04 03:08:21,627 [INFO ] grad_clip:3.0
30
+ 2024-06-04 03:08:21,628 [INFO ] device:cuda
31
+ 2024-06-04 03:08:21,628 [INFO ] distributed:True
32
+ 2024-06-04 03:08:21,628 [INFO ] seed:0
33
+ 2024-06-04 03:08:21,628 [INFO ] gpu_id:2
34
+ 2024-06-04 03:08:21,628 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
35
+ 2024-06-04 03:08:21,628 [INFO ] maskrcnn_path:/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl
36
+ 2024-06-04 03:08:21,628 [INFO ] start_epoch:0
37
+ 2024-06-04 03:08:21,628 [INFO ] pin_mem:True
38
+ 2024-06-04 03:08:21,628 [INFO ] save:True
39
+ 2024-06-04 03:08:21,628 [INFO ] exp_name:Fine_Tune_VQ_w_P
40
+ 2024-06-04 03:08:21,628 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
41
+ 2024-06-04 03:08:21,628 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P
42
+ 2024-06-04 03:08:21,628 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P
43
+ 2024-06-04 03:08:21,628 [INFO ] strart_epoch:0
44
+ 2024-06-04 03:08:21,628 [INFO ] cuda:True
45
+ 2024-06-04 03:08:21,628 [INFO ] rank:0
46
+ 2024-06-04 03:08:21,628 [INFO ] gpu:0
47
+ 2024-06-04 03:08:21,628 [INFO ] dist_backend:nccl
48
+ 2024-06-04 03:08:21,628 [INFO ] ===================================================
49
+ 2024-06-04 03:10:20,905 [INFO ] -1|taskloss:2.49146|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.06076|lpips:0.46229|dists:0.19295
50
+ 2024-06-04 03:10:48,101 [INFO ] -1|taskloss:1.68947|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.15691|lpips:0.41068|dists:0.16344
51
+ 2024-06-04 03:11:14,835 [INFO ] -1|taskloss:1.33339|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.94234|lpips:0.37990|dists:0.14601
52
+ 2024-06-04 03:11:41,295 [INFO ] -1|taskloss:1.03034|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.56941|lpips:0.35252|dists:0.13884
53
+ 2024-06-04 03:11:41,296 [INFO ] Test loss=1.636163592338562
54
+ 2024-06-04 03:28:26,377 [INFO ] 0|taskloss:2.33836|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.29709|lpips:0.46480|dists:0.21399
55
+ 2024-06-04 03:28:53,053 [INFO ] 0|taskloss:1.62612|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.42644|lpips:0.42104|dists:0.19210
56
+ 2024-06-04 03:29:20,005 [INFO ] 0|taskloss:1.29490|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.93270|lpips:0.38919|dists:0.17758
57
+ 2024-06-04 03:29:47,006 [INFO ] 0|taskloss:0.99774|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.65782|lpips:0.36090|dists:0.16477
58
+ 2024-06-04 03:29:47,006 [INFO ] Test loss=1.973321333527565
59
+ 2024-06-04 03:29:47,960 [INFO ] Saving BEST checkpoint!
60
+ 2024-06-04 03:46:15,723 [INFO ] 1|taskloss:2.27418|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.28722|lpips:0.47362|dists:0.22182
61
+ 2024-06-04 03:46:42,375 [INFO ] 1|taskloss:1.65933|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.20894|lpips:0.41849|dists:0.19428
62
+ 2024-06-04 03:47:09,078 [INFO ] 1|taskloss:1.28544|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.95363|lpips:0.39304|dists:0.18215
63
+ 2024-06-04 03:47:35,953 [INFO ] 1|taskloss:1.02331|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.67013|lpips:0.35942|dists:0.17008
64
+ 2024-06-04 03:47:35,954 [INFO ] Test loss=2.0538930036127567
65
+ 2024-06-04 04:04:02,717 [INFO ] 2|taskloss:2.25254|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:13.97699|lpips:0.47056|dists:0.22627
66
+ 2024-06-04 04:04:29,767 [INFO ] 2|taskloss:1.66452|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.12406|lpips:0.42705|dists:0.20062
67
+ 2024-06-04 04:04:56,853 [INFO ] 2|taskloss:1.26301|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.65730|lpips:0.39007|dists:0.18716
68
+ 2024-06-04 04:05:23,967 [INFO ] 2|taskloss:1.03645|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.02926|lpips:0.36933|dists:0.18181
69
+ 2024-06-04 04:05:23,968 [INFO ] Test loss=2.067601718939841
70
+ 2024-06-04 04:21:49,690 [INFO ] 3|taskloss:2.28804|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.03534|lpips:0.48001|dists:0.22667
71
+ 2024-06-04 04:22:16,508 [INFO ] 3|taskloss:1.66096|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18743|lpips:0.42347|dists:0.20229
72
+ 2024-06-04 04:22:43,283 [INFO ] 3|taskloss:1.28439|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.80286|lpips:0.39074|dists:0.18586
73
+ 2024-06-04 04:23:09,940 [INFO ] 3|taskloss:1.01509|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.39696|lpips:0.37266|dists:0.18077
74
+ 2024-06-04 04:23:09,941 [INFO ] Test loss=2.0790173208806664
75
+ 2024-06-04 04:39:36,248 [INFO ] 4|taskloss:2.17187|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.04568|lpips:0.46708|dists:0.21667
76
+ 2024-06-04 04:40:03,220 [INFO ] 4|taskloss:1.68821|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.43250|lpips:0.42161|dists:0.19500
77
+ 2024-06-04 04:40:29,994 [INFO ] 4|taskloss:1.23894|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.90158|lpips:0.38726|dists:0.18233
78
+ 2024-06-04 04:40:56,799 [INFO ] 4|taskloss:1.01278|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.31897|lpips:0.36575|dists:0.17080
79
+ 2024-06-04 04:40:56,799 [INFO ] Test loss=2.047704319062177
80
+ 2024-06-04 04:57:23,679 [INFO ] 5|taskloss:2.24287|mask ratio:0.75||bpp:0.01395|mask:0.00003|psnr:14.29049|lpips:0.46545|dists:0.21852
81
+ 2024-06-04 04:57:50,801 [INFO ] 5|taskloss:1.62839|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.02618|lpips:0.41894|dists:0.19633
82
+ 2024-06-04 04:58:17,569 [INFO ] 5|taskloss:1.29102|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.74514|lpips:0.39395|dists:0.18185
83
+ 2024-06-04 04:58:44,510 [INFO ] 5|taskloss:1.01440|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.38981|lpips:0.36329|dists:0.17676
84
+ 2024-06-04 04:58:44,511 [INFO ] Test loss=2.0560930917417863
85
+ 2024-06-04 05:15:10,098 [INFO ] 6|taskloss:2.32205|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.04229|lpips:0.46922|dists:0.21845
86
+ 2024-06-04 05:15:36,966 [INFO ] 6|taskloss:1.60729|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:14.98244|lpips:0.42790|dists:0.20636
87
+ 2024-06-04 05:16:03,975 [INFO ] 6|taskloss:1.26320|mask ratio:0.55||bpp:0.02509|mask:0.00003|psnr:15.94796|lpips:0.38764|dists:0.18432
88
+ 2024-06-04 05:16:30,771 [INFO ] 6|taskloss:1.03816|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.29310|lpips:0.37118|dists:0.17678
89
+ 2024-06-04 05:16:30,771 [INFO ] Test loss=2.0717015551963414
90
+ 2024-06-04 05:32:57,502 [INFO ] 7|taskloss:2.36936|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.30189|lpips:0.46484|dists:0.21939
91
+ 2024-06-04 05:33:24,434 [INFO ] 7|taskloss:1.63491|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.08945|lpips:0.42373|dists:0.19809
92
+ 2024-06-04 05:33:51,433 [INFO ] 7|taskloss:1.23911|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.95378|lpips:0.38986|dists:0.18503
93
+ 2024-06-04 05:34:18,373 [INFO ] 7|taskloss:0.99703|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.33615|lpips:0.36766|dists:0.17350
94
+ 2024-06-04 05:34:18,373 [INFO ] Test loss=2.0780272555211923
95
+ 2024-06-04 05:50:45,587 [INFO ] 8|taskloss:2.23844|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.26759|lpips:0.46414|dists:0.22619
96
+ 2024-06-04 05:51:12,540 [INFO ] 8|taskloss:1.66498|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:14.91135|lpips:0.43537|dists:0.21694
97
+ 2024-06-04 05:51:39,386 [INFO ] 8|taskloss:1.27593|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.52610|lpips:0.40321|dists:0.20067
98
+ 2024-06-04 05:52:05,972 [INFO ] 8|taskloss:1.03960|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:15.80372|lpips:0.38153|dists:0.19366
99
+ 2024-06-04 05:52:05,973 [INFO ] Test loss=2.0742442923749422
100
+ 2024-06-04 06:08:33,064 [INFO ] 9|taskloss:2.32985|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.33968|lpips:0.47202|dists:0.22093
101
+ 2024-06-04 06:08:59,961 [INFO ] 9|taskloss:1.62963|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.15994|lpips:0.42402|dists:0.19959
102
+ 2024-06-04 06:09:26,718 [INFO ] 9|taskloss:1.25288|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.62853|lpips:0.38951|dists:0.17977
103
+ 2024-06-04 06:09:53,638 [INFO ] 9|taskloss:1.00600|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.26209|lpips:0.36578|dists:0.17169
104
+ 2024-06-04 06:09:53,639 [INFO ] Test loss=2.0731508512786263
105
+ 2024-06-04 06:26:19,204 [INFO ] 10|taskloss:2.34149|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.26575|lpips:0.47004|dists:0.22524
106
+ 2024-06-04 06:26:45,967 [INFO ] 10|taskloss:1.62147|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.09978|lpips:0.42139|dists:0.19693
107
+ 2024-06-04 06:27:12,796 [INFO ] 10|taskloss:1.27020|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.77150|lpips:0.38540|dists:0.17910
108
+ 2024-06-04 06:27:39,548 [INFO ] 10|taskloss:1.02932|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.26428|lpips:0.36994|dists:0.17073
109
+ 2024-06-04 06:27:39,549 [INFO ] Test loss=2.0839089696479647
110
+ 2024-06-04 06:44:06,036 [INFO ] 11|taskloss:2.24896|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.07473|lpips:0.47022|dists:0.21116
111
+ 2024-06-04 06:44:32,940 [INFO ] 11|taskloss:1.69199|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18235|lpips:0.42282|dists:0.18916
112
+ 2024-06-04 06:44:59,784 [INFO ] 11|taskloss:1.26253|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:16.03878|lpips:0.38759|dists:0.17844
113
+ 2024-06-04 06:45:26,464 [INFO ] 11|taskloss:1.01063|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.42851|lpips:0.35960|dists:0.16572
114
+ 2024-06-04 06:45:26,465 [INFO ] Test loss=2.0745057896989785
115
+ 2024-06-04 07:01:51,689 [INFO ] 12|taskloss:2.26208|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.23028|lpips:0.46908|dists:0.22384
116
+ 2024-06-04 07:02:18,690 [INFO ] 12|taskloss:1.60826|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.05205|lpips:0.42262|dists:0.19667
117
+ 2024-06-04 07:02:45,653 [INFO ] 12|taskloss:1.24774|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.88546|lpips:0.38749|dists:0.18423
118
+ 2024-06-04 07:03:12,706 [INFO ] 12|taskloss:1.00849|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.33344|lpips:0.36366|dists:0.17469
119
+ 2024-06-04 07:03:12,707 [INFO ] Test loss=2.0502657253597727
120
+ 2024-06-04 07:19:36,537 [INFO ] 13|taskloss:2.27210|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.07016|lpips:0.47514|dists:0.21749
121
+ 2024-06-04 07:20:03,327 [INFO ] 13|taskloss:1.62562|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.32130|lpips:0.41837|dists:0.19381
122
+ 2024-06-04 07:20:30,312 [INFO ] 13|taskloss:1.20922|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:16.01362|lpips:0.38459|dists:0.17793
123
+ 2024-06-04 07:20:57,097 [INFO ] 13|taskloss:1.02460|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.43850|lpips:0.36615|dists:0.17285
124
+ 2024-06-04 07:20:57,097 [INFO ] Test loss=2.045453323897087
125
+ 2024-06-04 07:37:23,230 [INFO ] 14|taskloss:2.26558|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.31213|lpips:0.47121|dists:0.22154
126
+ 2024-06-04 07:37:50,002 [INFO ] 14|taskloss:1.65192|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.25740|lpips:0.43203|dists:0.20555
127
+ 2024-06-04 07:38:16,697 [INFO ] 14|taskloss:1.28469|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.82785|lpips:0.39241|dists:0.18424
128
+ 2024-06-04 07:38:43,483 [INFO ] 14|taskloss:0.98623|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.52078|lpips:0.36753|dists:0.17625
129
+ 2024-06-04 07:38:43,484 [INFO ] Test loss=2.058469656600645
130
+ 2024-06-04 07:55:08,283 [INFO ] 15|taskloss:2.34680|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.18895|lpips:0.46331|dists:0.22039
131
+ 2024-06-04 07:55:34,925 [INFO ] 15|taskloss:1.61486|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.15459|lpips:0.42977|dists:0.20644
132
+ 2024-06-04 07:56:01,738 [INFO ] 15|taskloss:1.21967|mask ratio:0.55||bpp:0.02512|mask:0.00003|psnr:15.76448|lpips:0.38905|dists:0.18471
133
+ 2024-06-04 07:56:28,420 [INFO ] 15|taskloss:1.01635|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.40507|lpips:0.36936|dists:0.17892
134
+ 2024-06-04 07:56:28,421 [INFO ] Test loss=2.064034671459121
135
+ 2024-06-04 08:12:55,104 [INFO ] 16|taskloss:2.24783|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.38436|lpips:0.46735|dists:0.21890
136
+ 2024-06-04 08:13:22,005 [INFO ] 16|taskloss:1.68004|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.20924|lpips:0.42151|dists:0.19677
137
+ 2024-06-04 08:13:48,870 [INFO ] 16|taskloss:1.22699|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.00963|lpips:0.38648|dists:0.18326
138
+ 2024-06-04 08:14:15,705 [INFO ] 16|taskloss:1.00658|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.69925|lpips:0.36080|dists:0.17115
139
+ 2024-06-04 08:14:15,706 [INFO ] Test loss=2.0563709961557195
140
+ 2024-06-04 08:30:41,819 [INFO ] 17|taskloss:2.26464|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.51690|lpips:0.46617|dists:0.21955
141
+ 2024-06-04 08:31:08,480 [INFO ] 17|taskloss:1.69014|mask ratio:0.65||bpp:0.01954|mask:0.00003|psnr:15.10275|lpips:0.42308|dists:0.19810
142
+ 2024-06-04 08:31:35,927 [INFO ] 17|taskloss:1.24362|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:15.88523|lpips:0.39605|dists:0.19024
143
+ 2024-06-04 08:32:04,307 [INFO ] 17|taskloss:1.00460|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.62445|lpips:0.36797|dists:0.17678
144
+ 2024-06-04 08:32:04,308 [INFO ] Test loss=2.064842736402745
145
+ 2024-06-04 08:48:38,838 [INFO ] 18|taskloss:2.21708|mask ratio:0.75||bpp:0.01399|mask:0.00003|psnr:14.15950|lpips:0.47041|dists:0.22174
146
+ 2024-06-04 08:49:06,574 [INFO ] 18|taskloss:1.61804|mask ratio:0.65||bpp:0.01951|mask:0.00003|psnr:15.24594|lpips:0.42104|dists:0.19944
147
+ 2024-06-04 08:49:33,943 [INFO ] 18|taskloss:1.25278|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.76423|lpips:0.39182|dists:0.18584
148
+ 2024-06-04 08:50:01,283 [INFO ] 18|taskloss:1.00574|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.33763|lpips:0.36889|dists:0.17756
149
+ 2024-06-04 08:50:01,284 [INFO ] Test loss=2.039618709603845
150
+ 2024-06-04 09:06:34,348 [INFO ] 19|taskloss:2.28699|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.06225|lpips:0.47773|dists:0.22275
151
+ 2024-06-04 09:07:01,971 [INFO ] 19|taskloss:1.62656|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.01703|lpips:0.42262|dists:0.19507
152
+ 2024-06-04 09:07:29,380 [INFO ] 19|taskloss:1.26701|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.04263|lpips:0.38800|dists:0.18259
153
+ 2024-06-04 09:07:56,604 [INFO ] 19|taskloss:1.02889|mask ratio:0.45||bpp:0.03066|mask:0.00003|psnr:16.22125|lpips:0.36636|dists:0.17257
154
+ 2024-06-04 09:07:56,605 [INFO ] Test loss=2.0622663147548583
155
+ 2024-06-04 09:24:25,377 [INFO ] 20|taskloss:2.24449|mask ratio:0.75||bpp:0.01396|mask:0.00003|psnr:14.10004|lpips:0.46875|dists:0.22301
156
+ 2024-06-04 09:24:52,350 [INFO ] 20|taskloss:1.67153|mask ratio:0.65||bpp:0.01953|mask:0.00003|psnr:15.36281|lpips:0.42634|dists:0.20213
157
+ 2024-06-04 09:25:19,089 [INFO ] 20|taskloss:1.27836|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:15.84954|lpips:0.39339|dists:0.18800
158
+ 2024-06-04 09:25:45,995 [INFO ] 20|taskloss:1.01734|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.35971|lpips:0.36891|dists:0.17725
159
+ 2024-06-04 09:25:45,996 [INFO ] Test loss=2.0684968145514464
160
+ 2024-06-04 09:42:14,168 [INFO ] 21|taskloss:2.27269|mask ratio:0.75||bpp:0.01398|mask:0.00003|psnr:14.36867|lpips:0.46679|dists:0.21023
161
+ 2024-06-04 09:42:41,882 [INFO ] 21|taskloss:1.59626|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.07057|lpips:0.42891|dists:0.19607
162
+ 2024-06-04 09:43:09,575 [INFO ] 21|taskloss:1.23144|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.13393|lpips:0.38757|dists:0.18141
163
+ 2024-06-04 09:43:37,232 [INFO ] 21|taskloss:1.02276|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.52278|lpips:0.36648|dists:0.17367
164
+ 2024-06-04 09:43:37,233 [INFO ] Test loss=2.0479125060594314
165
+ 2024-06-04 10:00:03,966 [INFO ] 22|taskloss:2.29455|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.29998|lpips:0.46570|dists:0.21669
166
+ 2024-06-04 10:00:32,019 [INFO ] 22|taskloss:1.68339|mask ratio:0.65||bpp:0.01951|mask:0.00003|psnr:15.14034|lpips:0.42580|dists:0.20284
167
+ 2024-06-04 10:00:59,949 [INFO ] 22|taskloss:1.26892|mask ratio:0.55||bpp:0.02510|mask:0.00003|psnr:16.01256|lpips:0.39080|dists:0.18718
168
+ 2024-06-04 10:01:27,782 [INFO ] 22|taskloss:1.02441|mask ratio:0.45||bpp:0.03067|mask:0.00003|psnr:16.42309|lpips:0.36846|dists:0.17773
169
+ 2024-06-04 10:01:27,783 [INFO ] Test loss=2.0797956952938494
170
+ 2024-06-04 10:17:48,799 [INFO ] 23|taskloss:2.28370|mask ratio:0.75||bpp:0.01397|mask:0.00003|psnr:14.41319|lpips:0.47015|dists:0.21995
171
+ 2024-06-04 10:18:16,236 [INFO ] 23|taskloss:1.61646|mask ratio:0.65||bpp:0.01952|mask:0.00003|psnr:15.18517|lpips:0.42225|dists:0.20285
172
+ 2024-06-04 10:18:43,717 [INFO ] 23|taskloss:1.26966|mask ratio:0.55||bpp:0.02511|mask:0.00003|psnr:16.16500|lpips:0.39132|dists:0.18260
173
+ 2024-06-04 10:19:11,303 [INFO ] 23|taskloss:1.00680|mask ratio:0.45||bpp:0.03068|mask:0.00003|psnr:16.58024|lpips:0.36382|dists:0.17243
174
+ 2024-06-04 10:19:11,304 [INFO ] Test loss=2.0641073229285345
Fine_Tune_VQ_w_P/checkpoint-20.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54885516ebb1f758a23f3f74c232ecf58376c561d0ed59fe51607b2103fdbfc8
3
+ size 867021559
Fine_Tune_VQ_w_P/checkpoint-last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1af8fab142de32e2b7a9f341dc0eef48b6f5f2c91cfd42840697c817240cf78a
3
+ size 867021559
Fine_Tune_VQ_w_P/events.out.tfevents.1717470501.t2vg-a100-G4-10.976065.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88722834393587807a9428e2fb5b03c41b660c9fed2948159db4e2b773bb74a1
3
+ size 999224
Fine_Tune_VQ_w_P/log.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"train_lr": 4.9890829694323135e-06, "train_loss": 1.243749070896332, "train_mask_ratio": 0.5038230929312227, "epoch": 0}
2
+ {"train_lr": 1.4989082969432316e-05, "train_loss": 1.215295555061909, "train_mask_ratio": 0.502963803220524, "epoch": 1}
3
+ {"train_lr": 2.498908296943232e-05, "train_loss": 1.2075574521349508, "train_mask_ratio": 0.5024819186681223, "epoch": 2}
4
+ {"train_lr": 3.498908296943235e-05, "train_loss": 1.2206204718702744, "train_mask_ratio": 0.50749479735262, "epoch": 3}
5
+ {"train_lr": 4.498908296943233e-05, "train_loss": 1.2098940466754302, "train_mask_ratio": 0.5005287936681223, "epoch": 4}
6
+ {"train_lr": 5.498908296943233e-05, "train_loss": 1.190811416142372, "train_mask_ratio": 0.4951235841975983, "epoch": 5}
7
+ {"train_lr": 6.49890829694323e-05, "train_loss": 1.21159366228695, "train_mask_ratio": 0.5036205308406113, "epoch": 6}
8
+ {"train_lr": 7.498908296943243e-05, "train_loss": 1.2252963672038248, "train_mask_ratio": 0.508678186408297, "epoch": 7}
9
+ {"train_lr": 8.498908296943224e-05, "train_loss": 1.1960869143233028, "train_mask_ratio": 0.49665879162117904, "epoch": 8}
10
+ {"train_lr": 9.49890829694323e-05, "train_loss": 1.1917035718932423, "train_mask_ratio": 0.49613426241812225, "epoch": 9}
11
+ {"train_lr": 9.999038593559283e-05, "train_loss": 1.2097623654718483, "train_mask_ratio": 0.5052154407751092, "epoch": 10}
12
+ {"train_lr": 9.993258939408399e-05, "train_loss": 1.1992728840966411, "train_mask_ratio": 0.4989019002456332, "epoch": 11}
13
+ {"train_lr": 9.981700354882874e-05, "train_loss": 1.2060506974050051, "train_mask_ratio": 0.50310026610262, "epoch": 12}
14
+ {"train_lr": 9.96437692233749e-05, "train_loss": 1.2169636118685314, "train_mask_ratio": 0.5097805506277293, "epoch": 13}
15
+ {"train_lr": 9.94130974770633e-05, "train_loss": 1.1891578605081317, "train_mask_ratio": 0.49562039437772926, "epoch": 14}
16
+ {"train_lr": 9.912526934788473e-05, "train_loss": 1.2026429740817786, "train_mask_ratio": 0.5030107123362445, "epoch": 15}
17
+ {"train_lr": 9.878063551007767e-05, "train_loss": 1.1993349365336927, "train_mask_ratio": 0.5014925627729258, "epoch": 16}
18
+ {"train_lr": 9.837961584688702e-05, "train_loss": 1.1961212191063764, "train_mask_ratio": 0.5002942480895196, "epoch": 17}
19
+ {"train_lr": 9.792269893900009e-05, "train_loss": 1.1912841638753509, "train_mask_ratio": 0.49881874317685587, "epoch": 18}
20
+ {"train_lr": 9.741044146928648e-05, "train_loss": 1.1925643646691044, "train_mask_ratio": 0.4981215031386463, "epoch": 19}
21
+ {"train_lr": 9.68434675445656e-05, "train_loss": 1.1990113721013589, "train_mask_ratio": 0.5008848765010917, "epoch": 20}
22
+ {"train_lr": 9.622246793522915e-05, "train_loss": 1.2057850010790678, "train_mask_ratio": 0.5041706468340611, "epoch": 21}
23
+ {"train_lr": 9.554819923364366e-05, "train_loss": 1.2057653185972481, "train_mask_ratio": 0.5051941184497817, "epoch": 22}
24
+ {"train_lr": 9.482148293235867e-05, "train_loss": 1.1847582247645054, "train_mask_ratio": 0.49711722161572053, "epoch": 23}
MIM_vbr/20240516_010714.log ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-05-16 01:07:14,223 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240516_010714.log
2
+ 2024-05-16 01:07:14,223 [INFO ] ======================= exp =======================
3
+ 2024-05-16 01:07:14,223 [INFO ] config:config/vpt_default.yaml
4
+ 2024-05-16 01:07:14,223 [INFO ] name:exp
5
+ 2024-05-16 01:07:14,223 [INFO ] lr:None
6
+ 2024-05-16 01:07:14,223 [INFO ] world_size:1
7
+ 2024-05-16 01:07:14,223 [INFO ] local_rank:-1
8
+ 2024-05-16 01:07:14,223 [INFO ] dist_on_itp:False
9
+ 2024-05-16 01:07:14,223 [INFO ] dist_url:env://
10
+ 2024-05-16 01:07:14,223 [INFO ] TEST:False
11
+ 2024-05-16 01:07:14,224 [INFO ] model:mage_vit_base_patch16
12
+ 2024-05-16 01:07:14,224 [INFO ] dataset:coco
13
+ 2024-05-16 01:07:14,224 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
14
+ 2024-05-16 01:07:14,224 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017/
15
+ 2024-05-16 01:07:14,224 [INFO ] eval_path:
16
+ 2024-05-16 01:07:14,224 [INFO ] batch_size:32
17
+ 2024-05-16 01:07:14,224 [INFO ] test_batch_size:1
18
+ 2024-05-16 01:07:14,224 [INFO ] num_workers:4
19
+ 2024-05-16 01:07:14,224 [INFO ] epochs:400
20
+ 2024-05-16 01:07:14,224 [INFO ] blr:0.00015
21
+ 2024-05-16 01:07:14,224 [INFO ] min_lr:5e-06
22
+ 2024-05-16 01:07:14,224 [INFO ] accum_iter:1
23
+ 2024-05-16 01:07:14,224 [INFO ] input_size:256
24
+ 2024-05-16 01:07:14,224 [INFO ] weight_decay:0.05
25
+ 2024-05-16 01:07:14,224 [INFO ] warmup_epochs:20
26
+ 2024-05-16 01:07:14,224 [INFO ] mask_ratio_min:0.5
27
+ 2024-05-16 01:07:14,224 [INFO ] mask_ratio_max:0.9
28
+ 2024-05-16 01:07:14,224 [INFO ] mask_ratio_mu:0.55
29
+ 2024-05-16 01:07:14,224 [INFO ] mask_ratio_std:0.25
30
+ 2024-05-16 01:07:14,225 [INFO ] grad_clip:3.0
31
+ 2024-05-16 01:07:14,225 [INFO ] device:cuda
32
+ 2024-05-16 01:07:14,225 [INFO ] distributed:True
33
+ 2024-05-16 01:07:14,225 [INFO ] seed:0
34
+ 2024-05-16 01:07:14,225 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_cls_ori_pos_embed.pth
35
+ 2024-05-16 01:07:14,225 [INFO ] start_epoch:0
36
+ 2024-05-16 01:07:14,225 [INFO ] pin_mem:True
37
+ 2024-05-16 01:07:14,225 [INFO ] save:True
38
+ 2024-05-16 01:07:14,225 [INFO ] exp_name:MIM_vbr
39
+ 2024-05-16 01:07:14,225 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
40
+ 2024-05-16 01:07:14,225 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_high_resolu_eval
41
+ 2024-05-16 01:07:14,225 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_high_resolu_eval
42
+ 2024-05-16 01:07:14,225 [INFO ] strart_epoch:0
43
+ 2024-05-16 01:07:14,225 [INFO ] cuda:True
44
+ 2024-05-16 01:07:14,225 [INFO ] rank:0
45
+ 2024-05-16 01:07:14,225 [INFO ] gpu:0
46
+ 2024-05-16 01:07:14,225 [INFO ] dist_backend:nccl
47
+ 2024-05-16 01:07:14,225 [INFO ] ===================================================
48
+ 2024-05-16 04:50:34,512 [INFO ] -1|bpp:0.06320|mask:0.01002|mask_ratio:0.05|psnr:19.27399|lpips:0.29888|dists:0.10206|Test loss:3932.46558
49
+ 2024-05-16 08:29:19,640 [INFO ] -1|bpp:0.05797|mask:0.01040|mask_ratio:0.15|psnr:18.67500|lpips:0.31491|dists:0.10963|Test loss:1802.33093
50
+ 2024-05-16 12:06:32,090 [INFO ] -1|bpp:0.05017|mask:0.01099|mask_ratio:0.3|psnr:17.72943|lpips:0.34331|dists:0.12283|Test loss:563.63690
51
+ 2024-05-16 15:45:18,037 [INFO ] -1|bpp:0.04235|mask:0.01158|mask_ratio:0.45|psnr:16.67805|lpips:0.37921|dists:0.13982|Test loss:177.07100
52
+ 2024-05-16 19:29:09,057 [INFO ] -1|bpp:0.03453|mask:0.01215|mask_ratio:0.6|psnr:15.40944|lpips:0.42819|dists:0.16429|Test loss:56.36422
53
+ 2024-05-16 22:59:34,921 [INFO ] -1|bpp:0.02675|mask:0.01275|mask_ratio:0.75|psnr:13.72992|lpips:0.49992|dists:0.20360|Test loss:18.22452
MIM_vbr/checkpoint-399.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4435ef340ca3965a9c47c956fea2bb48e8613162260e443c4a37d428e6890569
3
+ size 1960373936
MIM_vbr/continue_tune_rope.log ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-05-03 13:24:36,083 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
2
+ 2024-05-03 13:24:36,083 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
3
+ 2024-05-03 13:24:36,083 [INFO ] ======================= exp =======================
4
+ 2024-05-03 13:24:36,083 [INFO ] ======================= exp =======================
5
+ 2024-05-03 13:24:36,083 [INFO ] config:config/vpt_default.yaml
6
+ 2024-05-03 13:24:36,083 [INFO ] config:config/vpt_default.yaml
7
+ 2024-05-03 13:24:36,083 [INFO ] name:exp
8
+ 2024-05-03 13:24:36,083 [INFO ] name:exp
9
+ 2024-05-03 13:24:36,083 [INFO ] lr:None
10
+ 2024-05-03 13:24:36,083 [INFO ] lr:None
11
+ 2024-05-03 13:24:36,083 [INFO ] world_size:4
12
+ 2024-05-03 13:24:36,083 [INFO ] world_size:4
13
+ 2024-05-03 13:24:36,084 [INFO ] local_rank:-1
14
+ 2024-05-03 13:24:36,084 [INFO ] local_rank:-1
15
+ 2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
16
+ 2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
17
+ 2024-05-03 13:24:36,084 [INFO ] dist_url:env://
18
+ 2024-05-03 13:24:36,084 [INFO ] dist_url:env://
19
+ 2024-05-03 13:24:36,084 [INFO ] TEST:False
20
+ 2024-05-03 13:24:36,084 [INFO ] TEST:False
21
+ 2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
22
+ 2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
23
+ 2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
24
+ 2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
25
+ 2024-05-03 13:24:36,084 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
26
+ 2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
27
+ 2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
28
+ 2024-05-03 13:24:36,084 [INFO ] eval_path:
29
+ 2024-05-03 13:24:36,084 [INFO ] eval_path:
30
+ 2024-05-03 13:24:36,084 [INFO ] batch_size:64
31
+ 2024-05-03 13:24:36,084 [INFO ] batch_size:64
32
+ 2024-05-03 13:24:36,084 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr//20240503_132436.log
33
+ 2024-05-03 13:24:36,084 [INFO ] ======================= exp =======================
34
+ 2024-05-03 13:24:36,084 [INFO ] test_batch_size:64
35
+ 2024-05-03 13:24:36,084 [INFO ] test_batch_size:64
36
+ 2024-05-03 13:24:36,084 [INFO ] config:config/vpt_default.yaml
37
+ 2024-05-03 13:24:36,084 [INFO ] num_workers:4
38
+ 2024-05-03 13:24:36,084 [INFO ] num_workers:4
39
+ 2024-05-03 13:24:36,084 [INFO ] ======================= exp =======================
40
+ 2024-05-03 13:24:36,084 [INFO ] name:exp
41
+ 2024-05-03 13:24:36,084 [INFO ] epochs:400
42
+ 2024-05-03 13:24:36,084 [INFO ] epochs:400
43
+ 2024-05-03 13:24:36,084 [INFO ] lr:None
44
+ 2024-05-03 13:24:36,084 [INFO ] blr:0.00015
45
+ 2024-05-03 13:24:36,084 [INFO ] config:config/vpt_default.yaml
46
+ 2024-05-03 13:24:36,084 [INFO ] blr:0.00015
47
+ 2024-05-03 13:24:36,084 [INFO ] world_size:4
48
+ 2024-05-03 13:24:36,084 [INFO ] min_lr:5e-06
49
+ 2024-05-03 13:24:36,084 [INFO ] name:exp
50
+ 2024-05-03 13:24:36,084 [INFO ] min_lr:5e-06
51
+ 2024-05-03 13:24:36,084 [INFO ] accum_iter:1
52
+ 2024-05-03 13:24:36,084 [INFO ] local_rank:-1
53
+ 2024-05-03 13:24:36,084 [INFO ] lr:None
54
+ 2024-05-03 13:24:36,084 [INFO ] accum_iter:1
55
+ 2024-05-03 13:24:36,084 [INFO ] world_size:4
56
+ 2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
57
+ 2024-05-03 13:24:36,084 [INFO ] input_size:256
58
+ 2024-05-03 13:24:36,084 [INFO ] input_size:256
59
+ 2024-05-03 13:24:36,084 [INFO ] weight_decay:0.05
60
+ 2024-05-03 13:24:36,084 [INFO ] dist_url:env://
61
+ 2024-05-03 13:24:36,084 [INFO ] local_rank:-1
62
+ 2024-05-03 13:24:36,084 [INFO ] weight_decay:0.05
63
+ 2024-05-03 13:24:36,084 [INFO ] warmup_epochs:20
64
+ 2024-05-03 13:24:36,084 [INFO ] warmup_epochs:20
65
+ 2024-05-03 13:24:36,084 [INFO ] TEST:False
66
+ 2024-05-03 13:24:36,084 [INFO ] dist_on_itp:False
67
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_min:0.5
68
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_min:0.5
69
+ 2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
70
+ 2024-05-03 13:24:36,084 [INFO ] dist_url:env://
71
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_max:0.9
72
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_max:0.9
73
+ 2024-05-03 13:24:36,084 [INFO ] TEST:False
74
+ 2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
75
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_mu:0.55
76
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_mu:0.55
77
+ 2024-05-03 13:24:36,084 [INFO ] model:mage_vit_base_patch16
78
+ 2024-05-03 13:24:36,084 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
79
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_std:0.25
80
+ 2024-05-03 13:24:36,084 [INFO ] mask_ratio_std:0.25
81
+ 2024-05-03 13:24:36,084 [INFO ] dataset:imagenet
82
+ 2024-05-03 13:24:36,084 [INFO ] eval_path:
83
+ 2024-05-03 13:24:36,084 [INFO ] grad_clip:3.0
84
+ 2024-05-03 13:24:36,084 [INFO ] grad_clip:3.0
85
+ 2024-05-03 13:24:36,085 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
86
+ 2024-05-03 13:24:36,085 [INFO ] batch_size:64
87
+ 2024-05-03 13:24:36,085 [INFO ] device:cuda
88
+ 2024-05-03 13:24:36,085 [INFO ] device:cuda
89
+ 2024-05-03 13:24:36,085 [INFO ] test_batch_size:64
90
+ 2024-05-03 13:24:36,085 [INFO ] eval_path:
91
+ 2024-05-03 13:24:36,085 [INFO ] distributed:True
92
+ 2024-05-03 13:24:36,085 [INFO ] distributed:True
93
+ 2024-05-03 13:24:36,085 [INFO ] num_workers:4
94
+ 2024-05-03 13:24:36,085 [INFO ] batch_size:64
95
+ 2024-05-03 13:24:36,085 [INFO ] seed:0
96
+ 2024-05-03 13:24:36,085 [INFO ] seed:0
97
+ 2024-05-03 13:24:36,085 [INFO ] test_batch_size:64
98
+ 2024-05-03 13:24:36,085 [INFO ] epochs:400
99
+ 2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
100
+ 2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
101
+ 2024-05-03 13:24:36,085 [INFO ] blr:0.00015
102
+ 2024-05-03 13:24:36,085 [INFO ] num_workers:4
103
+ 2024-05-03 13:24:36,085 [INFO ] start_epoch:0
104
+ 2024-05-03 13:24:36,085 [INFO ] start_epoch:0
105
+ 2024-05-03 13:24:36,085 [INFO ] epochs:400
106
+ 2024-05-03 13:24:36,085 [INFO ] min_lr:5e-06
107
+ 2024-05-03 13:24:36,085 [INFO ] pin_mem:True
108
+ 2024-05-03 13:24:36,085 [INFO ] pin_mem:True
109
+ 2024-05-03 13:24:36,085 [INFO ] blr:0.00015
110
+ 2024-05-03 13:24:36,085 [INFO ] accum_iter:1
111
+ 2024-05-03 13:24:36,085 [INFO ] save:True
112
+ 2024-05-03 13:24:36,085 [INFO ] save:True
113
+ 2024-05-03 13:24:36,085 [INFO ] input_size:256
114
+ 2024-05-03 13:24:36,085 [INFO ] min_lr:5e-06
115
+ 2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
116
+ 2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
117
+ 2024-05-03 13:24:36,085 [INFO ] weight_decay:0.05
118
+ 2024-05-03 13:24:36,085 [INFO ] accum_iter:1
119
+ 2024-05-03 13:24:36,085 [INFO ] warmup_epochs:20
120
+ 2024-05-03 13:24:36,085 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
121
+ 2024-05-03 13:24:36,085 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
122
+ 2024-05-03 13:24:36,085 [INFO ] input_size:256
123
+ 2024-05-03 13:24:36,085 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
124
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_min:0.5
125
+ 2024-05-03 13:24:36,085 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
126
+ 2024-05-03 13:24:36,085 [INFO ] weight_decay:0.05
127
+ 2024-05-03 13:24:36,085 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
128
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_max:0.9
129
+ 2024-05-03 13:24:36,085 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
130
+ 2024-05-03 13:24:36,085 [INFO ] warmup_epochs:20
131
+ 2024-05-03 13:24:36,085 [INFO ] strart_epoch:0
132
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_mu:0.55
133
+ 2024-05-03 13:24:36,085 [INFO ] strart_epoch:0
134
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_min:0.5
135
+ 2024-05-03 13:24:36,085 [INFO ] cuda:True
136
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_std:0.25
137
+ 2024-05-03 13:24:36,085 [INFO ] cuda:True
138
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_max:0.9
139
+ 2024-05-03 13:24:36,085 [INFO ] rank:0
140
+ 2024-05-03 13:24:36,085 [INFO ] grad_clip:3.0
141
+ 2024-05-03 13:24:36,085 [INFO ] rank:1
142
+ 2024-05-03 13:24:36,085 [INFO ] gpu:0
143
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_mu:0.55
144
+ 2024-05-03 13:24:36,085 [INFO ] device:cuda
145
+ 2024-05-03 13:24:36,085 [INFO ] gpu:1
146
+ 2024-05-03 13:24:36,085 [INFO ] dist_backend:nccl
147
+ 2024-05-03 13:24:36,085 [INFO ] distributed:True
148
+ 2024-05-03 13:24:36,085 [INFO ] mask_ratio_std:0.25
149
+ 2024-05-03 13:24:36,085 [INFO ] dist_backend:nccl
150
+ 2024-05-03 13:24:36,085 [INFO ] ===================================================
151
+ 2024-05-03 13:24:36,085 [INFO ] grad_clip:3.0
152
+ 2024-05-03 13:24:36,085 [INFO ] seed:0
153
+ 2024-05-03 13:24:36,085 [INFO ] ===================================================
154
+ 2024-05-03 13:24:36,085 [INFO ] device:cuda
155
+ 2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
156
+ 2024-05-03 13:24:36,085 [INFO ] start_epoch:0
157
+ 2024-05-03 13:24:36,085 [INFO ] distributed:True
158
+ 2024-05-03 13:24:36,085 [INFO ] seed:0
159
+ 2024-05-03 13:24:36,085 [INFO ] pin_mem:True
160
+ 2024-05-03 13:24:36,085 [INFO ] save:True
161
+ 2024-05-03 13:24:36,085 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth
162
+ 2024-05-03 13:24:36,085 [INFO ] exp_name:MIM_vbr
163
+ 2024-05-03 13:24:36,085 [INFO ] start_epoch:0
164
+ 2024-05-03 13:24:36,086 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
165
+ 2024-05-03 13:24:36,086 [INFO ] pin_mem:True
166
+ 2024-05-03 13:24:36,086 [INFO ] save:True
167
+ 2024-05-03 13:24:36,086 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
168
+ 2024-05-03 13:24:36,086 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
169
+ 2024-05-03 13:24:36,086 [INFO ] exp_name:MIM_vbr
170
+ 2024-05-03 13:24:36,086 [INFO ] strart_epoch:0
171
+ 2024-05-03 13:24:36,086 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
172
+ 2024-05-03 13:24:36,086 [INFO ] cuda:True
173
+ 2024-05-03 13:24:36,086 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
174
+ 2024-05-03 13:24:36,086 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr
175
+ 2024-05-03 13:24:36,086 [INFO ] rank:3
176
+ 2024-05-03 13:24:36,086 [INFO ] gpu:3
177
+ 2024-05-03 13:24:36,086 [INFO ] strart_epoch:0
178
+ 2024-05-03 13:24:36,086 [INFO ] cuda:True
179
+ 2024-05-03 13:24:36,086 [INFO ] dist_backend:nccl
180
+ 2024-05-03 13:24:36,086 [INFO ] ===================================================
181
+ 2024-05-03 13:24:36,086 [INFO ] rank:2
182
+ 2024-05-03 13:24:36,086 [INFO ] gpu:2
183
+ 2024-05-03 13:24:36,086 [INFO ] dist_backend:nccl
184
+ 2024-05-03 13:24:36,086 [INFO ] ===================================================
185
+ 2024-05-03 13:25:40,076 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28939|msssim:0.33530|lpips:0.50428|dists:0.27922|Test loss:22.31097
186
+ 2024-05-03 13:45:27,931 [INFO ] 71|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27695|msssim:0.33996|lpips:0.50599|dists:0.28065|Test loss:22.37941
187
+ 2024-05-03 14:05:20,381 [INFO ] 72|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18272|msssim:0.33581|lpips:0.50830|dists:0.28347|Test loss:22.39342
188
+ 2024-05-03 14:25:08,798 [INFO ] 73|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.07158|msssim:0.33378|lpips:0.50512|dists:0.28206|Test loss:22.34897
189
+ 2024-05-03 14:45:13,282 [INFO ] 74|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25096|msssim:0.33580|lpips:0.50984|dists:0.28295|Test loss:22.41223
190
+ 2024-05-03 15:05:00,304 [INFO ] 75|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34433|msssim:0.34183|lpips:0.50830|dists:0.28667|Test loss:22.41598
191
+ 2024-05-03 15:24:49,830 [INFO ] 76|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18746|msssim:0.33532|lpips:0.50817|dists:0.28051|Test loss:22.38807
192
+ 2024-05-03 15:44:36,569 [INFO ] 77|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.24094|msssim:0.33337|lpips:0.50695|dists:0.28093|Test loss:22.35640
193
+ 2024-05-03 16:04:23,158 [INFO ] 78|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25324|msssim:0.33446|lpips:0.51007|dists:0.28490|Test loss:22.40084
194
+ 2024-05-03 16:24:10,702 [INFO ] 79|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22266|msssim:0.33541|lpips:0.50979|dists:0.28361|Test loss:22.35502
195
+ 2024-05-03 16:43:58,827 [INFO ] 80|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10645|msssim:0.32806|lpips:0.51042|dists:0.28348|Test loss:22.37403
196
+ 2024-05-03 17:03:49,009 [INFO ] 81|bpp:0.01873|mask:0.00888|mask_ratio:0.75|psnr:12.16949|msssim:0.33372|lpips:0.50971|dists:0.28431|Test loss:22.38677
197
+ 2024-05-03 17:23:36,872 [INFO ] 82|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19599|msssim:0.33479|lpips:0.50738|dists:0.27982|Test loss:22.39563
198
+ 2024-05-03 17:43:24,129 [INFO ] 83|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13946|msssim:0.33657|lpips:0.50608|dists:0.28026|Test loss:22.37709
199
+ 2024-05-03 18:03:09,632 [INFO ] 84|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19059|msssim:0.32992|lpips:0.51449|dists:0.28200|Test loss:22.39230
200
+ 2024-05-03 18:22:55,813 [INFO ] 85|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18655|msssim:0.33290|lpips:0.50831|dists:0.28350|Test loss:22.38299
201
+ 2024-05-03 18:42:41,651 [INFO ] 86|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21415|msssim:0.33196|lpips:0.50763|dists:0.28348|Test loss:22.34028
202
+ 2024-05-03 19:02:43,711 [INFO ] 87|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12306|msssim:0.33658|lpips:0.50986|dists:0.28280|Test loss:22.35450
203
+ 2024-05-03 19:22:30,211 [INFO ] 88|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27434|msssim:0.33857|lpips:0.50953|dists:0.28144|Test loss:22.32993
204
+ 2024-05-03 19:42:31,870 [INFO ] 89|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29707|msssim:0.33239|lpips:0.51159|dists:0.28729|Test loss:22.39535
205
+ 2024-05-03 20:02:18,523 [INFO ] 90|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.93105|msssim:0.32874|lpips:0.51374|dists:0.28500|Test loss:22.37451
206
+ 2024-05-03 20:22:40,704 [INFO ] 91|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15259|msssim:0.33332|lpips:0.51196|dists:0.28565|Test loss:22.39588
207
+ 2024-05-03 20:42:50,402 [INFO ] 92|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22207|msssim:0.33683|lpips:0.50868|dists:0.28256|Test loss:22.34765
208
+ 2024-05-03 21:03:00,830 [INFO ] 93|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11886|msssim:0.33659|lpips:0.50814|dists:0.28405|Test loss:22.34739
209
+ 2024-05-03 21:23:11,053 [INFO ] 94|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13203|msssim:0.33091|lpips:0.50892|dists:0.28377|Test loss:22.36003
210
+ 2024-05-03 21:43:19,380 [INFO ] 95|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13799|msssim:0.33715|lpips:0.51032|dists:0.28500|Test loss:22.35577
211
+ 2024-05-03 22:03:29,276 [INFO ] 96|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29663|msssim:0.33942|lpips:0.50849|dists:0.28102|Test loss:22.39275
212
+ 2024-05-03 22:23:40,310 [INFO ] 97|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31019|msssim:0.34155|lpips:0.50632|dists:0.28538|Test loss:22.33029
213
+ 2024-05-03 22:25:02,985 [INFO ] 98|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28031|msssim:0.33755|lpips:0.50549|dists:0.28245|Test loss:22.33734
214
+ 2024-05-03 22:45:13,341 [INFO ] 99|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17199|msssim:0.33258|lpips:0.50680|dists:0.28107|Test loss:22.37611
215
+ 2024-05-03 23:05:22,081 [INFO ] 100|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12885|msssim:0.33378|lpips:0.50697|dists:0.27960|Test loss:22.35676
216
+ 2024-05-03 23:25:59,251 [INFO ] 101|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23955|msssim:0.33759|lpips:0.50643|dists:0.28164|Test loss:22.33918
217
+ 2024-05-03 23:46:28,120 [INFO ] 102|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28548|msssim:0.33363|lpips:0.50599|dists:0.28383|Test loss:22.36715
218
+ 2024-05-04 00:07:06,785 [INFO ] 103|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38201|msssim:0.33652|lpips:0.50561|dists:0.28420|Test loss:22.34940
219
+ 2024-05-04 00:27:18,585 [INFO ] 104|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23971|msssim:0.33526|lpips:0.50535|dists:0.28196|Test loss:22.32486
220
+ 2024-05-04 00:47:46,770 [INFO ] 105|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16863|msssim:0.33425|lpips:0.50665|dists:0.28272|Test loss:22.35667
221
+ 2024-05-04 01:07:48,995 [INFO ] 106|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33955|msssim:0.34124|lpips:0.50687|dists:0.28063|Test loss:22.33323
222
+ 2024-05-04 01:27:57,424 [INFO ] 107|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32910|msssim:0.33785|lpips:0.50615|dists:0.28166|Test loss:22.33379
223
+ 2024-05-04 01:48:04,206 [INFO ] 108|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17313|msssim:0.33479|lpips:0.50718|dists:0.28483|Test loss:22.34892
224
+ 2024-05-04 02:08:13,034 [INFO ] 109|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25400|msssim:0.33878|lpips:0.50396|dists:0.27973|Test loss:22.27535
225
+ 2024-05-04 02:28:39,568 [INFO ] 110|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18919|msssim:0.33519|lpips:0.50766|dists:0.28370|Test loss:22.34272
226
+ 2024-05-04 02:48:31,198 [INFO ] 111|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25287|msssim:0.34151|lpips:0.50405|dists:0.28154|Test loss:22.29224
227
+ 2024-05-04 03:08:18,193 [INFO ] 112|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25592|msssim:0.33658|lpips:0.50613|dists:0.27908|Test loss:22.31837
228
+ 2024-05-04 03:28:06,312 [INFO ] 113|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19013|msssim:0.33855|lpips:0.50963|dists:0.28313|Test loss:22.35398
229
+ 2024-05-04 03:47:53,623 [INFO ] 114|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22372|msssim:0.33530|lpips:0.50483|dists:0.27558|Test loss:22.27828
230
+ 2024-05-04 04:07:39,992 [INFO ] 115|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35171|msssim:0.34350|lpips:0.50526|dists:0.28044|Test loss:22.30720
231
+ 2024-05-04 04:27:27,792 [INFO ] 116|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14400|msssim:0.33273|lpips:0.51011|dists:0.28111|Test loss:22.33026
232
+ 2024-05-04 04:47:15,548 [INFO ] 117|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22451|msssim:0.33556|lpips:0.50722|dists:0.28211|Test loss:22.30748
233
+ 2024-05-04 05:07:04,750 [INFO ] 118|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39794|msssim:0.33993|lpips:0.50137|dists:0.28058|Test loss:22.32248
234
+ 2024-05-04 05:26:52,003 [INFO ] 119|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35083|msssim:0.34228|lpips:0.50405|dists:0.28113|Test loss:22.29005
235
+ 2024-05-04 05:46:54,175 [INFO ] 120|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19306|msssim:0.33767|lpips:0.50508|dists:0.28069|Test loss:22.30016
236
+ 2024-05-04 06:07:56,985 [INFO ] 121|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30696|msssim:0.34229|lpips:0.50263|dists:0.27623|Test loss:22.32207
237
+ 2024-05-04 06:28:36,777 [INFO ] 122|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32276|msssim:0.34709|lpips:0.50338|dists:0.27890|Test loss:22.27271
238
+ 2024-05-04 06:49:34,456 [INFO ] 123|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15545|msssim:0.33990|lpips:0.50738|dists:0.28200|Test loss:22.31753
239
+ 2024-05-04 07:10:10,082 [INFO ] 124|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26117|msssim:0.33900|lpips:0.50466|dists:0.28039|Test loss:22.31486
240
+ 2024-05-04 07:30:52,259 [INFO ] 125|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22142|msssim:0.33723|lpips:0.50712|dists:0.28167|Test loss:22.32770
241
+ 2024-05-04 07:51:29,029 [INFO ] 126|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29230|msssim:0.33539|lpips:0.50329|dists:0.27957|Test loss:22.30125
242
+ 2024-05-04 08:11:57,295 [INFO ] 127|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10057|msssim:0.33604|lpips:0.50390|dists:0.27900|Test loss:22.32180
243
+ 2024-05-04 08:32:29,651 [INFO ] 128|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35982|msssim:0.34019|lpips:0.50363|dists:0.27863|Test loss:22.30414
244
+ 2024-05-04 08:53:13,202 [INFO ] 129|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15058|msssim:0.33351|lpips:0.50549|dists:0.27943|Test loss:22.26283
245
+ 2024-05-04 09:13:57,329 [INFO ] 130|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.27512|msssim:0.33742|lpips:0.50353|dists:0.27905|Test loss:22.30725
246
+ 2024-05-04 09:15:24,858 [INFO ] 131|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.20872|msssim:0.34373|lpips:0.50210|dists:0.28035|Test loss:22.27275
247
+ 2024-05-04 09:36:10,653 [INFO ] 132|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20383|msssim:0.33974|lpips:0.50339|dists:0.27824|Test loss:22.29464
248
+ 2024-05-04 09:56:51,249 [INFO ] 133|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12968|msssim:0.33882|lpips:0.50371|dists:0.27564|Test loss:22.30516
249
+ 2024-05-04 10:17:09,926 [INFO ] 134|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15733|msssim:0.33909|lpips:0.50420|dists:0.27941|Test loss:22.28371
250
+ 2024-05-04 10:37:22,751 [INFO ] 135|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.14190|msssim:0.33220|lpips:0.50398|dists:0.27657|Test loss:22.25691
251
+ 2024-05-04 10:58:32,045 [INFO ] 136|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21952|msssim:0.34333|lpips:0.50654|dists:0.28236|Test loss:22.29695
252
+ 2024-05-04 11:19:04,383 [INFO ] 137|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.12107|msssim:0.33130|lpips:0.50644|dists:0.27993|Test loss:22.30494
253
+ 2024-05-04 11:39:23,509 [INFO ] 138|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28914|msssim:0.34730|lpips:0.50144|dists:0.27736|Test loss:22.26708
254
+ 2024-05-04 11:59:44,396 [INFO ] 139|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12318|msssim:0.33862|lpips:0.50449|dists:0.27968|Test loss:22.28541
255
+ 2024-05-04 12:20:03,670 [INFO ] 140|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36024|msssim:0.33926|lpips:0.50315|dists:0.27705|Test loss:22.24402
256
+ 2024-05-04 12:40:47,447 [INFO ] 141|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37666|msssim:0.33869|lpips:0.50093|dists:0.27720|Test loss:22.27834
257
+ 2024-05-04 13:01:04,137 [INFO ] 142|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.21903|msssim:0.34257|lpips:0.50445|dists:0.27656|Test loss:22.27222
258
+ 2024-05-04 13:21:33,832 [INFO ] 143|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23086|msssim:0.33986|lpips:0.50459|dists:0.28040|Test loss:22.30935
259
+ 2024-05-04 13:42:00,911 [INFO ] 144|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18145|msssim:0.34159|lpips:0.50526|dists:0.27986|Test loss:22.29733
260
+ 2024-05-04 14:02:21,466 [INFO ] 145|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.41649|msssim:0.34486|lpips:0.50061|dists:0.27830|Test loss:22.27924
261
+ 2024-05-04 14:23:01,203 [INFO ] 146|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40444|msssim:0.33984|lpips:0.50394|dists:0.27939|Test loss:22.27717
262
+ 2024-05-04 14:44:11,569 [INFO ] 147|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.10670|msssim:0.33935|lpips:0.50644|dists:0.28129|Test loss:22.28859
263
+ 2024-05-04 15:05:12,252 [INFO ] 148|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31263|msssim:0.34371|lpips:0.50163|dists:0.27719|Test loss:22.28988
264
+ 2024-05-04 15:26:22,270 [INFO ] 149|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11357|msssim:0.34112|lpips:0.50444|dists:0.28143|Test loss:22.27471
265
+ 2024-05-04 15:47:18,685 [INFO ] 150|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37583|msssim:0.34108|lpips:0.50440|dists:0.27920|Test loss:22.28879
266
+ 2024-05-04 16:08:26,498 [INFO ] 151|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23233|msssim:0.33481|lpips:0.50793|dists:0.28015|Test loss:22.30692
267
+ 2024-05-04 16:29:29,985 [INFO ] 152|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38831|msssim:0.34042|lpips:0.50221|dists:0.27659|Test loss:22.24423
268
+ 2024-05-04 16:50:31,249 [INFO ] 153|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19962|msssim:0.34101|lpips:0.50407|dists:0.28065|Test loss:22.29826
269
+ 2024-05-04 17:11:31,226 [INFO ] 154|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.11960|msssim:0.34173|lpips:0.50215|dists:0.27693|Test loss:22.27951
270
+ 2024-05-04 17:32:28,025 [INFO ] 155|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17072|msssim:0.33614|lpips:0.50420|dists:0.27642|Test loss:22.25476
271
+ 2024-05-04 17:53:03,788 [INFO ] 156|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18252|msssim:0.33534|lpips:0.50560|dists:0.27726|Test loss:22.23966
272
+ 2024-05-04 18:13:53,631 [INFO ] 157|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.30066|msssim:0.34024|lpips:0.50406|dists:0.27919|Test loss:22.27285
273
+ 2024-05-04 18:34:02,548 [INFO ] 158|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16105|msssim:0.33724|lpips:0.50539|dists:0.27802|Test loss:22.28563
274
+ 2024-05-04 18:54:17,391 [INFO ] 159|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26286|msssim:0.33958|lpips:0.50246|dists:0.27786|Test loss:22.25693
275
+ 2024-05-04 19:14:45,980 [INFO ] 160|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.10013|msssim:0.33552|lpips:0.50203|dists:0.27745|Test loss:22.27180
276
+ 2024-05-04 19:35:55,093 [INFO ] 161|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19999|msssim:0.34075|lpips:0.50399|dists:0.28136|Test loss:22.28090
277
+ 2024-05-04 19:57:04,549 [INFO ] 162|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35397|msssim:0.34255|lpips:0.50396|dists:0.28067|Test loss:22.30754
278
+ 2024-05-04 20:18:09,435 [INFO ] 163|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.07471|msssim:0.33934|lpips:0.50488|dists:0.27944|Test loss:22.29782
279
+ 2024-05-04 20:19:36,675 [INFO ] 164|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.15191|msssim:0.33803|lpips:0.50363|dists:0.28080|Test loss:22.30785
280
+ 2024-05-04 20:40:45,907 [INFO ] 165|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35093|msssim:0.34222|lpips:0.50133|dists:0.27822|Test loss:22.23812
281
+ 2024-05-04 21:02:06,997 [INFO ] 166|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14147|msssim:0.34612|lpips:0.50051|dists:0.27732|Test loss:22.26769
282
+ 2024-05-04 21:23:13,250 [INFO ] 167|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28212|msssim:0.34412|lpips:0.50155|dists:0.27944|Test loss:22.26505
283
+ 2024-05-04 21:44:09,446 [INFO ] 168|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23860|msssim:0.34012|lpips:0.50238|dists:0.27605|Test loss:22.24193
284
+ 2024-05-04 22:04:20,528 [INFO ] 169|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18714|msssim:0.34326|lpips:0.50174|dists:0.27552|Test loss:22.19141
285
+ 2024-05-04 22:24:52,856 [INFO ] 170|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23269|msssim:0.34707|lpips:0.50277|dists:0.28087|Test loss:22.25782
286
+ 2024-05-04 22:46:02,786 [INFO ] 171|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34741|msssim:0.34084|lpips:0.50177|dists:0.27559|Test loss:22.23672
287
+ 2024-05-04 23:06:52,141 [INFO ] 172|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28709|msssim:0.34182|lpips:0.50183|dists:0.27629|Test loss:22.25575
288
+ 2024-05-04 23:27:45,873 [INFO ] 173|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32183|msssim:0.34279|lpips:0.50424|dists:0.27647|Test loss:22.26211
289
+ 2024-05-04 23:49:09,385 [INFO ] 174|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27799|msssim:0.34067|lpips:0.50149|dists:0.27606|Test loss:22.25085
290
+ 2024-05-05 00:09:54,865 [INFO ] 175|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.23614|msssim:0.34026|lpips:0.50595|dists:0.27867|Test loss:22.26061
291
+ 2024-05-05 00:30:37,643 [INFO ] 176|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26992|msssim:0.34440|lpips:0.50466|dists:0.27647|Test loss:22.26275
292
+ 2024-05-05 00:51:58,336 [INFO ] 177|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.27768|msssim:0.34192|lpips:0.50217|dists:0.27718|Test loss:22.24697
293
+ 2024-05-05 01:12:41,638 [INFO ] 178|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.22162|msssim:0.34181|lpips:0.50407|dists:0.28091|Test loss:22.28326
294
+ 2024-05-05 01:33:33,913 [INFO ] 179|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33287|msssim:0.34137|lpips:0.50010|dists:0.27889|Test loss:22.25626
295
+ 2024-05-05 01:54:24,919 [INFO ] 180|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26020|msssim:0.33738|lpips:0.50198|dists:0.27898|Test loss:22.27012
296
+ 2024-05-05 02:15:11,781 [INFO ] 181|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24128|msssim:0.34296|lpips:0.50206|dists:0.27911|Test loss:22.27168
297
+ 2024-05-05 02:36:00,258 [INFO ] 182|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.29608|msssim:0.34143|lpips:0.50258|dists:0.27963|Test loss:22.23007
298
+ 2024-05-05 02:56:51,997 [INFO ] 183|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.14427|msssim:0.33923|lpips:0.50353|dists:0.27847|Test loss:22.26153
299
+ 2024-05-05 03:17:42,729 [INFO ] 184|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.24628|msssim:0.33477|lpips:0.50120|dists:0.27460|Test loss:22.22222
300
+ 2024-05-05 03:38:55,556 [INFO ] 185|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22856|msssim:0.33995|lpips:0.49928|dists:0.27506|Test loss:22.21741
301
+ 2024-05-05 03:59:54,038 [INFO ] 186|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25829|msssim:0.34396|lpips:0.50041|dists:0.27953|Test loss:22.21000
302
+ 2024-05-05 04:20:57,361 [INFO ] 187|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33900|msssim:0.34563|lpips:0.50119|dists:0.27801|Test loss:22.24254
303
+ 2024-05-05 04:41:56,894 [INFO ] 188|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25541|msssim:0.34364|lpips:0.50117|dists:0.27644|Test loss:22.22793
304
+ 2024-05-05 05:02:57,947 [INFO ] 189|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.29414|msssim:0.34397|lpips:0.50046|dists:0.27675|Test loss:22.22504
305
+ 2024-05-05 05:24:00,200 [INFO ] 190|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42643|msssim:0.34251|lpips:0.50034|dists:0.27756|Test loss:22.20704
306
+ 2024-05-05 05:45:06,616 [INFO ] 191|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24620|msssim:0.33765|lpips:0.50179|dists:0.27763|Test loss:22.24989
307
+ 2024-05-05 06:05:48,232 [INFO ] 192|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20766|msssim:0.34222|lpips:0.50102|dists:0.27429|Test loss:22.20998
308
+ 2024-05-05 06:26:41,032 [INFO ] 193|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28995|msssim:0.34232|lpips:0.50011|dists:0.27672|Test loss:22.19870
309
+ 2024-05-05 06:47:28,088 [INFO ] 194|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31111|msssim:0.34537|lpips:0.50108|dists:0.27792|Test loss:22.22558
310
+ 2024-05-05 07:08:15,894 [INFO ] 195|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19613|msssim:0.34356|lpips:0.50329|dists:0.27837|Test loss:22.27292
311
+ 2024-05-05 07:29:07,588 [INFO ] 196|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34946|msssim:0.34073|lpips:0.50324|dists:0.27687|Test loss:22.25968
312
+ 2024-05-05 07:30:34,517 [INFO ] 197|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28973|msssim:0.34913|lpips:0.49958|dists:0.27724|Test loss:22.21089
313
+ 2024-05-05 07:51:24,205 [INFO ] 198|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.22106|msssim:0.33796|lpips:0.50167|dists:0.27715|Test loss:22.27205
314
+ 2024-05-05 08:12:37,156 [INFO ] 199|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26782|msssim:0.34472|lpips:0.50100|dists:0.27782|Test loss:22.21741
315
+ 2024-05-05 08:33:49,412 [INFO ] 200|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28700|msssim:0.34200|lpips:0.50071|dists:0.27705|Test loss:22.24541
316
+ 2024-05-05 08:55:00,482 [INFO ] 201|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26943|msssim:0.34567|lpips:0.50009|dists:0.27623|Test loss:22.23442
317
+ 2024-05-05 09:15:55,407 [INFO ] 202|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28424|msssim:0.33845|lpips:0.50023|dists:0.27742|Test loss:22.21661
318
+ 2024-05-05 09:37:07,037 [INFO ] 203|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36869|msssim:0.34921|lpips:0.49663|dists:0.27317|Test loss:22.18738
319
+ 2024-05-05 09:58:22,992 [INFO ] 204|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29465|msssim:0.34732|lpips:0.49851|dists:0.27560|Test loss:22.21542
320
+ 2024-05-05 10:19:30,345 [INFO ] 205|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34941|msssim:0.34522|lpips:0.49991|dists:0.27673|Test loss:22.23310
321
+ 2024-05-05 10:40:34,778 [INFO ] 206|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29729|msssim:0.34074|lpips:0.49963|dists:0.27439|Test loss:22.18471
322
+ 2024-05-05 11:01:52,830 [INFO ] 207|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20647|msssim:0.34440|lpips:0.50193|dists:0.27544|Test loss:22.16906
323
+ 2024-05-05 11:23:10,869 [INFO ] 208|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33049|msssim:0.34167|lpips:0.49881|dists:0.27573|Test loss:22.22115
324
+ 2024-05-05 11:44:29,730 [INFO ] 209|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36687|msssim:0.34201|lpips:0.50120|dists:0.27688|Test loss:22.21192
325
+ 2024-05-05 12:05:44,940 [INFO ] 210|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.19353|msssim:0.33601|lpips:0.50266|dists:0.27510|Test loss:22.24716
326
+ 2024-05-05 12:27:06,882 [INFO ] 211|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31839|msssim:0.34303|lpips:0.49930|dists:0.27422|Test loss:22.19657
327
+ 2024-05-05 12:48:21,046 [INFO ] 212|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42858|msssim:0.34387|lpips:0.49946|dists:0.27734|Test loss:22.23478
328
+ 2024-05-05 13:09:31,493 [INFO ] 213|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31374|msssim:0.34684|lpips:0.49643|dists:0.27400|Test loss:22.22582
329
+ 2024-05-05 13:30:34,522 [INFO ] 214|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21293|msssim:0.34112|lpips:0.49883|dists:0.27535|Test loss:22.21174
330
+ 2024-05-05 13:51:47,865 [INFO ] 215|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41841|msssim:0.34328|lpips:0.49828|dists:0.27544|Test loss:22.16456
331
+ 2024-05-05 14:13:13,905 [INFO ] 216|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25790|msssim:0.34093|lpips:0.49955|dists:0.27657|Test loss:22.21031
332
+ 2024-05-05 14:34:27,772 [INFO ] 217|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25331|msssim:0.34462|lpips:0.50060|dists:0.27798|Test loss:22.19541
333
+ 2024-05-05 14:55:44,043 [INFO ] 218|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34120|msssim:0.34909|lpips:0.49698|dists:0.27423|Test loss:22.18064
334
+ 2024-05-05 15:17:02,181 [INFO ] 219|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37122|msssim:0.34335|lpips:0.49592|dists:0.27407|Test loss:22.21357
335
+ 2024-05-05 15:38:13,952 [INFO ] 220|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33670|msssim:0.34260|lpips:0.50183|dists:0.27681|Test loss:22.21649
336
+ 2024-05-05 15:59:11,744 [INFO ] 221|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28591|msssim:0.34584|lpips:0.50050|dists:0.27571|Test loss:22.23926
337
+ 2024-05-05 16:20:07,029 [INFO ] 222|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26039|msssim:0.34677|lpips:0.49812|dists:0.27303|Test loss:22.18227
338
+ 2024-05-05 16:41:01,438 [INFO ] 223|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35407|msssim:0.34511|lpips:0.49698|dists:0.27452|Test loss:22.20431
339
+ 2024-05-05 17:01:56,573 [INFO ] 224|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36067|msssim:0.34552|lpips:0.49578|dists:0.27214|Test loss:22.16769
340
+ 2024-05-05 17:23:09,208 [INFO ] 225|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34319|msssim:0.34962|lpips:0.49547|dists:0.27391|Test loss:22.15918
341
+ 2024-05-05 17:44:29,823 [INFO ] 226|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21200|msssim:0.34131|lpips:0.49971|dists:0.27562|Test loss:22.19288
342
+ 2024-05-05 18:05:38,805 [INFO ] 227|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34673|msssim:0.34305|lpips:0.49692|dists:0.27394|Test loss:22.18381
343
+ 2024-05-05 18:26:44,515 [INFO ] 228|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26413|msssim:0.33844|lpips:0.50093|dists:0.27558|Test loss:22.21518
344
+ 2024-05-05 18:47:55,570 [INFO ] 229|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20001|msssim:0.34085|lpips:0.50149|dists:0.27539|Test loss:22.16549
345
+ 2024-05-05 18:49:21,297 [INFO ] 230|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24847|msssim:0.34171|lpips:0.50002|dists:0.27554|Test loss:22.18368
346
+ 2024-05-05 19:10:31,373 [INFO ] 231|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26719|msssim:0.34548|lpips:0.50115|dists:0.27678|Test loss:22.23358
347
+ 2024-05-05 19:31:27,088 [INFO ] 232|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29815|msssim:0.34711|lpips:0.49873|dists:0.27557|Test loss:22.20006
348
+ 2024-05-05 19:52:35,830 [INFO ] 233|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45504|msssim:0.35044|lpips:0.50058|dists:0.27448|Test loss:22.19251
349
+ 2024-05-05 20:13:42,913 [INFO ] 234|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36103|msssim:0.34949|lpips:0.49969|dists:0.27417|Test loss:22.19380
350
+ 2024-05-05 20:34:48,991 [INFO ] 235|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43152|msssim:0.34284|lpips:0.49837|dists:0.27674|Test loss:22.17983
351
+ 2024-05-05 20:55:57,801 [INFO ] 236|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31431|msssim:0.34494|lpips:0.49805|dists:0.27595|Test loss:22.18135
352
+ 2024-05-05 21:17:01,921 [INFO ] 237|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39373|msssim:0.34481|lpips:0.49903|dists:0.27500|Test loss:22.19592
353
+ 2024-05-05 21:38:15,187 [INFO ] 238|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32557|msssim:0.34869|lpips:0.50099|dists:0.27830|Test loss:22.18304
354
+ 2024-05-05 21:59:20,693 [INFO ] 239|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35618|msssim:0.34089|lpips:0.49914|dists:0.27589|Test loss:22.22733
355
+ 2024-05-05 22:20:28,843 [INFO ] 240|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38320|msssim:0.34359|lpips:0.49845|dists:0.27558|Test loss:22.20330
356
+ 2024-05-05 22:41:33,103 [INFO ] 241|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41832|msssim:0.34829|lpips:0.49624|dists:0.27300|Test loss:22.20170
357
+ 2024-05-05 23:02:36,039 [INFO ] 242|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26759|msssim:0.34710|lpips:0.49868|dists:0.27406|Test loss:22.18713
358
+ 2024-05-05 23:23:41,292 [INFO ] 243|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25480|msssim:0.33822|lpips:0.49764|dists:0.27286|Test loss:22.20413
359
+ 2024-05-05 23:44:51,601 [INFO ] 244|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24137|msssim:0.34107|lpips:0.49667|dists:0.27336|Test loss:22.19491
360
+ 2024-05-06 00:05:59,368 [INFO ] 245|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.44126|msssim:0.34903|lpips:0.49630|dists:0.27284|Test loss:22.21318
361
+ 2024-05-06 00:26:57,417 [INFO ] 246|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35589|msssim:0.34458|lpips:0.49846|dists:0.27602|Test loss:22.18793
362
+ 2024-05-06 00:48:01,606 [INFO ] 247|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29865|msssim:0.34271|lpips:0.49826|dists:0.27014|Test loss:22.17141
363
+ 2024-05-06 01:09:07,370 [INFO ] 248|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.31535|msssim:0.34943|lpips:0.49795|dists:0.27377|Test loss:22.17951
364
+ 2024-05-06 01:30:20,686 [INFO ] 249|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39809|msssim:0.34563|lpips:0.50040|dists:0.27597|Test loss:22.16323
365
+ 2024-05-06 01:51:22,543 [INFO ] 250|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28734|msssim:0.35247|lpips:0.49776|dists:0.27438|Test loss:22.14602
366
+ 2024-05-06 02:12:51,537 [INFO ] 251|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.28380|msssim:0.34202|lpips:0.49683|dists:0.27599|Test loss:22.20221
367
+ 2024-05-06 02:33:50,217 [INFO ] 252|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33769|msssim:0.35137|lpips:0.49936|dists:0.27664|Test loss:22.18598
368
+ 2024-05-06 02:55:05,177 [INFO ] 253|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30681|msssim:0.34182|lpips:0.50133|dists:0.27604|Test loss:22.19516
369
+ 2024-05-06 03:16:12,944 [INFO ] 254|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39329|msssim:0.34256|lpips:0.49706|dists:0.27318|Test loss:22.16968
370
+ 2024-05-06 03:37:19,768 [INFO ] 255|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.44825|msssim:0.34998|lpips:0.49653|dists:0.27555|Test loss:22.16102
371
+ 2024-05-06 03:58:30,008 [INFO ] 256|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.30121|msssim:0.34926|lpips:0.49515|dists:0.27333|Test loss:22.14309
372
+ 2024-05-06 04:19:52,471 [INFO ] 257|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35113|msssim:0.34655|lpips:0.49584|dists:0.27406|Test loss:22.19664
373
+ 2024-05-06 04:40:58,643 [INFO ] 258|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34788|msssim:0.34555|lpips:0.49748|dists:0.27429|Test loss:22.15390
374
+ 2024-05-06 05:02:09,625 [INFO ] 259|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37536|msssim:0.34993|lpips:0.49926|dists:0.27615|Test loss:22.16938
375
+ 2024-05-06 05:23:15,111 [INFO ] 260|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.30780|msssim:0.34366|lpips:0.49741|dists:0.27334|Test loss:22.14079
376
+ 2024-05-06 05:44:58,198 [INFO ] 261|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32658|msssim:0.34584|lpips:0.49735|dists:0.27514|Test loss:22.15270
377
+ 2024-05-06 06:06:38,107 [INFO ] 262|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34172|msssim:0.34937|lpips:0.49673|dists:0.27184|Test loss:22.16732
378
+ 2024-05-06 06:08:06,494 [INFO ] 263|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34207|msssim:0.34853|lpips:0.49675|dists:0.27406|Test loss:22.15906
379
+ 2024-05-06 06:29:36,637 [INFO ] 264|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.46720|msssim:0.34606|lpips:0.49734|dists:0.27330|Test loss:22.16713
380
+ 2024-05-06 06:50:49,014 [INFO ] 265|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39163|msssim:0.34577|lpips:0.49680|dists:0.27300|Test loss:22.15004
381
+ 2024-05-06 07:11:43,798 [INFO ] 266|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34610|msssim:0.34627|lpips:0.49790|dists:0.27368|Test loss:22.18821
382
+ 2024-05-06 07:32:32,314 [INFO ] 267|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35793|msssim:0.34650|lpips:0.49353|dists:0.27379|Test loss:22.15183
383
+ 2024-05-06 07:53:23,679 [INFO ] 268|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35679|msssim:0.34936|lpips:0.49728|dists:0.27346|Test loss:22.14839
384
+ 2024-05-06 08:14:16,352 [INFO ] 269|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33630|msssim:0.34671|lpips:0.49758|dists:0.27488|Test loss:22.11851
385
+ 2024-05-06 08:35:41,030 [INFO ] 270|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.50002|msssim:0.35202|lpips:0.49321|dists:0.27390|Test loss:22.13897
386
+ 2024-05-06 08:55:44,311 [INFO ] 271|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.26539|msssim:0.33970|lpips:0.49583|dists:0.27329|Test loss:22.18124
387
+ 2024-05-06 09:15:39,615 [INFO ] 272|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35656|msssim:0.34859|lpips:0.49631|dists:0.27369|Test loss:22.16513
388
+ 2024-05-06 09:35:34,569 [INFO ] 273|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37810|msssim:0.34932|lpips:0.49532|dists:0.27358|Test loss:22.14227
389
+ 2024-05-06 09:55:25,915 [INFO ] 274|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33323|msssim:0.34434|lpips:0.49772|dists:0.27299|Test loss:22.13977
390
+ 2024-05-06 10:15:09,984 [INFO ] 275|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29836|msssim:0.34814|lpips:0.49551|dists:0.27238|Test loss:22.17381
391
+ 2024-05-06 10:34:55,762 [INFO ] 276|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41804|msssim:0.34762|lpips:0.49558|dists:0.27152|Test loss:22.13700
392
+ 2024-05-06 10:54:41,271 [INFO ] 277|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34791|msssim:0.34759|lpips:0.49562|dists:0.27405|Test loss:22.12424
393
+ 2024-05-06 11:14:24,991 [INFO ] 278|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43791|msssim:0.34807|lpips:0.49478|dists:0.27218|Test loss:22.15800
394
+ 2024-05-06 11:34:06,688 [INFO ] 279|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37573|msssim:0.34698|lpips:0.49517|dists:0.27491|Test loss:22.17768
395
+ 2024-05-06 11:53:52,546 [INFO ] 280|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.32452|msssim:0.34915|lpips:0.49737|dists:0.27415|Test loss:22.14754
396
+ 2024-05-06 12:13:42,463 [INFO ] 281|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23457|msssim:0.34536|lpips:0.49823|dists:0.27429|Test loss:22.18300
397
+ 2024-05-06 12:33:26,800 [INFO ] 282|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32926|msssim:0.34179|lpips:0.49565|dists:0.27257|Test loss:22.12766
398
+ 2024-05-06 12:53:12,784 [INFO ] 283|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37897|msssim:0.34765|lpips:0.49597|dists:0.27514|Test loss:22.13906
399
+ 2024-05-06 13:13:01,155 [INFO ] 284|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35453|msssim:0.34555|lpips:0.49322|dists:0.27190|Test loss:22.15379
400
+ 2024-05-06 13:32:55,503 [INFO ] 285|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.25997|msssim:0.34366|lpips:0.49811|dists:0.27320|Test loss:22.14294
401
+ 2024-05-06 13:53:05,337 [INFO ] 286|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33252|msssim:0.34884|lpips:0.49568|dists:0.27505|Test loss:22.14334
402
+ 2024-05-06 14:12:57,225 [INFO ] 287|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37169|msssim:0.34353|lpips:0.49783|dists:0.27417|Test loss:22.14881
403
+ 2024-05-06 14:32:49,867 [INFO ] 288|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28825|msssim:0.34587|lpips:0.49870|dists:0.27497|Test loss:22.16187
404
+ 2024-05-06 14:52:38,558 [INFO ] 289|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37071|msssim:0.34675|lpips:0.49692|dists:0.27600|Test loss:22.16026
405
+ 2024-05-06 15:12:23,303 [INFO ] 290|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37343|msssim:0.35109|lpips:0.49530|dists:0.27273|Test loss:22.14378
406
+ 2024-05-06 15:32:11,252 [INFO ] 291|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37243|msssim:0.34936|lpips:0.49572|dists:0.27396|Test loss:22.17331
407
+ 2024-05-06 15:51:58,068 [INFO ] 292|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.55883|msssim:0.35747|lpips:0.49342|dists:0.27048|Test loss:22.13242
408
+ 2024-05-06 16:11:44,645 [INFO ] 293|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33236|msssim:0.34500|lpips:0.49713|dists:0.27188|Test loss:22.13761
409
+ 2024-05-06 16:31:32,164 [INFO ] 294|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37576|msssim:0.35218|lpips:0.49406|dists:0.27047|Test loss:22.14124
410
+ 2024-05-06 16:51:17,715 [INFO ] 295|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38965|msssim:0.34790|lpips:0.49487|dists:0.27294|Test loss:22.12060
411
+ 2024-05-06 16:52:39,431 [INFO ] 296|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32517|msssim:0.35096|lpips:0.49530|dists:0.27325|Test loss:22.13534
412
+ 2024-05-06 17:12:41,901 [INFO ] 297|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38625|msssim:0.35519|lpips:0.49261|dists:0.27088|Test loss:22.13709
413
+ 2024-05-06 17:32:30,223 [INFO ] 298|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.45197|msssim:0.35502|lpips:0.49372|dists:0.27574|Test loss:22.10319
414
+ 2024-05-06 17:53:02,560 [INFO ] 299|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27800|msssim:0.34538|lpips:0.49640|dists:0.27273|Test loss:22.16596
415
+ 2024-05-06 18:13:17,675 [INFO ] 300|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43854|msssim:0.35064|lpips:0.49498|dists:0.27317|Test loss:22.13541
416
+ 2024-05-06 18:33:40,420 [INFO ] 301|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42806|msssim:0.35004|lpips:0.49591|dists:0.27262|Test loss:22.13021
417
+ 2024-05-06 18:54:35,721 [INFO ] 302|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40629|msssim:0.35309|lpips:0.49433|dists:0.27251|Test loss:22.10585
418
+ 2024-05-06 19:15:23,438 [INFO ] 303|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43635|msssim:0.35212|lpips:0.49461|dists:0.27331|Test loss:22.09783
419
+ 2024-05-06 19:36:11,827 [INFO ] 304|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41165|msssim:0.34843|lpips:0.49495|dists:0.27378|Test loss:22.13802
420
+ 2024-05-06 19:56:38,743 [INFO ] 305|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36958|msssim:0.35154|lpips:0.49564|dists:0.27358|Test loss:22.18889
421
+ 2024-05-06 20:17:02,992 [INFO ] 306|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.36006|msssim:0.34681|lpips:0.49732|dists:0.27467|Test loss:22.16472
422
+ 2024-05-06 20:37:27,146 [INFO ] 307|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33619|msssim:0.34884|lpips:0.49251|dists:0.27202|Test loss:22.10112
423
+ 2024-05-06 20:57:51,479 [INFO ] 308|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38734|msssim:0.34704|lpips:0.49590|dists:0.27333|Test loss:22.16504
424
+ 2024-05-06 21:17:40,325 [INFO ] 309|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.47638|msssim:0.34979|lpips:0.49419|dists:0.27257|Test loss:22.09890
425
+ 2024-05-06 21:38:20,823 [INFO ] 310|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45121|msssim:0.35461|lpips:0.49393|dists:0.27205|Test loss:22.10250
426
+ 2024-05-06 21:58:54,290 [INFO ] 311|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42236|msssim:0.34856|lpips:0.49429|dists:0.27074|Test loss:22.12690
427
+ 2024-05-06 22:19:28,569 [INFO ] 312|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42022|msssim:0.34901|lpips:0.49516|dists:0.27319|Test loss:22.14768
428
+ 2024-05-06 22:40:03,891 [INFO ] 313|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35907|msssim:0.34569|lpips:0.49461|dists:0.27076|Test loss:22.15050
429
+ 2024-05-06 23:00:39,129 [INFO ] 314|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28936|msssim:0.34797|lpips:0.49777|dists:0.27248|Test loss:22.13753
430
+ 2024-05-06 23:21:22,649 [INFO ] 315|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29058|msssim:0.35014|lpips:0.49607|dists:0.27623|Test loss:22.14135
431
+ 2024-05-06 23:42:05,520 [INFO ] 316|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37230|msssim:0.34671|lpips:0.49491|dists:0.27353|Test loss:22.08937
432
+ 2024-05-07 00:02:59,358 [INFO ] 317|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33447|msssim:0.35346|lpips:0.49475|dists:0.27089|Test loss:22.10210
433
+ 2024-05-07 00:23:07,579 [INFO ] 318|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49132|msssim:0.35204|lpips:0.49550|dists:0.27231|Test loss:22.11998
434
+ 2024-05-07 00:42:57,724 [INFO ] 319|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37582|msssim:0.35174|lpips:0.49482|dists:0.27242|Test loss:22.16880
435
+ 2024-05-07 01:02:46,713 [INFO ] 320|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35898|msssim:0.35062|lpips:0.49379|dists:0.27271|Test loss:22.11210
436
+ 2024-05-07 01:22:57,462 [INFO ] 321|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.32817|msssim:0.35568|lpips:0.49295|dists:0.26980|Test loss:22.14307
437
+ 2024-05-07 01:42:55,818 [INFO ] 322|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40482|msssim:0.34941|lpips:0.49731|dists:0.27440|Test loss:22.12967
438
+ 2024-05-07 02:03:35,260 [INFO ] 323|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.48860|msssim:0.35103|lpips:0.49498|dists:0.27326|Test loss:22.12212
439
+ 2024-05-07 02:23:46,368 [INFO ] 324|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41354|msssim:0.35028|lpips:0.49496|dists:0.27339|Test loss:22.11742
440
+ 2024-05-07 02:43:38,446 [INFO ] 325|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37060|msssim:0.34866|lpips:0.49335|dists:0.27179|Test loss:22.13050
441
+ 2024-05-07 03:03:35,008 [INFO ] 326|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.52411|msssim:0.35067|lpips:0.49280|dists:0.27236|Test loss:22.11754
442
+ 2024-05-07 03:23:27,535 [INFO ] 327|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39696|msssim:0.35174|lpips:0.49652|dists:0.27232|Test loss:22.12615
443
+ 2024-05-07 03:43:44,363 [INFO ] 328|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39724|msssim:0.34908|lpips:0.49480|dists:0.27135|Test loss:22.14223
444
+ 2024-05-07 03:45:05,759 [INFO ] 329|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40106|msssim:0.35324|lpips:0.49468|dists:0.27436|Test loss:22.10712
445
+ 2024-05-07 04:04:47,920 [INFO ] 330|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32596|msssim:0.34426|lpips:0.49445|dists:0.26987|Test loss:22.13806
446
+ 2024-05-07 04:24:31,360 [INFO ] 331|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26787|msssim:0.34595|lpips:0.49652|dists:0.27052|Test loss:22.12575
447
+ 2024-05-07 04:44:11,953 [INFO ] 332|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43838|msssim:0.34961|lpips:0.49307|dists:0.27184|Test loss:22.11797
448
+ 2024-05-07 05:03:51,738 [INFO ] 333|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42003|msssim:0.35362|lpips:0.49367|dists:0.27208|Test loss:22.11651
449
+ 2024-05-07 05:23:34,532 [INFO ] 334|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.37144|msssim:0.35248|lpips:0.49673|dists:0.27496|Test loss:22.14811
450
+ 2024-05-07 05:43:17,906 [INFO ] 335|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.41272|msssim:0.35291|lpips:0.49492|dists:0.27193|Test loss:22.12853
451
+ 2024-05-07 06:03:26,128 [INFO ] 336|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31984|msssim:0.34795|lpips:0.49517|dists:0.27108|Test loss:22.11144
452
+ 2024-05-07 06:23:14,028 [INFO ] 337|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35224|msssim:0.35102|lpips:0.49385|dists:0.27123|Test loss:22.12615
453
+ 2024-05-07 06:43:29,005 [INFO ] 338|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34989|msssim:0.34449|lpips:0.49436|dists:0.27227|Test loss:22.14807
454
+ 2024-05-07 07:03:22,585 [INFO ] 339|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35900|msssim:0.35273|lpips:0.49441|dists:0.27247|Test loss:22.11002
455
+ 2024-05-07 07:23:32,865 [INFO ] 340|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43373|msssim:0.35449|lpips:0.49218|dists:0.27258|Test loss:22.09663
456
+ 2024-05-07 07:44:08,078 [INFO ] 341|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49582|msssim:0.35435|lpips:0.49275|dists:0.27220|Test loss:22.09673
457
+ 2024-05-07 08:05:12,924 [INFO ] 342|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27418|msssim:0.35315|lpips:0.49452|dists:0.27191|Test loss:22.11007
458
+ 2024-05-07 08:26:23,259 [INFO ] 343|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35699|msssim:0.34731|lpips:0.49411|dists:0.27074|Test loss:22.13691
459
+ 2024-05-07 08:47:40,259 [INFO ] 344|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43843|msssim:0.35445|lpips:0.49354|dists:0.27134|Test loss:22.09072
460
+ 2024-05-07 09:08:57,546 [INFO ] 345|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40076|msssim:0.35221|lpips:0.49355|dists:0.27157|Test loss:22.12100
461
+ 2024-05-07 09:30:12,427 [INFO ] 346|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36718|msssim:0.35203|lpips:0.49355|dists:0.27277|Test loss:22.08085
462
+ 2024-05-07 09:50:36,931 [INFO ] 347|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34232|msssim:0.34843|lpips:0.49354|dists:0.27177|Test loss:22.12479
463
+ 2024-05-07 10:10:45,355 [INFO ] 348|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42333|msssim:0.34844|lpips:0.49300|dists:0.27182|Test loss:22.11209
464
+ 2024-05-07 10:30:50,802 [INFO ] 349|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42717|msssim:0.35086|lpips:0.49458|dists:0.27253|Test loss:22.11386
465
+ 2024-05-07 10:50:55,389 [INFO ] 350|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35927|msssim:0.34962|lpips:0.49381|dists:0.26980|Test loss:22.10434
466
+ 2024-05-07 11:11:33,800 [INFO ] 351|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29633|msssim:0.34823|lpips:0.49724|dists:0.27674|Test loss:22.14577
467
+ 2024-05-07 11:32:25,539 [INFO ] 352|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36162|msssim:0.35235|lpips:0.49426|dists:0.26995|Test loss:22.10663
468
+ 2024-05-07 11:53:24,658 [INFO ] 353|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37880|msssim:0.35225|lpips:0.49382|dists:0.26963|Test loss:22.11384
469
+ 2024-05-07 12:14:36,359 [INFO ] 354|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40548|msssim:0.34668|lpips:0.49391|dists:0.27245|Test loss:22.14980
470
+ 2024-05-07 12:35:38,612 [INFO ] 355|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42286|msssim:0.35354|lpips:0.49382|dists:0.27242|Test loss:22.12479
471
+ 2024-05-07 12:56:38,507 [INFO ] 356|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33143|msssim:0.35170|lpips:0.49645|dists:0.27390|Test loss:22.13347
472
+ 2024-05-07 13:17:39,596 [INFO ] 357|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36672|msssim:0.34573|lpips:0.49592|dists:0.27143|Test loss:22.13349
473
+ 2024-05-07 13:38:44,893 [INFO ] 358|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42944|msssim:0.35298|lpips:0.49284|dists:0.27126|Test loss:22.12769
474
+ 2024-05-07 13:59:43,389 [INFO ] 359|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38038|msssim:0.35037|lpips:0.49457|dists:0.27234|Test loss:22.09782
475
+ 2024-05-07 14:20:17,253 [INFO ] 360|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38950|msssim:0.35237|lpips:0.49389|dists:0.27129|Test loss:22.09906
476
+ 2024-05-07 14:40:43,289 [INFO ] 361|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43936|msssim:0.34640|lpips:0.49196|dists:0.27161|Test loss:22.11428
477
+ 2024-05-07 14:42:07,243 [INFO ] 362|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.40003|msssim:0.35458|lpips:0.49361|dists:0.27126|Test loss:22.08922
478
+ 2024-05-07 15:02:21,429 [INFO ] 363|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.48445|msssim:0.35553|lpips:0.49189|dists:0.27160|Test loss:22.08326
479
+ 2024-05-07 15:22:34,155 [INFO ] 364|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.35317|msssim:0.35344|lpips:0.49614|dists:0.27294|Test loss:22.15190
480
+ 2024-05-07 15:42:52,603 [INFO ] 365|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42529|msssim:0.35113|lpips:0.49577|dists:0.27466|Test loss:22.11909
481
+ 2024-05-07 16:03:38,749 [INFO ] 366|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.33904|msssim:0.35202|lpips:0.49448|dists:0.27241|Test loss:22.12104
482
+ 2024-05-07 16:24:24,975 [INFO ] 367|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42592|msssim:0.35292|lpips:0.49367|dists:0.27379|Test loss:22.11291
483
+ 2024-05-07 16:44:56,650 [INFO ] 368|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39894|msssim:0.35064|lpips:0.49435|dists:0.27276|Test loss:22.12427
484
+ 2024-05-07 17:05:25,010 [INFO ] 369|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.48808|msssim:0.35531|lpips:0.49202|dists:0.27039|Test loss:22.11523
485
+ 2024-05-07 17:25:58,021 [INFO ] 370|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39756|msssim:0.35016|lpips:0.49330|dists:0.27168|Test loss:22.10594
486
+ 2024-05-07 17:46:18,607 [INFO ] 371|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.49306|msssim:0.35298|lpips:0.49468|dists:0.27251|Test loss:22.11654
487
+ 2024-05-07 18:06:49,127 [INFO ] 372|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34533|msssim:0.35142|lpips:0.49453|dists:0.27251|Test loss:22.12963
488
+ 2024-05-07 18:27:08,641 [INFO ] 373|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.34550|msssim:0.35208|lpips:0.49310|dists:0.27301|Test loss:22.11551
489
+ 2024-05-07 18:47:50,057 [INFO ] 374|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37816|msssim:0.35055|lpips:0.49417|dists:0.27248|Test loss:22.10346
490
+ 2024-05-07 19:08:22,282 [INFO ] 375|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36707|msssim:0.34627|lpips:0.49279|dists:0.27090|Test loss:22.07935
491
+ 2024-05-07 19:29:07,038 [INFO ] 376|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42175|msssim:0.35269|lpips:0.49278|dists:0.27011|Test loss:22.07009
492
+ 2024-05-07 19:49:50,578 [INFO ] 377|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.40540|msssim:0.35338|lpips:0.49272|dists:0.27135|Test loss:22.12318
493
+ 2024-05-07 20:10:17,385 [INFO ] 378|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39138|msssim:0.34611|lpips:0.49605|dists:0.27153|Test loss:22.13597
494
+ 2024-05-07 20:30:46,458 [INFO ] 379|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.39260|msssim:0.34934|lpips:0.49563|dists:0.27480|Test loss:22.14463
495
+ 2024-05-07 20:51:11,028 [INFO ] 380|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.42045|msssim:0.34795|lpips:0.49372|dists:0.26907|Test loss:22.10082
496
+ 2024-05-07 21:11:40,325 [INFO ] 381|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27697|msssim:0.34880|lpips:0.49537|dists:0.27470|Test loss:22.09123
497
+ 2024-05-07 21:32:25,704 [INFO ] 382|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35365|msssim:0.34741|lpips:0.49410|dists:0.27204|Test loss:22.12331
498
+ 2024-05-07 21:53:04,306 [INFO ] 383|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28105|msssim:0.35066|lpips:0.49763|dists:0.27276|Test loss:22.13340
499
+ 2024-05-07 22:13:43,291 [INFO ] 384|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31601|msssim:0.34885|lpips:0.49536|dists:0.27251|Test loss:22.12415
500
+ 2024-05-07 22:34:19,100 [INFO ] 385|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.34422|msssim:0.35122|lpips:0.49521|dists:0.27288|Test loss:22.14277
501
+ 2024-05-07 22:54:56,647 [INFO ] 386|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.41440|msssim:0.34890|lpips:0.49383|dists:0.27097|Test loss:22.10782
502
+ 2024-05-07 23:15:36,868 [INFO ] 387|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.45614|msssim:0.35336|lpips:0.49236|dists:0.27192|Test loss:22.09378
503
+ 2024-05-07 23:36:07,945 [INFO ] 388|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38647|msssim:0.34605|lpips:0.49631|dists:0.27383|Test loss:22.12287
504
+ 2024-05-07 23:56:49,417 [INFO ] 389|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.38224|msssim:0.35091|lpips:0.49457|dists:0.27112|Test loss:22.07871
505
+ 2024-05-08 00:17:12,737 [INFO ] 390|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.43890|msssim:0.35296|lpips:0.49275|dists:0.27223|Test loss:22.13008
506
+ 2024-05-08 00:37:32,168 [INFO ] 391|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35077|msssim:0.34966|lpips:0.49267|dists:0.27072|Test loss:22.13726
507
+ 2024-05-08 00:58:00,221 [INFO ] 392|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.43491|msssim:0.34659|lpips:0.49620|dists:0.27359|Test loss:22.16497
508
+ 2024-05-08 01:18:28,620 [INFO ] 393|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.46330|msssim:0.35150|lpips:0.49221|dists:0.27154|Test loss:22.10838
509
+ 2024-05-08 01:39:01,279 [INFO ] 394|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.48316|msssim:0.35406|lpips:0.49092|dists:0.26870|Test loss:22.06746
510
+ 2024-05-08 01:40:42,930 [INFO ] 395|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.39606|msssim:0.35099|lpips:0.49426|dists:0.27250|Test loss:22.11529
511
+ 2024-05-08 02:00:57,924 [INFO ] 396|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.37678|msssim:0.35195|lpips:0.49500|dists:0.27399|Test loss:22.10022
512
+ 2024-05-08 02:21:26,695 [INFO ] 397|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36464|msssim:0.35303|lpips:0.49387|dists:0.27059|Test loss:22.10485
513
+ 2024-05-08 02:42:04,184 [INFO ] 398|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.42602|msssim:0.34833|lpips:0.49350|dists:0.27095|Test loss:22.11267
514
+ 2024-05-08 03:02:23,282 [INFO ] 399|bpp:0.01875|mask:0.00888|mask_ratio:0.75|psnr:12.46092|msssim:0.34956|lpips:0.49528|dists:0.27270|Test loss:22.10216
MIM_vbr/events.out.tfevents.1714742678.t2vg-a100-G4-10.4118013.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03724613c625203caa59f67da24b894285dbf610923f2262214ffab63cca8c59
3
+ size 4582056
MIM_vbr/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
MIM_vbr_eval_lossless/20240518_044257.log ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-05-18 04:42:57,952 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240518_044257.log
2
+ 2024-05-18 04:42:57,952 [INFO ] ======================= exp =======================
3
+ 2024-05-18 04:42:57,952 [INFO ] config:config/vpt_default.yaml
4
+ 2024-05-18 04:42:57,952 [INFO ] name:exp
5
+ 2024-05-18 04:42:57,952 [INFO ] lr:None
6
+ 2024-05-18 04:42:57,952 [INFO ] world_size:1
7
+ 2024-05-18 04:42:57,952 [INFO ] local_rank:-1
8
+ 2024-05-18 04:42:57,952 [INFO ] dist_on_itp:False
9
+ 2024-05-18 04:42:57,952 [INFO ] dist_url:env://
10
+ 2024-05-18 04:42:57,952 [INFO ] TEST:False
11
+ 2024-05-18 04:42:57,952 [INFO ] model:mage_vit_base_patch16
12
+ 2024-05-18 04:42:57,953 [INFO ] dataset:imagenet
13
+ 2024-05-18 04:42:57,953 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
14
+ 2024-05-18 04:42:57,953 [INFO ] eval_path:
15
+ 2024-05-18 04:42:57,953 [INFO ] batch_size:64
16
+ 2024-05-18 04:42:57,953 [INFO ] test_batch_size:200
17
+ 2024-05-18 04:42:57,953 [INFO ] num_workers:4
18
+ 2024-05-18 04:42:57,953 [INFO ] epochs:400
19
+ 2024-05-18 04:42:57,953 [INFO ] blr:0.0002
20
+ 2024-05-18 04:42:57,953 [INFO ] min_lr:5e-06
21
+ 2024-05-18 04:42:57,953 [INFO ] accum_iter:1
22
+ 2024-05-18 04:42:57,953 [INFO ] input_size:256
23
+ 2024-05-18 04:42:57,953 [INFO ] weight_decay:0.05
24
+ 2024-05-18 04:42:57,953 [INFO ] warmup_epochs:20
25
+ 2024-05-18 04:42:57,953 [INFO ] mask_ratio_min:0.5
26
+ 2024-05-18 04:42:57,953 [INFO ] mask_ratio_max:0.9
27
+ 2024-05-18 04:42:57,953 [INFO ] mask_ratio_mu:0.55
28
+ 2024-05-18 04:42:57,953 [INFO ] mask_ratio_std:0.25
29
+ 2024-05-18 04:42:57,953 [INFO ] grad_clip:3.0
30
+ 2024-05-18 04:42:57,953 [INFO ] device:cuda
31
+ 2024-05-18 04:42:57,953 [INFO ] distributed:True
32
+ 2024-05-18 04:42:57,953 [INFO ] seed:0
33
+ 2024-05-18 04:42:57,953 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
34
+ 2024-05-18 04:42:57,953 [INFO ] start_epoch:0
35
+ 2024-05-18 04:42:57,953 [INFO ] pin_mem:True
36
+ 2024-05-18 04:42:57,953 [INFO ] save:True
37
+ 2024-05-18 04:42:57,953 [INFO ] exp_name:MIM_vbr_eval_lossless
38
+ 2024-05-18 04:42:57,954 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
39
+ 2024-05-18 04:42:57,954 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
40
+ 2024-05-18 04:42:57,954 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
41
+ 2024-05-18 04:42:57,954 [INFO ] strart_epoch:0
42
+ 2024-05-18 04:42:57,954 [INFO ] cuda:True
43
+ 2024-05-18 04:42:57,954 [INFO ] rank:0
44
+ 2024-05-18 04:42:57,954 [INFO ] gpu:0
45
+ 2024-05-18 04:42:57,954 [INFO ] dist_backend:nccl
46
+ 2024-05-18 04:42:57,954 [INFO ] ===================================================
47
+ 2024-05-18 06:20:16,116 [INFO ] -1|bpp:0.02418|mask:0.00847|mask_ratio:0.6|psnr:14.80285|msssim:0.51750|lpips:0.27254|dists:0.19907|accu:0.63626|Test loss:64.62691
48
+ 2024-05-18 07:56:33,739 [INFO ] -1|bpp:0.02791|mask:0.00819|mask_ratio:0.5|psnr:15.67397|msssim:0.56751|lpips:0.23728|dists:0.18073|accu:0.67258|Test loss:139.71758
49
+ 2024-05-18 09:30:08,858 [INFO ] -1|bpp:0.02046|mask:0.00875|mask_ratio:0.7|psnr:13.74718|msssim:0.45169|lpips:0.32138|dists:0.22512|accu:0.56262|Test loss:30.11541
50
+ 2024-05-18 11:05:40,441 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:13.17649|msssim:0.41227|lpips:0.35134|dists:0.24154|accu:0.50616|Test loss:21.23853
51
+ 2024-05-18 12:36:08,830 [INFO ] -1|bpp:0.03521|mask:0.00764|mask_ratio:0.3|psnr:17.07054|msssim:0.63998|lpips:0.18754|dists:0.15498|accu:0.70446|Test loss:643.00879
52
+ 2024-05-18 14:00:50,675 [INFO ] -1|bpp:0.04252|mask:0.00709|mask_ratio:0.1|psnr:18.25655|msssim:0.69397|lpips:0.15150|dists:0.13635|accu:0.72014|Test loss:2993.12524
53
+ 2024-05-18 15:25:33,426 [INFO ] -1|bpp:0.04581|mask:0.00684|mask_ratio:0.01|psnr:18.75086|msssim:0.71406|lpips:0.13808|dists:0.12934|accu:0.72174|Test loss:6007.86475
MIM_vbr_eval_lossless/20240521_120210.log ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-05-21 12:02:10,417 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240521_120210.log
2
+ 2024-05-21 12:02:10,418 [INFO ] ======================= exp =======================
3
+ 2024-05-21 12:02:10,418 [INFO ] config:config/vpt_default.yaml
4
+ 2024-05-21 12:02:10,418 [INFO ] name:exp
5
+ 2024-05-21 12:02:10,418 [INFO ] lr:None
6
+ 2024-05-21 12:02:10,418 [INFO ] world_size:1
7
+ 2024-05-21 12:02:10,418 [INFO ] local_rank:-1
8
+ 2024-05-21 12:02:10,418 [INFO ] dist_on_itp:False
9
+ 2024-05-21 12:02:10,418 [INFO ] dist_url:env://
10
+ 2024-05-21 12:02:10,418 [INFO ] TEST:False
11
+ 2024-05-21 12:02:10,418 [INFO ] model:mage_vit_base_patch16
12
+ 2024-05-21 12:02:10,418 [INFO ] dataset:imagenet
13
+ 2024-05-21 12:02:10,418 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
14
+ 2024-05-21 12:02:10,418 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
15
+ 2024-05-21 12:02:10,418 [INFO ] eval_path:
16
+ 2024-05-21 12:02:10,418 [INFO ] batch_size:64
17
+ 2024-05-21 12:02:10,418 [INFO ] test_batch_size:200
18
+ 2024-05-21 12:02:10,418 [INFO ] num_workers:4
19
+ 2024-05-21 12:02:10,418 [INFO ] epochs:400
20
+ 2024-05-21 12:02:10,418 [INFO ] blr:0.0002
21
+ 2024-05-21 12:02:10,419 [INFO ] min_lr:5e-06
22
+ 2024-05-21 12:02:10,419 [INFO ] accum_iter:1
23
+ 2024-05-21 12:02:10,419 [INFO ] input_size:256
24
+ 2024-05-21 12:02:10,419 [INFO ] weight_decay:0.05
25
+ 2024-05-21 12:02:10,419 [INFO ] warmup_epochs:20
26
+ 2024-05-21 12:02:10,419 [INFO ] mask_ratio_min:0.5
27
+ 2024-05-21 12:02:10,419 [INFO ] mask_ratio_max:0.9
28
+ 2024-05-21 12:02:10,419 [INFO ] mask_ratio_mu:0.55
29
+ 2024-05-21 12:02:10,419 [INFO ] mask_ratio_std:0.25
30
+ 2024-05-21 12:02:10,419 [INFO ] grad_clip:3.0
31
+ 2024-05-21 12:02:10,419 [INFO ] device:cuda
32
+ 2024-05-21 12:02:10,419 [INFO ] distributed:True
33
+ 2024-05-21 12:02:10,419 [INFO ] seed:0
34
+ 2024-05-21 12:02:10,419 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
35
+ 2024-05-21 12:02:10,419 [INFO ] start_epoch:0
36
+ 2024-05-21 12:02:10,419 [INFO ] pin_mem:True
37
+ 2024-05-21 12:02:10,419 [INFO ] save:True
38
+ 2024-05-21 12:02:10,419 [INFO ] exp_name:MIM_vbr_eval_lossless
39
+ 2024-05-21 12:02:10,419 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
40
+ 2024-05-21 12:02:10,419 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
41
+ 2024-05-21 12:02:10,419 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
42
+ 2024-05-21 12:02:10,419 [INFO ] strart_epoch:0
43
+ 2024-05-21 12:02:10,420 [INFO ] cuda:True
44
+ 2024-05-21 12:02:10,420 [INFO ] rank:0
45
+ 2024-05-21 12:02:10,420 [INFO ] gpu:0
46
+ 2024-05-21 12:02:10,420 [INFO ] dist_backend:nccl
47
+ 2024-05-21 12:02:10,420 [INFO ] ===================================================
48
+ 2024-05-21 12:03:00,851 [INFO ] -1|bpp:0.00784|mask:0.00908|mask_ratio:0.8|psnr:13.82288|msssim:0.34806|lpips:0.42462|dists:0.27493|Test loss:15.10523
49
+ 2024-05-21 12:03:41,020 [INFO ] -1|bpp:0.00987|mask:0.00891|mask_ratio:0.75|psnr:14.37770|msssim:0.39104|lpips:0.38284|dists:0.26121|Test loss:22.23518
50
+ 2024-05-21 12:04:19,453 [INFO ] -1|bpp:0.01170|mask:0.00879|mask_ratio:0.7|psnr:15.33662|msssim:0.45445|lpips:0.34781|dists:0.23473|Test loss:31.04415
51
+ 2024-05-21 12:05:01,263 [INFO ] -1|bpp:0.01369|mask:0.00867|mask_ratio:0.65|psnr:15.06704|msssim:0.46401|lpips:0.32283|dists:0.22039|Test loss:45.53056
52
+ 2024-05-21 12:05:46,943 [INFO ] -1|bpp:0.01568|mask:0.00854|mask_ratio:0.6|psnr:16.10357|msssim:0.51571|lpips:0.28952|dists:0.20649|Test loss:66.32143
MIM_vbr_eval_lossless/20240521_121146.log ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-05-21 12:11:46,727 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless//20240521_121146.log
2
+ 2024-05-21 12:11:46,727 [INFO ] ======================= exp =======================
3
+ 2024-05-21 12:11:46,727 [INFO ] config:config/vpt_default.yaml
4
+ 2024-05-21 12:11:46,727 [INFO ] name:exp
5
+ 2024-05-21 12:11:46,727 [INFO ] lr:None
6
+ 2024-05-21 12:11:46,727 [INFO ] world_size:1
7
+ 2024-05-21 12:11:46,727 [INFO ] local_rank:-1
8
+ 2024-05-21 12:11:46,727 [INFO ] dist_on_itp:False
9
+ 2024-05-21 12:11:46,727 [INFO ] dist_url:env://
10
+ 2024-05-21 12:11:46,727 [INFO ] TEST:False
11
+ 2024-05-21 12:11:46,727 [INFO ] model:mage_vit_base_patch16
12
+ 2024-05-21 12:11:46,728 [INFO ] dataset:imagenet
13
+ 2024-05-21 12:11:46,728 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet
14
+ 2024-05-21 12:11:46,728 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
15
+ 2024-05-21 12:11:46,728 [INFO ] eval_path:
16
+ 2024-05-21 12:11:46,728 [INFO ] batch_size:64
17
+ 2024-05-21 12:11:46,728 [INFO ] test_batch_size:200
18
+ 2024-05-21 12:11:46,728 [INFO ] num_workers:4
19
+ 2024-05-21 12:11:46,728 [INFO ] epochs:400
20
+ 2024-05-21 12:11:46,728 [INFO ] blr:0.0002
21
+ 2024-05-21 12:11:46,728 [INFO ] min_lr:5e-06
22
+ 2024-05-21 12:11:46,728 [INFO ] accum_iter:1
23
+ 2024-05-21 12:11:46,728 [INFO ] input_size:256
24
+ 2024-05-21 12:11:46,728 [INFO ] weight_decay:0.05
25
+ 2024-05-21 12:11:46,728 [INFO ] warmup_epochs:20
26
+ 2024-05-21 12:11:46,728 [INFO ] mask_ratio_min:0.5
27
+ 2024-05-21 12:11:46,728 [INFO ] mask_ratio_max:0.9
28
+ 2024-05-21 12:11:46,728 [INFO ] mask_ratio_mu:0.55
29
+ 2024-05-21 12:11:46,728 [INFO ] mask_ratio_std:0.25
30
+ 2024-05-21 12:11:46,728 [INFO ] grad_clip:3.0
31
+ 2024-05-21 12:11:46,728 [INFO ] device:cuda
32
+ 2024-05-21 12:11:46,729 [INFO ] distributed:True
33
+ 2024-05-21 12:11:46,729 [INFO ] seed:0
34
+ 2024-05-21 12:11:46,729 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
35
+ 2024-05-21 12:11:46,729 [INFO ] start_epoch:0
36
+ 2024-05-21 12:11:46,729 [INFO ] pin_mem:True
37
+ 2024-05-21 12:11:46,729 [INFO ] save:True
38
+ 2024-05-21 12:11:46,729 [INFO ] exp_name:MIM_vbr_eval_lossless
39
+ 2024-05-21 12:11:46,729 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
40
+ 2024-05-21 12:11:46,729 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
41
+ 2024-05-21 12:11:46,729 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless
42
+ 2024-05-21 12:11:46,729 [INFO ] strart_epoch:0
43
+ 2024-05-21 12:11:46,729 [INFO ] cuda:True
44
+ 2024-05-21 12:11:46,729 [INFO ] rank:0
45
+ 2024-05-21 12:11:46,729 [INFO ] gpu:0
46
+ 2024-05-21 12:11:46,729 [INFO ] dist_backend:nccl
47
+ 2024-05-21 12:11:46,729 [INFO ] ===================================================
48
+ 2024-05-21 12:12:38,993 [INFO ] -1|bpp:0.00784|mask:0.00908|mask_ratio:0.8|psnr:13.82276|msssim:0.34805|lpips:0.42462|dists:0.27493|Test loss:15.10523
49
+ 2024-05-21 12:13:18,373 [INFO ] -1|bpp:0.00987|mask:0.00891|mask_ratio:0.75|psnr:14.37755|msssim:0.39104|lpips:0.38284|dists:0.26121|Test loss:22.23518
50
+ 2024-05-21 12:13:58,355 [INFO ] -1|bpp:0.01170|mask:0.00879|mask_ratio:0.7|psnr:15.33651|msssim:0.45445|lpips:0.34781|dists:0.23473|Test loss:31.04415
51
+ 2024-05-21 12:14:38,346 [INFO ] -1|bpp:0.01369|mask:0.00867|mask_ratio:0.65|psnr:15.06691|msssim:0.46401|lpips:0.32283|dists:0.22039|Test loss:45.53056
52
+ 2024-05-21 12:15:17,341 [INFO ] -1|bpp:0.01568|mask:0.00854|mask_ratio:0.6|psnr:16.10348|msssim:0.51571|lpips:0.28952|dists:0.20649|Test loss:66.32143
53
+ 2024-05-21 12:15:59,518 [INFO ] -1|bpp:0.01769|mask:0.00842|mask_ratio:0.55|psnr:16.66886|msssim:0.54068|lpips:0.27387|dists:0.19902|Test loss:97.59130
54
+ 2024-05-21 12:16:43,259 [INFO ] -1|bpp:0.01969|mask:0.00830|mask_ratio:0.5|psnr:17.05398|msssim:0.55989|lpips:0.25349|dists:0.18627|Test loss:143.87851
55
+ 2024-05-21 12:17:21,314 [INFO ] -1|bpp:0.02154|mask:0.00818|mask_ratio:0.45|psnr:17.42430|msssim:0.58354|lpips:0.24210|dists:0.17977|Test loss:203.66936
56
+ 2024-05-21 12:18:04,417 [INFO ] -1|bpp:0.02357|mask:0.00794|mask_ratio:0.4|psnr:17.89758|msssim:0.60007|lpips:0.22977|dists:0.17287|Test loss:304.47241
57
+ 2024-05-21 12:18:46,874 [INFO ] -1|bpp:0.02555|mask:0.00781|mask_ratio:0.35|psnr:18.47392|msssim:0.62136|lpips:0.21420|dists:0.16763|Test loss:447.87213
58
+ 2024-05-21 12:19:30,648 [INFO ] -1|bpp:0.02756|mask:0.00769|mask_ratio:0.3|psnr:18.70894|msssim:0.64106|lpips:0.19699|dists:0.16134|Test loss:662.32788
59
+ 2024-05-21 12:20:10,918 [INFO ] -1|bpp:0.02956|mask:0.00757|mask_ratio:0.25|psnr:18.98025|msssim:0.65567|lpips:0.18680|dists:0.15801|Test loss:983.28394
60
+ 2024-05-21 12:20:50,854 [INFO ] -1|bpp:0.03138|mask:0.00745|mask_ratio:0.2|psnr:19.34285|msssim:0.66431|lpips:0.18044|dists:0.15229|Test loss:1385.45557
61
+ 2024-05-21 12:21:32,470 [INFO ] -1|bpp:0.03339|mask:0.00732|mask_ratio:0.15|psnr:19.71813|msssim:0.68537|lpips:0.16575|dists:0.14260|Test loss:2053.51758
62
+ 2024-05-21 12:22:12,619 [INFO ] -1|bpp:0.03539|mask:0.00720|mask_ratio:0.1|psnr:20.06234|msssim:0.69720|lpips:0.15856|dists:0.14339|Test loss:3017.49292
63
+ 2024-05-21 12:22:51,448 [INFO ] -1|bpp:0.03739|mask:0.00703|mask_ratio:0.05|psnr:20.41125|msssim:0.71023|lpips:0.15008|dists:0.14040|Test loss:4599.72754
VQ_w_Prompt/20240602_123240.log ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-06-02 12:32:40,143 [INFO ] Logging file is /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt//20240602_123240.log
2
+ 2024-06-02 12:32:40,143 [INFO ] ======================= exp =======================
3
+ 2024-06-02 12:32:40,143 [INFO ] config:config/vpt_default.yaml
4
+ 2024-06-02 12:32:40,143 [INFO ] name:exp
5
+ 2024-06-02 12:32:40,144 [INFO ] lr:None
6
+ 2024-06-02 12:32:40,144 [INFO ] world_size:1
7
+ 2024-06-02 12:32:40,144 [INFO ] local_rank:-1
8
+ 2024-06-02 12:32:40,144 [INFO ] dist_on_itp:False
9
+ 2024-06-02 12:32:40,144 [INFO ] dist_url:env://
10
+ 2024-06-02 12:32:40,144 [INFO ] TEST:False
11
+ 2024-06-02 12:32:40,144 [INFO ] model:mage_vit_base_patch16
12
+ 2024-06-02 12:32:40,144 [INFO ] dataset:coco
13
+ 2024-06-02 12:32:40,144 [INFO ] dataset_path:/home/t2vg-a100-G4-10/project/qyp/datasets/COCO
14
+ 2024-06-02 12:32:40,144 [INFO ] kodak_path:/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/
15
+ 2024-06-02 12:32:40,144 [INFO ] eval_path:
16
+ 2024-06-02 12:32:40,144 [INFO ] batch_size:128
17
+ 2024-06-02 12:32:40,144 [INFO ] learning_rate:0.0001
18
+ 2024-06-02 12:32:40,144 [INFO ] test_batch_size:1
19
+ 2024-06-02 12:32:40,144 [INFO ] num_workers:4
20
+ 2024-06-02 12:32:40,144 [INFO ] epochs:100
21
+ 2024-06-02 12:32:40,144 [INFO ] blr:0.00015
22
+ 2024-06-02 12:32:40,144 [INFO ] min_lr:5e-06
23
+ 2024-06-02 12:32:40,144 [INFO ] accum_iter:1
24
+ 2024-06-02 12:32:40,144 [INFO ] input_size:256
25
+ 2024-06-02 12:32:40,144 [INFO ] weight_decay:0.05
26
+ 2024-06-02 12:32:40,144 [INFO ] warmup_epochs:10
27
+ 2024-06-02 12:32:40,144 [INFO ] mask_ratio_min:0.5
28
+ 2024-06-02 12:32:40,144 [INFO ] mask_ratio_max:0.9
29
+ 2024-06-02 12:32:40,144 [INFO ] mask_ratio_mu:0.55
30
+ 2024-06-02 12:32:40,145 [INFO ] mask_ratio_std:0.25
31
+ 2024-06-02 12:32:40,145 [INFO ] grad_clip:3.0
32
+ 2024-06-02 12:32:40,145 [INFO ] device:cuda
33
+ 2024-06-02 12:32:40,145 [INFO ] distributed:False
34
+ 2024-06-02 12:32:40,145 [INFO ] seed:0
35
+ 2024-06-02 12:32:40,145 [INFO ] gpu_id:0,1
36
+ 2024-06-02 12:32:40,145 [INFO ] resume:/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth
37
+ 2024-06-02 12:32:40,145 [INFO ] maskrcnn_path:/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl
38
+ 2024-06-02 12:32:40,145 [INFO ] start_epoch:0
39
+ 2024-06-02 12:32:40,145 [INFO ] pin_mem:True
40
+ 2024-06-02 12:32:40,145 [INFO ] save:True
41
+ 2024-06-02 12:32:40,145 [INFO ] exp_name:VQ_w_Prompt
42
+ 2024-06-02 12:32:40,145 [INFO ] root:/home/t2vg-a100-G4-10/project/qyp/mimc_rope
43
+ 2024-06-02 12:32:40,145 [INFO ] output_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt
44
+ 2024-06-02 12:32:40,145 [INFO ] log_dir:/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt
45
+ 2024-06-02 12:32:40,145 [INFO ] strart_epoch:0
46
+ 2024-06-02 12:32:40,145 [INFO ] cuda:True
47
+ 2024-06-02 12:32:40,145 [INFO ] ===================================================
48
+ 2024-06-02 12:36:03,409 [INFO ] Epoch: -1 | perc loss: 0.33139 | psnr: 21.90954 | lpips: 0.22209 | dists: 0.09723
49
+ 2024-06-02 12:55:03,529 [INFO ] Epoch: 0 | perc loss: 0.33168 | psnr: 21.83200 | lpips: 0.22776 | dists: 0.11052
50
+ 2024-06-02 12:55:03,533 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
51
+ 2024-06-02 12:55:03,812 [INFO ] Saving BEST checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
52
+ 2024-06-02 13:13:51,301 [INFO ] Epoch: 1 | perc loss: 0.33369 | psnr: 21.56722 | lpips: 0.22748 | dists: 0.10901
53
+ 2024-06-02 13:13:51,305 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
54
+ 2024-06-02 13:32:37,056 [INFO ] Epoch: 2 | perc loss: 0.32811 | psnr: 21.75561 | lpips: 0.22586 | dists: 0.10543
55
+ 2024-06-02 13:32:37,060 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
56
+ 2024-06-02 13:32:37,942 [INFO ] Saving BEST checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
57
+ 2024-06-02 13:51:19,223 [INFO ] Epoch: 3 | perc loss: 0.33057 | psnr: 21.71532 | lpips: 0.22620 | dists: 0.10827
58
+ 2024-06-02 13:51:19,227 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
59
+ 2024-06-02 14:10:08,103 [INFO ] Epoch: 4 | perc loss: 0.33012 | psnr: 21.70657 | lpips: 0.22693 | dists: 0.10772
60
+ 2024-06-02 14:10:08,108 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
61
+ 2024-06-02 14:28:55,372 [INFO ] Epoch: 5 | perc loss: 0.33078 | psnr: 21.71807 | lpips: 0.22735 | dists: 0.10827
62
+ 2024-06-02 14:28:55,376 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
63
+ 2024-06-02 14:47:49,287 [INFO ] Epoch: 6 | perc loss: 0.33090 | psnr: 21.79235 | lpips: 0.22534 | dists: 0.10498
64
+ 2024-06-02 14:47:49,291 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
65
+ 2024-06-02 15:06:38,406 [INFO ] Epoch: 7 | perc loss: 0.33033 | psnr: 21.74950 | lpips: 0.22632 | dists: 0.10654
66
+ 2024-06-02 15:06:38,410 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
67
+ 2024-06-02 15:25:29,010 [INFO ] Epoch: 8 | perc loss: 0.32939 | psnr: 21.71603 | lpips: 0.22598 | dists: 0.10593
68
+ 2024-06-02 15:25:29,014 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
69
+ 2024-06-02 15:44:21,498 [INFO ] Epoch: 9 | perc loss: 0.33421 | psnr: 21.67705 | lpips: 0.22861 | dists: 0.10923
70
+ 2024-06-02 15:44:21,502 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
71
+ 2024-06-02 16:03:15,138 [INFO ] Epoch: 10 | perc loss: 0.33003 | psnr: 21.75746 | lpips: 0.22706 | dists: 0.10813
72
+ 2024-06-02 16:03:15,141 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
73
+ 2024-06-02 16:21:55,573 [INFO ] Epoch: 11 | perc loss: 0.33178 | psnr: 21.70939 | lpips: 0.22774 | dists: 0.10836
74
+ 2024-06-02 16:21:55,577 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
75
+ 2024-06-02 16:40:34,144 [INFO ] Epoch: 12 | perc loss: 0.33171 | psnr: 21.89004 | lpips: 0.22690 | dists: 0.10719
76
+ 2024-06-02 16:40:34,147 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
77
+ 2024-06-02 16:59:27,863 [INFO ] Epoch: 13 | perc loss: 0.32893 | psnr: 21.72392 | lpips: 0.22659 | dists: 0.10690
78
+ 2024-06-02 16:59:27,866 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
79
+ 2024-06-02 17:18:09,724 [INFO ] Epoch: 14 | perc loss: 0.33351 | psnr: 21.66716 | lpips: 0.22815 | dists: 0.10899
80
+ 2024-06-02 17:18:09,727 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
81
+ 2024-06-02 17:36:50,797 [INFO ] Epoch: 15 | perc loss: 0.33180 | psnr: 21.96293 | lpips: 0.22651 | dists: 0.10622
82
+ 2024-06-02 17:36:50,801 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
83
+ 2024-06-02 17:55:29,103 [INFO ] Epoch: 16 | perc loss: 0.33189 | psnr: 21.86891 | lpips: 0.22776 | dists: 0.10837
84
+ 2024-06-02 17:55:29,107 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
85
+ 2024-06-02 18:14:11,530 [INFO ] Epoch: 17 | perc loss: 0.33210 | psnr: 21.85856 | lpips: 0.22759 | dists: 0.10800
86
+ 2024-06-02 18:14:11,534 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
87
+ 2024-06-02 18:32:51,595 [INFO ] Epoch: 18 | perc loss: 0.33330 | psnr: 21.89861 | lpips: 0.22754 | dists: 0.10816
88
+ 2024-06-02 18:32:51,598 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
89
+ 2024-06-02 18:51:33,222 [INFO ] Epoch: 19 | perc loss: 0.33425 | psnr: 21.91737 | lpips: 0.22715 | dists: 0.10672
90
+ 2024-06-02 18:51:33,226 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
91
+ 2024-06-02 19:10:20,889 [INFO ] Epoch: 20 | perc loss: 0.33207 | psnr: 21.79711 | lpips: 0.22724 | dists: 0.10738
92
+ 2024-06-02 19:10:20,892 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
93
+ 2024-06-02 19:29:09,240 [INFO ] Epoch: 21 | perc loss: 0.33146 | psnr: 21.83736 | lpips: 0.22665 | dists: 0.10620
94
+ 2024-06-02 19:29:09,244 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
95
+ 2024-06-02 19:47:50,856 [INFO ] Epoch: 22 | perc loss: 0.33375 | psnr: 21.78199 | lpips: 0.22757 | dists: 0.10758
96
+ 2024-06-02 19:47:50,860 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
97
+ 2024-06-02 20:06:27,901 [INFO ] Epoch: 23 | perc loss: 0.33233 | psnr: 21.79922 | lpips: 0.22696 | dists: 0.10653
98
+ 2024-06-02 20:06:27,905 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
99
+ 2024-06-02 20:25:12,269 [INFO ] Epoch: 24 | perc loss: 0.33376 | psnr: 21.88994 | lpips: 0.22836 | dists: 0.10865
100
+ 2024-06-02 20:25:12,272 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
101
+ 2024-06-02 20:43:57,739 [INFO ] Epoch: 25 | perc loss: 0.33358 | psnr: 21.91691 | lpips: 0.22740 | dists: 0.10682
102
+ 2024-06-02 20:43:57,743 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
103
+ 2024-06-02 21:02:38,787 [INFO ] Epoch: 26 | perc loss: 0.33293 | psnr: 21.82560 | lpips: 0.22706 | dists: 0.10724
104
+ 2024-06-02 21:02:38,790 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
105
+ 2024-06-02 21:21:23,150 [INFO ] Epoch: 27 | perc loss: 0.33276 | psnr: 21.93110 | lpips: 0.22654 | dists: 0.10670
106
+ 2024-06-02 21:21:23,154 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
107
+ 2024-06-02 21:40:05,814 [INFO ] Epoch: 28 | perc loss: 0.33448 | psnr: 21.88258 | lpips: 0.22824 | dists: 0.10863
108
+ 2024-06-02 21:40:05,818 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
109
+ 2024-06-02 21:58:50,101 [INFO ] Epoch: 29 | perc loss: 0.33435 | psnr: 21.86732 | lpips: 0.22764 | dists: 0.10790
110
+ 2024-06-02 21:58:50,105 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
111
+ 2024-06-02 22:17:35,636 [INFO ] Epoch: 30 | perc loss: 0.33510 | psnr: 21.83126 | lpips: 0.22859 | dists: 0.10916
112
+ 2024-06-02 22:17:35,640 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
113
+ 2024-06-02 22:36:25,572 [INFO ] Epoch: 31 | perc loss: 0.33503 | psnr: 21.87967 | lpips: 0.22826 | dists: 0.10809
114
+ 2024-06-02 22:36:25,576 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
115
+ 2024-06-02 22:55:03,799 [INFO ] Epoch: 32 | perc loss: 0.33180 | psnr: 21.86306 | lpips: 0.22669 | dists: 0.10689
116
+ 2024-06-02 22:55:03,803 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
117
+ 2024-06-02 23:13:45,201 [INFO ] Epoch: 33 | perc loss: 0.33209 | psnr: 21.90117 | lpips: 0.22692 | dists: 0.10679
118
+ 2024-06-02 23:13:45,206 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
119
+ 2024-06-02 23:32:29,814 [INFO ] Epoch: 34 | perc loss: 0.33378 | psnr: 21.78412 | lpips: 0.22783 | dists: 0.10720
120
+ 2024-06-02 23:32:29,818 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
121
+ 2024-06-02 23:51:14,307 [INFO ] Epoch: 35 | perc loss: 0.33245 | psnr: 21.80858 | lpips: 0.22757 | dists: 0.10710
122
+ 2024-06-02 23:51:14,311 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
123
+ 2024-06-03 00:09:57,617 [INFO ] Epoch: 36 | perc loss: 0.33471 | psnr: 21.78929 | lpips: 0.22853 | dists: 0.10903
124
+ 2024-06-03 00:09:57,620 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
125
+ 2024-06-03 00:28:36,627 [INFO ] Epoch: 37 | perc loss: 0.33349 | psnr: 21.98594 | lpips: 0.22731 | dists: 0.10736
126
+ 2024-06-03 00:28:36,630 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
127
+ 2024-06-03 00:47:20,972 [INFO ] Epoch: 38 | perc loss: 0.33425 | psnr: 21.97987 | lpips: 0.22731 | dists: 0.10701
128
+ 2024-06-03 00:47:20,976 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
129
+ 2024-06-03 01:05:59,529 [INFO ] Epoch: 39 | perc loss: 0.33325 | psnr: 21.69044 | lpips: 0.22760 | dists: 0.10743
130
+ 2024-06-03 01:05:59,534 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
131
+ 2024-06-03 01:24:45,951 [INFO ] Epoch: 40 | perc loss: 0.33306 | psnr: 21.84931 | lpips: 0.22768 | dists: 0.10715
132
+ 2024-06-03 01:24:45,955 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
133
+ 2024-06-03 01:43:29,237 [INFO ] Epoch: 41 | perc loss: 0.33394 | psnr: 21.76097 | lpips: 0.22829 | dists: 0.10813
134
+ 2024-06-03 01:43:29,240 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
135
+ 2024-06-03 02:02:16,998 [INFO ] Epoch: 42 | perc loss: 0.33422 | psnr: 22.03626 | lpips: 0.22767 | dists: 0.10776
136
+ 2024-06-03 02:02:17,001 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
137
+ 2024-06-03 02:20:56,025 [INFO ] Epoch: 43 | perc loss: 0.33415 | psnr: 21.82381 | lpips: 0.22773 | dists: 0.10845
138
+ 2024-06-03 02:20:56,028 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
139
+ 2024-06-03 02:39:40,184 [INFO ] Epoch: 44 | perc loss: 0.33257 | psnr: 21.83755 | lpips: 0.22701 | dists: 0.10675
140
+ 2024-06-03 02:39:40,189 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
141
+ 2024-06-03 02:58:26,226 [INFO ] Epoch: 45 | perc loss: 0.33221 | psnr: 21.75377 | lpips: 0.22706 | dists: 0.10657
142
+ 2024-06-03 02:58:26,230 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
143
+ 2024-06-03 03:17:13,245 [INFO ] Epoch: 46 | perc loss: 0.33469 | psnr: 21.82597 | lpips: 0.22778 | dists: 0.10752
144
+ 2024-06-03 03:17:13,248 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
145
+ 2024-06-03 03:35:52,662 [INFO ] Epoch: 47 | perc loss: 0.33427 | psnr: 21.91075 | lpips: 0.22704 | dists: 0.10703
146
+ 2024-06-03 03:35:52,666 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
147
+ 2024-06-03 03:54:34,629 [INFO ] Epoch: 48 | perc loss: 0.33344 | psnr: 21.73634 | lpips: 0.22771 | dists: 0.10786
148
+ 2024-06-03 03:54:34,634 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
149
+ 2024-06-03 04:13:18,077 [INFO ] Epoch: 49 | perc loss: 0.33254 | psnr: 21.84265 | lpips: 0.22720 | dists: 0.10674
150
+ 2024-06-03 04:13:18,081 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
151
+ 2024-06-03 04:31:57,392 [INFO ] Epoch: 50 | perc loss: 0.33315 | psnr: 21.84459 | lpips: 0.22754 | dists: 0.10744
152
+ 2024-06-03 04:31:57,396 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
153
+ 2024-06-03 04:50:35,589 [INFO ] Epoch: 51 | perc loss: 0.33395 | psnr: 21.88569 | lpips: 0.22835 | dists: 0.10805
154
+ 2024-06-03 04:50:35,593 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
155
+ 2024-06-03 05:09:26,310 [INFO ] Epoch: 52 | perc loss: 0.33427 | psnr: 21.80637 | lpips: 0.22858 | dists: 0.10838
156
+ 2024-06-03 05:09:26,314 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
157
+ 2024-06-03 05:28:16,003 [INFO ] Epoch: 53 | perc loss: 0.33332 | psnr: 21.78990 | lpips: 0.22774 | dists: 0.10752
158
+ 2024-06-03 05:28:16,008 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
159
+ 2024-06-03 05:46:56,485 [INFO ] Epoch: 54 | perc loss: 0.33440 | psnr: 21.92330 | lpips: 0.22701 | dists: 0.10652
160
+ 2024-06-03 05:46:56,489 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
161
+ 2024-06-03 06:05:39,283 [INFO ] Epoch: 55 | perc loss: 0.33324 | psnr: 21.91578 | lpips: 0.22739 | dists: 0.10760
162
+ 2024-06-03 06:05:39,287 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
163
+ 2024-06-03 06:24:16,107 [INFO ] Epoch: 56 | perc loss: 0.33304 | psnr: 21.88298 | lpips: 0.22694 | dists: 0.10703
164
+ 2024-06-03 06:24:16,111 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
165
+ 2024-06-03 06:43:07,089 [INFO ] Epoch: 57 | perc loss: 0.33419 | psnr: 21.84341 | lpips: 0.22722 | dists: 0.10693
166
+ 2024-06-03 06:43:07,093 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
167
+ 2024-06-03 07:01:46,596 [INFO ] Epoch: 58 | perc loss: 0.33342 | psnr: 21.82555 | lpips: 0.22763 | dists: 0.10721
168
+ 2024-06-03 07:01:46,600 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
169
+ 2024-06-03 07:20:46,044 [INFO ] Epoch: 59 | perc loss: 0.33426 | psnr: 21.89526 | lpips: 0.22781 | dists: 0.10761
170
+ 2024-06-03 07:20:46,047 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
171
+ 2024-06-03 07:39:26,255 [INFO ] Epoch: 60 | perc loss: 0.33444 | psnr: 21.91903 | lpips: 0.22711 | dists: 0.10757
172
+ 2024-06-03 07:39:26,259 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
173
+ 2024-06-03 07:58:06,507 [INFO ] Epoch: 61 | perc loss: 0.33431 | psnr: 21.97837 | lpips: 0.22842 | dists: 0.10904
174
+ 2024-06-03 07:58:06,512 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
175
+ 2024-06-03 08:16:49,835 [INFO ] Epoch: 62 | perc loss: 0.33514 | psnr: 21.99198 | lpips: 0.22782 | dists: 0.10722
176
+ 2024-06-03 08:16:49,839 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
177
+ 2024-06-03 08:35:33,361 [INFO ] Epoch: 63 | perc loss: 0.33423 | psnr: 21.85692 | lpips: 0.22768 | dists: 0.10727
178
+ 2024-06-03 08:35:33,366 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
179
+ 2024-06-03 08:54:14,939 [INFO ] Epoch: 64 | perc loss: 0.33443 | psnr: 21.82958 | lpips: 0.22788 | dists: 0.10829
180
+ 2024-06-03 08:54:14,942 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
181
+ 2024-06-03 09:12:53,969 [INFO ] Epoch: 65 | perc loss: 0.33494 | psnr: 21.82883 | lpips: 0.22806 | dists: 0.10793
182
+ 2024-06-03 09:12:53,973 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
183
+ 2024-06-03 09:31:37,870 [INFO ] Epoch: 66 | perc loss: 0.33452 | psnr: 21.88860 | lpips: 0.22774 | dists: 0.10777
184
+ 2024-06-03 09:31:37,874 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
185
+ 2024-06-03 09:50:15,999 [INFO ] Epoch: 67 | perc loss: 0.33397 | psnr: 21.87563 | lpips: 0.22739 | dists: 0.10670
186
+ 2024-06-03 09:50:16,002 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
187
+ 2024-06-03 10:09:03,337 [INFO ] Epoch: 68 | perc loss: 0.33403 | psnr: 21.84018 | lpips: 0.22757 | dists: 0.10754
188
+ 2024-06-03 10:09:03,340 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
189
+ 2024-06-03 10:27:57,411 [INFO ] Epoch: 69 | perc loss: 0.33503 | psnr: 21.81171 | lpips: 0.22727 | dists: 0.10667
190
+ 2024-06-03 10:27:57,415 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
191
+ 2024-06-03 10:46:47,458 [INFO ] Epoch: 70 | perc loss: 0.33324 | psnr: 21.89323 | lpips: 0.22725 | dists: 0.10659
192
+ 2024-06-03 10:46:47,462 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
193
+ 2024-06-03 11:05:28,932 [INFO ] Epoch: 71 | perc loss: 0.33461 | psnr: 21.83094 | lpips: 0.22864 | dists: 0.10749
194
+ 2024-06-03 11:05:28,935 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
195
+ 2024-06-03 11:24:26,070 [INFO ] Epoch: 72 | perc loss: 0.33356 | psnr: 21.72290 | lpips: 0.22737 | dists: 0.10574
196
+ 2024-06-03 11:24:26,074 [INFO ] Saving checkpoint: /home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/checkpoint.pth.tar
VQ_w_Prompt/checkpoint_69.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c013e63dc8fc03dfa3945c99aff18bd6e0e98c351bc205059bb44bd2e329ed45
3
+ size 281021679
__pycache__/model_vq.cpython-38.pyc ADDED
Binary file (3.84 kB). View file
 
__pycache__/models_mage_codec.cpython-38.pyc ADDED
Binary file (21 kB). View file
 
__pycache__/models_mage_codec_cls.cpython-38.pyc ADDED
Binary file (21.7 kB). View file
 
__pycache__/models_mage_codec_full.cpython-38.pyc ADDED
Binary file (22.2 kB). View file
 
__pycache__/models_mage_codec_high_resolu.cpython-38.pyc ADDED
Binary file (22.4 kB). View file
 
__pycache__/models_mage_codec_hr_new.cpython-38.pyc ADDED
Binary file (21.8 kB). View file
 
__pycache__/models_mage_codec_hs_w_p.cpython-38.pyc ADDED
Binary file (24 kB). View file
 
__pycache__/models_mage_codec_rope.cpython-38.pyc ADDED
Binary file (21.3 kB). View file
 
cal_upper_bound.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import sys
4
+ import os
5
+ import time
6
+ import logging
7
+ from datetime import datetime
8
+ from model_vq import Model_VQ
9
+ import torch
10
+ import torch.nn as nn
11
+ from omegaconf import OmegaConf
12
+
13
+ import yaml
14
+ from pytorch_msssim import ms_ssim
15
+ from DISTS_pytorch import DISTS
16
+ import lpips
17
+ from torch.nn import functional as F
18
+ from torchvision import utils as vutils
19
+ import numpy as np
20
+ import glob
21
+
22
+ import util.misc as misc
23
+ import PIL.Image as Image
24
+ import torch.backends.cudnn as cudnn
25
+ from pathlib import Path
26
+ import os
27
+
28
+ os.environ['CUDA_VISIBLE_DEVICES'] = '3'
29
+
30
+
31
+ class CalMetrics(nn.Module):
32
+ """Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.mse = nn.MSELoss()
37
+
38
+ def psnr(self, rec, ori):
39
+ mse = torch.mean((rec - ori) ** 2)
40
+ if(mse == 0):
41
+ return 100
42
+ max_pixel = 1.
43
+ psnr = 10 * torch.log10(max_pixel / mse)
44
+ return torch.mean(psnr)
45
+
46
+ def lpips_vgg(self, rec, ori):
47
+ loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
48
+ lipis_vgg = loss_fn_vgg(rec, ori)
49
+ return lipis_vgg
50
+
51
+ def lpips_alex(self, rec, ori):
52
+ loss_fn_alex = lpips.LPIPS(net='alex').cuda()
53
+ lipis_alex = loss_fn_alex(rec, ori)
54
+ return lipis_alex
55
+
56
+ def dists(self, rec, ori):
57
+ D = DISTS().cuda()
58
+ dists_value = D(rec, ori)
59
+ return dists_value
60
+
61
+ def forward(self, ori, rec):
62
+ out = {}
63
+ if rec is not None:
64
+ out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
65
+ out["lpips_vgg"] = self.lpips_vgg(torch.clamp(rec, 0, 1), ori)
66
+ out["lpips_alex"] = self.lpips_alex(torch.clamp(rec, 0, 1), ori)
67
+ out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
68
+ return out
69
+
70
+
71
+ class AverageMeter:
72
+ """Compute running average."""
73
+
74
+ def __init__(self):
75
+ self.val = 0
76
+ self.avg = 0
77
+ self.sum = 0
78
+ self.count = 0
79
+
80
+ def update(self, val, n=1):
81
+ self.val = val
82
+ self.sum += val * n
83
+ self.count += n
84
+ self.avg = self.sum / self.count
85
+
86
+
87
+ class CustomDataParallel(nn.DataParallel):
88
+ """Custom DataParallel to access the module methods."""
89
+
90
+ def __getattr__(self, key):
91
+ try:
92
+ return super().__getattr__(key)
93
+ except AttributeError:
94
+ return getattr(self.module, key)
95
+
96
+
97
+ def init(args):
98
+ base_dir = f'{args.root}/{args.exp_name}/'
99
+ os.makedirs(base_dir, exist_ok=True)
100
+ return base_dir
101
+
102
+ def setup_logger(log_dir):
103
+ log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
104
+ root_logger = logging.getLogger()
105
+ root_logger.setLevel(logging.INFO)
106
+
107
+ log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
108
+ log_file_handler.setFormatter(log_formatter)
109
+ root_logger.addHandler(log_file_handler)
110
+
111
+ log_stream_handler = logging.StreamHandler(sys.stdout)
112
+ log_stream_handler.setFormatter(log_formatter)
113
+ root_logger.addHandler(log_stream_handler)
114
+
115
+ logging.info('Logging file is %s' % log_dir)
116
+
117
+
118
+ def load_img(p, padding=True, factor=64):
119
+ x = Image.open(p)
120
+ x = torch.from_numpy(np.asarray(x))
121
+ if len(x.shape) == 2:
122
+ x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3
123
+ x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
124
+ h, w = x.shape[2:4]
125
+
126
+ if padding:
127
+ dh = factor * math.ceil(h / factor) - h
128
+ dw = factor * math.ceil(w / factor) - w
129
+ # 均匀添加padding
130
+ dh_half = dh // 2
131
+ dw_half = dw // 2
132
+ dh_extra = dh % 2
133
+ dw_extra = dw % 2
134
+ x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra))
135
+ return x, h, w
136
+
137
+
138
+
139
+ def save_img(img: torch.Tensor, vis_path, input_p, rec=False):
140
+ img = img.clone().detach()
141
+ img = img.to(torch.device('cpu'))
142
+ if os.path.isdir(vis_path) is not True:
143
+ os.makedirs(vis_path)
144
+ end = '/'
145
+ if rec:
146
+ vis_path = vis_path + '/rec'
147
+ if os.path.isdir(vis_path) is not True:
148
+ os.makedirs(vis_path)
149
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
150
+ else:
151
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
152
+ vutils.save_image(img, os.path.join(img_name), nrow=8)
153
+
154
+
155
+ def inference(epoch, eval_path, model, metrics_criterion, device, stage='test'):
156
+ model.eval()
157
+ psnr = AverageMeter()
158
+ lpips_vgg = AverageMeter()
159
+ lpips_alex = AverageMeter()
160
+ dists = AverageMeter()
161
+
162
+ vis_path = os.path.join("./VQGAN/", stage)
163
+ os.makedirs(vis_path, exist_ok=True)
164
+
165
+ with torch.no_grad():
166
+ for input_p in eval_path:
167
+ x, hx, wx = load_img(input_p, padding=True, factor=64)
168
+ x = x.to(device)
169
+ rec = model(x)
170
+
171
+ x = x[:, :, :hx, :wx]
172
+ rec = rec[:, :, :hx, :wx]
173
+ rec = rec.to(device)
174
+ out_criterion = metrics_criterion(x, rec)
175
+
176
+ psnr.update(out_criterion['psnr'])
177
+ lpips_vgg.update(out_criterion['lpips_vgg'])
178
+ lpips_alex.update(out_criterion['lpips_alex'])
179
+ dists.update(out_criterion['dists'])
180
+
181
+ ## ======================= update progress bar & visualization ======================= ##
182
+ # save_img(x, vis_path, input_p)
183
+ save_img(rec, vis_path, input_p, rec=True)
184
+
185
+ model.train()
186
+ log_txt = f"{epoch}|psnr:{psnr.avg:.5f}|lpips_vgg:{lpips_vgg.avg.mean().item():.5f}|lpips_alex:{lpips_alex.avg.mean().item():.5f}|dists:{dists.avg.mean().item():.5f}"
187
+ logging.info(log_txt)
188
+ return psnr
189
+
190
+ def parse_args(argv):
191
+ parser = argparse.ArgumentParser(description="Example training script.")
192
+ parser.add_argument(
193
+ "-c",
194
+ "--config",
195
+ default="/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/cal_upper_bound.yaml",
196
+ help="Path to config file",
197
+ )
198
+ parser.add_argument(
199
+ '--name',
200
+ default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
201
+ type=str,
202
+ help='Result dir name',
203
+ )
204
+ parser.add_argument(
205
+ '--eval_path',
206
+ default='/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017',
207
+ type=str,
208
+ help='path to the evaluation dataset',
209
+ )
210
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
211
+ help='learning rate (absolute lr)')
212
+ given_configs, remaining = parser.parse_known_args(argv)
213
+ # distributed training parameters
214
+ parser.add_argument('--world_size', default=1, type=int,
215
+ help='number of distributed processes')
216
+ parser.add_argument('--local-rank', default=-1, type=int)
217
+ parser.add_argument('--dist_on_itp', action='store_true')
218
+ parser.add_argument('--dist_url', default='env://',
219
+ help='url used to set up distributed training')
220
+ with open(given_configs.config) as file:
221
+ yaml_data= yaml.safe_load(file)
222
+ parser.set_defaults(**yaml_data)
223
+
224
+ parser.add_argument(
225
+ "-T",
226
+ "--TEST",
227
+ action='store_true',
228
+ help='Testing'
229
+ )
230
+ args = parser.parse_args(remaining)
231
+ return args
232
+
233
+ def load_eval_ps(eval_path):
234
+ eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.jpg')))
235
+ return eval_ps
236
+
237
+ def main(argv):
238
+ args = parse_args(argv)
239
+ base_dir = init(args) # create the base dir for saving the results
240
+
241
+ if args.output_dir:
242
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
243
+ args.log_dir = args.output_dir
244
+
245
+ misc.init_distributed_mode(args)
246
+
247
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
248
+ print("{}".format(args).replace(', ', ',\n'))
249
+
250
+ device = torch.device(args.device)
251
+ # fix the seed for reproducibility
252
+ seed = args.seed + misc.get_rank()
253
+ torch.manual_seed(seed)
254
+ np.random.seed(seed)
255
+
256
+ cudnn.benchmark = True
257
+
258
+ setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
259
+ msg = f'======================= {args.name} ======================='
260
+ logging.info(msg)
261
+ for k in args.__dict__:
262
+ logging.info(k + ':' + str(args.__dict__[k]))
263
+ logging.info('=' * len(msg))
264
+
265
+ ## ======================= prepare dataset ======================= ##
266
+ eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.jpg')))
267
+
268
+ device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
269
+
270
+ ## ======================= prepare model ======================= ##
271
+ vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
272
+ config = OmegaConf.load('config/vqgan.yaml').model
273
+ model = Model_VQ(ddconfig=config.params.ddconfig,
274
+ n_embed=config.params.n_embed, # 1024
275
+ embed_dim=config.params.embed_dim, # 256
276
+ ckpt_path=vqgan_ckpt_path)
277
+
278
+ model.to(device)
279
+
280
+ metrics_criterion = CalMetrics()
281
+ ## ======================= pre validation ======================= ##
282
+ test_loss = inference(-1, eval_path, model, metrics_criterion, device, 'val')
283
+
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main(sys.argv[1:])
ckpt_pretrained/vgg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868
3
+ size 7289
coco_eval.py ADDED
File without changes
config/cal_upper_bound.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak'
2
+ eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017'
3
+ # eval_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/shark_new/'
4
+ device: "cuda"
5
+ seed: 0
6
+ save: True
7
+ name: "VQ"
8
+ exp_name: "VQ"
9
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
10
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ" # path where to save, empty for no saving
11
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ/logs" # path where to tensorboard log
12
+ strart_epoch: 0
13
+ cuda: True
config/eval copy.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
5
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
6
+ eval_path: ''
7
+ batch_size: 128 # Batch size per GPU
8
+ learning_rate: 1.e-4
9
+ test_batch_size: 1
10
+ num_workers: 4
11
+ epochs: 100
12
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
13
+ min_lr: 5.e-6
14
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
15
+ input_size: 256
16
+ weight_decay: 0.05
17
+ warmup_epochs: 10
18
+ mask_ratio_min: 0.5
19
+ mask_ratio_max: 0.9
20
+ mask_ratio_mu: 0.55
21
+ mask_ratio_std: 0.25
22
+ grad_clip: 3.0
23
+ device: "cuda"
24
+ distributed: True
25
+ seed: 0
26
+ gpu_id: 0,1
27
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
28
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
29
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
30
+ maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
31
+ start_epoch: 0
32
+ pin_mem: True
33
+ save: True
34
+ name: "exp"
35
+ exp_name: "VQ_w_Prompt"
36
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
37
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt" # path where to save, empty for no saving
38
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/logs" # path where to tensorboard log
39
+ strart_epoch: 0
40
+ cuda: True
41
+ TEST: True
config/eval.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
5
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
6
+ eval_path: ''
7
+ batch_size: 128 # Batch size per GPU
8
+ learning_rate: 1.e-4
9
+ test_batch_size: 1
10
+ num_workers: 4
11
+ epochs: 100
12
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
13
+ min_lr: 5.e-6
14
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
15
+ input_size: 256
16
+ weight_decay: 0.05
17
+ warmup_epochs: 10
18
+ mask_ratio_min: 0.5
19
+ mask_ratio_max: 0.9
20
+ mask_ratio_mu: 0.55
21
+ mask_ratio_std: 0.25
22
+ grad_clip: 3.0
23
+ device: "cuda"
24
+ distributed: True
25
+ seed: 0
26
+ gpu_id: 0,1
27
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
28
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
29
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
30
+ maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
31
+ start_epoch: 0
32
+ pin_mem: True
33
+ save: True
34
+ name: "exp"
35
+ exp_name: "VQ_w_Prompt"
36
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
37
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt" # path where to save, empty for no saving
38
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQ_w_Prompt/logs" # path where to tensorboard log
39
+ strart_epoch: 0
40
+ cuda: True
41
+ TEST: True
config/eval_config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "imagenet"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
4
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
5
+ eval_path: ''
6
+ batch_size: 64 # Batch size per GPU
7
+ test_batch_size: 200
8
+ num_workers: 4
9
+ epochs: 400
10
+ blr: 2.e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
11
+ min_lr: 5.e-6
12
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
13
+ input_size: 256
14
+ weight_decay: 0.05
15
+ warmup_epochs: 20
16
+ mask_ratio_min: 0.5
17
+ mask_ratio_max: 0.9
18
+ mask_ratio_mu: 0.55
19
+ mask_ratio_std: 0.25
20
+ grad_clip: 3.0
21
+ device: "cuda"
22
+ distributed: True
23
+ seed: 0
24
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
25
+ start_epoch: 0
26
+ pin_mem: True
27
+ save: True
28
+ name: "exp"
29
+ exp_name: "MIM_vbr_eval_lossless"
30
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
31
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless" # path where to save, empty for no saving
32
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_eval_lossless/logs_eval" # path where to tensorboard log
33
+ strart_epoch: 0
34
+ cuda: True
config/eval_high_res.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
5
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
6
+ eval_path: ''
7
+ batch_size: 128 # Batch size per GPU
8
+ learning_rate: 1.e-4
9
+ test_batch_size: 1
10
+ num_workers: 4
11
+ epochs: 100
12
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
13
+ min_lr: 5.e-6
14
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
15
+ input_size: 256
16
+ weight_decay: 0.05
17
+ warmup_epochs: 10
18
+ mask_ratio_min: 0.5
19
+ mask_ratio_max: 0.9
20
+ mask_ratio_mu: 0.55
21
+ mask_ratio_std: 0.25
22
+ grad_clip: 3.0
23
+ device: "cuda"
24
+ distributed: True
25
+ seed: 0
26
+ gpu_id: 0,1
27
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
28
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
29
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
30
+ maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
31
+ start_epoch: 0
32
+ pin_mem: True
33
+ save: True
34
+ name: "exp"
35
+ exp_name: "Test_High_Res"
36
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
37
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Test_High_Res" # path where to save, empty for no saving
38
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Test_High_Res/logs" # path where to tensorboard log
39
+ strart_epoch: 0
40
+ cuda: True
41
+ TEST: True
config/eval_vq.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
5
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO/val2017/'
6
+ eval_path: ''
7
+ batch_size: 32 # Batch size per GPU
8
+ test_batch_size: 1
9
+ num_workers: 4
10
+ epochs: 400
11
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
12
+ min_lr: 5.e-6
13
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
14
+ input_size: 256
15
+ weight_decay: 0.05
16
+ warmup_epochs: 20
17
+ mask_ratio_min: 0.5
18
+ mask_ratio_max: 0.9
19
+ mask_ratio_mu: 0.55
20
+ mask_ratio_std: 0.25
21
+ grad_clip: 3.0
22
+ device: "cuda"
23
+ distributed: True
24
+ seed: 0
25
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
26
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
27
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
28
+ start_epoch: 0
29
+ pin_mem: True
30
+ save: True
31
+ name: "exp"
32
+ exp_name: "VQGAN"
33
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
34
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQGAN" # path where to save, empty for no saving
35
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/VQGAN/logs" # path where to tensorboard log
36
+ strart_epoch: 0
37
+ cuda: True
38
+ TEST: True
config/fine_tune_vq_w_mask.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
5
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
6
+ eval_path: ''
7
+ batch_size: 64 # Batch size per GPU
8
+ learning_rate: 1.e-4
9
+ test_batch_size: 1
10
+ num_workers: 4
11
+ epochs: 100
12
+ blr: 1.e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
13
+ min_lr: 5.e-6
14
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
15
+ input_size: 256
16
+ weight_decay: 0.05
17
+ warmup_epochs: 10
18
+ mask_ratio_min: 0.25
19
+ mask_ratio_max: 0.75
20
+ grad_clip: 3.0
21
+ device: "cuda"
22
+ distributed: True
23
+ seed: 0
24
+ gpu_id: 2
25
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
26
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
27
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
28
+ maskrcnn_path: '/home/t2vg-a100-G4-10/project/qyp/ckpt/detectron2/mask_rcnn_r50_fpn_3x.pkl'
29
+ start_epoch: 0
30
+ pin_mem: True
31
+ save: True
32
+ name: "exp"
33
+ exp_name: "Fine_Tune_VQ_w_P"
34
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
35
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P" # path where to save, empty for no saving
36
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/Fine_Tune_VQ_w_P/logs" # path where to tensorboard log
37
+ strart_epoch: 0
38
+ cuda: True
39
+ TEST: True
config/train_det.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "coco"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/COCO'
4
+ kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/kodak/'
5
+ # kodak_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/test/'
6
+ eval_path: ''
7
+ batch_size: 32 # Batch size per GPU
8
+ test_batch_size: 1
9
+ num_workers: 4
10
+ epochs: 400
11
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
12
+ min_lr: 5.e-6
13
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
14
+ input_size: 256
15
+ weight_decay: 0.05
16
+ warmup_epochs: 20
17
+ mask_ratio_min: 0.5
18
+ mask_ratio_max: 0.9
19
+ mask_ratio_mu: 0.55
20
+ mask_ratio_std: 0.25
21
+ grad_clip: 3.0
22
+ device: "cuda"
23
+ distributed: True
24
+ seed: 0
25
+ resume: '/home/t2vg-a100-G4-10/project/qyp/lossless/checkpoint-330.pth' # Path to the checkpoint to resume from
26
+ # resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/checkpoint-399.pth'
27
+ # resume: '/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_rope/random_with_rope.pth'
28
+ start_epoch: 0
29
+ pin_mem: True
30
+ save: True
31
+ name: "exp"
32
+ exp_name: "MIM_vbr_kodak"
33
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
34
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_kodak" # path where to save, empty for no saving
35
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr_kodak/logs" # path where to tensorboard log
36
+ strart_epoch: 0
37
+ cuda: True
38
+ TEST: True
config/train_im_cls.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "imagenet"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
4
+ eval_path: ''
5
+ batch_size: 256 # Batch size per GPU
6
+ test_batch_size: 64
7
+ num_workers: 4
8
+ epochs: 100
9
+ learning_rate: 1.e-4
10
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
11
+ min_lr: 5.e-6
12
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
13
+ input_size: 256
14
+ weight_decay: 0.05
15
+ warmup_epochs: 10
16
+ mask_ratio_min: 0.3
17
+ mask_ratio_max: 0.75
18
+ grad_clip: 3.0
19
+ device: "cuda"
20
+ distributed: True
21
+ seed: 0
22
+ resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls/checkpoint-lastbest.pth' # Path to the checkpoint to resume from
23
+ start_epoch: 0
24
+ pin_mem: True
25
+ save: True
26
+ name: "exp"
27
+ exp_name: "train_im_cls"
28
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
29
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls" # path where to save, empty for no saving
30
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/train_im_cls/logs" # path where to tensorboard log
31
+ strart_epoch: 0
32
+ cuda: True
config/train_mim_config.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model: mage_vit_base_patch16
2
+ dataset: "imagenet"
3
+ dataset_path: '/home/t2vg-a100-G4-10/project/qyp/datasets/imagenet'
4
+ eval_path: ''
5
+ batch_size: 64 # Batch size per GPU
6
+ test_batch_size: 64
7
+ num_workers: 4
8
+ epochs: 400
9
+ blr: 1.5e-4 # 1.5e-4 # base learning rate: absolute_lr = base_lr * total_batch_size / 256
10
+ min_lr: 5.e-6
11
+ accum_iter: 1 # Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
12
+ input_size: 256
13
+ weight_decay: 0.05
14
+ warmup_epochs: 20
15
+ mask_ratio_min: 0.5
16
+ mask_ratio_max: 0.9
17
+ mask_ratio_mu: 0.55
18
+ mask_ratio_std: 0.25
19
+ grad_clip: 3.0
20
+ device: "cuda"
21
+ distributed: True
22
+ seed: 0
23
+ resume: '/home/t2vg-a100-G4-10/project/qyp/mimc_rope/random_with_rope.pth' # Path to the checkpoint to resume from
24
+ start_epoch: 0
25
+ pin_mem: True
26
+ save: True
27
+ name: "exp"
28
+ exp_name: "MIM_vbr"
29
+ root: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope"
30
+ output_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr" # path where to save, empty for no saving
31
+ log_dir: "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/MIM_vbr/logs" # path where to tensorboard log
32
+ strart_epoch: 0
33
+ cuda: True
config/vqgan.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: taming.models.vqgan.VQModel
3
+ params:
4
+ embed_dim: 256
5
+ n_embed: 1024
6
+ ddconfig:
7
+ double_z: False
8
+ z_channels: 256
9
+ resolution: 256
10
+ in_channels: 3
11
+ out_ch: 3
12
+ ch: 128
13
+ ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
14
+ num_res_blocks: 2
15
+ attn_resolutions: [16]
16
+ dropout: 0.0
deploited/models_mage_codec_pos_linear.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
7
+
8
+ from util.pos_embed import get_2d_sincos_pos_embed
9
+
10
+ from taming.models.vqgan import VQModel
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+ import scipy.stats as stats
14
+ from compressai.entropy_models import EntropyBottleneck
15
+ from compressai.layers import conv3x3, subpel_conv3x3
16
+ import math
17
+ from torch import Tensor
18
+ from einops import rearrange, repeat
19
+ import torch.nn.functional as F
20
+ import torchac
21
+ from typing import Any, Callable, List, Optional, Tuple, Union
22
+
23
+ SCALES_MIN = 0.11
24
+ SCALES_MAX = 256
25
+ SCALES_LEVELS = 64
26
+ def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
27
+ return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
28
+
29
+ def ste_round(x: Tensor) -> Tensor:
30
+ return torch.round(x) - x.detach() + x
31
+
32
+ def conv(in_channels, out_channels, kernel_size=5, stride=2):
33
+ return nn.Conv2d(
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size=kernel_size,
37
+ stride=stride,
38
+ padding=kernel_size // 2,
39
+ )
40
+
41
+ def mask_by_random_topk(mask_len, probs, temperature=1.0):
42
+ mask_len = mask_len.squeeze()
43
+ # 使用Gumbel分布进行采样,增加随机性
44
+ confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
45
+ sorted_confidence, _ = torch.sort(confidence, axis=-1)
46
+ # Obtains cut off threshold given the mask lengths.
47
+ cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
48
+ # Masks tokens with lower confidence.
49
+ masking = (confidence <= cut_off)
50
+ return masking
51
+
52
+ def adjust_mask_and_drop_embeddings(token_keep_mask):
53
+ """
54
+ Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
55
+ some of them to False, and then applies this adjusted mask to input_embeddings.
56
+
57
+ Parameters:
58
+ - input_embeddings: Tensor, The embeddings tensor.
59
+ - token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
60
+
61
+ Returns:
62
+ - Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
63
+ """
64
+ # 获取非零(即值为True)元素的索引
65
+ non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
66
+ # 计算非零元素的数量
67
+ non_zero_count = non_zero_indices[0].size(0)
68
+ # 计算最近的整数平方倍
69
+ next_square = math.floor(math.sqrt(non_zero_count))**2
70
+ # 计算需要移除的元素数量
71
+ remove_count = non_zero_count - next_square
72
+ if remove_count > 0:
73
+ # 如果需要移除元素以达到整数平方倍
74
+ permuted_indices = torch.randperm(non_zero_count)[:remove_count]
75
+ for idx in permuted_indices:
76
+ token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
77
+ # 使用更新后的token_keep_mask
78
+ # input_embeddings_after_drop = input_embeddings[token_keep_mask]
79
+
80
+ return token_keep_mask
81
+
82
+
83
+ class FactorizedEntropyModel(EntropyBottleneck):
84
+ def __init__(self, *args, **kwargs):
85
+ super().__init__(*args, **kwargs)
86
+
87
+ def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
88
+ if training is None:
89
+ training = self.training
90
+
91
+ # 输入形状已经是 [b, c, seq_len],无需转置
92
+ shape = x.size()
93
+
94
+ # Add noise or quantize
95
+ means = self._get_medians()
96
+ # outputs = self.quantize(
97
+ # x, "noise" if training else "dequantize", means.long()
98
+ # )
99
+ outputs = self.quantize(
100
+ x, "dequantize", means.long()
101
+ )
102
+
103
+ if not torch.jit.is_scripting():
104
+ likelihood = self._likelihood(outputs)
105
+ if self.use_likelihood_bound:
106
+ likelihood = self.likelihood_lower_bound(likelihood)
107
+ else:
108
+ raise NotImplementedError("TorchScript is not yet supported")
109
+
110
+ return outputs, likelihood
111
+
112
+ def compress(self, x):
113
+ # 构建索引,适用于单通道序列数据
114
+ indexes = self._build_indexes(x.size())
115
+ # 获取中位数,已经适配为单通道
116
+ medians = self._get_medians().detach()
117
+ # 调整 medians 的形状以匹配 x 的形状
118
+ medians = medians.expand_as(x)
119
+ # 调用基类的 compress 方法进行压缩
120
+ return super().compress(x, indexes, medians)
121
+
122
+ def decompress(self, strings, size):
123
+ # 预期的输出大小应包括单个通道
124
+ output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
125
+ # 构建索引
126
+ indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
127
+ # 获取中位数并调整其形状以匹配预期输出的形状
128
+ medians = self._extend_ndims(self._get_medians().detach(), len(size))
129
+ medians = medians.expand(len(strings), 1, *([-1] * len(size)))
130
+ # 调用基类的 decompress 方法进行解压缩
131
+ return super().decompress(strings, indexes, medians.dtype, medians)
132
+
133
+ def _preprocess(self, x):
134
+ x = x.permute(0, 2, 3, 1).contiguous()
135
+ return x
136
+
137
+
138
+ class Attention(nn.Module):
139
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
140
+ super().__init__()
141
+ self.num_heads = num_heads
142
+ head_dim = dim // num_heads
143
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
144
+ self.scale = qk_scale or head_dim ** -0.5
145
+
146
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147
+ self.attn_drop = nn.Dropout(attn_drop)
148
+ self.proj = nn.Linear(dim, dim)
149
+ self.proj_drop = nn.Dropout(proj_drop)
150
+
151
+ def forward(self, x):
152
+ B, N, C = x.shape
153
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
154
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
155
+
156
+ with torch.cuda.amp.autocast(enabled=False):
157
+ attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
158
+
159
+ attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
160
+ attn = attn.softmax(dim=-1)
161
+ attn = self.attn_drop(attn)
162
+
163
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
164
+ x = self.proj(x)
165
+ x = self.proj_drop(x)
166
+ # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
167
+ return x, attn
168
+
169
+
170
+ class Block(nn.Module):
171
+
172
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
173
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
174
+ super().__init__()
175
+ self.norm1 = norm_layer(dim)
176
+ self.attn = Attention(
177
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
178
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
180
+ self.norm2 = norm_layer(dim)
181
+ mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
182
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
183
+
184
+ def forward(self, x, return_attention=False):
185
+ if return_attention:
186
+ _, attn = self.attn(self.norm1(x))
187
+ return attn
188
+ else:
189
+ y, _ = self.attn(self.norm1(x))
190
+ x = x + self.drop_path(y)
191
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
192
+ return x
193
+
194
+
195
+ class LabelSmoothingCrossEntropy(nn.Module):
196
+ """ NLL loss with label smoothing.
197
+ """
198
+ def __init__(self, smoothing=0.1):
199
+ super(LabelSmoothingCrossEntropy, self).__init__()
200
+ assert smoothing < 1.0
201
+ self.smoothing = smoothing
202
+ self.confidence = 1. - smoothing
203
+
204
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
205
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
206
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
207
+ nll_loss = nll_loss.squeeze(1)
208
+ smooth_loss = -logprobs.mean(dim=-1)
209
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
210
+ return loss
211
+
212
+
213
+ class BertEmbeddings(nn.Module):
214
+ """Construct the embeddings from word, position and token_type embeddings."""
215
+
216
+ def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
217
+ super().__init__()
218
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
219
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
220
+
221
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
222
+ # any TensorFlow checkpoint file
223
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
224
+ self.dropout = nn.Dropout(dropout)
225
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
226
+ self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
227
+ # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
228
+ # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
229
+
230
+ torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
231
+ torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
232
+
233
+ def forward(
234
+ self, input_ids, z_H, z_W
235
+ ):
236
+ input_shape = input_ids.size() # input_ids: (B, N)(32,1 + latent_h * latent_w)
237
+ seq_length = input_shape[1]
238
+ position_ids = self.position_ids[:, :seq_length]
239
+ inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
240
+ position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
241
+
242
+ # if seq_length != (1 + z_H * z_W):
243
+ # Exclude the first embedding (position 0) before interpolation
244
+ position_embeddings_to_interp = position_embeddings[:, 1:, :] # (b, 256, embed_dim)
245
+ # Perform bilinear interpolation
246
+ interpolated_embeddings = F.interpolate(
247
+ position_embeddings_to_interp.permute(0, 2, 1).reshape(1, -1, 16, 16),
248
+ size=(z_H, z_W), mode="bicubic", align_corners=False
249
+ ).reshape(1, -1, z_H * z_W).permute(0, 2, 1) # (b, seq_len-1, embed_dim)
250
+ # Concatenate the first embedding back
251
+ position_embeddings = torch.cat([position_embeddings[:, :1, :], interpolated_embeddings], dim=1)
252
+
253
+ embeddings = inputs_embeds + position_embeddings
254
+
255
+ embeddings = self.LayerNorm(embeddings)
256
+ embeddings = self.dropout(embeddings)
257
+ return embeddings
258
+
259
+
260
+ class MlmLayer(nn.Module):
261
+
262
+ def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
263
+ super().__init__()
264
+ self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
265
+ self.gelu = nn.GELU()
266
+ self.ln = nn.LayerNorm(word_emb_dim)
267
+ self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
268
+
269
+ def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
270
+ mlm_hidden = self.fc(x)
271
+ mlm_hidden = self.gelu(mlm_hidden)
272
+ mlm_hidden = self.ln(mlm_hidden)
273
+ word_embeddings = word_embeddings.transpose(0, 1)
274
+ logits = torch.matmul(mlm_hidden, word_embeddings)
275
+ logits = logits + self.bias
276
+ return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
277
+
278
+
279
+ class MaskedGenerativeEncoderViT(nn.Module):
280
+ """ Masked Autoencoder with VisionTransformer backbone
281
+ """
282
+ def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
283
+ embed_dim=1024, depth=24, num_heads=16,
284
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
285
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
286
+ mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
287
+ vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
288
+ super().__init__()
289
+
290
+ # --------------------------------------------------------------------------
291
+ # VQGAN specifics
292
+ config = OmegaConf.load('/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/config/vqgan.yaml').model
293
+ self.vqgan = VQModel(ddconfig=config.params.ddconfig,
294
+ n_embed=config.params.n_embed, # 1024
295
+ embed_dim=config.params.embed_dim, # 256
296
+ ckpt_path=vqgan_ckpt_path)
297
+ for param in self.vqgan.parameters():
298
+ param.requires_grad = False
299
+
300
+ self.codebook_size = config.params.n_embed # 1024
301
+ vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
302
+ self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
303
+ self.mask_token_label = vocab_size - 1 # 2024
304
+ self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
305
+ hidden_size=embed_dim,
306
+ max_position_embeddings=256 +1,
307
+ # max_position_embeddings=256+1, # 256个patch + 1 class token
308
+ dropout=0.1)
309
+
310
+ # MAGE variant masking ratio
311
+ self.mask_ratio_min = mask_ratio_min
312
+ self.mask_ratio_max = mask_ratio_max
313
+ # self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
314
+ # (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
315
+ # loc=mask_ratio_mu, scale=mask_ratio_std)
316
+
317
+ # --------------------------------------------------------------------------
318
+ # MAGE encoder specifics
319
+ dropout_rate = 0.1
320
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
321
+ num_patches = self.patch_embed.num_patches # 256
322
+
323
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
324
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
325
+
326
+ self.blocks = nn.ModuleList([ # encoder
327
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
328
+ drop=dropout_rate, attn_drop=dropout_rate)
329
+ for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
330
+ self.norm = norm_layer(embed_dim) # layer norm
331
+ # --------------------------------------------------------------------------
332
+
333
+ # --------------------------------------------------------------------------
334
+ # MAGE decoder specifics
335
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
336
+
337
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
338
+ self.pad_with_cls_token = True
339
+
340
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
341
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
342
+
343
+ self.decoder_blocks = nn.ModuleList([
344
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
345
+ drop=dropout_rate, attn_drop=dropout_rate)
346
+ for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
347
+
348
+ self.decoder_norm = norm_layer(decoder_embed_dim)
349
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
350
+ # --------------------------------------------------------------------------
351
+
352
+ # --------------------------------------------------------------------------
353
+ # MlmLayer
354
+ self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
355
+
356
+ self.norm_pix_loss = norm_pix_loss
357
+
358
+ self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
359
+ # --------------------------------------------------------------------------
360
+ self.entropy_bottleneck = FactorizedEntropyModel(1)
361
+
362
+ self.initialize_weights()
363
+
364
+ def initialize_weights(self):
365
+ # initialization
366
+ # initialize (and freeze) pos_embed by sin-cos embedding
367
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
368
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
369
+
370
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
371
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
372
+
373
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
374
+ w = self.patch_embed.proj.weight.data
375
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
376
+
377
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
378
+ torch.nn.init.normal_(self.cls_token, std=.02)
379
+ torch.nn.init.normal_(self.mask_token, std=.02)
380
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
381
+
382
+ # initialize nn.Linear and nn.LayerNorm
383
+ self.apply(self._init_weights)
384
+
385
+ def _init_weights(self, m):
386
+ if isinstance(m, nn.Linear):
387
+ # we use xavier_uniform following official JAX ViT:
388
+ torch.nn.init.xavier_uniform_(m.weight)
389
+ if isinstance(m, nn.Linear) and m.bias is not None:
390
+ nn.init.constant_(m.bias, 0)
391
+ elif isinstance(m, nn.LayerNorm):
392
+ nn.init.constant_(m.bias, 0)
393
+ nn.init.constant_(m.weight, 1.0)
394
+
395
+ def random_sample_mask_rate(self):
396
+ # 生成一个 (0, 1] 范围内的随机数
397
+ random_sample = 1 - torch.rand(1)
398
+ # 映射到 mask_ratio_min 到 mask_ratio_max 的范围
399
+ mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
400
+ return mask_rate.item() # 转换为Python的标量值
401
+
402
+ def get_cdf_token_mask(self, token_all_mask):
403
+ bsz, seq_len = token_all_mask.size()
404
+ # --- use Normal distribution.
405
+ dist_normal = torch.distributions.Normal(0, 2)
406
+ cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
407
+ cdf_mask_token = (cdf_mask_token - .5) * 2
408
+ cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
409
+ b=bsz, s=seq_len)
410
+
411
+ cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
412
+ return cdf_mask_token
413
+
414
+ def get_dec_pos_embed(self, z_H, z_W):
415
+ # Exclude the first embedding (position 0) before interpolation
416
+ position_embeddings_to_interp = self.decoder_pos_embed_learned[:, 1:, :] # (b, 256, embed_dim)
417
+ # Perform bilinear interpolation
418
+ interpolated_embeddings = F.interpolate(
419
+ position_embeddings_to_interp.permute(0, 2, 1).reshape(1, -1, 16, 16),
420
+ size=(z_H, z_W), mode="bicubic", align_corners=False
421
+ ).reshape(1, -1, z_H * z_W).permute(0, 2, 1) # (b, seq_len -1 , embed_dim)
422
+ # Concatenate the first embedding back
423
+ position_embeddings = torch.cat([self.decoder_pos_embed_learned[:, :1, :], interpolated_embeddings], dim=1)
424
+
425
+ return position_embeddings
426
+
427
+ def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
428
+ """
429
+ input: x: (B, 3, H, W)
430
+ """
431
+ # ============ 1. tokenization ============ #
432
+ with torch.no_grad():
433
+ z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
434
+
435
+ _, _, z_H, z_W = z_q.size()
436
+ _, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
437
+ token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
438
+ gt_indices = token_indices.clone().detach().long()
439
+
440
+ # ============ 2. masking process ============ #
441
+ bsz, seq_len = token_indices.size() # seq_len=h*w
442
+ mask_ratio_min = self.mask_ratio_min # 0.5
443
+
444
+ if is_training:
445
+ # mask_rate = self.mask_ratio_generator.rvs(1)[0]
446
+ mask_rate = self.random_sample_mask_rate()
447
+ num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
448
+ else:
449
+ num_dropped_tokens = 0
450
+ if manual_mask_rate is not None:
451
+ mask_rate = manual_mask_rate
452
+ else:
453
+ raise ValueError("mask_rate should be provided for inference!")
454
+
455
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
456
+ mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
457
+ # it is possible that two elements of the noise is the same, so do a while loop to avoid it
458
+ while True:
459
+ noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
460
+ sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
461
+ if num_dropped_tokens > 0:
462
+ cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
463
+ else:
464
+ cutoff_drop = torch.zeros((bsz, 1), device=x.device)
465
+ cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
466
+ token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
467
+ token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
468
+ if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
469
+ break
470
+ else:
471
+ print("Rerandom the noise!")
472
+
473
+ # 获取unmasked token及其位置
474
+ unmasked_pos = token_all_mask == 0 # 未被mask的位置
475
+ unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
476
+
477
+ return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
478
+
479
+ def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, z_H, z_W):
480
+ bsz, seq_len = gt_indices.size()
481
+ padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
482
+ # 将未被mask的token填充回去
483
+ # 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
484
+ unmasked_token_counter = [0 for _ in range(bsz)]
485
+
486
+ for b in range(bsz):
487
+ for idx in range(seq_len):
488
+ # 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
489
+ if (token_all_mask[b, idx] == 0): # 检查是否未被mask
490
+ # 替换相应的unmaksed token
491
+ padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
492
+ # 更新计数器
493
+ unmasked_token_counter[b] += 1
494
+
495
+ token_indices = padded_token_indices
496
+ # ============ 3. Adding class token ============ #
497
+ # concate class token, add [CLS] token to aggregate sequence-level representations
498
+ token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
499
+ token_indices[:, 0] = self.fake_class_label # [B, 257]
500
+ # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
501
+ # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
502
+ # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
503
+ token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
504
+ token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
505
+ token_indices = token_indices.long()
506
+
507
+ # ============ 4. Embedding and Dropout ============ #
508
+ # bert embedding
509
+ input_embeddings = self.token_emb(token_indices, z_H, z_W) # get embeddings [B, 257, 768]
510
+ # print("Input embedding shape:", input_embeddings.shape)
511
+ bsz, seq_len, emb_dim = input_embeddings.shape
512
+
513
+ # dropping
514
+ token_keep_mask = 1 - token_drop_mask
515
+ input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
516
+ # print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
517
+
518
+ # ============ 5. Transformer encoding ============ #
519
+ x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
520
+ for blk in self.blocks:
521
+ x = blk(x) # each block has a multi-head self-attention and a mlp
522
+ x = self.norm(x)
523
+ # print("Encoder representation shape:", x.shape)
524
+
525
+ return x, token_indices, token_all_mask, token_drop_mask
526
+
527
+ def forward_decoding(self, x, token_drop_mask, token_all_mask, z_H, z_W):
528
+ """
529
+ x: output x of forward_encoder()
530
+ token_drop_mask: positions for dropped tokens
531
+ token_all_mask: positions for masked tokens
532
+ """
533
+ # ============ 1. Prepare Embedding and padding tokens ============ #
534
+ # embed tokens
535
+ x = self.decoder_embed(x) # input_embedding_after_padding
536
+
537
+ # append mask tokens to sequence
538
+ # replicates the [CLS] token embedding across the sequence length where masking is to be applied
539
+ if self.pad_with_cls_token: # True
540
+ mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
541
+ else:
542
+ mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
543
+
544
+ # ============ 2. Prepare positional embedding ============ #
545
+ # put undropped tokens into original sequence
546
+ x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
547
+ x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
548
+ # set undropped but masked positions with mask
549
+ x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
550
+
551
+ # add pos embed
552
+ dec_pos_embed_learned = self.get_dec_pos_embed(z_H, z_W)
553
+ x = x_after_pad + dec_pos_embed_learned # add learnable pos embedding
554
+ # x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
555
+
556
+ # apply Transformer blocks
557
+ for blk in self.decoder_blocks:
558
+ x = blk(x)
559
+
560
+ x = self.decoder_norm(x)
561
+
562
+ word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
563
+ logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
564
+ # print("Logits shape:", x.shape)
565
+
566
+ return logits
567
+
568
+ def forward_loss(self, gt_indices, logits, mask):
569
+ bsz, seq_len = gt_indices.size()
570
+ # logits and mask are with seq_len+1 but gt_indices is with seq_len
571
+ loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
572
+ loss = loss.reshape(bsz, seq_len)
573
+ loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
574
+ return loss
575
+
576
+ def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
577
+ lmbda = A * torch.exp(B * (1 - mask_ratio))
578
+ return lmbda
579
+
580
+ def cal_loss(self, logits, gt_indices, mask, mask_ratio):
581
+ mask_ratio = torch.tensor(mask_ratio)
582
+ ## cal cross entropy loss
583
+ task_loss = self.forward_loss(gt_indices, logits, mask)
584
+ lmbda = self.cal_lmbda(mask_ratio)
585
+ ## cal total loss for codec optimization
586
+ return task_loss, lmbda
587
+
588
+ def forward(self, imgs, is_training=False, manual_mask_rate=None):
589
+ ## ---------- encoding process ---------- ##
590
+ gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
591
+ latent = latent.unsqueeze(1)
592
+
593
+ latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
594
+ # 判断latent_hat和latent是否相等
595
+ # print((latent_hat == latent).all())
596
+ cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
597
+ sym = (token_all_mask.short() + 1).cpu()
598
+ bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
599
+ mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
600
+
601
+ ## ---------- decoding process ---------- ##
602
+ decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
603
+ decoded_mask = (decoded_sym - 1).to(device=imgs.device)
604
+ latent_hat = latent_hat.squeeze(1)
605
+ x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask, z_H, z_W)
606
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask, z_H, z_W)
607
+ ## calculate loss
608
+ task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
609
+ return_dict = {
610
+ 'logits': logits,
611
+ 'likelihoods': latent_likelihoods,
612
+ 'task_loss': task_loss,
613
+ 'token_indices': token_indices,
614
+ 'token_all_mask': token_all_mask,
615
+ 'bs_mask_token': bs_mask_token,
616
+ 'mask_ratio': mask_ratio,
617
+ 'lambda': lmbda,
618
+ 'mask_vis': 1 - mask_vis,
619
+ 'z_H': z_H,
620
+ 'z_W': z_W,
621
+ }
622
+ return return_dict
623
+
624
+ # def update(self, scale_table=None, force=False):
625
+ # if scale_table is None:
626
+ # scale_table = get_scale_table()
627
+ # updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
628
+ # updated |= super().update(force=force)
629
+ # return updated
630
+
631
+ def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
632
+ """
633
+ generated image at inference
634
+ seed: random seed
635
+ logits: predicted logits by model decoder
636
+ token_all_mask: mask token indices
637
+ token_indices: token indices of the input image after the vq tokenizer
638
+ num_iter: number of iterations for sampling
639
+ choice_temperature: temperature for sampling
640
+ """
641
+ # torch.manual_seed(seed)
642
+ # np.random.seed(seed)
643
+ bsz = logits.size(0)
644
+ codebook_emb_dim = 256
645
+ codebook_size = 1024
646
+ mask_token_id = self.mask_token_label
647
+ _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
648
+ unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
649
+ for step in range(num_iter):
650
+ if step == 0:
651
+ cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
652
+ cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
653
+ logits = logits[:, 1:, :codebook_size]
654
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
655
+ # For iter=6, we use categorical sampling and temp=4.5."
656
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
657
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
658
+ # get ids for next step
659
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
660
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
661
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
662
+ unknown_map = (cur_ids == mask_token_id)
663
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
664
+ # Defines the mask ratio for the next round. The number to mask out is
665
+ # determined by mask_ratio * unknown_number_in_the_beginning.
666
+ ratio = 1. * (step + 1) / num_iter
667
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
668
+
669
+ # sample ids according to prediction confidence
670
+ probs = torch.nn.functional.softmax(logits, dim=-1)
671
+ selected_probs = torch.squeeze(
672
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
673
+
674
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
675
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
676
+ mask_ratio = torch.tensor(mask_ratio).cuda()
677
+ # mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
678
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
679
+ # Keeps at least one of prediction in this round and also masks out at least
680
+ # one and for the next iteration
681
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
682
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
683
+
684
+ # Sample masking tokens for next iteration
685
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
686
+ # Masks tokens with lower confidence.
687
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
688
+ else:
689
+ cur_ids = token_indices.clone().long() # .long(): to int64
690
+ token_indices = torch.cat(
691
+ [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
692
+ token_indices[:, 0] = self.fake_class_label
693
+ token_indices = token_indices.long()
694
+ token_all_mask = token_indices == mask_token_id
695
+
696
+ token_drop_mask = torch.zeros_like(token_indices)
697
+
698
+ # token embedding
699
+ input_embeddings = self.token_emb(token_indices, z_H, z_W) # get input embeddings
700
+
701
+ # encoder
702
+ x = input_embeddings
703
+ for blk in self.blocks:
704
+ x = blk(x)
705
+ x = self.norm(x)
706
+
707
+ # decoder
708
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask, z_H, z_W)
709
+ logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
710
+
711
+ # get token prediction
712
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
713
+ # For iter=6, we use categorical sampling and temp=4.5."
714
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
715
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
716
+
717
+ # get ids for next step
718
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
719
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
720
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
721
+ unknown_map = (cur_ids == mask_token_id)
722
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
723
+ # Defines the mask ratio for the next round. The number to mask out is
724
+ # determined by mask_ratio * unknown_number_in_the_beginning.
725
+ ratio = 1. * (step + 1) / num_iter
726
+
727
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
728
+
729
+ # sample ids according to prediction confidence
730
+ probs = torch.nn.functional.softmax(logits, dim=-1)
731
+ selected_probs = torch.squeeze(
732
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
733
+
734
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
735
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
736
+ mask_ratio = torch.tensor(mask_ratio).cuda()
737
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
738
+ # Keeps at least one of prediction in this round and also masks out at least
739
+ # one and for the next iteration
740
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
741
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
742
+
743
+ # Sample masking tokens for next iteration
744
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
745
+ # Masks tokens with lower confidence.
746
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
747
+
748
+ # vqgan visualization
749
+ z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
750
+ gen_images = self.vqgan.decode(z_q)
751
+ return gen_images
752
+
753
+
754
+ def mage_vit_base_patch16(**kwargs):
755
+ model = MaskedGenerativeEncoderViT(
756
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
757
+ decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
758
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
759
+ return model
760
+
761
+
762
+ def mage_vit_large_patch16(**kwargs):
763
+ model = MaskedGenerativeEncoderViT(
764
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
765
+ decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
766
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
767
+ return model
deploited/models_mage_codec_rm_pos.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
7
+
8
+ from util.pos_embed import get_2d_sincos_pos_embed
9
+
10
+ from taming.models.vqgan import VQModel
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+ import scipy.stats as stats
14
+ from compressai.entropy_models import EntropyBottleneck
15
+ from compressai.layers import conv3x3, subpel_conv3x3
16
+ import math
17
+ from torch import Tensor
18
+ from einops import rearrange, repeat
19
+ import torch.nn.functional as F
20
+ import torchac
21
+ from typing import Any, Callable, List, Optional, Tuple, Union
22
+
23
+ SCALES_MIN = 0.11
24
+ SCALES_MAX = 256
25
+ SCALES_LEVELS = 64
26
+ def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
27
+ return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
28
+
29
+ def ste_round(x: Tensor) -> Tensor:
30
+ return torch.round(x) - x.detach() + x
31
+
32
+ def conv(in_channels, out_channels, kernel_size=5, stride=2):
33
+ return nn.Conv2d(
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size=kernel_size,
37
+ stride=stride,
38
+ padding=kernel_size // 2,
39
+ )
40
+
41
+ def mask_by_random_topk(mask_len, probs, temperature=1.0):
42
+ mask_len = mask_len.squeeze()
43
+ # 使用Gumbel分布进行采样,增加随机性
44
+ confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
45
+ sorted_confidence, _ = torch.sort(confidence, axis=-1)
46
+ # Obtains cut off threshold given the mask lengths.
47
+ cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
48
+ # Masks tokens with lower confidence.
49
+ masking = (confidence <= cut_off)
50
+ return masking
51
+
52
+ def adjust_mask_and_drop_embeddings(token_keep_mask):
53
+ """
54
+ Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
55
+ some of them to False, and then applies this adjusted mask to input_embeddings.
56
+
57
+ Parameters:
58
+ - input_embeddings: Tensor, The embeddings tensor.
59
+ - token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
60
+
61
+ Returns:
62
+ - Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
63
+ """
64
+ # 获取非零(即值为True)元素的索引
65
+ non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
66
+ # 计算非零元素的数量
67
+ non_zero_count = non_zero_indices[0].size(0)
68
+ # 计算最近的整数平方倍
69
+ next_square = math.floor(math.sqrt(non_zero_count))**2
70
+ # 计算需要移除的元素数量
71
+ remove_count = non_zero_count - next_square
72
+ if remove_count > 0:
73
+ # 如果需要移除元素以达到整数平方倍
74
+ permuted_indices = torch.randperm(non_zero_count)[:remove_count]
75
+ for idx in permuted_indices:
76
+ token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
77
+ # 使用更新后的token_keep_mask
78
+ # input_embeddings_after_drop = input_embeddings[token_keep_mask]
79
+
80
+ return token_keep_mask
81
+
82
+
83
+ class FactorizedEntropyModel(EntropyBottleneck):
84
+ def __init__(self, *args, **kwargs):
85
+ super().__init__(*args, **kwargs)
86
+
87
+ def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
88
+ if training is None:
89
+ training = self.training
90
+
91
+ # 输入形状已经是 [b, c, seq_len],无需转置
92
+ shape = x.size()
93
+
94
+ # Add noise or quantize
95
+ means = self._get_medians()
96
+ # outputs = self.quantize(
97
+ # x, "noise" if training else "dequantize", means.long()
98
+ # )
99
+ outputs = self.quantize(
100
+ x, "dequantize", means.long()
101
+ )
102
+
103
+ if not torch.jit.is_scripting():
104
+ likelihood = self._likelihood(outputs)
105
+ if self.use_likelihood_bound:
106
+ likelihood = self.likelihood_lower_bound(likelihood)
107
+ else:
108
+ raise NotImplementedError("TorchScript is not yet supported")
109
+
110
+ return outputs, likelihood
111
+
112
+ def compress(self, x):
113
+ # 构建索引,适用于单通道序列数据
114
+ indexes = self._build_indexes(x.size())
115
+ # 获取中位数,已经适配为单通道
116
+ medians = self._get_medians().detach()
117
+ # 调整 medians 的形状以匹配 x 的形状
118
+ medians = medians.expand_as(x)
119
+ # 调用基类的 compress 方法进行压缩
120
+ return super().compress(x, indexes, medians)
121
+
122
+ def decompress(self, strings, size):
123
+ # 预期的输出大小应包括单个通道
124
+ output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
125
+ # 构建索引
126
+ indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
127
+ # 获取中位数并调整其形状以匹配预期输出的形状
128
+ medians = self._extend_ndims(self._get_medians().detach(), len(size))
129
+ medians = medians.expand(len(strings), 1, *([-1] * len(size)))
130
+ # 调用基类的 decompress 方法进行解压缩
131
+ return super().decompress(strings, indexes, medians.dtype, medians)
132
+
133
+ def _preprocess(self, x):
134
+ x = x.permute(0, 2, 3, 1).contiguous()
135
+ return x
136
+
137
+
138
+ class Attention(nn.Module):
139
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
140
+ super().__init__()
141
+ self.num_heads = num_heads
142
+ head_dim = dim // num_heads
143
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
144
+ self.scale = qk_scale or head_dim ** -0.5
145
+
146
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
147
+ self.attn_drop = nn.Dropout(attn_drop)
148
+ self.proj = nn.Linear(dim, dim)
149
+ self.proj_drop = nn.Dropout(proj_drop)
150
+
151
+ def forward(self, x):
152
+ B, N, C = x.shape
153
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
154
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
155
+
156
+ with torch.cuda.amp.autocast(enabled=False):
157
+ attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
158
+
159
+ attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
160
+ attn = attn.softmax(dim=-1)
161
+ attn = self.attn_drop(attn)
162
+
163
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
164
+ x = self.proj(x)
165
+ x = self.proj_drop(x)
166
+ # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
167
+ return x, attn
168
+
169
+
170
+ class Block(nn.Module):
171
+
172
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
173
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
174
+ super().__init__()
175
+ self.norm1 = norm_layer(dim)
176
+ self.attn = Attention(
177
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
178
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
180
+ self.norm2 = norm_layer(dim)
181
+ mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
182
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
183
+
184
+ def forward(self, x, return_attention=False):
185
+ if return_attention:
186
+ _, attn = self.attn(self.norm1(x))
187
+ return attn
188
+ else:
189
+ y, _ = self.attn(self.norm1(x))
190
+ x = x + self.drop_path(y)
191
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
192
+ return x
193
+
194
+
195
+ class LabelSmoothingCrossEntropy(nn.Module):
196
+ """ NLL loss with label smoothing.
197
+ """
198
+ def __init__(self, smoothing=0.1):
199
+ super(LabelSmoothingCrossEntropy, self).__init__()
200
+ assert smoothing < 1.0
201
+ self.smoothing = smoothing
202
+ self.confidence = 1. - smoothing
203
+
204
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
205
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
206
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
207
+ nll_loss = nll_loss.squeeze(1)
208
+ smooth_loss = -logprobs.mean(dim=-1)
209
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
210
+ return loss
211
+
212
+
213
+ class BertEmbeddings(nn.Module):
214
+ """Construct the embeddings from word, position and token_type embeddings."""
215
+
216
+ def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
217
+ super().__init__()
218
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
219
+ # self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
220
+
221
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
222
+ # any TensorFlow checkpoint file
223
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
224
+ self.dropout = nn.Dropout(dropout)
225
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
226
+ # self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
227
+ # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0到max_position_embeddings-1的整数。
228
+ # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
229
+
230
+ torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
231
+ # torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
232
+
233
+ def forward(
234
+ self, input_ids
235
+ ):
236
+ input_shape = input_ids.size() # input_ids: (B, N)(32,257)
237
+
238
+ seq_length = input_shape[1]
239
+
240
+ # position_ids = self.position_ids[:, :seq_length]
241
+
242
+ inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
243
+
244
+ # position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
245
+ # embeddings = inputs_embeds + position_embeddings
246
+ embeddings = inputs_embeds
247
+
248
+ embeddings = self.LayerNorm(embeddings)
249
+ embeddings = self.dropout(embeddings)
250
+ return embeddings
251
+
252
+
253
+ class MlmLayer(nn.Module):
254
+
255
+ def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
256
+ super().__init__()
257
+ self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
258
+ self.gelu = nn.GELU()
259
+ self.ln = nn.LayerNorm(word_emb_dim)
260
+ self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
261
+
262
+ def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
263
+ mlm_hidden = self.fc(x)
264
+ mlm_hidden = self.gelu(mlm_hidden)
265
+ mlm_hidden = self.ln(mlm_hidden)
266
+ word_embeddings = word_embeddings.transpose(0, 1)
267
+ logits = torch.matmul(mlm_hidden, word_embeddings)
268
+ logits = logits + self.bias
269
+ return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
270
+
271
+
272
+ class MaskedGenerativeEncoderViT(nn.Module):
273
+ """ Masked Autoencoder with VisionTransformer backbone
274
+ """
275
+ def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
276
+ embed_dim=1024, depth=24, num_heads=16,
277
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
278
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
279
+ mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
280
+ vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
281
+ super().__init__()
282
+
283
+ # --------------------------------------------------------------------------
284
+ # VQGAN specifics
285
+ config = OmegaConf.load('config/vqgan.yaml').model
286
+ self.vqgan = VQModel(ddconfig=config.params.ddconfig,
287
+ n_embed=config.params.n_embed, # 1024
288
+ embed_dim=config.params.embed_dim, # 256
289
+ ckpt_path=vqgan_ckpt_path)
290
+ for param in self.vqgan.parameters():
291
+ param.requires_grad = False
292
+ self.codebook_size = config.params.n_embed # 1024
293
+ vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
294
+ self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
295
+ self.mask_token_label = vocab_size - 1 # 2024
296
+ self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
297
+ hidden_size=embed_dim,
298
+ max_position_embeddings=256 +1,
299
+ # max_position_embeddings=256+1, # 256个patch + 1 class token
300
+ dropout=0.1)
301
+
302
+ # MAGE variant masking ratio
303
+ self.mask_ratio_min = mask_ratio_min
304
+ self.mask_ratio_max = mask_ratio_max
305
+ # self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
306
+ # (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
307
+ # loc=mask_ratio_mu, scale=mask_ratio_std)
308
+
309
+ # --------------------------------------------------------------------------
310
+ # MAGE encoder specifics
311
+ dropout_rate = 0.1
312
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
313
+ num_patches = self.patch_embed.num_patches
314
+
315
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
316
+ # self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
317
+
318
+ self.blocks = nn.ModuleList([ # encoder
319
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
320
+ drop=dropout_rate, attn_drop=dropout_rate)
321
+ for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
322
+ self.norm = norm_layer(embed_dim) # layer norm
323
+ # --------------------------------------------------------------------------
324
+
325
+ # --------------------------------------------------------------------------
326
+ # MAGE decoder specifics
327
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
328
+
329
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
330
+ self.pad_with_cls_token = True
331
+
332
+ # self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
333
+ # self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
334
+
335
+ self.decoder_blocks = nn.ModuleList([
336
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
337
+ drop=dropout_rate, attn_drop=dropout_rate)
338
+ for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
339
+
340
+ self.decoder_norm = norm_layer(decoder_embed_dim)
341
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
342
+ # --------------------------------------------------------------------------
343
+
344
+ # --------------------------------------------------------------------------
345
+ # MlmLayer
346
+ self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
347
+
348
+ self.norm_pix_loss = norm_pix_loss
349
+
350
+ self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
351
+ # --------------------------------------------------------------------------
352
+ self.entropy_bottleneck = FactorizedEntropyModel(1)
353
+
354
+ self.initialize_weights()
355
+
356
+ def initialize_weights(self):
357
+ # initialization
358
+ # initialize (and freeze) pos_embed by sin-cos embedding
359
+ # pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
360
+ # self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
361
+
362
+ # decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
363
+ # self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
364
+
365
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
366
+ w = self.patch_embed.proj.weight.data
367
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
368
+
369
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
370
+ torch.nn.init.normal_(self.cls_token, std=.02)
371
+ torch.nn.init.normal_(self.mask_token, std=.02)
372
+ # torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
373
+
374
+ # initialize nn.Linear and nn.LayerNorm
375
+ self.apply(self._init_weights)
376
+
377
+ def _init_weights(self, m):
378
+ if isinstance(m, nn.Linear):
379
+ # we use xavier_uniform following official JAX ViT:
380
+ torch.nn.init.xavier_uniform_(m.weight)
381
+ if isinstance(m, nn.Linear) and m.bias is not None:
382
+ nn.init.constant_(m.bias, 0)
383
+ elif isinstance(m, nn.LayerNorm):
384
+ nn.init.constant_(m.bias, 0)
385
+ nn.init.constant_(m.weight, 1.0)
386
+
387
+ def random_sample_mask_rate(self):
388
+ # 生成一个 (0, 1] 范围内的随机数
389
+ random_sample = 1 - torch.rand(1)
390
+ # 映射到 mask_ratio_min 到 mask_ratio_max 的范围
391
+ mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
392
+ return mask_rate.item() # 转换为Python的标量值
393
+
394
+ def get_cdf_token_mask(self, token_all_mask):
395
+ bsz, seq_len = token_all_mask.size()
396
+ # --- use Normal distribution.
397
+ dist_normal = torch.distributions.Normal(0, 2)
398
+ cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
399
+ cdf_mask_token = (cdf_mask_token - .5) * 2
400
+ cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
401
+ b=bsz, s=seq_len)
402
+
403
+ cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
404
+ return cdf_mask_token
405
+
406
+ # def get_cdf_token_mask(self, token_all_mask):
407
+ # bsz, seq_len = token_all_mask.size()
408
+ # # 直接生成一个0到1之间的线性空间
409
+ # linear_space = torch.linspace(0, 1, steps=seq_len+1)
410
+ # # 无需映射到-1到1
411
+ # cdf_mask_token = linear_space
412
+ # # 调整形状以匹配token_all_mask,并扩展到每个batch
413
+ # cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
414
+ # b=bsz, s=seq_len)
415
+ # # cdf_mask_token = cdf_mask_token.unsqueeze(0).unsqueeze(-1).repeat(bsz, 1, seq_len)
416
+ # # 添加填充以匹配原始代码的操作
417
+ # cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
418
+ # return cdf_mask_token
419
+
420
+ def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
421
+ """
422
+ input: x: (B, 3, H, W)
423
+ """
424
+ # ============ 1. tokenization ============ #
425
+ with torch.no_grad():
426
+ z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
427
+
428
+ _, _, z_H, z_W = z_q.size()
429
+ _, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
430
+ token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
431
+ gt_indices = token_indices.clone().detach().long()
432
+
433
+ # ============ 2. masking process ============ #
434
+ bsz, seq_len = token_indices.size() # seq_len=h*w
435
+ mask_ratio_min = self.mask_ratio_min # 0.5
436
+
437
+ if is_training:
438
+ # mask_rate = self.mask_ratio_generator.rvs(1)[0]
439
+ mask_rate = self.random_sample_mask_rate()
440
+ num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
441
+ else:
442
+ num_dropped_tokens = 0
443
+ if manual_mask_rate is not None:
444
+ mask_rate = manual_mask_rate
445
+ else:
446
+ raise ValueError("mask_rate should be provided for inference!")
447
+
448
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
449
+ mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
450
+ # it is possible that two elements of the noise is the same, so do a while loop to avoid it
451
+ while True:
452
+ noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
453
+ sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
454
+ if num_dropped_tokens > 0:
455
+ cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
456
+ else:
457
+ cutoff_drop = torch.zeros((bsz, 1), device=x.device)
458
+ cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
459
+ token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
460
+ token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
461
+ if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
462
+ break
463
+ else:
464
+ print("Rerandom the noise!")
465
+
466
+ # 获取unmasked token及其位置
467
+ unmasked_pos = token_all_mask == 0 # 未被mask的位置
468
+ unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
469
+
470
+ return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
471
+
472
+ def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
473
+ bsz, seq_len = gt_indices.size()
474
+ padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
475
+ # 将未被mask的token填充回去
476
+ # 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
477
+ unmasked_token_counter = [0 for _ in range(bsz)]
478
+
479
+ for b in range(bsz):
480
+ for idx in range(seq_len):
481
+ # 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
482
+ if (token_all_mask[b, idx] == 0): # 检查是否未被mask
483
+ # 替换相应的unmaksed token
484
+ padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
485
+ # 更新计数器
486
+ unmasked_token_counter[b] += 1
487
+
488
+ token_indices = padded_token_indices
489
+ # ============ 3. Adding class token ============ #
490
+ # concate class token, add [CLS] token to aggregate sequence-level representations
491
+ token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
492
+ token_indices[:, 0] = self.fake_class_label # [B, 257]
493
+ # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
494
+ # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
495
+ # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
496
+ token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
497
+ token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
498
+ token_indices = token_indices.long()
499
+
500
+ # ============ 4. Embedding and Dropout ============ #
501
+ # bert embedding
502
+ input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768]
503
+ # print("Input embedding shape:", input_embeddings.shape)
504
+ bsz, seq_len, emb_dim = input_embeddings.shape
505
+
506
+ # dropping
507
+ token_keep_mask = 1 - token_drop_mask
508
+ input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
509
+ # print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
510
+
511
+ # ============ 5. Transformer encoding ============ #
512
+ x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
513
+ for blk in self.blocks:
514
+ x = blk(x) # each block has a multi-head self-attention and a mlp
515
+ x = self.norm(x)
516
+ # print("Encoder representation shape:", x.shape)
517
+
518
+ return x, token_indices, token_all_mask, token_drop_mask
519
+
520
+ def forward_decoding(self, x, token_drop_mask, token_all_mask):
521
+ """
522
+ x: output x of forward_encoder()
523
+ token_drop_mask: positions for dropped tokens
524
+ token_all_mask: positions for masked tokens
525
+ """
526
+ # ============ 1. Prepare Embedding and padding tokens ============ #
527
+ # embed tokens
528
+ x = self.decoder_embed(x) # input_embedding_after_padding
529
+
530
+ # append mask tokens to sequence
531
+ # replicates the [CLS] token embedding across the sequence length where masking is to be applied
532
+ if self.pad_with_cls_token: # True
533
+ mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
534
+ else:
535
+ mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
536
+
537
+ # ============ 2. Prepare positional embedding ============ #
538
+ # put undropped tokens into original sequence
539
+ x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
540
+ x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
541
+ # set undropped but masked positions with mask
542
+ x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
543
+
544
+ # add pos embed
545
+ # x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
546
+ x = x_after_pad
547
+
548
+ # apply Transformer blocks
549
+ for blk in self.decoder_blocks:
550
+ x = blk(x)
551
+
552
+ x = self.decoder_norm(x)
553
+
554
+ word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
555
+ logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
556
+ # print("Logits shape:", x.shape)
557
+
558
+ return logits
559
+
560
+ def forward_loss(self, gt_indices, logits, mask):
561
+ bsz, seq_len = gt_indices.size()
562
+ # logits and mask are with seq_len+1 but gt_indices is with seq_len
563
+ loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
564
+ loss = loss.reshape(bsz, seq_len)
565
+ loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
566
+ return loss
567
+
568
+ def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
569
+ lmbda = A * torch.exp(B * (1 - mask_ratio))
570
+ return lmbda
571
+
572
+ def cal_loss(self, logits, gt_indices, mask, mask_ratio):
573
+ mask_ratio = torch.tensor(mask_ratio)
574
+ ## cal cross entropy loss
575
+ task_loss = self.forward_loss(gt_indices, logits, mask)
576
+ lmbda = self.cal_lmbda(mask_ratio)
577
+ ## cal total loss for codec optimization
578
+ return task_loss, lmbda
579
+
580
+ def forward(self, imgs, is_training=False, manual_mask_rate=None):
581
+ ## ---------- encoding process ---------- ##
582
+ gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
583
+ latent = latent.unsqueeze(1)
584
+
585
+ latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
586
+ # 判断latent_hat和latent是否相等
587
+ # print((latent_hat == latent).all())
588
+ cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
589
+ sym = (token_all_mask.short() + 1).cpu()
590
+ bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
591
+ mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
592
+
593
+ ## ---------- decoding process ---------- ##
594
+ decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
595
+ decoded_mask = (decoded_sym - 1).to(device=imgs.device)
596
+ latent_hat = latent_hat.squeeze(1)
597
+ x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
598
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
599
+ ## calculate loss
600
+ task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
601
+ return_dict = {
602
+ 'logits': logits,
603
+ 'likelihoods': latent_likelihoods,
604
+ 'task_loss': task_loss,
605
+ 'token_indices': token_indices,
606
+ 'token_all_mask': token_all_mask,
607
+ 'bs_mask_token': bs_mask_token,
608
+ 'mask_ratio': mask_ratio,
609
+ 'lambda': lmbda,
610
+ 'mask_vis': 1 - mask_vis,
611
+ 'z_H': z_H,
612
+ 'z_W': z_W,
613
+ }
614
+ return return_dict
615
+
616
+ # def update(self, scale_table=None, force=False):
617
+ # if scale_table is None:
618
+ # scale_table = get_scale_table()
619
+ # updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
620
+ # updated |= super().update(force=force)
621
+ # return updated
622
+
623
+ def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
624
+ """
625
+ generated image at inference
626
+ seed: random seed
627
+ logits: predicted logits by model decoder
628
+ token_all_mask: mask token indices
629
+ token_indices: token indices of the input image after the vq tokenizer
630
+ num_iter: number of iterations for sampling
631
+ choice_temperature: temperature for sampling
632
+ """
633
+ # torch.manual_seed(seed)
634
+ # np.random.seed(seed)
635
+ bsz = logits.size(0)
636
+ codebook_emb_dim = 256
637
+ codebook_size = 1024
638
+ mask_token_id = self.mask_token_label
639
+ _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
640
+ unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
641
+ for step in range(num_iter):
642
+ if step == 0:
643
+ cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
644
+ cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
645
+ logits = logits[:, 1:, :codebook_size]
646
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
647
+ # For iter=6, we use categorical sampling and temp=4.5."
648
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
649
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
650
+ # get ids for next step
651
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
652
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
653
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
654
+ unknown_map = (cur_ids == mask_token_id)
655
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
656
+ # Defines the mask ratio for the next round. The number to mask out is
657
+ # determined by mask_ratio * unknown_number_in_the_beginning.
658
+ ratio = 1. * (step + 1) / num_iter
659
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
660
+
661
+ # sample ids according to prediction confidence
662
+ probs = torch.nn.functional.softmax(logits, dim=-1)
663
+ selected_probs = torch.squeeze(
664
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
665
+
666
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
667
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
668
+ mask_ratio = torch.tensor(mask_ratio).cuda()
669
+ # mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
670
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
671
+ # Keeps at least one of prediction in this round and also masks out at least
672
+ # one and for the next iteration
673
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
674
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
675
+
676
+ # Sample masking tokens for next iteration
677
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
678
+ # Masks tokens with lower confidence.
679
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
680
+ else:
681
+ cur_ids = token_indices.clone().long() # .long(): to int64
682
+ token_indices = torch.cat(
683
+ [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
684
+ token_indices[:, 0] = self.fake_class_label
685
+ token_indices = token_indices.long()
686
+ token_all_mask = token_indices == mask_token_id
687
+
688
+ token_drop_mask = torch.zeros_like(token_indices)
689
+
690
+ # token embedding
691
+ input_embeddings = self.token_emb(token_indices) # get input embeddings
692
+
693
+ # encoder
694
+ x = input_embeddings
695
+ for blk in self.blocks:
696
+ x = blk(x)
697
+ x = self.norm(x)
698
+
699
+ # decoder
700
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
701
+ logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
702
+
703
+ # get token prediction
704
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
705
+ # For iter=6, we use categorical sampling and temp=4.5."
706
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
707
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
708
+
709
+ # get ids for next step
710
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
711
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
712
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
713
+ unknown_map = (cur_ids == mask_token_id)
714
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
715
+ # Defines the mask ratio for the next round. The number to mask out is
716
+ # determined by mask_ratio * unknown_number_in_the_beginning.
717
+ ratio = 1. * (step + 1) / num_iter
718
+
719
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
720
+
721
+ # sample ids according to prediction confidence
722
+ probs = torch.nn.functional.softmax(logits, dim=-1)
723
+ selected_probs = torch.squeeze(
724
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
725
+
726
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
727
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
728
+ mask_ratio = torch.tensor(mask_ratio).cuda()
729
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
730
+ # Keeps at least one of prediction in this round and also masks out at least
731
+ # one and for the next iteration
732
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
733
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
734
+
735
+ # Sample masking tokens for next iteration
736
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
737
+ # Masks tokens with lower confidence.
738
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
739
+
740
+ # vqgan visualization
741
+ z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
742
+ gen_images = self.vqgan.decode(z_q)
743
+ return gen_images
744
+
745
+
746
+ def mage_vit_base_patch16(**kwargs):
747
+ model = MaskedGenerativeEncoderViT(
748
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
749
+ decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
750
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
751
+ return model
752
+
753
+
754
+ def mage_vit_large_patch16(**kwargs):
755
+ model = MaskedGenerativeEncoderViT(
756
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
757
+ decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
758
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
759
+ return model
deploited/models_mage_codec_rope.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from timm.models.vision_transformer import PatchEmbed, DropPath, Mlp
7
+
8
+ from util.pos_embed import get_2d_sincos_pos_embed
9
+
10
+ from taming.models.vqgan import VQModel
11
+ from omegaconf import OmegaConf
12
+ import numpy as np
13
+ import scipy.stats as stats
14
+ from compressai.entropy_models import EntropyBottleneck
15
+ from compressai.layers import conv3x3, subpel_conv3x3
16
+ import math
17
+ from torch import Tensor
18
+ from einops import rearrange, repeat
19
+ import torch.nn.functional as F
20
+ import torchac
21
+ from typing import Any, Callable, List, Optional, Tuple, Union
22
+ from util.RoPE import LlamaDynamicNTKScalingRotaryEmbedding, apply_rotary_pos_emb
23
+
24
+ SCALES_MIN = 0.11
25
+ SCALES_MAX = 256
26
+ SCALES_LEVELS = 64
27
+ def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS):
28
+ return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
29
+
30
+ def ste_round(x: Tensor) -> Tensor:
31
+ return torch.round(x) - x.detach() + x
32
+
33
+ def conv(in_channels, out_channels, kernel_size=5, stride=2):
34
+ return nn.Conv2d(
35
+ in_channels,
36
+ out_channels,
37
+ kernel_size=kernel_size,
38
+ stride=stride,
39
+ padding=kernel_size // 2,
40
+ )
41
+
42
+ def mask_by_random_topk(mask_len, probs, temperature=1.0):
43
+ mask_len = mask_len.squeeze()
44
+ # 使用Gumbel分布进行采样,增加随机性
45
+ confidence = torch.log(probs) + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()
46
+ sorted_confidence, _ = torch.sort(confidence, axis=-1)
47
+ # Obtains cut off threshold given the mask lengths.
48
+ cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
49
+ # Masks tokens with lower confidence.
50
+ masking = (confidence <= cut_off)
51
+ return masking
52
+
53
+ def adjust_mask_and_drop_embeddings(token_keep_mask):
54
+ """
55
+ Adjusts the token_keep_mask to the nearest square number of True values by randomly setting
56
+ some of them to False, and then applies this adjusted mask to input_embeddings.
57
+
58
+ Parameters:
59
+ - input_embeddings: Tensor, The embeddings tensor.
60
+ - token_keep_mask: BoolTensor, The mask tensor indicating which tokens to keep.
61
+
62
+ Returns:
63
+ - Tensor, Adjusted input embeddings after applying the modified token_keep_mask.
64
+ """
65
+ # 获取非零(即值为True)元素的索引
66
+ non_zero_indices = token_keep_mask.nonzero(as_tuple=True)
67
+ # 计算非零元素的数量
68
+ non_zero_count = non_zero_indices[0].size(0)
69
+ # 计算最近的整数平方倍
70
+ next_square = math.floor(math.sqrt(non_zero_count))**2
71
+ # 计算需要移除的元素数量
72
+ remove_count = non_zero_count - next_square
73
+ if remove_count > 0:
74
+ # 如果需要移除元素以达到整数平方倍
75
+ permuted_indices = torch.randperm(non_zero_count)[:remove_count]
76
+ for idx in permuted_indices:
77
+ token_keep_mask[non_zero_indices[0][idx], non_zero_indices[1][idx]] = False
78
+ # 使用更新后的token_keep_mask
79
+ # input_embeddings_after_drop = input_embeddings[token_keep_mask]
80
+
81
+ return token_keep_mask
82
+
83
+
84
+ class FactorizedEntropyModel(EntropyBottleneck):
85
+ def __init__(self, *args, **kwargs):
86
+ super().__init__(*args, **kwargs)
87
+
88
+ def forward(self, x: Tensor, training: Optional[bool] = None) -> Tuple[Tensor, Tensor]:
89
+ if training is None:
90
+ training = self.training
91
+
92
+ # 输入形状已经是 [b, c, seq_len],无需转置
93
+ shape = x.size()
94
+
95
+ # Add noise or quantize
96
+ means = self._get_medians()
97
+ # outputs = self.quantize(
98
+ # x, "noise" if training else "dequantize", means.long()
99
+ # )
100
+ outputs = self.quantize(
101
+ x, "dequantize", means.long()
102
+ )
103
+
104
+ if not torch.jit.is_scripting():
105
+ likelihood = self._likelihood(outputs)
106
+ if self.use_likelihood_bound:
107
+ likelihood = self.likelihood_lower_bound(likelihood)
108
+ else:
109
+ raise NotImplementedError("TorchScript is not yet supported")
110
+
111
+ return outputs, likelihood
112
+
113
+ def compress(self, x):
114
+ # 构建索引,适用于单通道序列数据
115
+ indexes = self._build_indexes(x.size())
116
+ # 获取中位数,已经适配为单通道
117
+ medians = self._get_medians().detach()
118
+ # 调整 medians 的形状以匹配 x 的形状
119
+ medians = medians.expand_as(x)
120
+ # 调用基类的 compress 方法进行压缩
121
+ return super().compress(x, indexes, medians)
122
+
123
+ def decompress(self, strings, size):
124
+ # 预期的输出大小应包括单个通道
125
+ output_size = (len(strings), 1, *size) # 这里 size 应该是 seq_len
126
+ # 构建索引
127
+ indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
128
+ # 获取中位数并调整其形状以匹配预期输出的形状
129
+ medians = self._extend_ndims(self._get_medians().detach(), len(size))
130
+ medians = medians.expand(len(strings), 1, *([-1] * len(size)))
131
+ # 调用基类的 decompress 方法进行解压缩
132
+ return super().decompress(strings, indexes, medians.dtype, medians)
133
+
134
+ def _preprocess(self, x):
135
+ x = x.permute(0, 2, 3, 1).contiguous()
136
+ return x
137
+
138
+
139
+ class Attention(nn.Module):
140
+ def __init__(self, dim, max_position_embeddings, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
141
+ super().__init__()
142
+ self.max_position_embeddings = max_position_embeddings
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
146
+ self.scale = qk_scale or head_dim ** -0.5
147
+
148
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
149
+ self.attn_drop = nn.Dropout(attn_drop)
150
+ self.proj = nn.Linear(dim, dim)
151
+ self.proj_drop = nn.Dropout(proj_drop)
152
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
153
+ head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=8
154
+ )
155
+ # self.register_buffer("position_ids", torch.arange(self.max_position_embeddings).expand((1, -1)))
156
+
157
+ def forward(self, x):
158
+ B, N, C = x.shape
159
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3,B,num_heads,N,head_dim)
160
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
161
+ cos, sin = self.rotary_emb(v, seq_len=N)
162
+ position_ids = torch.arange(N).expand((1, -1))
163
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
164
+
165
+ with torch.cuda.amp.autocast(enabled=False):
166
+ attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
167
+
168
+ attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
169
+ attn = attn.softmax(dim=-1)
170
+ attn = self.attn_drop(attn)
171
+
172
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
173
+ x = self.proj(x)
174
+ x = self.proj_drop(x)
175
+ # x是经过self-attn后的feature,attn是注意力权重矩阵,描述输入序列中各个元素之间的相关性
176
+ return x, attn
177
+
178
+
179
+ class Block(nn.Module):
180
+
181
+ def __init__(self, dim, max_position_embeddings, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
182
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
183
+ super().__init__()
184
+ self.norm1 = norm_layer(dim)
185
+ self.attn = Attention(
186
+ dim, max_position_embeddings, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
187
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
188
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # drop_path=0
189
+ self.norm2 = norm_layer(dim)
190
+ mlp_hidden_dim = int(dim * mlp_ratio) # mlp_ratio=4
191
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # drop=0.1
192
+
193
+ def forward(self, x, return_attention=False):
194
+ if return_attention:
195
+ _, attn = self.attn(self.norm1(x))
196
+ return attn
197
+ else:
198
+ y, _ = self.attn(self.norm1(x))
199
+ x = x + self.drop_path(y)
200
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
201
+ return x
202
+
203
+
204
+ class LabelSmoothingCrossEntropy(nn.Module):
205
+ """ NLL loss with label smoothing.
206
+ """
207
+ def __init__(self, smoothing=0.1):
208
+ super(LabelSmoothingCrossEntropy, self).__init__()
209
+ assert smoothing < 1.0
210
+ self.smoothing = smoothing
211
+ self.confidence = 1. - smoothing
212
+
213
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
214
+ logprobs = torch.nn.functional.log_softmax(x, dim=-1)
215
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
216
+ nll_loss = nll_loss.squeeze(1)
217
+ smooth_loss = -logprobs.mean(dim=-1)
218
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
219
+ return loss
220
+
221
+
222
+ class BertEmbeddings(nn.Module):
223
+ """Construct the embeddings from word, position and token_type embeddings."""
224
+
225
+ def __init__(self, vocab_size, hidden_size, max_position_embeddings, dropout=0.1):
226
+ super().__init__()
227
+ self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
228
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
229
+
230
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
231
+ # any TensorFlow checkpoint file
232
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=1e-6)
233
+ self.dropout = nn.Dropout(dropout)
234
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
235
+ self.register_buffer("position_ids", torch.arange(max_position_embeddings).expand((1, -1))) # (1, 257)
236
+ # 创建一个形状为(1, max_position_embeddings)的缓冲张量position_ids,其包含了从0���max_position_embeddings-1的整数。
237
+ # 这个缓冲张量将被用于获取position_embeddings的位置信息,以便在前向传播过程中使用
238
+
239
+ torch.nn.init.normal_(self.word_embeddings.weight, std=.02)
240
+ torch.nn.init.normal_(self.position_embeddings.weight, std=.02)
241
+
242
+ def forward(
243
+ self, input_ids
244
+ ):
245
+ input_shape = input_ids.size() # input_ids: (B, N)(32,1 + latent_h * latent_w)
246
+ seq_length = input_shape[1]
247
+ # position_ids = self.position_ids[:, :seq_length]
248
+ inputs_embeds = self.word_embeddings(input_ids) # (B, seq_len, embed_dim)
249
+ # position_embeddings = self.position_embeddings(position_ids) # (1, seq_len, embed_dim)
250
+
251
+ # embeddings = inputs_embeds + position_embeddings
252
+ embeddings = inputs_embeds
253
+
254
+ embeddings = self.LayerNorm(embeddings)
255
+ embeddings = self.dropout(embeddings)
256
+ return embeddings
257
+
258
+
259
+ class MlmLayer(nn.Module):
260
+
261
+ def __init__(self, feat_emb_dim, word_emb_dim, vocab_size):
262
+ super().__init__()
263
+ self.fc = nn.Linear(feat_emb_dim, word_emb_dim)
264
+ self.gelu = nn.GELU()
265
+ self.ln = nn.LayerNorm(word_emb_dim)
266
+ self.bias = nn.Parameter(torch.zeros(1, 1, vocab_size))
267
+
268
+ def forward(self, x, word_embeddings): # x: (b, seq_len, embed_dim)
269
+ mlm_hidden = self.fc(x)
270
+ mlm_hidden = self.gelu(mlm_hidden)
271
+ mlm_hidden = self.ln(mlm_hidden)
272
+ word_embeddings = word_embeddings.transpose(0, 1)
273
+ logits = torch.matmul(mlm_hidden, word_embeddings)
274
+ logits = logits + self.bias
275
+ return logits # (b, seq_len, vocab_size) 表示对于输入序列中的每个位置,模型预测它对应词汇表中每个单词的原始单词的未归一化概率
276
+
277
+
278
+ class MaskedGenerativeEncoderViT(nn.Module):
279
+ """ Masked Autoencoder with VisionTransformer backbone
280
+ """
281
+ def __init__(self, img_size=256, patch_size=16, in_chans=3, # need to change the default value of img_size
282
+ embed_dim=1024, depth=24, num_heads=16,
283
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
284
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False,
285
+ mask_ratio_min=0.5, mask_ratio_max=0.8, mask_ratio_mu=0.55, mask_ratio_std=0.25,
286
+ vqgan_ckpt_path='vqgan_jax_strongaug.ckpt'):
287
+ super().__init__()
288
+
289
+ # --------------------------------------------------------------------------
290
+ # VQGAN specifics
291
+ config = OmegaConf.load('/home/t2vg-a100-G4-10/project/qyp/mimc_rope/config/vqgan.yaml').model
292
+ self.vqgan = VQModel(ddconfig=config.params.ddconfig,
293
+ n_embed=config.params.n_embed, # 1024
294
+ embed_dim=config.params.embed_dim, # 256
295
+ ckpt_path=vqgan_ckpt_path)
296
+ for param in self.vqgan.parameters():
297
+ param.requires_grad = False
298
+
299
+ self.codebook_size = config.params.n_embed # 1024
300
+ vocab_size = self.codebook_size + 1000 + 1 # 1024 codebook size, 1000 classes, 1 for mask token.
301
+ self.fake_class_label = self.codebook_size + 1100 - 1024 # 1100
302
+ self.mask_token_label = vocab_size - 1 # 2024
303
+ self.max_position_embeddings = 256 + 1
304
+ self.token_emb = BertEmbeddings(vocab_size=vocab_size, # 向量空间大小,1024个embedding + 1000 class + 1 mask token
305
+ hidden_size=embed_dim,
306
+ max_position_embeddings=256 +1,
307
+ # max_position_embeddings=256+1, # 256个patch + 1 class token
308
+ dropout=0.1)
309
+
310
+ # MAGE variant masking ratio
311
+ self.mask_ratio_min = mask_ratio_min
312
+ self.mask_ratio_max = mask_ratio_max
313
+ # self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - mask_ratio_mu) / mask_ratio_std,
314
+ # (mask_ratio_max - mask_ratio_mu) / mask_ratio_std,
315
+ # loc=mask_ratio_mu, scale=mask_ratio_std)
316
+
317
+ # --------------------------------------------------------------------------
318
+ # MAGE encoder specifics
319
+ dropout_rate = 0.1
320
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) # 256, 16, 3, 1024, (B,N,C) n: 256/16*256/16=256, c=1024
321
+ num_patches = self.patch_embed.num_patches # 256
322
+
323
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
324
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
325
+
326
+ self.blocks = nn.ModuleList([ # encoder
327
+ Block(embed_dim, self.max_position_embeddings, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
328
+ drop=dropout_rate, attn_drop=dropout_rate)
329
+ for i in range(depth)]) # depth=12 for mage-vitb, embed_dim=768
330
+ self.norm = norm_layer(embed_dim) # layer norm
331
+ # --------------------------------------------------------------------------
332
+
333
+ # --------------------------------------------------------------------------
334
+ # MAGE decoder specifics
335
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
336
+
337
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) # decoder_embed_dim=512
338
+ self.pad_with_cls_token = True
339
+
340
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
341
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim)) # learnable pos embedding
342
+
343
+ self.decoder_blocks = nn.ModuleList([
344
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer,
345
+ drop=dropout_rate, attn_drop=dropout_rate)
346
+ for i in range(decoder_depth)]) # decoder_depth=8 for mage-vitb
347
+
348
+ self.decoder_norm = norm_layer(decoder_embed_dim)
349
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
350
+ # --------------------------------------------------------------------------
351
+
352
+ # --------------------------------------------------------------------------
353
+ # MlmLayer
354
+ self.mlm_layer = MlmLayer(feat_emb_dim=decoder_embed_dim, word_emb_dim=embed_dim, vocab_size=vocab_size)
355
+
356
+ self.norm_pix_loss = norm_pix_loss
357
+
358
+ self.criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
359
+ # --------------------------------------------------------------------------
360
+ self.entropy_bottleneck = FactorizedEntropyModel(1)
361
+
362
+ self.initialize_weights()
363
+
364
+ def initialize_weights(self):
365
+ # initialization
366
+ # initialize (and freeze) pos_embed by sin-cos embedding
367
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
368
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
369
+
370
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
371
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
372
+
373
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
374
+ w = self.patch_embed.proj.weight.data
375
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
376
+
377
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
378
+ torch.nn.init.normal_(self.cls_token, std=.02)
379
+ torch.nn.init.normal_(self.mask_token, std=.02)
380
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
381
+
382
+ # initialize nn.Linear and nn.LayerNorm
383
+ self.apply(self._init_weights)
384
+
385
+ def _init_weights(self, m):
386
+ if isinstance(m, nn.Linear):
387
+ # we use xavier_uniform following official JAX ViT:
388
+ torch.nn.init.xavier_uniform_(m.weight)
389
+ if isinstance(m, nn.Linear) and m.bias is not None:
390
+ nn.init.constant_(m.bias, 0)
391
+ elif isinstance(m, nn.LayerNorm):
392
+ nn.init.constant_(m.bias, 0)
393
+ nn.init.constant_(m.weight, 1.0)
394
+
395
+ def random_sample_mask_rate(self):
396
+ # 生成一个 (0, 1] 范围内的随机数
397
+ random_sample = 1 - torch.rand(1)
398
+ # 映射到 mask_ratio_min 到 mask_ratio_max 的范围
399
+ mask_rate = self.mask_ratio_min + random_sample * (self.mask_ratio_max - self.mask_ratio_min)
400
+ return mask_rate.item() # 转换为Python的标量值
401
+
402
+ def get_cdf_token_mask(self, token_all_mask):
403
+ bsz, seq_len = token_all_mask.size()
404
+ # --- use Normal distribution.
405
+ dist_normal = torch.distributions.Normal(0, 2)
406
+ cdf_mask_token = dist_normal.cdf(torch.arange(1, seq_len + 1))
407
+ cdf_mask_token = (cdf_mask_token - .5) * 2
408
+ cdf_mask_token = repeat(cdf_mask_token, 'Lp -> b s Lp',
409
+ b=bsz, s=seq_len)
410
+
411
+ cdf_mask_token = F.pad(cdf_mask_token, (1, 0))
412
+ return cdf_mask_token
413
+
414
+ # def get_dec_pos_embed(self, z_H, z_W):
415
+ # # Exclude the first embedding (position 0) before interpolation
416
+ # position_embeddings = self.decoder_pos_embed_learned # (b, 256, embed_dim)
417
+ # # Perform bilinear interpolation
418
+ # interpolated_embeddings = F.interpolate(
419
+ # position_embeddings.permute(0, 2, 1),
420
+ # size=(1+z_H*z_W), mode="linear", align_corners=False
421
+ # ).permute(0, 2, 1) # (b, seq_len -1 , embed_dim)
422
+ # # Concatenate the first embedding back
423
+ # # position_embeddings = torch.cat([self.decoder_pos_embed_learned[:, :1, :], interpolated_embeddings], dim=1)
424
+ # position_embeddings = interpolated_embeddings
425
+
426
+ # return position_embeddings
427
+
428
+ def pre_encoding(self, x, is_training=False, manual_mask_rate=None):
429
+ """
430
+ input: x: (B, 3, H, W)
431
+ """
432
+ # ============ 1. tokenization ============ #
433
+ with torch.no_grad():
434
+ z_q, _, token_tuple = self.vqgan.encode(x) # z_q: (B, 256, 16, 16), token_tuple: (B, 256, 16, 16)
435
+
436
+ _, _, z_H, z_W = z_q.size()
437
+ _, _, token_indices = token_tuple # token_indices: (B*H*W,)(8192)
438
+ token_indices = token_indices.reshape(z_q.size(0), -1) # token_indices: (B, H*W)
439
+ gt_indices = token_indices.clone().detach().long()
440
+
441
+ # ============ 2. masking process ============ #
442
+ bsz, seq_len = token_indices.size() # seq_len=h*w
443
+ mask_ratio_min = self.mask_ratio_min # 0.5
444
+
445
+ if is_training:
446
+ # mask_rate = self.mask_ratio_generator.rvs(1)[0]
447
+ mask_rate = self.random_sample_mask_rate()
448
+ num_dropped_tokens = int(np.ceil(seq_len * mask_ratio_min))
449
+ else:
450
+ num_dropped_tokens = 0
451
+ if manual_mask_rate is not None:
452
+ mask_rate = manual_mask_rate
453
+ else:
454
+ raise ValueError("mask_rate should be provided for inference!")
455
+
456
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
457
+ mask_ratio = num_masked_tokens / seq_len # for calculate vbr lambda
458
+ # it is possible that two elements of the noise is the same, so do a while loop to avoid it
459
+ while True:
460
+ noise = torch.rand(bsz, seq_len, device=x.device) # noise in [0, 1]
461
+ sorted_noise, _ = torch.sort(noise, dim=1) # ascend: small is remove, large is keep
462
+ if num_dropped_tokens > 0:
463
+ cutoff_drop = sorted_noise[:, num_dropped_tokens-1:num_dropped_tokens]
464
+ else:
465
+ cutoff_drop = torch.zeros((bsz, 1), device=x.device)
466
+ cutoff_mask = sorted_noise[:, num_masked_tokens-1:num_masked_tokens]
467
+ token_drop_mask = (noise <= cutoff_drop).float() # 逻辑上标记那些token是被drop掉的
468
+ token_all_mask = (noise <= cutoff_mask).float() # 逻辑上标记那些token是被mask掉的
469
+ if token_drop_mask.sum() == bsz*num_dropped_tokens and token_all_mask.sum() == bsz*num_masked_tokens:
470
+ break
471
+ else:
472
+ print("Rerandom the noise!")
473
+
474
+ # 获取unmasked token及其位置
475
+ unmasked_pos = token_all_mask == 0 # 未被mask的位置
476
+ unmaksed_token_indices = token_indices[unmasked_pos].reshape(bsz, -1) # 未被mask的token
477
+
478
+ return gt_indices, token_indices, unmaksed_token_indices, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W
479
+
480
+ def pre_decoding(self, gt_indices, unmaksed_token_indices, token_all_mask, token_drop_mask):
481
+ bsz, seq_len = gt_indices.size()
482
+ padded_token_indices = torch.full_like(gt_indices, fill_value=self.mask_token_label)
483
+ # 将未被mask的token填充回去
484
+ # 我们需要一个计数器来追踪每个batch中已经填充了多少个unmaksed_token_indices
485
+ unmasked_token_counter = [0 for _ in range(bsz)]
486
+
487
+ for b in range(bsz):
488
+ for idx in range(seq_len):
489
+ # 如果当前位置未被mask,则从unmaksed_token_indices填充;否则,保留mask_token_label
490
+ if (token_all_mask[b, idx] == 0): # 检查是否未被mask
491
+ # 替换相应的unmaksed token
492
+ padded_token_indices[b, idx] = unmaksed_token_indices[b, unmasked_token_counter[b]]
493
+ # 更新计数器
494
+ unmasked_token_counter[b] += 1
495
+
496
+ token_indices = padded_token_indices
497
+ # ============ 3. Adding class token ============ #
498
+ # concate class token, add [CLS] token to aggregate sequence-level representations
499
+ token_indices = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
500
+ token_indices[:, 0] = self.fake_class_label # [B, 257]
501
+ # Masks (token_drop_mask and token_all_mask) are updated to account for the added class token,
502
+ # ensuring the first position is always kept by setting it to 0 (indicating "do not mask/drop")
503
+ # 添加0向量,和token_indices,表示[CLS] token不会被mask/drop
504
+ token_drop_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_drop_mask], dim=1)
505
+ token_all_mask = torch.cat([torch.zeros(token_indices.size(0), 1).cuda(), token_all_mask], dim=1)
506
+ token_indices = token_indices.long()
507
+
508
+ # ============ 4. Embedding and Dropout ============ #
509
+ # bert embedding
510
+ input_embeddings = self.token_emb(token_indices) # get embeddings [B, 257, 768]
511
+ # print("Input embedding shape:", input_embeddings.shape)
512
+ bsz, seq_len, emb_dim = input_embeddings.shape
513
+
514
+ # dropping
515
+ token_keep_mask = 1 - token_drop_mask
516
+ input_embeddings_after_drop = input_embeddings[token_keep_mask.nonzero(as_tuple=True)].reshape(bsz, -1, emb_dim)
517
+ # print("Input embedding after drop shape:", input_embeddings_after_drop.shape)
518
+
519
+ # ============ 5. Transformer encoding ============ #
520
+ x = input_embeddings_after_drop # (B, seq_len_after_drop, embed_dim) # 32, 129, 768
521
+ for blk in self.blocks:
522
+ x = blk(x) # each block has a multi-head self-attention and a mlp
523
+ x = self.norm(x)
524
+ # print("Encoder representation shape:", x.shape)
525
+
526
+ return x, token_indices, token_all_mask, token_drop_mask
527
+
528
+ def forward_decoding(self, x, token_drop_mask, token_all_mask):
529
+ """
530
+ x: output x of forward_encoder()
531
+ token_drop_mask: positions for dropped tokens
532
+ token_all_mask: positions for masked tokens
533
+ """
534
+ # ============ 1. Prepare Embedding and padding tokens ============ #
535
+ # embed tokens
536
+ x = self.decoder_embed(x) # input_embedding_after_padding
537
+
538
+ # append mask tokens to sequence
539
+ # replicates the [CLS] token embedding across the sequence length where masking is to be applied
540
+ if self.pad_with_cls_token: # True
541
+ mask_tokens = x[:, 0:1].repeat(1, token_all_mask.shape[1], 1)
542
+ else:
543
+ mask_tokens = self.mask_token.repeat(token_all_mask.shape[0], token_all_mask.shape[1], 1)
544
+
545
+ # ============ 2. Prepare positional embedding ============ #
546
+ # put undropped tokens into original sequence
547
+ x_after_pad = mask_tokens.clone() # 未被drop的tokens被填充回去
548
+ x_after_pad[(1 - token_drop_mask).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
549
+ # set undropped but masked positions with mask
550
+ x_after_pad = torch.where(token_all_mask.unsqueeze(-1).bool(), mask_tokens, x_after_pad) # 被drop的也padding
551
+
552
+ # add pos embed
553
+ # x = x_after_pad + self.decoder_pos_embed_learned # add learnable pos embedding
554
+ x = x_after_pad
555
+
556
+ # apply Transformer blocks
557
+ for blk in self.decoder_blocks:
558
+ x = blk(x)
559
+
560
+ x = self.decoder_norm(x)
561
+
562
+ word_embeddings = self.token_emb.word_embeddings.weight.data.detach()
563
+ logits = self.mlm_layer(x, word_embeddings) # produce predictions for masked tokens
564
+ # print("Logits shape:", x.shape)
565
+
566
+ return logits
567
+
568
+ def forward_loss(self, gt_indices, logits, mask):
569
+ bsz, seq_len = gt_indices.size()
570
+ # logits and mask are with seq_len+1 but gt_indices is with seq_len
571
+ loss = self.criterion(logits[:, 1:, :self.codebook_size].reshape(bsz*seq_len, -1), gt_indices.reshape(bsz*seq_len))
572
+ loss = loss.reshape(bsz, seq_len)
573
+ loss = (loss * mask[:, 1:]).sum() / mask[:, 1:].sum() # mean loss on removed patches
574
+ return loss
575
+
576
+ def cal_lmbda(self, mask_ratio, A=5e-1, B=8):
577
+ lmbda = A * torch.exp(B * (1 - mask_ratio))
578
+ return lmbda
579
+
580
+ def cal_loss(self, logits, gt_indices, mask, mask_ratio):
581
+ mask_ratio = torch.tensor(mask_ratio)
582
+ ## cal cross entropy loss
583
+ task_loss = self.forward_loss(gt_indices, logits, mask)
584
+ lmbda = self.cal_lmbda(mask_ratio)
585
+ ## cal total loss for codec optimization
586
+ return task_loss, lmbda
587
+
588
+ def forward(self, imgs, is_training=False, manual_mask_rate=None):
589
+ ## ---------- encoding process ---------- ##
590
+ gt_indices, token_indices, latent, token_all_mask, token_drop_mask, mask_ratio, z_H, z_W = self.pre_encoding(imgs, is_training, manual_mask_rate)
591
+ latent = latent.unsqueeze(1)
592
+
593
+ latent_hat, latent_likelihoods = self.entropy_bottleneck(latent)
594
+ # 判断latent_hat和latent是否相等
595
+ # print((latent_hat == latent).all())
596
+ cdf_mask_token = self.get_cdf_token_mask(token_all_mask).cpu()
597
+ sym = (token_all_mask.short() + 1).cpu()
598
+ bs_mask_token = torchac.encode_float_cdf(cdf_mask_token, sym, check_input_bounds=True)
599
+ mask_vis = rearrange(token_all_mask, 'b (h w) -> b h w', h=z_H, w=z_W).unsqueeze(1)
600
+
601
+ ## ---------- decoding process ---------- ##
602
+ decoded_sym = torchac.decode_float_cdf(cdf_mask_token, bs_mask_token)
603
+ decoded_mask = (decoded_sym - 1).to(device=imgs.device)
604
+ latent_hat = latent_hat.squeeze(1)
605
+ x, token_indices, token_all_mask, token_drop_mask = self.pre_decoding(gt_indices, latent_hat, decoded_mask, token_drop_mask)
606
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
607
+ ## calculate loss
608
+ task_loss, lmbda = self.cal_loss(logits, gt_indices, token_all_mask, mask_ratio)
609
+ return_dict = {
610
+ 'logits': logits,
611
+ 'likelihoods': latent_likelihoods,
612
+ 'task_loss': task_loss,
613
+ 'token_indices': token_indices,
614
+ 'token_all_mask': token_all_mask,
615
+ 'bs_mask_token': bs_mask_token,
616
+ 'mask_ratio': mask_ratio,
617
+ 'lambda': lmbda,
618
+ 'mask_vis': 1 - mask_vis,
619
+ 'z_H': z_H,
620
+ 'z_W': z_W,
621
+ }
622
+ return return_dict
623
+
624
+ # def update(self, scale_table=None, force=False):
625
+ # if scale_table is None:
626
+ # scale_table = get_scale_table()
627
+ # updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
628
+ # updated |= super().update(force=force)
629
+ # return updated
630
+
631
+ def gen_img(self, logits, token_all_mask, token_indices, z_H, z_W, num_iter=12, choice_temperature=4.5):
632
+ """
633
+ generated image at inference
634
+ seed: random seed
635
+ logits: predicted logits by model decoder
636
+ token_all_mask: mask token indices
637
+ token_indices: token indices of the input image after the vq tokenizer
638
+ num_iter: number of iterations for sampling
639
+ choice_temperature: temperature for sampling
640
+ """
641
+ # torch.manual_seed(seed)
642
+ # np.random.seed(seed)
643
+ bsz = logits.size(0)
644
+ codebook_emb_dim = 256
645
+ codebook_size = 1024
646
+ mask_token_id = self.mask_token_label
647
+ _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
648
+ unknown_number_in_the_beginning = torch.sum(token_all_mask, dim=-1, keepdims=True).float()
649
+ for step in range(num_iter):
650
+ if step == 0:
651
+ cur_ids = token_indices.clone().long() # token_indices represent the current state of the sequence(unmasked tokens)
652
+ cur_ids = cur_ids[:, 1:] # 从第二列开始到最后一列
653
+ logits = logits[:, 1:, :codebook_size]
654
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
655
+ # For iter=6, we use categorical sampling and temp=4.5."
656
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
657
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
658
+ # get ids for next step
659
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
660
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
661
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
662
+ unknown_map = (cur_ids == mask_token_id)
663
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
664
+ # Defines the mask ratio for the next round. The number to mask out is
665
+ # determined by mask_ratio * unknown_number_in_the_beginning.
666
+ ratio = 1. * (step + 1) / num_iter
667
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
668
+
669
+ # sample ids according to prediction confidence
670
+ probs = torch.nn.functional.softmax(logits, dim=-1)
671
+ selected_probs = torch.squeeze(
672
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
673
+
674
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
675
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
676
+ mask_ratio = torch.tensor(mask_ratio).cuda()
677
+ # mask_len = torch.tensor([np.floor(unknown_number_in_the_beginning.numpy() * mask_ratio.numpy())]).cuda()
678
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
679
+ # Keeps at least one of prediction in this round and also masks out at least
680
+ # one and for the next iteration
681
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
682
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
683
+
684
+ # Sample masking tokens for next iteration
685
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
686
+ # Masks tokens with lower confidence.
687
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
688
+ else:
689
+ cur_ids = token_indices.clone().long() # .long(): to int64
690
+ token_indices = torch.cat(
691
+ [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
692
+ token_indices[:, 0] = self.fake_class_label
693
+ token_indices = token_indices.long()
694
+ token_all_mask = token_indices == mask_token_id
695
+
696
+ token_drop_mask = torch.zeros_like(token_indices)
697
+
698
+ # token embedding
699
+ input_embeddings = self.token_emb(token_indices) # get input embeddings
700
+
701
+ # encoder
702
+ x = input_embeddings
703
+ for blk in self.blocks:
704
+ x = blk(x)
705
+ x = self.norm(x)
706
+
707
+ # decoder
708
+ logits = self.forward_decoding(x, token_drop_mask, token_all_mask)
709
+ logits = logits[:, 1:, :codebook_size] # remove the cls token and dims > codebook_size
710
+
711
+ # get token prediction
712
+ # the author said a little tricky here, "For iter=1, they use argmax and temp=0.0.
713
+ # For iter=6, we use categorical sampling and temp=4.5."
714
+ sample_dist = torch.distributions.categorical.Categorical(logits=logits)
715
+ sampled_ids = sample_dist.sample() # sampled_ids = torch.argmax(logits, dim=-1)
716
+
717
+ # get ids for next step
718
+ # unknown_map: type bool,shape is same as cur_ids and sampled_ids, indicate where the value will be replace
719
+ # 根据unknown_map的值,在相应位置上选择sampled_ids或cur_ids中的值,并将其存储到sampled_ids张量中。
720
+ # 换句话说,它将模型预测的类别(在未知位置)与之前已知的类别(在已知位置)合并到一个张量中。
721
+ unknown_map = (cur_ids == mask_token_id)
722
+ sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
723
+ # Defines the mask ratio for the next round. The number to mask out is
724
+ # determined by mask_ratio * unknown_number_in_the_beginning.
725
+ ratio = 1. * (step + 1) / num_iter
726
+
727
+ mask_ratio = np.cos(math.pi / 2. * ratio) # ratio = cosine(Π/2 * i/num_iter)
728
+
729
+ # sample ids according to prediction confidence
730
+ probs = torch.nn.functional.softmax(logits, dim=-1)
731
+ selected_probs = torch.squeeze(
732
+ torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
733
+
734
+ selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
735
+ unknown_number_in_the_beginning = unknown_number_in_the_beginning.clone().detach().cuda()
736
+ mask_ratio = torch.tensor(mask_ratio).cuda()
737
+ mask_len = torch.floor(unknown_number_in_the_beginning * mask_ratio).long() # 每个iter剩余的mask token数
738
+ # Keeps at least one of prediction in this round and also masks out at least
739
+ # one and for the next iteration
740
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
741
+ torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))
742
+
743
+ # Sample masking tokens for next iteration
744
+ masking = mask_by_random_topk(mask_len[0], selected_probs, choice_temperature * (1 - ratio))
745
+ # Masks tokens with lower confidence.
746
+ token_indices = torch.where(masking, mask_token_id, sampled_ids)
747
+
748
+ # vqgan visualization
749
+ z_q = self.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, z_H, z_W, codebook_emb_dim))
750
+ gen_images = self.vqgan.decode(z_q)
751
+ return gen_images
752
+
753
+
754
+ def mage_vit_base_patch16(**kwargs):
755
+ model = MaskedGenerativeEncoderViT(
756
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
757
+ decoder_embed_dim=768, decoder_depth=8, decoder_num_heads=16,
758
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
759
+ return model
760
+
761
+
762
+ def mage_vit_large_patch16(**kwargs):
763
+ model = MaskedGenerativeEncoderViT(
764
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
765
+ decoder_embed_dim=1024, decoder_depth=8, decoder_num_heads=16,
766
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
767
+ return model
deploited/random_with_rope.log ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-04-26 09:40:33,300 [INFO ] Logging file is /home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr//20240426_094033.log
2
+ 2024-04-26 09:40:33,301 [INFO ] ======================= exp =======================
3
+ 2024-04-26 09:40:33,301 [INFO ] config:config/vpt_default.yaml
4
+ 2024-04-26 09:40:33,301 [INFO ] name:exp
5
+ 2024-04-26 09:40:33,301 [INFO ] lr:None
6
+ 2024-04-26 09:40:33,301 [INFO ] world_size:1
7
+ 2024-04-26 09:40:33,301 [INFO ] local_rank:-1
8
+ 2024-04-26 09:40:33,301 [INFO ] dist_on_itp:False
9
+ 2024-04-26 09:40:33,301 [INFO ] dist_url:env://
10
+ 2024-04-26 09:40:33,301 [INFO ] TEST:False
11
+ 2024-04-26 09:40:33,301 [INFO ] model:mage_vit_base_patch16
12
+ 2024-04-26 09:40:33,301 [INFO ] dataset:imagenet
13
+ 2024-04-26 09:40:33,301 [INFO ] dataset_path:/home/v-ruoyufeng/v-ruoyufeng/qyp/datasets/imagenet
14
+ 2024-04-26 09:40:33,301 [INFO ] eval_path:
15
+ 2024-04-26 09:40:33,302 [INFO ] batch_size:64
16
+ 2024-04-26 09:40:33,302 [INFO ] test_batch_size:96
17
+ 2024-04-26 09:40:33,302 [INFO ] num_workers:4
18
+ 2024-04-26 09:40:33,302 [INFO ] epochs:400
19
+ 2024-04-26 09:40:33,302 [INFO ] blr:0.0002
20
+ 2024-04-26 09:40:33,302 [INFO ] min_lr:5e-06
21
+ 2024-04-26 09:40:33,302 [INFO ] accum_iter:1
22
+ 2024-04-26 09:40:33,302 [INFO ] input_size:256
23
+ 2024-04-26 09:40:33,302 [INFO ] weight_decay:0.05
24
+ 2024-04-26 09:40:33,302 [INFO ] warmup_epochs:20
25
+ 2024-04-26 09:40:33,302 [INFO ] mask_ratio_min:0.5
26
+ 2024-04-26 09:40:33,302 [INFO ] mask_ratio_max:0.9
27
+ 2024-04-26 09:40:33,302 [INFO ] mask_ratio_mu:0.55
28
+ 2024-04-26 09:40:33,302 [INFO ] mask_ratio_std:0.25
29
+ 2024-04-26 09:40:33,303 [INFO ] grad_clip:3.0
30
+ 2024-04-26 09:40:33,303 [INFO ] device:cuda
31
+ 2024-04-26 09:40:33,303 [INFO ] distributed:True
32
+ 2024-04-26 09:40:33,303 [INFO ] seed:0
33
+ 2024-04-26 09:40:33,303 [INFO ] resume:/home/v-ruoyufeng/v-ruoyufeng/qyp/mimc_cls/checkpoint-220.pth
34
+ 2024-04-26 09:40:33,303 [INFO ] start_epoch:0
35
+ 2024-04-26 09:40:33,303 [INFO ] pin_mem:True
36
+ 2024-04-26 09:40:33,303 [INFO ] save:True
37
+ 2024-04-26 09:40:33,303 [INFO ] exp_name:MIM_vbr
38
+ 2024-04-26 09:40:33,303 [INFO ] root:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage
39
+ 2024-04-26 09:40:33,303 [INFO ] output_dir:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr
40
+ 2024-04-26 09:40:33,303 [INFO ] log_dir:/home/v-ruoyufeng/v-ruoyufeng/qyp/mage/MIM_vbr
41
+ 2024-04-26 09:40:33,303 [INFO ] strart_epoch:0
42
+ 2024-04-26 09:40:33,303 [INFO ] cuda:True
43
+ 2024-04-26 09:40:33,304 [INFO ] rank:0
44
+ 2024-04-26 09:40:33,304 [INFO ] gpu:0
45
+ 2024-04-26 09:40:33,304 [INFO ] dist_backend:nccl
46
+ 2024-04-26 09:40:33,304 [INFO ] ===================================================
47
+ 2024-04-26 09:43:59,448 [INFO ] -1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.19228|msssim:0.19877|lpips:0.65322|dists:0.41351|Test loss:26.43571
48
+ 2024-04-26 11:03:41,816 [INFO ] 0|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.59349|msssim:0.26058|lpips:0.58663|dists:0.36222|Test loss:24.07364
49
+ 2024-04-26 12:23:19,122 [INFO ] 1|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.76288|msssim:0.28263|lpips:0.56872|dists:0.33883|Test loss:23.62683
50
+ 2024-04-26 13:43:34,495 [INFO ] 2|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.85301|msssim:0.28844|lpips:0.55959|dists:0.32906|Test loss:23.47551
51
+ 2024-04-26 15:03:52,749 [INFO ] 3|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.92868|msssim:0.29865|lpips:0.55056|dists:0.32090|Test loss:23.33249
52
+ 2024-04-26 16:23:59,470 [INFO ] 4|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.89536|msssim:0.30246|lpips:0.54562|dists:0.31920|Test loss:23.25514
53
+ 2024-04-26 17:44:34,630 [INFO ] 5|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.77614|msssim:0.30227|lpips:0.54257|dists:0.30956|Test loss:23.14211
54
+ 2024-04-26 19:05:03,170 [INFO ] 6|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.99645|msssim:0.30600|lpips:0.53496|dists:0.30253|Test loss:23.07454
55
+ 2024-04-26 20:25:38,328 [INFO ] 7|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01199|msssim:0.31352|lpips:0.53121|dists:0.29895|Test loss:23.00766
56
+ 2024-04-26 21:46:25,922 [INFO ] 8|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.95318|msssim:0.31135|lpips:0.53315|dists:0.30050|Test loss:22.96596
57
+ 2024-04-26 23:07:15,459 [INFO ] 9|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.86282|msssim:0.31913|lpips:0.52871|dists:0.30449|Test loss:22.93078
58
+ 2024-04-27 00:27:59,646 [INFO ] 10|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.02011|msssim:0.32396|lpips:0.52660|dists:0.29969|Test loss:22.88336
59
+ 2024-04-27 01:48:56,864 [INFO ] 11|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.07232|msssim:0.32188|lpips:0.52460|dists:0.29983|Test loss:22.86315
60
+ 2024-04-27 03:09:11,839 [INFO ] 12|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01043|msssim:0.32597|lpips:0.52504|dists:0.30010|Test loss:22.79985
61
+ 2024-04-27 04:29:49,673 [INFO ] 13|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10261|msssim:0.32433|lpips:0.52505|dists:0.29726|Test loss:22.78871
62
+ 2024-04-27 05:50:35,081 [INFO ] 14|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11125|msssim:0.32866|lpips:0.51948|dists:0.29619|Test loss:22.71350
63
+ 2024-04-27 07:10:26,373 [INFO ] 15|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.06473|msssim:0.32313|lpips:0.51909|dists:0.29284|Test loss:22.72064
64
+ 2024-04-27 08:29:03,434 [INFO ] 16|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.93886|msssim:0.32386|lpips:0.51924|dists:0.29350|Test loss:22.67830
65
+ 2024-04-27 09:48:42,090 [INFO ] 17|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.91387|msssim:0.32367|lpips:0.51942|dists:0.29057|Test loss:22.68104
66
+ 2024-04-27 11:07:05,282 [INFO ] 18|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04036|msssim:0.32969|lpips:0.51792|dists:0.29092|Test loss:22.64099
67
+ 2024-04-27 12:26:41,635 [INFO ] 19|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:11.97499|msssim:0.32280|lpips:0.51902|dists:0.28883|Test loss:22.61567
68
+ 2024-04-27 13:46:27,048 [INFO ] 20|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04421|msssim:0.32968|lpips:0.51996|dists:0.28981|Test loss:22.59735
69
+ 2024-04-27 15:07:26,901 [INFO ] 21|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.09963|msssim:0.32759|lpips:0.51643|dists:0.29144|Test loss:22.58986
70
+ 2024-04-27 16:27:18,924 [INFO ] 22|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25422|msssim:0.33293|lpips:0.51218|dists:0.28637|Test loss:22.55061
71
+ 2024-04-27 17:47:28,605 [INFO ] 23|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.03031|msssim:0.32350|lpips:0.51621|dists:0.28594|Test loss:22.55859
72
+ 2024-04-27 19:06:30,307 [INFO ] 24|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.01311|msssim:0.33033|lpips:0.51452|dists:0.28868|Test loss:22.51194
73
+ 2024-04-27 20:26:25,446 [INFO ] 25|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12269|msssim:0.33470|lpips:0.51502|dists:0.28773|Test loss:22.53424
74
+ 2024-04-27 21:45:53,181 [INFO ] 26|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22099|msssim:0.33310|lpips:0.51191|dists:0.28884|Test loss:22.50715
75
+ 2024-04-27 23:06:53,646 [INFO ] 27|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.08395|msssim:0.33514|lpips:0.51467|dists:0.28832|Test loss:22.50186
76
+ 2024-04-28 00:26:46,986 [INFO ] 28|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23953|msssim:0.33534|lpips:0.51394|dists:0.28696|Test loss:22.49509
77
+ 2024-04-28 01:47:49,741 [INFO ] 29|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.11600|msssim:0.33224|lpips:0.51388|dists:0.28858|Test loss:22.48610
78
+ 2024-04-28 03:08:05,651 [INFO ] 30|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.04495|msssim:0.33490|lpips:0.51515|dists:0.28526|Test loss:22.47889
79
+ 2024-04-28 04:29:31,374 [INFO ] 31|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.13572|msssim:0.33166|lpips:0.51150|dists:0.28472|Test loss:22.44420
80
+ 2024-04-28 04:36:00,208 [INFO ] 32|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29226|msssim:0.33557|lpips:0.51159|dists:0.28618|Test loss:22.46866
81
+ 2024-04-28 05:55:42,942 [INFO ] 33|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24471|msssim:0.33982|lpips:0.50918|dists:0.28688|Test loss:22.45550
82
+ 2024-04-28 07:14:35,318 [INFO ] 34|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26866|msssim:0.34018|lpips:0.50957|dists:0.28656|Test loss:22.46537
83
+ 2024-04-28 08:33:33,967 [INFO ] 35|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17031|msssim:0.33502|lpips:0.51055|dists:0.28262|Test loss:22.42946
84
+ 2024-04-28 09:54:29,978 [INFO ] 36|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25919|msssim:0.33783|lpips:0.51071|dists:0.28604|Test loss:22.45442
85
+ 2024-04-28 11:15:39,983 [INFO ] 37|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20635|msssim:0.33632|lpips:0.50824|dists:0.28454|Test loss:22.45101
86
+ 2024-04-28 12:35:52,799 [INFO ] 38|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32195|msssim:0.34213|lpips:0.50806|dists:0.28493|Test loss:22.43027
87
+ 2024-04-28 13:55:26,853 [INFO ] 39|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.36062|msssim:0.33969|lpips:0.50819|dists:0.28332|Test loss:22.41071
88
+ 2024-04-28 15:15:39,162 [INFO ] 40|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22780|msssim:0.33774|lpips:0.50971|dists:0.28230|Test loss:22.41198
89
+ 2024-04-28 16:36:09,180 [INFO ] 41|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29715|msssim:0.34114|lpips:0.50735|dists:0.28232|Test loss:22.39130
90
+ 2024-04-28 17:57:27,808 [INFO ] 42|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.09102|msssim:0.33483|lpips:0.50913|dists:0.28347|Test loss:22.41543
91
+ 2024-04-28 19:17:26,206 [INFO ] 43|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20306|msssim:0.34007|lpips:0.51021|dists:0.28400|Test loss:22.40152
92
+ 2024-04-28 20:35:40,022 [INFO ] 44|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.22824|msssim:0.33591|lpips:0.50983|dists:0.28228|Test loss:22.42137
93
+ 2024-04-28 21:54:40,195 [INFO ] 45|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.38086|msssim:0.34395|lpips:0.50842|dists:0.28619|Test loss:22.39918
94
+ 2024-04-28 23:13:35,391 [INFO ] 46|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.16532|msssim:0.33896|lpips:0.50933|dists:0.28314|Test loss:22.41701
95
+ 2024-04-29 00:32:11,881 [INFO ] 47|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26819|msssim:0.33832|lpips:0.50565|dists:0.28173|Test loss:22.37620
96
+ 2024-04-29 01:52:28,969 [INFO ] 48|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29298|msssim:0.34225|lpips:0.50837|dists:0.28685|Test loss:22.38713
97
+ 2024-04-29 03:11:36,092 [INFO ] 49|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.10050|msssim:0.33825|lpips:0.50617|dists:0.28070|Test loss:22.36497
98
+ 2024-04-29 04:31:26,489 [INFO ] 50|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.33884|msssim:0.34209|lpips:0.50693|dists:0.28271|Test loss:22.37217
99
+ 2024-04-29 05:52:26,051 [INFO ] 51|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32687|msssim:0.34543|lpips:0.50534|dists:0.28436|Test loss:22.36113
100
+ 2024-04-29 07:12:41,027 [INFO ] 52|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.23419|msssim:0.34313|lpips:0.50683|dists:0.28346|Test loss:22.34546
101
+ 2024-04-29 08:32:37,029 [INFO ] 53|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.19908|msssim:0.34212|lpips:0.51034|dists:0.28401|Test loss:22.35217
102
+ 2024-04-29 09:53:24,004 [INFO ] 54|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.20177|msssim:0.34039|lpips:0.50786|dists:0.28354|Test loss:22.35897
103
+ 2024-04-29 11:13:36,377 [INFO ] 55|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26382|msssim:0.34287|lpips:0.50529|dists:0.28344|Test loss:22.36005
104
+ 2024-04-29 12:34:01,726 [INFO ] 56|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.17537|msssim:0.34354|lpips:0.50563|dists:0.28147|Test loss:22.34528
105
+ 2024-04-29 13:55:09,515 [INFO ] 57|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.12880|msssim:0.34119|lpips:0.50656|dists:0.28078|Test loss:22.33371
106
+ 2024-04-29 15:16:53,478 [INFO ] 58|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.24665|msssim:0.34374|lpips:0.50576|dists:0.28350|Test loss:22.32770
107
+ 2024-04-29 16:37:56,041 [INFO ] 59|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21803|msssim:0.34118|lpips:0.50617|dists:0.28309|Test loss:22.35576
108
+ 2024-04-29 17:57:24,559 [INFO ] 60|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.27436|msssim:0.34233|lpips:0.50439|dists:0.28176|Test loss:22.33442
109
+ 2024-04-29 19:16:29,152 [INFO ] 61|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.18612|msssim:0.33822|lpips:0.50917|dists:0.28183|Test loss:22.35925
110
+ 2024-04-29 20:35:53,546 [INFO ] 62|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28599|msssim:0.34549|lpips:0.50387|dists:0.28009|Test loss:22.33322
111
+ 2024-04-29 21:55:16,529 [INFO ] 63|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.21795|msssim:0.34211|lpips:0.50608|dists:0.27972|Test loss:22.33488
112
+ 2024-04-29 23:14:30,846 [INFO ] 64|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.29407|msssim:0.34144|lpips:0.50596|dists:0.28011|Test loss:22.35140
113
+ 2024-04-29 23:19:58,151 [INFO ] 65|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.28235|msssim:0.34247|lpips:0.50468|dists:0.27957|Test loss:22.33808
114
+ 2024-04-30 00:39:16,656 [INFO ] 66|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.25703|msssim:0.34225|lpips:0.50430|dists:0.28089|Test loss:22.31090
115
+ 2024-04-30 01:59:45,043 [INFO ] 67|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.26808|msssim:0.34043|lpips:0.50586|dists:0.28125|Test loss:22.32784
116
+ 2024-04-30 03:21:24,464 [INFO ] 68|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.31984|msssim:0.34534|lpips:0.50445|dists:0.28171|Test loss:22.30779
117
+ 2024-04-30 04:42:24,830 [INFO ] 69|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.35958|msssim:0.34179|lpips:0.50447|dists:0.28130|Test loss:22.31575
118
+ 2024-04-30 06:03:32,581 [INFO ] 70|bpp:0.01874|mask:0.00888|mask_ratio:0.75|psnr:12.32468|msssim:0.34391|lpips:0.50387|dists:0.28060|Test loss:22.32640
deploited/random_with_rope.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b81f2e9114c831f49bf4926d5a40d8655a9d21376f3e71fcc66ecfb1c5db1cb7
3
+ size 1960373872
deploited/test_crop.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from util.utils import pad_to_multiple_of_256, split_into_blocks, merge_blocks, crop_to_original_shape
2
+ import glob
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import math
7
+ from torch.nn import functional as F
8
+ import PIL.Image as Image
9
+ from torchvision import utils as vutils
10
+
11
+
12
+ def load_img(p, padding=True, factor=64):
13
+ x = Image.open(p)
14
+ x = torch.from_numpy(np.asarray(x))
15
+ if len(x.shape) == 2:
16
+ x = x.unsqueeze(-1).repeat(1, 1, 3) # h,w -> h,w,3
17
+ x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
18
+ h, w = x.shape[2:4]
19
+
20
+ if padding:
21
+ dh = factor * math.ceil(h / factor) - h
22
+ dw = factor * math.ceil(w / factor) - w
23
+ # 均匀添加padding
24
+ dh_half = dh // 2
25
+ dw_half = dw // 2
26
+ dh_extra = dh % 2
27
+ dw_extra = dw % 2
28
+ x = F.pad(x, (dw_half, dw_half + dw_extra, dh_half, dh_half + dh_extra))
29
+ return x, h, w
30
+
31
+ def save_img(img: torch.Tensor, vis_path, input_p, rec=False):
32
+ img = img.clone().detach()
33
+ img = img.to(torch.device('cpu'))
34
+ if os.path.isdir(vis_path) is not True:
35
+ os.makedirs(vis_path)
36
+ end = '/'
37
+ if rec:
38
+ vis_path = vis_path + '/rec'
39
+ if os.path.isdir(vis_path) is not True:
40
+ os.makedirs(vis_path)
41
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
42
+ else:
43
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
44
+ vutils.save_image(img, os.path.join(img_name), nrow=8)
45
+
46
+
47
+ eval_path = sorted(glob.glob(os.path.join('/home/t2vg-a100-G4-10/project/qyp/datasets/test', '*.jpg')))
48
+ vis_path = os.path.join("./test_crop/")
49
+ os.makedirs(vis_path, exist_ok=True)
50
+
51
+ for input_p in eval_path:
52
+ x, hx, wx = load_img(input_p, padding=True, factor=64)
53
+ print("ori height", hx, "ori width", wx)
54
+ ori_shape = x.shape
55
+ print("input shape", ori_shape)
56
+ x = pad_to_multiple_of_256(x, 0)
57
+ save_img(x, vis_path, input_p, rec=False)
58
+ print("shape after padding", x.shape)
59
+ _, _, new_h, new_w = x.shape
60
+ x = split_into_blocks(x)
61
+ print("new shape", x.shape)
62
+ new_bsz = x.size(0)
63
+ new_shape = [ori_shape[0], 3, new_h, new_w]
64
+ x = merge_blocks(x, new_shape)
65
+ print("shape after merge", x.shape)
66
+ x = crop_to_original_shape(x, ori_shape)
67
+ print("shape after crop", x.shape)
68
+ save_img(x, vis_path, input_p, rec=True)
deploited/train_vbr_codec_rope.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import argparse
3
+ import math
4
+ # import torchac
5
+ import sys
6
+ import os
7
+ import time
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+
15
+ import torchvision
16
+ from torch.utils.data import DataLoader
17
+ from torchvision import transforms
18
+ from torchvision.models import resnet50
19
+
20
+ import yaml
21
+ from pytorch_msssim import ms_ssim
22
+ from DISTS_pytorch import DISTS
23
+ from util.lpips import LPIPS
24
+ from torch.nn import functional as F
25
+ from torchvision import utils as vutils
26
+ import matplotlib.pyplot as plt
27
+ import numpy as np
28
+ import glob
29
+
30
+ import util.misc as misc
31
+ import util.lr_sched as lr_sched
32
+ from torch.utils.tensorboard import SummaryWriter
33
+ import models_mage_codec
34
+ import models_mage_codec_rope
35
+ import timm.optim.optim_factory as optim_factory
36
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
37
+ import json
38
+ import PIL.Image as Image
39
+ import torch.backends.cudnn as cudnn
40
+ from pathlib import Path
41
+ import random
42
+ import torch.distributed as dist
43
+
44
+ class CalMetrics(nn.Module):
45
+ """Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
46
+
47
+ def __init__(self):
48
+ super().__init__()
49
+ self.mse = nn.MSELoss()
50
+
51
+ def bpp_loss(self, ori, out_net):
52
+ b, _, h, w = ori.shape
53
+ num_pixels = b * h * w
54
+ # bpp = sum(
55
+ # (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
56
+ # for likelihoods in out_net["likelihoods"].values()
57
+ # )
58
+ bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels)
59
+ bs_mask_token = out_net['bs_mask_token']
60
+ bytes_length = len(bs_mask_token)
61
+ # 因为每个字节包含8位,所以总位数是字节数 * 8
62
+ total_bits = bytes_length * 8
63
+ # 计算每像素的位数(bpp)
64
+ bpp_mask = total_bits / num_pixels
65
+ return bpp, bpp_mask
66
+
67
+ def psnr(self, rec, ori):
68
+ mse = torch.mean((rec - ori) ** 2)
69
+ if(mse == 0):
70
+ return 100
71
+ max_pixel = 1.
72
+ psnr = 10 * torch.log10(max_pixel / mse)
73
+ return torch.mean(psnr)
74
+
75
+ def lpips(self, rec, ori):
76
+ lpips_func = LPIPS().eval().to(device=rec.device)
77
+ lipis_value = lpips_func(rec, ori)
78
+ return lipis_value.mean()
79
+
80
+ def dists(self, rec, ori):
81
+ D = DISTS().cuda()
82
+ dists_value = D(rec, ori)
83
+ return dists_value.mean()
84
+
85
+ def cal_total_loss(self, lpips, bpp, out_net):
86
+ # task_loss = out_net['task_loss'] + 0.1 * lpips
87
+ task_loss = out_net['task_loss']
88
+ total_loss = bpp + out_net['lambda'] * task_loss
89
+ return total_loss
90
+
91
+ def forward(self, ori, out_net, rec=None):
92
+ out = {}
93
+ out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net)
94
+ out["bpp_loss"] = out["bpp"] + out["bpp_mask"]
95
+ # out["total_loss"] = self.cal_total_loss(out["bpp_loss"], out_net)
96
+ if rec is not None:
97
+ out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
98
+ out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True)
99
+ out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori)
100
+ out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
101
+ out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net)
102
+ return out
103
+
104
+
105
+ class FeatureHook():
106
+ def __init__(self, module):
107
+ module.register_forward_hook(self.attach)
108
+
109
+ def attach(self, model, input, output):
110
+ self.feature = output
111
+
112
+
113
+ class Clsloss(nn.Module):
114
+ def __init__(self, device, cls_loss=True) -> None:
115
+ super().__init__()
116
+ self.ce = nn.CrossEntropyLoss()
117
+ self.classifier = resnet50(True)
118
+ self.classifier.requires_grad_(False)
119
+ self.hooks = [FeatureHook(i) for i in [ # for calculating perceptual loss
120
+ self.classifier.layer1,
121
+ self.classifier.layer2,
122
+ self.classifier.layer3,
123
+ self.classifier.layer4,
124
+ ]]
125
+ self.classifier = self.classifier.to(device)
126
+ for k, p in self.classifier.named_parameters():
127
+ p.requires_grad = False
128
+ self.classifier.eval()
129
+ self.cls_loss = cls_loss
130
+ self.transform = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
131
+
132
+ def accuracy(output, target, topk=(1,)):
133
+ maxk = max(topk)
134
+ batch_size = target.size(0)
135
+
136
+ _, pred = output.topk(maxk, 1, True, True)
137
+ pred = pred.t()
138
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
139
+
140
+ res = []
141
+ for k in topk:
142
+ correct_k = correct[:k].view(-1).float().sum(0)
143
+ res.append(correct_k.mul_(100.0 / batch_size))
144
+ return res
145
+
146
+ def forward(self, d, rec, y_true):
147
+ # output = codec(d), d = original image, y_true = label
148
+ x_hat = torch.clamp(rec,0,1)
149
+ pred = self.classifier(self.transform(x_hat)) # transform is normalization
150
+ cls_loss = self.ce(pred, y_true)
151
+ accu = sum(torch.argmax(pred,-1)==y_true)/pred.shape[0]
152
+ if self.perceptual_loss:
153
+ pred_feat = [i.feature.clone() for i in self.hooks]
154
+ _ = self.classifier(self.transform(d))
155
+ ori_feat = [i.feature.clone() for i in self.hooks]
156
+ perc_loss = torch.stack([nn.functional.mse_loss(p,o, reduction='none').mean((1,2,3)) for p,o in zip(pred_feat, ori_feat)])
157
+ perc_loss = perc_loss.mean()
158
+ return perc_loss
159
+
160
+ return cls_loss, accu, None
161
+
162
+
163
+ class AverageMeter:
164
+ """Compute running average."""
165
+
166
+ def __init__(self):
167
+ self.val = 0
168
+ self.avg = 0
169
+ self.sum = 0
170
+ self.count = 0
171
+
172
+ def update(self, val, n=1):
173
+ self.val = val
174
+ self.sum += val * n
175
+ self.count += n
176
+ self.avg = self.sum / self.count
177
+
178
+ class CustomDataParallel(nn.DataParallel):
179
+ """Custom DataParallel to access the module methods."""
180
+
181
+ def __getattr__(self, key):
182
+ try:
183
+ return super().__getattr__(key)
184
+ except AttributeError:
185
+ return getattr(self.module, key)
186
+
187
+
188
+ def init(args):
189
+ base_dir = f'{args.root}/{args.exp_name}/'
190
+ os.makedirs(base_dir, exist_ok=True)
191
+ return base_dir
192
+
193
+ def setup_logger(log_dir):
194
+ log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
195
+ root_logger = logging.getLogger()
196
+ root_logger.setLevel(logging.INFO)
197
+
198
+ log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
199
+ log_file_handler.setFormatter(log_formatter)
200
+ root_logger.addHandler(log_file_handler)
201
+
202
+ log_stream_handler = logging.StreamHandler(sys.stdout)
203
+ log_stream_handler.setFormatter(log_formatter)
204
+ root_logger.addHandler(log_stream_handler)
205
+
206
+ logging.info('Logging file is %s' % log_dir)
207
+
208
+
209
+ def load_img(p, padding=True, factor=64):
210
+ x = Image.open(p)
211
+ x = torch.from_numpy(np.asarray(x))
212
+ if len(x.shape) == 2:
213
+ x = x.unsqueeze(-1).repeat(1,1,3) # h,w -> h,w,3
214
+ x = x.permute(2, 0, 1).unsqueeze(0).float().div(255)
215
+ h, w = x.shape[2:4]
216
+
217
+ if padding:
218
+ dh = factor * math.ceil(h / factor) - h
219
+ dw = factor * math.ceil(w / factor) - w
220
+ x = F.pad(x, (0, dw, 0, dh))
221
+ return x, h, w
222
+
223
+ def save_img(img: torch.Tensor, vis_path, input_p, mask=False):
224
+ img = img.clone().detach()
225
+ img = img.to(torch.device('cpu'))
226
+ if os.path.isdir(vis_path) is not True:
227
+ os.makedirs(vis_path)
228
+ end = '/'
229
+ if mask:
230
+ img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):])
231
+ else:
232
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
233
+ vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8)
234
+
235
+ def train_one_epoch(model, data_loader, metrics_criterion, device,
236
+ optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'):
237
+ ## ======================= set configs ======================= ##
238
+ model.train(True)
239
+ metric_logger = misc.MetricLogger(delimiter=" ")
240
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
241
+ header = 'Epoch: [{}]'.format(epoch)
242
+ print_freq = 20
243
+ accum_iter = args.accum_iter
244
+ optimizer.zero_grad()
245
+ if log_writer is not None:
246
+ print('log_dir: {}'.format(log_writer.log_dir))
247
+
248
+ vis_path = os.path.join("./MIM_vbr/", stage)
249
+ os.makedirs(vis_path, exist_ok=True)
250
+
251
+ # tqdm_emu = tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train), leave=False)
252
+ for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
253
+ samples = samples.to(device, non_blocking=True) # samples = original image
254
+
255
+ # we use a per iteration (instead of per epoch) lr scheduler
256
+ if data_iter_step % accum_iter == 0:
257
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
258
+
259
+ with torch.cuda.amp.autocast():
260
+ out_net = model(samples, is_training=True, manual_mask_rate=None)
261
+
262
+ rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
263
+ # rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
264
+ rec = rec.to(device)
265
+ out_criterion = metrics_criterion(samples, out_net, rec)
266
+ loss_value = out_criterion['total_loss'].item()
267
+
268
+ if not math.isfinite(loss_value):
269
+ print("Loss is {}, stopping training".format(loss_value))
270
+ sys.exit(1)
271
+
272
+ out_criterion['total_loss'] /= accum_iter
273
+ loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
274
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
275
+ if (data_iter_step + 1) % accum_iter == 0:
276
+ optimizer.zero_grad()
277
+
278
+ torch.cuda.synchronize()
279
+
280
+ metric_logger.update(loss=loss_value)
281
+
282
+ lr = optimizer.param_groups[0]["lr"]
283
+ metric_logger.update(lr=lr)
284
+ metric_logger.update(bpp=out_criterion['bpp_loss'])
285
+ metric_logger.update(bpp_mask=out_criterion['bpp_mask'])
286
+ metric_logger.update(task_loss=out_net['task_loss'].item()) # task_loss未更新,均值更新了
287
+ metric_logger.update(lmbda=out_net['lambda'])
288
+ metric_logger.update(mask_ratio=out_net['mask_ratio']) # mask_ratio未更新,均值更新了
289
+ metric_logger.update(lpips=out_criterion['lpips'].item()) # lpips未更新,均值更新了
290
+ metric_logger.update(dists=out_criterion['dists'].item())
291
+
292
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
293
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
294
+ """ We use epoch_1000x as the x-axis in tensorboard.
295
+ This calibrates different curves when batch size changes.
296
+ """
297
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
298
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
299
+ log_writer.add_scalar('lr', lr, epoch_1000x)
300
+
301
+ ## ======================= update progress bar & visualization ======================= ##
302
+ if data_iter_step % 500 == 0:
303
+ with torch.no_grad():
304
+ real_fake_images = torch.cat((samples, rec), dim=0)
305
+ vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8)
306
+ # vutils.save_image(samples, os.path.join(vis_path, f"{epoch}_{data_iter_step}_ori.jpg"), nrow=6)
307
+ # vutils.save_image(rec, os.path.join(vis_path, f"{epoch}_{data_iter_step}_rec.jpg"), nrow=6)
308
+ vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8)
309
+
310
+ # if (data_iter_step % 10000 == 0) and (data_iter_step != 0):
311
+ # test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, 0.75, args, 'val')
312
+ # gather the stats from all processes
313
+ metric_logger.synchronize_between_processes()
314
+ print("Averaged stats:", metric_logger)
315
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
316
+
317
+ def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'):
318
+ model.eval()
319
+ bpp_loss = AverageMeter()
320
+ bpp_mask = AverageMeter()
321
+ psnr = AverageMeter()
322
+ msssim = AverageMeter()
323
+ lpips = AverageMeter()
324
+ dists = AverageMeter()
325
+ test_loss = AverageMeter()
326
+
327
+ vis_path = os.path.join("./MIM_vbr/", stage)
328
+ os.makedirs(vis_path, exist_ok=True)
329
+ if stage == 'test':
330
+ test_vis_path = os.path.join("/home/t2vg-a100-G4-10/project/qyp/rec_fid", manual_mask_ratio)
331
+ os.makedirs(test_vis_path, exist_ok=True)
332
+
333
+ with torch.no_grad():
334
+ # tqdm_meter = tqdm.tqdm(eval_path, leave=False)
335
+ tqdm_meter = tqdm.tqdm(enumerate(test_loader),leave=False, total=len(test_loader))
336
+ for i, (d, l) in tqdm_meter:
337
+ d = d.to(device)
338
+ # l = l.to(device)
339
+ # for input_p in eval_path:
340
+ # x, hx, wx = load_img(input_p, padding=True, factor=64)
341
+ # x = x.to(device)
342
+ out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
343
+ rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'])
344
+
345
+ # x = x[:, :, :hx, :wx]
346
+ # rec = rec[:, :, :hx, :wx]
347
+ rec = rec.to(device)
348
+ out_criterion = metrics_criterion(d, out_net, rec)
349
+
350
+ bpp_loss.update(out_criterion["bpp_loss"])
351
+ bpp_mask.update(out_criterion["bpp_mask"])
352
+ psnr.update(out_criterion['psnr'])
353
+ msssim.update(out_criterion['msssim'])
354
+ lpips.update(out_criterion['lpips'])
355
+ dists.update(out_criterion['dists'])
356
+ test_loss.update(out_criterion['total_loss'])
357
+
358
+ ## ======================= update progress bar & visualization ======================= ##
359
+ if stage == 'val':
360
+ if i % 5 == 0:
361
+ with torch.no_grad():
362
+ real_fake_images = torch.cat((d, rec), dim=0)
363
+ vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{i}.jpg"), nrow=8)
364
+ vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg"), nrow=8)
365
+ if stage == 'test':
366
+ with torch.no_grad():
367
+ vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"), nrow=8)
368
+ # txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
369
+ # tqdm_meter.set_postfix_str(txt)
370
+
371
+ model.train()
372
+
373
+ # 假设其它变量和环境已经正确设置
374
+ if torch.distributed.is_initialized():
375
+ rank = dist.get_rank()
376
+ else:
377
+ rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
378
+
379
+ if rank == 0:
380
+ log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}"
381
+ logging.info(log_txt)
382
+ return test_loss.avg
383
+
384
+ def inference_with_acc(epoch, test_loader, model, metrics_criterion, cls_criterion, device, manual_mask_ratio, args, stage='test'):
385
+ model.eval()
386
+ bpp_loss = AverageMeter()
387
+ bpp_mask = AverageMeter()
388
+ psnr = AverageMeter()
389
+ msssim = AverageMeter()
390
+ lpips = AverageMeter()
391
+ dists = AverageMeter()
392
+ accuracy = AverageMeter()
393
+ test_loss = AverageMeter()
394
+
395
+ if stage == 'test':
396
+ # test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio)
397
+ test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/test_pos", str(manual_mask_ratio))
398
+ os.makedirs(test_vis_path, exist_ok=True)
399
+
400
+ with torch.no_grad():
401
+ # tqdm_meter = tqdm.tqdm(eval_path, leave=False)
402
+ tqdm_meter = tqdm.tqdm(enumerate(test_loader),leave=False, total=len(test_loader))
403
+ for i, (d, l) in tqdm_meter:
404
+ d = d.to(device)
405
+ # l = l.to(device)
406
+ # for input_p in eval_path:
407
+ # x, hx, wx = load_img(input_p, padding=True, factor=64)
408
+ # x = x.to(device)
409
+ # _, _, hx, wx = d.shape
410
+ out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
411
+ rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['z_H'], out_net['z_W'], num_iter=12)
412
+
413
+ # x = x[:, :, :hx, :wx]
414
+ # rec = rec[:, :, :hx, :wx]
415
+ rec = rec.to(device)
416
+ out_criterion = metrics_criterion(d, out_net, rec)
417
+ _, accu, _ = cls_criterion(d, rec, l)
418
+
419
+ bpp_loss.update(out_criterion["bpp_loss"])
420
+ bpp_mask.update(out_criterion["bpp_mask"])
421
+ psnr.update(out_criterion['psnr'])
422
+ msssim.update(out_criterion['msssim'])
423
+ lpips.update(out_criterion['lpips'])
424
+ dists.update(out_criterion['dists'])
425
+ test_loss.update(out_criterion['total_loss'])
426
+ accuracy.update(accu)
427
+
428
+ ## ======================= update progress bar & visualization ======================= ##
429
+ if stage == 'test':
430
+ with torch.no_grad():
431
+ vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"))
432
+ # txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
433
+ # tqdm_meter.set_postfix_str(txt)
434
+
435
+ model.train()
436
+
437
+ # 假设其它变量和环境已经正确设置
438
+ if torch.distributed.is_initialized():
439
+ rank = dist.get_rank()
440
+ else:
441
+ rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
442
+
443
+ if rank == 0:
444
+ log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|accu:{accuracy.avg:.5f}|Test loss:{test_loss.avg.item():.5f}"
445
+ logging.info(log_txt)
446
+ return test_loss.avg
447
+
448
+ def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"):
449
+ torch.save(state, base_dir+filename)
450
+ if is_best:
451
+ torch.save(state, base_dir+"checkpoint_best.pth.tar")
452
+ # shutil.copyfile(base_dir+filename, base_dir+"checkpoint_best_loss.pth.tar")
453
+
454
+ def parse_args(argv):
455
+ parser = argparse.ArgumentParser(description="Example training script.")
456
+ parser.add_argument(
457
+ "-c",
458
+ "--config",
459
+ default="config/vpt_default.yaml",
460
+ help="Path to config file",
461
+ )
462
+ parser.add_argument(
463
+ '--name',
464
+ default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
465
+ type=str,
466
+ help='Result dir name',
467
+ )
468
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
469
+ help='learning rate (absolute lr)')
470
+ given_configs, remaining = parser.parse_known_args(argv)
471
+ # distributed training parameters
472
+ parser.add_argument('--world_size', default=1, type=int,
473
+ help='number of distributed processes')
474
+ parser.add_argument('--local-rank', default=-1, type=int)
475
+ parser.add_argument('--dist_on_itp', action='store_true')
476
+ parser.add_argument('--dist_url', default='env://',
477
+ help='url used to set up distributed training')
478
+ with open(given_configs.config) as file:
479
+ yaml_data= yaml.safe_load(file)
480
+ parser.set_defaults(**yaml_data)
481
+
482
+ parser.add_argument(
483
+ "-T",
484
+ "--TEST",
485
+ # action='store_true',
486
+ default=False,
487
+ help='Testing'
488
+ )
489
+ args = parser.parse_args(remaining)
490
+ return args
491
+
492
+ def load_eval_ps(eval_path):
493
+ eval_ps = sorted(glob.glob(os.path.join(eval_path, '*.png')))
494
+ return eval_ps
495
+
496
+ def main(argv):
497
+ args = parse_args(argv)
498
+ base_dir = init(args) # create the base dir for saving the results
499
+
500
+ if args.output_dir:
501
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
502
+ args.log_dir = args.output_dir
503
+
504
+ misc.init_distributed_mode(args)
505
+
506
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
507
+ print("{}".format(args).replace(', ', ',\n'))
508
+
509
+ device = torch.device(args.device)
510
+ # fix the seed for reproducibility
511
+ seed = args.seed + misc.get_rank()
512
+ torch.manual_seed(seed)
513
+ torch.cuda.manual_seed(seed)
514
+ np.random.seed(seed)
515
+ random.seed(seed)
516
+
517
+ cudnn.benchmark = True
518
+
519
+ setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
520
+ msg = f'======================= {args.name} ======================='
521
+ logging.info(msg)
522
+ for k in args.__dict__:
523
+ logging.info(k + ':' + str(args.__dict__[k]))
524
+ logging.info('=' * len(msg))
525
+
526
+ ## ======================= prepare dataset ======================= ##
527
+ transform_train = transforms.Compose([
528
+ transforms.RandomResizedCrop(args.input_size, scale=(0.8, 1.0)),
529
+ transforms.RandomHorizontalFlip(),
530
+ transforms.ToTensor()])
531
+ transform_test = transforms.Compose( # crop to 256x256
532
+ [transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()]
533
+ )
534
+
535
+ if args.dataset=='imagenet':
536
+ train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_path, "train"), transform=transform_train)
537
+ test_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_path, "val"), transform=transform_test)
538
+
539
+ val_dataset, _ = torch.utils.data.random_split(test_dataset, [2000, 48000])
540
+ small_train_datasets = torch.utils.data.random_split(train_dataset, [40000]*32+[1167])
541
+ eval_path = sorted(glob.glob(os.path.join(args.eval_path, '*.png')))
542
+
543
+ device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
544
+
545
+ # if args.distributed:
546
+ if True:
547
+ num_tasks = misc.get_world_size()
548
+ global_rank = misc.get_rank()
549
+ sampler_test = torch.utils.data.DistributedSampler(
550
+ test_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
551
+ )
552
+ sampler_val = torch.utils.data.DistributedSampler(
553
+ val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
554
+ )
555
+ else:
556
+ sampler_train = torch.utils.data.RandomSampler(train_dataset)
557
+
558
+ if global_rank == 0 and args.log_dir is not None:
559
+ os.makedirs(args.log_dir, exist_ok=True)
560
+ log_writer = SummaryWriter(log_dir=args.log_dir)
561
+ else:
562
+ log_writer = None
563
+
564
+ val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=args.test_batch_size,
565
+ num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True)
566
+ test_dataloader = DataLoader(test_dataset, sampler=sampler_test, batch_size=1,
567
+ num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem,)
568
+
569
+ ## ======================= prepare model ======================= ##
570
+ vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
571
+ # model = models_mage_codec.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
572
+ # mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
573
+ # vqgan_ckpt_path=vqgan_ckpt_path)
574
+ model = models_mage_codec_rope.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
575
+ mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
576
+ vqgan_ckpt_path=vqgan_ckpt_path)
577
+
578
+ model.to(device)
579
+ model_without_ddp = model
580
+ print("Model = %s" % str(model_without_ddp))
581
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
582
+ if args.lr is None: # only base_lr is specified
583
+ args.lr = args.blr * eff_batch_size / 256
584
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
585
+ print("actual lr: %.2e" % args.lr)
586
+
587
+ print("accumulate grad iterations: %d" % args.accum_iter)
588
+ print("effective batch size: %d" % eff_batch_size)
589
+
590
+ if args.distributed:
591
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
592
+ model_without_ddp = model.module
593
+
594
+ # following timm: set wd as 0 for bias and norm layers
595
+ param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
596
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
597
+ print(optimizer)
598
+ loss_scaler = NativeScaler()
599
+
600
+ # resume from a checkpoint
601
+ misc.load_model(args=args, model_without_ddp=model_without_ddp,
602
+ optimizer=optimizer, loss_scaler=loss_scaler, strict=False)
603
+
604
+ metrics_criterion = CalMetrics()
605
+ cls_criterion = Clsloss(device, True)
606
+ ## ======================= Start Training ======================= ##
607
+ last_epoch = args.start_epoch
608
+
609
+ if args.TEST: # test only
610
+ best_loss = float("inf")
611
+ tqrange = tqdm.trange(last_epoch, args.epochs)
612
+ # loss = test_epoch(-1, test_dataloader, net, rate_criterion, args.VPT_lmbda, args, 'test')
613
+ for manual_mask_ratio in [0.3]:
614
+ loss = inference_with_acc(-1, test_dataloader, model, metrics_criterion, cls_criterion, device, manual_mask_ratio, args, 'test')
615
+ return
616
+
617
+ ## ======================= pre validation ======================= ##
618
+ print("############## pre validation ##############")
619
+ best_loss = float("inf")
620
+ tqrange = tqdm.trange(last_epoch, args.epochs)
621
+ val_mask_ratio = 0.75
622
+ test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
623
+
624
+ ## ======================= start training ======================= ##
625
+ print(f"############## Start training for {args.epochs} epochs ##############")
626
+ start_time = time.time()
627
+ for epoch in tqrange:
628
+ current_dataset = small_train_datasets[epoch % len(small_train_datasets)]
629
+ sampler_train = torch.utils.data.DistributedSampler(current_dataset, shuffle=True)
630
+ data_loader_train = DataLoader(
631
+ current_dataset, sampler=sampler_train,
632
+ batch_size=args.batch_size,
633
+ num_workers=args.num_workers,
634
+ pin_memory=args.pin_mem,
635
+ drop_last=True,
636
+ )
637
+ if args.distributed:
638
+ data_loader_train.sampler.set_epoch(epoch)
639
+ train_stats = train_one_epoch(model, data_loader_train, metrics_criterion, device,
640
+ optimizer, epoch, loss_scaler, log_writer=log_writer, args=args, val_dataloader=val_dataloader, stage='train')
641
+
642
+ test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
643
+
644
+ is_best = test_loss < best_loss
645
+ best_loss = min(test_loss, best_loss)
646
+
647
+ if args.output_dir and (epoch % 10 == 0 or epoch + 1 == args.epochs):
648
+ misc.save_model(
649
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
650
+ loss_scaler=loss_scaler, epoch=epoch)
651
+ if is_best:
652
+ misc.save_model_last(
653
+ args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
654
+ loss_scaler=loss_scaler, epoch=epoch, is_best=is_best)
655
+ # misc.save_model_last(
656
+ # args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
657
+ # loss_scaler=loss_scaler, epoch=epoch)
658
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
659
+ 'epoch': epoch,}
660
+ if args.output_dir and misc.is_main_process():
661
+ if log_writer is not None:
662
+ log_writer.flush()
663
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
664
+ f.write(json.dumps(log_stats) + "\n")
665
+
666
+ total_time = time.time() - start_time
667
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
668
+ print('Training time {}'.format(total_time_str))
669
+
670
+
671
+
672
+ if __name__ == "__main__":
673
+ main(sys.argv[1:])
detection.py ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import argparse
3
+ import math
4
+ # import torchac
5
+ import sys
6
+ import os
7
+ import time
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.optim as optim
14
+
15
+ import torchvision
16
+ from torch.utils.data import DataLoader
17
+ from torchvision import transforms
18
+ from torchvision.models import resnet50
19
+
20
+ import yaml
21
+ from pytorch_msssim import ms_ssim
22
+ from DISTS_pytorch import DISTS
23
+ from util.lpips import LPIPS
24
+ from torch.nn import functional as F
25
+ from torchvision import utils as vutils
26
+ import numpy as np
27
+
28
+ import util.misc as misc
29
+ import util.lr_sched as lr_sched
30
+ from torch.utils.tensorboard import SummaryWriter
31
+ import models_mage_codec_high_resolu
32
+ import timm.optim.optim_factory as optim_factory
33
+ from util.misc import NativeScalerWithGradNormCount as NativeScaler
34
+ from collections import OrderedDict
35
+ import pickle
36
+ import torch.backends.cudnn as cudnn
37
+ from pathlib import Path
38
+ import random
39
+ import torch.distributed as dist
40
+ from util.dataloader import MSCOCO, Kodak, prepadding
41
+ from util.utils import adaptively_split_and_pad, crop_and_reconstruct
42
+ from util.alignment import Alignment
43
+
44
+ ## General
45
+ from detectron2.config import get_cfg
46
+ from detectron2.layers import ShapeSpec
47
+ from detectron2.modeling.backbone.fpn import build_resnet_fpn_backbone
48
+
49
+ ## Test
50
+ from detectron2.evaluation import COCOEvaluator
51
+ from detectron2.data.datasets import register_coco_instances
52
+ from detectron2.data import build_detection_test_loader
53
+ from detectron2.data.detection_utils import read_image
54
+
55
+ from contextlib import ExitStack, contextmanager
56
+
57
+
58
+ ## Function for model to eval 用于临时将模型切换到评估模式并在操作完成后恢复模型的原始模式
59
+ @contextmanager
60
+ def inference_context(model):
61
+ training_mode = model.training
62
+ model.eval()
63
+ yield
64
+ model.train(training_mode)
65
+
66
+ class CalMetrics(nn.Module):
67
+ """Calculate BPP, PSNR, MS-SSIM, LPIPS and DISTS for the reconstructed image."""
68
+
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.mse = nn.MSELoss()
72
+
73
+ def bpp_loss(self, ori, out_net):
74
+ b, _, h, w = ori.shape
75
+ num_pixels = b * h * w
76
+ bpp = torch.log(out_net["likelihoods"]).sum() / (-math.log(2) * num_pixels)
77
+ bs_mask_token = out_net['bs_mask_token']
78
+ bytes_length = len(bs_mask_token)
79
+ # 因为每个字节包含8位,所以总位数是字节数 * 8
80
+ total_bits = bytes_length * 8
81
+ # 计算每像素的位数(bpp)
82
+ bpp_mask = total_bits / num_pixels
83
+ return bpp, bpp_mask
84
+
85
+ def psnr(self, rec, ori):
86
+ mse = torch.mean((rec - ori) ** 2)
87
+ if(mse == 0):
88
+ return 100
89
+ max_pixel = 1.
90
+ psnr = 10 * torch.log10(max_pixel / mse)
91
+ return torch.mean(psnr)
92
+
93
+ def lpips(self, rec, ori):
94
+ lpips_func = LPIPS().eval().to(device=rec.device)
95
+ lipis_value = lpips_func(rec, ori)
96
+ return lipis_value.mean()
97
+
98
+ def dists(self, rec, ori):
99
+ D = DISTS().cuda()
100
+ dists_value = D(rec, ori)
101
+ return dists_value.mean()
102
+
103
+ def cal_total_loss(self, lpips, bpp, out_net):
104
+ # task_loss = out_net['task_loss'] + 0.1 * lpips
105
+ task_loss = out_net['task_loss']
106
+ total_loss = bpp + out_net['lambda'] * task_loss
107
+ return total_loss
108
+
109
+ def forward(self, ori, out_net, rec=None):
110
+ out = {}
111
+ out["bpp"], out["bpp_mask"] = self.bpp_loss(ori, out_net)
112
+ out["bpp_loss"] = out["bpp"] + out["bpp_mask"]
113
+ # out["total_loss"] = self.cal_total_loss(out["bpp_loss"], out_net)
114
+ if rec is not None:
115
+ out["psnr"] = self.psnr(torch.clamp(rec, 0, 1), ori)
116
+ out["msssim"] = ms_ssim(torch.clamp(rec, 0, 1), ori, data_range=1, size_average=True)
117
+ out["lpips"] = self.lpips(torch.clamp(rec, 0, 1), ori)
118
+ out["dists"] = self.dists(torch.clamp(rec, 0, 1), ori)
119
+ out["total_loss"] = self.cal_total_loss(out["lpips"], out["bpp_loss"], out_net)
120
+ return out
121
+
122
+
123
+ class TaskLoss(nn.Module):
124
+ def __init__(self, cfg, device) -> None:
125
+ super().__init__()
126
+ self.ce = nn.CrossEntropyLoss()
127
+ self.task_net = build_resnet_fpn_backbone(cfg, ShapeSpec(channels=3))
128
+ checkpoint = OrderedDict()
129
+ with open(cfg.MODEL.WEIGHTS, 'rb') as f:
130
+ FPN_ckpt = pickle.load(f)
131
+ for k, v in FPN_ckpt['model'].items():
132
+ if 'backbone' in k:
133
+ checkpoint['.'.join(k.split('.')[1:])] = torch.from_numpy(v)
134
+ self.task_net.load_state_dict(checkpoint, strict=True)
135
+ self.task_net = self.task_net.to(device)
136
+ for k, p in self.task_net.named_parameters():
137
+ p.requires_grad = False
138
+ self.task_net.eval()
139
+ self.align = Alignment(divisor=32).to(device) # 初始化对齐模块,用于图像大小调整。
140
+ self.pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device) # imagenet mean
141
+
142
+ def forward(self, output, d, train_mode=False):
143
+ with torch.no_grad():
144
+ ## Ground truth for perceptual loss
145
+ d = d.flip(1).mul(255) # RGB to BGR, [0,1] to [0,255]
146
+ d = d - self.pixel_mean
147
+ if not train_mode:
148
+ d = self.align.align(d)
149
+ gt_out = self.task_net(d)
150
+
151
+ x_hat = torch.clamp(output["x_hat"], 0, 1)
152
+ x_hat = x_hat.flip(1).mul(255)
153
+ x_hat = x_hat - self.pixel_mean
154
+ if not train_mode:
155
+ x_hat = self.align.align(x_hat)
156
+ task_net_out = self.task_net(x_hat)
157
+
158
+ distortion_p2 = nn.MSELoss(reduction='none')(gt_out["p2"], task_net_out["p2"])
159
+ distortion_p3 = nn.MSELoss(reduction='none')(gt_out["p3"], task_net_out["p3"])
160
+ distortion_p4 = nn.MSELoss(reduction='none')(gt_out["p4"], task_net_out["p4"])
161
+ distortion_p5 = nn.MSELoss(reduction='none')(gt_out["p5"], task_net_out["p5"])
162
+ distortion_p6 = nn.MSELoss(reduction='none')(gt_out["p6"], task_net_out["p6"])
163
+
164
+ return 0.2*(distortion_p2.mean()+distortion_p3.mean()+distortion_p4.mean()+distortion_p5.mean()+distortion_p6.mean())
165
+
166
+
167
+ class AverageMeter:
168
+ """Compute running average."""
169
+
170
+ def __init__(self):
171
+ self.val = 0
172
+ self.avg = 0
173
+ self.sum = 0
174
+ self.count = 0
175
+
176
+ def update(self, val, n=1):
177
+ self.val = val
178
+ self.sum += val * n
179
+ self.count += n
180
+ self.avg = self.sum / self.count
181
+
182
+
183
+ def init(args):
184
+ base_dir = f'{args.root}/{args.exp_name}/'
185
+ os.makedirs(base_dir, exist_ok=True)
186
+ return base_dir
187
+
188
+ def setup_logger(log_dir):
189
+ log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s")
190
+ root_logger = logging.getLogger()
191
+ root_logger.setLevel(logging.INFO)
192
+
193
+ log_file_handler = logging.FileHandler(log_dir, encoding='utf-8')
194
+ log_file_handler.setFormatter(log_formatter)
195
+ root_logger.addHandler(log_file_handler)
196
+
197
+ log_stream_handler = logging.StreamHandler(sys.stdout)
198
+ log_stream_handler.setFormatter(log_formatter)
199
+ root_logger.addHandler(log_stream_handler)
200
+
201
+ logging.info('Logging file is %s' % log_dir)
202
+
203
+ def save_img(img: torch.Tensor, vis_path, input_p, mask=False):
204
+ img = img.clone().detach()
205
+ img = img.to(torch.device('cpu'))
206
+ if os.path.isdir(vis_path) is not True:
207
+ os.makedirs(vis_path)
208
+ end = '/'
209
+ if mask:
210
+ img_name = vis_path + 'mask_' + str(input_p[input_p.rfind(end):])
211
+ else:
212
+ img_name = vis_path + str(input_p[input_p.rfind(end):])
213
+ vutils.save_image(img, os.path.join(vis_path, img_name), nrow=8)
214
+
215
+ def train_one_epoch(model, data_loader, metrics_criterion, device,
216
+ optimizer, epoch, loss_scaler, log_writer, args, val_dataloader=None, stage='train'):
217
+ ## ======================= set configs ======================= ##
218
+ model.train(True)
219
+ metric_logger = misc.MetricLogger(delimiter=" ")
220
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
221
+ header = 'Epoch: [{}]'.format(epoch)
222
+ print_freq = 20
223
+ accum_iter = args.accum_iter
224
+ optimizer.zero_grad()
225
+ if log_writer is not None:
226
+ print('log_dir: {}'.format(log_writer.log_dir))
227
+
228
+ vis_path = os.path.join("./MIM_vbr/", stage)
229
+ os.makedirs(vis_path, exist_ok=True)
230
+
231
+ # tqdm_emu = tqdm.tqdm(enumerate(data_loader_train), total=len(data_loader_train), leave=False)
232
+ for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
233
+ samples = samples.to(device, non_blocking=True) # samples = original image
234
+
235
+ # we use a per iteration (instead of per epoch) lr scheduler
236
+ if data_iter_step % accum_iter == 0:
237
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
238
+
239
+ with torch.cuda.amp.autocast():
240
+ out_net = model(samples, is_training=True, manual_mask_rate=None)
241
+
242
+ rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'])
243
+ # rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'])
244
+ rec = rec.to(device)
245
+ out_criterion = metrics_criterion(samples, out_net, rec)
246
+ loss_value = out_criterion['total_loss'].item()
247
+
248
+ if not math.isfinite(loss_value):
249
+ print("Loss is {}, stopping training".format(loss_value))
250
+ sys.exit(1)
251
+
252
+ out_criterion['total_loss'] /= accum_iter
253
+ loss_scaler(out_criterion['total_loss'], optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
254
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
255
+ if (data_iter_step + 1) % accum_iter == 0:
256
+ optimizer.zero_grad()
257
+
258
+ torch.cuda.synchronize()
259
+
260
+ metric_logger.update(loss=loss_value)
261
+
262
+ lr = optimizer.param_groups[0]["lr"]
263
+ metric_logger.update(lr=lr)
264
+ metric_logger.update(bpp=out_criterion['bpp_loss'])
265
+ metric_logger.update(bpp_mask=out_criterion['bpp_mask'])
266
+ metric_logger.update(task_loss=out_net['task_loss'].item()) # task_loss未更新,均值更新了
267
+ metric_logger.update(lmbda=out_net['lambda'])
268
+ metric_logger.update(mask_ratio=out_net['mask_ratio']) # mask_ratio未更新,均值更新了
269
+ metric_logger.update(lpips=out_criterion['lpips'].item()) # lpips未更新,均值更新了
270
+ metric_logger.update(dists=out_criterion['dists'].item())
271
+
272
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
273
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
274
+ """ We use epoch_1000x as the x-axis in tensorboard.
275
+ This calibrates different curves when batch size changes.
276
+ """
277
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
278
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
279
+ log_writer.add_scalar('lr', lr, epoch_1000x)
280
+
281
+ ## ======================= update progress bar & visualization ======================= ##
282
+ if data_iter_step % 1000 == 0:
283
+ with torch.no_grad():
284
+ real_fake_images = torch.cat((samples, rec), dim=0)
285
+ vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{data_iter_step}.jpg"), nrow=8)
286
+ # vutils.save_image(samples, os.path.join(vis_path, f"{epoch}_{data_iter_step}_ori.jpg"), nrow=6)
287
+ # vutils.save_image(rec, os.path.join(vis_path, f"{epoch}_{data_iter_step}_rec.jpg"), nrow=6)
288
+ vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{data_iter_step}_mask.jpg"), nrow=8)
289
+
290
+ # if (data_iter_step % 10000 == 0) and (data_iter_step != 0):
291
+ # test_loss = inference(epoch, val_dataloader, model, metrics_criterion, device, 0.75, args, 'val')
292
+ # gather the stats from all processes
293
+ metric_logger.synchronize_between_processes()
294
+ print("Averaged stats:", metric_logger)
295
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
296
+
297
+ def train_one_epoch(train_dataloader, optimizer, model, criterion_rd, criterion_task, lmbda):
298
+ model.train()
299
+ device = next(model.parameters()).device
300
+ tqdm_emu = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False)
301
+ for i, d in tqdm_emu:
302
+ d = d.to(device)
303
+
304
+ optimizer.zero_grad()
305
+
306
+ out_net = model(d)
307
+
308
+ out_criterion = criterion_rd(out_net, d)
309
+ perc_loss = criterion_task(out_net, d)
310
+ total_loss = perc_loss + lmbda * out_criterion['bpp_loss']
311
+ total_loss.backward()
312
+ optimizer.step()
313
+
314
+ update_txt=f'[{i*len(d)}/{len(train_dataloader.dataset)}] | Loss: {total_loss.item():.3f} | Distortion loss: {perc_loss.item():.5f} | Bpp loss: {out_criterion["bpp_loss"].item():.4f}'
315
+ tqdm_emu.set_postfix_str(update_txt, refresh=True)
316
+
317
+
318
+ def validation_epoch(epoch, val_dataloader, model, criterion_rd, criterion_task, lmbda):
319
+ model.eval()
320
+ device = next(model.parameters()).device
321
+
322
+ bpp_loss = AverageMeter()
323
+ mse_loss = AverageMeter()
324
+ psnr = AverageMeter()
325
+ percloss = AverageMeter()
326
+ totalloss = AverageMeter()
327
+
328
+ with torch.no_grad():
329
+ tqdm_meter = tqdm.tqdm(enumerate(val_dataloader),leave=False, total=len(val_dataloader))
330
+ for i, d in tqdm_meter:
331
+ align = Alignment(divisor=256, mode='resize').to(device)
332
+
333
+ d = d.to(device)
334
+ align_d = align.align(d)
335
+
336
+ out_net = model(align_d)
337
+ out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1)
338
+ out_criterion = criterion_rd(out_net, d)
339
+ perc_loss = criterion_task(out_net, d)
340
+ total_loss = perc_loss + lmbda * out_criterion['bpp_loss']
341
+
342
+ bpp_loss.update(out_criterion["bpp_loss"])
343
+ mse_loss.update(out_criterion["mse_loss"])
344
+ psnr.update(out_criterion['psnr'])
345
+ percloss.update(perc_loss)
346
+ totalloss.update(total_loss)
347
+
348
+ txt = f"Loss: {totalloss.avg:.3f} | MSE loss: {mse_loss.avg:.5f} | Perception loss: {percloss.avg:.4f} | Bpp loss: {bpp_loss.avg:.4f}"
349
+ tqdm_meter.set_postfix_str(txt)
350
+
351
+ model.train()
352
+ print(f"Epoch: {epoch} | bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}")
353
+ return totalloss.avg
354
+
355
+
356
+ def test_epoch(test_dataloader, model, criterion_rd, predictor, evaluator):
357
+ model.eval()
358
+ device = next(model.parameters()).device
359
+ pixel_mean = torch.Tensor([103.530, 116.280, 123.675]).view(-1, 1, 1).to(device)
360
+
361
+ bpp_loss = AverageMeter()
362
+ psnr = AverageMeter()
363
+
364
+ with torch.no_grad():
365
+ tqdm_meter = tqdm.tqdm(enumerate(test_dataloader),leave=False, total=len(test_dataloader))
366
+ for i, batch in tqdm_meter:
367
+ with ExitStack() as stack:
368
+ ## model to eval()
369
+ if isinstance(predictor.model, nn.Module):
370
+ stack.enter_context(inference_context(predictor.model)) # inference_context:将预测器的模型设为评估模式
371
+ stack.enter_context(torch.no_grad())
372
+
373
+ align = Alignment(divisor=256, mode='resize').to(device)
374
+ rcnn_align = Alignment(divisor=32).to(device)
375
+
376
+ img = read_image(batch[0]["file_name"], format="BGR")
377
+ d = torch.stack([batch[0]['image'].float().div(255)]).flip(1).to(device)
378
+ align_d = align.align(d)
379
+
380
+ out_net = model(align_d)
381
+ out_net['x_hat'] = align.resume(out_net['x_hat']).clamp_(0, 1)
382
+ out_criterion = criterion_rd(out_net, d)
383
+
384
+ trand_y_tilde = out_net['x_hat'].flip(1).mul(255)
385
+ trand_y_tilde = rcnn_align.align(trand_y_tilde - pixel_mean)
386
+
387
+ bpp_loss.update(out_criterion["bpp_loss"])
388
+ psnr.update(out_criterion['psnr'])
389
+
390
+ ## MaskRCNN
391
+ predictions = predictor(img, trand_y_tilde)
392
+ evaluator.process(batch, [predictions])
393
+ txt = f"Bpp loss: {bpp_loss.avg:.4f} | PSNR loss: {psnr.avg:.4f}"
394
+ tqdm_meter.set_postfix_str(txt)
395
+
396
+ results = evaluator.evaluate()
397
+ model.train()
398
+ print(f"bpp loss: {bpp_loss.avg:.5f} | psnr: {psnr.avg:.5f}")
399
+ return
400
+
401
+
402
+ def inference(epoch, test_loader, model, metrics_criterion, device, manual_mask_ratio, args, stage='test'):
403
+ model.eval()
404
+ bpp_loss = AverageMeter()
405
+ bpp_mask = AverageMeter()
406
+ psnr = AverageMeter()
407
+ msssim = AverageMeter()
408
+ lpips = AverageMeter()
409
+ dists = AverageMeter()
410
+ test_loss = AverageMeter()
411
+
412
+ vis_path = os.path.join("./MIM_test_high_resolu/", stage)
413
+ os.makedirs(vis_path, exist_ok=True)
414
+ if stage == 'test':
415
+ test_vis_path = os.path.join("/home/v-ruoyufeng/v-ruoyufeng/qyp/rec_fid", manual_mask_ratio)
416
+ os.makedirs(test_vis_path, exist_ok=True)
417
+
418
+ with torch.no_grad():
419
+ tqdm_meter = tqdm.tqdm(enumerate(test_loader), leave=False, total=len(test_loader))
420
+ for i, d in tqdm_meter:
421
+ d = d.to(device)
422
+ d0 = d
423
+ b_ori, _, h_ori, w_ori = d.shape
424
+ d, patch_sizes, num_blocks_h, num_blocks_w = adaptively_split_and_pad(d)
425
+ # d, h_ori, w_ori = prepadding(d)
426
+ out_net = model(d, is_training=False, manual_mask_rate=manual_mask_ratio)
427
+ # rec = model.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], out_net['ori_shape'], out_net['new_shape'])
428
+ rec = model.module.gen_img(out_net['logits'], out_net['token_all_mask'], out_net['token_indices'], num_iter=20)
429
+ rec = crop_and_reconstruct(rec, patch_sizes, num_blocks_h, num_blocks_w)
430
+ rec = rec.unsqueeze(0)
431
+ rec = rec.to(device)
432
+ print('d0:', d0.shape)
433
+ print('rec:', rec.shape)
434
+
435
+ # d = d[:, :, :h_ori, :w_ori]
436
+ # rec = rec[:, :, :h_ori, :w_ori]
437
+ # rec = rec[:, :, :hx, :wx]
438
+ out_criterion = metrics_criterion(d0, out_net, rec)
439
+
440
+ bpp_loss.update(out_criterion["bpp_loss"])
441
+ bpp_mask.update(out_criterion["bpp_mask"])
442
+ psnr.update(out_criterion['psnr'])
443
+ msssim.update(out_criterion['msssim'])
444
+ lpips.update(out_criterion['lpips'])
445
+ dists.update(out_criterion['dists'])
446
+ test_loss.update(out_criterion['total_loss'])
447
+
448
+ ## ======================= update progress bar & visualization ======================= ##
449
+ if stage == 'val':
450
+ # if i % 5 == 0:
451
+ with torch.no_grad():
452
+ real_fake_images = torch.cat((d0, rec), dim=0)
453
+ vutils.save_image(real_fake_images, os.path.join(vis_path, f"{epoch}_{i}.jpg"))
454
+ vutils.save_image(out_net['mask_vis'], os.path.join(vis_path, f"{epoch}_{i}_mask.jpg"))
455
+ if stage == 'test':
456
+ with torch.no_grad():
457
+ vutils.save_image(rec, os.path.join(test_vis_path, f"{i}.jpg"), nrow=8)
458
+ # txt = f"Rec Loss:{test_loss.avg:.4f}|Bpp:{bpp_loss.avg:.4f}|lpips:{lpips.avg:.4f}|msssim:{msssim.avg:.4f}|dists:{dists.avg:.4f}|psnr:{psnr.avg:.4f}\n"
459
+ # tqdm_meter.set_postfix_str(txt)
460
+
461
+ model.train()
462
+
463
+ # 假设其它变量和环境已经正确设置
464
+ if torch.distributed.is_initialized():
465
+ rank = dist.get_rank()
466
+ else:
467
+ rank = 0 # 假设未使用DDP,则默认为单进程模式,rank为0
468
+
469
+ if rank == 0:
470
+ log_txt = f"{epoch}|bpp:{bpp_loss.avg.item():.5f}|mask:{bpp_mask.avg:.5f}|mask_ratio:{manual_mask_ratio}|psnr:{psnr.avg.item():.5f}|msssim:{msssim.avg.item():.5f}|lpips:{lpips.avg.item():.5f}|dists:{dists.avg.item():.5f}|Test loss:{test_loss.avg.item():.5f}"
471
+ logging.info(log_txt)
472
+ return test_loss.avg
473
+
474
+
475
+ def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"):
476
+ torch.save(state, base_dir+filename)
477
+ if is_best:
478
+ torch.save(state, base_dir+"checkpoint_best.pth.tar")
479
+
480
+ def parse_args(argv):
481
+ parser = argparse.ArgumentParser(description="Example training script.")
482
+ parser.add_argument(
483
+ "-c",
484
+ "--config",
485
+ default="config/vpt_default.yaml",
486
+ help="Path to config file",
487
+ )
488
+ parser.add_argument(
489
+ '--name',
490
+ default=datetime.now().strftime('%Y-%m-%d_%H_%M_%S'),
491
+ type=str,
492
+ help='Result dir name',
493
+ )
494
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
495
+ help='learning rate (absolute lr)')
496
+ given_configs, remaining = parser.parse_known_args(argv)
497
+ # distributed training parameters
498
+ parser.add_argument('--world_size', default=1, type=int,
499
+ help='number of distributed processes')
500
+ parser.add_argument('--local-rank', default=-1, type=int)
501
+ parser.add_argument('--dist_on_itp', action='store_true')
502
+ parser.add_argument('--dist_url', default='env://',
503
+ help='url used to set up distributed training')
504
+ with open(given_configs.config) as file:
505
+ yaml_data= yaml.safe_load(file)
506
+ parser.set_defaults(**yaml_data)
507
+
508
+ parser.add_argument(
509
+ "-T",
510
+ "--TEST",
511
+ # action='store_true',
512
+ default=False,
513
+ help='Testing'
514
+ )
515
+ args = parser.parse_args(remaining)
516
+ return args
517
+
518
+
519
+ def main(argv):
520
+ args = parse_args(argv)
521
+ base_dir = init(args) # create the base dir for saving the results
522
+
523
+ if args.output_dir:
524
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
525
+ args.log_dir = args.output_dir
526
+
527
+ misc.init_distributed_mode(args)
528
+
529
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
530
+ print("{}".format(args).replace(', ', ',\n'))
531
+
532
+ device = torch.device(args.device)
533
+ # fix the seed for reproducibility
534
+ seed = args.seed + misc.get_rank()
535
+ torch.manual_seed(seed)
536
+ torch.cuda.manual_seed(seed)
537
+ np.random.seed(seed)
538
+ random.seed(seed)
539
+
540
+ cudnn.benchmark = True
541
+
542
+ setup_logger(base_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log')
543
+ msg = f'======================= {args.name} ======================='
544
+ logging.info(msg)
545
+ for k in args.__dict__:
546
+ logging.info(k + ':' + str(args.__dict__[k]))
547
+ logging.info('=' * len(msg))
548
+
549
+ ## ======================= prepare dataset ======================= ##
550
+ transform_det = transforms.Compose([
551
+ transforms.RandomHorizontalFlip(),
552
+ transforms.ToTensor()])
553
+ transform_val = transforms.Compose([
554
+ # transforms.Resize(224),
555
+ # transforms.CenterCrop(224),
556
+ transforms.ToTensor()
557
+ ])
558
+
559
+
560
+ if args.dataset=='coco':
561
+ train_dataset = MSCOCO(args.dataset_path + "/train2017/",
562
+ transform_det,
563
+ "/home/t2vg-a100-G4-10/project/qyp/mimc_rope/util/img_list.txt")
564
+ # val_dataset = Kodak(args.kodak_path, transform_val)
565
+ val_dataset = MSCOCO(args.kodak_path, transform_val)
566
+
567
+ device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
568
+
569
+ # if args.distributed:
570
+ if True:
571
+ num_tasks = misc.get_world_size()
572
+ global_rank = misc.get_rank()
573
+ sampler_val = torch.utils.data.DistributedSampler(
574
+ val_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
575
+ )
576
+ else:
577
+ sampler_train = torch.utils.data.RandomSampler(train_dataset)
578
+
579
+ if global_rank == 0 and args.log_dir is not None:
580
+ os.makedirs(args.log_dir, exist_ok=True)
581
+ log_writer = SummaryWriter(log_dir=args.log_dir)
582
+ else:
583
+ log_writer = None
584
+
585
+ val_dataloader = DataLoader(val_dataset, sampler=sampler_val, batch_size=1,
586
+ num_workers=args.num_workers, shuffle=False, pin_memory=args.pin_mem, drop_last=True)
587
+
588
+ ## ======================= prepare model ======================= ##
589
+ vqgan_ckpt_path = '/home/t2vg-a100-G4-10/project/qyp/mage/vqgan_jax_strongaug.ckpt'
590
+ model = models_mage_codec_high_resolu.__dict__[args.model](mask_ratio_mu=args.mask_ratio_mu, mask_ratio_std=args.mask_ratio_std,
591
+ mask_ratio_min=args.mask_ratio_min, mask_ratio_max=args.mask_ratio_max,
592
+ vqgan_ckpt_path=vqgan_ckpt_path)
593
+
594
+ model.to(device)
595
+ model_without_ddp = model
596
+ print("Model = %s" % str(model_without_ddp))
597
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
598
+ if args.lr is None: # only base_lr is specified
599
+ args.lr = args.blr * eff_batch_size / 256
600
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
601
+ print("actual lr: %.2e" % args.lr)
602
+
603
+ print("accumulate grad iterations: %d" % args.accum_iter)
604
+ print("effective batch size: %d" % eff_batch_size)
605
+
606
+ if args.distributed:
607
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
608
+ model_without_ddp = model.module
609
+
610
+ # following timm: set wd as 0 for bias and norm layers
611
+ param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
612
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
613
+ print(optimizer)
614
+ loss_scaler = NativeScaler()
615
+
616
+ # resume from a checkpoint
617
+ misc.load_model(args=args, model_without_ddp=model_without_ddp,
618
+ optimizer=optimizer, loss_scaler=loss_scaler, strict=False)
619
+
620
+ metrics_criterion = CalMetrics()
621
+ # cls_criterion = Clsloss(device, True)
622
+ ## ======================= Start Training ======================= ##
623
+ last_epoch = args.start_epoch
624
+
625
+ ## ======================= pre validation ======================= ##
626
+ print("############## pre validation ##############")
627
+ best_loss = float("inf")
628
+ tqrange = tqdm.trange(last_epoch, args.epochs)
629
+ val_mask_ratio = 0.5
630
+ test_loss = inference(-1, val_dataloader, model, metrics_criterion, device, val_mask_ratio, args, 'val')
631
+
632
+
633
+
634
+ if __name__ == "__main__":
635
+ main(sys.argv[1:])
engine_finetune.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from typing import Iterable, Optional
4
+
5
+ import torch
6
+
7
+ from timm.data import Mixup
8
+ from timm.utils import accuracy
9
+
10
+ import util.misc as misc
11
+ import util.lr_sched as lr_sched
12
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+
14
+
15
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
16
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
17
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
18
+ mixup_fn: Optional[Mixup] = None, log_writer=None,
19
+ args=None, denormalize=False):
20
+ model.train(True)
21
+ metric_logger = misc.MetricLogger(delimiter=" ")
22
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
23
+ header = 'Epoch: [{}]'.format(epoch)
24
+ print_freq = 20
25
+
26
+ accum_iter = args.accum_iter
27
+
28
+ optimizer.zero_grad()
29
+
30
+ if log_writer is not None:
31
+ print('log_dir: {}'.format(log_writer.log_dir))
32
+
33
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
34
+
35
+ # we use a per iteration (instead of per epoch) lr scheduler
36
+ if data_iter_step % accum_iter == 0:
37
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
38
+
39
+ samples = samples.to(device, non_blocking=True)
40
+ targets = targets.to(device, non_blocking=True)
41
+
42
+ if mixup_fn is not None:
43
+ samples, targets = mixup_fn(samples, targets)
44
+
45
+ if denormalize:
46
+ for c_id, (mean, std) in enumerate(zip(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)):
47
+ samples[:, c_id] = samples[:, c_id] * std + mean
48
+
49
+ with torch.cuda.amp.autocast():
50
+ outputs = model(samples)
51
+ loss = criterion(outputs, targets)
52
+
53
+ loss_value = loss.item()
54
+
55
+ if not math.isfinite(loss_value):
56
+ print("Loss is {}, stopping training".format(loss_value))
57
+ sys.exit(1)
58
+
59
+ loss /= accum_iter
60
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
61
+ parameters=model.parameters(), create_graph=False,
62
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
63
+ if (data_iter_step + 1) % accum_iter == 0:
64
+ optimizer.zero_grad()
65
+
66
+ torch.cuda.synchronize()
67
+
68
+ metric_logger.update(loss=loss_value)
69
+ min_lr = 10.
70
+ max_lr = 0.
71
+ for group in optimizer.param_groups:
72
+ min_lr = min(min_lr, group["lr"])
73
+ max_lr = max(max_lr, group["lr"])
74
+
75
+ metric_logger.update(lr=max_lr)
76
+
77
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
78
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
79
+ """ We use epoch_1000x as the x-axis in tensorboard.
80
+ This calibrates different curves when batch size changes.
81
+ """
82
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
83
+ log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
84
+ log_writer.add_scalar('lr', max_lr, epoch_1000x)
85
+
86
+ # gather the stats from all processes
87
+ metric_logger.synchronize_between_processes()
88
+ print("Averaged stats:", metric_logger)
89
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
90
+
91
+
92
+ @torch.no_grad()
93
+ def evaluate(data_loader, model, device, denormalize=False):
94
+ criterion = torch.nn.CrossEntropyLoss()
95
+
96
+ metric_logger = misc.MetricLogger(delimiter=" ")
97
+ header = 'Test:'
98
+
99
+ # switch to evaluation mode
100
+ model.eval()
101
+
102
+ for batch in metric_logger.log_every(data_loader, 10, header):
103
+ images = batch[0]
104
+ target = batch[-1]
105
+ images = images.to(device, non_blocking=True)
106
+ target = target.to(device, non_blocking=True)
107
+
108
+ if denormalize:
109
+ for c_id, (mean, std) in enumerate(zip(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)):
110
+ images[:, c_id] = images[:, c_id] * std + mean
111
+
112
+ # compute output
113
+ with torch.cuda.amp.autocast():
114
+ output = model(images)
115
+ loss = criterion(output, target)
116
+
117
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
118
+
119
+ batch_size = images.shape[0]
120
+ metric_logger.update(loss=loss.item())
121
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
122
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
123
+ # gather the stats from all processes
124
+ metric_logger.synchronize_between_processes()
125
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
126
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
127
+
128
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
engine_pretrain.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ from typing import Iterable
4
+
5
+ import torch
6
+
7
+ import util.misc as misc
8
+ import util.lr_sched as lr_sched
9
+
10
+
11
+ def train_one_epoch(model: torch.nn.Module,
12
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
13
+ device: torch.device, epoch: int, loss_scaler,
14
+ log_writer=None,
15
+ args=None):
16
+ model.train(True)
17
+ metric_logger = misc.MetricLogger(delimiter=" ")
18
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
19
+ header = 'Epoch: [{}]'.format(epoch)
20
+ print_freq = 20
21
+
22
+ accum_iter = args.accum_iter
23
+
24
+ optimizer.zero_grad()
25
+
26
+ if log_writer is not None:
27
+ print('log_dir: {}'.format(log_writer.log_dir))
28
+
29
+ for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
30
+
31
+ # we use a per iteration (instead of per epoch) lr scheduler
32
+ if data_iter_step % accum_iter == 0:
33
+ lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
34
+
35
+ samples = samples.to(device, non_blocking=True)
36
+
37
+ with torch.cuda.amp.autocast():
38
+ loss, _, _ = model(samples)
39
+
40
+ loss_value = loss.item()
41
+
42
+ if not math.isfinite(loss_value):
43
+ print("Loss is {}, stopping training".format(loss_value))
44
+ sys.exit(1)
45
+
46
+ loss /= accum_iter
47
+ loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(),
48
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
49
+ if (data_iter_step + 1) % accum_iter == 0:
50
+ optimizer.zero_grad()
51
+
52
+ torch.cuda.synchronize()
53
+
54
+ metric_logger.update(loss=loss_value)
55
+
56
+ lr = optimizer.param_groups[0]["lr"]
57
+ metric_logger.update(lr=lr)
58
+
59
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
60
+ if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
61
+ """ We use epoch_1000x as the x-axis in tensorboard.
62
+ This calibrates different curves when batch size changes.
63
+ """
64
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
65
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
66
+ log_writer.add_scalar('lr', lr, epoch_1000x)
67
+
68
+
69
+ # gather the stats from all processes
70
+ metric_logger.synchronize_between_processes()
71
+ print("Averaged stats:", metric_logger)
72
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
environment.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mage
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=10.2
9
+ - pytorch=1.7.1
10
+ - torchvision=0.8.2
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - opencv-python==4.1.2.30
15
+ - pudb==2019.2
16
+ - imageio==2.9.0
17
+ - imageio-ffmpeg==0.4.2
18
+ - pytorch-lightning==1.0.8
19
+ - omegaconf==2.0.0
20
+ - test-tube>=0.7.5
21
+ - streamlit>=0.73.1
22
+ - einops==0.3.0
23
+ - more-itertools>=8.0.0
24
+ - transformers==4.3.1
25
+ - timm==0.3.2
26
+ - -e .