YuanGao-YG commited on
Commit
484df85
·
verified ·
1 Parent(s): d8fe602

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ img/dynamic_prompting.jpg filter=lfs diff=lfs merge=lfs -text
37
+ img/introduction_benchmark.jpg filter=lfs diff=lfs merge=lfs -text
38
+ img/results.jpg filter=lfs diff=lfs merge=lfs -text
39
+ img/vision_main.jpg filter=lfs diff=lfs merge=lfs -text
checkpoint_VISION/best_mse.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf069f9e3e6fc97c00084b264f1762705c2de0e137c8232233a4db381efa9e8b
3
+ size 55591502
checkpoint_VISION/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
data/KD48_demo.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b23de2cfd4c4eb2f6a3dfeb3c811f957444752e07b1d0cd1d791cf7d53be2b23
3
+ size 471861248
data/mean.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:309903f1e30d424cbccae45933b646110fe6f52f162fe056531d81b827ac3ecf
3
+ size 200
data/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The intact project is available at the Hugging Face.
data/std.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5795b455977d2622434d87d161d053ceef568a2d94dbdbe22699e756b8dbf59c
3
+ size 200
environment.yml ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vision
2
+ channels:
3
+ - pytorch
4
+ - dglteam/label/th24_cu118
5
+ - nvidia
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - blas=1.0=mkl
11
+ - brotli-python=1.0.9=py310h6a678d5_8
12
+ - bzip2=1.0.8=h5eee18b_6
13
+ - c-ares=1.34.5=hef5626c_0
14
+ - ca-certificates=2025.9.9=h06a4308_0
15
+ - certifi=2025.8.3=py310h06a4308_0
16
+ - charset-normalizer=3.3.2=pyhd3eb1b0_0
17
+ - click=8.2.1=py310h06a4308_0
18
+ - cloudpickle=3.1.1=py310h06a4308_0
19
+ - cuda-cudart=11.8.89=0
20
+ - cuda-cupti=11.8.87=0
21
+ - cuda-libraries=11.8.0=0
22
+ - cuda-nvrtc=11.8.89=0
23
+ - cuda-nvtx=11.8.86=0
24
+ - cuda-runtime=11.8.0=0
25
+ - cyrus-sasl=2.1.28=h1110e0f_3
26
+ - cytoolz=1.0.1=py310h5eee18b_0
27
+ - dask-core=2025.7.0=py310h06a4308_0
28
+ - dgl=2.4.0.th24.cu118=py310_0
29
+ - expat=2.7.1=h6a678d5_0
30
+ - ffmpeg=4.3=hf484d3e_0
31
+ - filelock=3.13.1=py310h06a4308_0
32
+ - fontconfig=2.14.1=h55d465d_3
33
+ - freetype=2.12.1=h4a9f257_0
34
+ - gmp=6.2.1=h295c915_3
35
+ - gmpy2=2.1.2=py310heeb90bb_0
36
+ - gnutls=3.6.15=he1e5248_0
37
+ - icu=73.1=h6a678d5_0
38
+ - idna=3.7=py310h06a4308_0
39
+ - imageio=2.37.0=py310h06a4308_0
40
+ - importlib-metadata=8.5.0=py310h06a4308_0
41
+ - intel-openmp=2023.1.0=hdb19cb5_46306
42
+ - jinja2=3.1.4=py310h06a4308_0
43
+ - jpeg=9e=h5eee18b_3
44
+ - krb5=1.20.1=h143b758_1
45
+ - lame=3.100=h7b6447c_0
46
+ - lcms2=2.12=h3be6417_0
47
+ - ld_impl_linux-64=2.38=h1181459_1
48
+ - lerc=3.0=h295c915_0
49
+ - libabseil=20250127.0=cxx17_h6a678d5_0
50
+ - libcublas=11.11.3.6=0
51
+ - libcufft=10.9.0.58=0
52
+ - libcufile=1.9.1.3=0
53
+ - libcups=2.4.2=h252cb56_2
54
+ - libcurand=10.3.5.147=0
55
+ - libcurl=8.12.1=hc9e6f67_0
56
+ - libcusolver=11.4.1.48=0
57
+ - libcusparse=11.7.5.86=0
58
+ - libdeflate=1.17=h5eee18b_1
59
+ - libedit=3.1.20230828=h5eee18b_0
60
+ - libev=4.33=h7f8727e_1
61
+ - libevent=2.1.12=hdbd6064_1
62
+ - libffi=3.4.4=h6a678d5_1
63
+ - libgcc-ng=11.2.0=h1234567_1
64
+ - libgfortran-ng=11.2.0=h00389a5_1
65
+ - libgfortran5=11.2.0=h1234567_1
66
+ - libgl=1.7.0=h5eee18b_2
67
+ - libglib=2.84.2=h37c7471_0
68
+ - libglvnd=1.7.0=h5eee18b_2
69
+ - libglx=1.7.0=h5eee18b_2
70
+ - libgomp=11.2.0=h1234567_1
71
+ - libiconv=1.16=h5eee18b_3
72
+ - libidn2=2.3.4=h5eee18b_0
73
+ - libjpeg-turbo=2.0.0=h9bf148f_0
74
+ - libkrb5=1.21.3=h520c7b4_4
75
+ - libnghttp2=1.57.0=h2d74bed_0
76
+ - libnpp=11.8.0.86=0
77
+ - libnvjpeg=11.9.0.86=0
78
+ - libpng=1.6.39=h5eee18b_0
79
+ - libpq=17.4=h02b6914_2
80
+ - libprotobuf=5.29.3=h3cdef7c_1
81
+ - libsodium=1.0.20=heac8642_0
82
+ - libssh2=1.11.1=h251f7ec_0
83
+ - libstdcxx-ng=11.2.0=h1234567_1
84
+ - libtasn1=4.19.0=h5eee18b_0
85
+ - libtiff=4.5.1=h6a678d5_0
86
+ - libunistring=0.9.10=h27cfd23_0
87
+ - libuuid=1.41.5=h5eee18b_0
88
+ - libwebp-base=1.3.2=h5eee18b_0
89
+ - libxcb=1.17.0=h9b100fa_0
90
+ - libxkbcommon=1.0.1=h097e994_2
91
+ - libxml2=2.13.5=hfdd30dd_0
92
+ - llvm-openmp=14.0.6=h9e868ea_0
93
+ - lmdb=0.9.31=hb25bd0a_0
94
+ - locket=1.0.0=py310h06a4308_0
95
+ - lz4-c=1.9.4=h6a678d5_1
96
+ - markupsafe=2.1.3=py310h5eee18b_0
97
+ - matplotlib-base=3.10.0=py310hbfdbfaf_0
98
+ - mkl=2023.1.0=h213fc3f_46344
99
+ - mkl-service=2.4.0=py310h5eee18b_1
100
+ - mkl_fft=1.3.10=py310h5eee18b_0
101
+ - mkl_random=1.2.7=py310h1128e8f_0
102
+ - mpc=1.1.0=h10f8cd9_1
103
+ - mpfr=4.0.2=hb69a4c5_1
104
+ - mpmath=1.3.0=py310h06a4308_0
105
+ - mysql=8.4.0=h721767e_2
106
+ - ncurses=6.4=h6a678d5_0
107
+ - nettle=3.7.3=hbbd107a_1
108
+ - networkx=3.3=py310h06a4308_0
109
+ - numpy=1.26.4=py310h5f9d8c6_0
110
+ - numpy-base=1.26.4=py310hb5e798b_0
111
+ - openh264=2.1.1=h4ff587b_0
112
+ - openjpeg=2.5.2=he7f1fd0_0
113
+ - openldap=2.6.10=he9288cc_1
114
+ - openssl=3.0.17=h5eee18b_0
115
+ - partd=1.4.2=py310h06a4308_0
116
+ - pcre2=10.42=hebb0a14_1
117
+ - pillow=10.4.0=py310h5eee18b_0
118
+ - pip=24.2=py310h06a4308_0
119
+ - psutil=5.9.0=py310h5eee18b_0
120
+ - pthread-stubs=0.3=h0ce48e5_1
121
+ - pybind11-abi=4=hd3eb1b0_1
122
+ - pyparsing=3.2.0=py310h06a4308_0
123
+ - pyqt=6.7.1=py310h8dad735_2
124
+ - pyqt6-sip=13.9.1=py310h6ce4db3_2
125
+ - pysocks=1.7.1=py310h06a4308_0
126
+ - python=3.10.15=he870216_1
127
+ - python-dateutil=2.9.0post0=py310h06a4308_2
128
+ - pytorch=2.4.0=py3.10_cuda11.8_cudnn9.1.0_0
129
+ - pytorch-cuda=11.8=h7e8668a_5
130
+ - pytorch-mutex=1.0=cuda
131
+ - pywavelets=1.8.0=py310h5eee18b_0
132
+ - pyyaml=6.0.1=py310h5eee18b_0
133
+ - qtbase=6.7.3=hdaa5aa8_0
134
+ - qtdeclarative=6.7.3=h7934f7d_1
135
+ - qtsvg=6.7.3=he4bddd1_1
136
+ - qttools=6.7.3=h80c7b02_0
137
+ - qtwebchannel=6.7.3=h7934f7d_1
138
+ - qtwebsockets=6.7.3=h7934f7d_1
139
+ - readline=8.2=h5eee18b_0
140
+ - requests=2.32.3=py310h06a4308_0
141
+ - scipy=1.13.1=py310h5f9d8c6_0
142
+ - setuptools=75.1.0=py310h06a4308_0
143
+ - sip=6.10.0=py310h6a678d5_0
144
+ - sqlite=3.45.3=h5eee18b_0
145
+ - sympy=1.13.2=py310h06a4308_0
146
+ - tbb=2021.8.0=hdb19cb5_0
147
+ - tifffile=2025.2.18=py310h06a4308_0
148
+ - tk=8.6.14=h39e8969_0
149
+ - toolz=1.0.0=py310h06a4308_0
150
+ - torchaudio=2.4.0=py310_cu118
151
+ - torchtriton=3.0.0=py310
152
+ - torchvision=0.19.0=py310_cu118
153
+ - tqdm=4.66.5=py310h2f386ee_0
154
+ - typing_extensions=4.11.0=py310h06a4308_0
155
+ - unicodedata2=15.1.0=py310h5eee18b_1
156
+ - urllib3=2.2.3=py310h06a4308_0
157
+ - wheel=0.44.0=py310h06a4308_0
158
+ - xcb-util=0.4.1=h5eee18b_2
159
+ - xcb-util-cursor=0.1.5=h5eee18b_0
160
+ - xcb-util-image=0.4.0=h5eee18b_2
161
+ - xcb-util-renderutil=0.3.10=h5eee18b_0
162
+ - xorg-libx11=1.8.12=h9b100fa_1
163
+ - xorg-libxau=1.0.12=h9b100fa_0
164
+ - xorg-libxdmcp=1.1.5=h9b100fa_0
165
+ - xorg-xorgproto=2024.1=h5eee18b_1
166
+ - xz=5.4.6=h5eee18b_1
167
+ - yaml=0.2.5=h7b6447c_0
168
+ - zlib=1.2.13=h5eee18b_1
169
+ - zstd=1.5.5=hc292b87_2
170
+ - pip:
171
+ - aiobotocore==2.15.1
172
+ - aiohappyeyeballs==2.4.3
173
+ - aiohttp==3.10.8
174
+ - aioitertools==0.12.0
175
+ - aiosignal==1.3.1
176
+ - anyio==4.6.0
177
+ - argon2-cffi==23.1.0
178
+ - argon2-cffi-bindings==21.2.0
179
+ - arrow==1.3.0
180
+ - asttokens==2.4.1
181
+ - async-lru==2.0.4
182
+ - async-timeout==4.0.3
183
+ - attrs==24.2.0
184
+ - babel==2.16.0
185
+ - beautifulsoup4==4.12.3
186
+ - bleach==6.1.0
187
+ - blessed==1.20.0
188
+ - botocore==1.35.23
189
+ - cartopy==0.24.1
190
+ - cffi==1.17.1
191
+ - cftime==1.6.4.post1
192
+ - cmocean==4.0.3
193
+ - colorama==0.4.6
194
+ - comm==0.2.2
195
+ - contourpy==1.3.0
196
+ - cycler==0.12.1
197
+ - debugpy==1.8.6
198
+ - decorator==5.1.1
199
+ - defusedxml==0.7.1
200
+ - einops==0.8.0
201
+ - exceptiongroup==1.2.2
202
+ - executing==2.1.0
203
+ - fastjsonschema==2.20.0
204
+ - fonttools==4.54.1
205
+ - fqdn==1.5.1
206
+ - frozenlist==1.4.1
207
+ - fsspec==2024.9.0
208
+ - gpustat==1.1.1
209
+ - gsw==3.6.20
210
+ - h11==0.14.0
211
+ - h5netcdf==1.4.0
212
+ - h5py==3.12.1
213
+ - httpcore==1.0.6
214
+ - httpx==0.27.2
215
+ - huggingface-hub==0.25.1
216
+ - icecream==2.1.3
217
+ - ipykernel==6.29.5
218
+ - ipython==8.28.0
219
+ - isoduration==20.11.0
220
+ - jedi==0.19.1
221
+ - jmespath==1.0.1
222
+ - joblib==1.4.2
223
+ - json5==0.9.25
224
+ - jsonpointer==3.0.0
225
+ - jsonschema==4.23.0
226
+ - jsonschema-specifications==2024.10.1
227
+ - jupyter-client==8.6.3
228
+ - jupyter-core==5.7.2
229
+ - jupyter-events==0.10.0
230
+ - jupyter-lsp==2.2.5
231
+ - jupyter-server==2.14.2
232
+ - jupyter-server-terminals==0.5.3
233
+ - jupyterlab==4.2.5
234
+ - jupyterlab-pygments==0.3.0
235
+ - jupyterlab-server==2.27.3
236
+ - kiwisolver==1.4.7
237
+ - lazy-loader==0.4
238
+ - matplotlib==3.9.2
239
+ - matplotlib-inline==0.1.7
240
+ - mistune==3.0.2
241
+ - multidict==6.1.0
242
+ - nbclient==0.10.0
243
+ - nbconvert==7.16.4
244
+ - nbformat==5.10.4
245
+ - nest-asyncio==1.6.0
246
+ - netcdf4==1.7.2
247
+ - notebook==7.2.2
248
+ - notebook-shim==0.2.4
249
+ - nvfuser-cu118-torch24==0.2.9.dev20240808
250
+ - nvidia-cuda-cupti-cu11==11.8.87
251
+ - nvidia-cuda-nvrtc-cu11==11.8.89
252
+ - nvidia-cuda-runtime-cu11==11.8.89
253
+ - nvidia-ml-py==12.560.30
254
+ - nvidia-nvtx-cu11==11.8.86
255
+ - overrides==7.7.0
256
+ - packaging==24.1
257
+ - pandas==2.2.3
258
+ - pandocfilters==1.5.1
259
+ - parso==0.8.4
260
+ - pexpect==4.9.0
261
+ - platformdirs==4.3.6
262
+ - prometheus-client==0.21.0
263
+ - prompt-toolkit==3.0.48
264
+ - ptyprocess==0.7.0
265
+ - pure-eval==0.2.3
266
+ - pycparser==2.22
267
+ - pygments==2.18.0
268
+ - pyproj==3.7.0
269
+ - pyshp==2.3.1
270
+ - python-json-logger==2.0.7
271
+ - pytz==2024.2
272
+ - pyzmq==26.2.0
273
+ - referencing==0.35.1
274
+ - rfc3339-validator==0.1.4
275
+ - rfc3986-validator==0.1.1
276
+ - rpds-py==0.20.0
277
+ - ruamel-yaml==0.18.6
278
+ - ruamel-yaml-clib==0.2.8
279
+ - s3fs==2024.9.0
280
+ - safetensors==0.4.5
281
+ - scikit-image==0.25.2
282
+ - scikit-learn==1.5.2
283
+ - send2trash==1.8.3
284
+ - shapely==2.0.6
285
+ - six==1.16.0
286
+ - sniffio==1.3.1
287
+ - soupsieve==2.6
288
+ - stack-data==0.6.3
289
+ - terminado==0.18.1
290
+ - thop==0.1.1-2209072238
291
+ - threadpoolctl==3.5.0
292
+ - timm==1.0.9
293
+ - tinycss2==1.3.0
294
+ - tomli==2.0.2
295
+ - torchsummary==1.5.1
296
+ - tornado==6.4.1
297
+ - traitlets==5.14.3
298
+ - treelib==1.7.0
299
+ - types-python-dateutil==2.9.0.20241003
300
+ - tzdata==2024.2
301
+ - uri-template==1.3.0
302
+ - wcwidth==0.2.13
303
+ - webcolors==24.8.0
304
+ - webencodings==0.5.1
305
+ - websocket-client==1.8.0
306
+ - wrapt==1.16.0
307
+ - xarray==2024.9.0
308
+ - yarl==1.13.1
309
+ - zipp==3.20.2
310
+ prefix: /miniconda3/envs/vision
img/.DS_Store ADDED
Binary file (6.15 kB). View file
 
img/dynamic_prompting.jpg ADDED

Git LFS Details

  • SHA256: 840657c36051aaac4391f3b909a4a382443f92c3a831ff80fc367b6e8f33f7c9
  • Pointer size: 131 Bytes
  • Size of remote file: 462 kB
img/introduction_benchmark.jpg ADDED

Git LFS Details

  • SHA256: 3397240f4007bfe93b81c9b585133578f0bc2ae4881889ead06991ff628f89c1
  • Pointer size: 132 Bytes
  • Size of remote file: 3.76 MB
img/results.jpg ADDED

Git LFS Details

  • SHA256: 34b92aae3b5bd4971d01e5ad8aaf0ccdca283828e48d30c2fa7e1078875323eb
  • Pointer size: 131 Bytes
  • Size of remote file: 697 kB
img/vision_main.jpg ADDED

Git LFS Details

  • SHA256: e9be3c7a4aac02d6d43119553569c3d04c5ec4bf2c97b3a36ba74cff87d34987
  • Pointer size: 131 Bytes
  • Size of remote file: 786 kB
inference_co_ssh_u_v_b_vision.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xarray as xr
2
+ import torch as pt
3
+ from torch import nn
4
+ import torch.utils.data as Data
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import os
8
+ from model.vision import VISION
9
+ import os.path as osp
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import h5py
13
+
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16
+
17
+ high1 = 0
18
+ high2 = 512
19
+ width1 = 0
20
+ width2 = 512
21
+
22
+ class testDataset(Dataset):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.data_path = './data/KD48_demo.h5'
26
+ self.data_file = h5py.File(self.data_path, 'r')
27
+ self.mean = np.load('./data/mean.npy')
28
+ self.std = np.load('./data/std.npy')
29
+ self.size = 50
30
+
31
+ def __getitem__(self, index):
32
+
33
+ data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2]
34
+ data = np.nan_to_num(data, nan=0)
35
+ ssh = data[0]
36
+ ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :])
37
+
38
+ data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2]
39
+ data_u = np.nan_to_num(data_u, nan=0)
40
+ u = data_u[0]
41
+ u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :])
42
+
43
+ data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2]
44
+ data_v = np.nan_to_num(data_v, nan=0)
45
+ v = data_v[0]
46
+ v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :])
47
+
48
+ data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2]
49
+ data_w_20 = np.nan_to_num(data_w_20, nan=0)
50
+ w_20 = data_w_20[0]
51
+ w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :])
52
+
53
+ data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2]
54
+ data_w_40 = np.nan_to_num(data_w_40, nan=0)
55
+ w_40 = data_w_40[0]
56
+ w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :])
57
+
58
+ data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2]
59
+ data_w_60 = np.nan_to_num(data_w_60, nan=0)
60
+ w_60 = data_w_60[0]
61
+ w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :])
62
+
63
+ data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2]
64
+ data_b = np.nan_to_num(data_b, nan=0)
65
+ b = data_b[0]
66
+ b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :])
67
+
68
+ return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0)
69
+
70
+ def __len__(self):
71
+ return self.size
72
+
73
+ def __del__(self):
74
+ self.data_file.close()
75
+
76
+
77
+ testdataset = testDataset()
78
+ testloader=Data.DataLoader(
79
+ dataset=testdataset,
80
+ batch_size=1,
81
+ shuffle=False,
82
+ num_workers=0
83
+ )
84
+
85
+ model = VISION().cuda()
86
+ checkpoint_path = './checkpoint_VISION/best_mse.pt'
87
+ ckpt = pt.load(checkpoint_path, map_location='cpu')
88
+ state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
89
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
90
+ model.load_state_dict(state_dict, strict=True)
91
+
92
+ model.eval()
93
+
94
+ folder_path = './result'
95
+ if not os.path.exists(folder_path):
96
+ os.makedirs(folder_path)
97
+
98
+ output_path = osp.join(folder_path, 'results_co_ssh_u_v_b_vision.h5')
99
+ if os.path.exists(output_path):
100
+ os.remove(output_path)
101
+
102
+ N = len(testdataset)
103
+ H = 512
104
+ W = 512
105
+ f_out = h5py.File(output_path, 'w')
106
+ dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32')
107
+ dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32')
108
+
109
+
110
+ buffer_preds = []
111
+ buffer_trues = []
112
+ buffer_indices = []
113
+
114
+ batch_size_to_save = 1
115
+ current_count = 0
116
+
117
+ with pt.no_grad():
118
+ num = 0
119
+ for data in tqdm(testloader, desc="Loading data"):
120
+ xbatch = data[:, 0:4, :, :].cuda().float()
121
+ ybatch = data[:, 4:7, :, :].cuda().float()
122
+ out = model(xbatch)
123
+ print(out.shape)
124
+
125
+ mse = pt.mean((ybatch - out) ** 2)
126
+ print(num, mse)
127
+
128
+ preds_np = out.detach().cpu().numpy().astype(np.float32)
129
+ trues_np = ybatch.detach().cpu().numpy().astype(np.float32)
130
+
131
+ buffer_preds.append(preds_np)
132
+ buffer_trues.append(trues_np)
133
+ buffer_indices.append(num)
134
+
135
+ if len(buffer_preds) == batch_size_to_save:
136
+ preds_block = np.concatenate(buffer_preds, axis=0)
137
+ trues_block = np.concatenate(buffer_trues, axis=0)
138
+ indices_block = buffer_indices
139
+
140
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
141
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
142
+
143
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
144
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
145
+
146
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
147
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
148
+
149
+ buffer_preds.clear()
150
+ buffer_trues.clear()
151
+ buffer_indices.clear()
152
+
153
+ num += 1
154
+
155
+ if len(buffer_preds) > 0:
156
+ preds_block = np.concatenate(buffer_preds, axis=0)
157
+ trues_block = np.concatenate(buffer_trues, axis=0)
158
+ indices_block = buffer_indices
159
+
160
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
161
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
162
+
163
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
164
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
165
+
166
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
167
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
168
+
169
+ buffer_preds.clear()
170
+ buffer_trues.clear()
171
+ buffer_indices.clear()
172
+
173
+ f_out.close()
174
+ print("Results successfully saved to HDF5 file.")
inference_io_ssh_u_v_vision.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xarray as xr
2
+ import torch as pt
3
+ from torch import nn
4
+ import torch.utils.data as Data
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import os
8
+ from model.vision import VISION
9
+ import os.path as osp
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import h5py
13
+
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16
+
17
+ high1 = 0
18
+ high2 = 512
19
+ width1 = 0
20
+ width2 = 512
21
+
22
+ class testDataset(Dataset):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.data_path = './data/KD48_demo.h5'
26
+ self.data_file = h5py.File(self.data_path, 'r')
27
+ self.mean = np.load('./data/mean.npy')
28
+ self.std = np.load('./data/std.npy')
29
+ self.size = 50
30
+
31
+ def __getitem__(self, index):
32
+
33
+ data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2]
34
+ data = np.nan_to_num(data, nan=0)
35
+ ssh = data[0]
36
+ ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :])
37
+
38
+ data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2]
39
+ data_u = np.nan_to_num(data_u, nan=0)
40
+ u = data_u[0]
41
+ u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :])
42
+
43
+ data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2]
44
+ data_v = np.nan_to_num(data_v, nan=0)
45
+ v = data_v[0]
46
+ v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :])
47
+
48
+ data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2]
49
+ data_w_20 = np.nan_to_num(data_w_20, nan=0)
50
+ w_20 = data_w_20[0]
51
+ w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :])
52
+
53
+ data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2]
54
+ data_w_40 = np.nan_to_num(data_w_40, nan=0)
55
+ w_40 = data_w_40[0]
56
+ w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :])
57
+
58
+ data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2]
59
+ data_w_60 = np.nan_to_num(data_w_60, nan=0)
60
+ w_60 = data_w_60[0]
61
+ w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :])
62
+
63
+ data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2]
64
+ data_b = np.nan_to_num(data_b, nan=0)
65
+ b = data_b[0]
66
+ b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :])
67
+
68
+ return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0)
69
+
70
+ def __len__(self):
71
+ return self.size
72
+
73
+ def __del__(self):
74
+ self.data_file.close()
75
+
76
+
77
+ testdataset = testDataset()
78
+ testloader=Data.DataLoader(
79
+ dataset=testdataset,
80
+ batch_size=1,
81
+ shuffle=False,
82
+ num_workers=0
83
+ )
84
+
85
+ model = VISION().cuda()
86
+ checkpoint_path = './checkpoint_VISION/best_mse.pt'
87
+ ckpt = pt.load(checkpoint_path, map_location='cpu')
88
+ state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
89
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
90
+ model.load_state_dict(state_dict, strict=True)
91
+
92
+ model.eval()
93
+
94
+ folder_path = './result'
95
+ if not os.path.exists(folder_path):
96
+ os.makedirs(folder_path)
97
+
98
+ output_path = osp.join(folder_path, 'results_io_ssh_u_v_vision.h5')
99
+ if os.path.exists(output_path):
100
+ os.remove(output_path)
101
+
102
+ N = len(testdataset)
103
+ H = 512
104
+ W = 512
105
+ f_out = h5py.File(output_path, 'w')
106
+ dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32')
107
+ dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32')
108
+
109
+
110
+ buffer_preds = []
111
+ buffer_trues = []
112
+ buffer_indices = []
113
+
114
+ batch_size_to_save = 1
115
+ current_count = 0
116
+
117
+ with pt.no_grad():
118
+ num = 0
119
+ for data in tqdm(testloader, desc="Loading data"):
120
+ xbatch = data[:, 0:3, :, :].cuda().float()
121
+ ybatch = data[:, 4:7, :, :].cuda().float()
122
+ out = model(xbatch)
123
+ print(out.shape)
124
+
125
+ mse = pt.mean((ybatch - out) ** 2)
126
+ print(num, mse)
127
+
128
+ preds_np = out.detach().cpu().numpy().astype(np.float32)
129
+ trues_np = ybatch.detach().cpu().numpy().astype(np.float32)
130
+
131
+ buffer_preds.append(preds_np)
132
+ buffer_trues.append(trues_np)
133
+ buffer_indices.append(num)
134
+
135
+ if len(buffer_preds) == batch_size_to_save:
136
+ preds_block = np.concatenate(buffer_preds, axis=0)
137
+ trues_block = np.concatenate(buffer_trues, axis=0)
138
+ indices_block = buffer_indices
139
+
140
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
141
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
142
+
143
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
144
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
145
+
146
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
147
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
148
+
149
+ buffer_preds.clear()
150
+ buffer_trues.clear()
151
+ buffer_indices.clear()
152
+
153
+ num += 1
154
+
155
+ if len(buffer_preds) > 0:
156
+ preds_block = np.concatenate(buffer_preds, axis=0)
157
+ trues_block = np.concatenate(buffer_trues, axis=0)
158
+ indices_block = buffer_indices
159
+
160
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
161
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
162
+
163
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
164
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
165
+
166
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
167
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
168
+
169
+ buffer_preds.clear()
170
+ buffer_trues.clear()
171
+ buffer_indices.clear()
172
+
173
+ f_out.close()
174
+ print("Results successfully saved to HDF5 file.")
inference_io_ssh_vision.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xarray as xr
2
+ import torch as pt
3
+ from torch import nn
4
+ import torch.utils.data as Data
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import os
8
+ from model.vision import VISION
9
+ import os.path as osp
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import h5py
13
+
14
+
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16
+
17
+ high1 = 0
18
+ high2 = 512
19
+ width1 = 0
20
+ width2 = 512
21
+
22
+ class testDataset(Dataset):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.data_path = './data/KD48_demo.h5'
26
+ self.data_file = h5py.File(self.data_path, 'r')
27
+ self.mean = np.load('./data/mean.npy')
28
+ self.std = np.load('./data/std.npy')
29
+ self.size = 50
30
+
31
+ def __getitem__(self, index):
32
+
33
+ data = self.data_file['fields'][index, 0:1, high1:high2, width1:width2]
34
+ data = np.nan_to_num(data, nan=0)
35
+ ssh = data[0]
36
+ ssh = (ssh - self.mean[0, 0, :, :])/(self.std[0, 0, :, :])
37
+
38
+ data_u = self.data_file['fields'][index, 1:2, high1:high2, width1:width2]
39
+ data_u = np.nan_to_num(data_u, nan=0)
40
+ u = data_u[0]
41
+ u = (u - self.mean[0, 1, :, :])/(self.std[0, 1, :, :])
42
+
43
+ data_v = self.data_file['fields'][index, 2:3, high1:high2, width1:width2]
44
+ data_v = np.nan_to_num(data_v, nan=0)
45
+ v = data_v[0]
46
+ v = (v - self.mean[0, 2, :, :])/(self.std[0, 2, :, :])
47
+
48
+ data_w_20 = self.data_file['fields'][index, 3:4, high1:high2, width1:width2]
49
+ data_w_20 = np.nan_to_num(data_w_20, nan=0)
50
+ w_20 = data_w_20[0]
51
+ w_20 = (w_20 - self.mean[0, 3, :, :])/(self.std[0, 3, :, :])
52
+
53
+ data_w_40 = self.data_file['fields'][index, 4:5, high1:high2, width1:width2]
54
+ data_w_40 = np.nan_to_num(data_w_40, nan=0)
55
+ w_40 = data_w_40[0]
56
+ w_40 = (w_40 - self.mean[0, 4, :, :])/(self.std[0, 4, :, :])
57
+
58
+ data_w_60 = self.data_file['fields'][index, 5:6, high1:high2, width1:width2]
59
+ data_w_60 = np.nan_to_num(data_w_60, nan=0)
60
+ w_60 = data_w_60[0]
61
+ w_60 = (w_60 - self.mean[0, 5, :, :])/(self.std[0, 5, :, :])
62
+
63
+ data_b = self.data_file['fields'][index, 8:9, high1:high2, width1:width2]
64
+ data_b = np.nan_to_num(data_b, nan=0)
65
+ b = data_b[0]
66
+ b = (b - self.mean[0, 8, :, :])/(self.std[0, 8, :, :])
67
+
68
+ return np.stack((ssh, u, v, b, w_20, w_40, w_60), axis=0)
69
+
70
+ def __len__(self):
71
+ return self.size
72
+
73
+ def __del__(self):
74
+ self.data_file.close()
75
+
76
+
77
+ testdataset = testDataset()
78
+ testloader=Data.DataLoader(
79
+ dataset=testdataset,
80
+ batch_size=1,
81
+ shuffle=False,
82
+ num_workers=0
83
+ )
84
+
85
+ model = VISION().cuda()
86
+ checkpoint_path = './checkpoint_VISION/best_mse.pt'
87
+ ckpt = pt.load(checkpoint_path, map_location='cpu')
88
+ state_dict = ckpt['model'] if isinstance(ckpt, dict) and 'model' in ckpt else ckpt
89
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
90
+ model.load_state_dict(state_dict, strict=True)
91
+
92
+ model.eval()
93
+
94
+ folder_path = './result'
95
+ if not os.path.exists(folder_path):
96
+ os.makedirs(folder_path)
97
+
98
+ output_path = osp.join(folder_path, 'results_io_ssh_vision.h5')
99
+ if os.path.exists(output_path):
100
+ os.remove(output_path)
101
+
102
+ N = len(testdataset)
103
+ H = 512
104
+ W = 512
105
+ f_out = h5py.File(output_path, 'w')
106
+ dset_pred = f_out.create_dataset('predicted', shape=(N, 3, H, W), dtype='float32')
107
+ dset_true = f_out.create_dataset('ground_truth', shape=(N, 3, H, W), dtype='float32')
108
+
109
+
110
+ buffer_preds = []
111
+ buffer_trues = []
112
+ buffer_indices = []
113
+
114
+ batch_size_to_save = 1
115
+ current_count = 10
116
+
117
+ with pt.no_grad():
118
+ num = 0
119
+ for data in tqdm(testloader, desc="Loading data"):
120
+ xbatch = data[:, 0:1, :, :].cuda().float()
121
+ ybatch = data[:, 4:7, :, :].cuda().float()
122
+ out = model(xbatch)
123
+ print(out.shape)
124
+
125
+ mse = pt.mean((ybatch - out) ** 2)
126
+ print(num, mse)
127
+
128
+ preds_np = out.detach().cpu().numpy().astype(np.float32)
129
+ trues_np = ybatch.detach().cpu().numpy().astype(np.float32)
130
+
131
+ buffer_preds.append(preds_np)
132
+ buffer_trues.append(trues_np)
133
+ buffer_indices.append(num)
134
+
135
+ if len(buffer_preds) == batch_size_to_save:
136
+ preds_block = np.concatenate(buffer_preds, axis=0)
137
+ trues_block = np.concatenate(buffer_trues, axis=0)
138
+ indices_block = buffer_indices
139
+
140
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
141
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
142
+
143
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
144
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
145
+
146
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
147
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
148
+
149
+ buffer_preds.clear()
150
+ buffer_trues.clear()
151
+ buffer_indices.clear()
152
+
153
+ num += 1
154
+
155
+ if len(buffer_preds) > 0:
156
+ preds_block = np.concatenate(buffer_preds, axis=0)
157
+ trues_block = np.concatenate(buffer_trues, axis=0)
158
+ indices_block = buffer_indices
159
+
160
+ dset_pred[indices_block, 0, :, :] = preds_block[:, 0, :, :]
161
+ dset_true[indices_block, 0, :, :] = trues_block[:, 0, :, :]
162
+
163
+ dset_pred[indices_block, 1, :, :] = preds_block[:, 1, :, :]
164
+ dset_true[indices_block, 1, :, :] = trues_block[:, 1, :, :]
165
+
166
+ dset_pred[indices_block, 2, :, :] = preds_block[:, 2, :, :]
167
+ dset_true[indices_block, 2, :, :] = trues_block[:, 2, :, :]
168
+
169
+ buffer_preds.clear()
170
+ buffer_trues.clear()
171
+ buffer_indices.clear()
172
+
173
+ f_out.close()
174
+ print("Results successfully saved to HDF5 file.")
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model/.ipynb_checkpoints/vision-checkpoint.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from thop import profile
10
+
11
+
12
+ class VISION(nn.Module):
13
+ def __init__(self,channel = 16):
14
+ super(VISION,self).__init__()
15
+ self.aoe = AOE(channel)
16
+ self.gsao = GSAO(channel)
17
+
18
+ def forward(self,x):
19
+ x_aoe = self.aoe(x)
20
+ out = self.gsao(x_aoe)
21
+
22
+ return out
23
+
24
+ class GSAO(nn.Module):
25
+ def __init__(self,channel = 16):
26
+ super(GSAO,self).__init__()
27
+
28
+ self.gsao_left = GSAO_Left(channel)
29
+
30
+ self.ssdc = SSDC(channel)
31
+
32
+ self.gsao_right = GSAO_Right(channel)
33
+
34
+ self.gsao_out = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
35
+
36
+ def forward(self,x):
37
+
38
+ L,M,S,SS = self.gsao_left(x)
39
+ ssdc = self.ssdc(SS)
40
+ x_out = self.gsao_right(ssdc,SS,S,M,L)
41
+ out = self.gsao_out(x_out)
42
+
43
+ return out
44
+
45
+
46
+ class AOE(nn.Module):
47
+ def __init__(self,channel = 16):
48
+ super(AOE,self).__init__()
49
+
50
+ self.uoa = UOA(channel)
51
+ self.scp = SCP(channel)
52
+
53
+ def forward(self,x):
54
+ x_in = self.uoa(x)
55
+ x_out = self.scp(x_in)#3 16
56
+
57
+ return x_out
58
+
59
+ class UOA(nn.Module):
60
+ def __init__(self,channel = 16):
61
+ super(UOA,self).__init__()
62
+
63
+ self.Haze_in1 = nn.Conv2d(1,channel,kernel_size=1,stride=1,padding=0,bias=False)
64
+ self.Haze_in3 = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
65
+ self.Haze_in4 = nn.Conv2d(4,channel,kernel_size=1,stride=1,padding=0,bias=False)
66
+
67
+ def forward(self,x):
68
+ if x.shape[1] == 1:
69
+ x_in = self.Haze_in1(x)#3 16
70
+ elif x.shape[1] == 3:
71
+ x_in = self.Haze_in3(x)#3 16
72
+ elif x.shape[1] == 4:
73
+ x_in = self.Haze_in4(x)#3 16
74
+
75
+ return x_in
76
+
77
+ class SCP(nn.Module):
78
+ def __init__(self, channel):
79
+ super(SCP, self).__init__()
80
+ self.cgm = CGM(channel)
81
+ self.cim = CIM(channel)
82
+
83
+ def forward(self, x):
84
+ x_cgm = self.cgm(x)
85
+ x_cim = self.cim(x_cgm, x)
86
+
87
+ return x_cim
88
+
89
+ class GSAO_Left(nn.Module):
90
+ def __init__(self,channel):
91
+ super(GSAO_Left,self).__init__()
92
+
93
+ self.el = GARO(channel)#16
94
+ self.em = GARO(channel*2)#32
95
+ self.es = GARO(channel*4)#64
96
+ self.ess = GARO(channel*8)#128
97
+ self.esss = GARO(channel*16)#256
98
+
99
+ self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
100
+ self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
101
+ self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
102
+ self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
103
+
104
+ def forward(self,x):
105
+
106
+ elout = self.el(x)#16
107
+ x_emin = self.conv_eltem(self.maxpool(elout))#32
108
+ emout = self.em(x_emin)
109
+ x_esin = self.conv_emtes(self.maxpool(emout))
110
+ esout = self.es(x_esin)
111
+ x_esin = self.conv_estess(self.maxpool(esout))
112
+ essout = self.ess(x_esin)#128
113
+
114
+ return elout,emout,esout,essout
115
+
116
+ class SSDC(nn.Module):
117
+ def __init__(self,channel):
118
+ super(SSDC,self).__init__()
119
+
120
+ self.s1 = SKO(channel*8)#128
121
+ self.s2 = SKO(channel*8)#128
122
+
123
+ def forward(self,x):
124
+ ssdc1 = self.s1(x) + x
125
+ ssdc2 = self.s2(ssdc1) + ssdc1
126
+
127
+ return ssdc2
128
+
129
+ class GSAO_Right(nn.Module):
130
+ def __init__(self,channel):
131
+ super(GSAO_Right,self).__init__()
132
+
133
+ self.dss = GARO(channel*8)#128
134
+ self.ds = GARO(channel*4)#64
135
+ self.dm = GARO(channel*2)#32
136
+ self.dl = GARO(channel)#16
137
+
138
+ self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
139
+ self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
140
+ self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
141
+ self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
142
+
143
+ def _upsample(self,x):
144
+ _,_,H,W = x.size()
145
+ return F.upsample(x,size=(2*H,2*W),mode='bilinear')
146
+
147
+ def forward(self,x,ss,s,m,l):
148
+
149
+ dssout = self.dss(x+ss)
150
+ x_dsin = self.conv_dsstds(self._upsample(dssout))
151
+ dsout = self.ds(x_dsin+s)
152
+ x_dmin = self.conv_dstdm(self._upsample(dsout))
153
+ dmout = self.dm(x_dmin+m)
154
+ x_dlin = self.conv_dmtdl(self._upsample(dmout))
155
+ dlout = self.dl(x_dlin+l)
156
+
157
+ return dlout
158
+
159
+
160
+ class SKO(nn.Module):
161
+ def __init__(self, in_ch, M=3, G=1, r=4, stride=1, L=32) -> None:
162
+ super().__init__()
163
+
164
+ d = max(int(in_ch/r), L)
165
+ self.M = M
166
+ self.in_ch = in_ch
167
+ self.convs = nn.ModuleList([])
168
+ for i in range(M):
169
+ self.convs.append(
170
+ nn.Sequential(
171
+ nn.Conv2d(in_ch, in_ch, kernel_size=3+i*2, stride=stride, padding = 1+i, groups=G),
172
+ nn.BatchNorm2d(in_ch),
173
+ nn.ReLU(inplace=True)
174
+ )
175
+ )
176
+ # print("D:", d)
177
+ self.fc = nn.Linear(in_ch, d)
178
+ self.fcs = nn.ModuleList([])
179
+ for i in range(M):
180
+ self.fcs.append(nn.Linear(d, in_ch))
181
+ self.softmax = nn.Softmax(dim=1)
182
+
183
+ def forward(self, x):
184
+ for i, conv in enumerate(self.convs):
185
+ fea = conv(x).clone().unsqueeze_(dim=1).clone()
186
+ if i == 0:
187
+ feas = fea
188
+ else:
189
+ feas = torch.cat([feas.clone(), fea], dim=1)
190
+ fea_U = torch.sum(feas.clone(), dim=1)
191
+ fea_s = fea_U.clone().mean(-1).mean(-1)
192
+ fea_z = self.fc(fea_s)
193
+ for i, fc in enumerate(self.fcs):
194
+ vector = fc(fea_z).clone().unsqueeze_(dim=1)
195
+ if i == 0:
196
+ attention_vectors = vector
197
+ else:
198
+ attention_vectors = torch.cat([attention_vectors.clone(), vector], dim=1)
199
+ attention_vectors = self.softmax(attention_vectors.clone())
200
+ attention_vectors = attention_vectors.clone().unsqueeze(-1).unsqueeze(-1)
201
+ fea_v = (feas * attention_vectors).clone().sum(dim=1)
202
+ return fea_v
203
+
204
+
205
+ class GARO(nn.Module):
206
+ def __init__(self, channel, norm=False):
207
+ super(GARO, self).__init__()
208
+
209
+ self.conv_1_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
210
+ self.conv_2_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
211
+ self.act = nn.PReLU(channel)
212
+ self.norm = nn.GroupNorm(num_channels=channel, num_groups=1)
213
+
214
+ def _upsample(self, x, y):
215
+ _, _, H, W = y.size()
216
+ return F.upsample(x, size=(H, W), mode='bilinear')
217
+
218
+ def forward(self, x):
219
+ x_1 = self.act(self.norm(self.conv_1_1(x)))
220
+ x_2 = self.act(self.norm(self.conv_2_1(x_1))) + x
221
+
222
+ return x_2
223
+
224
+ class CGM(nn.Module):
225
+ def __init__(self, channel, prompt_len=3, prompt_size=96, lin_dim=16):
226
+ super(CGM, self).__init__()
227
+ self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, channel, prompt_size, prompt_size))
228
+ self.linear_layer = nn.Linear(lin_dim, prompt_len)
229
+ self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
230
+
231
+ def forward(self, x):
232
+ B, C, H, W = x.shape
233
+ emb = x.mean(dim=(-2, -1))
234
+ prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
235
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
236
+ 1, 1,
237
+ 1,
238
+ 1).squeeze(
239
+ 1)
240
+ prompt = torch.sum(prompt, dim=1)
241
+ prompt = F.interpolate(prompt, (H, W), mode="bilinear")
242
+ prompt = self.conv3x3(prompt)
243
+
244
+ return prompt
245
+
246
+ class CIM(nn.Module):
247
+ def __init__(self, channel):
248
+ super(CIM, self).__init__()
249
+ self.res = ResBlock(2*channel, 2*channel)
250
+ self.conv3x3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
251
+
252
+ def forward(self, prompt, x):
253
+
254
+ x = torch.cat((prompt, x), dim=1)
255
+ x = self.res(x)
256
+ out = self.conv3x3(x)
257
+
258
+ return out
259
+
260
+
261
+ class DeformConv2d(nn.Module):
262
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
263
+ super(DeformConv2d, self).__init__()
264
+ self.kernel_size = kernel_size
265
+ self.padding = padding
266
+ self.stride = stride
267
+ self.zero_padding = nn.ZeroPad2d(padding)
268
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
269
+
270
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
271
+ nn.init.constant_(self.p_conv.weight, 0)
272
+ self.p_conv.register_backward_hook(self._set_lr)
273
+
274
+ self.modulation = modulation
275
+ if modulation:
276
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
277
+ nn.init.constant_(self.m_conv.weight, 0)
278
+ self.m_conv.register_backward_hook(self._set_lr)
279
+
280
+ @staticmethod
281
+ def _set_lr(module, grad_input, grad_output):
282
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
283
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
284
+
285
+ def forward(self, x):
286
+ offset = self.p_conv(x)
287
+ if self.modulation:
288
+ m = torch.sigmoid(self.m_conv(x))
289
+
290
+ dtype = offset.data.type()
291
+ ks = self.kernel_size
292
+ N = offset.size(1) // 2
293
+
294
+ if self.padding:
295
+ x = self.zero_padding(x)
296
+
297
+ p = self._get_p(offset, dtype)
298
+
299
+ p = p.contiguous().permute(0, 2, 3, 1)
300
+ q_lt = p.detach().floor()
301
+ q_rb = q_lt + 1
302
+
303
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
304
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
305
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
306
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
307
+
308
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
309
+
310
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
311
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
312
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
313
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
314
+
315
+ x_q_lt = self._get_x_q(x, q_lt, N)
316
+ x_q_rb = self._get_x_q(x, q_rb, N)
317
+ x_q_lb = self._get_x_q(x, q_lb, N)
318
+ x_q_rt = self._get_x_q(x, q_rt, N)
319
+
320
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
321
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
322
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
323
+ g_rt.unsqueeze(dim=1) * x_q_rt
324
+
325
+ if self.modulation:
326
+ m = m.contiguous().permute(0, 2, 3, 1)
327
+ m = m.unsqueeze(dim=1)
328
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
329
+ x_offset *= m
330
+
331
+ x_offset = self._reshape_x_offset(x_offset, ks)
332
+ out = self.conv(x_offset)
333
+
334
+ return out
335
+
336
+ def _get_p_n(self, N, dtype):
337
+ p_n_x, p_n_y = torch.meshgrid(
338
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
339
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
340
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
341
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
342
+
343
+ return p_n
344
+
345
+ def _get_p_0(self, h, w, N, dtype):
346
+ p_0_x, p_0_y = torch.meshgrid(
347
+ torch.arange(1, h*self.stride+1, self.stride),
348
+ torch.arange(1, w*self.stride+1, self.stride))
349
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
350
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
351
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
352
+
353
+ return p_0
354
+
355
+ def _get_p(self, offset, dtype):
356
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
357
+
358
+ p_n = self._get_p_n(N, dtype)
359
+ p_0 = self._get_p_0(h, w, N, dtype)
360
+ p = p_0 + p_n + offset
361
+ return p
362
+
363
+ def _get_x_q(self, x, q, N):
364
+ b, h, w, _ = q.size()
365
+ padded_w = x.size(3)
366
+ c = x.size(1)
367
+ x = x.contiguous().view(b, c, -1)
368
+
369
+ index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
370
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
371
+
372
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
373
+
374
+ return x_offset
375
+
376
+ @staticmethod
377
+ def _reshape_x_offset(x_offset, ks):
378
+ b, c, h, w, N = x_offset.size()
379
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
380
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
381
+
382
+ return x_offset
383
+
384
+ class DeformConv2d(nn.Module):
385
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
386
+ super(DeformConv2d, self).__init__()
387
+ self.kernel_size = kernel_size
388
+ self.padding = padding
389
+ self.stride = stride
390
+ self.zero_padding = nn.ZeroPad2d(padding)
391
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
392
+
393
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
394
+ nn.init.constant_(self.p_conv.weight, 0)
395
+ self.p_conv.register_backward_hook(self._set_lr)
396
+
397
+ self.modulation = modulation
398
+ if modulation:
399
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
400
+ nn.init.constant_(self.m_conv.weight, 0)
401
+ self.m_conv.register_backward_hook(self._set_lr)
402
+
403
+ @staticmethod
404
+ def _set_lr(module, grad_input, grad_output):
405
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
406
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
407
+
408
+ def forward(self, x):
409
+ offset = self.p_conv(x)
410
+ if self.modulation:
411
+ m = torch.sigmoid(self.m_conv(x))
412
+
413
+ dtype = offset.data.type()
414
+ ks = self.kernel_size
415
+ N = offset.size(1) // 2
416
+
417
+ if self.padding:
418
+ x = self.zero_padding(x)
419
+
420
+ p = self._get_p(offset, dtype)
421
+
422
+ p = p.contiguous().permute(0, 2, 3, 1)
423
+ q_lt = p.detach().floor()
424
+ q_rb = q_lt + 1
425
+
426
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
427
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
428
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
429
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
430
+
431
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
432
+
433
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
434
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
435
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
436
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
437
+
438
+ x_q_lt = self._get_x_q(x, q_lt, N)
439
+ x_q_rb = self._get_x_q(x, q_rb, N)
440
+ x_q_lb = self._get_x_q(x, q_lb, N)
441
+ x_q_rt = self._get_x_q(x, q_rt, N)
442
+
443
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
444
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
445
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
446
+ g_rt.unsqueeze(dim=1) * x_q_rt
447
+
448
+ if self.modulation:
449
+ m = m.contiguous().permute(0, 2, 3, 1)
450
+ m = m.unsqueeze(dim=1)
451
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
452
+ x_offset *= m
453
+
454
+ x_offset = self._reshape_x_offset(x_offset, ks)
455
+ out = self.conv(x_offset)
456
+
457
+ return out
458
+
459
+ def _get_p_n(self, N, dtype):
460
+ p_n_x, p_n_y = torch.meshgrid(
461
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
462
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
463
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
464
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
465
+
466
+ return p_n
467
+
468
+ def _get_p_0(self, h, w, N, dtype):
469
+ p_0_x, p_0_y = torch.meshgrid(
470
+ torch.arange(1, h*self.stride+1, self.stride),
471
+ torch.arange(1, w*self.stride+1, self.stride))
472
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
473
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
474
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
475
+
476
+ return p_0
477
+
478
+ def _get_p(self, offset, dtype):
479
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
480
+
481
+ p_n = self._get_p_n(N, dtype)
482
+ p_0 = self._get_p_0(h, w, N, dtype)
483
+ p = p_0 + p_n + offset
484
+ return p
485
+
486
+ def _get_x_q(self, x, q, N):
487
+ b, h, w, _ = q.size()
488
+ padded_w = x.size(3)
489
+ c = x.size(1)
490
+ x = x.contiguous().view(b, c, -1)
491
+
492
+ index = q[..., :N]*padded_w + q[..., N:]
493
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
494
+
495
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
496
+
497
+ return x_offset
498
+
499
+ @staticmethod
500
+ def _reshape_x_offset(x_offset, ks):
501
+ b, c, h, w, N = x_offset.size()
502
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
503
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
504
+
505
+ return x_offset
506
+
507
+
508
+ class BasicConv(nn.Module):
509
+ def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
510
+ super(BasicConv, self).__init__()
511
+ if bias and norm:
512
+ bias = False
513
+
514
+ padding = kernel_size // 2
515
+ layers = list()
516
+ if transpose:
517
+ padding = kernel_size // 2 -1
518
+ layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
519
+ else:
520
+ layers.append(
521
+ nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
522
+ if norm:
523
+ layers.append(nn.BatchNorm2d(out_channel))
524
+ if relu:
525
+ layers.append(nn.GELU())
526
+ self.main = nn.Sequential(*layers)
527
+
528
+ def forward(self, x):
529
+ return self.main(x)
530
+
531
+
532
+ class ResBlock(nn.Module):
533
+ def __init__(self, in_channel, out_channel):
534
+ super(ResBlock, self).__init__()
535
+ self.main = nn.Sequential(
536
+ BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
537
+ BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
538
+ )
539
+
540
+ def forward(self, x):
541
+ return self.main(x) + x
542
+
543
+
544
+ from thop import profile
545
+
546
+ if __name__ == '__main__':
547
+
548
+ device = "cuda" if torch.cuda.is_available() else "cpu"
549
+
550
+ net = VISION().to(device)
551
+
552
+ input = torch.randn(1, 4, 512, 512).to(device)
553
+ output = net(input)
554
+
555
+ macs, params = profile(net, inputs=(input, ))
556
+
557
+ print('macs: ', macs, 'params: ', params)
558
+ print('macs: %.2f G, params: %.2f M' % (macs / 1000000000.0, params / 1000000.0))
559
+ print(output.shape)
model/vision.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from thop import profile
10
+
11
+
12
+ class VISION(nn.Module):
13
+ def __init__(self,channel = 16):
14
+ super(VISION,self).__init__()
15
+ self.aoe = AOE(channel)
16
+ self.gsao = GSAO(channel)
17
+
18
+ def forward(self,x):
19
+ x_aoe = self.aoe(x)
20
+ out = self.gsao(x_aoe)
21
+
22
+ return out
23
+
24
+ class GSAO(nn.Module):
25
+ def __init__(self,channel = 16):
26
+ super(GSAO,self).__init__()
27
+
28
+ self.gsao_left = GSAO_Left(channel)
29
+
30
+ self.ssdc = SSDC(channel)
31
+
32
+ self.gsao_right = GSAO_Right(channel)
33
+
34
+ self.gsao_out = nn.Conv2d(channel,3,kernel_size=1,stride=1,padding=0,bias=False)
35
+
36
+ def forward(self,x):
37
+
38
+ L,M,S,SS = self.gsao_left(x)
39
+ ssdc = self.ssdc(SS)
40
+ x_out = self.gsao_right(ssdc,SS,S,M,L)
41
+ out = self.gsao_out(x_out)
42
+
43
+ return out
44
+
45
+
46
+ class AOE(nn.Module):
47
+ def __init__(self,channel = 16):
48
+ super(AOE,self).__init__()
49
+
50
+ self.uoa = UOA(channel)
51
+ self.scp = SCP(channel)
52
+
53
+ def forward(self,x):
54
+ x_in = self.uoa(x)
55
+ x_out = self.scp(x_in)#3 16
56
+
57
+ return x_out
58
+
59
+ class UOA(nn.Module):
60
+ def __init__(self,channel = 16):
61
+ super(UOA,self).__init__()
62
+
63
+ self.Haze_in1 = nn.Conv2d(1,channel,kernel_size=1,stride=1,padding=0,bias=False)
64
+ self.Haze_in3 = nn.Conv2d(3,channel,kernel_size=1,stride=1,padding=0,bias=False)
65
+ self.Haze_in4 = nn.Conv2d(4,channel,kernel_size=1,stride=1,padding=0,bias=False)
66
+
67
+ def forward(self,x):
68
+ if x.shape[1] == 1:
69
+ x_in = self.Haze_in1(x)#3 16
70
+ elif x.shape[1] == 3:
71
+ x_in = self.Haze_in3(x)#3 16
72
+ elif x.shape[1] == 4:
73
+ x_in = self.Haze_in4(x)#3 16
74
+
75
+ return x_in
76
+
77
+ class SCP(nn.Module):
78
+ def __init__(self, channel):
79
+ super(SCP, self).__init__()
80
+ self.cgm = CGM(channel)
81
+ self.cim = CIM(channel)
82
+
83
+ def forward(self, x):
84
+ x_cgm = self.cgm(x)
85
+ x_cim = self.cim(x_cgm, x)
86
+
87
+ return x_cim
88
+
89
+ class GSAO_Left(nn.Module):
90
+ def __init__(self,channel):
91
+ super(GSAO_Left,self).__init__()
92
+
93
+ self.el = GARO(channel)#16
94
+ self.em = GARO(channel*2)#32
95
+ self.es = GARO(channel*4)#64
96
+ self.ess = GARO(channel*8)#128
97
+ self.esss = GARO(channel*16)#256
98
+
99
+ self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
100
+ self.conv_eltem = nn.Conv2d(channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#16 32
101
+ self.conv_emtes = nn.Conv2d(2*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#32 64
102
+ self.conv_estess = nn.Conv2d(4*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 128
103
+
104
+ def forward(self,x):
105
+
106
+ elout = self.el(x)#16
107
+ x_emin = self.conv_eltem(self.maxpool(elout))#32
108
+ emout = self.em(x_emin)
109
+ x_esin = self.conv_emtes(self.maxpool(emout))
110
+ esout = self.es(x_esin)
111
+ x_esin = self.conv_estess(self.maxpool(esout))
112
+ essout = self.ess(x_esin)#128
113
+
114
+ return elout,emout,esout,essout
115
+
116
+ class SSDC(nn.Module):
117
+ def __init__(self,channel):
118
+ super(SSDC,self).__init__()
119
+
120
+ self.s1 = SKO(channel*8)#128
121
+ self.s2 = SKO(channel*8)#128
122
+
123
+ def forward(self,x):
124
+ ssdc1 = self.s1(x) + x
125
+ ssdc2 = self.s2(ssdc1) + ssdc1
126
+
127
+ return ssdc2
128
+
129
+ class GSAO_Right(nn.Module):
130
+ def __init__(self,channel):
131
+ super(GSAO_Right,self).__init__()
132
+
133
+ self.dss = GARO(channel*8)#128
134
+ self.ds = GARO(channel*4)#64
135
+ self.dm = GARO(channel*2)#32
136
+ self.dl = GARO(channel)#16
137
+
138
+ self.conv_dssstdss = nn.Conv2d(16*channel,8*channel,kernel_size=1,stride=1,padding=0,bias=False)#256 128
139
+ self.conv_dsstds = nn.Conv2d(8*channel,4*channel,kernel_size=1,stride=1,padding=0,bias=False)#128 64
140
+ self.conv_dstdm = nn.Conv2d(4*channel,2*channel,kernel_size=1,stride=1,padding=0,bias=False)#64 32
141
+ self.conv_dmtdl = nn.Conv2d(2*channel,channel,kernel_size=1,stride=1,padding=0,bias=False)#32 16
142
+
143
+ def _upsample(self,x):
144
+ _,_,H,W = x.size()
145
+ return F.upsample(x,size=(2*H,2*W),mode='bilinear')
146
+
147
+ def forward(self,x,ss,s,m,l):
148
+
149
+ dssout = self.dss(x+ss)
150
+ x_dsin = self.conv_dsstds(self._upsample(dssout))
151
+ dsout = self.ds(x_dsin+s)
152
+ x_dmin = self.conv_dstdm(self._upsample(dsout))
153
+ dmout = self.dm(x_dmin+m)
154
+ x_dlin = self.conv_dmtdl(self._upsample(dmout))
155
+ dlout = self.dl(x_dlin+l)
156
+
157
+ return dlout
158
+
159
+
160
+ class SKO(nn.Module):
161
+ def __init__(self, in_ch, M=3, G=1, r=4, stride=1, L=32) -> None:
162
+ super().__init__()
163
+
164
+ d = max(int(in_ch/r), L)
165
+ self.M = M
166
+ self.in_ch = in_ch
167
+ self.convs = nn.ModuleList([])
168
+ for i in range(M):
169
+ self.convs.append(
170
+ nn.Sequential(
171
+ nn.Conv2d(in_ch, in_ch, kernel_size=3+i*2, stride=stride, padding = 1+i, groups=G),
172
+ nn.BatchNorm2d(in_ch),
173
+ nn.ReLU(inplace=True)
174
+ )
175
+ )
176
+ # print("D:", d)
177
+ self.fc = nn.Linear(in_ch, d)
178
+ self.fcs = nn.ModuleList([])
179
+ for i in range(M):
180
+ self.fcs.append(nn.Linear(d, in_ch))
181
+ self.softmax = nn.Softmax(dim=1)
182
+
183
+ def forward(self, x):
184
+ for i, conv in enumerate(self.convs):
185
+ fea = conv(x).clone().unsqueeze_(dim=1).clone()
186
+ if i == 0:
187
+ feas = fea
188
+ else:
189
+ feas = torch.cat([feas.clone(), fea], dim=1)
190
+ fea_U = torch.sum(feas.clone(), dim=1)
191
+ fea_s = fea_U.clone().mean(-1).mean(-1)
192
+ fea_z = self.fc(fea_s)
193
+ for i, fc in enumerate(self.fcs):
194
+ vector = fc(fea_z).clone().unsqueeze_(dim=1)
195
+ if i == 0:
196
+ attention_vectors = vector
197
+ else:
198
+ attention_vectors = torch.cat([attention_vectors.clone(), vector], dim=1)
199
+ attention_vectors = self.softmax(attention_vectors.clone())
200
+ attention_vectors = attention_vectors.clone().unsqueeze(-1).unsqueeze(-1)
201
+ fea_v = (feas * attention_vectors).clone().sum(dim=1)
202
+ return fea_v
203
+
204
+
205
+ class GARO(nn.Module):
206
+ def __init__(self, channel, norm=False):
207
+ super(GARO, self).__init__()
208
+
209
+ self.conv_1_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
210
+ self.conv_2_1 = DeformConv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
211
+ self.act = nn.PReLU(channel)
212
+ self.norm = nn.GroupNorm(num_channels=channel, num_groups=1)
213
+
214
+ def _upsample(self, x, y):
215
+ _, _, H, W = y.size()
216
+ return F.upsample(x, size=(H, W), mode='bilinear')
217
+
218
+ def forward(self, x):
219
+ x_1 = self.act(self.norm(self.conv_1_1(x)))
220
+ x_2 = self.act(self.norm(self.conv_2_1(x_1))) + x
221
+
222
+ return x_2
223
+
224
+ class CGM(nn.Module):
225
+ def __init__(self, channel, prompt_len=3, prompt_size=96, lin_dim=16):
226
+ super(CGM, self).__init__()
227
+ self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, channel, prompt_size, prompt_size))
228
+ self.linear_layer = nn.Linear(lin_dim, prompt_len)
229
+ self.conv3x3 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
230
+
231
+ def forward(self, x):
232
+ B, C, H, W = x.shape
233
+ emb = x.mean(dim=(-2, -1))
234
+ prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
235
+ prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
236
+ 1, 1,
237
+ 1,
238
+ 1).squeeze(
239
+ 1)
240
+ prompt = torch.sum(prompt, dim=1)
241
+ prompt = F.interpolate(prompt, (H, W), mode="bilinear")
242
+ prompt = self.conv3x3(prompt)
243
+
244
+ return prompt
245
+
246
+ class CIM(nn.Module):
247
+ def __init__(self, channel):
248
+ super(CIM, self).__init__()
249
+ self.res = ResBlock(2*channel, 2*channel)
250
+ self.conv3x3 = nn.Conv2d(2*channel, channel, kernel_size=3, stride=1, padding=1, bias=False)
251
+
252
+ def forward(self, prompt, x):
253
+
254
+ x = torch.cat((prompt, x), dim=1)
255
+ x = self.res(x)
256
+ out = self.conv3x3(x)
257
+
258
+ return out
259
+
260
+
261
+ class DeformConv2d(nn.Module):
262
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
263
+ super(DeformConv2d, self).__init__()
264
+ self.kernel_size = kernel_size
265
+ self.padding = padding
266
+ self.stride = stride
267
+ self.zero_padding = nn.ZeroPad2d(padding)
268
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
269
+
270
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
271
+ nn.init.constant_(self.p_conv.weight, 0)
272
+ self.p_conv.register_backward_hook(self._set_lr)
273
+
274
+ self.modulation = modulation
275
+ if modulation:
276
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
277
+ nn.init.constant_(self.m_conv.weight, 0)
278
+ self.m_conv.register_backward_hook(self._set_lr)
279
+
280
+ @staticmethod
281
+ def _set_lr(module, grad_input, grad_output):
282
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
283
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
284
+
285
+ def forward(self, x):
286
+ offset = self.p_conv(x)
287
+ if self.modulation:
288
+ m = torch.sigmoid(self.m_conv(x))
289
+
290
+ dtype = offset.data.type()
291
+ ks = self.kernel_size
292
+ N = offset.size(1) // 2
293
+
294
+ if self.padding:
295
+ x = self.zero_padding(x)
296
+
297
+ p = self._get_p(offset, dtype)
298
+
299
+ p = p.contiguous().permute(0, 2, 3, 1)
300
+ q_lt = p.detach().floor()
301
+ q_rb = q_lt + 1
302
+
303
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
304
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
305
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
306
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
307
+
308
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
309
+
310
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
311
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
312
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
313
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
314
+
315
+ x_q_lt = self._get_x_q(x, q_lt, N)
316
+ x_q_rb = self._get_x_q(x, q_rb, N)
317
+ x_q_lb = self._get_x_q(x, q_lb, N)
318
+ x_q_rt = self._get_x_q(x, q_rt, N)
319
+
320
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
321
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
322
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
323
+ g_rt.unsqueeze(dim=1) * x_q_rt
324
+
325
+ if self.modulation:
326
+ m = m.contiguous().permute(0, 2, 3, 1)
327
+ m = m.unsqueeze(dim=1)
328
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
329
+ x_offset *= m
330
+
331
+ x_offset = self._reshape_x_offset(x_offset, ks)
332
+ out = self.conv(x_offset)
333
+
334
+ return out
335
+
336
+ def _get_p_n(self, N, dtype):
337
+ p_n_x, p_n_y = torch.meshgrid(
338
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
339
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
340
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
341
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
342
+
343
+ return p_n
344
+
345
+ def _get_p_0(self, h, w, N, dtype):
346
+ p_0_x, p_0_y = torch.meshgrid(
347
+ torch.arange(1, h*self.stride+1, self.stride),
348
+ torch.arange(1, w*self.stride+1, self.stride))
349
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
350
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
351
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
352
+
353
+ return p_0
354
+
355
+ def _get_p(self, offset, dtype):
356
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
357
+
358
+ p_n = self._get_p_n(N, dtype)
359
+ p_0 = self._get_p_0(h, w, N, dtype)
360
+ p = p_0 + p_n + offset
361
+ return p
362
+
363
+ def _get_x_q(self, x, q, N):
364
+ b, h, w, _ = q.size()
365
+ padded_w = x.size(3)
366
+ c = x.size(1)
367
+ x = x.contiguous().view(b, c, -1)
368
+
369
+ index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
370
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
371
+
372
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
373
+
374
+ return x_offset
375
+
376
+ @staticmethod
377
+ def _reshape_x_offset(x_offset, ks):
378
+ b, c, h, w, N = x_offset.size()
379
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
380
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
381
+
382
+ return x_offset
383
+
384
+ class DeformConv2d(nn.Module):
385
+ def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
386
+ super(DeformConv2d, self).__init__()
387
+ self.kernel_size = kernel_size
388
+ self.padding = padding
389
+ self.stride = stride
390
+ self.zero_padding = nn.ZeroPad2d(padding)
391
+ self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
392
+
393
+ self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
394
+ nn.init.constant_(self.p_conv.weight, 0)
395
+ self.p_conv.register_backward_hook(self._set_lr)
396
+
397
+ self.modulation = modulation
398
+ if modulation:
399
+ self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
400
+ nn.init.constant_(self.m_conv.weight, 0)
401
+ self.m_conv.register_backward_hook(self._set_lr)
402
+
403
+ @staticmethod
404
+ def _set_lr(module, grad_input, grad_output):
405
+ grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
406
+ grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
407
+
408
+ def forward(self, x):
409
+ offset = self.p_conv(x)
410
+ if self.modulation:
411
+ m = torch.sigmoid(self.m_conv(x))
412
+
413
+ dtype = offset.data.type()
414
+ ks = self.kernel_size
415
+ N = offset.size(1) // 2
416
+
417
+ if self.padding:
418
+ x = self.zero_padding(x)
419
+
420
+ p = self._get_p(offset, dtype)
421
+
422
+ p = p.contiguous().permute(0, 2, 3, 1)
423
+ q_lt = p.detach().floor()
424
+ q_rb = q_lt + 1
425
+
426
+ q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
427
+ q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
428
+ q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
429
+ q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
430
+
431
+ p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
432
+
433
+ g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
434
+ g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
435
+ g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
436
+ g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
437
+
438
+ x_q_lt = self._get_x_q(x, q_lt, N)
439
+ x_q_rb = self._get_x_q(x, q_rb, N)
440
+ x_q_lb = self._get_x_q(x, q_lb, N)
441
+ x_q_rt = self._get_x_q(x, q_rt, N)
442
+
443
+ x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
444
+ g_rb.unsqueeze(dim=1) * x_q_rb + \
445
+ g_lb.unsqueeze(dim=1) * x_q_lb + \
446
+ g_rt.unsqueeze(dim=1) * x_q_rt
447
+
448
+ if self.modulation:
449
+ m = m.contiguous().permute(0, 2, 3, 1)
450
+ m = m.unsqueeze(dim=1)
451
+ m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
452
+ x_offset *= m
453
+
454
+ x_offset = self._reshape_x_offset(x_offset, ks)
455
+ out = self.conv(x_offset)
456
+
457
+ return out
458
+
459
+ def _get_p_n(self, N, dtype):
460
+ p_n_x, p_n_y = torch.meshgrid(
461
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
462
+ torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
463
+ p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
464
+ p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
465
+
466
+ return p_n
467
+
468
+ def _get_p_0(self, h, w, N, dtype):
469
+ p_0_x, p_0_y = torch.meshgrid(
470
+ torch.arange(1, h*self.stride+1, self.stride),
471
+ torch.arange(1, w*self.stride+1, self.stride))
472
+ p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
473
+ p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
474
+ p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
475
+
476
+ return p_0
477
+
478
+ def _get_p(self, offset, dtype):
479
+ N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
480
+
481
+ p_n = self._get_p_n(N, dtype)
482
+ p_0 = self._get_p_0(h, w, N, dtype)
483
+ p = p_0 + p_n + offset
484
+ return p
485
+
486
+ def _get_x_q(self, x, q, N):
487
+ b, h, w, _ = q.size()
488
+ padded_w = x.size(3)
489
+ c = x.size(1)
490
+ x = x.contiguous().view(b, c, -1)
491
+
492
+ index = q[..., :N]*padded_w + q[..., N:]
493
+ index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
494
+
495
+ x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
496
+
497
+ return x_offset
498
+
499
+ @staticmethod
500
+ def _reshape_x_offset(x_offset, ks):
501
+ b, c, h, w, N = x_offset.size()
502
+ x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
503
+ x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
504
+
505
+ return x_offset
506
+
507
+
508
+ class BasicConv(nn.Module):
509
+ def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
510
+ super(BasicConv, self).__init__()
511
+ if bias and norm:
512
+ bias = False
513
+
514
+ padding = kernel_size // 2
515
+ layers = list()
516
+ if transpose:
517
+ padding = kernel_size // 2 -1
518
+ layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
519
+ else:
520
+ layers.append(
521
+ nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
522
+ if norm:
523
+ layers.append(nn.BatchNorm2d(out_channel))
524
+ if relu:
525
+ layers.append(nn.GELU())
526
+ self.main = nn.Sequential(*layers)
527
+
528
+ def forward(self, x):
529
+ return self.main(x)
530
+
531
+
532
+ class ResBlock(nn.Module):
533
+ def __init__(self, in_channel, out_channel):
534
+ super(ResBlock, self).__init__()
535
+ self.main = nn.Sequential(
536
+ BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
537
+ BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
538
+ )
539
+
540
+ def forward(self, x):
541
+ return self.main(x) + x
542
+
543
+
544
+ from thop import profile
545
+
546
+ if __name__ == '__main__':
547
+
548
+ device = "cuda" if torch.cuda.is_available() else "cpu"
549
+
550
+ net = VISION().to(device)
551
+
552
+ input = torch.randn(1, 4, 512, 512).to(device)
553
+ output = net(input)
554
+
555
+ macs, params = profile(net, inputs=(input, ))
556
+
557
+ print('macs: ', macs, 'params: ', params)
558
+ print('macs: %.2f G, params: %.2f M' % (macs / 1000000000.0, params / 1000000.0))
559
+ print(output.shape)
result/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ The inference results will be saved in this folder.