MonsterMMORPG commited on
Commit
5678a7b
·
verified ·
1 Parent(s): 05b02bc

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. Trellv2/ZhengPeng7--BiRefNet/.gitattributes +35 -0
  3. Trellv2/ZhengPeng7--BiRefNet/.gitignore +142 -0
  4. Trellv2/ZhengPeng7--BiRefNet/BiRefNet_config.py +11 -0
  5. Trellv2/ZhengPeng7--BiRefNet/README.md +226 -0
  6. Trellv2/ZhengPeng7--BiRefNet/birefnet.py +2252 -0
  7. Trellv2/ZhengPeng7--BiRefNet/config.json +20 -0
  8. Trellv2/ZhengPeng7--BiRefNet/handler.py +139 -0
  9. Trellv2/ZhengPeng7--BiRefNet/model.safetensors +3 -0
  10. Trellv2/ZhengPeng7--BiRefNet/requirements.txt +16 -0
  11. Trellv2/briaai--RMBG-2.0/.gitattributes +40 -0
  12. Trellv2/briaai--RMBG-2.0/BiRefNet_config.py +11 -0
  13. Trellv2/briaai--RMBG-2.0/README.md +218 -0
  14. Trellv2/briaai--RMBG-2.0/birefnet.py +2245 -0
  15. Trellv2/briaai--RMBG-2.0/collage5.png +3 -0
  16. Trellv2/briaai--RMBG-2.0/config.json +20 -0
  17. Trellv2/briaai--RMBG-2.0/diagram1.png +0 -0
  18. Trellv2/briaai--RMBG-2.0/model.safetensors +3 -0
  19. Trellv2/briaai--RMBG-2.0/onnx/model_bnb4.onnx +3 -0
  20. Trellv2/briaai--RMBG-2.0/onnx/model_fp16.onnx +3 -0
  21. Trellv2/briaai--RMBG-2.0/onnx/model_int8.onnx +3 -0
  22. Trellv2/briaai--RMBG-2.0/onnx/model_q4.onnx +3 -0
  23. Trellv2/briaai--RMBG-2.0/onnx/model_q4f16.onnx +3 -0
  24. Trellv2/briaai--RMBG-2.0/onnx/model_quantized.onnx +3 -0
  25. Trellv2/briaai--RMBG-2.0/onnx/model_uint8.onnx +3 -0
  26. Trellv2/briaai--RMBG-2.0/preprocessor_config.json +23 -0
  27. Trellv2/briaai--RMBG-2.0/pytorch_model.bin +3 -0
  28. Trellv2/briaai--RMBG-2.0/t4.png +3 -0
  29. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/.gitattributes +35 -0
  30. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/LICENSE.md +66 -0
  31. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/README.md +477 -0
  32. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/config.json +32 -0
  33. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/model.safetensors +3 -0
  34. Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/preprocessor_config.json +31 -0
  35. Trellv2/microsoft--TRELLIS.2-4B/.gitattributes +35 -0
  36. Trellv2/microsoft--TRELLIS.2-4B/README.md +139 -0
  37. Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16.json +24 -0
  38. Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16.json +23 -0
  39. Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16.safetensors +3 -0
  40. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_1024_bf16.json +19 -0
  41. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_1024_bf16.safetensors +3 -0
  42. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_512_bf16.json +19 -0
  43. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_1024_bf16.json +19 -0
  44. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json +19 -0
  45. Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16.safetensors +3 -0
  46. Trellv2/microsoft--TRELLIS.2-4B/ckpts/ss_flow_img_dit_1_3B_64_bf16.json +19 -0
  47. Trellv2/microsoft--TRELLIS.2-4B/ckpts/ss_flow_img_dit_1_3B_64_bf16.safetensors +3 -0
  48. Trellv2/microsoft--TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16.json +25 -0
  49. Trellv2/microsoft--TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16.json +24 -0
  50. Trellv2/microsoft--TRELLIS.2-4B/pipeline.json +95 -0
.gitattributes CHANGED
@@ -112,3 +112,5 @@ pytorchvideo-0.1.5-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
112
  flex_gemm-0.0.1-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
113
  nvdiffrast-0.4.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
114
  cumesh-0.0.1-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
 
 
 
112
  flex_gemm-0.0.1-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
113
  nvdiffrast-0.4.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
114
  cumesh-0.0.1-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
115
+ Trellv2/briaai--RMBG-2.0/collage5.png filter=lfs diff=lfs merge=lfs -text
116
+ Trellv2/briaai--RMBG-2.0/t4.png filter=lfs diff=lfs merge=lfs -text
Trellv2/ZhengPeng7--BiRefNet/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Trellv2/ZhengPeng7--BiRefNet/.gitignore ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ e_*
3
+ .vscode
4
+ ckpt
5
+ preds
6
+ evaluation/eval-*
7
+ nohup.out*
8
+ tmp*
9
+ *.pth
10
+ core-*-python-*
11
+ .DS_Store
12
+ __MACOSX/
13
+
14
+ # Byte-compiled / optimized / DLL files
15
+ __pycache__/
16
+ *.py[cod]
17
+ *$py.class
18
+
19
+ # C extensions
20
+ *.so
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ pip-wheel-metadata/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # PyInstaller
44
+ # Usually these files are written by a python script from a template
45
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
46
+ *.manifest
47
+ *.spec
48
+
49
+ # Installer logs
50
+ pip-log.txt
51
+ pip-delete-this-directory.txt
52
+
53
+ # Unit test / coverage reports
54
+ htmlcov/
55
+ .tox/
56
+ .nox/
57
+ .coverage
58
+ .coverage.*
59
+ .cache
60
+ nosetests.xml
61
+ coverage.xml
62
+ *.cover
63
+ *.py,cover
64
+ .hypothesis/
65
+ .pytest_cache/
66
+
67
+ # Translations
68
+ *.mo
69
+ *.pot
70
+
71
+ # Django stuff:
72
+ *.log
73
+ local_settings.py
74
+ db.sqlite3
75
+ db.sqlite3-journal
76
+
77
+ # Flask stuff:
78
+ instance/
79
+ .webassets-cache
80
+
81
+ # Scrapy stuff:
82
+ .scrapy
83
+
84
+ # Sphinx documentation
85
+ docs/_build/
86
+
87
+ # PyBuilder
88
+ target/
89
+
90
+ # Jupyter Notebook
91
+ .ipynb_checkpoints
92
+
93
+ # IPython
94
+ profile_default/
95
+ ipython_config.py
96
+
97
+ # pyenv
98
+ .python-version
99
+
100
+ # pipenv
101
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
103
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
104
+ # install all needed dependencies.
105
+ #Pipfile.lock
106
+
107
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108
+ __pypackages__/
109
+
110
+ # Celery stuff
111
+ celerybeat-schedule
112
+ celerybeat.pid
113
+
114
+ # SageMath parsed files
115
+ *.sage.py
116
+
117
+ # Environments
118
+ .env
119
+ .venv
120
+ env/
121
+ venv/
122
+ ENV/
123
+ env.bak/
124
+ venv.bak/
125
+
126
+ # Spyder project settings
127
+ .spyderproject
128
+ .spyproject
129
+
130
+ # Rope project settings
131
+ .ropeproject
132
+
133
+ # mkdocs documentation
134
+ /site
135
+
136
+ # mypy
137
+ .mypy_cache/
138
+ .dmypy.json
139
+ dmypy.json
140
+
141
+ # Pyre type checker
142
+ .pyre/
Trellv2/ZhengPeng7--BiRefNet/BiRefNet_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BiRefNetConfig(PretrainedConfig):
4
+ model_type = "SegformerForSemanticSegmentation"
5
+ def __init__(
6
+ self,
7
+ bb_pretrained=False,
8
+ **kwargs
9
+ ):
10
+ self.bb_pretrained = bb_pretrained
11
+ super().__init__(**kwargs)
Trellv2/ZhengPeng7--BiRefNet/README.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: birefnet
3
+ tags:
4
+ - background-removal
5
+ - mask-generation
6
+ - Dichotomous Image Segmentation
7
+ - Camouflaged Object Detection
8
+ - Salient Object Detection
9
+ - pytorch_model_hub_mixin
10
+ - model_hub_mixin
11
+ - transformers
12
+ repo_url: https://github.com/ZhengPeng7/BiRefNet
13
+ pipeline_tag: image-segmentation
14
+ license: mit
15
+ ---
16
+ <h1 align="center">Bilateral Reference for High-Resolution Dichotomous Image Segmentation</h1>
17
+
18
+ <div align='center'>
19
+ <a href='https://scholar.google.com/citations?user=TZRzWOsAAAAJ' target='_blank'><strong>Peng Zheng</strong></a><sup> 1,4,5,6</sup>,&thinsp;
20
+ <a href='https://scholar.google.com/citations?user=0uPb8MMAAAAJ' target='_blank'><strong>Dehong Gao</strong></a><sup> 2</sup>,&thinsp;
21
+ <a href='https://scholar.google.com/citations?user=kakwJ5QAAAAJ' target='_blank'><strong>Deng-Ping Fan</strong></a><sup> 1*</sup>,&thinsp;
22
+ <a href='https://scholar.google.com/citations?user=9cMQrVsAAAAJ' target='_blank'><strong>Li Liu</strong></a><sup> 3</sup>,&thinsp;
23
+ <a href='https://scholar.google.com/citations?user=qQP6WXIAAAAJ' target='_blank'><strong>Jorma Laaksonen</strong></a><sup> 4</sup>,&thinsp;
24
+ <a href='https://scholar.google.com/citations?user=pw_0Z_UAAAAJ' target='_blank'><strong>Wanli Ouyang</strong></a><sup> 5</sup>,&thinsp;
25
+ <a href='https://scholar.google.com/citations?user=stFCYOAAAAAJ' target='_blank'><strong>Nicu Sebe</strong></a><sup> 6</sup>
26
+ </div>
27
+
28
+ <div align='center'>
29
+ <sup>1 </sup>Nankai University&ensp; <sup>2 </sup>Northwestern Polytechnical University&ensp; <sup>3 </sup>National University of Defense Technology&ensp; <sup>4 </sup>Aalto University&ensp; <sup>5 </sup>Shanghai AI Laboratory&ensp; <sup>6 </sup>University of Trento&ensp;
30
+ </div>
31
+
32
+ <div align="center" style="display: flex; justify-content: center; flex-wrap: wrap;">
33
+ <a href='https://www.sciopen.com/article/pdf/10.26599/AIR.2024.9150038.pdf'><img src='https://img.shields.io/badge/Journal-Paper-red'></a>&ensp;
34
+ <a href='https://arxiv.org/pdf/2401.03407'><img src='https://img.shields.io/badge/arXiv-BiRefNet-red'></a>&ensp;
35
+ <a href='https://drive.google.com/file/d/1aBnJ_R9lbnC2dm8dqD0-pzP2Cu-U1Xpt/view?usp=drive_link'><img src='https://img.shields.io/badge/中文版-BiRefNet-red'></a>&ensp;
36
+ <a href='https://www.birefnet.top'><img src='https://img.shields.io/badge/Page-BiRefNet-red'></a>&ensp;
37
+ <a href='https://drive.google.com/drive/folders/1s2Xe0cjq-2ctnJBR24563yMSCOu4CcxM'><img src='https://img.shields.io/badge/Drive-Stuff-green'></a>&ensp;
38
+ <a href='LICENSE'><img src='https://img.shields.io/badge/License-MIT-yellow'></a>&ensp;
39
+ <a href='https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HF%20Spaces-BiRefNet-blue'></a>&ensp;
40
+ <a href='https://huggingface.co/ZhengPeng7/BiRefNet'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HF%20Models-BiRefNet-blue'></a>&ensp;
41
+ <a href='https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link'><img src='https://img.shields.io/badge/Single_Image_Inference-F9AB00?style=for-the-badge&logo=googlecolab&color=525252'></a>&ensp;
42
+ <a href='https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl#scrollTo=DJ4meUYjia6S'><img src='https://img.shields.io/badge/Inference_&_Evaluation-F9AB00?style=for-the-badge&logo=googlecolab&color=525252'></a>&ensp;
43
+ </div>
44
+
45
+
46
+ | *DIS-Sample_1* | *DIS-Sample_2* |
47
+ | :------------------------------: | :-------------------------------: |
48
+ | <img src="https://drive.google.com/thumbnail?id=1ItXaA26iYnE8XQ_GgNLy71MOWePoS2-g&sz=w400" /> | <img src="https://drive.google.com/thumbnail?id=1Z-esCujQF_uEa_YJjkibc3NUrW4aR_d4&sz=w400" /> |
49
+
50
+ This repo is the official implementation of "[**Bilateral Reference for High-Resolution Dichotomous Image Segmentation**](https://arxiv.org/pdf/2401.03407.pdf)" (___CAAI AIR 2024___).
51
+
52
+ Visit our GitHub repo: [https://github.com/ZhengPeng7/BiRefNet](https://github.com/ZhengPeng7/BiRefNet) for more details -- **codes**, **docs**, and **model zoo**!
53
+
54
+ ## How to use
55
+
56
+ ### 0. Install Packages:
57
+ ```
58
+ pip install -qr https://raw.githubusercontent.com/ZhengPeng7/BiRefNet/main/requirements.txt
59
+ ```
60
+
61
+ ### 1. Load BiRefNet:
62
+
63
+ #### Use codes + weights from HuggingFace
64
+ > Only use the weights on HuggingFace -- Pro: No need to download BiRefNet codes manually; Con: Codes on HuggingFace might not be latest version (I'll try to keep them always latest).
65
+
66
+ ```python
67
+ # Load BiRefNet with weights
68
+ from transformers import AutoModelForImageSegmentation
69
+ birefnet = AutoModelForImageSegmentation.from_pretrained('ZhengPeng7/BiRefNet', trust_remote_code=True)
70
+ ```
71
+
72
+ #### Use codes from GitHub + weights from HuggingFace
73
+ > Only use the weights on HuggingFace -- Pro: codes are always latest; Con: Need to clone the BiRefNet repo from my GitHub.
74
+
75
+ ```shell
76
+ # Download codes
77
+ git clone https://github.com/ZhengPeng7/BiRefNet.git
78
+ cd BiRefNet
79
+ ```
80
+
81
+ ```python
82
+ # Use codes locally
83
+ from models.birefnet import BiRefNet
84
+
85
+ # Load weights from Hugging Face Models
86
+ birefnet = BiRefNet.from_pretrained('ZhengPeng7/BiRefNet')
87
+ ```
88
+
89
+ #### Use codes from GitHub + weights from local space
90
+ > Only use the weights and codes both locally.
91
+
92
+ ```python
93
+ # Use codes and weights locally
94
+ import torch
95
+ from utils import check_state_dict
96
+
97
+ birefnet = BiRefNet(bb_pretrained=False)
98
+ state_dict = torch.load(PATH_TO_WEIGHT, map_location='cpu')
99
+ state_dict = check_state_dict(state_dict)
100
+ birefnet.load_state_dict(state_dict)
101
+ ```
102
+
103
+ #### Use the loaded BiRefNet for inference
104
+ ```python
105
+ # Imports
106
+ from PIL import Image
107
+ import matplotlib.pyplot as plt
108
+ import torch
109
+ from torchvision import transforms
110
+ from models.birefnet import BiRefNet
111
+
112
+ birefnet = ... # -- BiRefNet should be loaded with codes above, either way.
113
+ torch.set_float32_matmul_precision(['high', 'highest'][0])
114
+ birefnet.to('cuda')
115
+ birefnet.eval()
116
+ birefnet.half()
117
+
118
+ def extract_object(birefnet, imagepath):
119
+ # Data settings
120
+ image_size = (1024, 1024)
121
+ transform_image = transforms.Compose([
122
+ transforms.Resize(image_size),
123
+ transforms.ToTensor(),
124
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
125
+ ])
126
+
127
+ image = Image.open(imagepath)
128
+ input_images = transform_image(image).unsqueeze(0).to('cuda').half()
129
+
130
+ # Prediction
131
+ with torch.no_grad():
132
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
133
+ pred = preds[0].squeeze()
134
+ pred_pil = transforms.ToPILImage()(pred)
135
+ mask = pred_pil.resize(image.size)
136
+ image.putalpha(mask)
137
+ return image, mask
138
+
139
+ # Visualization
140
+ plt.axis("off")
141
+ plt.imshow(extract_object(birefnet, imagepath='PATH-TO-YOUR_IMAGE.jpg')[0])
142
+ plt.show()
143
+
144
+ ```
145
+
146
+ ### 2. Use inference endpoint locally:
147
+ > You may need to click the *deploy* and set up the endpoint by yourself, which would make some costs.
148
+ ```
149
+ import requests
150
+ import base64
151
+ from io import BytesIO
152
+ from PIL import Image
153
+
154
+
155
+ YOUR_HF_TOKEN = 'xxx'
156
+ API_URL = "xxx"
157
+ headers = {
158
+ "Authorization": "Bearer {}".format(YOUR_HF_TOKEN)
159
+ }
160
+
161
+ def base64_to_bytes(base64_string):
162
+ # Remove the data URI prefix if present
163
+ if "data:image" in base64_string:
164
+ base64_string = base64_string.split(",")[1]
165
+
166
+ # Decode the Base64 string into bytes
167
+ image_bytes = base64.b64decode(base64_string)
168
+ return image_bytes
169
+
170
+ def bytes_to_base64(image_bytes):
171
+ # Create a BytesIO object to handle the image data
172
+ image_stream = BytesIO(image_bytes)
173
+
174
+ # Open the image using Pillow (PIL)
175
+ image = Image.open(image_stream)
176
+ return image
177
+
178
+ def query(payload):
179
+ response = requests.post(API_URL, headers=headers, json=payload)
180
+ return response.json()
181
+
182
+ output = query({
183
+ "inputs": "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg",
184
+ "parameters": {}
185
+ })
186
+
187
+ output_image = bytes_to_base64(base64_to_bytes(output))
188
+ output_image
189
+ ```
190
+
191
+
192
+ > This BiRefNet for standard dichotomous image segmentation (DIS) is trained on **DIS-TR** and validated on **DIS-TEs and DIS-VD**.
193
+
194
+ ## This repo holds the official model weights of "[<ins>Bilateral Reference for High-Resolution Dichotomous Image Segmentation</ins>](https://arxiv.org/pdf/2401.03407)" (_CAAI AIR 2024_).
195
+
196
+ This repo contains the weights of BiRefNet proposed in our paper, which has achieved the SOTA performance on three tasks (DIS, HRSOD, and COD).
197
+
198
+ Go to my GitHub page for BiRefNet codes and the latest updates: https://github.com/ZhengPeng7/BiRefNet :)
199
+
200
+
201
+ #### Try our online demos for inference:
202
+
203
+ + Online **Image Inference** on Colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14Dqg7oeBkFEtchaHLNpig2BcdkZEogba?usp=drive_link)
204
+ + **Online Inference with GUI on Hugging Face** with adjustable resolutions: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/ZhengPeng7/BiRefNet_demo)
205
+ + **Inference and evaluation** of your given weights: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MaEiBfJ4xIaZZn0DqKrhydHB8X97hNXl#scrollTo=DJ4meUYjia6S)
206
+ <img src="https://drive.google.com/thumbnail?id=12XmDhKtO1o2fEvBu4OE4ULVB2BK0ecWi&sz=w1080" />
207
+
208
+ ## Acknowledgement:
209
+
210
+ + Many thanks to @Freepik for their generous support on GPU resources for training higher resolution BiRefNet models and more of my explorations.
211
+ + Many thanks to @fal for their generous support on GPU resources for training better general BiRefNet models.
212
+ + Many thanks to @not-lain for his help on the better deployment of our BiRefNet model on HuggingFace.
213
+
214
+
215
+ ## Citation
216
+
217
+ ```
218
+ @article{zheng2024birefnet,
219
+ title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
220
+ author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
221
+ journal={CAAI Artificial Intelligence Research},
222
+ volume = {3},
223
+ pages = {9150038},
224
+ year={2024}
225
+ }
226
+ ```
Trellv2/ZhengPeng7--BiRefNet/birefnet.py ADDED
@@ -0,0 +1,2252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### config.py
2
+
3
+ import os
4
+ import math
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class Config(PretrainedConfig):
9
+ def __init__(self) -> None:
10
+ # Compatible with the latest version of transformers
11
+ super().__init__()
12
+
13
+ # PATH settings
14
+ self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
15
+
16
+ # TASK settings
17
+ self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
18
+ self.training_set = {
19
+ 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
20
+ 'COD': 'TR-COD10K+TR-CAMO',
21
+ 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
22
+ 'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation.
23
+ 'P3M-10k': 'TR-P3M-10k',
24
+ }[self.task]
25
+ self.prompt4loc = ['dense', 'sparse'][0]
26
+
27
+ # Faster-Training settings
28
+ self.load_all = True
29
+ self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
30
+ # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
31
+ # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
32
+ # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
33
+ self.precisionHigh = True
34
+
35
+ # MODEL settings
36
+ self.ms_supervision = True
37
+ self.out_ref = self.ms_supervision and True
38
+ self.dec_ipt = True
39
+ self.dec_ipt_split = True
40
+ self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder
41
+ self.mul_scl_ipt = ['', 'add', 'cat'][2]
42
+ self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2]
43
+ self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1]
44
+ self.dec_blk = ['BasicDecBlk', 'ResBlk', 'HierarAttDecBlk'][0]
45
+
46
+ # TRAINING settings
47
+ self.batch_size = 4
48
+ self.IoU_finetune_last_epochs = [
49
+ 0,
50
+ {
51
+ 'DIS5K': -50,
52
+ 'COD': -20,
53
+ 'HRSOD': -20,
54
+ 'DIS5K+HRSOD+HRS10K': -20,
55
+ 'P3M-10k': -20,
56
+ }[self.task]
57
+ ][1] # choose 0 to skip
58
+ self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
59
+ self.size = 1024
60
+ self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
61
+
62
+ # Backbone settings
63
+ self.bb = [
64
+ 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2
65
+ 'swin_v1_t', 'swin_v1_s', # 3, 4
66
+ 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4
67
+ 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8
68
+ 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5
69
+ ][6]
70
+ self.lateral_channels_in_collection = {
71
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
72
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
73
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
74
+ 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96],
75
+ 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64],
76
+ }[self.bb]
77
+ if self.mul_scl_ipt == 'cat':
78
+ self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection]
79
+ self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else []
80
+
81
+ # MODEL settings - inactive
82
+ self.lat_blk = ['BasicLatBlk'][0]
83
+ self.dec_channels_inter = ['fixed', 'adap'][0]
84
+ self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0]
85
+ self.progressive_ref = self.refine and True
86
+ self.ender = self.progressive_ref and False
87
+ self.scale = self.progressive_ref and 2
88
+ self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`.
89
+ self.refine_iteration = 1
90
+ self.freeze_bb = False
91
+ self.model = [
92
+ 'BiRefNet',
93
+ ][0]
94
+ if self.dec_blk == 'HierarAttDecBlk':
95
+ self.batch_size = 2 ** [0, 1, 2, 3, 4][2]
96
+
97
+ # TRAINING settings - inactive
98
+ self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4]
99
+ self.optimizer = ['Adam', 'AdamW'][1]
100
+ self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch.
101
+ self.lr_decay_rate = 0.5
102
+ # Loss
103
+ self.lambdas_pix_last = {
104
+ # not 0 means opening this loss
105
+ # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
106
+ 'bce': 30 * 1, # high performance
107
+ 'iou': 0.5 * 1, # 0 / 255
108
+ 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64)
109
+ 'mse': 150 * 0, # can smooth the saliency map
110
+ 'triplet': 3 * 0,
111
+ 'reg': 100 * 0,
112
+ 'ssim': 10 * 1, # help contours,
113
+ 'cnt': 5 * 0, # help contours
114
+ 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
115
+ }
116
+ self.lambdas_cls = {
117
+ 'ce': 5.0
118
+ }
119
+ # Adv
120
+ self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training
121
+ self.lambda_adv_d = 3. * (self.lambda_adv_g > 0)
122
+
123
+ # PATH settings - inactive
124
+ self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')
125
+ self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights')
126
+ self.weights = {
127
+ 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'),
128
+ 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
129
+ 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
130
+ 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
131
+ 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
132
+ 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
133
+ 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
134
+ 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
135
+ }
136
+
137
+ # Callbacks - inactive
138
+ self.verbose_eval = True
139
+ self.only_S_MAE = False
140
+ self.use_fp16 = False # Bugs. It may cause nan in training.
141
+ self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs
142
+
143
+ # others
144
+ self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0')
145
+
146
+ self.batch_size_valid = 1
147
+ self.rand_seed = 7
148
+ # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
149
+ # with open(run_sh_file[0], 'r') as f:
150
+ # lines = f.readlines()
151
+ # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
152
+ # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])
153
+ # self.val_step = [0, self.save_step][0]
154
+
155
+ def print_task(self) -> None:
156
+ # Return task for choosing settings in shell scripts.
157
+ print(self.task)
158
+
159
+
160
+
161
+ ### models/backbones/pvt_v2.py
162
+
163
+ import torch
164
+ import torch.nn as nn
165
+ from functools import partial
166
+
167
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
168
+ from timm.models.registry import register_model
169
+
170
+ import math
171
+
172
+ # from config import Config
173
+
174
+ # config = Config()
175
+
176
+ class Mlp(nn.Module):
177
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
178
+ super().__init__()
179
+ out_features = out_features or in_features
180
+ hidden_features = hidden_features or in_features
181
+ self.fc1 = nn.Linear(in_features, hidden_features)
182
+ self.dwconv = DWConv(hidden_features)
183
+ self.act = act_layer()
184
+ self.fc2 = nn.Linear(hidden_features, out_features)
185
+ self.drop = nn.Dropout(drop)
186
+
187
+ self.apply(self._init_weights)
188
+
189
+ def _init_weights(self, m):
190
+ if isinstance(m, nn.Linear):
191
+ trunc_normal_(m.weight, std=.02)
192
+ if isinstance(m, nn.Linear) and m.bias is not None:
193
+ nn.init.constant_(m.bias, 0)
194
+ elif isinstance(m, nn.LayerNorm):
195
+ nn.init.constant_(m.bias, 0)
196
+ nn.init.constant_(m.weight, 1.0)
197
+ elif isinstance(m, nn.Conv2d):
198
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
199
+ fan_out //= m.groups
200
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
201
+ if m.bias is not None:
202
+ m.bias.data.zero_()
203
+
204
+ def forward(self, x, H, W):
205
+ x = self.fc1(x)
206
+ x = self.dwconv(x, H, W)
207
+ x = self.act(x)
208
+ x = self.drop(x)
209
+ x = self.fc2(x)
210
+ x = self.drop(x)
211
+ return x
212
+
213
+
214
+ class Attention(nn.Module):
215
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
216
+ super().__init__()
217
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
218
+
219
+ self.dim = dim
220
+ self.num_heads = num_heads
221
+ head_dim = dim // num_heads
222
+ self.scale = qk_scale or head_dim ** -0.5
223
+
224
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
225
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
226
+ self.attn_drop_prob = attn_drop
227
+ self.attn_drop = nn.Dropout(attn_drop)
228
+ self.proj = nn.Linear(dim, dim)
229
+ self.proj_drop = nn.Dropout(proj_drop)
230
+
231
+ self.sr_ratio = sr_ratio
232
+ if sr_ratio > 1:
233
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
234
+ self.norm = nn.LayerNorm(dim)
235
+
236
+ self.apply(self._init_weights)
237
+
238
+ def _init_weights(self, m):
239
+ if isinstance(m, nn.Linear):
240
+ trunc_normal_(m.weight, std=.02)
241
+ if isinstance(m, nn.Linear) and m.bias is not None:
242
+ nn.init.constant_(m.bias, 0)
243
+ elif isinstance(m, nn.LayerNorm):
244
+ nn.init.constant_(m.bias, 0)
245
+ nn.init.constant_(m.weight, 1.0)
246
+ elif isinstance(m, nn.Conv2d):
247
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
248
+ fan_out //= m.groups
249
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
250
+ if m.bias is not None:
251
+ m.bias.data.zero_()
252
+
253
+ def forward(self, x, H, W):
254
+ B, N, C = x.shape
255
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
256
+
257
+ if self.sr_ratio > 1:
258
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
259
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
260
+ x_ = self.norm(x_)
261
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
262
+ else:
263
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
264
+ k, v = kv[0], kv[1]
265
+
266
+ if config.SDPA_enabled:
267
+ x = torch.nn.functional.scaled_dot_product_attention(
268
+ q, k, v,
269
+ attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
270
+ ).transpose(1, 2).reshape(B, N, C)
271
+ else:
272
+ attn = (q @ k.transpose(-2, -1)) * self.scale
273
+ attn = attn.softmax(dim=-1)
274
+ attn = self.attn_drop(attn)
275
+
276
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
277
+ x = self.proj(x)
278
+ x = self.proj_drop(x)
279
+
280
+ return x
281
+
282
+
283
+ class Block(nn.Module):
284
+
285
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
286
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
287
+ super().__init__()
288
+ self.norm1 = norm_layer(dim)
289
+ self.attn = Attention(
290
+ dim,
291
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
292
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
293
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
294
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
295
+ self.norm2 = norm_layer(dim)
296
+ mlp_hidden_dim = int(dim * mlp_ratio)
297
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
298
+
299
+ self.apply(self._init_weights)
300
+
301
+ def _init_weights(self, m):
302
+ if isinstance(m, nn.Linear):
303
+ trunc_normal_(m.weight, std=.02)
304
+ if isinstance(m, nn.Linear) and m.bias is not None:
305
+ nn.init.constant_(m.bias, 0)
306
+ elif isinstance(m, nn.LayerNorm):
307
+ nn.init.constant_(m.bias, 0)
308
+ nn.init.constant_(m.weight, 1.0)
309
+ elif isinstance(m, nn.Conv2d):
310
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
311
+ fan_out //= m.groups
312
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
313
+ if m.bias is not None:
314
+ m.bias.data.zero_()
315
+
316
+ def forward(self, x, H, W):
317
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
318
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
319
+
320
+ return x
321
+
322
+
323
+ class OverlapPatchEmbed(nn.Module):
324
+ """ Image to Patch Embedding
325
+ """
326
+
327
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768):
328
+ super().__init__()
329
+ img_size = to_2tuple(img_size)
330
+ patch_size = to_2tuple(patch_size)
331
+
332
+ self.img_size = img_size
333
+ self.patch_size = patch_size
334
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
335
+ self.num_patches = self.H * self.W
336
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride,
337
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
338
+ self.norm = nn.LayerNorm(embed_dim)
339
+
340
+ self.apply(self._init_weights)
341
+
342
+ def _init_weights(self, m):
343
+ if isinstance(m, nn.Linear):
344
+ trunc_normal_(m.weight, std=.02)
345
+ if isinstance(m, nn.Linear) and m.bias is not None:
346
+ nn.init.constant_(m.bias, 0)
347
+ elif isinstance(m, nn.LayerNorm):
348
+ nn.init.constant_(m.bias, 0)
349
+ nn.init.constant_(m.weight, 1.0)
350
+ elif isinstance(m, nn.Conv2d):
351
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
352
+ fan_out //= m.groups
353
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
354
+ if m.bias is not None:
355
+ m.bias.data.zero_()
356
+
357
+ def forward(self, x):
358
+ x = self.proj(x)
359
+ _, _, H, W = x.shape
360
+ x = x.flatten(2).transpose(1, 2)
361
+ x = self.norm(x)
362
+
363
+ return x, H, W
364
+
365
+
366
+ class PyramidVisionTransformerImpr(nn.Module):
367
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
368
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
369
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
370
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
371
+ super().__init__()
372
+ self.num_classes = num_classes
373
+ self.depths = depths
374
+
375
+ # patch_embed
376
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels,
377
+ embed_dim=embed_dims[0])
378
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0],
379
+ embed_dim=embed_dims[1])
380
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1],
381
+ embed_dim=embed_dims[2])
382
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2],
383
+ embed_dim=embed_dims[3])
384
+
385
+ # transformer encoder
386
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
387
+ cur = 0
388
+ self.block1 = nn.ModuleList([Block(
389
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
390
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
391
+ sr_ratio=sr_ratios[0])
392
+ for i in range(depths[0])])
393
+ self.norm1 = norm_layer(embed_dims[0])
394
+
395
+ cur += depths[0]
396
+ self.block2 = nn.ModuleList([Block(
397
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
398
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
399
+ sr_ratio=sr_ratios[1])
400
+ for i in range(depths[1])])
401
+ self.norm2 = norm_layer(embed_dims[1])
402
+
403
+ cur += depths[1]
404
+ self.block3 = nn.ModuleList([Block(
405
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
406
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
407
+ sr_ratio=sr_ratios[2])
408
+ for i in range(depths[2])])
409
+ self.norm3 = norm_layer(embed_dims[2])
410
+
411
+ cur += depths[2]
412
+ self.block4 = nn.ModuleList([Block(
413
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
414
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
415
+ sr_ratio=sr_ratios[3])
416
+ for i in range(depths[3])])
417
+ self.norm4 = norm_layer(embed_dims[3])
418
+
419
+ # classification head
420
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
421
+
422
+ self.apply(self._init_weights)
423
+
424
+ def _init_weights(self, m):
425
+ if isinstance(m, nn.Linear):
426
+ trunc_normal_(m.weight, std=.02)
427
+ if isinstance(m, nn.Linear) and m.bias is not None:
428
+ nn.init.constant_(m.bias, 0)
429
+ elif isinstance(m, nn.LayerNorm):
430
+ nn.init.constant_(m.bias, 0)
431
+ nn.init.constant_(m.weight, 1.0)
432
+ elif isinstance(m, nn.Conv2d):
433
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
434
+ fan_out //= m.groups
435
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
436
+ if m.bias is not None:
437
+ m.bias.data.zero_()
438
+
439
+ def init_weights(self, pretrained=None):
440
+ if isinstance(pretrained, str):
441
+ logger = 1
442
+ #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
443
+
444
+ def reset_drop_path(self, drop_path_rate):
445
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
446
+ cur = 0
447
+ for i in range(self.depths[0]):
448
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
449
+
450
+ cur += self.depths[0]
451
+ for i in range(self.depths[1]):
452
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
453
+
454
+ cur += self.depths[1]
455
+ for i in range(self.depths[2]):
456
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
457
+
458
+ cur += self.depths[2]
459
+ for i in range(self.depths[3]):
460
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
461
+
462
+ def freeze_patch_emb(self):
463
+ self.patch_embed1.requires_grad = False
464
+
465
+ @torch.jit.ignore
466
+ def no_weight_decay(self):
467
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
468
+
469
+ def get_classifier(self):
470
+ return self.head
471
+
472
+ def reset_classifier(self, num_classes, global_pool=''):
473
+ self.num_classes = num_classes
474
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
475
+
476
+ def forward_features(self, x):
477
+ B = x.shape[0]
478
+ outs = []
479
+
480
+ # stage 1
481
+ x, H, W = self.patch_embed1(x)
482
+ for i, blk in enumerate(self.block1):
483
+ x = blk(x, H, W)
484
+ x = self.norm1(x)
485
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
486
+ outs.append(x)
487
+
488
+ # stage 2
489
+ x, H, W = self.patch_embed2(x)
490
+ for i, blk in enumerate(self.block2):
491
+ x = blk(x, H, W)
492
+ x = self.norm2(x)
493
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
494
+ outs.append(x)
495
+
496
+ # stage 3
497
+ x, H, W = self.patch_embed3(x)
498
+ for i, blk in enumerate(self.block3):
499
+ x = blk(x, H, W)
500
+ x = self.norm3(x)
501
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
502
+ outs.append(x)
503
+
504
+ # stage 4
505
+ x, H, W = self.patch_embed4(x)
506
+ for i, blk in enumerate(self.block4):
507
+ x = blk(x, H, W)
508
+ x = self.norm4(x)
509
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
510
+ outs.append(x)
511
+
512
+ return outs
513
+
514
+ # return x.mean(dim=1)
515
+
516
+ def forward(self, x):
517
+ x = self.forward_features(x)
518
+ # x = self.head(x)
519
+
520
+ return x
521
+
522
+
523
+ class DWConv(nn.Module):
524
+ def __init__(self, dim=768):
525
+ super(DWConv, self).__init__()
526
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
527
+
528
+ def forward(self, x, H, W):
529
+ B, N, C = x.shape
530
+ x = x.transpose(1, 2).view(B, C, H, W).contiguous()
531
+ x = self.dwconv(x)
532
+ x = x.flatten(2).transpose(1, 2)
533
+
534
+ return x
535
+
536
+
537
+ def _conv_filter(state_dict, patch_size=16):
538
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
539
+ out_dict = {}
540
+ for k, v in state_dict.items():
541
+ if 'patch_embed.proj.weight' in k:
542
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
543
+ out_dict[k] = v
544
+
545
+ return out_dict
546
+
547
+
548
+ ## @register_model
549
+ class pvt_v2_b0(PyramidVisionTransformerImpr):
550
+ def __init__(self, **kwargs):
551
+ super(pvt_v2_b0, self).__init__(
552
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
553
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
554
+ drop_rate=0.0, drop_path_rate=0.1)
555
+
556
+
557
+
558
+ ## @register_model
559
+ class pvt_v2_b1(PyramidVisionTransformerImpr):
560
+ def __init__(self, **kwargs):
561
+ super(pvt_v2_b1, self).__init__(
562
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
563
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
564
+ drop_rate=0.0, drop_path_rate=0.1)
565
+
566
+ ## @register_model
567
+ class pvt_v2_b2(PyramidVisionTransformerImpr):
568
+ def __init__(self, in_channels=3, **kwargs):
569
+ super(pvt_v2_b2, self).__init__(
570
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
571
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
572
+ drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
573
+
574
+ ## @register_model
575
+ class pvt_v2_b3(PyramidVisionTransformerImpr):
576
+ def __init__(self, **kwargs):
577
+ super(pvt_v2_b3, self).__init__(
578
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
579
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
580
+ drop_rate=0.0, drop_path_rate=0.1)
581
+
582
+ ## @register_model
583
+ class pvt_v2_b4(PyramidVisionTransformerImpr):
584
+ def __init__(self, **kwargs):
585
+ super(pvt_v2_b4, self).__init__(
586
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
587
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
588
+ drop_rate=0.0, drop_path_rate=0.1)
589
+
590
+
591
+ ## @register_model
592
+ class pvt_v2_b5(PyramidVisionTransformerImpr):
593
+ def __init__(self, **kwargs):
594
+ super(pvt_v2_b5, self).__init__(
595
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
596
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
597
+ drop_rate=0.0, drop_path_rate=0.1)
598
+
599
+
600
+
601
+ ### models/backbones/swin_v1.py
602
+
603
+ # --------------------------------------------------------
604
+ # Swin Transformer
605
+ # Copyright (c) 2021 Microsoft
606
+ # Licensed under The MIT License [see LICENSE for details]
607
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
608
+ # --------------------------------------------------------
609
+
610
+ import torch
611
+ import torch.nn as nn
612
+ import torch.nn.functional as F
613
+ import torch.utils.checkpoint as checkpoint
614
+ import numpy as np
615
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
616
+
617
+ # from config import Config
618
+
619
+
620
+ # config = Config()
621
+
622
+
623
+ class Mlp(nn.Module):
624
+ """ Multilayer perceptron."""
625
+
626
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
627
+ super().__init__()
628
+ out_features = out_features or in_features
629
+ hidden_features = hidden_features or in_features
630
+ self.fc1 = nn.Linear(in_features, hidden_features)
631
+ self.act = act_layer()
632
+ self.fc2 = nn.Linear(hidden_features, out_features)
633
+ self.drop = nn.Dropout(drop)
634
+
635
+ def forward(self, x):
636
+ x = self.fc1(x)
637
+ x = self.act(x)
638
+ x = self.drop(x)
639
+ x = self.fc2(x)
640
+ x = self.drop(x)
641
+ return x
642
+
643
+
644
+ def window_partition(x, window_size):
645
+ """
646
+ Args:
647
+ x: (B, H, W, C)
648
+ window_size (int): window size
649
+
650
+ Returns:
651
+ windows: (num_windows*B, window_size, window_size, C)
652
+ """
653
+ B, H, W, C = x.shape
654
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
655
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
656
+ return windows
657
+
658
+
659
+ def window_reverse(windows, window_size, H, W):
660
+ """
661
+ Args:
662
+ windows: (num_windows*B, window_size, window_size, C)
663
+ window_size (int): Window size
664
+ H (int): Height of image
665
+ W (int): Width of image
666
+
667
+ Returns:
668
+ x: (B, H, W, C)
669
+ """
670
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
671
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
672
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
673
+ return x
674
+
675
+
676
+ class WindowAttention(nn.Module):
677
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
678
+ It supports both of shifted and non-shifted window.
679
+
680
+ Args:
681
+ dim (int): Number of input channels.
682
+ window_size (tuple[int]): The height and width of the window.
683
+ num_heads (int): Number of attention heads.
684
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
685
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
686
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
687
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
688
+ """
689
+
690
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
691
+
692
+ super().__init__()
693
+ self.dim = dim
694
+ self.window_size = window_size # Wh, Ww
695
+ self.num_heads = num_heads
696
+ head_dim = dim // num_heads
697
+ self.scale = qk_scale or head_dim ** -0.5
698
+
699
+ # define a parameter table of relative position bias
700
+ self.relative_position_bias_table = nn.Parameter(
701
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
702
+
703
+ # get pair-wise relative position index for each token inside the window
704
+ coords_h = torch.arange(self.window_size[0])
705
+ coords_w = torch.arange(self.window_size[1])
706
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
707
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
708
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
709
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
710
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
711
+ relative_coords[:, :, 1] += self.window_size[1] - 1
712
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
713
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
714
+ self.register_buffer("relative_position_index", relative_position_index)
715
+
716
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
717
+ self.attn_drop_prob = attn_drop
718
+ self.attn_drop = nn.Dropout(attn_drop)
719
+ self.proj = nn.Linear(dim, dim)
720
+ self.proj_drop = nn.Dropout(proj_drop)
721
+
722
+ trunc_normal_(self.relative_position_bias_table, std=.02)
723
+ self.softmax = nn.Softmax(dim=-1)
724
+
725
+ def forward(self, x, mask=None):
726
+ """ Forward function.
727
+
728
+ Args:
729
+ x: input features with shape of (num_windows*B, N, C)
730
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
731
+ """
732
+ B_, N, C = x.shape
733
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
734
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
735
+
736
+ q = q * self.scale
737
+
738
+ if config.SDPA_enabled:
739
+ x = torch.nn.functional.scaled_dot_product_attention(
740
+ q, k, v,
741
+ attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
742
+ ).transpose(1, 2).reshape(B_, N, C)
743
+ else:
744
+ attn = (q @ k.transpose(-2, -1))
745
+
746
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
747
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
748
+ ) # Wh*Ww, Wh*Ww, nH
749
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
750
+ attn = attn + relative_position_bias.unsqueeze(0)
751
+
752
+ if mask is not None:
753
+ nW = mask.shape[0]
754
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
755
+ attn = attn.view(-1, self.num_heads, N, N)
756
+ attn = self.softmax(attn)
757
+ else:
758
+ attn = self.softmax(attn)
759
+
760
+ attn = self.attn_drop(attn)
761
+
762
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
763
+ x = self.proj(x)
764
+ x = self.proj_drop(x)
765
+ return x
766
+
767
+
768
+ class SwinTransformerBlock(nn.Module):
769
+ """ Swin Transformer Block.
770
+
771
+ Args:
772
+ dim (int): Number of input channels.
773
+ num_heads (int): Number of attention heads.
774
+ window_size (int): Window size.
775
+ shift_size (int): Shift size for SW-MSA.
776
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
777
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
778
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
779
+ drop (float, optional): Dropout rate. Default: 0.0
780
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
781
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
782
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
783
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
784
+ """
785
+
786
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
787
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
788
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
789
+ super().__init__()
790
+ self.dim = dim
791
+ self.num_heads = num_heads
792
+ self.window_size = window_size
793
+ self.shift_size = shift_size
794
+ self.mlp_ratio = mlp_ratio
795
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
796
+
797
+ self.norm1 = norm_layer(dim)
798
+ self.attn = WindowAttention(
799
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
800
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
801
+
802
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
803
+ self.norm2 = norm_layer(dim)
804
+ mlp_hidden_dim = int(dim * mlp_ratio)
805
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
806
+
807
+ self.H = None
808
+ self.W = None
809
+
810
+ def forward(self, x, mask_matrix):
811
+ """ Forward function.
812
+
813
+ Args:
814
+ x: Input feature, tensor size (B, H*W, C).
815
+ H, W: Spatial resolution of the input feature.
816
+ mask_matrix: Attention mask for cyclic shift.
817
+ """
818
+ B, L, C = x.shape
819
+ H, W = self.H, self.W
820
+ assert L == H * W, "input feature has wrong size"
821
+
822
+ shortcut = x
823
+ x = self.norm1(x)
824
+ x = x.view(B, H, W, C)
825
+
826
+ # pad feature maps to multiples of window size
827
+ pad_l = pad_t = 0
828
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
829
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
830
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
831
+ _, Hp, Wp, _ = x.shape
832
+
833
+ # cyclic shift
834
+ if self.shift_size > 0:
835
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
836
+ attn_mask = mask_matrix
837
+ else:
838
+ shifted_x = x
839
+ attn_mask = None
840
+
841
+ # partition windows
842
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
843
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
844
+
845
+ # W-MSA/SW-MSA
846
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
847
+
848
+ # merge windows
849
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
850
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
851
+
852
+ # reverse cyclic shift
853
+ if self.shift_size > 0:
854
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
855
+ else:
856
+ x = shifted_x
857
+
858
+ if pad_r > 0 or pad_b > 0:
859
+ x = x[:, :H, :W, :].contiguous()
860
+
861
+ x = x.view(B, H * W, C)
862
+
863
+ # FFN
864
+ x = shortcut + self.drop_path(x)
865
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
866
+
867
+ return x
868
+
869
+
870
+ class PatchMerging(nn.Module):
871
+ """ Patch Merging Layer
872
+
873
+ Args:
874
+ dim (int): Number of input channels.
875
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
876
+ """
877
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
878
+ super().__init__()
879
+ self.dim = dim
880
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
881
+ self.norm = norm_layer(4 * dim)
882
+
883
+ def forward(self, x, H, W):
884
+ """ Forward function.
885
+
886
+ Args:
887
+ x: Input feature, tensor size (B, H*W, C).
888
+ H, W: Spatial resolution of the input feature.
889
+ """
890
+ B, L, C = x.shape
891
+ assert L == H * W, "input feature has wrong size"
892
+
893
+ x = x.view(B, H, W, C)
894
+
895
+ # padding
896
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
897
+ if pad_input:
898
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
899
+
900
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
901
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
902
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
903
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
904
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
905
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
906
+
907
+ x = self.norm(x)
908
+ x = self.reduction(x)
909
+
910
+ return x
911
+
912
+
913
+ class BasicLayer(nn.Module):
914
+ """ A basic Swin Transformer layer for one stage.
915
+
916
+ Args:
917
+ dim (int): Number of feature channels
918
+ depth (int): Depths of this stage.
919
+ num_heads (int): Number of attention head.
920
+ window_size (int): Local window size. Default: 7.
921
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
922
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
923
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
924
+ drop (float, optional): Dropout rate. Default: 0.0
925
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
926
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
927
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
928
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
929
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
930
+ """
931
+
932
+ def __init__(self,
933
+ dim,
934
+ depth,
935
+ num_heads,
936
+ window_size=7,
937
+ mlp_ratio=4.,
938
+ qkv_bias=True,
939
+ qk_scale=None,
940
+ drop=0.,
941
+ attn_drop=0.,
942
+ drop_path=0.,
943
+ norm_layer=nn.LayerNorm,
944
+ downsample=None,
945
+ use_checkpoint=False):
946
+ super().__init__()
947
+ self.window_size = window_size
948
+ self.shift_size = window_size // 2
949
+ self.depth = depth
950
+ self.use_checkpoint = use_checkpoint
951
+
952
+ # build blocks
953
+ self.blocks = nn.ModuleList([
954
+ SwinTransformerBlock(
955
+ dim=dim,
956
+ num_heads=num_heads,
957
+ window_size=window_size,
958
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
959
+ mlp_ratio=mlp_ratio,
960
+ qkv_bias=qkv_bias,
961
+ qk_scale=qk_scale,
962
+ drop=drop,
963
+ attn_drop=attn_drop,
964
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
965
+ norm_layer=norm_layer)
966
+ for i in range(depth)])
967
+
968
+ # patch merging layer
969
+ if downsample is not None:
970
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
971
+ else:
972
+ self.downsample = None
973
+
974
+ def forward(self, x, H, W):
975
+ """ Forward function.
976
+
977
+ Args:
978
+ x: Input feature, tensor size (B, H*W, C).
979
+ H, W: Spatial resolution of the input feature.
980
+ """
981
+
982
+ # calculate attention mask for SW-MSA
983
+ # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
984
+ Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
985
+ Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
986
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
987
+ h_slices = (slice(0, -self.window_size),
988
+ slice(-self.window_size, -self.shift_size),
989
+ slice(-self.shift_size, None))
990
+ w_slices = (slice(0, -self.window_size),
991
+ slice(-self.window_size, -self.shift_size),
992
+ slice(-self.shift_size, None))
993
+ cnt = 0
994
+ for h in h_slices:
995
+ for w in w_slices:
996
+ img_mask[:, h, w, :] = cnt
997
+ cnt += 1
998
+
999
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1000
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1001
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1002
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
1003
+
1004
+ for blk in self.blocks:
1005
+ blk.H, blk.W = H, W
1006
+ if self.use_checkpoint:
1007
+ x = checkpoint.checkpoint(blk, x, attn_mask)
1008
+ else:
1009
+ x = blk(x, attn_mask)
1010
+ if self.downsample is not None:
1011
+ x_down = self.downsample(x, H, W)
1012
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
1013
+ return x, H, W, x_down, Wh, Ww
1014
+ else:
1015
+ return x, H, W, x, H, W
1016
+
1017
+
1018
+ class PatchEmbed(nn.Module):
1019
+ """ Image to Patch Embedding
1020
+
1021
+ Args:
1022
+ patch_size (int): Patch token size. Default: 4.
1023
+ in_channels (int): Number of input image channels. Default: 3.
1024
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1025
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
1026
+ """
1027
+
1028
+ def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
1029
+ super().__init__()
1030
+ patch_size = to_2tuple(patch_size)
1031
+ self.patch_size = patch_size
1032
+
1033
+ self.in_channels = in_channels
1034
+ self.embed_dim = embed_dim
1035
+
1036
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
1037
+ if norm_layer is not None:
1038
+ self.norm = norm_layer(embed_dim)
1039
+ else:
1040
+ self.norm = None
1041
+
1042
+ def forward(self, x):
1043
+ """Forward function."""
1044
+ # padding
1045
+ _, _, H, W = x.size()
1046
+ if W % self.patch_size[1] != 0:
1047
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
1048
+ if H % self.patch_size[0] != 0:
1049
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
1050
+
1051
+ x = self.proj(x) # B C Wh Ww
1052
+ if self.norm is not None:
1053
+ Wh, Ww = x.size(2), x.size(3)
1054
+ x = x.flatten(2).transpose(1, 2)
1055
+ x = self.norm(x)
1056
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
1057
+
1058
+ return x
1059
+
1060
+
1061
+ class SwinTransformer(nn.Module):
1062
+ """ Swin Transformer backbone.
1063
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
1064
+ https://arxiv.org/pdf/2103.14030
1065
+
1066
+ Args:
1067
+ pretrain_img_size (int): Input image size for training the pretrained model,
1068
+ used in absolute postion embedding. Default 224.
1069
+ patch_size (int | tuple(int)): Patch size. Default: 4.
1070
+ in_channels (int): Number of input image channels. Default: 3.
1071
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1072
+ depths (tuple[int]): Depths of each Swin Transformer stage.
1073
+ num_heads (tuple[int]): Number of attention head of each stage.
1074
+ window_size (int): Window size. Default: 7.
1075
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
1076
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
1077
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
1078
+ drop_rate (float): Dropout rate.
1079
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
1080
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
1081
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
1082
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
1083
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
1084
+ out_indices (Sequence[int]): Output from which stages.
1085
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
1086
+ -1 means not freezing any parameters.
1087
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1088
+ """
1089
+
1090
+ def __init__(self,
1091
+ pretrain_img_size=224,
1092
+ patch_size=4,
1093
+ in_channels=3,
1094
+ embed_dim=96,
1095
+ depths=[2, 2, 6, 2],
1096
+ num_heads=[3, 6, 12, 24],
1097
+ window_size=7,
1098
+ mlp_ratio=4.,
1099
+ qkv_bias=True,
1100
+ qk_scale=None,
1101
+ drop_rate=0.,
1102
+ attn_drop_rate=0.,
1103
+ drop_path_rate=0.2,
1104
+ norm_layer=nn.LayerNorm,
1105
+ ape=False,
1106
+ patch_norm=True,
1107
+ out_indices=(0, 1, 2, 3),
1108
+ frozen_stages=-1,
1109
+ use_checkpoint=False):
1110
+ super().__init__()
1111
+
1112
+ self.pretrain_img_size = pretrain_img_size
1113
+ self.num_layers = len(depths)
1114
+ self.embed_dim = embed_dim
1115
+ self.ape = ape
1116
+ self.patch_norm = patch_norm
1117
+ self.out_indices = out_indices
1118
+ self.frozen_stages = frozen_stages
1119
+
1120
+ # split image into non-overlapping patches
1121
+ self.patch_embed = PatchEmbed(
1122
+ patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
1123
+ norm_layer=norm_layer if self.patch_norm else None)
1124
+
1125
+ # absolute position embedding
1126
+ if self.ape:
1127
+ pretrain_img_size = to_2tuple(pretrain_img_size)
1128
+ patch_size = to_2tuple(patch_size)
1129
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
1130
+
1131
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
1132
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1133
+
1134
+ self.pos_drop = nn.Dropout(p=drop_rate)
1135
+
1136
+ # stochastic depth
1137
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1138
+
1139
+ # build layers
1140
+ self.layers = nn.ModuleList()
1141
+ for i_layer in range(self.num_layers):
1142
+ layer = BasicLayer(
1143
+ dim=int(embed_dim * 2 ** i_layer),
1144
+ depth=depths[i_layer],
1145
+ num_heads=num_heads[i_layer],
1146
+ window_size=window_size,
1147
+ mlp_ratio=mlp_ratio,
1148
+ qkv_bias=qkv_bias,
1149
+ qk_scale=qk_scale,
1150
+ drop=drop_rate,
1151
+ attn_drop=attn_drop_rate,
1152
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
1153
+ norm_layer=norm_layer,
1154
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
1155
+ use_checkpoint=use_checkpoint)
1156
+ self.layers.append(layer)
1157
+
1158
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
1159
+ self.num_features = num_features
1160
+
1161
+ # add a norm layer for each output
1162
+ for i_layer in out_indices:
1163
+ layer = norm_layer(num_features[i_layer])
1164
+ layer_name = f'norm{i_layer}'
1165
+ self.add_module(layer_name, layer)
1166
+
1167
+ self._freeze_stages()
1168
+
1169
+ def _freeze_stages(self):
1170
+ if self.frozen_stages >= 0:
1171
+ self.patch_embed.eval()
1172
+ for param in self.patch_embed.parameters():
1173
+ param.requires_grad = False
1174
+
1175
+ if self.frozen_stages >= 1 and self.ape:
1176
+ self.absolute_pos_embed.requires_grad = False
1177
+
1178
+ if self.frozen_stages >= 2:
1179
+ self.pos_drop.eval()
1180
+ for i in range(0, self.frozen_stages - 1):
1181
+ m = self.layers[i]
1182
+ m.eval()
1183
+ for param in m.parameters():
1184
+ param.requires_grad = False
1185
+
1186
+
1187
+ def forward(self, x):
1188
+ """Forward function."""
1189
+ x = self.patch_embed(x)
1190
+
1191
+ Wh, Ww = x.size(2), x.size(3)
1192
+ if self.ape:
1193
+ # interpolate the position embedding to the corresponding size
1194
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1195
+ x = (x + absolute_pos_embed) # B Wh*Ww C
1196
+
1197
+ outs = []#x.contiguous()]
1198
+ x = x.flatten(2).transpose(1, 2)
1199
+ x = self.pos_drop(x)
1200
+ for i in range(self.num_layers):
1201
+ layer = self.layers[i]
1202
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
1203
+
1204
+ if i in self.out_indices:
1205
+ norm_layer = getattr(self, f'norm{i}')
1206
+ x_out = norm_layer(x_out)
1207
+
1208
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
1209
+ outs.append(out)
1210
+
1211
+ return tuple(outs)
1212
+
1213
+ def train(self, mode=True):
1214
+ """Convert the model into training mode while keep layers freezed."""
1215
+ super(SwinTransformer, self).train(mode)
1216
+ self._freeze_stages()
1217
+
1218
+ def swin_v1_t():
1219
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
1220
+ return model
1221
+
1222
+ def swin_v1_s():
1223
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
1224
+ return model
1225
+
1226
+ def swin_v1_b():
1227
+ model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
1228
+ return model
1229
+
1230
+ def swin_v1_l():
1231
+ model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
1232
+ return model
1233
+
1234
+
1235
+
1236
+ ### models/modules/deform_conv.py
1237
+
1238
+ import torch
1239
+ import torch.nn as nn
1240
+ from torchvision.ops import deform_conv2d
1241
+
1242
+
1243
+ class DeformableConv2d(nn.Module):
1244
+ def __init__(self,
1245
+ in_channels,
1246
+ out_channels,
1247
+ kernel_size=3,
1248
+ stride=1,
1249
+ padding=1,
1250
+ bias=False):
1251
+
1252
+ super(DeformableConv2d, self).__init__()
1253
+
1254
+ assert type(kernel_size) == tuple or type(kernel_size) == int
1255
+
1256
+ kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1257
+ self.stride = stride if type(stride) == tuple else (stride, stride)
1258
+ self.padding = padding
1259
+
1260
+ self.offset_conv = nn.Conv2d(in_channels,
1261
+ 2 * kernel_size[0] * kernel_size[1],
1262
+ kernel_size=kernel_size,
1263
+ stride=stride,
1264
+ padding=self.padding,
1265
+ bias=True)
1266
+
1267
+ nn.init.constant_(self.offset_conv.weight, 0.)
1268
+ nn.init.constant_(self.offset_conv.bias, 0.)
1269
+
1270
+ self.modulator_conv = nn.Conv2d(in_channels,
1271
+ 1 * kernel_size[0] * kernel_size[1],
1272
+ kernel_size=kernel_size,
1273
+ stride=stride,
1274
+ padding=self.padding,
1275
+ bias=True)
1276
+
1277
+ nn.init.constant_(self.modulator_conv.weight, 0.)
1278
+ nn.init.constant_(self.modulator_conv.bias, 0.)
1279
+
1280
+ self.regular_conv = nn.Conv2d(in_channels,
1281
+ out_channels=out_channels,
1282
+ kernel_size=kernel_size,
1283
+ stride=stride,
1284
+ padding=self.padding,
1285
+ bias=bias)
1286
+
1287
+ def forward(self, x):
1288
+ #h, w = x.shape[2:]
1289
+ #max_offset = max(h, w)/4.
1290
+
1291
+ offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1292
+ modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1293
+
1294
+ x = deform_conv2d(
1295
+ input=x,
1296
+ offset=offset,
1297
+ weight=self.regular_conv.weight,
1298
+ bias=self.regular_conv.bias,
1299
+ padding=self.padding,
1300
+ mask=modulator,
1301
+ stride=self.stride,
1302
+ )
1303
+ return x
1304
+
1305
+
1306
+
1307
+
1308
+ ### utils.py
1309
+
1310
+ import torch.nn as nn
1311
+
1312
+
1313
+ def build_act_layer(act_layer):
1314
+ if act_layer == 'ReLU':
1315
+ return nn.ReLU(inplace=True)
1316
+ elif act_layer == 'SiLU':
1317
+ return nn.SiLU(inplace=True)
1318
+ elif act_layer == 'GELU':
1319
+ return nn.GELU()
1320
+
1321
+ raise NotImplementedError(f'build_act_layer does not support {act_layer}')
1322
+
1323
+
1324
+ def build_norm_layer(dim,
1325
+ norm_layer,
1326
+ in_format='channels_last',
1327
+ out_format='channels_last',
1328
+ eps=1e-6):
1329
+ layers = []
1330
+ if norm_layer == 'BN':
1331
+ if in_format == 'channels_last':
1332
+ layers.append(to_channels_first())
1333
+ layers.append(nn.BatchNorm2d(dim))
1334
+ if out_format == 'channels_last':
1335
+ layers.append(to_channels_last())
1336
+ elif norm_layer == 'LN':
1337
+ if in_format == 'channels_first':
1338
+ layers.append(to_channels_last())
1339
+ layers.append(nn.LayerNorm(dim, eps=eps))
1340
+ if out_format == 'channels_first':
1341
+ layers.append(to_channels_first())
1342
+ else:
1343
+ raise NotImplementedError(
1344
+ f'build_norm_layer does not support {norm_layer}')
1345
+ return nn.Sequential(*layers)
1346
+
1347
+
1348
+ class to_channels_first(nn.Module):
1349
+
1350
+ def __init__(self):
1351
+ super().__init__()
1352
+
1353
+ def forward(self, x):
1354
+ return x.permute(0, 3, 1, 2)
1355
+
1356
+
1357
+ class to_channels_last(nn.Module):
1358
+
1359
+ def __init__(self):
1360
+ super().__init__()
1361
+
1362
+ def forward(self, x):
1363
+ return x.permute(0, 2, 3, 1)
1364
+
1365
+
1366
+
1367
+ ### dataset.py
1368
+
1369
+ _class_labels_TR_sorted = (
1370
+ 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
1371
+ 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
1372
+ 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
1373
+ 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
1374
+ 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
1375
+ 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
1376
+ 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
1377
+ 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
1378
+ 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
1379
+ 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
1380
+ 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
1381
+ 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
1382
+ 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
1383
+ 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
1384
+ )
1385
+ class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
1386
+
1387
+
1388
+ ### models/backbones/build_backbones.py
1389
+
1390
+ import torch
1391
+ import torch.nn as nn
1392
+ from collections import OrderedDict
1393
+ from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
1394
+ # from models.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
1395
+ # from models.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
1396
+ # from config import Config
1397
+
1398
+
1399
+ config = Config()
1400
+
1401
+ def build_backbone(bb_name, pretrained=True, params_settings=''):
1402
+ if bb_name == 'vgg16':
1403
+ bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0]
1404
+ bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]}))
1405
+ elif bb_name == 'vgg16bn':
1406
+ bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0]
1407
+ bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]}))
1408
+ elif bb_name == 'resnet50':
1409
+ bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children())
1410
+ bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]}))
1411
+ else:
1412
+ bb = eval('{}({})'.format(bb_name, params_settings))
1413
+ if pretrained:
1414
+ bb = load_weights(bb, bb_name)
1415
+ return bb
1416
+
1417
+ def load_weights(model, model_name):
1418
+ save_model = torch.load(config.weights[model_name], map_location='cpu')
1419
+ model_dict = model.state_dict()
1420
+ state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()}
1421
+ # to ignore the weights with mismatched size when I modify the backbone itself.
1422
+ if not state_dict:
1423
+ save_model_keys = list(save_model.keys())
1424
+ sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
1425
+ state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()}
1426
+ if not state_dict or not sub_item:
1427
+ print('Weights are not successully loaded. Check the state dict of weights file.')
1428
+ return None
1429
+ else:
1430
+ print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item))
1431
+ model_dict.update(state_dict)
1432
+ model.load_state_dict(model_dict)
1433
+ return model
1434
+
1435
+
1436
+
1437
+ ### models/modules/decoder_blocks.py
1438
+
1439
+ import torch
1440
+ import torch.nn as nn
1441
+ # from models.aspp import ASPP, ASPPDeformable
1442
+ # from config import Config
1443
+
1444
+
1445
+ # config = Config()
1446
+
1447
+
1448
+ class BasicDecBlk(nn.Module):
1449
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
1450
+ super(BasicDecBlk, self).__init__()
1451
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1452
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
1453
+ self.relu_in = nn.ReLU(inplace=True)
1454
+ if config.dec_att == 'ASPP':
1455
+ self.dec_att = ASPP(in_channels=inter_channels)
1456
+ elif config.dec_att == 'ASPPDeformable':
1457
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
1458
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1459
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
1460
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1461
+
1462
+ def forward(self, x):
1463
+ x = self.conv_in(x)
1464
+ x = self.bn_in(x)
1465
+ x = self.relu_in(x)
1466
+ if hasattr(self, 'dec_att'):
1467
+ x = self.dec_att(x)
1468
+ x = self.conv_out(x)
1469
+ x = self.bn_out(x)
1470
+ return x
1471
+
1472
+
1473
+ class ResBlk(nn.Module):
1474
+ def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
1475
+ super(ResBlk, self).__init__()
1476
+ if out_channels is None:
1477
+ out_channels = in_channels
1478
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1479
+
1480
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
1481
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
1482
+ self.relu_in = nn.ReLU(inplace=True)
1483
+
1484
+ if config.dec_att == 'ASPP':
1485
+ self.dec_att = ASPP(in_channels=inter_channels)
1486
+ elif config.dec_att == 'ASPPDeformable':
1487
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
1488
+
1489
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1490
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1491
+
1492
+ self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1493
+
1494
+ def forward(self, x):
1495
+ _x = self.conv_resi(x)
1496
+ x = self.conv_in(x)
1497
+ x = self.bn_in(x)
1498
+ x = self.relu_in(x)
1499
+ if hasattr(self, 'dec_att'):
1500
+ x = self.dec_att(x)
1501
+ x = self.conv_out(x)
1502
+ x = self.bn_out(x)
1503
+ return x + _x
1504
+
1505
+
1506
+
1507
+ ### models/modules/lateral_blocks.py
1508
+
1509
+ import numpy as np
1510
+ import torch
1511
+ import torch.nn as nn
1512
+ import torch.nn.functional as F
1513
+ from functools import partial
1514
+
1515
+ # from config import Config
1516
+
1517
+
1518
+ # config = Config()
1519
+
1520
+
1521
+ class BasicLatBlk(nn.Module):
1522
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
1523
+ super(BasicLatBlk, self).__init__()
1524
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1525
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1526
+
1527
+ def forward(self, x):
1528
+ x = self.conv(x)
1529
+ return x
1530
+
1531
+
1532
+
1533
+ ### models/modules/aspp.py
1534
+
1535
+ import torch
1536
+ import torch.nn as nn
1537
+ import torch.nn.functional as F
1538
+ # from models.deform_conv import DeformableConv2d
1539
+ # from config import Config
1540
+
1541
+
1542
+ # config = Config()
1543
+
1544
+
1545
+ class _ASPPModule(nn.Module):
1546
+ def __init__(self, in_channels, planes, kernel_size, padding, dilation):
1547
+ super(_ASPPModule, self).__init__()
1548
+ self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
1549
+ stride=1, padding=padding, dilation=dilation, bias=False)
1550
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
1551
+ self.relu = nn.ReLU(inplace=True)
1552
+
1553
+ def forward(self, x):
1554
+ x = self.atrous_conv(x)
1555
+ x = self.bn(x)
1556
+
1557
+ return self.relu(x)
1558
+
1559
+
1560
+ class ASPP(nn.Module):
1561
+ def __init__(self, in_channels=64, out_channels=None, output_stride=16):
1562
+ super(ASPP, self).__init__()
1563
+ self.down_scale = 1
1564
+ if out_channels is None:
1565
+ out_channels = in_channels
1566
+ self.in_channelster = 256 // self.down_scale
1567
+ if output_stride == 16:
1568
+ dilations = [1, 6, 12, 18]
1569
+ elif output_stride == 8:
1570
+ dilations = [1, 12, 24, 36]
1571
+ else:
1572
+ raise NotImplementedError
1573
+
1574
+ self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0])
1575
+ self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1])
1576
+ self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2])
1577
+ self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3])
1578
+
1579
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
1580
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
1581
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
1582
+ nn.ReLU(inplace=True))
1583
+ self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
1584
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1585
+ self.relu = nn.ReLU(inplace=True)
1586
+ self.dropout = nn.Dropout(0.5)
1587
+
1588
+ def forward(self, x):
1589
+ x1 = self.aspp1(x)
1590
+ x2 = self.aspp2(x)
1591
+ x3 = self.aspp3(x)
1592
+ x4 = self.aspp4(x)
1593
+ x5 = self.global_avg_pool(x)
1594
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
1595
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
1596
+
1597
+ x = self.conv1(x)
1598
+ x = self.bn1(x)
1599
+ x = self.relu(x)
1600
+
1601
+ return self.dropout(x)
1602
+
1603
+
1604
+ ##################### Deformable
1605
+ class _ASPPModuleDeformable(nn.Module):
1606
+ def __init__(self, in_channels, planes, kernel_size, padding):
1607
+ super(_ASPPModuleDeformable, self).__init__()
1608
+ self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
1609
+ stride=1, padding=padding, bias=False)
1610
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
1611
+ self.relu = nn.ReLU(inplace=True)
1612
+
1613
+ def forward(self, x):
1614
+ x = self.atrous_conv(x)
1615
+ x = self.bn(x)
1616
+
1617
+ return self.relu(x)
1618
+
1619
+
1620
+ class ASPPDeformable(nn.Module):
1621
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
1622
+ super(ASPPDeformable, self).__init__()
1623
+ self.down_scale = 1
1624
+ if out_channels is None:
1625
+ out_channels = in_channels
1626
+ self.in_channelster = 256 // self.down_scale
1627
+
1628
+ self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
1629
+ self.aspp_deforms = nn.ModuleList([
1630
+ _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes
1631
+ ])
1632
+
1633
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
1634
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
1635
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
1636
+ nn.ReLU(inplace=True))
1637
+ self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
1638
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1639
+ self.relu = nn.ReLU(inplace=True)
1640
+ self.dropout = nn.Dropout(0.5)
1641
+
1642
+ def forward(self, x):
1643
+ x1 = self.aspp1(x)
1644
+ x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
1645
+ x5 = self.global_avg_pool(x)
1646
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
1647
+ x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
1648
+
1649
+ x = self.conv1(x)
1650
+ x = self.bn1(x)
1651
+ x = self.relu(x)
1652
+
1653
+ return self.dropout(x)
1654
+
1655
+
1656
+
1657
+ ### models/refinement/refiner.py
1658
+
1659
+ import torch
1660
+ import torch.nn as nn
1661
+ from collections import OrderedDict
1662
+ import torch
1663
+ import torch.nn as nn
1664
+ import torch.nn.functional as F
1665
+ from torchvision.models import vgg16, vgg16_bn
1666
+ from torchvision.models import resnet50
1667
+
1668
+ # from config import Config
1669
+ # from dataset import class_labels_TR_sorted
1670
+ # from models.build_backbone import build_backbone
1671
+ # from models.decoder_blocks import BasicDecBlk
1672
+ # from models.lateral_blocks import BasicLatBlk
1673
+ # from models.ing import *
1674
+ # from models.stem_layer import StemLayer
1675
+
1676
+
1677
+ class RefinerPVTInChannels4(nn.Module):
1678
+ def __init__(self, in_channels=3+1):
1679
+ super(RefinerPVTInChannels4, self).__init__()
1680
+ self.config = Config()
1681
+ self.epoch = 1
1682
+ self.bb = build_backbone(self.config.bb, params_settings='in_channels=4')
1683
+
1684
+ lateral_channels_in_collection = {
1685
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
1686
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
1687
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
1688
+ }
1689
+ channels = lateral_channels_in_collection[self.config.bb]
1690
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
1691
+
1692
+ self.decoder = Decoder(channels)
1693
+
1694
+ if 0:
1695
+ for key, value in self.named_parameters():
1696
+ if 'bb.' in key:
1697
+ value.requires_grad = False
1698
+
1699
+ def forward(self, x):
1700
+ if isinstance(x, list):
1701
+ x = torch.cat(x, dim=1)
1702
+ ########## Encoder ##########
1703
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
1704
+ x1 = self.bb.conv1(x)
1705
+ x2 = self.bb.conv2(x1)
1706
+ x3 = self.bb.conv3(x2)
1707
+ x4 = self.bb.conv4(x3)
1708
+ else:
1709
+ x1, x2, x3, x4 = self.bb(x)
1710
+
1711
+ x4 = self.squeeze_module(x4)
1712
+
1713
+ ########## Decoder ##########
1714
+
1715
+ features = [x, x1, x2, x3, x4]
1716
+ scaled_preds = self.decoder(features)
1717
+
1718
+ return scaled_preds
1719
+
1720
+
1721
+ class Refiner(nn.Module):
1722
+ def __init__(self, in_channels=3+1):
1723
+ super(Refiner, self).__init__()
1724
+ self.config = Config()
1725
+ self.epoch = 1
1726
+ self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
1727
+ self.bb = build_backbone(self.config.bb)
1728
+
1729
+ lateral_channels_in_collection = {
1730
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
1731
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
1732
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
1733
+ }
1734
+ channels = lateral_channels_in_collection[self.config.bb]
1735
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
1736
+
1737
+ self.decoder = Decoder(channels)
1738
+
1739
+ if 0:
1740
+ for key, value in self.named_parameters():
1741
+ if 'bb.' in key:
1742
+ value.requires_grad = False
1743
+
1744
+ def forward(self, x):
1745
+ if isinstance(x, list):
1746
+ x = torch.cat(x, dim=1)
1747
+ x = self.stem_layer(x)
1748
+ ########## Encoder ##########
1749
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
1750
+ x1 = self.bb.conv1(x)
1751
+ x2 = self.bb.conv2(x1)
1752
+ x3 = self.bb.conv3(x2)
1753
+ x4 = self.bb.conv4(x3)
1754
+ else:
1755
+ x1, x2, x3, x4 = self.bb(x)
1756
+
1757
+ x4 = self.squeeze_module(x4)
1758
+
1759
+ ########## Decoder ##########
1760
+
1761
+ features = [x, x1, x2, x3, x4]
1762
+ scaled_preds = self.decoder(features)
1763
+
1764
+ return scaled_preds
1765
+
1766
+
1767
+ class Decoder(nn.Module):
1768
+ def __init__(self, channels):
1769
+ super(Decoder, self).__init__()
1770
+ self.config = Config()
1771
+ DecoderBlock = eval('BasicDecBlk')
1772
+ LateralBlock = eval('BasicLatBlk')
1773
+
1774
+ self.decoder_block4 = DecoderBlock(channels[0], channels[1])
1775
+ self.decoder_block3 = DecoderBlock(channels[1], channels[2])
1776
+ self.decoder_block2 = DecoderBlock(channels[2], channels[3])
1777
+ self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2)
1778
+
1779
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
1780
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
1781
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
1782
+
1783
+ if self.config.ms_supervision:
1784
+ self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
1785
+ self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
1786
+ self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
1787
+ self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0))
1788
+
1789
+ def forward(self, features):
1790
+ x, x1, x2, x3, x4 = features
1791
+ outs = []
1792
+ p4 = self.decoder_block4(x4)
1793
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
1794
+ _p3 = _p4 + self.lateral_block4(x3)
1795
+
1796
+ p3 = self.decoder_block3(_p3)
1797
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
1798
+ _p2 = _p3 + self.lateral_block3(x2)
1799
+
1800
+ p2 = self.decoder_block2(_p2)
1801
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
1802
+ _p1 = _p2 + self.lateral_block2(x1)
1803
+
1804
+ _p1 = self.decoder_block1(_p1)
1805
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
1806
+ p1_out = self.conv_out1(_p1)
1807
+
1808
+ if self.config.ms_supervision:
1809
+ outs.append(self.conv_ms_spvn_4(p4))
1810
+ outs.append(self.conv_ms_spvn_3(p3))
1811
+ outs.append(self.conv_ms_spvn_2(p2))
1812
+ outs.append(p1_out)
1813
+ return outs
1814
+
1815
+
1816
+ class RefUNet(nn.Module):
1817
+ # Refinement
1818
+ def __init__(self, in_channels=3+1):
1819
+ super(RefUNet, self).__init__()
1820
+ self.encoder_1 = nn.Sequential(
1821
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
1822
+ nn.Conv2d(64, 64, 3, 1, 1),
1823
+ nn.BatchNorm2d(64),
1824
+ nn.ReLU(inplace=True)
1825
+ )
1826
+
1827
+ self.encoder_2 = nn.Sequential(
1828
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1829
+ nn.Conv2d(64, 64, 3, 1, 1),
1830
+ nn.BatchNorm2d(64),
1831
+ nn.ReLU(inplace=True)
1832
+ )
1833
+
1834
+ self.encoder_3 = nn.Sequential(
1835
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1836
+ nn.Conv2d(64, 64, 3, 1, 1),
1837
+ nn.BatchNorm2d(64),
1838
+ nn.ReLU(inplace=True)
1839
+ )
1840
+
1841
+ self.encoder_4 = nn.Sequential(
1842
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1843
+ nn.Conv2d(64, 64, 3, 1, 1),
1844
+ nn.BatchNorm2d(64),
1845
+ nn.ReLU(inplace=True)
1846
+ )
1847
+
1848
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
1849
+ #####
1850
+ self.decoder_5 = nn.Sequential(
1851
+ nn.Conv2d(64, 64, 3, 1, 1),
1852
+ nn.BatchNorm2d(64),
1853
+ nn.ReLU(inplace=True)
1854
+ )
1855
+ #####
1856
+ self.decoder_4 = nn.Sequential(
1857
+ nn.Conv2d(128, 64, 3, 1, 1),
1858
+ nn.BatchNorm2d(64),
1859
+ nn.ReLU(inplace=True)
1860
+ )
1861
+
1862
+ self.decoder_3 = nn.Sequential(
1863
+ nn.Conv2d(128, 64, 3, 1, 1),
1864
+ nn.BatchNorm2d(64),
1865
+ nn.ReLU(inplace=True)
1866
+ )
1867
+
1868
+ self.decoder_2 = nn.Sequential(
1869
+ nn.Conv2d(128, 64, 3, 1, 1),
1870
+ nn.BatchNorm2d(64),
1871
+ nn.ReLU(inplace=True)
1872
+ )
1873
+
1874
+ self.decoder_1 = nn.Sequential(
1875
+ nn.Conv2d(128, 64, 3, 1, 1),
1876
+ nn.BatchNorm2d(64),
1877
+ nn.ReLU(inplace=True)
1878
+ )
1879
+
1880
+ self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1)
1881
+
1882
+ self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1883
+
1884
+ def forward(self, x):
1885
+ outs = []
1886
+ if isinstance(x, list):
1887
+ x = torch.cat(x, dim=1)
1888
+ hx = x
1889
+
1890
+ hx1 = self.encoder_1(hx)
1891
+ hx2 = self.encoder_2(hx1)
1892
+ hx3 = self.encoder_3(hx2)
1893
+ hx4 = self.encoder_4(hx3)
1894
+
1895
+ hx = self.decoder_5(self.pool4(hx4))
1896
+ hx = torch.cat((self.upscore2(hx), hx4), 1)
1897
+
1898
+ d4 = self.decoder_4(hx)
1899
+ hx = torch.cat((self.upscore2(d4), hx3), 1)
1900
+
1901
+ d3 = self.decoder_3(hx)
1902
+ hx = torch.cat((self.upscore2(d3), hx2), 1)
1903
+
1904
+ d2 = self.decoder_2(hx)
1905
+ hx = torch.cat((self.upscore2(d2), hx1), 1)
1906
+
1907
+ d1 = self.decoder_1(hx)
1908
+
1909
+ x = self.conv_d0(d1)
1910
+ outs.append(x)
1911
+ return outs
1912
+
1913
+
1914
+
1915
+ ### models/stem_layer.py
1916
+
1917
+ import torch.nn as nn
1918
+ # from utils import build_act_layer, build_norm_layer
1919
+
1920
+
1921
+ class StemLayer(nn.Module):
1922
+ r""" Stem layer of InternImage
1923
+ Args:
1924
+ in_channels (int): number of input channels
1925
+ out_channels (int): number of output channels
1926
+ act_layer (str): activation layer
1927
+ norm_layer (str): normalization layer
1928
+ """
1929
+
1930
+ def __init__(self,
1931
+ in_channels=3+1,
1932
+ inter_channels=48,
1933
+ out_channels=96,
1934
+ act_layer='GELU',
1935
+ norm_layer='BN'):
1936
+ super().__init__()
1937
+ self.conv1 = nn.Conv2d(in_channels,
1938
+ inter_channels,
1939
+ kernel_size=3,
1940
+ stride=1,
1941
+ padding=1)
1942
+ self.norm1 = build_norm_layer(
1943
+ inter_channels, norm_layer, 'channels_first', 'channels_first'
1944
+ )
1945
+ self.act = build_act_layer(act_layer)
1946
+ self.conv2 = nn.Conv2d(inter_channels,
1947
+ out_channels,
1948
+ kernel_size=3,
1949
+ stride=1,
1950
+ padding=1)
1951
+ self.norm2 = build_norm_layer(
1952
+ out_channels, norm_layer, 'channels_first', 'channels_first'
1953
+ )
1954
+
1955
+ def forward(self, x):
1956
+ x = self.conv1(x)
1957
+ x = self.norm1(x)
1958
+ x = self.act(x)
1959
+ x = self.conv2(x)
1960
+ x = self.norm2(x)
1961
+ return x
1962
+
1963
+
1964
+ ### models/birefnet.py
1965
+
1966
+ import torch
1967
+ import torch.nn as nn
1968
+ import torch.nn.functional as F
1969
+ from kornia.filters import laplacian
1970
+ from transformers import PreTrainedModel
1971
+ from einops import rearrange
1972
+
1973
+ # from config import Config
1974
+ # from dataset import class_labels_TR_sorted
1975
+ # from models.build_backbone import build_backbone
1976
+ # from models.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk
1977
+ # from models.lateral_blocks import BasicLatBlk
1978
+ # from models.aspp import ASPP, ASPPDeformable
1979
+ # from models.ing import *
1980
+ # from models.refiner import Refiner, RefinerPVTInChannels4, RefUNet
1981
+ # from models.stem_layer import StemLayer
1982
+ from .BiRefNet_config import BiRefNetConfig
1983
+
1984
+
1985
+ def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
1986
+ if patch_ref is not None:
1987
+ grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
1988
+ patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
1989
+ return patches
1990
+
1991
+ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
1992
+ if patch_ref is not None:
1993
+ grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
1994
+ image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
1995
+ return image
1996
+
1997
+ class BiRefNet(
1998
+ PreTrainedModel
1999
+ ):
2000
+ config_class = BiRefNetConfig
2001
+ def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
2002
+ super(BiRefNet, self).__init__(config)
2003
+ bb_pretrained = config.bb_pretrained
2004
+ self.config = Config()
2005
+ self.epoch = 1
2006
+ self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
2007
+
2008
+ channels = self.config.lateral_channels_in_collection
2009
+
2010
+ if self.config.auxiliary_classification:
2011
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
2012
+ self.cls_head = nn.Sequential(
2013
+ nn.Linear(channels[0], len(class_labels_TR_sorted))
2014
+ )
2015
+
2016
+ if self.config.squeeze_block:
2017
+ self.squeeze_module = nn.Sequential(*[
2018
+ eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
2019
+ for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
2020
+ ])
2021
+
2022
+ self.decoder = Decoder(channels)
2023
+
2024
+ if self.config.ender:
2025
+ self.dec_end = nn.Sequential(
2026
+ nn.Conv2d(1, 16, 3, 1, 1),
2027
+ nn.Conv2d(16, 1, 3, 1, 1),
2028
+ nn.ReLU(inplace=True),
2029
+ )
2030
+
2031
+ # refine patch-level segmentation
2032
+ if self.config.refine:
2033
+ if self.config.refine == 'itself':
2034
+ self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
2035
+ else:
2036
+ self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
2037
+
2038
+ if self.config.freeze_bb:
2039
+ # Freeze the backbone...
2040
+ print(self.named_parameters())
2041
+ for key, value in self.named_parameters():
2042
+ if 'bb.' in key and 'refiner.' not in key:
2043
+ value.requires_grad = False
2044
+
2045
+ def forward_enc(self, x):
2046
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
2047
+ x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
2048
+ else:
2049
+ x1, x2, x3, x4 = self.bb(x)
2050
+ if self.config.mul_scl_ipt == 'cat':
2051
+ B, C, H, W = x.shape
2052
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2053
+ x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2054
+ x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2055
+ x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2056
+ x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2057
+ elif self.config.mul_scl_ipt == 'add':
2058
+ B, C, H, W = x.shape
2059
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2060
+ x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
2061
+ x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
2062
+ x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
2063
+ x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
2064
+ class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
2065
+ if self.config.cxt:
2066
+ x4 = torch.cat(
2067
+ (
2068
+ *[
2069
+ F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
2070
+ F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
2071
+ F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
2072
+ ][-len(self.config.cxt):],
2073
+ x4
2074
+ ),
2075
+ dim=1
2076
+ )
2077
+ return (x1, x2, x3, x4), class_preds
2078
+
2079
+ def forward_ori(self, x):
2080
+ ########## Encoder ##########
2081
+ (x1, x2, x3, x4), class_preds = self.forward_enc(x)
2082
+ if self.config.squeeze_block:
2083
+ x4 = self.squeeze_module(x4)
2084
+ ########## Decoder ##########
2085
+ features = [x, x1, x2, x3, x4]
2086
+ if self.training and self.config.out_ref:
2087
+ features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
2088
+ scaled_preds = self.decoder(features)
2089
+ return scaled_preds, class_preds
2090
+
2091
+ def forward(self, x):
2092
+ scaled_preds, class_preds = self.forward_ori(x)
2093
+ class_preds_lst = [class_preds]
2094
+ return [scaled_preds, class_preds_lst] if self.training else scaled_preds
2095
+
2096
+
2097
+ class Decoder(nn.Module):
2098
+ def __init__(self, channels):
2099
+ super(Decoder, self).__init__()
2100
+ self.config = Config()
2101
+ DecoderBlock = eval(self.config.dec_blk)
2102
+ LateralBlock = eval(self.config.lat_blk)
2103
+
2104
+ if self.config.dec_ipt:
2105
+ self.split = self.config.dec_ipt_split
2106
+ N_dec_ipt = 64
2107
+ DBlock = SimpleConvs
2108
+ ic = 64
2109
+ ipt_cha_opt = 1
2110
+ self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
2111
+ self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
2112
+ self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
2113
+ self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
2114
+ self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
2115
+ else:
2116
+ self.split = None
2117
+
2118
+ self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1])
2119
+ self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
2120
+ self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
2121
+ self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
2122
+ self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0))
2123
+
2124
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
2125
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
2126
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
2127
+
2128
+ if self.config.ms_supervision:
2129
+ self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
2130
+ self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
2131
+ self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
2132
+
2133
+ if self.config.out_ref:
2134
+ _N = 16
2135
+ self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2136
+ self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2137
+ self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2138
+
2139
+ self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2140
+ self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
+ self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2142
+
2143
+ self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2144
+ self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2145
+ self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2146
+
2147
+ def forward(self, features):
2148
+ if self.training and self.config.out_ref:
2149
+ outs_gdt_pred = []
2150
+ outs_gdt_label = []
2151
+ x, x1, x2, x3, x4, gdt_gt = features
2152
+ else:
2153
+ x, x1, x2, x3, x4 = features
2154
+ outs = []
2155
+
2156
+ if self.config.dec_ipt:
2157
+ patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2158
+ x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2159
+ p4 = self.decoder_block4(x4)
2160
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
2161
+ if self.config.out_ref:
2162
+ p4_gdt = self.gdt_convs_4(p4)
2163
+ if self.training:
2164
+ # >> GT:
2165
+ m4_dia = m4
2166
+ gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2167
+ outs_gdt_label.append(gdt_label_main_4)
2168
+ # >> Pred:
2169
+ gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
2170
+ outs_gdt_pred.append(gdt_pred_4)
2171
+ gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
2172
+ # >> Finally:
2173
+ p4 = p4 * gdt_attn_4
2174
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
2175
+ _p3 = _p4 + self.lateral_block4(x3)
2176
+
2177
+ if self.config.dec_ipt:
2178
+ patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2179
+ _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2180
+ p3 = self.decoder_block3(_p3)
2181
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
2182
+ if self.config.out_ref:
2183
+ p3_gdt = self.gdt_convs_3(p3)
2184
+ if self.training:
2185
+ # >> GT:
2186
+ # m3 --dilation--> m3_dia
2187
+ # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
2188
+ m3_dia = m3
2189
+ gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2190
+ outs_gdt_label.append(gdt_label_main_3)
2191
+ # >> Pred:
2192
+ # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
2193
+ # F_3^G --sigmoid--> A_3^G
2194
+ gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
2195
+ outs_gdt_pred.append(gdt_pred_3)
2196
+ gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
2197
+ # >> Finally:
2198
+ # p3 = p3 * A_3^G
2199
+ p3 = p3 * gdt_attn_3
2200
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
2201
+ _p2 = _p3 + self.lateral_block3(x2)
2202
+
2203
+ if self.config.dec_ipt:
2204
+ patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2205
+ _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2206
+ p2 = self.decoder_block2(_p2)
2207
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
2208
+ if self.config.out_ref:
2209
+ p2_gdt = self.gdt_convs_2(p2)
2210
+ if self.training:
2211
+ # >> GT:
2212
+ m2_dia = m2
2213
+ gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2214
+ outs_gdt_label.append(gdt_label_main_2)
2215
+ # >> Pred:
2216
+ gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
2217
+ outs_gdt_pred.append(gdt_pred_2)
2218
+ gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
2219
+ # >> Finally:
2220
+ p2 = p2 * gdt_attn_2
2221
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
2222
+ _p1 = _p2 + self.lateral_block2(x1)
2223
+
2224
+ if self.config.dec_ipt:
2225
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2226
+ _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2227
+ _p1 = self.decoder_block1(_p1)
2228
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2229
+
2230
+ if self.config.dec_ipt:
2231
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2232
+ _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2233
+ p1_out = self.conv_out1(_p1)
2234
+
2235
+ if self.config.ms_supervision and self.training:
2236
+ outs.append(m4)
2237
+ outs.append(m3)
2238
+ outs.append(m2)
2239
+ outs.append(p1_out)
2240
+ return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs)
2241
+
2242
+
2243
+ class SimpleConvs(nn.Module):
2244
+ def __init__(
2245
+ self, in_channels: int, out_channels: int, inter_channels=64
2246
+ ) -> None:
2247
+ super().__init__()
2248
+ self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
2249
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
2250
+
2251
+ def forward(self, x):
2252
+ return self.conv_out(self.conv1(x))
Trellv2/ZhengPeng7--BiRefNet/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ZhengPeng7/BiRefNet",
3
+ "architectures": [
4
+ "BiRefNet"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "BiRefNet_config.BiRefNetConfig",
8
+ "AutoModelForImageSegmentation": "birefnet.BiRefNet"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "pt": [
13
+ "AutoModelForImageSegmentation"
14
+ ],
15
+ "tf": [],
16
+ "type": "image"
17
+ }
18
+ },
19
+ "bb_pretrained": false
20
+ }
Trellv2/ZhengPeng7--BiRefNet/handler.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
+ from typing import Dict, List, Any, Tuple
3
+ import os
4
+ import requests
5
+ from io import BytesIO
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from torchvision import transforms
11
+ from transformers import AutoModelForImageSegmentation
12
+
13
+ torch.set_float32_matmul_precision(["high", "highest"][0])
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ ### image_proc.py
18
+ def refine_foreground(image, mask, r=90):
19
+ if mask.size != image.size:
20
+ mask = mask.resize(image.size)
21
+ image = np.array(image) / 255.0
22
+ mask = np.array(mask) / 255.0
23
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
24
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
25
+ return image_masked
26
+
27
+
28
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
29
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
30
+ alpha = alpha[:, :, None]
31
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
32
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
33
+
34
+
35
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36
+ if isinstance(image, Image.Image):
37
+ image = np.array(image) / 255.0
38
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
39
+
40
+ blurred_FA = cv2.blur(F * alpha, (r, r))
41
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
42
+
43
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45
+ F = blurred_F + alpha * \
46
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
47
+ F = np.clip(F, 0, 1)
48
+ return F, blurred_B
49
+
50
+
51
+ class ImagePreprocessor():
52
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
53
+ self.transform_image = transforms.Compose([
54
+ transforms.Resize(resolution),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57
+ ])
58
+
59
+ def proc(self, image: Image.Image) -> torch.Tensor:
60
+ image = self.transform_image(image)
61
+ return image
62
+
63
+ usage_to_weights_file = {
64
+ 'General': 'BiRefNet',
65
+ 'General-HR': 'BiRefNet_HR',
66
+ 'General-Lite': 'BiRefNet_lite',
67
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
68
+ 'General-reso_512': 'BiRefNet-reso_512',
69
+ 'Matting': 'BiRefNet-matting',
70
+ 'Matting-HR': 'BiRefNet_HR-Matting',
71
+ 'Portrait': 'BiRefNet-portrait',
72
+ 'DIS': 'BiRefNet-DIS5K',
73
+ 'HRSOD': 'BiRefNet-HRSOD',
74
+ 'COD': 'BiRefNet-COD',
75
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
76
+ 'General-legacy': 'BiRefNet-legacy'
77
+ }
78
+
79
+ # Choose the version of BiRefNet here.
80
+ usage = 'General'
81
+
82
+ # Set resolution
83
+ if usage in ['General-Lite-2K']:
84
+ resolution = (2560, 1440)
85
+ elif usage in ['General-reso_512']:
86
+ resolution = (512, 512)
87
+ elif usage in ['General-HR', 'Matting-HR']:
88
+ resolution = (2048, 2048)
89
+ else:
90
+ resolution = (1024, 1024)
91
+
92
+ half_precision = True
93
+
94
+ class EndpointHandler():
95
+ def __init__(self, path=''):
96
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
97
+ '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
98
+ )
99
+ self.birefnet.to(device)
100
+ self.birefnet.eval()
101
+ if half_precision:
102
+ self.birefnet.half()
103
+
104
+ def __call__(self, data: Dict[str, Any]):
105
+ """
106
+ data args:
107
+ inputs (:obj: `str`)
108
+ date (:obj: `str`)
109
+ Return:
110
+ A :obj:`list` | `dict`: will be serialized and returned
111
+ """
112
+ print('data["inputs"] = ', data["inputs"])
113
+ image_src = data["inputs"]
114
+ if isinstance(image_src, str):
115
+ if os.path.isfile(image_src):
116
+ image_ori = Image.open(image_src)
117
+ else:
118
+ response = requests.get(image_src)
119
+ image_data = BytesIO(response.content)
120
+ image_ori = Image.open(image_data)
121
+ else:
122
+ image_ori = Image.fromarray(image_src)
123
+
124
+ image = image_ori.convert('RGB')
125
+ # Preprocess the image
126
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
127
+ image_proc = image_preprocessor.proc(image)
128
+ image_proc = image_proc.unsqueeze(0)
129
+
130
+ # Prediction
131
+ with torch.no_grad():
132
+ preds = self.birefnet(image_proc.to(device).half() if half_precision else image_proc.to(device))[-1].sigmoid().cpu()
133
+ pred = preds[0].squeeze()
134
+
135
+ # Show Results
136
+ pred_pil = transforms.ToPILImage()(pred)
137
+ image_masked = refine_foreground(image, pred_pil)
138
+ image_masked.putalpha(pred_pil.resize(image.size))
139
+ return image_masked
Trellv2/ZhengPeng7--BiRefNet/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ab37426bf4de0567af6b5d21b16151357149139362e6e8992021b8ce356a154
3
+ size 444473596
Trellv2/ZhengPeng7--BiRefNet/requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision
3
+ numpy<2
4
+ opencv-python
5
+ timm
6
+ scipy
7
+ scikit-image
8
+ kornia
9
+ einops
10
+
11
+ tqdm
12
+ prettytable
13
+
14
+ transformers
15
+ huggingface-hub>0.25
16
+ accelerate
Trellv2/briaai--RMBG-2.0/.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model_not_working.not_safetensors filter=lfs diff=lfs merge=lfs -text
37
+ t4.png filter=lfs diff=lfs merge=lfs -text
38
+ collage.png filter=lfs diff=lfs merge=lfs -text
39
+ collage3.png filter=lfs diff=lfs merge=lfs -text
40
+ collage5.png filter=lfs diff=lfs merge=lfs -text
Trellv2/briaai--RMBG-2.0/BiRefNet_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class BiRefNetConfig(PretrainedConfig):
4
+ model_type = "SegformerForSemanticSegmentation"
5
+ def __init__(
6
+ self,
7
+ bb_pretrained=False,
8
+ **kwargs
9
+ ):
10
+ self.bb_pretrained = bb_pretrained
11
+ super().__init__(**kwargs)
Trellv2/briaai--RMBG-2.0/README.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: bria-rmbg-2.0
4
+ license_link: https://creativecommons.org/licenses/by-nc/4.0/deed.en
5
+ pipeline_tag: image-segmentation
6
+ tags:
7
+ - remove background
8
+ - background
9
+ - background-removal
10
+ - Pytorch
11
+ - vision
12
+ - legal liability
13
+ - transformers
14
+ - transformers.js
15
+ extra_gated_description: >-
16
+ Bria AI Model weights are open source for non commercial use only, per the
17
+ provided [license](https://creativecommons.org/licenses/by-nc/4.0/deed.en).
18
+ extra_gated_heading: Fill in this form to immediatly access the model for non commercial use
19
+ extra_gated_fields:
20
+ Name: text
21
+ Email: text
22
+ Company/Org name: text
23
+ Company Website URL: text
24
+ Discord user: text
25
+ I agree to BRIA’s Privacy policy, Terms & conditions, and acknowledge Non commercial use to be Personal use / Academy / Non profit (direct or indirect): checkbox
26
+ ---
27
+
28
+ # BRIA Background Removal v2.0 Model Card
29
+ <p align="center"><img src="https://platform.bria.ai/assets/Bria-logo-5e0c53b1.svg" alt="BRIA Logo" width="200" /></p>
30
+
31
+ <!-- RMBG Card wrapper -->
32
+ <div class="rmbg-card" style="position: relative; border-radius: 12px; overflow: hidden;">
33
+
34
+ <!-- FIBO Promo Banner (Top) -->
35
+ <a
36
+ href="https://huggingface.co/briaai/FIBO"
37
+ target="_blank"
38
+ rel="noopener"
39
+ aria-label="Explore FIBO on Hugging Face"
40
+ style="
41
+ position: absolute;
42
+ top: 0;
43
+ left: 0;
44
+ width: 100%;
45
+ display: flex;
46
+ align-items: center;
47
+ justify-content: center;
48
+ gap: 10px;
49
+ background: linear-gradient(90deg, #fff6b7 0%, #fde047 100%);
50
+ color: #1f2937;
51
+ text-decoration: none;
52
+ font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif;
53
+ font-weight: 600;
54
+ font-size: 13px;
55
+ padding: 10px 0;
56
+ border-bottom: 1px solid rgba(0,0,0,0.08);
57
+ box-shadow: 0 2px 8px rgba(0,0,0,0.08);
58
+ z-index: 10;
59
+ ">
60
+ <img
61
+ src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
62
+ alt="Hugging Face"
63
+ width="18"
64
+ height="18"
65
+ style="display:block"
66
+ />
67
+ <span>✨ Discover <strong>FIBO</strong> on Hugging Face</span>
68
+ </a>
69
+
70
+ <!-- ... your RMBG content below ... -->
71
+ <p align="center">
72
+ 💜 <a href="https://go.bria.ai/46gzn20"><b>Bria AI</b></a>&nbsp&nbsp | &nbsp&nbsp🤗 <a href="https://huggingface.co/briaai/">Hugging Face</a> &nbsp&nbsp | &nbsp&nbsp 📑 <a href="https://blog.bria.ai/">Blog</a> &nbsp&nbsp
73
+ <br>
74
+ 🖥️ <a href="https://huggingface.co/spaces/briaai/BRIA-RMBG-2.0">Demo</a>&nbsp&nbsp| &nbsp&nbsp <a href="https://github.com/Bria-AI/RMBG-2.0">Github</a>&nbsp&nbsp
75
+ </p>
76
+
77
+ RMBG v2.0 is our new state-of-the-art background removal model significantly improves RMBG v1.4. The model is designed to effectively separate foreground from background in a range of
78
+ categories and image types. This model has been trained on a carefully selected dataset, which includes:
79
+ general stock images, e-commerce, gaming, and advertising content, making it suitable for commercial use cases powering enterprise content creation at scale.
80
+ The accuracy, efficiency, and versatility currently rival leading source-available models.
81
+ It is ideal where content safety, legally licensed datasets, and bias mitigation are paramount.
82
+
83
+ Developed by BRIA AI, RMBG v2.0 is available as a source-available model for non-commercial use.
84
+
85
+ ### Get Access
86
+
87
+ Bria RMBG2.0 is availabe everywhere you build, either as source-code and weights, ComfyUI nodes or API endpoints.
88
+
89
+ - **Purchase:** To purchase a commercial license for RMBG V2.0 **or** an API package [Click Here](https://share-eu1.hsforms.com/2sj9FVZTGSFmFRibDLhr_ZAf4e04).
90
+ - **API Endpoint**: [Bria.ai](https://docs.bria.ai/image-editing/v2-endpoints/background-remove), [fal.ai](https://fal.ai/models/fal-ai/bria/background/remove), [Replicate](https://replicate.com/bria/remove-background)
91
+ - **ComfyUI**: [Use it in workflows](https://github.com/Bria-AI/ComfyUI-BRIA-API)
92
+ - **GitHub**: [github.com/Bria-AI/RMBG-2.0](https://github.com/Bria-AI/RMBG-2.0)
93
+
94
+ For more information, please visit our [website](https://bria.ai/).
95
+
96
+ Join our [Discord community](https://discord.gg/Nxe9YW9zHS) for more information, tutorials, tools, and to connect with other users!
97
+
98
+ [CLICK HERE FOR A DEMO](https://huggingface.co/spaces/briaai/BRIA-RMBG-2.0)
99
+
100
+
101
+
102
+ ![examples](t4.png)
103
+
104
+ ## Model Details
105
+ #####
106
+ ### Model Description
107
+
108
+ - **Developed by:** [BRIA AI](https://bria.ai/)
109
+ - **Model type:** Background Removal
110
+ - **License:** [Creative Commons Attribution–Non-Commercial (CC BY-NC 4.0)](https://creativecommons.org/licenses/by-nc/4.0/deed.en)
111
+ - The model is released under a CC BY-NC 4.0 license for non-commercial use.
112
+ - Commercial use is subject to a commercial agreement with BRIA. Available [here](https://share-eu1.hsforms.com/2sj9FVZTGSFmFRibDLhr_ZAf4e04)
113
+
114
+
115
+ - **Model Description:** BRIA RMBG-2.0 is a dichotomous image segmentation model trained exclusively on a professional-grade dataset. The model output includes a single-channel 8-bit grayscale alpha matte, where each pixel value indicates the opacity level of the corresponding pixel in the original image. This non-binary output approach offers developers the flexibility to define custom thresholds for foreground-background separation, catering to varied use cases requirements and enhancing integration into complex pipelines.
116
+ - **BRIA:** Resources for more information: [BRIA AI](https://bria.ai/)
117
+
118
+
119
+
120
+ ## Training data
121
+ Bria-RMBG model was trained with over 15,000 high-quality, high-resolution, manually labeled (pixel-wise accuracy), fully licensed images.
122
+ Our benchmark included balanced gender, balanced ethnicity, and people with different types of disabilities.
123
+ For clarity, we provide our data distribution according to different categories, demonstrating our model’s versatility.
124
+
125
+ ### Distribution of images:
126
+
127
+ | Category | Distribution |
128
+ | -----------------------------------| -----------------------------------:|
129
+ | Objects only | 45.11% |
130
+ | People with objects/animals | 25.24% |
131
+ | People only | 17.35% |
132
+ | people/objects/animals with text | 8.52% |
133
+ | Text only | 2.52% |
134
+ | Animals only | 1.89% |
135
+
136
+ | Category | Distribution |
137
+ | -----------------------------------| -----------------------------------------:|
138
+ | Photorealistic | 87.70% |
139
+ | Non-Photorealistic | 12.30% |
140
+
141
+
142
+ | Category | Distribution |
143
+ | -----------------------------------| -----------------------------------:|
144
+ | Non Solid Background | 52.05% |
145
+ | Solid Background | 47.95%
146
+
147
+
148
+ | Category | Distribution |
149
+ | -----------------------------------| -----------------------------------:|
150
+ | Single main foreground object | 51.42% |
151
+ | Multiple objects in the foreground | 48.58% |
152
+
153
+
154
+ ## Qualitative Evaluation
155
+ Open source models comparison
156
+ ![diagram](diagram1.png)
157
+ ![examples](collage5.png)
158
+
159
+ ### Architecture
160
+ RMBG-2.0 is developed on the [BiRefNet](https://github.com/ZhengPeng7/BiRefNet) architecture enhanced with our proprietary dataset and training scheme. This training data significantly improves the model’s accuracy and effectiveness for background-removal task.<br>
161
+ If you use this model in your research, please cite:
162
+
163
+ ```
164
+ @article{BiRefNet,
165
+ title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
166
+ author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
167
+ journal={CAAI Artificial Intelligence Research},
168
+ year={2024}
169
+ }
170
+ ```
171
+
172
+ #### Requirements
173
+ ```bash
174
+ torch
175
+ torchvision
176
+ pillow
177
+ kornia
178
+ transformers
179
+ ```
180
+
181
+ ### Usage
182
+
183
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
184
+
185
+
186
+ ```python
187
+ from PIL import Image
188
+ import torch
189
+ from torchvision import transforms
190
+ from transformers import AutoModelForImageSegmentation
191
+
192
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
193
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True).eval().to(device)
194
+
195
+ # Data settings
196
+ image_size = (1024, 1024)
197
+ transform_image = transforms.Compose([
198
+ transforms.Resize(image_size),
199
+ transforms.ToTensor(),
200
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
201
+ ])
202
+
203
+ image = Image.open(input_image_path)
204
+ input_images = transform_image(image).unsqueeze(0).to(device)
205
+
206
+ # Prediction
207
+ with torch.no_grad():
208
+ preds = model(input_images)[-1].sigmoid().cpu()
209
+ pred = preds[0].squeeze()
210
+ pred_pil = transforms.ToPILImage()(pred)
211
+ mask = pred_pil.resize(image.size)
212
+ image.putalpha(mask)
213
+
214
+ image.save("no_bg_image.png")
215
+ ```
216
+
217
+
218
+ </div>
Trellv2/briaai--RMBG-2.0/birefnet.py ADDED
@@ -0,0 +1,2245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### config.py
2
+
3
+ import os
4
+ import math
5
+ from transformers import PretrainedConfig
6
+
7
+ class Config(PretrainedConfig):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+ # PATH settings
11
+ self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
12
+
13
+ # TASK settings
14
+ self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
15
+ self.training_set = {
16
+ 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0],
17
+ 'COD': 'TR-COD10K+TR-CAMO',
18
+ 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5],
19
+ 'DIS5K+HRSOD+HRS10K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TE-HRS10K+TE-HRSOD+TE-UHRSD+TR-HRS10K+TR-HRSOD+TR-UHRSD', # leave DIS-VD for evaluation.
20
+ 'P3M-10k': 'TR-P3M-10k',
21
+ }[self.task]
22
+ self.prompt4loc = ['dense', 'sparse'][0]
23
+
24
+ # Faster-Training settings
25
+ self.load_all = True
26
+ self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch.
27
+ # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting.
28
+ # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607.
29
+ # 3. But compile in Pytorch > 2.0.1 seems to bring no acceleration for training.
30
+ self.precisionHigh = True
31
+
32
+ # MODEL settings
33
+ self.ms_supervision = True
34
+ self.out_ref = self.ms_supervision and True
35
+ self.dec_ipt = True
36
+ self.dec_ipt_split = True
37
+ self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder
38
+ self.mul_scl_ipt = ['', 'add', 'cat'][2]
39
+ self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2]
40
+ self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1]
41
+ self.dec_blk = ['BasicDecBlk', 'ResBlk', 'HierarAttDecBlk'][0]
42
+
43
+ # TRAINING settings
44
+ self.batch_size = 4
45
+ self.IoU_finetune_last_epochs = [
46
+ 0,
47
+ {
48
+ 'DIS5K': -50,
49
+ 'COD': -20,
50
+ 'HRSOD': -20,
51
+ 'DIS5K+HRSOD+HRS10K': -20,
52
+ 'P3M-10k': -20,
53
+ }[self.task]
54
+ ][1] # choose 0 to skip
55
+ self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly
56
+ self.size = 1024
57
+ self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader
58
+
59
+ # Backbone settings
60
+ self.bb = [
61
+ 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2
62
+ 'swin_v1_t', 'swin_v1_s', # 3, 4
63
+ 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4
64
+ 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8
65
+ 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5
66
+ ][6]
67
+ self.lateral_channels_in_collection = {
68
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
69
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
70
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
71
+ 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96],
72
+ 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64],
73
+ }[self.bb]
74
+ if self.mul_scl_ipt == 'cat':
75
+ self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection]
76
+ self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else []
77
+
78
+ # MODEL settings - inactive
79
+ self.lat_blk = ['BasicLatBlk'][0]
80
+ self.dec_channels_inter = ['fixed', 'adap'][0]
81
+ self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0]
82
+ self.progressive_ref = self.refine and True
83
+ self.ender = self.progressive_ref and False
84
+ self.scale = self.progressive_ref and 2
85
+ self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`.
86
+ self.refine_iteration = 1
87
+ self.freeze_bb = False
88
+ self.model = [
89
+ 'BiRefNet',
90
+ ][0]
91
+ if self.dec_blk == 'HierarAttDecBlk':
92
+ self.batch_size = 2 ** [0, 1, 2, 3, 4][2]
93
+
94
+ # TRAINING settings - inactive
95
+ self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4]
96
+ self.optimizer = ['Adam', 'AdamW'][1]
97
+ self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch.
98
+ self.lr_decay_rate = 0.5
99
+ # Loss
100
+ self.lambdas_pix_last = {
101
+ # not 0 means opening this loss
102
+ # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30
103
+ 'bce': 30 * 1, # high performance
104
+ 'iou': 0.5 * 1, # 0 / 255
105
+ 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64)
106
+ 'mse': 150 * 0, # can smooth the saliency map
107
+ 'triplet': 3 * 0,
108
+ 'reg': 100 * 0,
109
+ 'ssim': 10 * 1, # help contours,
110
+ 'cnt': 5 * 0, # help contours
111
+ 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4.
112
+ }
113
+ self.lambdas_cls = {
114
+ 'ce': 5.0
115
+ }
116
+ # Adv
117
+ self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training
118
+ self.lambda_adv_d = 3. * (self.lambda_adv_g > 0)
119
+
120
+ # PATH settings - inactive
121
+ self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')
122
+ self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights')
123
+ self.weights = {
124
+ 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'),
125
+ 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]),
126
+ 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]),
127
+ 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]),
128
+ 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]),
129
+ 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]),
130
+ 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]),
131
+ 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]),
132
+ }
133
+
134
+ # Callbacks - inactive
135
+ self.verbose_eval = True
136
+ self.only_S_MAE = False
137
+ self.use_fp16 = False # Bugs. It may cause nan in training.
138
+ self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs
139
+
140
+ # others
141
+ self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0')
142
+
143
+ self.batch_size_valid = 1
144
+ self.rand_seed = 7
145
+ # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f]
146
+ # with open(run_sh_file[0], 'r') as f:
147
+ # lines = f.readlines()
148
+ # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0])
149
+ # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0])
150
+ # self.val_step = [0, self.save_step][0]
151
+
152
+ def print_task(self) -> None:
153
+ # Return task for choosing settings in shell scripts.
154
+ print(self.task)
155
+
156
+
157
+
158
+ ### models/backbones/pvt_v2.py
159
+
160
+ import torch
161
+ import torch.nn as nn
162
+ from functools import partial
163
+
164
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
165
+ from timm.models.registry import register_model
166
+
167
+ import math
168
+
169
+ # from config import Config
170
+
171
+ # config = Config()
172
+
173
+ class Mlp(nn.Module):
174
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
175
+ super().__init__()
176
+ out_features = out_features or in_features
177
+ hidden_features = hidden_features or in_features
178
+ self.fc1 = nn.Linear(in_features, hidden_features)
179
+ self.dwconv = DWConv(hidden_features)
180
+ self.act = act_layer()
181
+ self.fc2 = nn.Linear(hidden_features, out_features)
182
+ self.drop = nn.Dropout(drop)
183
+
184
+ self.apply(self._init_weights)
185
+
186
+ def _init_weights(self, m):
187
+ if isinstance(m, nn.Linear):
188
+ trunc_normal_(m.weight, std=.02)
189
+ if isinstance(m, nn.Linear) and m.bias is not None:
190
+ nn.init.constant_(m.bias, 0)
191
+ elif isinstance(m, nn.LayerNorm):
192
+ nn.init.constant_(m.bias, 0)
193
+ nn.init.constant_(m.weight, 1.0)
194
+ elif isinstance(m, nn.Conv2d):
195
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
196
+ fan_out //= m.groups
197
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
198
+ if m.bias is not None:
199
+ m.bias.data.zero_()
200
+
201
+ def forward(self, x, H, W):
202
+ x = self.fc1(x)
203
+ x = self.dwconv(x, H, W)
204
+ x = self.act(x)
205
+ x = self.drop(x)
206
+ x = self.fc2(x)
207
+ x = self.drop(x)
208
+ return x
209
+
210
+
211
+ class Attention(nn.Module):
212
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
213
+ super().__init__()
214
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
215
+
216
+ self.dim = dim
217
+ self.num_heads = num_heads
218
+ head_dim = dim // num_heads
219
+ self.scale = qk_scale or head_dim ** -0.5
220
+
221
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
222
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
223
+ self.attn_drop_prob = attn_drop
224
+ self.attn_drop = nn.Dropout(attn_drop)
225
+ self.proj = nn.Linear(dim, dim)
226
+ self.proj_drop = nn.Dropout(proj_drop)
227
+
228
+ self.sr_ratio = sr_ratio
229
+ if sr_ratio > 1:
230
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
231
+ self.norm = nn.LayerNorm(dim)
232
+
233
+ self.apply(self._init_weights)
234
+
235
+ def _init_weights(self, m):
236
+ if isinstance(m, nn.Linear):
237
+ trunc_normal_(m.weight, std=.02)
238
+ if isinstance(m, nn.Linear) and m.bias is not None:
239
+ nn.init.constant_(m.bias, 0)
240
+ elif isinstance(m, nn.LayerNorm):
241
+ nn.init.constant_(m.bias, 0)
242
+ nn.init.constant_(m.weight, 1.0)
243
+ elif isinstance(m, nn.Conv2d):
244
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
245
+ fan_out //= m.groups
246
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
247
+ if m.bias is not None:
248
+ m.bias.data.zero_()
249
+
250
+ def forward(self, x, H, W):
251
+ B, N, C = x.shape
252
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
253
+
254
+ if self.sr_ratio > 1:
255
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
256
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
257
+ x_ = self.norm(x_)
258
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
259
+ else:
260
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
261
+ k, v = kv[0], kv[1]
262
+
263
+ if config.SDPA_enabled:
264
+ x = torch.nn.functional.scaled_dot_product_attention(
265
+ q, k, v,
266
+ attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
267
+ ).transpose(1, 2).reshape(B, N, C)
268
+ else:
269
+ attn = (q @ k.transpose(-2, -1)) * self.scale
270
+ attn = attn.softmax(dim=-1)
271
+ attn = self.attn_drop(attn)
272
+
273
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
274
+ x = self.proj(x)
275
+ x = self.proj_drop(x)
276
+
277
+ return x
278
+
279
+
280
+ class Block(nn.Module):
281
+
282
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
283
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
284
+ super().__init__()
285
+ self.norm1 = norm_layer(dim)
286
+ self.attn = Attention(
287
+ dim,
288
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
289
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
290
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
291
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
292
+ self.norm2 = norm_layer(dim)
293
+ mlp_hidden_dim = int(dim * mlp_ratio)
294
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
295
+
296
+ self.apply(self._init_weights)
297
+
298
+ def _init_weights(self, m):
299
+ if isinstance(m, nn.Linear):
300
+ trunc_normal_(m.weight, std=.02)
301
+ if isinstance(m, nn.Linear) and m.bias is not None:
302
+ nn.init.constant_(m.bias, 0)
303
+ elif isinstance(m, nn.LayerNorm):
304
+ nn.init.constant_(m.bias, 0)
305
+ nn.init.constant_(m.weight, 1.0)
306
+ elif isinstance(m, nn.Conv2d):
307
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
308
+ fan_out //= m.groups
309
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
310
+ if m.bias is not None:
311
+ m.bias.data.zero_()
312
+
313
+ def forward(self, x, H, W):
314
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
315
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
316
+
317
+ return x
318
+
319
+
320
+ class OverlapPatchEmbed(nn.Module):
321
+ """ Image to Patch Embedding
322
+ """
323
+
324
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768):
325
+ super().__init__()
326
+ img_size = to_2tuple(img_size)
327
+ patch_size = to_2tuple(patch_size)
328
+
329
+ self.img_size = img_size
330
+ self.patch_size = patch_size
331
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
332
+ self.num_patches = self.H * self.W
333
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride,
334
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
335
+ self.norm = nn.LayerNorm(embed_dim)
336
+
337
+ self.apply(self._init_weights)
338
+
339
+ def _init_weights(self, m):
340
+ if isinstance(m, nn.Linear):
341
+ trunc_normal_(m.weight, std=.02)
342
+ if isinstance(m, nn.Linear) and m.bias is not None:
343
+ nn.init.constant_(m.bias, 0)
344
+ elif isinstance(m, nn.LayerNorm):
345
+ nn.init.constant_(m.bias, 0)
346
+ nn.init.constant_(m.weight, 1.0)
347
+ elif isinstance(m, nn.Conv2d):
348
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
349
+ fan_out //= m.groups
350
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
351
+ if m.bias is not None:
352
+ m.bias.data.zero_()
353
+
354
+ def forward(self, x):
355
+ x = self.proj(x)
356
+ _, _, H, W = x.shape
357
+ x = x.flatten(2).transpose(1, 2)
358
+ x = self.norm(x)
359
+
360
+ return x, H, W
361
+
362
+
363
+ class PyramidVisionTransformerImpr(nn.Module):
364
+ def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
365
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
366
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
367
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
368
+ super().__init__()
369
+ self.num_classes = num_classes
370
+ self.depths = depths
371
+
372
+ # patch_embed
373
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels,
374
+ embed_dim=embed_dims[0])
375
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0],
376
+ embed_dim=embed_dims[1])
377
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1],
378
+ embed_dim=embed_dims[2])
379
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2],
380
+ embed_dim=embed_dims[3])
381
+
382
+ # transformer encoder
383
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
384
+ cur = 0
385
+ self.block1 = nn.ModuleList([Block(
386
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
387
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
388
+ sr_ratio=sr_ratios[0])
389
+ for i in range(depths[0])])
390
+ self.norm1 = norm_layer(embed_dims[0])
391
+
392
+ cur += depths[0]
393
+ self.block2 = nn.ModuleList([Block(
394
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
395
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
396
+ sr_ratio=sr_ratios[1])
397
+ for i in range(depths[1])])
398
+ self.norm2 = norm_layer(embed_dims[1])
399
+
400
+ cur += depths[1]
401
+ self.block3 = nn.ModuleList([Block(
402
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
403
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
404
+ sr_ratio=sr_ratios[2])
405
+ for i in range(depths[2])])
406
+ self.norm3 = norm_layer(embed_dims[2])
407
+
408
+ cur += depths[2]
409
+ self.block4 = nn.ModuleList([Block(
410
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
411
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
412
+ sr_ratio=sr_ratios[3])
413
+ for i in range(depths[3])])
414
+ self.norm4 = norm_layer(embed_dims[3])
415
+
416
+ # classification head
417
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
418
+
419
+ self.apply(self._init_weights)
420
+
421
+ def _init_weights(self, m):
422
+ if isinstance(m, nn.Linear):
423
+ trunc_normal_(m.weight, std=.02)
424
+ if isinstance(m, nn.Linear) and m.bias is not None:
425
+ nn.init.constant_(m.bias, 0)
426
+ elif isinstance(m, nn.LayerNorm):
427
+ nn.init.constant_(m.bias, 0)
428
+ nn.init.constant_(m.weight, 1.0)
429
+ elif isinstance(m, nn.Conv2d):
430
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
431
+ fan_out //= m.groups
432
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
433
+ if m.bias is not None:
434
+ m.bias.data.zero_()
435
+
436
+ def init_weights(self, pretrained=None):
437
+ if isinstance(pretrained, str):
438
+ logger = 1
439
+ #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
440
+
441
+ def reset_drop_path(self, drop_path_rate):
442
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
443
+ cur = 0
444
+ for i in range(self.depths[0]):
445
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
446
+
447
+ cur += self.depths[0]
448
+ for i in range(self.depths[1]):
449
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
450
+
451
+ cur += self.depths[1]
452
+ for i in range(self.depths[2]):
453
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
454
+
455
+ cur += self.depths[2]
456
+ for i in range(self.depths[3]):
457
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
458
+
459
+ def freeze_patch_emb(self):
460
+ self.patch_embed1.requires_grad = False
461
+
462
+ @torch.jit.ignore
463
+ def no_weight_decay(self):
464
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
465
+
466
+ def get_classifier(self):
467
+ return self.head
468
+
469
+ def reset_classifier(self, num_classes, global_pool=''):
470
+ self.num_classes = num_classes
471
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
472
+
473
+ def forward_features(self, x):
474
+ B = x.shape[0]
475
+ outs = []
476
+
477
+ # stage 1
478
+ x, H, W = self.patch_embed1(x)
479
+ for i, blk in enumerate(self.block1):
480
+ x = blk(x, H, W)
481
+ x = self.norm1(x)
482
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
483
+ outs.append(x)
484
+
485
+ # stage 2
486
+ x, H, W = self.patch_embed2(x)
487
+ for i, blk in enumerate(self.block2):
488
+ x = blk(x, H, W)
489
+ x = self.norm2(x)
490
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
491
+ outs.append(x)
492
+
493
+ # stage 3
494
+ x, H, W = self.patch_embed3(x)
495
+ for i, blk in enumerate(self.block3):
496
+ x = blk(x, H, W)
497
+ x = self.norm3(x)
498
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
499
+ outs.append(x)
500
+
501
+ # stage 4
502
+ x, H, W = self.patch_embed4(x)
503
+ for i, blk in enumerate(self.block4):
504
+ x = blk(x, H, W)
505
+ x = self.norm4(x)
506
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
507
+ outs.append(x)
508
+
509
+ return outs
510
+
511
+ # return x.mean(dim=1)
512
+
513
+ def forward(self, x):
514
+ x = self.forward_features(x)
515
+ # x = self.head(x)
516
+
517
+ return x
518
+
519
+
520
+ class DWConv(nn.Module):
521
+ def __init__(self, dim=768):
522
+ super(DWConv, self).__init__()
523
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
524
+
525
+ def forward(self, x, H, W):
526
+ B, N, C = x.shape
527
+ x = x.transpose(1, 2).view(B, C, H, W).contiguous()
528
+ x = self.dwconv(x)
529
+ x = x.flatten(2).transpose(1, 2)
530
+
531
+ return x
532
+
533
+
534
+ def _conv_filter(state_dict, patch_size=16):
535
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
536
+ out_dict = {}
537
+ for k, v in state_dict.items():
538
+ if 'patch_embed.proj.weight' in k:
539
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
540
+ out_dict[k] = v
541
+
542
+ return out_dict
543
+
544
+
545
+ ## @register_model
546
+ class pvt_v2_b0(PyramidVisionTransformerImpr):
547
+ def __init__(self, **kwargs):
548
+ super(pvt_v2_b0, self).__init__(
549
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
550
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
551
+ drop_rate=0.0, drop_path_rate=0.1)
552
+
553
+
554
+
555
+ ## @register_model
556
+ class pvt_v2_b1(PyramidVisionTransformerImpr):
557
+ def __init__(self, **kwargs):
558
+ super(pvt_v2_b1, self).__init__(
559
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
560
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
561
+ drop_rate=0.0, drop_path_rate=0.1)
562
+
563
+ ## @register_model
564
+ class pvt_v2_b2(PyramidVisionTransformerImpr):
565
+ def __init__(self, in_channels=3, **kwargs):
566
+ super(pvt_v2_b2, self).__init__(
567
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
568
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
569
+ drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels)
570
+
571
+ ## @register_model
572
+ class pvt_v2_b3(PyramidVisionTransformerImpr):
573
+ def __init__(self, **kwargs):
574
+ super(pvt_v2_b3, self).__init__(
575
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
576
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
577
+ drop_rate=0.0, drop_path_rate=0.1)
578
+
579
+ ## @register_model
580
+ class pvt_v2_b4(PyramidVisionTransformerImpr):
581
+ def __init__(self, **kwargs):
582
+ super(pvt_v2_b4, self).__init__(
583
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
584
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
585
+ drop_rate=0.0, drop_path_rate=0.1)
586
+
587
+
588
+ ## @register_model
589
+ class pvt_v2_b5(PyramidVisionTransformerImpr):
590
+ def __init__(self, **kwargs):
591
+ super(pvt_v2_b5, self).__init__(
592
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
593
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
594
+ drop_rate=0.0, drop_path_rate=0.1)
595
+
596
+
597
+
598
+ ### models/backbones/swin_v1.py
599
+
600
+ # --------------------------------------------------------
601
+ # Swin Transformer
602
+ # Copyright (c) 2021 Microsoft
603
+ # Licensed under The MIT License [see LICENSE for details]
604
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
605
+ # --------------------------------------------------------
606
+
607
+ import torch
608
+ import torch.nn as nn
609
+ import torch.nn.functional as F
610
+ import torch.utils.checkpoint as checkpoint
611
+ import numpy as np
612
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
613
+
614
+ # from config import Config
615
+
616
+
617
+ # config = Config()
618
+
619
+ class Mlp(nn.Module):
620
+ """ Multilayer perceptron."""
621
+
622
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
623
+ super().__init__()
624
+ out_features = out_features or in_features
625
+ hidden_features = hidden_features or in_features
626
+ self.fc1 = nn.Linear(in_features, hidden_features)
627
+ self.act = act_layer()
628
+ self.fc2 = nn.Linear(hidden_features, out_features)
629
+ self.drop = nn.Dropout(drop)
630
+
631
+ def forward(self, x):
632
+ x = self.fc1(x)
633
+ x = self.act(x)
634
+ x = self.drop(x)
635
+ x = self.fc2(x)
636
+ x = self.drop(x)
637
+ return x
638
+
639
+
640
+ def window_partition(x, window_size):
641
+ """
642
+ Args:
643
+ x: (B, H, W, C)
644
+ window_size (int): window size
645
+
646
+ Returns:
647
+ windows: (num_windows*B, window_size, window_size, C)
648
+ """
649
+ B, H, W, C = x.shape
650
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
651
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
652
+ return windows
653
+
654
+
655
+ def window_reverse(windows, window_size, H, W):
656
+ """
657
+ Args:
658
+ windows: (num_windows*B, window_size, window_size, C)
659
+ window_size (int): Window size
660
+ H (int): Height of image
661
+ W (int): Width of image
662
+
663
+ Returns:
664
+ x: (B, H, W, C)
665
+ """
666
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
667
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
668
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
669
+ return x
670
+
671
+
672
+ class WindowAttention(nn.Module):
673
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
674
+ It supports both of shifted and non-shifted window.
675
+
676
+ Args:
677
+ dim (int): Number of input channels.
678
+ window_size (tuple[int]): The height and width of the window.
679
+ num_heads (int): Number of attention heads.
680
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
681
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
682
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
683
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
684
+ """
685
+
686
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
687
+
688
+ super().__init__()
689
+ self.dim = dim
690
+ self.window_size = window_size # Wh, Ww
691
+ self.num_heads = num_heads
692
+ head_dim = dim // num_heads
693
+ self.scale = qk_scale or head_dim ** -0.5
694
+
695
+ # define a parameter table of relative position bias
696
+ self.relative_position_bias_table = nn.Parameter(
697
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
698
+
699
+ # get pair-wise relative position index for each token inside the window
700
+ coords_h = torch.arange(self.window_size[0])
701
+ coords_w = torch.arange(self.window_size[1])
702
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
703
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
704
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
705
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
706
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
707
+ relative_coords[:, :, 1] += self.window_size[1] - 1
708
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
709
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
710
+ self.register_buffer("relative_position_index", relative_position_index)
711
+
712
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
713
+ self.attn_drop_prob = attn_drop
714
+ self.attn_drop = nn.Dropout(attn_drop)
715
+ self.proj = nn.Linear(dim, dim)
716
+ self.proj_drop = nn.Dropout(proj_drop)
717
+
718
+ trunc_normal_(self.relative_position_bias_table, std=.02)
719
+ self.softmax = nn.Softmax(dim=-1)
720
+
721
+ def forward(self, x, mask=None):
722
+ """ Forward function.
723
+
724
+ Args:
725
+ x: input features with shape of (num_windows*B, N, C)
726
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
727
+ """
728
+ B_, N, C = x.shape
729
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
730
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
731
+
732
+ q = q * self.scale
733
+
734
+ if config.SDPA_enabled:
735
+ x = torch.nn.functional.scaled_dot_product_attention(
736
+ q, k, v,
737
+ attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False
738
+ ).transpose(1, 2).reshape(B_, N, C)
739
+ else:
740
+ attn = (q @ k.transpose(-2, -1))
741
+
742
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
743
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
744
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
745
+ attn = attn + relative_position_bias.unsqueeze(0)
746
+
747
+ if mask is not None:
748
+ nW = mask.shape[0]
749
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
750
+ attn = attn.view(-1, self.num_heads, N, N)
751
+ attn = self.softmax(attn)
752
+ else:
753
+ attn = self.softmax(attn)
754
+
755
+ attn = self.attn_drop(attn)
756
+
757
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
758
+ x = self.proj(x)
759
+ x = self.proj_drop(x)
760
+ return x
761
+
762
+
763
+ class SwinTransformerBlock(nn.Module):
764
+ """ Swin Transformer Block.
765
+
766
+ Args:
767
+ dim (int): Number of input channels.
768
+ num_heads (int): Number of attention heads.
769
+ window_size (int): Window size.
770
+ shift_size (int): Shift size for SW-MSA.
771
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
772
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
773
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
774
+ drop (float, optional): Dropout rate. Default: 0.0
775
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
776
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
777
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
778
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
779
+ """
780
+
781
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
782
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
783
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
784
+ super().__init__()
785
+ self.dim = dim
786
+ self.num_heads = num_heads
787
+ self.window_size = window_size
788
+ self.shift_size = shift_size
789
+ self.mlp_ratio = mlp_ratio
790
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
791
+
792
+ self.norm1 = norm_layer(dim)
793
+ self.attn = WindowAttention(
794
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
795
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
796
+
797
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
798
+ self.norm2 = norm_layer(dim)
799
+ mlp_hidden_dim = int(dim * mlp_ratio)
800
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
801
+
802
+ self.H = None
803
+ self.W = None
804
+
805
+ def forward(self, x, mask_matrix):
806
+ """ Forward function.
807
+
808
+ Args:
809
+ x: Input feature, tensor size (B, H*W, C).
810
+ H, W: Spatial resolution of the input feature.
811
+ mask_matrix: Attention mask for cyclic shift.
812
+ """
813
+ B, L, C = x.shape
814
+ H, W = self.H, self.W
815
+ assert L == H * W, "input feature has wrong size"
816
+
817
+ shortcut = x
818
+ x = self.norm1(x)
819
+ x = x.view(B, H, W, C)
820
+
821
+ # pad feature maps to multiples of window size
822
+ pad_l = pad_t = 0
823
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
824
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
825
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
826
+ _, Hp, Wp, _ = x.shape
827
+
828
+ # cyclic shift
829
+ if self.shift_size > 0:
830
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
831
+ attn_mask = mask_matrix
832
+ else:
833
+ shifted_x = x
834
+ attn_mask = None
835
+
836
+ # partition windows
837
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
838
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
839
+
840
+ # W-MSA/SW-MSA
841
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
842
+
843
+ # merge windows
844
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
845
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
846
+
847
+ # reverse cyclic shift
848
+ if self.shift_size > 0:
849
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
850
+ else:
851
+ x = shifted_x
852
+
853
+ if pad_r > 0 or pad_b > 0:
854
+ x = x[:, :H, :W, :].contiguous()
855
+
856
+ x = x.view(B, H * W, C)
857
+
858
+ # FFN
859
+ x = shortcut + self.drop_path(x)
860
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
861
+
862
+ return x
863
+
864
+
865
+ class PatchMerging(nn.Module):
866
+ """ Patch Merging Layer
867
+
868
+ Args:
869
+ dim (int): Number of input channels.
870
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
871
+ """
872
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
873
+ super().__init__()
874
+ self.dim = dim
875
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
876
+ self.norm = norm_layer(4 * dim)
877
+
878
+ def forward(self, x, H, W):
879
+ """ Forward function.
880
+
881
+ Args:
882
+ x: Input feature, tensor size (B, H*W, C).
883
+ H, W: Spatial resolution of the input feature.
884
+ """
885
+ B, L, C = x.shape
886
+ assert L == H * W, "input feature has wrong size"
887
+
888
+ x = x.view(B, H, W, C)
889
+
890
+ # padding
891
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
892
+ if pad_input:
893
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
894
+
895
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
896
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
897
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
898
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
899
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
900
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
901
+
902
+ x = self.norm(x)
903
+ x = self.reduction(x)
904
+
905
+ return x
906
+
907
+
908
+ class BasicLayer(nn.Module):
909
+ """ A basic Swin Transformer layer for one stage.
910
+
911
+ Args:
912
+ dim (int): Number of feature channels
913
+ depth (int): Depths of this stage.
914
+ num_heads (int): Number of attention head.
915
+ window_size (int): Local window size. Default: 7.
916
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
917
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
918
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
919
+ drop (float, optional): Dropout rate. Default: 0.0
920
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
921
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
922
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
923
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
924
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
925
+ """
926
+
927
+ def __init__(self,
928
+ dim,
929
+ depth,
930
+ num_heads,
931
+ window_size=7,
932
+ mlp_ratio=4.,
933
+ qkv_bias=True,
934
+ qk_scale=None,
935
+ drop=0.,
936
+ attn_drop=0.,
937
+ drop_path=0.,
938
+ norm_layer=nn.LayerNorm,
939
+ downsample=None,
940
+ use_checkpoint=False):
941
+ super().__init__()
942
+ self.window_size = window_size
943
+ self.shift_size = window_size // 2
944
+ self.depth = depth
945
+ self.use_checkpoint = use_checkpoint
946
+
947
+ # build blocks
948
+ self.blocks = nn.ModuleList([
949
+ SwinTransformerBlock(
950
+ dim=dim,
951
+ num_heads=num_heads,
952
+ window_size=window_size,
953
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
954
+ mlp_ratio=mlp_ratio,
955
+ qkv_bias=qkv_bias,
956
+ qk_scale=qk_scale,
957
+ drop=drop,
958
+ attn_drop=attn_drop,
959
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
960
+ norm_layer=norm_layer)
961
+ for i in range(depth)])
962
+
963
+ # patch merging layer
964
+ if downsample is not None:
965
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
966
+ else:
967
+ self.downsample = None
968
+
969
+ def forward(self, x, H, W):
970
+ """ Forward function.
971
+
972
+ Args:
973
+ x: Input feature, tensor size (B, H*W, C).
974
+ H, W: Spatial resolution of the input feature.
975
+ """
976
+
977
+ # calculate attention mask for SW-MSA
978
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
979
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
980
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
981
+ h_slices = (slice(0, -self.window_size),
982
+ slice(-self.window_size, -self.shift_size),
983
+ slice(-self.shift_size, None))
984
+ w_slices = (slice(0, -self.window_size),
985
+ slice(-self.window_size, -self.shift_size),
986
+ slice(-self.shift_size, None))
987
+ cnt = 0
988
+ for h in h_slices:
989
+ for w in w_slices:
990
+ img_mask[:, h, w, :] = cnt
991
+ cnt += 1
992
+
993
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
994
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
995
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
996
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
997
+
998
+ for blk in self.blocks:
999
+ blk.H, blk.W = H, W
1000
+ if self.use_checkpoint:
1001
+ x = checkpoint.checkpoint(blk, x, attn_mask)
1002
+ else:
1003
+ x = blk(x, attn_mask)
1004
+ if self.downsample is not None:
1005
+ x_down = self.downsample(x, H, W)
1006
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
1007
+ return x, H, W, x_down, Wh, Ww
1008
+ else:
1009
+ return x, H, W, x, H, W
1010
+
1011
+
1012
+ class PatchEmbed(nn.Module):
1013
+ """ Image to Patch Embedding
1014
+
1015
+ Args:
1016
+ patch_size (int): Patch token size. Default: 4.
1017
+ in_channels (int): Number of input image channels. Default: 3.
1018
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1019
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
1020
+ """
1021
+
1022
+ def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None):
1023
+ super().__init__()
1024
+ patch_size = to_2tuple(patch_size)
1025
+ self.patch_size = patch_size
1026
+
1027
+ self.in_channels = in_channels
1028
+ self.embed_dim = embed_dim
1029
+
1030
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
1031
+ if norm_layer is not None:
1032
+ self.norm = norm_layer(embed_dim)
1033
+ else:
1034
+ self.norm = None
1035
+
1036
+ def forward(self, x):
1037
+ """Forward function."""
1038
+ # padding
1039
+ _, _, H, W = x.size()
1040
+ if W % self.patch_size[1] != 0:
1041
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
1042
+ if H % self.patch_size[0] != 0:
1043
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
1044
+
1045
+ x = self.proj(x) # B C Wh Ww
1046
+ if self.norm is not None:
1047
+ Wh, Ww = x.size(2), x.size(3)
1048
+ x = x.flatten(2).transpose(1, 2)
1049
+ x = self.norm(x)
1050
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
1051
+
1052
+ return x
1053
+
1054
+
1055
+ class SwinTransformer(nn.Module):
1056
+ """ Swin Transformer backbone.
1057
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
1058
+ https://arxiv.org/pdf/2103.14030
1059
+
1060
+ Args:
1061
+ pretrain_img_size (int): Input image size for training the pretrained model,
1062
+ used in absolute postion embedding. Default 224.
1063
+ patch_size (int | tuple(int)): Patch size. Default: 4.
1064
+ in_channels (int): Number of input image channels. Default: 3.
1065
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1066
+ depths (tuple[int]): Depths of each Swin Transformer stage.
1067
+ num_heads (tuple[int]): Number of attention head of each stage.
1068
+ window_size (int): Window size. Default: 7.
1069
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
1070
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
1071
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
1072
+ drop_rate (float): Dropout rate.
1073
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
1074
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
1075
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
1076
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
1077
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
1078
+ out_indices (Sequence[int]): Output from which stages.
1079
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
1080
+ -1 means not freezing any parameters.
1081
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1082
+ """
1083
+
1084
+ def __init__(self,
1085
+ pretrain_img_size=224,
1086
+ patch_size=4,
1087
+ in_channels=3,
1088
+ embed_dim=96,
1089
+ depths=[2, 2, 6, 2],
1090
+ num_heads=[3, 6, 12, 24],
1091
+ window_size=7,
1092
+ mlp_ratio=4.,
1093
+ qkv_bias=True,
1094
+ qk_scale=None,
1095
+ drop_rate=0.,
1096
+ attn_drop_rate=0.,
1097
+ drop_path_rate=0.2,
1098
+ norm_layer=nn.LayerNorm,
1099
+ ape=False,
1100
+ patch_norm=True,
1101
+ out_indices=(0, 1, 2, 3),
1102
+ frozen_stages=-1,
1103
+ use_checkpoint=False):
1104
+ super().__init__()
1105
+
1106
+ self.pretrain_img_size = pretrain_img_size
1107
+ self.num_layers = len(depths)
1108
+ self.embed_dim = embed_dim
1109
+ self.ape = ape
1110
+ self.patch_norm = patch_norm
1111
+ self.out_indices = out_indices
1112
+ self.frozen_stages = frozen_stages
1113
+
1114
+ # split image into non-overlapping patches
1115
+ self.patch_embed = PatchEmbed(
1116
+ patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
1117
+ norm_layer=norm_layer if self.patch_norm else None)
1118
+
1119
+ # absolute position embedding
1120
+ if self.ape:
1121
+ pretrain_img_size = to_2tuple(pretrain_img_size)
1122
+ patch_size = to_2tuple(patch_size)
1123
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
1124
+
1125
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
1126
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1127
+
1128
+ self.pos_drop = nn.Dropout(p=drop_rate)
1129
+
1130
+ # stochastic depth
1131
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1132
+
1133
+ # build layers
1134
+ self.layers = nn.ModuleList()
1135
+ for i_layer in range(self.num_layers):
1136
+ layer = BasicLayer(
1137
+ dim=int(embed_dim * 2 ** i_layer),
1138
+ depth=depths[i_layer],
1139
+ num_heads=num_heads[i_layer],
1140
+ window_size=window_size,
1141
+ mlp_ratio=mlp_ratio,
1142
+ qkv_bias=qkv_bias,
1143
+ qk_scale=qk_scale,
1144
+ drop=drop_rate,
1145
+ attn_drop=attn_drop_rate,
1146
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
1147
+ norm_layer=norm_layer,
1148
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
1149
+ use_checkpoint=use_checkpoint)
1150
+ self.layers.append(layer)
1151
+
1152
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
1153
+ self.num_features = num_features
1154
+
1155
+ # add a norm layer for each output
1156
+ for i_layer in out_indices:
1157
+ layer = norm_layer(num_features[i_layer])
1158
+ layer_name = f'norm{i_layer}'
1159
+ self.add_module(layer_name, layer)
1160
+
1161
+ self._freeze_stages()
1162
+
1163
+ def _freeze_stages(self):
1164
+ if self.frozen_stages >= 0:
1165
+ self.patch_embed.eval()
1166
+ for param in self.patch_embed.parameters():
1167
+ param.requires_grad = False
1168
+
1169
+ if self.frozen_stages >= 1 and self.ape:
1170
+ self.absolute_pos_embed.requires_grad = False
1171
+
1172
+ if self.frozen_stages >= 2:
1173
+ self.pos_drop.eval()
1174
+ for i in range(0, self.frozen_stages - 1):
1175
+ m = self.layers[i]
1176
+ m.eval()
1177
+ for param in m.parameters():
1178
+ param.requires_grad = False
1179
+
1180
+
1181
+ def forward(self, x):
1182
+ """Forward function."""
1183
+ x = self.patch_embed(x)
1184
+
1185
+ Wh, Ww = x.size(2), x.size(3)
1186
+ if self.ape:
1187
+ # interpolate the position embedding to the corresponding size
1188
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1189
+ x = (x + absolute_pos_embed) # B Wh*Ww C
1190
+
1191
+ outs = []#x.contiguous()]
1192
+ x = x.flatten(2).transpose(1, 2)
1193
+ x = self.pos_drop(x)
1194
+ for i in range(self.num_layers):
1195
+ layer = self.layers[i]
1196
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
1197
+
1198
+ if i in self.out_indices:
1199
+ norm_layer = getattr(self, f'norm{i}')
1200
+ x_out = norm_layer(x_out)
1201
+
1202
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
1203
+ outs.append(out)
1204
+
1205
+ return tuple(outs)
1206
+
1207
+ def train(self, mode=True):
1208
+ """Convert the model into training mode while keep layers freezed."""
1209
+ super(SwinTransformer, self).train(mode)
1210
+ self._freeze_stages()
1211
+
1212
+ def swin_v1_t():
1213
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], window_size=7)
1214
+ return model
1215
+
1216
+ def swin_v1_s():
1217
+ model = SwinTransformer(embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7)
1218
+ return model
1219
+
1220
+ def swin_v1_b():
1221
+ model = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12)
1222
+ return model
1223
+
1224
+ def swin_v1_l():
1225
+ model = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12)
1226
+ return model
1227
+
1228
+
1229
+
1230
+ ### models/modules/deform_conv.py
1231
+
1232
+ import torch
1233
+ import torch.nn as nn
1234
+ from torchvision.ops import deform_conv2d
1235
+
1236
+
1237
+ class DeformableConv2d(nn.Module):
1238
+ def __init__(self,
1239
+ in_channels,
1240
+ out_channels,
1241
+ kernel_size=3,
1242
+ stride=1,
1243
+ padding=1,
1244
+ bias=False):
1245
+
1246
+ super(DeformableConv2d, self).__init__()
1247
+
1248
+ assert type(kernel_size) == tuple or type(kernel_size) == int
1249
+
1250
+ kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size)
1251
+ self.stride = stride if type(stride) == tuple else (stride, stride)
1252
+ self.padding = padding
1253
+
1254
+ self.offset_conv = nn.Conv2d(in_channels,
1255
+ 2 * kernel_size[0] * kernel_size[1],
1256
+ kernel_size=kernel_size,
1257
+ stride=stride,
1258
+ padding=self.padding,
1259
+ bias=True)
1260
+
1261
+ nn.init.constant_(self.offset_conv.weight, 0.)
1262
+ nn.init.constant_(self.offset_conv.bias, 0.)
1263
+
1264
+ self.modulator_conv = nn.Conv2d(in_channels,
1265
+ 1 * kernel_size[0] * kernel_size[1],
1266
+ kernel_size=kernel_size,
1267
+ stride=stride,
1268
+ padding=self.padding,
1269
+ bias=True)
1270
+
1271
+ nn.init.constant_(self.modulator_conv.weight, 0.)
1272
+ nn.init.constant_(self.modulator_conv.bias, 0.)
1273
+
1274
+ self.regular_conv = nn.Conv2d(in_channels,
1275
+ out_channels=out_channels,
1276
+ kernel_size=kernel_size,
1277
+ stride=stride,
1278
+ padding=self.padding,
1279
+ bias=bias)
1280
+
1281
+ def forward(self, x):
1282
+ #h, w = x.shape[2:]
1283
+ #max_offset = max(h, w)/4.
1284
+
1285
+ offset = self.offset_conv(x)#.clamp(-max_offset, max_offset)
1286
+ modulator = 2. * torch.sigmoid(self.modulator_conv(x))
1287
+
1288
+ x = deform_conv2d(
1289
+ input=x,
1290
+ offset=offset,
1291
+ weight=self.regular_conv.weight,
1292
+ bias=self.regular_conv.bias,
1293
+ padding=self.padding,
1294
+ mask=modulator,
1295
+ stride=self.stride,
1296
+ )
1297
+ return x
1298
+
1299
+
1300
+
1301
+
1302
+ ### utils.py
1303
+
1304
+ import torch.nn as nn
1305
+
1306
+
1307
+ def build_act_layer(act_layer):
1308
+ if act_layer == 'ReLU':
1309
+ return nn.ReLU(inplace=True)
1310
+ elif act_layer == 'SiLU':
1311
+ return nn.SiLU(inplace=True)
1312
+ elif act_layer == 'GELU':
1313
+ return nn.GELU()
1314
+
1315
+ raise NotImplementedError(f'build_act_layer does not support {act_layer}')
1316
+
1317
+
1318
+ def build_norm_layer(dim,
1319
+ norm_layer,
1320
+ in_format='channels_last',
1321
+ out_format='channels_last',
1322
+ eps=1e-6):
1323
+ layers = []
1324
+ if norm_layer == 'BN':
1325
+ if in_format == 'channels_last':
1326
+ layers.append(to_channels_first())
1327
+ layers.append(nn.BatchNorm2d(dim))
1328
+ if out_format == 'channels_last':
1329
+ layers.append(to_channels_last())
1330
+ elif norm_layer == 'LN':
1331
+ if in_format == 'channels_first':
1332
+ layers.append(to_channels_last())
1333
+ layers.append(nn.LayerNorm(dim, eps=eps))
1334
+ if out_format == 'channels_first':
1335
+ layers.append(to_channels_first())
1336
+ else:
1337
+ raise NotImplementedError(
1338
+ f'build_norm_layer does not support {norm_layer}')
1339
+ return nn.Sequential(*layers)
1340
+
1341
+
1342
+ class to_channels_first(nn.Module):
1343
+
1344
+ def __init__(self):
1345
+ super().__init__()
1346
+
1347
+ def forward(self, x):
1348
+ return x.permute(0, 3, 1, 2)
1349
+
1350
+
1351
+ class to_channels_last(nn.Module):
1352
+
1353
+ def __init__(self):
1354
+ super().__init__()
1355
+
1356
+ def forward(self, x):
1357
+ return x.permute(0, 2, 3, 1)
1358
+
1359
+
1360
+
1361
+ ### dataset.py
1362
+
1363
+ _class_labels_TR_sorted = (
1364
+ 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, '
1365
+ 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, '
1366
+ 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, '
1367
+ 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, '
1368
+ 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, '
1369
+ 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, '
1370
+ 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, '
1371
+ 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, '
1372
+ 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, '
1373
+ 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, '
1374
+ 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, '
1375
+ 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, '
1376
+ 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, '
1377
+ 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht'
1378
+ )
1379
+ class_labels_TR_sorted = _class_labels_TR_sorted.split(', ')
1380
+
1381
+
1382
+ ### models/backbones/build_backbones.py
1383
+
1384
+ import torch
1385
+ import torch.nn as nn
1386
+ from collections import OrderedDict
1387
+ from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights
1388
+ # from models.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5
1389
+ # from models.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l
1390
+ # from config import Config
1391
+
1392
+
1393
+ config = Config()
1394
+
1395
+ def build_backbone(bb_name, pretrained=True, params_settings=''):
1396
+ if bb_name == 'vgg16':
1397
+ bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0]
1398
+ bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]}))
1399
+ elif bb_name == 'vgg16bn':
1400
+ bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0]
1401
+ bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]}))
1402
+ elif bb_name == 'resnet50':
1403
+ bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children())
1404
+ bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]}))
1405
+ else:
1406
+ bb = eval('{}({})'.format(bb_name, params_settings))
1407
+ if pretrained:
1408
+ bb = load_weights(bb, bb_name)
1409
+ return bb
1410
+
1411
+ def load_weights(model, model_name):
1412
+ save_model = torch.load(config.weights[model_name], map_location='cpu')
1413
+ model_dict = model.state_dict()
1414
+ state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()}
1415
+ # to ignore the weights with mismatched size when I modify the backbone itself.
1416
+ if not state_dict:
1417
+ save_model_keys = list(save_model.keys())
1418
+ sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None
1419
+ state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()}
1420
+ if not state_dict or not sub_item:
1421
+ print('Weights are not successully loaded. Check the state dict of weights file.')
1422
+ return None
1423
+ else:
1424
+ print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item))
1425
+ model_dict.update(state_dict)
1426
+ model.load_state_dict(model_dict)
1427
+ return model
1428
+
1429
+
1430
+
1431
+ ### models/modules/decoder_blocks.py
1432
+
1433
+ import torch
1434
+ import torch.nn as nn
1435
+ # from models.aspp import ASPP, ASPPDeformable
1436
+ # from config import Config
1437
+
1438
+
1439
+ # config = Config()
1440
+
1441
+
1442
+ class BasicDecBlk(nn.Module):
1443
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
1444
+ super(BasicDecBlk, self).__init__()
1445
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1446
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
1447
+ self.relu_in = nn.ReLU(inplace=True)
1448
+ if config.dec_att == 'ASPP':
1449
+ self.dec_att = ASPP(in_channels=inter_channels)
1450
+ elif config.dec_att == 'ASPPDeformable':
1451
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
1452
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1453
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
1454
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1455
+
1456
+ def forward(self, x):
1457
+ x = self.conv_in(x)
1458
+ x = self.bn_in(x)
1459
+ x = self.relu_in(x)
1460
+ if hasattr(self, 'dec_att'):
1461
+ x = self.dec_att(x)
1462
+ x = self.conv_out(x)
1463
+ x = self.bn_out(x)
1464
+ return x
1465
+
1466
+
1467
+ class ResBlk(nn.Module):
1468
+ def __init__(self, in_channels=64, out_channels=None, inter_channels=64):
1469
+ super(ResBlk, self).__init__()
1470
+ if out_channels is None:
1471
+ out_channels = in_channels
1472
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1473
+
1474
+ self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1)
1475
+ self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity()
1476
+ self.relu_in = nn.ReLU(inplace=True)
1477
+
1478
+ if config.dec_att == 'ASPP':
1479
+ self.dec_att = ASPP(in_channels=inter_channels)
1480
+ elif config.dec_att == 'ASPPDeformable':
1481
+ self.dec_att = ASPPDeformable(in_channels=inter_channels)
1482
+
1483
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1)
1484
+ self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1485
+
1486
+ self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1487
+
1488
+ def forward(self, x):
1489
+ _x = self.conv_resi(x)
1490
+ x = self.conv_in(x)
1491
+ x = self.bn_in(x)
1492
+ x = self.relu_in(x)
1493
+ if hasattr(self, 'dec_att'):
1494
+ x = self.dec_att(x)
1495
+ x = self.conv_out(x)
1496
+ x = self.bn_out(x)
1497
+ return x + _x
1498
+
1499
+
1500
+
1501
+ ### models/modules/lateral_blocks.py
1502
+
1503
+ import numpy as np
1504
+ import torch
1505
+ import torch.nn as nn
1506
+ import torch.nn.functional as F
1507
+ from functools import partial
1508
+
1509
+ # from config import Config
1510
+
1511
+
1512
+ # config = Config()
1513
+
1514
+
1515
+ class BasicLatBlk(nn.Module):
1516
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64):
1517
+ super(BasicLatBlk, self).__init__()
1518
+ inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64
1519
+ self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0)
1520
+
1521
+ def forward(self, x):
1522
+ x = self.conv(x)
1523
+ return x
1524
+
1525
+
1526
+
1527
+ ### models/modules/aspp.py
1528
+
1529
+ import torch
1530
+ import torch.nn as nn
1531
+ import torch.nn.functional as F
1532
+ # from models.deform_conv import DeformableConv2d
1533
+ # from config import Config
1534
+
1535
+
1536
+ # config = Config()
1537
+
1538
+
1539
+ class _ASPPModule(nn.Module):
1540
+ def __init__(self, in_channels, planes, kernel_size, padding, dilation):
1541
+ super(_ASPPModule, self).__init__()
1542
+ self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size,
1543
+ stride=1, padding=padding, dilation=dilation, bias=False)
1544
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
1545
+ self.relu = nn.ReLU(inplace=True)
1546
+
1547
+ def forward(self, x):
1548
+ x = self.atrous_conv(x)
1549
+ x = self.bn(x)
1550
+
1551
+ return self.relu(x)
1552
+
1553
+
1554
+ class ASPP(nn.Module):
1555
+ def __init__(self, in_channels=64, out_channels=None, output_stride=16):
1556
+ super(ASPP, self).__init__()
1557
+ self.down_scale = 1
1558
+ if out_channels is None:
1559
+ out_channels = in_channels
1560
+ self.in_channelster = 256 // self.down_scale
1561
+ if output_stride == 16:
1562
+ dilations = [1, 6, 12, 18]
1563
+ elif output_stride == 8:
1564
+ dilations = [1, 12, 24, 36]
1565
+ else:
1566
+ raise NotImplementedError
1567
+
1568
+ self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0])
1569
+ self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1])
1570
+ self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2])
1571
+ self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3])
1572
+
1573
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
1574
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
1575
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
1576
+ nn.ReLU(inplace=True))
1577
+ self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False)
1578
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1579
+ self.relu = nn.ReLU(inplace=True)
1580
+ self.dropout = nn.Dropout(0.5)
1581
+
1582
+ def forward(self, x):
1583
+ x1 = self.aspp1(x)
1584
+ x2 = self.aspp2(x)
1585
+ x3 = self.aspp3(x)
1586
+ x4 = self.aspp4(x)
1587
+ x5 = self.global_avg_pool(x)
1588
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
1589
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
1590
+
1591
+ x = self.conv1(x)
1592
+ x = self.bn1(x)
1593
+ x = self.relu(x)
1594
+
1595
+ return self.dropout(x)
1596
+
1597
+
1598
+ ##################### Deformable
1599
+ class _ASPPModuleDeformable(nn.Module):
1600
+ def __init__(self, in_channels, planes, kernel_size, padding):
1601
+ super(_ASPPModuleDeformable, self).__init__()
1602
+ self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
1603
+ stride=1, padding=padding, bias=False)
1604
+ self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity()
1605
+ self.relu = nn.ReLU(inplace=True)
1606
+
1607
+ def forward(self, x):
1608
+ x = self.atrous_conv(x)
1609
+ x = self.bn(x)
1610
+
1611
+ return self.relu(x)
1612
+
1613
+
1614
+ class ASPPDeformable(nn.Module):
1615
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]):
1616
+ super(ASPPDeformable, self).__init__()
1617
+ self.down_scale = 1
1618
+ if out_channels is None:
1619
+ out_channels = in_channels
1620
+ self.in_channelster = 256 // self.down_scale
1621
+
1622
+ self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0)
1623
+ self.aspp_deforms = nn.ModuleList([
1624
+ _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes
1625
+ ])
1626
+
1627
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
1628
+ nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False),
1629
+ nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(),
1630
+ nn.ReLU(inplace=True))
1631
+ self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False)
1632
+ self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity()
1633
+ self.relu = nn.ReLU(inplace=True)
1634
+ self.dropout = nn.Dropout(0.5)
1635
+
1636
+ def forward(self, x):
1637
+ x1 = self.aspp1(x)
1638
+ x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
1639
+ x5 = self.global_avg_pool(x)
1640
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
1641
+ x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
1642
+
1643
+ x = self.conv1(x)
1644
+ x = self.bn1(x)
1645
+ x = self.relu(x)
1646
+
1647
+ return self.dropout(x)
1648
+
1649
+
1650
+
1651
+ ### models/refinement/refiner.py
1652
+
1653
+ import torch
1654
+ import torch.nn as nn
1655
+ from collections import OrderedDict
1656
+ import torch
1657
+ import torch.nn as nn
1658
+ import torch.nn.functional as F
1659
+ from torchvision.models import vgg16, vgg16_bn
1660
+ from torchvision.models import resnet50
1661
+
1662
+ # from config import Config
1663
+ # from dataset import class_labels_TR_sorted
1664
+ # from models.build_backbone import build_backbone
1665
+ # from models.decoder_blocks import BasicDecBlk
1666
+ # from models.lateral_blocks import BasicLatBlk
1667
+ # from models.ing import *
1668
+ # from models.stem_layer import StemLayer
1669
+
1670
+
1671
+ class RefinerPVTInChannels4(nn.Module):
1672
+ def __init__(self, in_channels=3+1):
1673
+ super(RefinerPVTInChannels4, self).__init__()
1674
+ self.config = Config()
1675
+ self.epoch = 1
1676
+ self.bb = build_backbone(self.config.bb, params_settings='in_channels=4')
1677
+
1678
+ lateral_channels_in_collection = {
1679
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
1680
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
1681
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
1682
+ }
1683
+ channels = lateral_channels_in_collection[self.config.bb]
1684
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
1685
+
1686
+ self.decoder = Decoder(channels)
1687
+
1688
+ if 0:
1689
+ for key, value in self.named_parameters():
1690
+ if 'bb.' in key:
1691
+ value.requires_grad = False
1692
+
1693
+ def forward(self, x):
1694
+ if isinstance(x, list):
1695
+ x = torch.cat(x, dim=1)
1696
+ ########## Encoder ##########
1697
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
1698
+ x1 = self.bb.conv1(x)
1699
+ x2 = self.bb.conv2(x1)
1700
+ x3 = self.bb.conv3(x2)
1701
+ x4 = self.bb.conv4(x3)
1702
+ else:
1703
+ x1, x2, x3, x4 = self.bb(x)
1704
+
1705
+ x4 = self.squeeze_module(x4)
1706
+
1707
+ ########## Decoder ##########
1708
+
1709
+ features = [x, x1, x2, x3, x4]
1710
+ scaled_preds = self.decoder(features)
1711
+
1712
+ return scaled_preds
1713
+
1714
+
1715
+ class Refiner(nn.Module):
1716
+ def __init__(self, in_channels=3+1):
1717
+ super(Refiner, self).__init__()
1718
+ self.config = Config()
1719
+ self.epoch = 1
1720
+ self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
1721
+ self.bb = build_backbone(self.config.bb)
1722
+
1723
+ lateral_channels_in_collection = {
1724
+ 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64],
1725
+ 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64],
1726
+ 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192],
1727
+ }
1728
+ channels = lateral_channels_in_collection[self.config.bb]
1729
+ self.squeeze_module = BasicDecBlk(channels[0], channels[0])
1730
+
1731
+ self.decoder = Decoder(channels)
1732
+
1733
+ if 0:
1734
+ for key, value in self.named_parameters():
1735
+ if 'bb.' in key:
1736
+ value.requires_grad = False
1737
+
1738
+ def forward(self, x):
1739
+ if isinstance(x, list):
1740
+ x = torch.cat(x, dim=1)
1741
+ x = self.stem_layer(x)
1742
+ ########## Encoder ##########
1743
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
1744
+ x1 = self.bb.conv1(x)
1745
+ x2 = self.bb.conv2(x1)
1746
+ x3 = self.bb.conv3(x2)
1747
+ x4 = self.bb.conv4(x3)
1748
+ else:
1749
+ x1, x2, x3, x4 = self.bb(x)
1750
+
1751
+ x4 = self.squeeze_module(x4)
1752
+
1753
+ ########## Decoder ##########
1754
+
1755
+ features = [x, x1, x2, x3, x4]
1756
+ scaled_preds = self.decoder(features)
1757
+
1758
+ return scaled_preds
1759
+
1760
+
1761
+ class Decoder(nn.Module):
1762
+ def __init__(self, channels):
1763
+ super(Decoder, self).__init__()
1764
+ self.config = Config()
1765
+ DecoderBlock = eval('BasicDecBlk')
1766
+ LateralBlock = eval('BasicLatBlk')
1767
+
1768
+ self.decoder_block4 = DecoderBlock(channels[0], channels[1])
1769
+ self.decoder_block3 = DecoderBlock(channels[1], channels[2])
1770
+ self.decoder_block2 = DecoderBlock(channels[2], channels[3])
1771
+ self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2)
1772
+
1773
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
1774
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
1775
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
1776
+
1777
+ if self.config.ms_supervision:
1778
+ self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
1779
+ self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
1780
+ self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
1781
+ self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0))
1782
+
1783
+ def forward(self, features):
1784
+ x, x1, x2, x3, x4 = features
1785
+ outs = []
1786
+ p4 = self.decoder_block4(x4)
1787
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
1788
+ _p3 = _p4 + self.lateral_block4(x3)
1789
+
1790
+ p3 = self.decoder_block3(_p3)
1791
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
1792
+ _p2 = _p3 + self.lateral_block3(x2)
1793
+
1794
+ p2 = self.decoder_block2(_p2)
1795
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
1796
+ _p1 = _p2 + self.lateral_block2(x1)
1797
+
1798
+ _p1 = self.decoder_block1(_p1)
1799
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
1800
+ p1_out = self.conv_out1(_p1)
1801
+
1802
+ if self.config.ms_supervision:
1803
+ outs.append(self.conv_ms_spvn_4(p4))
1804
+ outs.append(self.conv_ms_spvn_3(p3))
1805
+ outs.append(self.conv_ms_spvn_2(p2))
1806
+ outs.append(p1_out)
1807
+ return outs
1808
+
1809
+
1810
+ class RefUNet(nn.Module):
1811
+ # Refinement
1812
+ def __init__(self, in_channels=3+1):
1813
+ super(RefUNet, self).__init__()
1814
+ self.encoder_1 = nn.Sequential(
1815
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
1816
+ nn.Conv2d(64, 64, 3, 1, 1),
1817
+ nn.BatchNorm2d(64),
1818
+ nn.ReLU(inplace=True)
1819
+ )
1820
+
1821
+ self.encoder_2 = nn.Sequential(
1822
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1823
+ nn.Conv2d(64, 64, 3, 1, 1),
1824
+ nn.BatchNorm2d(64),
1825
+ nn.ReLU(inplace=True)
1826
+ )
1827
+
1828
+ self.encoder_3 = nn.Sequential(
1829
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1830
+ nn.Conv2d(64, 64, 3, 1, 1),
1831
+ nn.BatchNorm2d(64),
1832
+ nn.ReLU(inplace=True)
1833
+ )
1834
+
1835
+ self.encoder_4 = nn.Sequential(
1836
+ nn.MaxPool2d(2, 2, ceil_mode=True),
1837
+ nn.Conv2d(64, 64, 3, 1, 1),
1838
+ nn.BatchNorm2d(64),
1839
+ nn.ReLU(inplace=True)
1840
+ )
1841
+
1842
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
1843
+ #####
1844
+ self.decoder_5 = nn.Sequential(
1845
+ nn.Conv2d(64, 64, 3, 1, 1),
1846
+ nn.BatchNorm2d(64),
1847
+ nn.ReLU(inplace=True)
1848
+ )
1849
+ #####
1850
+ self.decoder_4 = nn.Sequential(
1851
+ nn.Conv2d(128, 64, 3, 1, 1),
1852
+ nn.BatchNorm2d(64),
1853
+ nn.ReLU(inplace=True)
1854
+ )
1855
+
1856
+ self.decoder_3 = nn.Sequential(
1857
+ nn.Conv2d(128, 64, 3, 1, 1),
1858
+ nn.BatchNorm2d(64),
1859
+ nn.ReLU(inplace=True)
1860
+ )
1861
+
1862
+ self.decoder_2 = nn.Sequential(
1863
+ nn.Conv2d(128, 64, 3, 1, 1),
1864
+ nn.BatchNorm2d(64),
1865
+ nn.ReLU(inplace=True)
1866
+ )
1867
+
1868
+ self.decoder_1 = nn.Sequential(
1869
+ nn.Conv2d(128, 64, 3, 1, 1),
1870
+ nn.BatchNorm2d(64),
1871
+ nn.ReLU(inplace=True)
1872
+ )
1873
+
1874
+ self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1)
1875
+
1876
+ self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1877
+
1878
+ def forward(self, x):
1879
+ outs = []
1880
+ if isinstance(x, list):
1881
+ x = torch.cat(x, dim=1)
1882
+ hx = x
1883
+
1884
+ hx1 = self.encoder_1(hx)
1885
+ hx2 = self.encoder_2(hx1)
1886
+ hx3 = self.encoder_3(hx2)
1887
+ hx4 = self.encoder_4(hx3)
1888
+
1889
+ hx = self.decoder_5(self.pool4(hx4))
1890
+ hx = torch.cat((self.upscore2(hx), hx4), 1)
1891
+
1892
+ d4 = self.decoder_4(hx)
1893
+ hx = torch.cat((self.upscore2(d4), hx3), 1)
1894
+
1895
+ d3 = self.decoder_3(hx)
1896
+ hx = torch.cat((self.upscore2(d3), hx2), 1)
1897
+
1898
+ d2 = self.decoder_2(hx)
1899
+ hx = torch.cat((self.upscore2(d2), hx1), 1)
1900
+
1901
+ d1 = self.decoder_1(hx)
1902
+
1903
+ x = self.conv_d0(d1)
1904
+ outs.append(x)
1905
+ return outs
1906
+
1907
+
1908
+
1909
+ ### models/stem_layer.py
1910
+
1911
+ import torch.nn as nn
1912
+ # from utils import build_act_layer, build_norm_layer
1913
+
1914
+
1915
+ class StemLayer(nn.Module):
1916
+ r""" Stem layer of InternImage
1917
+ Args:
1918
+ in_channels (int): number of input channels
1919
+ out_channels (int): number of output channels
1920
+ act_layer (str): activation layer
1921
+ norm_layer (str): normalization layer
1922
+ """
1923
+
1924
+ def __init__(self,
1925
+ in_channels=3+1,
1926
+ inter_channels=48,
1927
+ out_channels=96,
1928
+ act_layer='GELU',
1929
+ norm_layer='BN'):
1930
+ super().__init__()
1931
+ self.conv1 = nn.Conv2d(in_channels,
1932
+ inter_channels,
1933
+ kernel_size=3,
1934
+ stride=1,
1935
+ padding=1)
1936
+ self.norm1 = build_norm_layer(
1937
+ inter_channels, norm_layer, 'channels_first', 'channels_first'
1938
+ )
1939
+ self.act = build_act_layer(act_layer)
1940
+ self.conv2 = nn.Conv2d(inter_channels,
1941
+ out_channels,
1942
+ kernel_size=3,
1943
+ stride=1,
1944
+ padding=1)
1945
+ self.norm2 = build_norm_layer(
1946
+ out_channels, norm_layer, 'channels_first', 'channels_first'
1947
+ )
1948
+
1949
+ def forward(self, x):
1950
+ x = self.conv1(x)
1951
+ x = self.norm1(x)
1952
+ x = self.act(x)
1953
+ x = self.conv2(x)
1954
+ x = self.norm2(x)
1955
+ return x
1956
+
1957
+
1958
+ ### models/birefnet.py
1959
+
1960
+ import torch
1961
+ import torch.nn as nn
1962
+ import torch.nn.functional as F
1963
+ from kornia.filters import laplacian
1964
+ from transformers import PreTrainedModel
1965
+
1966
+ # from config import Config
1967
+ # from dataset import class_labels_TR_sorted
1968
+ # from models.build_backbone import build_backbone
1969
+ # from models.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk
1970
+ # from models.lateral_blocks import BasicLatBlk
1971
+ # from models.aspp import ASPP, ASPPDeformable
1972
+ # from models.ing import *
1973
+ # from models.refiner import Refiner, RefinerPVTInChannels4, RefUNet
1974
+ # from models.stem_layer import StemLayer
1975
+ from .BiRefNet_config import BiRefNetConfig
1976
+
1977
+
1978
+ class BiRefNet(
1979
+ PreTrainedModel
1980
+ ):
1981
+ config_class = BiRefNetConfig
1982
+ def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1983
+ super(BiRefNet, self).__init__(config)
1984
+ bb_pretrained = config.bb_pretrained
1985
+ self.config = Config()
1986
+ self.epoch = 1
1987
+ self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
1988
+
1989
+ channels = self.config.lateral_channels_in_collection
1990
+
1991
+ if self.config.auxiliary_classification:
1992
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
1993
+ self.cls_head = nn.Sequential(
1994
+ nn.Linear(channels[0], len(class_labels_TR_sorted))
1995
+ )
1996
+
1997
+ if self.config.squeeze_block:
1998
+ self.squeeze_module = nn.Sequential(*[
1999
+ eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0])
2000
+ for _ in range(eval(self.config.squeeze_block.split('_x')[1]))
2001
+ ])
2002
+
2003
+ self.decoder = Decoder(channels)
2004
+
2005
+ if self.config.ender:
2006
+ self.dec_end = nn.Sequential(
2007
+ nn.Conv2d(1, 16, 3, 1, 1),
2008
+ nn.Conv2d(16, 1, 3, 1, 1),
2009
+ nn.ReLU(inplace=True),
2010
+ )
2011
+
2012
+ # refine patch-level segmentation
2013
+ if self.config.refine:
2014
+ if self.config.refine == 'itself':
2015
+ self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN')
2016
+ else:
2017
+ self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1'))
2018
+
2019
+ if self.config.freeze_bb:
2020
+ # Freeze the backbone...
2021
+ print(self.named_parameters())
2022
+ for key, value in self.named_parameters():
2023
+ if 'bb.' in key and 'refiner.' not in key:
2024
+ value.requires_grad = False
2025
+
2026
+ def forward_enc(self, x):
2027
+ if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']:
2028
+ x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3)
2029
+ else:
2030
+ x1, x2, x3, x4 = self.bb(x)
2031
+ if self.config.mul_scl_ipt == 'cat':
2032
+ B, C, H, W = x.shape
2033
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2034
+ x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2035
+ x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2036
+ x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2037
+ x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
2038
+ elif self.config.mul_scl_ipt == 'add':
2039
+ B, C, H, W = x.shape
2040
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
2041
+ x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)
2042
+ x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)
2043
+ x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)
2044
+ x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)
2045
+ class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None
2046
+ if self.config.cxt:
2047
+ x4 = torch.cat(
2048
+ (
2049
+ *[
2050
+ F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
2051
+ F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
2052
+ F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
2053
+ ][-len(self.config.cxt):],
2054
+ x4
2055
+ ),
2056
+ dim=1
2057
+ )
2058
+ return (x1, x2, x3, x4), class_preds
2059
+
2060
+ def forward_ori(self, x):
2061
+ ########## Encoder ##########
2062
+ (x1, x2, x3, x4), class_preds = self.forward_enc(x)
2063
+ if self.config.squeeze_block:
2064
+ x4 = self.squeeze_module(x4)
2065
+ ########## Decoder ##########
2066
+ features = [x, x1, x2, x3, x4]
2067
+ if self.training and self.config.out_ref:
2068
+ features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
2069
+ scaled_preds = self.decoder(features)
2070
+ return scaled_preds, class_preds
2071
+
2072
+ def forward(self, x):
2073
+ scaled_preds, class_preds = self.forward_ori(x)
2074
+ class_preds_lst = [class_preds]
2075
+ return [scaled_preds, class_preds_lst] if self.training else scaled_preds
2076
+
2077
+
2078
+ class Decoder(nn.Module):
2079
+ def __init__(self, channels):
2080
+ super(Decoder, self).__init__()
2081
+ self.config = Config()
2082
+ DecoderBlock = eval(self.config.dec_blk)
2083
+ LateralBlock = eval(self.config.lat_blk)
2084
+
2085
+ if self.config.dec_ipt:
2086
+ self.split = self.config.dec_ipt_split
2087
+ N_dec_ipt = 64
2088
+ DBlock = SimpleConvs
2089
+ ic = 64
2090
+ ipt_cha_opt = 1
2091
+ self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
2092
+ self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
2093
+ self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
2094
+ self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
2095
+ self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
2096
+ else:
2097
+ self.split = None
2098
+
2099
+ self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[1])
2100
+ self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2])
2101
+ self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3])
2102
+ self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2)
2103
+ self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0))
2104
+
2105
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
2106
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
2107
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
2108
+
2109
+ if self.config.ms_supervision:
2110
+ self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0)
2111
+ self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0)
2112
+ self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0)
2113
+
2114
+ if self.config.out_ref:
2115
+ _N = 16
2116
+ self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2117
+ self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2118
+ self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N) if self.config.batch_size > 1 else nn.Identity(), nn.ReLU(inplace=True))
2119
+
2120
+ self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2121
+ self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2122
+ self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2123
+
2124
+ self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2125
+ self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2126
+ self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2127
+
2128
+ def get_patches_batch(self, x, p):
2129
+ _size_h, _size_w = p.shape[2:]
2130
+ patches_batch = []
2131
+ for idx in range(x.shape[0]):
2132
+ columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
2133
+ patches_x = []
2134
+ for column_x in columns_x:
2135
+ patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
2136
+ patch_sample = torch.cat(patches_x, dim=1)
2137
+ patches_batch.append(patch_sample)
2138
+ return torch.cat(patches_batch, dim=0)
2139
+
2140
+ def forward(self, features):
2141
+ if self.training and self.config.out_ref:
2142
+ outs_gdt_pred = []
2143
+ outs_gdt_label = []
2144
+ x, x1, x2, x3, x4, gdt_gt = features
2145
+ else:
2146
+ x, x1, x2, x3, x4 = features
2147
+ outs = []
2148
+
2149
+ if self.config.dec_ipt:
2150
+ patches_batch = self.get_patches_batch(x, x4) if self.split else x
2151
+ x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2152
+ p4 = self.decoder_block4(x4)
2153
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
2154
+ if self.config.out_ref:
2155
+ p4_gdt = self.gdt_convs_4(p4)
2156
+ if self.training:
2157
+ # >> GT:
2158
+ m4_dia = m4
2159
+ gdt_label_main_4 = gdt_gt * F.interpolate(m4_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2160
+ outs_gdt_label.append(gdt_label_main_4)
2161
+ # >> Pred:
2162
+ gdt_pred_4 = self.gdt_convs_pred_4(p4_gdt)
2163
+ outs_gdt_pred.append(gdt_pred_4)
2164
+ gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
2165
+ # >> Finally:
2166
+ p4 = p4 * gdt_attn_4
2167
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
2168
+ _p3 = _p4 + self.lateral_block4(x3)
2169
+
2170
+ if self.config.dec_ipt:
2171
+ patches_batch = self.get_patches_batch(x, _p3) if self.split else x
2172
+ _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2173
+ p3 = self.decoder_block3(_p3)
2174
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
2175
+ if self.config.out_ref:
2176
+ p3_gdt = self.gdt_convs_3(p3)
2177
+ if self.training:
2178
+ # >> GT:
2179
+ # m3 --dilation--> m3_dia
2180
+ # G_3^gt * m3_dia --> G_3^m, which is the label of gradient
2181
+ m3_dia = m3
2182
+ gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2183
+ outs_gdt_label.append(gdt_label_main_3)
2184
+ # >> Pred:
2185
+ # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
2186
+ # F_3^G --sigmoid--> A_3^G
2187
+ gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
2188
+ outs_gdt_pred.append(gdt_pred_3)
2189
+ gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
2190
+ # >> Finally:
2191
+ # p3 = p3 * A_3^G
2192
+ p3 = p3 * gdt_attn_3
2193
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
2194
+ _p2 = _p3 + self.lateral_block3(x2)
2195
+
2196
+ if self.config.dec_ipt:
2197
+ patches_batch = self.get_patches_batch(x, _p2) if self.split else x
2198
+ _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2199
+ p2 = self.decoder_block2(_p2)
2200
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
2201
+ if self.config.out_ref:
2202
+ p2_gdt = self.gdt_convs_2(p2)
2203
+ if self.training:
2204
+ # >> GT:
2205
+ m2_dia = m2
2206
+ gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
2207
+ outs_gdt_label.append(gdt_label_main_2)
2208
+ # >> Pred:
2209
+ gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
2210
+ outs_gdt_pred.append(gdt_pred_2)
2211
+ gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
2212
+ # >> Finally:
2213
+ p2 = p2 * gdt_attn_2
2214
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
2215
+ _p1 = _p2 + self.lateral_block2(x1)
2216
+
2217
+ if self.config.dec_ipt:
2218
+ patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2219
+ _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2220
+ _p1 = self.decoder_block1(_p1)
2221
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2222
+
2223
+ if self.config.dec_ipt:
2224
+ patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2225
+ _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2226
+ p1_out = self.conv_out1(_p1)
2227
+
2228
+ if self.config.ms_supervision:
2229
+ outs.append(m4)
2230
+ outs.append(m3)
2231
+ outs.append(m2)
2232
+ outs.append(p1_out)
2233
+ return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs)
2234
+
2235
+
2236
+ class SimpleConvs(nn.Module):
2237
+ def __init__(
2238
+ self, in_channels: int, out_channels: int, inter_channels=64
2239
+ ) -> None:
2240
+ super().__init__()
2241
+ self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1)
2242
+ self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1)
2243
+
2244
+ def forward(self, x):
2245
+ return self.conv_out(self.conv1(x))
Trellv2/briaai--RMBG-2.0/collage5.png ADDED

Git LFS Details

  • SHA256: f9f802564aa1e3a7c90762c7e65b77007f081cb179cdd9b42607bad3b1fdaf16
  • Pointer size: 132 Bytes
  • Size of remote file: 4.52 MB
Trellv2/briaai--RMBG-2.0/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ZhengPeng7/BiRefNet",
3
+ "architectures": [
4
+ "BiRefNet"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "BiRefNet_config.BiRefNetConfig",
8
+ "AutoModelForImageSegmentation": "birefnet.BiRefNet"
9
+ },
10
+ "custom_pipelines": {
11
+ "image-segmentation": {
12
+ "pt": [
13
+ "AutoModelForImageSegmentation"
14
+ ],
15
+ "tf": [],
16
+ "type": "image"
17
+ }
18
+ },
19
+ "bb_pretrained": false
20
+ }
Trellv2/briaai--RMBG-2.0/diagram1.png ADDED
Trellv2/briaai--RMBG-2.0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:566ed80c3d95f87ada6864d4cbe2290a1c5eb1c7bb0b123e984f60f76b02c3a7
3
+ size 884878856
Trellv2/briaai--RMBG-2.0/onnx/model_bnb4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadc9222fbffa53a348efea52d97475350ecee463a4a46f452e6e6b7b8757d25
3
+ size 355288046
Trellv2/briaai--RMBG-2.0/onnx/model_fp16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9dc47db40d113090ba5d7a13d8fcfd9ee4eda510ce92613219b2fe19da4746f6
3
+ size 513576499
Trellv2/briaai--RMBG-2.0/onnx/model_int8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8ee7690d8c5e7fc45d7b4938ac2fe4eab63fdeddd537673cda2d4c6e74809af
3
+ size 366087445
Trellv2/briaai--RMBG-2.0/onnx/model_q4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a813e0eab56c982b71254214f41fa860cc7b565a6f2aab55d1f99f41c646ece1
3
+ size 367451512
Trellv2/briaai--RMBG-2.0/onnx/model_q4f16.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bfeb5f93220eb19f6747c217b62cf04342840c4e973f55bf64e9762919f446d
3
+ size 233815293
Trellv2/briaai--RMBG-2.0/onnx/model_quantized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcea23951a378f92634834888896cc1eec54655366ae6e949282646ce17c5420
3
+ size 366087549
Trellv2/briaai--RMBG-2.0/onnx/model_uint8.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcea23951a378f92634834888896cc1eec54655366ae6e949282646ce17c5420
3
+ size 366087549
Trellv2/briaai--RMBG-2.0/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "do_rescale": true,
4
+ "do_resize": true,
5
+ "feature_extractor_type": "ViTFeatureExtractor",
6
+ "image_mean": [
7
+ 0.485,
8
+ 0.456,
9
+ 0.406
10
+ ],
11
+ "image_processor_type": "ViTFeatureExtractor",
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 1024,
21
+ "width": 1024
22
+ }
23
+ }
Trellv2/briaai--RMBG-2.0/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0986c2881028a2d0ef9b638ab06bc4cfe7c529760d451eaa7098ade2592015f2
3
+ size 885079136
Trellv2/briaai--RMBG-2.0/t4.png ADDED

Git LFS Details

  • SHA256: 43a9453f567d9bff7fe4481205575bbf302499379047ee6073247315452ba8fb
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/LICENSE.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DINOv3 License
2
+
3
+ *Last Updated: August 19, 2025*
4
+
5
+ **“Agreement”** means the terms and conditions for use, reproduction, distribution and modification of the DINO Materials set forth herein.
6
+
7
+ **“DINO Materials”** means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement.
8
+
9
+ **“Documentation”** means the specifications, manuals and documentation accompanying
10
+ DINO Materials distributed by Meta.
11
+
12
+ **“Licensee”** or **“you”** means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf.
13
+
14
+ **“Meta”** or **“we”** means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland).
15
+
16
+ **“Sanctions”** means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom.
17
+
18
+ **“Trade Controls”** means any of the following: Sanctions and applicable export and import controls.
19
+
20
+ By clicking “I Accept” below or by using or distributing any portion or element of the DINO Materials, you agree to be bound by this Agreement.
21
+
22
+ ## 1. License Rights and Redistribution.
23
+
24
+ a. <ins>Grant of Rights</ins>. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the DINO Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the DINO Materials.
25
+
26
+ b. <ins>Redistribution and Use</ins>.
27
+
28
+ i. Distribution of DINO Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the DINO Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such DINO Materials.
29
+
30
+ ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with DINO Materials, you must acknowledge the use of DINO Materials in your publication.
31
+
32
+ iii. Your use of the DINO Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws.
33
+
34
+ iv. Your use of the DINO Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the DINO Materials.
35
+
36
+ v. You are not the target of Trade Controls and your use of DINO Materials must comply with Trade Controls. You agree not to use, or permit others to use, DINO Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons.
37
+
38
+ ## 2. User Support.
39
+
40
+ Your use of the DINO Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the DINO Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind.
41
+
42
+ ## 3. Disclaimer of Warranty.
43
+
44
+ UNLESS REQUIRED BY APPLICABLE LAW, THE DINO MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE DINO MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE DINO MATERIALS AND ANY OUTPUT AND RESULTS.
45
+
46
+ ## 4. Limitation of Liability.
47
+
48
+ IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
49
+
50
+ ## 5. Intellectual Property.
51
+
52
+ a. Subject to Meta’s ownership of DINO Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the DINO Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
53
+
54
+ b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the DINO Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the DINO Materials.
55
+
56
+ ## 6. Term and Termination.
57
+
58
+ The term of this Agreement will commence upon your acceptance of this Agreement or access to the DINO Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the DINO Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement.
59
+
60
+ ## 7. Governing Law and Jurisdiction.
61
+
62
+ This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
63
+
64
+ ## 8. Modifications and Amendments.
65
+
66
+ Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the DINO Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/README.md ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ extra_gated_fields:
3
+ First Name: text
4
+ Last Name: text
5
+ Date of birth: date_picker
6
+ Country: country
7
+ Affiliation: text
8
+ Job title:
9
+ type: select
10
+ options:
11
+ - Student
12
+ - Research Graduate
13
+ - AI researcher
14
+ - AI developer/engineer
15
+ - Reporter
16
+ - Other
17
+ geo: ip_location
18
+ By clicking Submit below I accept the terms of the license and acknowledge that the information I provide will be collected stored processed and shared in accordance with the Meta Privacy Policy: checkbox
19
+ extra_gated_description: >-
20
+ The information you provide will be collected, stored, processed and shared in
21
+ accordance with the [Meta Privacy
22
+ Policy](https://www.facebook.com/privacy/policy/).
23
+ extra_gated_button_content: Submit
24
+ language:
25
+ - en
26
+ tags:
27
+ - dino
28
+ - dinov3
29
+ - arxiv:2508.10104
30
+ license: other
31
+ license_name: dinov3-license
32
+ license_link: https://ai.meta.com/resources/models-and-libraries/dinov3-license
33
+ base_model: facebook/dinov3-vit7b16-pretrain-lvd1689m
34
+ pipeline_tag: image-feature-extraction
35
+ library_name: transformers
36
+ ---
37
+
38
+ # Model Card for DINOv3
39
+
40
+ DINOv3 is a family of versatile vision foundation models that outperforms the specialized state of the art across a broad range of settings, without fine-tuning. DINOv3 produces high-quality dense features that achieve outstanding performance on various vision tasks, significantly surpassing previous self- and weakly-supervised foundation models.
41
+
42
+ ## Model Details
43
+
44
+ These are Vision Transformer and ConvNeXt models trained following the method described in the DINOv3 paper. 12 models are provided:
45
+
46
+ - 10 models pretrained on web data (LVD-1689M dataset)
47
+ - 1 ViT-7B trained from scratch,
48
+ - 5 ViT-S/S+/B/L/H+ models distilled from the ViT-7B,
49
+ - 4 ConvNeXt-{T/S/B/L} models distilled from the ViT-7B,
50
+ - 2 models pretrained on satellite data (SAT-493M dataset)
51
+ - 1 ViT-7B trained from scratch
52
+ - 1 ViT-L distilled from the ViT-7B
53
+
54
+
55
+ Each Transformer-based model takes an image as input and returns a class token, patch tokens (and register tokens). These models follow a ViT architecture, with a patch size of 16. For a 224x224 image, this results in 1 class token + 4 register tokens + 196 patch tokens = 201 tokens (for DINOv2 with registers this resulted in 1 + 4 + 256 = 261 tokens).
56
+
57
+ The models can accept larger images provided the image shapes are multiples of the patch size (16). If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
58
+
59
+ ### Model Description
60
+
61
+ - **Developed by:** Meta AI
62
+ - **Model type:** Vision Transformer, ConvNeXt
63
+ - **License:** [DINOv3 License](https://ai.meta.com/resources/models-and-libraries/dinov3-license/)
64
+
65
+ ### Model Sources
66
+
67
+ - **Repository:** [https://github.com/facebookresearch/dinov3](https://github.com/facebookresearch/dinov3)
68
+ - **Paper:** [https://arxiv.org/abs/2508.10104](https://arxiv.org/abs/2508.10104)
69
+
70
+ ## Uses
71
+
72
+ The models are vision backbones providing multi-purpose features for downstream tasks.
73
+
74
+ ### Direct Use
75
+
76
+ The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
77
+
78
+ - on image classification, using k-NN classifiers on the class token
79
+ - on image classification, with logistic regression classifiers applied on the class token
80
+ - on image classification, with a linear layer applied on the class token and the average of the patch tokens
81
+ - on image retrieval using nearest neighbors
82
+ - on geometric and semantic 3D keypoint correspondances
83
+ - on depth estimation, semantic segmentation, using linear layers
84
+ - on unsupervised object discovery
85
+ - on video segmentation tracking
86
+ - on video classification, using a small 4-layer attentive probe
87
+
88
+ ### Downstream Use
89
+
90
+ While fine-tuning the models can yield some gains, it is recommended to keep this option as a last resort: the frozen features are expected to provide good performance out-of-the-box.
91
+
92
+ ## Bias, Risks, and Limitations
93
+
94
+ Compared to DINOv2 and SEERv2, DINOv3 delivers somewhat consistent performance across income categories on geographical fairness and diversity, although with a notable performance drop in the low-income bucket compared to the highest-income bucket.
95
+
96
+ DINOv3 also achieves relatively good scores across different regions, improving over its predecessor DINOv2. However, a relative difference is still observed between Europe and Africa.
97
+
98
+ ### Recommendations
99
+
100
+ Fine-tuning is expected to increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
101
+
102
+ ## How to Get Started with the Model
103
+
104
+ The example below demonstrates how to obtain an image embedding with [Pipeline] or the [AutoModel] class.
105
+
106
+ ```python
107
+ from transformers import pipeline
108
+ from transformers.image_utils import load_image
109
+
110
+ url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
111
+ image = load_image(url)
112
+
113
+ feature_extractor = pipeline(
114
+ model="facebook/dinov3-vitl16-pretrain-lvd1689m",
115
+ task="image-feature-extraction",
116
+ )
117
+ features = feature_extractor(image)
118
+ ```
119
+
120
+ ```python
121
+ import torch
122
+ from transformers import AutoImageProcessor, AutoModel
123
+ from transformers.image_utils import load_image
124
+
125
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
126
+ image = load_image(url)
127
+
128
+ pretrained_model_name = "facebook/dinov3-vitl16-pretrain-lvd1689m"
129
+ processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
130
+ model = AutoModel.from_pretrained(
131
+ pretrained_model_name,
132
+ device_map="auto",
133
+ )
134
+
135
+ inputs = processor(images=image, return_tensors="pt").to(model.device)
136
+ with torch.inference_mode():
137
+ outputs = model(**inputs)
138
+
139
+ pooled_output = outputs.pooler_output
140
+ print("Pooled output shape:", pooled_output.shape)
141
+ ```
142
+
143
+ ## Training Details
144
+
145
+ ### Training Data
146
+
147
+ - Web dataset (LVD-1689M): a curated dataset of 1,689 millions of images extracted from a large data
148
+ pool of 17 billions web images collected from public posts on Instagram
149
+
150
+ - Satellite dataset (SAT-493M): a dataset of 493 millions of 512x512 images sampled randomly from Maxar RGB ortho-rectified imagery at 0.6 meter resolution
151
+
152
+ ### Training Procedure
153
+
154
+ **Training objective:**
155
+
156
+ - DINO self-distillation loss with multi-crop
157
+ - iBOT masked-image modeling loss
158
+ - KoLeo regularization on [CLS] tokens
159
+ - Gram anchoring
160
+
161
+ - **Training regime:** PyTorch FSDP2 (with bf16 and fp8 matrix multiplications)
162
+
163
+ **Distillation:**
164
+
165
+ - Distillation follows the standard DINOv3 pretraining procedure, except the teacher is a frozen pretrained ViT-7B.
166
+
167
+ ## Evaluation
168
+
169
+ **Results**
170
+
171
+ The reader is referred to the associated paper for details on the evaluation protocols
172
+
173
+ *Results for ViT backbones pretrained (or distilled) on web (LVD-1689M)*
174
+
175
+ <table>
176
+ <tr>
177
+ <th></th>
178
+ <!-- <th></th> -->
179
+ <th colspan="4">Global Tasks</th>
180
+ <th colspan="5">Dense Tasks</th>
181
+ </tr>
182
+ <tr>
183
+ <th>Model</th>
184
+ <!-- <th>Dataset</th> -->
185
+ <th>IN-ReaL</th>
186
+ <th>IN-R</th>
187
+ <th>Obj.Net</th>
188
+ <th>Ox.-H</th>
189
+ <th>ADE20k</th>
190
+ <th>NYU↓</th>
191
+ <th>DAVIS</th>
192
+ <th>NAVI</th>
193
+ <th>SPair</th>
194
+ </tr>
195
+ <tr>
196
+ <td>DINOv3 ViT-S/16</td>
197
+ <!-- <td>LVD-1689M</td> -->
198
+ <td align="right">87.0</td>
199
+ <td align="right">60.4</td>
200
+ <td align="right">50.9</td>
201
+ <td align="right">49.5</td>
202
+ <td align="right">47.0</td>
203
+ <td align="right">0.403</td>
204
+ <td align="right">72.7</td>
205
+ <td align="right">56.3</td>
206
+ <td align="right">50.4</td>
207
+ </tr>
208
+ <tr>
209
+ <td>DINOv3 ViT-S+/16</td>
210
+ <!-- <td>LVD-1689M</td> -->
211
+ <td align="right">88.0</td>
212
+ <td align="right">68.8</td>
213
+ <td align="right">54.6</td>
214
+ <td align="right">50.0</td>
215
+ <td align="right">48.8</td>
216
+ <td align="right">0.399</td>
217
+ <td align="right">75.5</td>
218
+ <td align="right">57.1</td>
219
+ <td align="right">55.2</td>
220
+ </tr>
221
+ <tr>
222
+ <td>DINOv3 ViT-B/16</td>
223
+ <!-- <td>LVD-1689M</td> -->
224
+ <td align="right">89.3</td>
225
+ <td align="right">76.7</td>
226
+ <td align="right">64.1</td>
227
+ <td align="right">58.5</td>
228
+ <td align="right">51.8</td>
229
+ <td align="right">0.373</td>
230
+ <td align="right">77.2</td>
231
+ <td align="right">58.8</td>
232
+ <td align="right">57.2</td>
233
+ </tr>
234
+ <tr>
235
+ <td>DINOv3 ViT-L/16</td>
236
+ <!-- <td>LVD-1689M</td> -->
237
+ <td align="right">90.2</td>
238
+ <td align="right">88.1</td>
239
+ <td align="right">74.8</td>
240
+ <td align="right">63.1</td>
241
+ <td align="right">54.9</td>
242
+ <td align="right">0.352</td>
243
+ <td align="right">79.9</td>
244
+ <td align="right">62.3</td>
245
+ <td align="right">61.3</td>
246
+ </tr>
247
+ <tr>
248
+ <td>DINOv3 ViT-H+/16</td>
249
+ <!-- <td>LVD-1689M</td> -->
250
+ <td align="right">90.3</td>
251
+ <td align="right">90.0</td>
252
+ <td align="right">78.6</td>
253
+ <td align="right">64.5</td>
254
+ <td align="right">54.8</td>
255
+ <td align="right">0.352</td>
256
+ <td align="right">79.3</td>
257
+ <td align="right">63.3</td>
258
+ <td align="right">56.3</td>
259
+ </tr>
260
+ <tr>
261
+ <td>DINOv3 ViT-7B/16</td>
262
+ <!-- <td>LVD-1689M</td> -->
263
+ <td align="right">90.4</td>
264
+ <td align="right">91.1</td>
265
+ <td align="right">91.1</td>
266
+ <td align="right">72.8</td>
267
+ <td align="right">55.9</td>
268
+ <td align="right">0.309</td>
269
+ <td align="right">79.7</td>
270
+ <td align="right">64.4</td>
271
+ <td align="right">58.7</td>
272
+ </tr>
273
+ </table>
274
+
275
+ *Results for ConvNeXt backbones distilled on web (LVD-1689M)*
276
+
277
+ <table>
278
+ <tr>
279
+ <th></th>
280
+ <th colspan="6">Global Tasks</th>
281
+ <th colspan="2">Dense Tasks</th>
282
+ </tr>
283
+ <tr>
284
+ <th>Model</th>
285
+ <th colspan="2">IN-ReaL</th>
286
+ <th colspan="2">IN-R</th>
287
+ <th colspan="2">Obj.Net</th>
288
+ <th>ADE20k</th>
289
+ <th>NYU↓</th>
290
+ </tr>
291
+ <tr>
292
+ <td></th>
293
+ <td>@256px</td>
294
+ <td>@512px</td>
295
+ <td>@256px</td>
296
+ <td>@512px</td>
297
+ <td>@256px</td>
298
+ <td>@512px</td>
299
+ <td colspan="2"></td>
300
+ </tr>
301
+ <tr>
302
+ <td>DINOv3 ConvNeXt Tiny</td>
303
+ <td align="right">86.6</td>
304
+ <td align="right">87.7</td>
305
+ <td align="right">73.7</td>
306
+ <td align="right">74.1</td>
307
+ <td align="right">52.6</td>
308
+ <td align="right">58.7</td>
309
+ <td align="right">42.7</td>
310
+ <td align="right">0.448</td>
311
+ </tr>
312
+ <tr>
313
+ <td>DINOv3 ConvNeXt Small</td>
314
+ <td align="right">87.9</td>
315
+ <td align="right">88.7</td>
316
+ <td align="right">73.7</td>
317
+ <td align="right">74.1</td>
318
+ <td align="right">52.6</td>
319
+ <td align="right">58.7</td>
320
+ <td align="right">44.8</td>
321
+ <td align="right">0.432</td>
322
+ </tr>
323
+ <tr>
324
+ <td>DINOv3 ConvNeXt Base</td>
325
+ <td align="right">88.5</td>
326
+ <td align="right">89.2</td>
327
+ <td align="right">77.2</td>
328
+ <td align="right">78.2</td>
329
+ <td align="right">56.2</td>
330
+ <td align="right">61.3</td>
331
+ <td align="right">46.3</td>
332
+ <td align="right">0.420</td>
333
+ </tr>
334
+ <tr>
335
+ <td>DINOv3 ConvNeXt Large</td>
336
+ <td align="right">88.9</td>
337
+ <td align="right">89.4</td>
338
+ <td align="right">81.3</td>
339
+ <td align="right">82.4</td>
340
+ <td align="right">59.3</td>
341
+ <td align="right">65.2</td>
342
+ <td align="right">47.8</td>
343
+ <td align="right">0.403</td>
344
+ </tr>
345
+ </table>
346
+
347
+ *Results for ViT backbones pretrained (or distilled) on satellite (SAT-493M)*
348
+
349
+ <table>
350
+ <tr>
351
+ <th></th>
352
+ <th colspan="7">(GEO-Bench) Classification</th>
353
+ </tr>
354
+ <tr>
355
+ <th>Model</ht>
356
+ <th>m-BEnet</th>
357
+ <th>m-brick-kiln
358
+ <th>m-eurosat</th>
359
+ <th>m-forestnet</th>
360
+ <th>m-pv4ger</th>
361
+ <th>m-so2sat</th>
362
+ <th>mean</th>
363
+ </tr>
364
+ <tr>
365
+ <td>DINOv3 ViT-L/16</td>
366
+ <td>73.0</td>
367
+ <td>96.5</td>
368
+ <td>94.1</td>
369
+ <td>60.6</td>
370
+ <td>96.0</td>
371
+ <td>57.4</td>
372
+ <td>79.6</td>
373
+ </tr>
374
+ <tr>
375
+ <td>DINOv3 ViT-7B/16</td>
376
+ <td>74.0</td>
377
+ <td>97.2</td>
378
+ <td>94.8</td>
379
+ <td>62.3</td>
380
+ <td>96.1</td>
381
+ <td>62.1</td>
382
+ <td>81.1</td>
383
+ </tr>
384
+ <tr>
385
+ <th></th>
386
+ <th colspan="7">(GEO-Bench) Segmentation</th>
387
+ </tr>
388
+ <tr>
389
+ <th>Model</th>
390
+ <th>m-cashew</th>
391
+ <th>m-chesapeake</th>
392
+ <th>m-NeonTree</th>
393
+ <th>m-nz-cattle</th>
394
+ <th>m-pv4ger-seg</th>
395
+ <th>m-SA-crop</th>
396
+ <th>mean</th>
397
+ </tr>
398
+ <tr>
399
+ <td>DINOv3 ViT-L/16</td>
400
+ <td>94.2</td>
401
+ <td>75.6</td>
402
+ <td>61.8</td>
403
+ <td>83.7</td>
404
+ <td>95.2</td>
405
+ <td>36.8</td>
406
+ <td>74.5</td>
407
+ </tr>
408
+ <tr>
409
+ <td>DINOv3 ViT-7B/16</td>
410
+ <td>94.1</td>
411
+ <td>76.6</td>
412
+ <td>62.6</td>
413
+ <td>83.4</td>
414
+ <td>95.5</td>
415
+ <td>37.6</td>
416
+ <td>75.0</td>
417
+ </tr>
418
+ </table>
419
+
420
+
421
+ ## Environmental Impact
422
+
423
+ - **Hardware Type:** Nvidia H100
424
+ - **Hours used:** 61,440 hours for ViT-7B model training
425
+ - **Cloud Provider:** Private infrastructure
426
+ - **Compute Region:** USA
427
+ - **Carbon Emitted:** 18t CO2eq
428
+
429
+ ## Technical Specifications
430
+
431
+ ### Model Architecture and Objective
432
+
433
+ Vision Transformer models:
434
+
435
+ - ViT-S (21M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, MLP FFN, RoPE
436
+ - ViT-S+ (29M parameters): patch size 16, embedding dimension 384, 4 register tokens, 6 heads, SwiGLU FFN, RoPE
437
+ - ViT-B (86M parameters): patch size 16, embedding dimension 768, 4 register tokens, 12 heads, MLP FFN, RoPE
438
+ - ViT-L (300M parameters): patch size 16, embedding dimension 1024, 4 register tokens, 16 heads, MLP FFN, RoPE
439
+ - ViT-H+ (840M parameters): patch size 16, embedding dimension 1280, 4 register tokens, 20 heads, SwiGLU FFN, RoPE
440
+ - ViT-7B (6716M parameters): patch size 16, embedding dimension 4096, 4 register tokens, 32 heads, SwiGLU FFN, RoPE
441
+
442
+ ConvNeXt models:
443
+
444
+ - ConvNeXt Tiny (29M parameters)
445
+ - ConvNeXt Small (50M parameters)
446
+ - ConvNeXt Base (89M parameters)
447
+ - ConvNeXt Large (198M parameters)
448
+
449
+ ### Compute Infrastructure
450
+
451
+ #### Hardware
452
+
453
+ Nvidia H100 GPUs
454
+
455
+ #### Software
456
+
457
+ PyTorch 2.7
458
+
459
+ ## More Information
460
+
461
+ See the [blog post](https://ai.meta.com/blog/dinov3-self-supervised-vision-model/) and the associated [website](https://ai.meta.com/dinov3/).
462
+
463
+ ## Citation
464
+
465
+ **BibTeX**
466
+
467
+ ```
468
+ @misc{simeoni2025dinov3,
469
+ title={{DINOv3}},
470
+ author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick and Bojanowski, Piotr},
471
+ year={2025},
472
+ eprint={2508.10104},
473
+ archivePrefix={arXiv},
474
+ primaryClass={cs.CV},
475
+ url={https://arxiv.org/abs/2508.10104},
476
+ }
477
+ ```
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DINOv3ViTModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "drop_path_rate": 0.0,
7
+ "hidden_act": "gelu",
8
+ "hidden_size": 1024,
9
+ "image_size": 224,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 4096,
12
+ "key_bias": false,
13
+ "layer_norm_eps": 1e-05,
14
+ "layerscale_value": 1.0,
15
+ "mlp_bias": true,
16
+ "model_type": "dinov3_vit",
17
+ "num_attention_heads": 16,
18
+ "num_channels": 3,
19
+ "num_hidden_layers": 24,
20
+ "num_register_tokens": 4,
21
+ "patch_size": 16,
22
+ "pos_embed_jitter": null,
23
+ "pos_embed_rescale": 2.0,
24
+ "pos_embed_shift": null,
25
+ "proj_bias": true,
26
+ "query_bias": true,
27
+ "rope_theta": 100.0,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.56.0.dev0",
30
+ "use_gated_mlp": false,
31
+ "value_bias": true
32
+ }
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dcb2e45127cccbf1601e5f42fef165eea275c8e5213197e8dcf3f48822718179
3
+ size 1212559808
Trellv2/facebook--dinov3-vitl16-pretrain-lvd1689m/preprocessor_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": null,
3
+ "data_format": "channels_first",
4
+ "default_to_square": true,
5
+ "device": null,
6
+ "disable_grouping": null,
7
+ "do_center_crop": null,
8
+ "do_convert_rgb": null,
9
+ "do_normalize": true,
10
+ "do_rescale": true,
11
+ "do_resize": true,
12
+ "image_mean": [
13
+ 0.485,
14
+ 0.456,
15
+ 0.406
16
+ ],
17
+ "image_processor_type": "DINOv3ViTImageProcessorFast",
18
+ "image_std": [
19
+ 0.229,
20
+ 0.224,
21
+ 0.225
22
+ ],
23
+ "input_data_format": null,
24
+ "resample": 2,
25
+ "rescale_factor": 0.00392156862745098,
26
+ "return_tensors": null,
27
+ "size": {
28
+ "height": 224,
29
+ "width": 224
30
+ }
31
+ }
Trellv2/microsoft--TRELLIS.2-4B/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Trellv2/microsoft--TRELLIS.2-4B/README.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-to-3d
4
+ library_name: trellis2
5
+ language:
6
+ - en
7
+ ---
8
+
9
+ # TRELLIS.2: Native and Compact Structured Latents for 3D Generation
10
+
11
+ **Model Name:** TRELLIS.2-4B
12
+
13
+ **Paper:** [https://arxiv.org/abs/2512.14692](https://arxiv.org/abs/2512.14692)
14
+
15
+ **Repository:** [https://github.com/microsoft/TRELLIS.2](https://github.com/microsoft/TRELLIS.2)
16
+
17
+ **Project Page:** [https://microsoft.github.io/trellis.2](https://microsoft.github.io/trellis.2)
18
+
19
+ ## Introduction
20
+
21
+ **TRELLIS.2** is a state-of-the-art large 3D generative model designed for high-fidelity **image-to-3D** generation. It leverages a novel "field-free" sparse voxel structure termed **O-Voxel** and a large-scale flow-matching transformer (4 Billion parameters).
22
+
23
+ Unlike previous methods that rely on iso-surface fields (e.g., SDF, Flexicubes) which struggle with open surfaces or non-manifold geometry, TRELLIS can reconstruct and generate **arbitrary 3D assets** with complex topologies, sharp features, and full Physical-Based Rendering (PBR) materials—including transparency/translucency.
24
+
25
+ ## Model Details
26
+
27
+ * **Developed by:** Jianfeng Xiang, Xiaoxue Chen, Sicheng Xu, Ruicheng Wang, Zelong Lv, Yu Deng, Hongyuan Zhu, Yue Dong, Hao Zhao, Nicholas Jing Yuan, Jiaolong Yang
28
+ * **Model Type:** Flow-Matching Transformers with Sparse Voxel based 3D VAE
29
+ * **Parameters:** 4 Billion
30
+ * **Input:** Single Image
31
+ * **Output:** 3D Asset (Mesh with PBR Materials)
32
+ * **Resolution:** Varies from 512³ to 1536³ (Voxel Grid Resolution)
33
+
34
+ ## Key Features
35
+
36
+ * **O-Voxel Representation:** An omni-voxel structure that encodes both geometry and appearance. It supports:
37
+ * **Arbitrary Topology:** Handles open surfaces, non-manifold geometry, and fully-enclosed structures without lossy conversion.
38
+ * **Rich Appearance:** Captures PBR attributes (including opacity for translucent surfaces) aligned with geometry.
39
+ * **Efficiency:** Instant optimization-free bidirectional conversion between meshes and O-Voxels (ms to seconds).
40
+ * **High-Resolution Generation:** The model is trained to generate fully textured assets at **up to 1536³ resolution**.
41
+ * **High-Fidelity while Compact Latent Space:** Utilizes a Sparse 3D VAE with **16× spatial downsampling**, encoding a 1024³ asset into only ~9.6K latent tokens with negligible perceptual degradation.
42
+ * **Shape-conditioned Texture Generation:** Generates textures for input 3D meshes and reference images.
43
+ * **State-of-the-Art Speed:** Inference is highly efficient; see table below.
44
+
45
+ ## Inference Speed (NVIDIA H100 GPU)
46
+
47
+ | Resolution | Time |
48
+ | :--- | :--- |
49
+ | 512³ | ~3 seconds |
50
+ | 1024³ | ~17 seconds |
51
+ | 1536³ | ~60 seconds |
52
+
53
+ ## Requirements
54
+ - **System**: The model is currently tested only on **Linux**.
55
+ - **Hardware**: An NVIDIA GPU with at least 24GB of memory is necessary. The code has been verified on NVIDIA A100 and H100 GPUs.
56
+ - **Software**:
57
+ - The [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit-archive) is needed to compile certain packages. Recommended version is 12.4.
58
+ - [Conda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) is recommended for managing dependencies.
59
+ - Python version 3.8 or higher is required.
60
+
61
+ ## Known Limitations
62
+
63
+ * **Geometric Artifacts (Small Holes):** While O-Voxels handle complex topology well, the generated raw meshes may occasionally contain small holes or minor topological discontinuities. For applications requiring strictly watertight geometry (e.g., 3D printing), we provide accompanying mesh post-processing scripts, such as hole-filling algorithms.
64
+ * **Base Model w/o Alignment:** TRELLIS.2-4B is a pre-trained foundation model. It has **not** been aligned with human preferences (e.g., via RLHF) or fine-tuned for specific aesthetic standards. Consequently, the outputs reflect the distribution of the training data and may vary in style; users may need to experiment with inputs to achieve the desired artistic result.
65
+
66
+ We are actively working on improving the model and addressing these limitations.
67
+
68
+ ## Usage
69
+
70
+ *Note: Please refer to the official [GitHub Repository](https://github.com/microsoft/TRELLIS.2) for installation instructions and dependencies.*
71
+
72
+ ```python
73
+ import os
74
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
75
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Can save GPU memory
76
+ import cv2
77
+ import imageio
78
+ from PIL import Image
79
+ import torch
80
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
81
+ from trellis2.utils import render_utils
82
+ from trellis2.renderers import EnvMap
83
+ import o_voxel
84
+
85
+ # 1. Setup Environment Map
86
+ envmap = EnvMap(torch.tensor(
87
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
88
+ dtype=torch.float32, device='cuda'
89
+ ))
90
+
91
+ # 2. Load Pipeline
92
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
93
+ pipeline.cuda()
94
+
95
+ # 3. Load Image & Run
96
+ image = Image.open("assets/example_image/T.png")
97
+ mesh = pipeline.run(image)[0]
98
+ mesh.simplify(16777216) # nvdiffrast limit
99
+
100
+ # 4. Render Video
101
+ video = render_utils.make_pbr_vis_frames(render_utils.render_video(mesh, envmap=envmap))
102
+ imageio.mimsave("sample.mp4", video, fps=15)
103
+
104
+ # 5. Export to GLB
105
+ glb = o_voxel.postprocess.to_glb(
106
+ vertices = mesh.vertices,
107
+ faces = mesh.faces,
108
+ attr_volume = mesh.attrs,
109
+ coords = mesh.coords,
110
+ attr_layout = mesh.layout,
111
+ voxel_size = mesh.voxel_size,
112
+ aabb = [[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
113
+ decimation_target = 1000000,
114
+ texture_size = 4096,
115
+ remesh = True,
116
+ remesh_band = 1,
117
+ remesh_project = 0,
118
+ verbose = True
119
+ )
120
+ glb.export("sample.glb", extension_webp=True)
121
+ ```
122
+
123
+ ## Citation
124
+
125
+ If you find this model useful for your research, please cite our work:
126
+
127
+ ```
128
+ @article{
129
+ xiang2025trellis2,
130
+ title={Native and Compact Structured Latents for 3D Generation},
131
+ author={Xiang, Jianfeng and Chen, Xiaoxue and Xu, Sicheng and Wang, Ruicheng and Lv, Zelong and Deng, Yu and Zhu, Hongyuan and Dong, Yue and Zhao, Hao and Yuan, Nicholas Jing and Yang, Jiaolong},
132
+ journal={Tech report},
133
+ year={2025}
134
+ }
135
+ ```
136
+
137
+ ## License
138
+
139
+ This model is released under the MIT License. The code and dataset are publicly released to facilitate reproduction and further research.
Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "FlexiDualGridVaeDecoder",
3
+ "args": {
4
+ "resolution": 256,
5
+ "model_channels": [1024, 512, 256, 128, 64],
6
+ "latent_channels": 32,
7
+ "num_blocks": [4, 16, 8, 4, 0],
8
+ "block_type": [
9
+ "SparseConvNeXtBlock3d",
10
+ "SparseConvNeXtBlock3d",
11
+ "SparseConvNeXtBlock3d",
12
+ "SparseConvNeXtBlock3d",
13
+ "SparseConvNeXtBlock3d"
14
+ ],
15
+ "up_block_type": [
16
+ "SparseResBlockC2S3d",
17
+ "SparseResBlockC2S3d",
18
+ "SparseResBlockC2S3d",
19
+ "SparseResBlockC2S3d"
20
+ ],
21
+ "block_args": [{}, {}, {}, {}, {}],
22
+ "use_fp16": true
23
+ }
24
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "FlexiDualGridVaeEncoder",
3
+ "args": {
4
+ "model_channels": [64, 128, 256, 512, 1024],
5
+ "latent_channels": 32,
6
+ "num_blocks": [0, 4, 8, 16, 4],
7
+ "block_type": [
8
+ "SparseConvNeXtBlock3d",
9
+ "SparseConvNeXtBlock3d",
10
+ "SparseConvNeXtBlock3d",
11
+ "SparseConvNeXtBlock3d",
12
+ "SparseConvNeXtBlock3d"
13
+ ],
14
+ "down_block_type": [
15
+ "SparseResBlockS2C3d",
16
+ "SparseResBlockS2C3d",
17
+ "SparseResBlockS2C3d",
18
+ "SparseResBlockS2C3d"
19
+ ],
20
+ "block_args": [{}, {}, {}, {}, {}],
21
+ "use_fp16": true
22
+ }
23
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/shape_enc_next_dc_f16c32_fp16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f37c5ff5b983b68e9946060000f09bc131f3e84318a2c8b7430a81e4b4636c41
3
+ size 708797208
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_1024_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SLatFlowModel",
3
+ "args": {
4
+ "resolution": 64,
5
+ "in_channels": 32,
6
+ "out_channels": 32,
7
+ "model_channels": 1536,
8
+ "cond_channels": 1024,
9
+ "num_blocks": 30,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 5.3334,
12
+ "pe_mode": "rope",
13
+ "share_mod": true,
14
+ "initialization": "scaled",
15
+ "qk_rms_norm": true,
16
+ "qk_rms_norm_cross": true,
17
+ "dtype": "bfloat16"
18
+ }
19
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_1024_bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07cd0596f634c5adc1890023d16023afc5eed02fb84b22bb23aff5bf0030fbbd
3
+ size 2584574424
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_img2shape_dit_1_3B_512_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SLatFlowModel",
3
+ "args": {
4
+ "resolution": 32,
5
+ "in_channels": 32,
6
+ "out_channels": 32,
7
+ "model_channels": 1536,
8
+ "cond_channels": 1024,
9
+ "num_blocks": 30,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 5.3334,
12
+ "pe_mode": "rope",
13
+ "share_mod": true,
14
+ "initialization": "scaled",
15
+ "qk_rms_norm": true,
16
+ "qk_rms_norm_cross": true,
17
+ "dtype": "bfloat16"
18
+ }
19
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_1024_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SLatFlowModel",
3
+ "args": {
4
+ "resolution": 64,
5
+ "in_channels": 64,
6
+ "out_channels": 32,
7
+ "model_channels": 1536,
8
+ "cond_channels": 1024,
9
+ "num_blocks": 30,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 5.3334,
12
+ "pe_mode": "rope",
13
+ "share_mod": true,
14
+ "initialization": "scaled",
15
+ "qk_rms_norm": true,
16
+ "qk_rms_norm_cross": true,
17
+ "dtype": "bfloat16"
18
+ }
19
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SLatFlowModel",
3
+ "args": {
4
+ "resolution": 32,
5
+ "in_channels": 64,
6
+ "out_channels": 32,
7
+ "model_channels": 1536,
8
+ "cond_channels": 1024,
9
+ "num_blocks": 30,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 5.3334,
12
+ "pe_mode": "rope",
13
+ "share_mod": true,
14
+ "initialization": "scaled",
15
+ "qk_rms_norm": true,
16
+ "qk_rms_norm_cross": true,
17
+ "dtype": "bfloat16"
18
+ }
19
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8371aa1c5d13be79dcd5ddfd2cf3835e902e204dc34427169a1c702828e1a94d
3
+ size 2584672728
Trellv2/microsoft--TRELLIS.2-4B/ckpts/ss_flow_img_dit_1_3B_64_bf16.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SparseStructureFlowModel",
3
+ "args": {
4
+ "resolution": 16,
5
+ "in_channels": 8,
6
+ "out_channels": 8,
7
+ "model_channels": 1536,
8
+ "cond_channels": 1024,
9
+ "num_blocks": 30,
10
+ "num_heads": 12,
11
+ "mlp_ratio": 5.3334,
12
+ "pe_mode": "rope",
13
+ "share_mod": true,
14
+ "initialization": "scaled",
15
+ "qk_rms_norm": true,
16
+ "qk_rms_norm_cross": true,
17
+ "dtype": "bfloat16"
18
+ }
19
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/ss_flow_img_dit_1_3B_64_bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca01377c485bec418076d38ee80166d32dc776d744f2553b835cba1e97a7abf6
3
+ size 2584426920
Trellv2/microsoft--TRELLIS.2-4B/ckpts/tex_dec_next_dc_f16c32_fp16.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SparseUnetVaeDecoder",
3
+ "args": {
4
+ "out_channels": 6,
5
+ "model_channels": [1024, 512, 256, 128, 64],
6
+ "latent_channels": 32,
7
+ "num_blocks": [4, 16, 8, 4, 0],
8
+ "block_type": [
9
+ "SparseConvNeXtBlock3d",
10
+ "SparseConvNeXtBlock3d",
11
+ "SparseConvNeXtBlock3d",
12
+ "SparseConvNeXtBlock3d",
13
+ "SparseConvNeXtBlock3d"
14
+ ],
15
+ "up_block_type": [
16
+ "SparseResBlockC2S3d",
17
+ "SparseResBlockC2S3d",
18
+ "SparseResBlockC2S3d",
19
+ "SparseResBlockC2S3d"
20
+ ],
21
+ "block_args": [{}, {}, {}, {}, {}],
22
+ "pred_subdiv": false,
23
+ "use_fp16": true
24
+ }
25
+ }
Trellv2/microsoft--TRELLIS.2-4B/ckpts/tex_enc_next_dc_f16c32_fp16.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SparseUnetVaeEncoder",
3
+ "args": {
4
+ "in_channels": 6,
5
+ "model_channels": [64, 128, 256, 512, 1024],
6
+ "latent_channels": 32,
7
+ "num_blocks": [0, 4, 8, 16, 4],
8
+ "block_type": [
9
+ "SparseConvNeXtBlock3d",
10
+ "SparseConvNeXtBlock3d",
11
+ "SparseConvNeXtBlock3d",
12
+ "SparseConvNeXtBlock3d",
13
+ "SparseConvNeXtBlock3d"
14
+ ],
15
+ "down_block_type": [
16
+ "SparseResBlockS2C3d",
17
+ "SparseResBlockS2C3d",
18
+ "SparseResBlockS2C3d",
19
+ "SparseResBlockS2C3d"
20
+ ],
21
+ "block_args": [{}, {}, {}, {}, {}],
22
+ "use_fp16": true
23
+ }
24
+ }
Trellv2/microsoft--TRELLIS.2-4B/pipeline.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "Trellis2ImageTo3DPipeline",
3
+ "args": {
4
+ "models": {
5
+ "sparse_structure_decoder": "microsoft/TRELLIS-image-large/ckpts/ss_dec_conv3d_16l8_fp16",
6
+ "sparse_structure_flow_model": "ckpts/ss_flow_img_dit_1_3B_64_bf16",
7
+ "shape_slat_decoder": "ckpts/shape_dec_next_dc_f16c32_fp16",
8
+ "shape_slat_flow_model_512": "ckpts/slat_flow_img2shape_dit_1_3B_512_bf16",
9
+ "shape_slat_flow_model_1024": "ckpts/slat_flow_img2shape_dit_1_3B_1024_bf16",
10
+ "tex_slat_decoder": "ckpts/tex_dec_next_dc_f16c32_fp16",
11
+ "tex_slat_flow_model_512": "ckpts/slat_flow_imgshape2tex_dit_1_3B_512_bf16",
12
+ "tex_slat_flow_model_1024": "ckpts/slat_flow_imgshape2tex_dit_1_3B_1024_bf16"
13
+ },
14
+ "sparse_structure_sampler": {
15
+ "name": "FlowEulerGuidanceIntervalSampler",
16
+ "args": {
17
+ "sigma_min": 1e-5
18
+ },
19
+ "params": {
20
+ "steps": 12,
21
+ "guidance_strength": 7.5,
22
+ "guidance_rescale": 0.7,
23
+ "guidance_interval": [0.6, 1.0],
24
+ "rescale_t": 5.0
25
+ }
26
+ },
27
+ "shape_slat_sampler": {
28
+ "name": "FlowEulerGuidanceIntervalSampler",
29
+ "args": {
30
+ "sigma_min": 1e-5
31
+ },
32
+ "params": {
33
+ "steps": 12,
34
+ "guidance_strength": 7.5,
35
+ "guidance_rescale": 0.5,
36
+ "guidance_interval": [0.6, 1.0],
37
+ "rescale_t": 3.0
38
+ }
39
+ },
40
+ "shape_slat_normalization": {
41
+ "mean": [
42
+ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
43
+ -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
44
+ 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
45
+ -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
46
+ ],
47
+ "std": [
48
+ 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
49
+ 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
50
+ 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
51
+ 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
52
+ ]
53
+ },
54
+ "tex_slat_sampler": {
55
+ "name": "FlowEulerGuidanceIntervalSampler",
56
+ "args": {
57
+ "sigma_min": 1e-5
58
+ },
59
+ "params": {
60
+ "steps": 12,
61
+ "guidance_strength": 1.0,
62
+ "guidance_rescale": 0.0,
63
+ "guidance_interval": [0.6, 0.9],
64
+ "rescale_t": 3.0
65
+ }
66
+ },
67
+ "tex_slat_normalization": {
68
+ "mean": [
69
+ 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
70
+ 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
71
+ -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
72
+ 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
73
+ ],
74
+ "std": [
75
+ 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
76
+ 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
77
+ 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
78
+ 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
79
+ ]
80
+ },
81
+ "image_cond_model": {
82
+ "name": "DinoV3FeatureExtractor",
83
+ "args": {
84
+ "model_name": "facebook/dinov3-vitl16-pretrain-lvd1689m"
85
+ }
86
+ },
87
+ "rembg_model": {
88
+ "name": "BiRefNet",
89
+ "args": {
90
+ "model_name": "briaai/RMBG-2.0"
91
+ }
92
+ },
93
+ "default_pipeline_type": "1024_cascade"
94
+ }
95
+ }