vera6 commited on
Commit
463545c
·
verified ·
1 Parent(s): d3b2701

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ app/1510dnpnr-15K.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ app/*.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore CHANGED
@@ -1,162 +1,162 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- build/
12
- develop-eggs/
13
- dist/
14
- downloads/
15
- eggs/
16
- .eggs/
17
- lib/
18
- lib64/
19
- parts/
20
- sdist/
21
- var/
22
- wheels/
23
- share/python-wheels/
24
- *.egg-info/
25
- .installed.cfg
26
- *.egg
27
- MANIFEST
28
-
29
- # PyInstaller
30
- # Usually these files are written by a python script from a template
31
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
- *.manifest
33
- *.spec
34
-
35
- # Installer logs
36
- pip-log.txt
37
- pip-delete-this-directory.txt
38
-
39
- # Unit test / coverage reports
40
- htmlcov/
41
- .tox/
42
- .nox/
43
- .coverage
44
- .coverage.*
45
- .cache
46
- nosetests.xml
47
- coverage.xml
48
- *.cover
49
- *.py,cover
50
- .hypothesis/
51
- .pytest_cache/
52
- cover/
53
-
54
- # Translations
55
- *.mo
56
- *.pot
57
-
58
- # Django stuff:
59
- *.log
60
- local_settings.py
61
- db.sqlite3
62
- db.sqlite3-journal
63
-
64
- # Flask stuff:
65
- instance/
66
- .webassets-cache
67
-
68
- # Scrapy stuff:
69
- .scrapy
70
-
71
- # Sphinx documentation
72
- docs/_build/
73
-
74
- # PyBuilder
75
- .pybuilder/
76
- target/
77
-
78
- # Jupyter Notebook
79
- .ipynb_checkpoints
80
-
81
- # IPython
82
- profile_default/
83
- ipython_config.py
84
-
85
- # pyenv
86
- # For a library or package, you might want to ignore these files since the code is
87
- # intended to run in multiple environments; otherwise, check them in:
88
- # .python-version
89
-
90
- # pipenv
91
- # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
- # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
- # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
- # install all needed dependencies.
95
- #Pipfile.lock
96
-
97
- # poetry
98
- # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
- # This is especially recommended for binary packages to ensure reproducibility, and is more
100
- # commonly ignored for libraries.
101
- # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
- #poetry.lock
103
-
104
- # pdm
105
- # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
- #pdm.lock
107
- # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
- # in version control.
109
- # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
- .pdm.toml
111
- .pdm-python
112
- .pdm-build/
113
-
114
- # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
- __pypackages__/
116
-
117
- # Celery stuff
118
- celerybeat-schedule
119
- celerybeat.pid
120
-
121
- # SageMath parsed files
122
- *.sage.py
123
-
124
- # Environments
125
- .env
126
- .venv
127
- env/
128
- venv/
129
- ENV/
130
- env.bak/
131
- venv.bak/
132
-
133
- # Spyder project settings
134
- .spyderproject
135
- .spyproject
136
-
137
- # Rope project settings
138
- .ropeproject
139
-
140
- # mkdocs documentation
141
- /site
142
-
143
- # mypy
144
- .mypy_cache/
145
- .dmypy.json
146
- dmypy.json
147
-
148
- # Pyre type checker
149
- .pyre/
150
-
151
- # pytype static type analyzer
152
- .pytype/
153
-
154
- # Cython debug symbols
155
- cython_debug/
156
-
157
- # PyCharm
158
- # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
- # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
- # and can be added to the global gitignore or merged into this file. For a more nuclear
161
- # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
- #.idea/
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
Dockerfile CHANGED
@@ -1,33 +1,33 @@
1
- FROM python:3.10.14-bookworm
2
-
3
- ARG USER_UID=10002
4
- ARG USER_GID=$USER_UID
5
- ARG USERNAME=modelapi
6
-
7
- RUN groupadd --gid $USER_GID $USERNAME \
8
- && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME
9
-
10
- # Copy required files
11
- RUN mkdir -p /modelapi && mkdir -p /home/$USERNAME/.modelapi
12
- COPY app /modelapi/app
13
- COPY sgmse /modelapi/sgmse
14
- COPY pyproject.toml /modelapi/pyproject.toml
15
-
16
- ENV CUDA_HOME=/usr/local/cuda-12.6
17
-
18
- # Setup permissions
19
- RUN chown -R $USER_UID:$USER_GID /modelapi \
20
- && chown -R $USER_UID:$USER_GID /home/$USERNAME/.modelapi \
21
- && chown -R $USER_UID:$USER_GID /home/$USERNAME \
22
- && chmod -R 755 /home/$USERNAME \
23
- && chmod -R 755 /modelapi \
24
- && chmod -R 755 /home/$USERNAME/.modelapi
25
-
26
- # Change to the user and do subnet installation
27
- USER $USERNAME
28
-
29
- RUN /bin/bash -c "python3 -m venv /modelapi/.venv && source /modelapi/.venv/bin/activate && pip3 install -e /modelapi/."
30
-
31
- EXPOSE 6500
32
-
33
  CMD ["/bin/bash", "-c", "source /modelapi/.venv/bin/activate && python3 /modelapi/app/run.py"]
 
1
+ FROM python:3.10.14-bookworm
2
+
3
+ ARG USER_UID=10002
4
+ ARG USER_GID=$USER_UID
5
+ ARG USERNAME=modelapi
6
+
7
+ RUN groupadd --gid $USER_GID $USERNAME \
8
+ && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME
9
+
10
+ # Copy required files
11
+ RUN mkdir -p /modelapi && mkdir -p /home/$USERNAME/.modelapi
12
+ COPY app /modelapi/app
13
+ COPY sgmse /modelapi/sgmse
14
+ COPY pyproject.toml /modelapi/pyproject.toml
15
+
16
+ ENV CUDA_HOME=/usr/local/cuda-12.6
17
+
18
+ # Setup permissions
19
+ RUN chown -R $USER_UID:$USER_GID /modelapi \
20
+ && chown -R $USER_UID:$USER_GID /home/$USERNAME/.modelapi \
21
+ && chown -R $USER_UID:$USER_GID /home/$USERNAME \
22
+ && chmod -R 755 /home/$USERNAME \
23
+ && chmod -R 755 /modelapi \
24
+ && chmod -R 755 /home/$USERNAME/.modelapi
25
+
26
+ # Change to the user
27
+ USER $USERNAME
28
+
29
+ RUN /bin/bash -c "python3 -m venv /modelapi/.venv && source /modelapi/.venv/bin/activate && pip3 install -e /modelapi/."
30
+
31
+ EXPOSE 6500
32
+
33
  CMD ["/bin/bash", "-c", "source /modelapi/.venv/bin/activate && python3 /modelapi/app/run.py"]
LICENSE CHANGED
@@ -1,34 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2024 synapsec.ai
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
22
-
23
- ---
24
-
25
- ### Third-Party Code:
26
-
27
- Portions of this software are derived from code in the following project(s):
28
-
29
- - speech-enhancement-sgmse by sp-uhh (MIT License)
30
- - Repository: https://huggingface.co/sp-uhh/speech-enhancement-sgmse
31
- - Copyright (c) 2022 Signal Processing (SP), Universität Hamburg
32
- - Licensed under the MIT License (included in the `THIRD_PARTY_LICENSES` file)
33
-
34
- ---
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 synapsec.ai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1 +1,30 @@
1
- DENOISING speech enhancement model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Container Template for SoundsRight Subnet Miners
2
+
3
+ Miners in [Bittensor's](https://bittensor.com/) [SoundsRight Subnet](https://github.com/synapsec-ai/soundsright-subnet) must containerize their models before uploading to HuggingFace. This repo serves as a template.
4
+
5
+ The branches `DENOISING_16000HZ` and `DEREVERBERATION_16000HZ` contain this template fitted with [SGMSE+](https://huggingface.co/sp-uhh/speech-enhancement-sgmse) for 16 kHz speech enhancement, and the branches `DENOISING_48000HZ` and `DEREVERBERATION_48000HZ` are fitted with [SGMSE+](https://huggingface.co/sp-uhh/speech-enhancement-sgmse) for 48 kHz speech enhancement for the denoising and dereverberation tasks. These are also helpful resources for how to incorporate your model.
6
+
7
+ The `main` branch contains a template for a container that will spin up an API to communicate with the validator. The following entrypoints cannot be altered:
8
+
9
+ 1. `/status/` : Communicates API status
10
+ 2. `/prepare/` : Makes necessary preparations (downloading checkpoints, etc.) and initializes model
11
+ 3. `/upload-audio/` : Upload audio files, save to noisy audio directory
12
+ 4. `/enhance/` : Initialize model, enhance audio files, save to enhanced audio directory
13
+ 5. `/download-enhanced/` : Download enhanced audio files
14
+ 6. `/reset/` : Remove all existing audio files for another batch of enhancement
15
+
16
+ To add your own model to this template, there are a few things that a miner must do:
17
+
18
+ 1. Add the model files under the `model` directory.
19
+ 2. Modify the `modelapi.prepare` method in `app/app.py` with necessary preparations to initialize your model.
20
+ 3. Modify the `modelapi.enhance` method in `app/app.py` with the logic your model uses to enhance audio.
21
+ 4. Update `dependencies` in `pyproject.toml` with the dependencies used by your model.
22
+ 5. If you have directories other than `app` in your repository, be sure to modify the `Dockerfile` accordingly (reference line 12 in the `Dockerfile` for how to do this).
23
+ 6. Cite your sources (if applicable).
24
+
25
+ For your model to be processed by validators, there are a few formatting requirements. Note that the template already has been formatted to fit these guidelines.
26
+
27
+ 1. API endpoints must as outlined above.
28
+ 2. Port must be 6500.
29
+ 3. Container must be configured to run as non-root user.
30
+ 4. Container is not reliant on having network access to function as intended.
app/app.py CHANGED
@@ -1,213 +1,205 @@
1
- import fastapi
2
- import shutil
3
- import os
4
- import zipfile
5
- import io
6
- import uvicorn
7
- import threading
8
  import glob
9
- from typing import List
10
  import torch
11
  import gdown
12
  from soundfile import write
13
  from torchaudio import load
14
  from librosa import resample
15
  import logging
16
-
17
  logging.basicConfig(level=logging.DEBUG)
18
 
19
  from sgmse import ScoreModel
20
  from sgmse.util.other import pad_spec
21
 
22
-
23
  class ModelAPI:
24
-
25
  def __init__(self, host, port):
26
-
27
- self.host = host
28
  self.port = port
29
-
30
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
31
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
32
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
33
  app_dir = os.path.dirname(os.path.abspath(__file__))
34
- self.ckpt_path = glob.glob(os.path.join(app_dir, "*.ckpt"))[0]
35
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
36
  self.corrector = "ald"
37
  self.corrector_steps = 1
38
  self.snr = 0.5
39
  self.N = 30
40
 
 
41
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
42
  if not os.path.exists(audio_path):
43
  os.makedirs(audio_path)
44
-
 
45
  for filename in os.listdir(audio_path):
46
  file_path = os.path.join(audio_path, filename)
47
-
 
48
  try:
49
  if os.path.isfile(file_path) or os.path.islink(file_path):
50
- os.unlink(file_path)
51
  elif os.path.isdir(file_path):
52
- shutil.rmtree(file_path)
53
  except Exception as e:
54
  raise e
55
-
56
  self.app = fastapi.FastAPI()
57
  self._setup_routes()
58
-
59
  def _prepare(self):
60
  """Miners should modify this function to fit their fine-tuned models.
61
-
62
  This function will make any preparations necessary to initialize the
63
  speech enhancement model (i.e. downloading checkpoint files, etc.)
64
  """
65
-
66
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
67
  self.model.t_eps = 0.03
68
  self.model.eval()
69
-
70
  def _enhance(self):
71
  """
72
  Miners should modify this function to fit their fine-tuned models.
73
-
74
  This function will:
75
  1. Open each noisy .wav file
76
- 2. Enhance the audio with the model
77
- 3. Save the enhanced audio in .wav format to ModelAPI.enhanced_audio_path
78
  """
79
-
80
- if self.model.backbone == "ncsnpp_48k":
 
81
  target_sr = 48000
82
  pad_mode = "reflection"
83
- elif self.model.backbone == "ncsnpp_v2":
84
  target_sr = 16000
85
  pad_mode = "reflection"
86
- print("using ncsnpp_v2")
87
  else:
88
  target_sr = 16000
89
  pad_mode = "zero_pad"
90
-
91
- noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
92
- for noisy_file in noisy_files:
93
-
 
94
  filename = noisy_file.replace(self.noisy_audio_path, "")
95
  filename = filename[1:] if filename.startswith("/") else filename
96
 
 
97
  y, sr = load(noisy_file)
98
-
99
  if sr != target_sr:
100
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
101
 
102
- T_orig = y.size(1)
103
-
104
  # Normalize
105
  norm_factor = y.abs().max()
106
  y = y / norm_factor
107
-
108
  # Prepare DNN input
109
- Y = torch.unsqueeze(
110
- self.model._forward_transform(self.model._stft(y.to(self.device))), 0
111
- )
112
  Y = pad_spec(Y, mode=pad_mode)
113
-
114
  # Reverse sampling
115
- if self.model.sde.__class__.__name__ == "OUVESDE":
116
- if self.model.sde.sampler_type == "pc":
117
- sampler = self.model.get_pc_sampler(
118
- "reverse_diffusion",
119
- self.corrector,
120
- Y.to(self.device),
121
- N=self.N,
122
- corrector_steps=self.corrector_steps,
123
- snr=self.snr,
124
- )
125
- elif self.model.sde.sampler_type == "ode":
126
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
  else:
128
- raise ValueError(f"Sampler type {args.sampler_type} not supported")
129
- elif self.model.sde.__class__.__name__ == "SBVESDE":
130
- sampler_type = (
131
- "ode"
132
- if self.model.sde.sampler_type == "pc"
133
- else self.model.sde.sampler_type
134
- )
135
- sampler = self.model.get_sb_sampler(
136
- sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type
137
- )
138
  else:
139
- raise ValueError(
140
- f"SDE {self.model.sde.__class__.__name__} not supported"
141
- )
142
-
143
  sample, _ = sampler()
144
-
 
145
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
146
-
 
147
  x_hat = x_hat * norm_factor
148
-
149
- os.makedirs(
150
- os.path.dirname(os.path.join(self.enhanced_audio_path, filename)),
151
- exist_ok=True,
152
- )
153
- write(
154
- os.path.join(self.enhanced_audio_path, filename),
155
- x_hat.cpu().numpy(),
156
- target_sr,
157
- )
158
-
159
  def _setup_routes(self):
 
 
 
 
 
 
 
 
 
160
  self.app.get("/status/")(self.get_status)
161
  self.app.post("/prepare/")(self.prepare)
162
  self.app.post("/upload-audio/")(self.upload_audio)
163
  self.app.post("/enhance/")(self.enhance_audio)
164
  self.app.get("/download-enhanced/")(self.download_enhanced)
165
  self.app.post("/reset/")(self.reset)
166
-
167
  def get_status(self):
168
  try:
169
  return {"container_running": True}
170
  except Exception as e:
171
  logging.error(f"Error getting status: {e}")
172
- raise fastapi.HTTPException(
173
- status_code=500, detail="An error occurred while fetching API status."
174
- )
175
-
176
  def prepare(self):
177
  try:
178
  self._prepare()
179
- return {"preparations": True}
180
  except Exception as e:
181
  logging.error(f"Error during preparations: {e}")
182
- return fastapi.HTTPException(
183
- status_code=500, detail="An error occurred while fetching API status."
184
- )
185
-
186
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
187
-
188
  uploaded_files = []
189
-
190
  for file in files:
191
- try:
 
192
  file_path = os.path.join(self.noisy_audio_path, file.filename)
193
-
 
194
  with open(file_path, "wb") as f:
195
- while contents := file.file.read(1024 * 1024):
196
  f.write(contents)
197
-
198
- uploaded_files.append(file.filename)
199
-
 
200
  except Exception as e:
201
- logging.error(f"Error uploading files: {e}")
202
- raise fastapi.HTTPException(
203
- status_code=500,
204
- detail="An error occurred while uploading the noisy files.",
205
- )
206
  finally:
207
  file.file.close()
208
-
209
  print(f"uploaded files: {uploaded_files}")
210
-
211
  return {"uploaded_files": uploaded_files, "status": True}
212
 
213
  def enhance_audio(self):
@@ -215,44 +207,39 @@ class ModelAPI:
215
  # Enhance audio
216
  self._enhance()
217
  # Obtain list of file paths for enhanced audio
218
- wav_files = glob.glob(os.path.join(self.enhanced_audio_path, "*.wav"))
219
  # Extract just the file names
220
  enhanced_files = [os.path.basename(file) for file in wav_files]
221
  return {"status": True}
222
-
223
  except Exception as e:
224
  print(f"Exception occured during enhancement: {e}")
225
- raise fastapi.HTTPException(
226
- status_code=500,
227
- detail="An error occurred while enhancing the noisy files.",
228
- )
229
-
230
  def download_enhanced(self):
231
  try:
 
232
  zip_buffer = io.BytesIO()
233
 
234
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
235
- for wav_file in glob.glob(
236
- os.path.join(self.enhanced_audio_path, "*.wav")
237
- ):
238
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
 
 
239
  zip_buffer.seek(0)
240
 
 
241
  return fastapi.responses.StreamingResponse(
242
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
243
  media_type="application/zip",
244
- headers={
245
- "Content-Disposition": "attachment; filename=enhanced_audio_files.zip"
246
- },
247
  )
248
 
249
  except Exception as e:
250
  logging.error(f"Error during enhanced files download: {e}")
251
- raise fastapi.HTTPException(
252
- status_code=500,
253
- detail=f"An error occurred while creating the download file: {str(e)}",
254
- )
255
-
256
  def reset(self):
257
  """
258
  Removes all audio files in preparation for another batch of enhancement.
@@ -268,17 +255,9 @@ class ModelAPI:
268
  os.remove(filepath)
269
  except Exception as e:
270
  print(f"Error removing {filepath}: {e}")
271
- return {
272
- "status": False,
273
- "noisy": os.listdir(self.noisy_audio_path),
274
- "enhanced": os.listdir(self.enhanced_audio_path),
275
- }
276
- return {
277
- "status": True,
278
- "noisy": os.listdir(self.noisy_audio_path),
279
- "enhanced": os.listdir(self.enhanced_audio_path),
280
- }
281
 
282
  def run(self):
283
-
284
- uvicorn.run(self.app, host=self.host, port=self.port)
 
1
+ import fastapi
2
+ import shutil
3
+ import os
4
+ import zipfile
5
+ import io
6
+ import uvicorn
 
7
  import glob
8
+ from typing import List
9
  import torch
10
  import gdown
11
  from soundfile import write
12
  from torchaudio import load
13
  from librosa import resample
14
  import logging
 
15
  logging.basicConfig(level=logging.DEBUG)
16
 
17
  from sgmse import ScoreModel
18
  from sgmse.util.other import pad_spec
19
 
 
20
  class ModelAPI:
21
+
22
  def __init__(self, host, port):
23
+
24
+ self.host = host
25
  self.port = port
26
+
27
  self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
28
  self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
29
  self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
30
  app_dir = os.path.dirname(os.path.abspath(__file__))
31
+
32
+ ckpt_files = glob.glob(os.path.join(app_dir, "*.ckpt"))
33
+
34
+ if not ckpt_files:
35
+ raise FileNotFoundError("No .ckpt file found in app_dir.")
36
+ elif len(ckpt_files) > 1:
37
+ raise RuntimeError("Multiple .ckpt files found in app_dir. Please keep only one.")
38
+ else:
39
+ self.ckpt_path = ckpt_files[0]
40
+
41
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
  self.corrector = "ald"
43
  self.corrector_steps = 1
44
  self.snr = 0.5
45
  self.N = 30
46
 
47
+ # Create directories if they do not exist
48
  for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
49
  if not os.path.exists(audio_path):
50
  os.makedirs(audio_path)
51
+
52
+ # Loop through all the files and subdirectories in the directory
53
  for filename in os.listdir(audio_path):
54
  file_path = os.path.join(audio_path, filename)
55
+
56
+ # Check if it's a file or directory and remove accordingly
57
  try:
58
  if os.path.isfile(file_path) or os.path.islink(file_path):
59
+ os.unlink(file_path) # Remove the file or link
60
  elif os.path.isdir(file_path):
61
+ shutil.rmtree(file_path) # Remove the directory and its contents
62
  except Exception as e:
63
  raise e
64
+
65
  self.app = fastapi.FastAPI()
66
  self._setup_routes()
67
+
68
  def _prepare(self):
69
  """Miners should modify this function to fit their fine-tuned models.
70
+
71
  This function will make any preparations necessary to initialize the
72
  speech enhancement model (i.e. downloading checkpoint files, etc.)
73
  """
74
+ # Initialize model
75
  self.model = ScoreModel.load_from_checkpoint(self.ckpt_path, self.device)
76
  self.model.t_eps = 0.03
77
  self.model.eval()
 
78
  def _enhance(self):
79
  """
80
  Miners should modify this function to fit their fine-tuned models.
81
+
82
  This function will:
83
  1. Open each noisy .wav file
84
+ 2. Enhance the audio with the model
85
+ 3. Save the enhanced audio in .wav format to MinerAPI.enhanced_audio_path
86
  """
87
+
88
+ # Check if the model is trained on 48 kHz data
89
+ if self.model.backbone == 'ncsnpp_48k':
90
  target_sr = 48000
91
  pad_mode = "reflection"
92
+ elif self.model.backbone == 'ncsnpp_v2':
93
  target_sr = 16000
94
  pad_mode = "reflection"
 
95
  else:
96
  target_sr = 16000
97
  pad_mode = "zero_pad"
98
+
99
+ # Define file paths for all noisy files to be enhanced
100
+ noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, '*.wav')))
101
+ for noisy_file in noisy_files:
102
+
103
  filename = noisy_file.replace(self.noisy_audio_path, "")
104
  filename = filename[1:] if filename.startswith("/") else filename
105
 
106
+ # Load wav
107
  y, sr = load(noisy_file)
108
+ # Resample if necessary
109
  if sr != target_sr:
110
  y = torch.tensor(resample(y.numpy(), orig_sr=sr, target_sr=target_sr))
111
 
112
+ T_orig = y.size(1)
 
113
  # Normalize
114
  norm_factor = y.abs().max()
115
  y = y / norm_factor
 
116
  # Prepare DNN input
117
+ Y = torch.unsqueeze(self.model._forward_transform(self.model._stft(y.to(self.device))), 0)
 
 
118
  Y = pad_spec(Y, mode=pad_mode)
119
+
120
  # Reverse sampling
121
+ if self.model.sde.__class__.__name__ == 'OUVESDE':
122
+ if self.model.sde.sampler_type == 'pc':
123
+ sampler = self.model.get_pc_sampler('reverse_diffusion', self.corrector, Y.to(self.device), N=self.N,
124
+ corrector_steps=self.corrector_steps, snr=self.snr)
125
+ elif self.model.sde.sampler_type == 'ode':
 
 
 
 
 
 
126
  sampler = self.model.get_ode_sampler(Y.to(self.device), N=self.N)
127
  else:
128
+ raise ValueError(f"Sampler type {self.model.sde.sampler_type} not supported")
129
+ elif self.model.sde.__class__.__name__ == 'SBVESDE':
130
+ sampler_type = 'ode' if self.model.sde.sampler_type == 'pc' else self.model.sde.sampler_type
131
+ sampler = self.model.get_sb_sampler(sde=self.model.sde, y=Y.cuda(), sampler_type=sampler_type)
 
 
 
 
 
 
132
  else:
133
+ raise ValueError(f"SDE {self.model.sde.__class__.__name__} not supported")
 
 
 
134
  sample, _ = sampler()
135
+
136
+ # Backward transform in time domain
137
  x_hat = self.model.to_audio(sample.squeeze(), T_orig)
138
+
139
+ # Renormalize
140
  x_hat = x_hat * norm_factor
141
+
142
+ # Write enhanced wav file
143
+ os.makedirs(os.path.dirname(os.path.join(self.enhanced_audio_path, filename)), exist_ok=True)
144
+ write(os.path.join(self.enhanced_audio_path, filename), x_hat.cpu().numpy(), target_sr)
145
+
 
 
 
 
 
 
146
  def _setup_routes(self):
147
+ """
148
+ Setup API routes:
149
+
150
+ /status/ : Communicates API status
151
+ /upload-audio/ : Upload audio files, save to noisy audio directory
152
+ /enhance/ : Enhance audio files, save to enhanced audio directory
153
+ /download-enhanced/ : Download enhanced audio files
154
+ /reset/ : Reset noisy and enhanced file cache
155
+ """
156
  self.app.get("/status/")(self.get_status)
157
  self.app.post("/prepare/")(self.prepare)
158
  self.app.post("/upload-audio/")(self.upload_audio)
159
  self.app.post("/enhance/")(self.enhance_audio)
160
  self.app.get("/download-enhanced/")(self.download_enhanced)
161
  self.app.post("/reset/")(self.reset)
162
+
163
  def get_status(self):
164
  try:
165
  return {"container_running": True}
166
  except Exception as e:
167
  logging.error(f"Error getting status: {e}")
168
+ raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
169
+
 
 
170
  def prepare(self):
171
  try:
172
  self._prepare()
173
+ return {'preparations': True}
174
  except Exception as e:
175
  logging.error(f"Error during preparations: {e}")
176
+ return fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
177
+
 
 
178
  def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
179
+
180
  uploaded_files = []
181
+
182
  for file in files:
183
+ try:
184
+ # Define the path to save the file
185
  file_path = os.path.join(self.noisy_audio_path, file.filename)
186
+
187
+ # Save the uploaded file
188
  with open(file_path, "wb") as f:
189
+ while contents := file.file.read(1024*1024):
190
  f.write(contents)
191
+
192
+ # Append the file name to the list of uploaded files
193
+ uploaded_files.append(file.filename)
194
+
195
  except Exception as e:
196
+ logging.error(f"Error uploading files: {e}")
197
+ raise fastapi.HTTPException(status_code=500, detail="An error occurred while uploading the noisy files.")
 
 
 
198
  finally:
199
  file.file.close()
200
+
201
  print(f"uploaded files: {uploaded_files}")
202
+
203
  return {"uploaded_files": uploaded_files, "status": True}
204
 
205
  def enhance_audio(self):
 
207
  # Enhance audio
208
  self._enhance()
209
  # Obtain list of file paths for enhanced audio
210
+ wav_files = glob.glob(os.path.join(self.enhanced_audio_path, '*.wav'))
211
  # Extract just the file names
212
  enhanced_files = [os.path.basename(file) for file in wav_files]
213
  return {"status": True}
214
+
215
  except Exception as e:
216
  print(f"Exception occured during enhancement: {e}")
217
+ raise fastapi.HTTPException(status_code=500, detail="An error occurred while enhancing the noisy files.")
218
+
 
 
 
219
  def download_enhanced(self):
220
  try:
221
+ # Create an in-memory zip file to hold all the enhanced audio files
222
  zip_buffer = io.BytesIO()
223
 
224
  with zipfile.ZipFile(zip_buffer, "w") as zip_file:
225
+ # Add each .wav file in the enhanced_audio_path directory to the zip file
226
+ for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, '*.wav')):
 
227
  zip_file.write(wav_file, arcname=os.path.basename(wav_file))
228
+
229
+ # Make sure to seek back to the start of the BytesIO object before sending it
230
  zip_buffer.seek(0)
231
 
232
+ # Send the zip file to the client as a downloadable file
233
  return fastapi.responses.StreamingResponse(
234
  iter([zip_buffer.getvalue()]), # Stream the in-memory content
235
  media_type="application/zip",
236
+ headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
 
 
237
  )
238
 
239
  except Exception as e:
240
  logging.error(f"Error during enhanced files download: {e}")
241
+ raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while creating the download file: {str(e)}")
242
+
 
 
 
243
  def reset(self):
244
  """
245
  Removes all audio files in preparation for another batch of enhancement.
 
255
  os.remove(filepath)
256
  except Exception as e:
257
  print(f"Error removing {filepath}: {e}")
258
+ return {"status": False, "noisy": os.listdir(self.noisy_audio_path), "enhanced": os.listdir(self.enhanced_audio_path)}
259
+ return {"status": True, "noisy": os.listdir(self.noisy_audio_path), "enhanced": os.listdir(self.enhanced_audio_path)}
 
 
 
 
 
 
 
 
260
 
261
  def run(self):
262
+
263
+ uvicorn.run(self.app, host=self.host, port=self.port)
app/miner_4.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91992fe2205bee9a15c7c4302b7053e3ef8a9d15889c26bb19ec8529fc8a0903
3
- size 1312970157
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b546fe7ee37fa22db34470deff369aba15f31406da4a43d04fa80bf485f316d5
3
+ size 1312981921
app/run.py CHANGED
@@ -1,14 +1,8 @@
1
- import sys
2
- from pathlib import Path
3
-
4
- # Add parent directory to PYTHONPATH
5
- sys.path.append(str(Path(__file__).resolve().parent.parent))
6
-
7
- from app import ModelAPI
8
-
9
- api = ModelAPI(
10
- host = "0.0.0.0",
11
- port = 6500
12
- )
13
-
14
  api.run()
 
1
+ from app import ModelAPI
2
+
3
+ api = ModelAPI(
4
+ host = "0.0.0.0",
5
+ port = 6500
6
+ )
7
+
 
 
 
 
 
 
8
  api.run()
pyproject.toml CHANGED
@@ -1,58 +1,57 @@
1
- [build-system]
2
- requires = ["setuptools", "wheel"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "modelapi"
7
- version = "1.0.0"
8
- description = "This project implements a container for a fine-tuned audio enhancement model."
9
- readme = { file = "README.md", content-type = "text/markdown" }
10
- license = { file = "LICENSE" }
11
- classifiers = [
12
- "Development Status :: 3 - Beta",
13
- "Intended Audience :: Developers",
14
- "Topic :: Software Development :: Build Tools",
15
- "License :: OSI Approved :: MIT License",
16
- "Programming Language :: Python :: 3 :: Only",
17
- "Programming Language :: Python :: 3.10",
18
- "Topic :: Scientific/Engineering",
19
- "Topic :: Scientific/Engineering :: Mathematics",
20
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
21
- "Topic :: Software Development",
22
- "Topic :: Software Development :: Libraries",
23
- "Topic :: Software Development :: Libraries :: Python Modules"
24
- ]
25
- requires-python = ">=3.10,<3.12"
26
- dependencies = [
27
- "fastapi==0.115.5",
28
- "uvicorn==0.32.0",
29
- "python-multipart==0.0.17",
30
- "h5py==3.10.0",
31
- "ipympl==0.9.3",
32
- "librosa==0.10.1",
33
- "ninja==1.11.1.1",
34
- "numpy==1.24.4",
35
- "pandas==2.0.3",
36
- "pesq==0.0.4",
37
- "pillow==10.2.0",
38
- "protobuf==4.25.2",
39
- "pyarrow==15.0.0",
40
- "pyroomacoustics==0.7.3",
41
- "pystoi==0.4.1",
42
- "pytorch-lightning==2.1.4",
43
- "scipy==1.10.1",
44
- "setuptools==44.0.0",
45
- "seaborn==0.13.2",
46
- "torch==2.2.0",
47
- "torch-ema==0.3",
48
- "torchaudio==2.2.0",
49
- "torchvision==0.17.0",
50
- "torchinfo==1.8.0",
51
- "torchsde==0.2.6",
52
- "gdown==5.2.0",
53
- "huggingface-hub==0.31.4",
54
- "torch_pesq==0.1.2",
55
- ]
56
-
57
- [tool.setuptools.packages.find]
58
- include = ["app","model"]
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "modelapi"
7
+ version = "1.0.0"
8
+ description = "This project implements a container for a fine-tuned audio enhancement model."
9
+ readme = { file = "README.md", content-type = "text/markdown" }
10
+ license = { file = "LICENSE" }
11
+ classifiers = [
12
+ "Development Status :: 3 - Beta",
13
+ "Intended Audience :: Developers",
14
+ "Topic :: Software Development :: Build Tools",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Programming Language :: Python :: 3 :: Only",
17
+ "Programming Language :: Python :: 3.10",
18
+ "Topic :: Scientific/Engineering",
19
+ "Topic :: Scientific/Engineering :: Mathematics",
20
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
21
+ "Topic :: Software Development",
22
+ "Topic :: Software Development :: Libraries",
23
+ "Topic :: Software Development :: Libraries :: Python Modules"
24
+ ]
25
+ requires-python = ">=3.10,<3.12"
26
+ dependencies = [
27
+ "fastapi==0.115.5",
28
+ "uvicorn==0.32.0",
29
+ "python-multipart==0.0.17",
30
+ "h5py==3.10.0",
31
+ "ipympl==0.9.3",
32
+ "librosa==0.10.1",
33
+ "ninja==1.11.1.1",
34
+ "numpy==1.24.4",
35
+ "pandas==2.0.3",
36
+ "pesq==0.0.4",
37
+ "pillow==10.2.0",
38
+ "protobuf==4.25.2",
39
+ "pyarrow==15.0.0",
40
+ "pyroomacoustics==0.7.3",
41
+ "pystoi==0.4.1",
42
+ "pytorch-lightning==2.5.1",
43
+ "scipy==1.10.1",
44
+ "setuptools==44.0.0",
45
+ "seaborn==0.13.2",
46
+ "torch==2.2.0",
47
+ "torch-ema==0.3",
48
+ "torchaudio==2.2.0",
49
+ "torchvision==0.17.0",
50
+ "torchinfo==1.8.0",
51
+ "torchsde==0.2.6",
52
+ "gdown==5.2.0",
53
+ "torch_pesq==0.1.2"
54
+ ]
55
+
56
+ [tool.setuptools.packages.find]
57
+ include = ["app","model", "sgmse"]
 
sgmse/backbones/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
- from .shared import BackboneRegistry
2
- from .ncsnpp import NCSNpp
3
- from .ncsnpp_v2 import NCSNpp_v2
4
- from .ncsnpp_48k import NCSNpp_48k
5
- from .dcunet import DCUNet
6
-
7
- __all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_v2', 'NCSNpp_48k', 'DCUNet']
 
1
+ from .shared import BackboneRegistry
2
+ from .ncsnpp import NCSNpp
3
+ from .ncsnpp_v2 import NCSNpp_v2
4
+ from .ncsnpp_48k import NCSNpp_48k
5
+ from .dcunet import DCUNet
6
+
7
+ __all__ = ['BackboneRegistry', 'NCSNpp', 'NCSNpp_v2', 'NCSNpp_48k', 'DCUNet']
sgmse/backbones/dcunet.py CHANGED
@@ -1,627 +1,627 @@
1
- from functools import partial
2
- import numpy as np
3
-
4
- import torch
5
- from torch import nn, Tensor
6
- from torch.nn.modules.batchnorm import _BatchNorm
7
-
8
- from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
9
- DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
10
-
11
-
12
- def get_activation(name):
13
- if name == "silu":
14
- return nn.SiLU
15
- elif name == "relu":
16
- return nn.ReLU
17
- elif name == "leaky_relu":
18
- return nn.LeakyReLU
19
- else:
20
- raise NotImplementedError(f"Unknown activation: {name}")
21
-
22
-
23
- class BatchNorm(_BatchNorm):
24
- def _check_input_dim(self, input):
25
- if input.dim() < 2 or input.dim() > 4:
26
- raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
27
-
28
-
29
- class OnReIm(nn.Module):
30
- def __init__(self, module_cls, *args, **kwargs):
31
- super().__init__()
32
- self.re_module = module_cls(*args, **kwargs)
33
- self.im_module = module_cls(*args, **kwargs)
34
-
35
- def forward(self, x):
36
- return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
37
-
38
-
39
- # Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
40
-
41
- def unet_decoder_args(encoders, *, skip_connections):
42
- """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
43
- given the arguments used to construct the encoder.
44
- Args:
45
- encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
46
- List of arguments used to construct the encoders
47
- skip_connections (bool): Whether to include skip connections in the
48
- calculation of decoder input channels.
49
- Return:
50
- tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
51
- Arguments to be used to construct decoders
52
- """
53
- decoder_args = []
54
- for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
55
- if skip_connections and decoder_args:
56
- skip_in_chan = enc_out_chan
57
- else:
58
- skip_in_chan = 0
59
- decoder_args.append(
60
- (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
61
- )
62
- return tuple(decoder_args)
63
-
64
-
65
- def make_unet_encoder_decoder_args(encoder_args, decoder_args):
66
- encoder_args = tuple(
67
- (
68
- in_chan,
69
- out_chan,
70
- tuple(kernel_size),
71
- tuple(stride),
72
- tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
73
- tuple(dilation)
74
- )
75
- for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
76
- )
77
-
78
- if decoder_args == "auto":
79
- decoder_args = unet_decoder_args(
80
- encoder_args,
81
- skip_connections=True,
82
- )
83
- else:
84
- decoder_args = tuple(
85
- (
86
- in_chan,
87
- out_chan,
88
- tuple(kernel_size),
89
- tuple(stride),
90
- tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
91
- tuple(dilation),
92
- output_padding,
93
- )
94
- for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
95
- )
96
-
97
- return encoder_args, decoder_args
98
-
99
-
100
- DCUNET_ARCHITECTURES = {
101
- "DCUNet-10": make_unet_encoder_decoder_args(
102
- # Encoders:
103
- # (in_chan, out_chan, kernel_size, stride, padding, dilation)
104
- (
105
- (1, 32, (7, 5), (2, 2), "auto", (1,1)),
106
- (32, 64, (7, 5), (2, 2), "auto", (1,1)),
107
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
108
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
109
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
110
- ),
111
- # Decoders: automatic inverse
112
- "auto",
113
- ),
114
- "DCUNet-16": make_unet_encoder_decoder_args(
115
- # Encoders:
116
- # (in_chan, out_chan, kernel_size, stride, padding, dilation)
117
- (
118
- (1, 32, (7, 5), (2, 2), "auto", (1,1)),
119
- (32, 32, (7, 5), (2, 1), "auto", (1,1)),
120
- (32, 64, (7, 5), (2, 2), "auto", (1,1)),
121
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
122
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
123
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
124
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
125
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
126
- ),
127
- # Decoders: automatic inverse
128
- "auto",
129
- ),
130
- "DCUNet-20": make_unet_encoder_decoder_args(
131
- # Encoders:
132
- # (in_chan, out_chan, kernel_size, stride, padding, dilation)
133
- (
134
- (1, 32, (7, 1), (1, 1), "auto", (1,1)),
135
- (32, 32, (1, 7), (1, 1), "auto", (1,1)),
136
- (32, 64, (7, 5), (2, 2), "auto", (1,1)),
137
- (64, 64, (7, 5), (2, 1), "auto", (1,1)),
138
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
139
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
140
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
141
- (64, 64, (5, 3), (2, 1), "auto", (1,1)),
142
- (64, 64, (5, 3), (2, 2), "auto", (1,1)),
143
- (64, 90, (5, 3), (2, 1), "auto", (1,1)),
144
- ),
145
- # Decoders: automatic inverse
146
- "auto",
147
- ),
148
- "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
149
- # Encoders:
150
- # (in_chan, out_chan, kernel_size, stride, padding, dilation)
151
- (
152
- (1, 32, (4, 4), (1, 1), "auto", (1, 1)),
153
- (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
154
- (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
155
- (32, 64, (4, 4), (2, 1), "auto", (2, 1)),
156
- (64, 128, (4, 4), (2, 2), "auto", (4, 1)),
157
- (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
158
- ),
159
- # Decoders: automatic inverse
160
- "auto",
161
- ),
162
- }
163
-
164
-
165
- @BackboneRegistry.register("dcunet")
166
- class DCUNet(nn.Module):
167
- @staticmethod
168
- def add_argparse_args(parser):
169
- parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
170
- parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
171
- parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
172
- parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
173
- parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
174
- parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
175
- parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
176
- parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
177
- parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
178
- parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
179
- return parser
180
-
181
- def __init__(
182
- self,
183
- dcunet_architecture: str = "DilDCUNet-v2",
184
- dcunet_time_embedding: str = "gfp",
185
- dcunet_temb_layers_global: int = 2,
186
- dcunet_temb_layers_local: int = 1,
187
- dcunet_temb_activation: str = "silu",
188
- dcunet_time_embedding_complex: bool = False,
189
- dcunet_fix_length: str = "pad",
190
- dcunet_mask_bound: str = "none",
191
- dcunet_norm_type: str = "bN",
192
- dcunet_activation: str = "relu",
193
- embed_dim: int = 128,
194
- **kwargs
195
- ):
196
- super().__init__()
197
-
198
- self.architecture = dcunet_architecture
199
- self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
200
- self.norm_type = dcunet_norm_type
201
- self.activation = dcunet_activation
202
- self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
203
- self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
204
- self.time_embedding_complex = dcunet_time_embedding_complex
205
- self.temb_layers_global = dcunet_temb_layers_global
206
- self.temb_layers_local = dcunet_temb_layers_local
207
- self.temb_activation = dcunet_temb_activation
208
- conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
209
-
210
- # Replace `input_channels` in encoders config
211
- _replaced_input_channels, *rest = conf_encoders[0]
212
- encoders = ((self.input_channels, *rest), *conf_encoders[1:])
213
- decoders = conf_decoders
214
- self.encoders_stride_product = np.prod(
215
- [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
216
- )
217
-
218
- # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
219
- encoder_decoder_kwargs = dict(
220
- norm_type=self.norm_type, activation=self.activation,
221
- temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
222
-
223
- # Instantiate (global) time embedding layer
224
- embed_ops = []
225
- if self.time_embedding is not None:
226
- complex_valued = self.time_embedding_complex
227
- if self.time_embedding == "gfp":
228
- embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
229
- encoder_decoder_kwargs["embed_dim"] = embed_dim
230
- elif self.time_embedding == "ds":
231
- embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
232
- encoder_decoder_kwargs["embed_dim"] = embed_dim
233
-
234
- if self.time_embedding_complex:
235
- assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
236
- encoder_decoder_kwargs["complex_time_embedding"] = True
237
- for _ in range(self.temb_layers_global):
238
- embed_ops += [
239
- ComplexLinear(embed_dim, embed_dim, complex_valued=True),
240
- OnReIm(get_activation(dcunet_temb_activation))
241
- ]
242
- self.embed = nn.Sequential(*embed_ops)
243
-
244
- ### Instantiate DCUNet layers ###
245
- output_layer = ComplexConvTranspose2d(*decoders[-1])
246
- encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
247
- decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
248
-
249
- self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
250
- if self.mask_bound is not None:
251
- raise NotImplementedError("sorry, mask bounding not implemented at the moment")
252
- # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
253
- #operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
254
- #output_layer = nn.Sequential(*[x for x in operations if x is not None])
255
-
256
- assert len(encoders) == len(decoders) + 1
257
- self.encoders = nn.ModuleList(encoders)
258
- self.decoders = nn.ModuleList(decoders)
259
- self.output_layer = output_layer or nn.Identity()
260
-
261
- def forward(self, spec, t) -> Tensor:
262
- """
263
- Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
264
- by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
265
- and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
266
- strides of the encoders.
267
- Args:
268
- spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
269
- Returns:
270
- Tensor, of shape (batch, time) or (time).
271
- """
272
- # TF-rep shape: (batch, self.input_channels, n_fft, frames)
273
- # Estimate mask from time-frequency representation.
274
- x_in = self.fix_input_dims(spec)
275
- x = x_in
276
- t_embed = self.embed(t+0j) if self.time_embedding is not None else None
277
-
278
- enc_outs = []
279
- for idx, enc in enumerate(self.encoders):
280
- x = enc(x, t_embed)
281
- # UNet skip connection
282
- enc_outs.append(x)
283
- for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
284
- x = dec(x, t_embed, output_size=enc_out.shape)
285
- x = torch.cat([x, enc_out], dim=1)
286
-
287
- output = self.output_layer(x, output_size=x_in.shape)
288
- # output shape: (batch, 1, n_fft, frames)
289
- output = self.fix_output_dims(output, spec)
290
- return output
291
-
292
- def fix_input_dims(self, x):
293
- return _fix_dcu_input_dims(
294
- self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
295
- )
296
-
297
- def fix_output_dims(self, out, x):
298
- return _fix_dcu_output_dims(self.fix_length_mode, out, x)
299
-
300
-
301
- def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
302
- """Pad or trim `x` to a length compatible with DCUNet."""
303
- freq_prod = int(encoders_stride_product[0])
304
- time_prod = int(encoders_stride_product[1])
305
- if (x.shape[2] - 1) % freq_prod:
306
- raise TypeError(
307
- f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
308
- f"{freq_prod}, got {x.shape} instead"
309
- )
310
- time_remainder = (x.shape[3] - 1) % time_prod
311
- if time_remainder:
312
- if fix_length_mode is None:
313
- raise TypeError(
314
- f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
315
- f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
316
- f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
317
- )
318
- elif fix_length_mode == "pad":
319
- pad_shape = [0, time_prod - time_remainder]
320
- x = nn.functional.pad(x, pad_shape, mode="constant")
321
- elif fix_length_mode == "trim":
322
- pad_shape = [0, -time_remainder]
323
- x = nn.functional.pad(x, pad_shape, mode="constant")
324
- else:
325
- raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
326
- return x
327
-
328
-
329
- def _fix_dcu_output_dims(fix_length_mode, out, x):
330
- """Fix shape of `out` to the original shape of `x` by padding/cropping."""
331
- inp_len = x.shape[-1]
332
- output_len = out.shape[-1]
333
- return nn.functional.pad(out, [0, inp_len - output_len])
334
-
335
-
336
- def _get_norm(norm_type):
337
- if norm_type == "CbN":
338
- return ComplexBatchNorm
339
- elif norm_type == "bN":
340
- return partial(OnReIm, BatchNorm)
341
- else:
342
- raise NotImplementedError(f"Unknown norm type: {norm_type}")
343
-
344
-
345
- class DCUNetComplexEncoderBlock(nn.Module):
346
- def __init__(
347
- self,
348
- in_chan,
349
- out_chan,
350
- kernel_size,
351
- stride,
352
- padding,
353
- dilation,
354
- norm_type="bN",
355
- activation="leaky_relu",
356
- embed_dim=None,
357
- complex_time_embedding=False,
358
- temb_layers=1,
359
- temb_activation="silu"
360
- ):
361
- super().__init__()
362
-
363
- self.in_chan = in_chan
364
- self.out_chan = out_chan
365
- self.kernel_size = kernel_size
366
- self.stride = stride
367
- self.padding = padding
368
- self.dilation = dilation
369
- self.temb_layers = temb_layers
370
- self.temb_activation = temb_activation
371
- self.complex_time_embedding = complex_time_embedding
372
-
373
- self.conv = ComplexConv2d(
374
- in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
375
- )
376
- self.norm = _get_norm(norm_type)(out_chan)
377
- self.activation = OnReIm(get_activation(activation))
378
- self.embed_dim = embed_dim
379
- if self.embed_dim is not None:
380
- ops = []
381
- for _ in range(max(0, self.temb_layers - 1)):
382
- ops += [
383
- ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
384
- OnReIm(get_activation(self.temb_activation))
385
- ]
386
- ops += [
387
- FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
388
- OnReIm(get_activation(self.temb_activation))
389
- ]
390
- self.embed_layer = nn.Sequential(*ops)
391
-
392
- def forward(self, x, t_embed):
393
- y = self.conv(x)
394
- if self.embed_dim is not None:
395
- y = y + self.embed_layer(t_embed)
396
- return self.activation(self.norm(y))
397
-
398
-
399
- class DCUNetComplexDecoderBlock(nn.Module):
400
- def __init__(
401
- self,
402
- in_chan,
403
- out_chan,
404
- kernel_size,
405
- stride,
406
- padding,
407
- dilation,
408
- output_padding=(0, 0),
409
- norm_type="bN",
410
- activation="leaky_relu",
411
- embed_dim=None,
412
- temb_layers=1,
413
- temb_activation='swish',
414
- complex_time_embedding=False,
415
- ):
416
- super().__init__()
417
-
418
- self.in_chan = in_chan
419
- self.out_chan = out_chan
420
- self.kernel_size = kernel_size
421
- self.stride = stride
422
- self.padding = padding
423
- self.dilation = dilation
424
- self.output_padding = output_padding
425
- self.complex_time_embedding = complex_time_embedding
426
- self.temb_layers = temb_layers
427
- self.temb_activation = temb_activation
428
-
429
- self.deconv = ComplexConvTranspose2d(
430
- in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
431
- )
432
- self.norm = _get_norm(norm_type)(out_chan)
433
- self.activation = OnReIm(get_activation(activation))
434
- self.embed_dim = embed_dim
435
- if self.embed_dim is not None:
436
- ops = []
437
- for _ in range(max(0, self.temb_layers - 1)):
438
- ops += [
439
- ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
440
- OnReIm(get_activation(self.temb_activation))
441
- ]
442
- ops += [
443
- FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
444
- OnReIm(get_activation(self.temb_activation))
445
- ]
446
- self.embed_layer = nn.Sequential(*ops)
447
-
448
- def forward(self, x, t_embed, output_size=None):
449
- y = self.deconv(x, output_size=output_size)
450
- if self.embed_dim is not None:
451
- y = y + self.embed_layer(t_embed)
452
- return self.activation(self.norm(y))
453
-
454
-
455
- # From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
456
- class ComplexBatchNorm(torch.nn.Module):
457
- def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
458
- super(ComplexBatchNorm, self).__init__()
459
- self.num_features = num_features
460
- self.eps = eps
461
- self.momentum = momentum
462
- self.affine = affine
463
- self.track_running_stats = track_running_stats
464
- if self.affine:
465
- self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
466
- self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
467
- self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
468
- self.Br = torch.nn.Parameter(torch.Tensor(num_features))
469
- self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
470
- else:
471
- self.register_parameter('Wrr', None)
472
- self.register_parameter('Wri', None)
473
- self.register_parameter('Wii', None)
474
- self.register_parameter('Br', None)
475
- self.register_parameter('Bi', None)
476
- if self.track_running_stats:
477
- self.register_buffer('RMr', torch.zeros(num_features))
478
- self.register_buffer('RMi', torch.zeros(num_features))
479
- self.register_buffer('RVrr', torch.ones (num_features))
480
- self.register_buffer('RVri', torch.zeros(num_features))
481
- self.register_buffer('RVii', torch.ones (num_features))
482
- self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
483
- else:
484
- self.register_parameter('RMr', None)
485
- self.register_parameter('RMi', None)
486
- self.register_parameter('RVrr', None)
487
- self.register_parameter('RVri', None)
488
- self.register_parameter('RVii', None)
489
- self.register_parameter('num_batches_tracked', None)
490
- self.reset_parameters()
491
-
492
- def reset_running_stats(self):
493
- if self.track_running_stats:
494
- self.RMr.zero_()
495
- self.RMi.zero_()
496
- self.RVrr.fill_(1)
497
- self.RVri.zero_()
498
- self.RVii.fill_(1)
499
- self.num_batches_tracked.zero_()
500
-
501
- def reset_parameters(self):
502
- self.reset_running_stats()
503
- if self.affine:
504
- self.Br.data.zero_()
505
- self.Bi.data.zero_()
506
- self.Wrr.data.fill_(1)
507
- self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
508
- self.Wii.data.fill_(1)
509
-
510
- def _check_input_dim(self, xr, xi):
511
- assert(xr.shape == xi.shape)
512
- assert(xr.size(1) == self.num_features)
513
-
514
- def forward(self, x):
515
- xr, xi = x.real, x.imag
516
- self._check_input_dim(xr, xi)
517
-
518
- exponential_average_factor = 0.0
519
-
520
- if self.training and self.track_running_stats:
521
- self.num_batches_tracked += 1
522
- if self.momentum is None: # use cumulative moving average
523
- exponential_average_factor = 1.0 / self.num_batches_tracked.item()
524
- else: # use exponential moving average
525
- exponential_average_factor = self.momentum
526
-
527
- #
528
- # NOTE: The precise meaning of the "training flag" is:
529
- # True: Normalize using batch statistics, update running statistics
530
- # if they are being collected.
531
- # False: Normalize using running statistics, ignore batch statistics.
532
- #
533
- training = self.training or not self.track_running_stats
534
- redux = [i for i in reversed(range(xr.dim())) if i!=1]
535
- vdim = [1] * xr.dim()
536
- vdim[1] = xr.size(1)
537
-
538
- #
539
- # Mean M Computation and Centering
540
- #
541
- # Includes running mean update if training and running.
542
- #
543
- if training:
544
- Mr, Mi = xr, xi
545
- for d in redux:
546
- Mr = Mr.mean(d, keepdim=True)
547
- Mi = Mi.mean(d, keepdim=True)
548
- if self.track_running_stats:
549
- self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
550
- self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
551
- else:
552
- Mr = self.RMr.view(vdim)
553
- Mi = self.RMi.view(vdim)
554
- xr, xi = xr-Mr, xi-Mi
555
-
556
- #
557
- # Variance Matrix V Computation
558
- #
559
- # Includes epsilon numerical stabilizer/Tikhonov regularizer.
560
- # Includes running variance update if training and running.
561
- #
562
- if training:
563
- Vrr = xr * xr
564
- Vri = xr * xi
565
- Vii = xi * xi
566
- for d in redux:
567
- Vrr = Vrr.mean(d, keepdim=True)
568
- Vri = Vri.mean(d, keepdim=True)
569
- Vii = Vii.mean(d, keepdim=True)
570
- if self.track_running_stats:
571
- self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
572
- self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
573
- self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
574
- else:
575
- Vrr = self.RVrr.view(vdim)
576
- Vri = self.RVri.view(vdim)
577
- Vii = self.RVii.view(vdim)
578
- Vrr = Vrr + self.eps
579
- Vri = Vri
580
- Vii = Vii + self.eps
581
-
582
- #
583
- # Matrix Inverse Square Root U = V^-0.5
584
- #
585
- # sqrt of a 2x2 matrix,
586
- # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
587
- tau = Vrr + Vii
588
- delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
589
- s = delta.sqrt()
590
- t = (tau + 2*s).sqrt()
591
-
592
- # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
593
- rst = (s * t).reciprocal()
594
- Urr = (s + Vii) * rst
595
- Uii = (s + Vrr) * rst
596
- Uri = ( - Vri) * rst
597
-
598
- #
599
- # Optionally left-multiply U by affine weights W to produce combined
600
- # weights Z, left-multiply the inputs by Z, then optionally bias them.
601
- #
602
- # y = Zx + B
603
- # y = WUx + B
604
- # y = [Wrr Wri][Urr Uri] [xr] + [Br]
605
- # [Wir Wii][Uir Uii] [xi] [Bi]
606
- #
607
- if self.affine:
608
- Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
609
- Zrr = (Wrr * Urr) + (Wri * Uri)
610
- Zri = (Wrr * Uri) + (Wri * Uii)
611
- Zir = (Wri * Urr) + (Wii * Uri)
612
- Zii = (Wri * Uri) + (Wii * Uii)
613
- else:
614
- Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
615
-
616
- yr = (Zrr * xr) + (Zri * xi)
617
- yi = (Zir * xr) + (Zii * xi)
618
-
619
- if self.affine:
620
- yr = yr + self.Br.view(vdim)
621
- yi = yi + self.Bi.view(vdim)
622
-
623
- return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
624
-
625
- def extra_repr(self):
626
- return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
627
- 'track_running_stats={track_running_stats}'.format(**self.__dict__)
 
1
+ from functools import partial
2
+ import numpy as np
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from .shared import BackboneRegistry, ComplexConv2d, ComplexConvTranspose2d, ComplexLinear, \
9
+ DiffusionStepEmbedding, GaussianFourierProjection, FeatureMapDense, torch_complex_from_reim
10
+
11
+
12
+ def get_activation(name):
13
+ if name == "silu":
14
+ return nn.SiLU
15
+ elif name == "relu":
16
+ return nn.ReLU
17
+ elif name == "leaky_relu":
18
+ return nn.LeakyReLU
19
+ else:
20
+ raise NotImplementedError(f"Unknown activation: {name}")
21
+
22
+
23
+ class BatchNorm(_BatchNorm):
24
+ def _check_input_dim(self, input):
25
+ if input.dim() < 2 or input.dim() > 4:
26
+ raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim()))
27
+
28
+
29
+ class OnReIm(nn.Module):
30
+ def __init__(self, module_cls, *args, **kwargs):
31
+ super().__init__()
32
+ self.re_module = module_cls(*args, **kwargs)
33
+ self.im_module = module_cls(*args, **kwargs)
34
+
35
+ def forward(self, x):
36
+ return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
37
+
38
+
39
+ # Code for DCUNet largely copied from Danilo's `informedenh` repo, cheers!
40
+
41
+ def unet_decoder_args(encoders, *, skip_connections):
42
+ """Get list of decoder arguments for upsampling (right) side of a symmetric u-net,
43
+ given the arguments used to construct the encoder.
44
+ Args:
45
+ encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)):
46
+ List of arguments used to construct the encoders
47
+ skip_connections (bool): Whether to include skip connections in the
48
+ calculation of decoder input channels.
49
+ Return:
50
+ tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding):
51
+ Arguments to be used to construct decoders
52
+ """
53
+ decoder_args = []
54
+ for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation in reversed(encoders):
55
+ if skip_connections and decoder_args:
56
+ skip_in_chan = enc_out_chan
57
+ else:
58
+ skip_in_chan = 0
59
+ decoder_args.append(
60
+ (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding, enc_dilation)
61
+ )
62
+ return tuple(decoder_args)
63
+
64
+
65
+ def make_unet_encoder_decoder_args(encoder_args, decoder_args):
66
+ encoder_args = tuple(
67
+ (
68
+ in_chan,
69
+ out_chan,
70
+ tuple(kernel_size),
71
+ tuple(stride),
72
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else tuple(padding),
73
+ tuple(dilation)
74
+ )
75
+ for in_chan, out_chan, kernel_size, stride, padding, dilation in encoder_args
76
+ )
77
+
78
+ if decoder_args == "auto":
79
+ decoder_args = unet_decoder_args(
80
+ encoder_args,
81
+ skip_connections=True,
82
+ )
83
+ else:
84
+ decoder_args = tuple(
85
+ (
86
+ in_chan,
87
+ out_chan,
88
+ tuple(kernel_size),
89
+ tuple(stride),
90
+ tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding,
91
+ tuple(dilation),
92
+ output_padding,
93
+ )
94
+ for in_chan, out_chan, kernel_size, stride, padding, dilation, output_padding in decoder_args
95
+ )
96
+
97
+ return encoder_args, decoder_args
98
+
99
+
100
+ DCUNET_ARCHITECTURES = {
101
+ "DCUNet-10": make_unet_encoder_decoder_args(
102
+ # Encoders:
103
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
104
+ (
105
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
106
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
107
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
108
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
109
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
110
+ ),
111
+ # Decoders: automatic inverse
112
+ "auto",
113
+ ),
114
+ "DCUNet-16": make_unet_encoder_decoder_args(
115
+ # Encoders:
116
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
117
+ (
118
+ (1, 32, (7, 5), (2, 2), "auto", (1,1)),
119
+ (32, 32, (7, 5), (2, 1), "auto", (1,1)),
120
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
121
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
122
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
123
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
124
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
125
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
126
+ ),
127
+ # Decoders: automatic inverse
128
+ "auto",
129
+ ),
130
+ "DCUNet-20": make_unet_encoder_decoder_args(
131
+ # Encoders:
132
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
133
+ (
134
+ (1, 32, (7, 1), (1, 1), "auto", (1,1)),
135
+ (32, 32, (1, 7), (1, 1), "auto", (1,1)),
136
+ (32, 64, (7, 5), (2, 2), "auto", (1,1)),
137
+ (64, 64, (7, 5), (2, 1), "auto", (1,1)),
138
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
139
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
140
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
141
+ (64, 64, (5, 3), (2, 1), "auto", (1,1)),
142
+ (64, 64, (5, 3), (2, 2), "auto", (1,1)),
143
+ (64, 90, (5, 3), (2, 1), "auto", (1,1)),
144
+ ),
145
+ # Decoders: automatic inverse
146
+ "auto",
147
+ ),
148
+ "DilDCUNet-v2": make_unet_encoder_decoder_args( # architecture used in SGMSE / Interspeech paper
149
+ # Encoders:
150
+ # (in_chan, out_chan, kernel_size, stride, padding, dilation)
151
+ (
152
+ (1, 32, (4, 4), (1, 1), "auto", (1, 1)),
153
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
154
+ (32, 32, (4, 4), (1, 1), "auto", (1, 1)),
155
+ (32, 64, (4, 4), (2, 1), "auto", (2, 1)),
156
+ (64, 128, (4, 4), (2, 2), "auto", (4, 1)),
157
+ (128, 256, (4, 4), (2, 2), "auto", (8, 1)),
158
+ ),
159
+ # Decoders: automatic inverse
160
+ "auto",
161
+ ),
162
+ }
163
+
164
+
165
+ @BackboneRegistry.register("dcunet")
166
+ class DCUNet(nn.Module):
167
+ @staticmethod
168
+ def add_argparse_args(parser):
169
+ parser.add_argument("--dcunet-architecture", type=str, default="DilDCUNet-v2", choices=DCUNET_ARCHITECTURES.keys(), help="The concrete DCUNet architecture. 'DilDCUNet-v2' by default.")
170
+ parser.add_argument("--dcunet-time-embedding", type=str, choices=("gfp", "ds", "none"), default="gfp", help="Timestep embedding style. 'gfp' (Gaussian Fourier Projections) by default.")
171
+ parser.add_argument("--dcunet-temb-layers-global", type=int, default=1, help="Number of global linear+activation layers for the time embedding. 1 by default.")
172
+ parser.add_argument("--dcunet-temb-layers-local", type=int, default=1, help="Number of local (per-encoder/per-decoder) linear+activation layers for the time embedding. 1 by default.")
173
+ parser.add_argument("--dcunet-temb-activation", type=str, default="silu", help="The (complex) activation to use between all (global&local) time embedding layers.")
174
+ parser.add_argument("--dcunet-time-embedding-complex", action="store_true", help="Use complex-valued timestep embedding. Compatible with 'gfp' and 'ds' embeddings.")
175
+ parser.add_argument("--dcunet-fix-length", type=str, default="pad", choices=("pad", "trim", "none"), help="DCUNet strategy to 'fix' mismatched input timespan. 'pad' by default.")
176
+ parser.add_argument("--dcunet-mask-bound", type=str, choices=("tanh", "sigmoid", "none"), default="none", help="DCUNet output bounding strategy. 'none' by default.")
177
+ parser.add_argument("--dcunet-norm-type", type=str, choices=("bN", "CbN"), default="bN", help="The type of norm to use within each encoder and decoder layer. 'bN' (real/imaginary separate batch norm) by default.")
178
+ parser.add_argument("--dcunet-activation", type=str, choices=("leaky_relu", "relu", "silu"), default="leaky_relu", help="The activation to use within each encoder and decoder layer. 'leaky_relu' by default.")
179
+ return parser
180
+
181
+ def __init__(
182
+ self,
183
+ dcunet_architecture: str = "DilDCUNet-v2",
184
+ dcunet_time_embedding: str = "gfp",
185
+ dcunet_temb_layers_global: int = 2,
186
+ dcunet_temb_layers_local: int = 1,
187
+ dcunet_temb_activation: str = "silu",
188
+ dcunet_time_embedding_complex: bool = False,
189
+ dcunet_fix_length: str = "pad",
190
+ dcunet_mask_bound: str = "none",
191
+ dcunet_norm_type: str = "bN",
192
+ dcunet_activation: str = "relu",
193
+ embed_dim: int = 128,
194
+ **kwargs
195
+ ):
196
+ super().__init__()
197
+
198
+ self.architecture = dcunet_architecture
199
+ self.fix_length_mode = (dcunet_fix_length if dcunet_fix_length != "none" else None)
200
+ self.norm_type = dcunet_norm_type
201
+ self.activation = dcunet_activation
202
+ self.input_channels = 2 # for x_t and y -- note that this is 2 rather than 4, because we directly treat complex channels in this DNN
203
+ self.time_embedding = (dcunet_time_embedding if dcunet_time_embedding != "none" else None)
204
+ self.time_embedding_complex = dcunet_time_embedding_complex
205
+ self.temb_layers_global = dcunet_temb_layers_global
206
+ self.temb_layers_local = dcunet_temb_layers_local
207
+ self.temb_activation = dcunet_temb_activation
208
+ conf_encoders, conf_decoders = DCUNET_ARCHITECTURES[dcunet_architecture]
209
+
210
+ # Replace `input_channels` in encoders config
211
+ _replaced_input_channels, *rest = conf_encoders[0]
212
+ encoders = ((self.input_channels, *rest), *conf_encoders[1:])
213
+ decoders = conf_decoders
214
+ self.encoders_stride_product = np.prod(
215
+ [enc_stride for _, _, _, enc_stride, _, _ in encoders], axis=0
216
+ )
217
+
218
+ # Prepare kwargs for encoder and decoder (to potentially be modified before layer instantiation)
219
+ encoder_decoder_kwargs = dict(
220
+ norm_type=self.norm_type, activation=self.activation,
221
+ temb_layers=self.temb_layers_local, temb_activation=self.temb_activation)
222
+
223
+ # Instantiate (global) time embedding layer
224
+ embed_ops = []
225
+ if self.time_embedding is not None:
226
+ complex_valued = self.time_embedding_complex
227
+ if self.time_embedding == "gfp":
228
+ embed_ops += [GaussianFourierProjection(embed_dim=embed_dim, complex_valued=complex_valued)]
229
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
230
+ elif self.time_embedding == "ds":
231
+ embed_ops += [DiffusionStepEmbedding(embed_dim=embed_dim, complex_valued=complex_valued)]
232
+ encoder_decoder_kwargs["embed_dim"] = embed_dim
233
+
234
+ if self.time_embedding_complex:
235
+ assert self.time_embedding in ("gfp", "ds"), "Complex timestep embedding only available for gfp and ds"
236
+ encoder_decoder_kwargs["complex_time_embedding"] = True
237
+ for _ in range(self.temb_layers_global):
238
+ embed_ops += [
239
+ ComplexLinear(embed_dim, embed_dim, complex_valued=True),
240
+ OnReIm(get_activation(dcunet_temb_activation))
241
+ ]
242
+ self.embed = nn.Sequential(*embed_ops)
243
+
244
+ ### Instantiate DCUNet layers ###
245
+ output_layer = ComplexConvTranspose2d(*decoders[-1])
246
+ encoders = [DCUNetComplexEncoderBlock(*args, **encoder_decoder_kwargs) for args in encoders]
247
+ decoders = [DCUNetComplexDecoderBlock(*args, **encoder_decoder_kwargs) for args in decoders[:-1]]
248
+
249
+ self.mask_bound = (dcunet_mask_bound if dcunet_mask_bound != "none" else None)
250
+ if self.mask_bound is not None:
251
+ raise NotImplementedError("sorry, mask bounding not implemented at the moment")
252
+ # TODO we can't use nn.Sequential since the ComplexConvTranspose2d needs a second `output_size` argument
253
+ #operations = (output_layer, complex_nn.BoundComplexMask(self.mask_bound))
254
+ #output_layer = nn.Sequential(*[x for x in operations if x is not None])
255
+
256
+ assert len(encoders) == len(decoders) + 1
257
+ self.encoders = nn.ModuleList(encoders)
258
+ self.decoders = nn.ModuleList(decoders)
259
+ self.output_layer = output_layer or nn.Identity()
260
+
261
+ def forward(self, spec, t) -> Tensor:
262
+ """
263
+ Input shape is expected to be $(batch, nfreqs, time)$, with $nfreqs - 1$ divisible
264
+ by $f_0 * f_1 * ... * f_N$ where $f_k$ are the frequency strides of the encoders,
265
+ and $time - 1$ is divisible by $t_0 * t_1 * ... * t_N$ where $t_N$ are the time
266
+ strides of the encoders.
267
+ Args:
268
+ spec (Tensor): complex spectrogram tensor. 1D, 2D or 3D tensor, time last.
269
+ Returns:
270
+ Tensor, of shape (batch, time) or (time).
271
+ """
272
+ # TF-rep shape: (batch, self.input_channels, n_fft, frames)
273
+ # Estimate mask from time-frequency representation.
274
+ x_in = self.fix_input_dims(spec)
275
+ x = x_in
276
+ t_embed = self.embed(t+0j) if self.time_embedding is not None else None
277
+
278
+ enc_outs = []
279
+ for idx, enc in enumerate(self.encoders):
280
+ x = enc(x, t_embed)
281
+ # UNet skip connection
282
+ enc_outs.append(x)
283
+ for (enc_out, dec) in zip(reversed(enc_outs[:-1]), self.decoders):
284
+ x = dec(x, t_embed, output_size=enc_out.shape)
285
+ x = torch.cat([x, enc_out], dim=1)
286
+
287
+ output = self.output_layer(x, output_size=x_in.shape)
288
+ # output shape: (batch, 1, n_fft, frames)
289
+ output = self.fix_output_dims(output, spec)
290
+ return output
291
+
292
+ def fix_input_dims(self, x):
293
+ return _fix_dcu_input_dims(
294
+ self.fix_length_mode, x, torch.from_numpy(self.encoders_stride_product)
295
+ )
296
+
297
+ def fix_output_dims(self, out, x):
298
+ return _fix_dcu_output_dims(self.fix_length_mode, out, x)
299
+
300
+
301
+ def _fix_dcu_input_dims(fix_length_mode, x, encoders_stride_product):
302
+ """Pad or trim `x` to a length compatible with DCUNet."""
303
+ freq_prod = int(encoders_stride_product[0])
304
+ time_prod = int(encoders_stride_product[1])
305
+ if (x.shape[2] - 1) % freq_prod:
306
+ raise TypeError(
307
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with freq divisible by "
308
+ f"{freq_prod}, got {x.shape} instead"
309
+ )
310
+ time_remainder = (x.shape[3] - 1) % time_prod
311
+ if time_remainder:
312
+ if fix_length_mode is None:
313
+ raise TypeError(
314
+ f"Input shape must be [batch, ch, freq + 1, time + 1] with time divisible by "
315
+ f"{time_prod}, got {x.shape} instead. Set the 'fix_length_mode' argument "
316
+ f"in 'DCUNet' to 'pad' or 'trim' to fix shapes automatically."
317
+ )
318
+ elif fix_length_mode == "pad":
319
+ pad_shape = [0, time_prod - time_remainder]
320
+ x = nn.functional.pad(x, pad_shape, mode="constant")
321
+ elif fix_length_mode == "trim":
322
+ pad_shape = [0, -time_remainder]
323
+ x = nn.functional.pad(x, pad_shape, mode="constant")
324
+ else:
325
+ raise ValueError(f"Unknown fix_length mode '{fix_length_mode}'")
326
+ return x
327
+
328
+
329
+ def _fix_dcu_output_dims(fix_length_mode, out, x):
330
+ """Fix shape of `out` to the original shape of `x` by padding/cropping."""
331
+ inp_len = x.shape[-1]
332
+ output_len = out.shape[-1]
333
+ return nn.functional.pad(out, [0, inp_len - output_len])
334
+
335
+
336
+ def _get_norm(norm_type):
337
+ if norm_type == "CbN":
338
+ return ComplexBatchNorm
339
+ elif norm_type == "bN":
340
+ return partial(OnReIm, BatchNorm)
341
+ else:
342
+ raise NotImplementedError(f"Unknown norm type: {norm_type}")
343
+
344
+
345
+ class DCUNetComplexEncoderBlock(nn.Module):
346
+ def __init__(
347
+ self,
348
+ in_chan,
349
+ out_chan,
350
+ kernel_size,
351
+ stride,
352
+ padding,
353
+ dilation,
354
+ norm_type="bN",
355
+ activation="leaky_relu",
356
+ embed_dim=None,
357
+ complex_time_embedding=False,
358
+ temb_layers=1,
359
+ temb_activation="silu"
360
+ ):
361
+ super().__init__()
362
+
363
+ self.in_chan = in_chan
364
+ self.out_chan = out_chan
365
+ self.kernel_size = kernel_size
366
+ self.stride = stride
367
+ self.padding = padding
368
+ self.dilation = dilation
369
+ self.temb_layers = temb_layers
370
+ self.temb_activation = temb_activation
371
+ self.complex_time_embedding = complex_time_embedding
372
+
373
+ self.conv = ComplexConv2d(
374
+ in_chan, out_chan, kernel_size, stride, padding, bias=norm_type is None, dilation=dilation
375
+ )
376
+ self.norm = _get_norm(norm_type)(out_chan)
377
+ self.activation = OnReIm(get_activation(activation))
378
+ self.embed_dim = embed_dim
379
+ if self.embed_dim is not None:
380
+ ops = []
381
+ for _ in range(max(0, self.temb_layers - 1)):
382
+ ops += [
383
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
384
+ OnReIm(get_activation(self.temb_activation))
385
+ ]
386
+ ops += [
387
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
388
+ OnReIm(get_activation(self.temb_activation))
389
+ ]
390
+ self.embed_layer = nn.Sequential(*ops)
391
+
392
+ def forward(self, x, t_embed):
393
+ y = self.conv(x)
394
+ if self.embed_dim is not None:
395
+ y = y + self.embed_layer(t_embed)
396
+ return self.activation(self.norm(y))
397
+
398
+
399
+ class DCUNetComplexDecoderBlock(nn.Module):
400
+ def __init__(
401
+ self,
402
+ in_chan,
403
+ out_chan,
404
+ kernel_size,
405
+ stride,
406
+ padding,
407
+ dilation,
408
+ output_padding=(0, 0),
409
+ norm_type="bN",
410
+ activation="leaky_relu",
411
+ embed_dim=None,
412
+ temb_layers=1,
413
+ temb_activation='swish',
414
+ complex_time_embedding=False,
415
+ ):
416
+ super().__init__()
417
+
418
+ self.in_chan = in_chan
419
+ self.out_chan = out_chan
420
+ self.kernel_size = kernel_size
421
+ self.stride = stride
422
+ self.padding = padding
423
+ self.dilation = dilation
424
+ self.output_padding = output_padding
425
+ self.complex_time_embedding = complex_time_embedding
426
+ self.temb_layers = temb_layers
427
+ self.temb_activation = temb_activation
428
+
429
+ self.deconv = ComplexConvTranspose2d(
430
+ in_chan, out_chan, kernel_size, stride, padding, output_padding, dilation=dilation, bias=norm_type is None
431
+ )
432
+ self.norm = _get_norm(norm_type)(out_chan)
433
+ self.activation = OnReIm(get_activation(activation))
434
+ self.embed_dim = embed_dim
435
+ if self.embed_dim is not None:
436
+ ops = []
437
+ for _ in range(max(0, self.temb_layers - 1)):
438
+ ops += [
439
+ ComplexLinear(self.embed_dim, self.embed_dim, complex_valued=True),
440
+ OnReIm(get_activation(self.temb_activation))
441
+ ]
442
+ ops += [
443
+ FeatureMapDense(self.embed_dim, self.out_chan, complex_valued=True),
444
+ OnReIm(get_activation(self.temb_activation))
445
+ ]
446
+ self.embed_layer = nn.Sequential(*ops)
447
+
448
+ def forward(self, x, t_embed, output_size=None):
449
+ y = self.deconv(x, output_size=output_size)
450
+ if self.embed_dim is not None:
451
+ y = y + self.embed_layer(t_embed)
452
+ return self.activation(self.norm(y))
453
+
454
+
455
+ # From https://github.com/chanil1218/DCUnet.pytorch/blob/2dcdd30804be47a866fde6435cbb7e2f81585213/models/layers/complexnn.py
456
+ class ComplexBatchNorm(torch.nn.Module):
457
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=False):
458
+ super(ComplexBatchNorm, self).__init__()
459
+ self.num_features = num_features
460
+ self.eps = eps
461
+ self.momentum = momentum
462
+ self.affine = affine
463
+ self.track_running_stats = track_running_stats
464
+ if self.affine:
465
+ self.Wrr = torch.nn.Parameter(torch.Tensor(num_features))
466
+ self.Wri = torch.nn.Parameter(torch.Tensor(num_features))
467
+ self.Wii = torch.nn.Parameter(torch.Tensor(num_features))
468
+ self.Br = torch.nn.Parameter(torch.Tensor(num_features))
469
+ self.Bi = torch.nn.Parameter(torch.Tensor(num_features))
470
+ else:
471
+ self.register_parameter('Wrr', None)
472
+ self.register_parameter('Wri', None)
473
+ self.register_parameter('Wii', None)
474
+ self.register_parameter('Br', None)
475
+ self.register_parameter('Bi', None)
476
+ if self.track_running_stats:
477
+ self.register_buffer('RMr', torch.zeros(num_features))
478
+ self.register_buffer('RMi', torch.zeros(num_features))
479
+ self.register_buffer('RVrr', torch.ones (num_features))
480
+ self.register_buffer('RVri', torch.zeros(num_features))
481
+ self.register_buffer('RVii', torch.ones (num_features))
482
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
483
+ else:
484
+ self.register_parameter('RMr', None)
485
+ self.register_parameter('RMi', None)
486
+ self.register_parameter('RVrr', None)
487
+ self.register_parameter('RVri', None)
488
+ self.register_parameter('RVii', None)
489
+ self.register_parameter('num_batches_tracked', None)
490
+ self.reset_parameters()
491
+
492
+ def reset_running_stats(self):
493
+ if self.track_running_stats:
494
+ self.RMr.zero_()
495
+ self.RMi.zero_()
496
+ self.RVrr.fill_(1)
497
+ self.RVri.zero_()
498
+ self.RVii.fill_(1)
499
+ self.num_batches_tracked.zero_()
500
+
501
+ def reset_parameters(self):
502
+ self.reset_running_stats()
503
+ if self.affine:
504
+ self.Br.data.zero_()
505
+ self.Bi.data.zero_()
506
+ self.Wrr.data.fill_(1)
507
+ self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite
508
+ self.Wii.data.fill_(1)
509
+
510
+ def _check_input_dim(self, xr, xi):
511
+ assert(xr.shape == xi.shape)
512
+ assert(xr.size(1) == self.num_features)
513
+
514
+ def forward(self, x):
515
+ xr, xi = x.real, x.imag
516
+ self._check_input_dim(xr, xi)
517
+
518
+ exponential_average_factor = 0.0
519
+
520
+ if self.training and self.track_running_stats:
521
+ self.num_batches_tracked += 1
522
+ if self.momentum is None: # use cumulative moving average
523
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
524
+ else: # use exponential moving average
525
+ exponential_average_factor = self.momentum
526
+
527
+ #
528
+ # NOTE: The precise meaning of the "training flag" is:
529
+ # True: Normalize using batch statistics, update running statistics
530
+ # if they are being collected.
531
+ # False: Normalize using running statistics, ignore batch statistics.
532
+ #
533
+ training = self.training or not self.track_running_stats
534
+ redux = [i for i in reversed(range(xr.dim())) if i!=1]
535
+ vdim = [1] * xr.dim()
536
+ vdim[1] = xr.size(1)
537
+
538
+ #
539
+ # Mean M Computation and Centering
540
+ #
541
+ # Includes running mean update if training and running.
542
+ #
543
+ if training:
544
+ Mr, Mi = xr, xi
545
+ for d in redux:
546
+ Mr = Mr.mean(d, keepdim=True)
547
+ Mi = Mi.mean(d, keepdim=True)
548
+ if self.track_running_stats:
549
+ self.RMr.lerp_(Mr.squeeze(), exponential_average_factor)
550
+ self.RMi.lerp_(Mi.squeeze(), exponential_average_factor)
551
+ else:
552
+ Mr = self.RMr.view(vdim)
553
+ Mi = self.RMi.view(vdim)
554
+ xr, xi = xr-Mr, xi-Mi
555
+
556
+ #
557
+ # Variance Matrix V Computation
558
+ #
559
+ # Includes epsilon numerical stabilizer/Tikhonov regularizer.
560
+ # Includes running variance update if training and running.
561
+ #
562
+ if training:
563
+ Vrr = xr * xr
564
+ Vri = xr * xi
565
+ Vii = xi * xi
566
+ for d in redux:
567
+ Vrr = Vrr.mean(d, keepdim=True)
568
+ Vri = Vri.mean(d, keepdim=True)
569
+ Vii = Vii.mean(d, keepdim=True)
570
+ if self.track_running_stats:
571
+ self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor)
572
+ self.RVri.lerp_(Vri.squeeze(), exponential_average_factor)
573
+ self.RVii.lerp_(Vii.squeeze(), exponential_average_factor)
574
+ else:
575
+ Vrr = self.RVrr.view(vdim)
576
+ Vri = self.RVri.view(vdim)
577
+ Vii = self.RVii.view(vdim)
578
+ Vrr = Vrr + self.eps
579
+ Vri = Vri
580
+ Vii = Vii + self.eps
581
+
582
+ #
583
+ # Matrix Inverse Square Root U = V^-0.5
584
+ #
585
+ # sqrt of a 2x2 matrix,
586
+ # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
587
+ tau = Vrr + Vii
588
+ delta = torch.addcmul(Vrr * Vii, Vri, Vri, value=-1)
589
+ s = delta.sqrt()
590
+ t = (tau + 2*s).sqrt()
591
+
592
+ # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html
593
+ rst = (s * t).reciprocal()
594
+ Urr = (s + Vii) * rst
595
+ Uii = (s + Vrr) * rst
596
+ Uri = ( - Vri) * rst
597
+
598
+ #
599
+ # Optionally left-multiply U by affine weights W to produce combined
600
+ # weights Z, left-multiply the inputs by Z, then optionally bias them.
601
+ #
602
+ # y = Zx + B
603
+ # y = WUx + B
604
+ # y = [Wrr Wri][Urr Uri] [xr] + [Br]
605
+ # [Wir Wii][Uir Uii] [xi] [Bi]
606
+ #
607
+ if self.affine:
608
+ Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim)
609
+ Zrr = (Wrr * Urr) + (Wri * Uri)
610
+ Zri = (Wrr * Uri) + (Wri * Uii)
611
+ Zir = (Wri * Urr) + (Wii * Uri)
612
+ Zii = (Wri * Uri) + (Wii * Uii)
613
+ else:
614
+ Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii
615
+
616
+ yr = (Zrr * xr) + (Zri * xi)
617
+ yi = (Zir * xr) + (Zii * xi)
618
+
619
+ if self.affine:
620
+ yr = yr + self.Br.view(vdim)
621
+ yi = yi + self.Bi.view(vdim)
622
+
623
+ return torch.view_as_complex(torch.stack([yr, yi], dim=-1))
624
+
625
+ def extra_repr(self):
626
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
627
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
sgmse/backbones/ncsnpp.py CHANGED
@@ -1,420 +1,419 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # pylint: skip-file
17
-
18
- from .ncsnpp_utils import layers, layerspp, normalization
19
- import torch.nn as nn
20
- import functools
21
- import torch
22
- import numpy as np
23
-
24
- from .shared import BackboneRegistry
25
-
26
- ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
- ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
- Combine = layerspp.Combine
29
- conv3x3 = layerspp.conv3x3
30
- conv1x1 = layerspp.conv1x1
31
- get_act = layers.get_act
32
- get_normalization = normalization.get_normalization
33
- default_initializer = layers.default_init
34
-
35
-
36
- @BackboneRegistry.register("ncsnpp")
37
- class NCSNpp(nn.Module):
38
- """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
-
40
- @staticmethod
41
- def add_argparse_args(parser):
42
- parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
- parser.add_argument("--num_res_blocks", type=int, default=2)
44
- parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
45
- parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
46
- parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
47
- parser.set_defaults(centered=True)
48
- return parser
49
-
50
- def __init__(self,
51
- scale_by_sigma = True,
52
- nonlinearity = 'swish',
53
- nf = 128,
54
- # nf=96,
55
- ch_mult = (1, 1, 2, 2, 2, 2, 2),
56
- num_res_blocks = 2,
57
- attn_resolutions = (16,),
58
- resamp_with_conv = True,
59
- conditional = True,
60
- fir = True,
61
- fir_kernel = [1, 3, 3, 1],
62
- skip_rescale = True,
63
- resblock_type = 'biggan',
64
- progressive = 'output_skip',
65
- progressive_input = 'input_skip',
66
- progressive_combine = 'sum',
67
- init_scale = 0.,
68
- fourier_scale = 16,
69
- image_size = 256,
70
- embedding_type = 'fourier',
71
- dropout = .0,
72
- centered = True,
73
- **unused_kwargs
74
- ):
75
- super().__init__()
76
- self.act = act = get_act(nonlinearity)
77
-
78
- self.nf = nf = nf
79
- ch_mult = ch_mult
80
- self.num_res_blocks = num_res_blocks = num_res_blocks
81
- self.attn_resolutions = attn_resolutions = attn_resolutions
82
- dropout = dropout
83
- resamp_with_conv = resamp_with_conv
84
- self.num_resolutions = num_resolutions = len(ch_mult)
85
- self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
86
-
87
- self.conditional = conditional = conditional # noise-conditional
88
- self.centered = centered
89
- self.scale_by_sigma = scale_by_sigma
90
-
91
- fir = fir
92
- fir_kernel = fir_kernel
93
- self.skip_rescale = skip_rescale = skip_rescale
94
- self.resblock_type = resblock_type = resblock_type.lower()
95
- self.progressive = progressive = progressive.lower()
96
- self.progressive_input = progressive_input = progressive_input.lower()
97
- self.embedding_type = embedding_type = embedding_type.lower()
98
- init_scale = init_scale
99
- assert progressive in ['none', 'output_skip', 'residual']
100
- assert progressive_input in ['none', 'input_skip', 'residual']
101
- assert embedding_type in ['fourier', 'positional']
102
- combine_method = progressive_combine.lower()
103
- combiner = functools.partial(Combine, method=combine_method)
104
-
105
- num_channels = 4 # x.real, x.imag, y.real, y.imag
106
- self.output_layer = nn.Conv2d(num_channels, 2, 1)
107
-
108
- modules = []
109
- # timestep/noise_level embedding
110
- if embedding_type == 'fourier':
111
- # Gaussian Fourier features embeddings.
112
- modules.append(layerspp.GaussianFourierProjection(
113
- embedding_size=nf, scale=fourier_scale
114
- ))
115
- embed_dim = 2 * nf
116
- elif embedding_type == 'positional':
117
- embed_dim = nf
118
- else:
119
- raise ValueError(f'embedding type {embedding_type} unknown.')
120
-
121
- if conditional:
122
- modules.append(nn.Linear(embed_dim, nf * 4))
123
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
124
- nn.init.zeros_(modules[-1].bias)
125
- modules.append(nn.Linear(nf * 4, nf * 4))
126
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
127
- nn.init.zeros_(modules[-1].bias)
128
-
129
- AttnBlock = functools.partial(layerspp.AttnBlockpp,
130
- init_scale=init_scale, skip_rescale=skip_rescale)
131
-
132
- Upsample = functools.partial(layerspp.Upsample,
133
- with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
134
-
135
- if progressive == 'output_skip':
136
- self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
137
- elif progressive == 'residual':
138
- pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
139
- fir_kernel=fir_kernel, with_conv=True)
140
-
141
- Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
142
-
143
- if progressive_input == 'input_skip':
144
- self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
145
- elif progressive_input == 'residual':
146
- pyramid_downsample = functools.partial(layerspp.Downsample,
147
- fir=fir, fir_kernel=fir_kernel, with_conv=True)
148
-
149
- if resblock_type == 'ddpm':
150
- ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
151
- dropout=dropout, init_scale=init_scale,
152
- skip_rescale=skip_rescale, temb_dim=nf * 4)
153
-
154
- elif resblock_type == 'biggan':
155
- ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
156
- dropout=dropout, fir=fir, fir_kernel=fir_kernel,
157
- init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
158
-
159
- else:
160
- raise ValueError(f'resblock type {resblock_type} unrecognized.')
161
-
162
- # Downsampling block
163
-
164
- channels = num_channels
165
- if progressive_input != 'none':
166
- input_pyramid_ch = channels
167
-
168
- modules.append(conv3x3(channels, nf))
169
- hs_c = [nf]
170
-
171
- in_ch = nf
172
- for i_level in range(num_resolutions):
173
- # Residual blocks for this resolution
174
- for i_block in range(num_res_blocks):
175
- out_ch = nf * ch_mult[i_level]
176
- modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
177
- in_ch = out_ch
178
-
179
- if all_resolutions[i_level] in attn_resolutions:
180
- modules.append(AttnBlock(channels=in_ch))
181
- hs_c.append(in_ch)
182
-
183
- if i_level != num_resolutions - 1:
184
- if resblock_type == 'ddpm':
185
- modules.append(Downsample(in_ch=in_ch))
186
- else:
187
- modules.append(ResnetBlock(down=True, in_ch=in_ch))
188
-
189
- if progressive_input == 'input_skip':
190
- modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
191
- if combine_method == 'cat':
192
- in_ch *= 2
193
-
194
- elif progressive_input == 'residual':
195
- modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
196
- input_pyramid_ch = in_ch
197
-
198
- hs_c.append(in_ch)
199
-
200
- in_ch = hs_c[-1]
201
- modules.append(ResnetBlock(in_ch=in_ch))
202
- modules.append(AttnBlock(channels=in_ch))
203
- modules.append(ResnetBlock(in_ch=in_ch))
204
-
205
- pyramid_ch = 0
206
- # Upsampling block
207
- for i_level in reversed(range(num_resolutions)):
208
- for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
209
- out_ch = nf * ch_mult[i_level]
210
- modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
211
- in_ch = out_ch
212
-
213
- if all_resolutions[i_level] in attn_resolutions:
214
- modules.append(AttnBlock(channels=in_ch))
215
-
216
- if progressive != 'none':
217
- if i_level == num_resolutions - 1:
218
- if progressive == 'output_skip':
219
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
220
- num_channels=in_ch, eps=1e-6))
221
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
222
- pyramid_ch = channels
223
- elif progressive == 'residual':
224
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
225
- modules.append(conv3x3(in_ch, in_ch, bias=True))
226
- pyramid_ch = in_ch
227
- else:
228
- raise ValueError(f'{progressive} is not a valid name.')
229
- else:
230
- if progressive == 'output_skip':
231
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
232
- num_channels=in_ch, eps=1e-6))
233
- modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
234
- pyramid_ch = channels
235
- elif progressive == 'residual':
236
- modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
237
- pyramid_ch = in_ch
238
- else:
239
- raise ValueError(f'{progressive} is not a valid name')
240
-
241
- if i_level != 0:
242
- if resblock_type == 'ddpm':
243
- modules.append(Upsample(in_ch=in_ch))
244
- else:
245
- modules.append(ResnetBlock(in_ch=in_ch, up=True))
246
-
247
- assert not hs_c
248
-
249
- if progressive != 'output_skip':
250
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
251
- num_channels=in_ch, eps=1e-6))
252
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
253
-
254
- self.all_modules = nn.ModuleList(modules)
255
-
256
-
257
- def forward(self, x, time_cond):
258
- # timestep/noise_level embedding; only for continuous training
259
- modules = self.all_modules
260
- m_idx = 0
261
-
262
- # Convert real and imaginary parts of (x,y) into four channel dimensions
263
- x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
264
- x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
265
-
266
- if self.embedding_type == 'fourier':
267
- # Gaussian Fourier features embeddings.
268
- used_sigmas = time_cond
269
- temb = modules[m_idx](torch.log(used_sigmas))
270
- m_idx += 1
271
-
272
- elif self.embedding_type == 'positional':
273
- # Sinusoidal positional embeddings.
274
- timesteps = time_cond
275
- used_sigmas = self.sigmas[time_cond.long()]
276
- temb = layers.get_timestep_embedding(timesteps, self.nf)
277
-
278
- else:
279
- raise ValueError(f'embedding type {self.embedding_type} unknown.')
280
-
281
- if self.conditional:
282
- temb = modules[m_idx](temb)
283
- m_idx += 1
284
- temb = modules[m_idx](self.act(temb))
285
- m_idx += 1
286
- else:
287
- temb = None
288
-
289
- if not self.centered:
290
- # If input data is in [0, 1]
291
- x = 2 * x - 1.
292
-
293
- # Downsampling block
294
- input_pyramid = None
295
- if self.progressive_input != 'none':
296
- input_pyramid = x
297
-
298
- # Input layer: Conv2d: 4ch -> 128ch
299
- hs = [modules[m_idx](x)]
300
- m_idx += 1
301
-
302
- # Down path in U-Net
303
- for i_level in range(self.num_resolutions):
304
- # Residual blocks for this resolution
305
- for i_block in range(self.num_res_blocks):
306
- h = modules[m_idx](hs[-1], temb)
307
- m_idx += 1
308
- # Attention layer (optional)
309
- if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
310
- h = modules[m_idx](h)
311
- m_idx += 1
312
- hs.append(h)
313
-
314
- # Downsampling
315
- if i_level != self.num_resolutions - 1:
316
- if self.resblock_type == 'ddpm':
317
- h = modules[m_idx](hs[-1])
318
- m_idx += 1
319
- else:
320
- h = modules[m_idx](hs[-1], temb)
321
- m_idx += 1
322
-
323
- if self.progressive_input == 'input_skip': # Combine h with x
324
- input_pyramid = self.pyramid_downsample(input_pyramid)
325
- h = modules[m_idx](input_pyramid, h)
326
- m_idx += 1
327
-
328
- elif self.progressive_input == 'residual':
329
- input_pyramid = modules[m_idx](input_pyramid)
330
- m_idx += 1
331
- if self.skip_rescale:
332
- input_pyramid = (input_pyramid + h) / np.sqrt(2.)
333
- else:
334
- input_pyramid = input_pyramid + h
335
- h = input_pyramid
336
- hs.append(h)
337
-
338
- h = hs[-1] # actualy equal to: h = h
339
- h = modules[m_idx](h, temb) # ResNet block
340
- m_idx += 1
341
- h = modules[m_idx](h) # Attention block
342
- m_idx += 1
343
- h = modules[m_idx](h, temb) # ResNet block
344
- m_idx += 1
345
-
346
- pyramid = None
347
-
348
- # Upsampling block
349
- for i_level in reversed(range(self.num_resolutions)):
350
- for i_block in range(self.num_res_blocks + 1):
351
- h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
352
- m_idx += 1
353
-
354
- # edit: from -1 to -2
355
- if h.shape[-2] in self.attn_resolutions:
356
- h = modules[m_idx](h)
357
- m_idx += 1
358
-
359
- if self.progressive != 'none':
360
- if i_level == self.num_resolutions - 1:
361
- if self.progressive == 'output_skip':
362
- pyramid = self.act(modules[m_idx](h)) # GroupNorm
363
- m_idx += 1
364
- pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
365
- m_idx += 1
366
- elif self.progressive == 'residual':
367
- pyramid = self.act(modules[m_idx](h))
368
- m_idx += 1
369
- pyramid = modules[m_idx](pyramid)
370
- m_idx += 1
371
- else:
372
- raise ValueError(f'{self.progressive} is not a valid name.')
373
- else:
374
- if self.progressive == 'output_skip':
375
- pyramid = self.pyramid_upsample(pyramid) # Upsample
376
- pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
377
- m_idx += 1
378
- pyramid_h = modules[m_idx](pyramid_h)
379
- m_idx += 1
380
- pyramid = pyramid + pyramid_h
381
- elif self.progressive == 'residual':
382
- pyramid = modules[m_idx](pyramid)
383
- m_idx += 1
384
- if self.skip_rescale:
385
- pyramid = (pyramid + h) / np.sqrt(2.)
386
- else:
387
- pyramid = pyramid + h
388
- h = pyramid
389
- else:
390
- raise ValueError(f'{self.progressive} is not a valid name')
391
-
392
- # Upsampling Layer
393
- if i_level != 0:
394
- if self.resblock_type == 'ddpm':
395
- h = modules[m_idx](h)
396
- m_idx += 1
397
- else:
398
- h = modules[m_idx](h, temb) # Upspampling
399
- m_idx += 1
400
-
401
- assert not hs
402
-
403
- if self.progressive == 'output_skip':
404
- h = pyramid
405
- else:
406
- h = self.act(modules[m_idx](h))
407
- m_idx += 1
408
- h = modules[m_idx](h)
409
- m_idx += 1
410
-
411
- assert m_idx == len(modules), "Implementation error"
412
- if self.scale_by_sigma:
413
- used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
414
- h = h / used_sigmas
415
-
416
- # Convert back to complex number
417
- h = self.output_layer(h)
418
- h = torch.permute(h, (0, 2, 3, 1)).contiguous()
419
- h = torch.view_as_complex(h)[:,None, :, :]
420
- return h
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp")
37
+ class NCSNpp(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
45
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
46
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
47
+ parser.set_defaults(centered=True)
48
+ return parser
49
+
50
+ def __init__(self,
51
+ scale_by_sigma = True,
52
+ nonlinearity = 'swish',
53
+ nf = 128,
54
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
55
+ num_res_blocks = 2,
56
+ attn_resolutions = (16,),
57
+ resamp_with_conv = True,
58
+ conditional = True,
59
+ fir = True,
60
+ fir_kernel = [1, 3, 3, 1],
61
+ skip_rescale = True,
62
+ resblock_type = 'biggan',
63
+ progressive = 'output_skip',
64
+ progressive_input = 'input_skip',
65
+ progressive_combine = 'sum',
66
+ init_scale = 0.,
67
+ fourier_scale = 16,
68
+ image_size = 256,
69
+ embedding_type = 'fourier',
70
+ dropout = .0,
71
+ centered = True,
72
+ **unused_kwargs
73
+ ):
74
+ super().__init__()
75
+ self.act = act = get_act(nonlinearity)
76
+
77
+ self.nf = nf = nf
78
+ ch_mult = ch_mult
79
+ self.num_res_blocks = num_res_blocks = num_res_blocks
80
+ self.attn_resolutions = attn_resolutions = attn_resolutions
81
+ dropout = dropout
82
+ resamp_with_conv = resamp_with_conv
83
+ self.num_resolutions = num_resolutions = len(ch_mult)
84
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
85
+
86
+ self.conditional = conditional = conditional # noise-conditional
87
+ self.centered = centered
88
+ self.scale_by_sigma = scale_by_sigma
89
+
90
+ fir = fir
91
+ fir_kernel = fir_kernel
92
+ self.skip_rescale = skip_rescale = skip_rescale
93
+ self.resblock_type = resblock_type = resblock_type.lower()
94
+ self.progressive = progressive = progressive.lower()
95
+ self.progressive_input = progressive_input = progressive_input.lower()
96
+ self.embedding_type = embedding_type = embedding_type.lower()
97
+ init_scale = init_scale
98
+ assert progressive in ['none', 'output_skip', 'residual']
99
+ assert progressive_input in ['none', 'input_skip', 'residual']
100
+ assert embedding_type in ['fourier', 'positional']
101
+ combine_method = progressive_combine.lower()
102
+ combiner = functools.partial(Combine, method=combine_method)
103
+
104
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
105
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
106
+
107
+ modules = []
108
+ # timestep/noise_level embedding
109
+ if embedding_type == 'fourier':
110
+ # Gaussian Fourier features embeddings.
111
+ modules.append(layerspp.GaussianFourierProjection(
112
+ embedding_size=nf, scale=fourier_scale
113
+ ))
114
+ embed_dim = 2 * nf
115
+ elif embedding_type == 'positional':
116
+ embed_dim = nf
117
+ else:
118
+ raise ValueError(f'embedding type {embedding_type} unknown.')
119
+
120
+ if conditional:
121
+ modules.append(nn.Linear(embed_dim, nf * 4))
122
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
123
+ nn.init.zeros_(modules[-1].bias)
124
+ modules.append(nn.Linear(nf * 4, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+
128
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
129
+ init_scale=init_scale, skip_rescale=skip_rescale)
130
+
131
+ Upsample = functools.partial(layerspp.Upsample,
132
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
133
+
134
+ if progressive == 'output_skip':
135
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
136
+ elif progressive == 'residual':
137
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
138
+ fir_kernel=fir_kernel, with_conv=True)
139
+
140
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
141
+
142
+ if progressive_input == 'input_skip':
143
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
144
+ elif progressive_input == 'residual':
145
+ pyramid_downsample = functools.partial(layerspp.Downsample,
146
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
147
+
148
+ if resblock_type == 'ddpm':
149
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
150
+ dropout=dropout, init_scale=init_scale,
151
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
152
+
153
+ elif resblock_type == 'biggan':
154
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
155
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
156
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
157
+
158
+ else:
159
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
160
+
161
+ # Downsampling block
162
+
163
+ channels = num_channels
164
+ if progressive_input != 'none':
165
+ input_pyramid_ch = channels
166
+
167
+ modules.append(conv3x3(channels, nf))
168
+ hs_c = [nf]
169
+
170
+ in_ch = nf
171
+ for i_level in range(num_resolutions):
172
+ # Residual blocks for this resolution
173
+ for i_block in range(num_res_blocks):
174
+ out_ch = nf * ch_mult[i_level]
175
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
176
+ in_ch = out_ch
177
+
178
+ if all_resolutions[i_level] in attn_resolutions:
179
+ modules.append(AttnBlock(channels=in_ch))
180
+ hs_c.append(in_ch)
181
+
182
+ if i_level != num_resolutions - 1:
183
+ if resblock_type == 'ddpm':
184
+ modules.append(Downsample(in_ch=in_ch))
185
+ else:
186
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
187
+
188
+ if progressive_input == 'input_skip':
189
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
190
+ if combine_method == 'cat':
191
+ in_ch *= 2
192
+
193
+ elif progressive_input == 'residual':
194
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
195
+ input_pyramid_ch = in_ch
196
+
197
+ hs_c.append(in_ch)
198
+
199
+ in_ch = hs_c[-1]
200
+ modules.append(ResnetBlock(in_ch=in_ch))
201
+ modules.append(AttnBlock(channels=in_ch))
202
+ modules.append(ResnetBlock(in_ch=in_ch))
203
+
204
+ pyramid_ch = 0
205
+ # Upsampling block
206
+ for i_level in reversed(range(num_resolutions)):
207
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
208
+ out_ch = nf * ch_mult[i_level]
209
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
210
+ in_ch = out_ch
211
+
212
+ if all_resolutions[i_level] in attn_resolutions:
213
+ modules.append(AttnBlock(channels=in_ch))
214
+
215
+ if progressive != 'none':
216
+ if i_level == num_resolutions - 1:
217
+ if progressive == 'output_skip':
218
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
219
+ num_channels=in_ch, eps=1e-6))
220
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
221
+ pyramid_ch = channels
222
+ elif progressive == 'residual':
223
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
224
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
225
+ pyramid_ch = in_ch
226
+ else:
227
+ raise ValueError(f'{progressive} is not a valid name.')
228
+ else:
229
+ if progressive == 'output_skip':
230
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
231
+ num_channels=in_ch, eps=1e-6))
232
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
233
+ pyramid_ch = channels
234
+ elif progressive == 'residual':
235
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
236
+ pyramid_ch = in_ch
237
+ else:
238
+ raise ValueError(f'{progressive} is not a valid name')
239
+
240
+ if i_level != 0:
241
+ if resblock_type == 'ddpm':
242
+ modules.append(Upsample(in_ch=in_ch))
243
+ else:
244
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
245
+
246
+ assert not hs_c
247
+
248
+ if progressive != 'output_skip':
249
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
250
+ num_channels=in_ch, eps=1e-6))
251
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
252
+
253
+ self.all_modules = nn.ModuleList(modules)
254
+
255
+
256
+ def forward(self, x, time_cond):
257
+ # timestep/noise_level embedding; only for continuous training
258
+ modules = self.all_modules
259
+ m_idx = 0
260
+
261
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
262
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
263
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
264
+
265
+ if self.embedding_type == 'fourier':
266
+ # Gaussian Fourier features embeddings.
267
+ used_sigmas = time_cond
268
+ temb = modules[m_idx](torch.log(used_sigmas))
269
+ m_idx += 1
270
+
271
+ elif self.embedding_type == 'positional':
272
+ # Sinusoidal positional embeddings.
273
+ timesteps = time_cond
274
+ used_sigmas = self.sigmas[time_cond.long()]
275
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
276
+
277
+ else:
278
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
279
+
280
+ if self.conditional:
281
+ temb = modules[m_idx](temb)
282
+ m_idx += 1
283
+ temb = modules[m_idx](self.act(temb))
284
+ m_idx += 1
285
+ else:
286
+ temb = None
287
+
288
+ if not self.centered:
289
+ # If input data is in [0, 1]
290
+ x = 2 * x - 1.
291
+
292
+ # Downsampling block
293
+ input_pyramid = None
294
+ if self.progressive_input != 'none':
295
+ input_pyramid = x
296
+
297
+ # Input layer: Conv2d: 4ch -> 128ch
298
+ hs = [modules[m_idx](x)]
299
+ m_idx += 1
300
+
301
+ # Down path in U-Net
302
+ for i_level in range(self.num_resolutions):
303
+ # Residual blocks for this resolution
304
+ for i_block in range(self.num_res_blocks):
305
+ h = modules[m_idx](hs[-1], temb)
306
+ m_idx += 1
307
+ # Attention layer (optional)
308
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
309
+ h = modules[m_idx](h)
310
+ m_idx += 1
311
+ hs.append(h)
312
+
313
+ # Downsampling
314
+ if i_level != self.num_resolutions - 1:
315
+ if self.resblock_type == 'ddpm':
316
+ h = modules[m_idx](hs[-1])
317
+ m_idx += 1
318
+ else:
319
+ h = modules[m_idx](hs[-1], temb)
320
+ m_idx += 1
321
+
322
+ if self.progressive_input == 'input_skip': # Combine h with x
323
+ input_pyramid = self.pyramid_downsample(input_pyramid)
324
+ h = modules[m_idx](input_pyramid, h)
325
+ m_idx += 1
326
+
327
+ elif self.progressive_input == 'residual':
328
+ input_pyramid = modules[m_idx](input_pyramid)
329
+ m_idx += 1
330
+ if self.skip_rescale:
331
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
332
+ else:
333
+ input_pyramid = input_pyramid + h
334
+ h = input_pyramid
335
+ hs.append(h)
336
+
337
+ h = hs[-1] # actualy equal to: h = h
338
+ h = modules[m_idx](h, temb) # ResNet block
339
+ m_idx += 1
340
+ h = modules[m_idx](h) # Attention block
341
+ m_idx += 1
342
+ h = modules[m_idx](h, temb) # ResNet block
343
+ m_idx += 1
344
+
345
+ pyramid = None
346
+
347
+ # Upsampling block
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ for i_block in range(self.num_res_blocks + 1):
350
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
351
+ m_idx += 1
352
+
353
+ # edit: from -1 to -2
354
+ if h.shape[-2] in self.attn_resolutions:
355
+ h = modules[m_idx](h)
356
+ m_idx += 1
357
+
358
+ if self.progressive != 'none':
359
+ if i_level == self.num_resolutions - 1:
360
+ if self.progressive == 'output_skip':
361
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
362
+ m_idx += 1
363
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
364
+ m_idx += 1
365
+ elif self.progressive == 'residual':
366
+ pyramid = self.act(modules[m_idx](h))
367
+ m_idx += 1
368
+ pyramid = modules[m_idx](pyramid)
369
+ m_idx += 1
370
+ else:
371
+ raise ValueError(f'{self.progressive} is not a valid name.')
372
+ else:
373
+ if self.progressive == 'output_skip':
374
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
375
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
376
+ m_idx += 1
377
+ pyramid_h = modules[m_idx](pyramid_h)
378
+ m_idx += 1
379
+ pyramid = pyramid + pyramid_h
380
+ elif self.progressive == 'residual':
381
+ pyramid = modules[m_idx](pyramid)
382
+ m_idx += 1
383
+ if self.skip_rescale:
384
+ pyramid = (pyramid + h) / np.sqrt(2.)
385
+ else:
386
+ pyramid = pyramid + h
387
+ h = pyramid
388
+ else:
389
+ raise ValueError(f'{self.progressive} is not a valid name')
390
+
391
+ # Upsampling Layer
392
+ if i_level != 0:
393
+ if self.resblock_type == 'ddpm':
394
+ h = modules[m_idx](h)
395
+ m_idx += 1
396
+ else:
397
+ h = modules[m_idx](h, temb) # Upspampling
398
+ m_idx += 1
399
+
400
+ assert not hs
401
+
402
+ if self.progressive == 'output_skip':
403
+ h = pyramid
404
+ else:
405
+ h = self.act(modules[m_idx](h))
406
+ m_idx += 1
407
+ h = modules[m_idx](h)
408
+ m_idx += 1
409
+
410
+ assert m_idx == len(modules), "Implementation error"
411
+ if self.scale_by_sigma:
412
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
413
+ h = h / used_sigmas
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
418
+ h = torch.view_as_complex(h)[:,None, :, :]
419
+ return h
 
sgmse/backbones/ncsnpp_48k.py CHANGED
@@ -1,424 +1,424 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # pylint: skip-file
17
-
18
- from .ncsnpp_utils import layers, layerspp, normalization
19
- import torch.nn as nn
20
- import functools
21
- import torch
22
- import numpy as np
23
-
24
- from .shared import BackboneRegistry
25
-
26
- ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
- ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
- Combine = layerspp.Combine
29
- conv3x3 = layerspp.conv3x3
30
- conv1x1 = layerspp.conv1x1
31
- get_act = layers.get_act
32
- get_normalization = normalization.get_normalization
33
- default_initializer = layers.default_init
34
-
35
-
36
- @BackboneRegistry.register("ncsnpp_48k")
37
- class NCSNpp_48k(nn.Module):
38
- """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
-
40
- @staticmethod
41
- def add_argparse_args(parser):
42
- parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
- parser.add_argument("--num_res_blocks", type=int, default=2)
44
- parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
45
- parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
46
- parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
47
- parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
48
- parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
49
- parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
50
- parser.set_defaults(centered=True)
51
- return parser
52
-
53
- def __init__(self,
54
- scale_by_sigma = True,
55
- nonlinearity = 'swish',
56
- nf = 128,
57
- ch_mult = (1, 1, 2, 2, 2, 2, 2),
58
- num_res_blocks = 2,
59
- attn_resolutions = (),
60
- resamp_with_conv = True,
61
- conditional = True,
62
- fir = True,
63
- fir_kernel = [1, 3, 3, 1],
64
- skip_rescale = True,
65
- resblock_type = 'biggan',
66
- progressive = 'none',
67
- progressive_input = 'none',
68
- progressive_combine = 'sum',
69
- init_scale = 0.,
70
- fourier_scale = 16,
71
- image_size = 256,
72
- embedding_type = 'fourier',
73
- dropout = .0,
74
- centered = True,
75
- **unused_kwargs
76
- ):
77
- super().__init__()
78
- self.act = act = get_act(nonlinearity)
79
-
80
- self.nf = nf = nf
81
- ch_mult = ch_mult
82
- self.num_res_blocks = num_res_blocks = num_res_blocks
83
- self.attn_resolutions = attn_resolutions
84
- dropout = dropout
85
- resamp_with_conv = resamp_with_conv
86
- self.num_resolutions = num_resolutions = len(ch_mult)
87
- self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
88
-
89
- self.conditional = conditional = conditional # noise-conditional
90
- self.centered = centered
91
- self.scale_by_sigma = scale_by_sigma
92
-
93
- fir = fir
94
- fir_kernel = fir_kernel
95
- self.skip_rescale = skip_rescale = skip_rescale
96
- self.resblock_type = resblock_type = resblock_type.lower()
97
- self.progressive = progressive = progressive.lower()
98
- self.progressive_input = progressive_input = progressive_input.lower()
99
- self.embedding_type = embedding_type = embedding_type.lower()
100
- init_scale = init_scale
101
- assert progressive in ['none', 'output_skip', 'residual']
102
- assert progressive_input in ['none', 'input_skip', 'residual']
103
- assert embedding_type in ['fourier', 'positional']
104
- combine_method = progressive_combine.lower()
105
- combiner = functools.partial(Combine, method=combine_method)
106
-
107
- num_channels = 4 # x.real, x.imag, y.real, y.imag
108
- self.output_layer = nn.Conv2d(num_channels, 2, 1)
109
-
110
- modules = []
111
- # timestep/noise_level embedding
112
- if embedding_type == 'fourier':
113
- # Gaussian Fourier features embeddings.
114
- modules.append(layerspp.GaussianFourierProjection(
115
- embedding_size=nf, scale=fourier_scale
116
- ))
117
- embed_dim = 2 * nf
118
- elif embedding_type == 'positional':
119
- embed_dim = nf
120
- else:
121
- raise ValueError(f'embedding type {embedding_type} unknown.')
122
-
123
- if conditional:
124
- modules.append(nn.Linear(embed_dim, nf * 4))
125
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
- nn.init.zeros_(modules[-1].bias)
127
- modules.append(nn.Linear(nf * 4, nf * 4))
128
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
129
- nn.init.zeros_(modules[-1].bias)
130
-
131
- AttnBlock = functools.partial(layerspp.AttnBlockpp,
132
- init_scale=init_scale, skip_rescale=skip_rescale)
133
-
134
- Upsample = functools.partial(layerspp.Upsample,
135
- with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
136
-
137
- if progressive == 'output_skip':
138
- self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
139
- elif progressive == 'residual':
140
- pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
141
- fir_kernel=fir_kernel, with_conv=True)
142
-
143
- Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
144
-
145
- if progressive_input == 'input_skip':
146
- self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
147
- elif progressive_input == 'residual':
148
- pyramid_downsample = functools.partial(layerspp.Downsample,
149
- fir=fir, fir_kernel=fir_kernel, with_conv=True)
150
-
151
- if resblock_type == 'ddpm':
152
- ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
153
- dropout=dropout, init_scale=init_scale,
154
- skip_rescale=skip_rescale, temb_dim=nf * 4)
155
-
156
- elif resblock_type == 'biggan':
157
- ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
158
- dropout=dropout, fir=fir, fir_kernel=fir_kernel,
159
- init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
160
-
161
- else:
162
- raise ValueError(f'resblock type {resblock_type} unrecognized.')
163
-
164
- # Downsampling block
165
-
166
- channels = num_channels
167
- if progressive_input != 'none':
168
- input_pyramid_ch = channels
169
-
170
- modules.append(conv3x3(channels, nf))
171
- hs_c = [nf]
172
-
173
- in_ch = nf
174
- for i_level in range(num_resolutions):
175
- # Residual blocks for this resolution
176
- for i_block in range(num_res_blocks):
177
- out_ch = nf * ch_mult[i_level]
178
- modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
179
- in_ch = out_ch
180
-
181
- if all_resolutions[i_level] in attn_resolutions:
182
- modules.append(AttnBlock(channels=in_ch))
183
- hs_c.append(in_ch)
184
-
185
- if i_level != num_resolutions - 1:
186
- if resblock_type == 'ddpm':
187
- modules.append(Downsample(in_ch=in_ch))
188
- else:
189
- modules.append(ResnetBlock(down=True, in_ch=in_ch))
190
-
191
- if progressive_input == 'input_skip':
192
- modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
193
- if combine_method == 'cat':
194
- in_ch *= 2
195
-
196
- elif progressive_input == 'residual':
197
- modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
198
- input_pyramid_ch = in_ch
199
-
200
- hs_c.append(in_ch)
201
-
202
- in_ch = hs_c[-1]
203
- modules.append(ResnetBlock(in_ch=in_ch))
204
- modules.append(AttnBlock(channels=in_ch))
205
- modules.append(ResnetBlock(in_ch=in_ch))
206
-
207
- pyramid_ch = 0
208
- # Upsampling block
209
- for i_level in reversed(range(num_resolutions)):
210
- for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
211
- out_ch = nf * ch_mult[i_level]
212
- modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
213
- in_ch = out_ch
214
-
215
- if all_resolutions[i_level] in attn_resolutions:
216
- modules.append(AttnBlock(channels=in_ch))
217
-
218
- if progressive != 'none':
219
- if i_level == num_resolutions - 1:
220
- if progressive == 'output_skip':
221
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
222
- num_channels=in_ch, eps=1e-6))
223
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
224
- pyramid_ch = channels
225
- elif progressive == 'residual':
226
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
227
- modules.append(conv3x3(in_ch, in_ch, bias=True))
228
- pyramid_ch = in_ch
229
- else:
230
- raise ValueError(f'{progressive} is not a valid name.')
231
- else:
232
- if progressive == 'output_skip':
233
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
234
- num_channels=in_ch, eps=1e-6))
235
- modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
236
- pyramid_ch = channels
237
- elif progressive == 'residual':
238
- modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
239
- pyramid_ch = in_ch
240
- else:
241
- raise ValueError(f'{progressive} is not a valid name')
242
-
243
- if i_level != 0:
244
- if resblock_type == 'ddpm':
245
- modules.append(Upsample(in_ch=in_ch))
246
- else:
247
- modules.append(ResnetBlock(in_ch=in_ch, up=True))
248
-
249
- assert not hs_c
250
-
251
- if progressive != 'output_skip':
252
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
253
- num_channels=in_ch, eps=1e-6))
254
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
255
-
256
- self.all_modules = nn.ModuleList(modules)
257
-
258
-
259
- def forward(self, x, time_cond):
260
- # timestep/noise_level embedding; only for continuous training
261
- modules = self.all_modules
262
- m_idx = 0
263
-
264
- # Convert real and imaginary parts of (x,y) into four channel dimensions
265
- x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
266
- x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
267
-
268
- if self.embedding_type == 'fourier':
269
- # Gaussian Fourier features embeddings.
270
- used_sigmas = time_cond
271
- temb = modules[m_idx](torch.log(used_sigmas))
272
- m_idx += 1
273
-
274
- elif self.embedding_type == 'positional':
275
- # Sinusoidal positional embeddings.
276
- timesteps = time_cond
277
- used_sigmas = self.sigmas[time_cond.long()]
278
- temb = layers.get_timestep_embedding(timesteps, self.nf)
279
-
280
- else:
281
- raise ValueError(f'embedding type {self.embedding_type} unknown.')
282
-
283
- if self.conditional:
284
- temb = modules[m_idx](temb)
285
- m_idx += 1
286
- temb = modules[m_idx](self.act(temb))
287
- m_idx += 1
288
- else:
289
- temb = None
290
-
291
- if not self.centered:
292
- # If input data is in [0, 1]
293
- x = 2 * x - 1.
294
-
295
- # Downsampling block
296
- input_pyramid = None
297
- if self.progressive_input != 'none':
298
- input_pyramid = x
299
-
300
- # Input layer: Conv2d: 4ch -> 128ch
301
- hs = [modules[m_idx](x)]
302
- m_idx += 1
303
-
304
- # Down path in U-Net
305
- for i_level in range(self.num_resolutions):
306
- # Residual blocks for this resolution
307
- for i_block in range(self.num_res_blocks):
308
- h = modules[m_idx](hs[-1], temb)
309
- m_idx += 1
310
- # Attention layer (optional)
311
- if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
312
- h = modules[m_idx](h)
313
- m_idx += 1
314
- hs.append(h)
315
-
316
- # Downsampling
317
- if i_level != self.num_resolutions - 1:
318
- if self.resblock_type == 'ddpm':
319
- h = modules[m_idx](hs[-1])
320
- m_idx += 1
321
- else:
322
- h = modules[m_idx](hs[-1], temb)
323
- m_idx += 1
324
-
325
- if self.progressive_input == 'input_skip': # Combine h with x
326
- input_pyramid = self.pyramid_downsample(input_pyramid)
327
- h = modules[m_idx](input_pyramid, h)
328
- m_idx += 1
329
-
330
- elif self.progressive_input == 'residual':
331
- input_pyramid = modules[m_idx](input_pyramid)
332
- m_idx += 1
333
- if self.skip_rescale:
334
- input_pyramid = (input_pyramid + h) / np.sqrt(2.)
335
- else:
336
- input_pyramid = input_pyramid + h
337
- h = input_pyramid
338
- hs.append(h)
339
-
340
- h = hs[-1] # actualy equal to: h = h
341
- h = modules[m_idx](h, temb) # ResNet block
342
- m_idx += 1
343
- h = modules[m_idx](h) # Attention block
344
- m_idx += 1
345
- h = modules[m_idx](h, temb) # ResNet block
346
- m_idx += 1
347
-
348
- pyramid = None
349
-
350
- # Upsampling block
351
- for i_level in reversed(range(self.num_resolutions)):
352
- for i_block in range(self.num_res_blocks + 1):
353
- h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
354
- m_idx += 1
355
-
356
- # edit: from -1 to -2
357
- if h.shape[-2] in self.attn_resolutions:
358
- h = modules[m_idx](h)
359
- m_idx += 1
360
-
361
- if self.progressive != 'none':
362
- if i_level == self.num_resolutions - 1:
363
- if self.progressive == 'output_skip':
364
- pyramid = self.act(modules[m_idx](h)) # GroupNorm
365
- m_idx += 1
366
- pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
367
- m_idx += 1
368
- elif self.progressive == 'residual':
369
- pyramid = self.act(modules[m_idx](h))
370
- m_idx += 1
371
- pyramid = modules[m_idx](pyramid)
372
- m_idx += 1
373
- else:
374
- raise ValueError(f'{self.progressive} is not a valid name.')
375
- else:
376
- if self.progressive == 'output_skip':
377
- pyramid = self.pyramid_upsample(pyramid) # Upsample
378
- pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
379
- m_idx += 1
380
- pyramid_h = modules[m_idx](pyramid_h)
381
- m_idx += 1
382
- pyramid = pyramid + pyramid_h
383
- elif self.progressive == 'residual':
384
- pyramid = modules[m_idx](pyramid)
385
- m_idx += 1
386
- if self.skip_rescale:
387
- pyramid = (pyramid + h) / np.sqrt(2.)
388
- else:
389
- pyramid = pyramid + h
390
- h = pyramid
391
- else:
392
- raise ValueError(f'{self.progressive} is not a valid name')
393
-
394
- # Upsampling Layer
395
- if i_level != 0:
396
- if self.resblock_type == 'ddpm':
397
- h = modules[m_idx](h)
398
- m_idx += 1
399
- else:
400
- h = modules[m_idx](h, temb) # Upspampling
401
- m_idx += 1
402
-
403
- assert not hs
404
-
405
- if self.progressive == 'output_skip':
406
- h = pyramid
407
- else:
408
- h = self.act(modules[m_idx](h))
409
- m_idx += 1
410
- h = modules[m_idx](h)
411
- m_idx += 1
412
-
413
- assert m_idx == len(modules), "Implementation error"
414
-
415
- # Convert back to complex number
416
- h = self.output_layer(h)
417
-
418
- if self.scale_by_sigma:
419
- used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
420
- h = h / used_sigmas
421
-
422
- h = torch.permute(h, (0, 2, 3, 1)).contiguous()
423
- h = torch.view_as_complex(h)[:,None, :, :]
424
- return h
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp_48k")
37
+ class NCSNpp_48k(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
43
+ parser.add_argument("--num_res_blocks", type=int, default=2)
44
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[])
45
+ parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model")
46
+ parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
47
+ parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
48
+ parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method")
49
+ parser.add_argument("--progressive_input", type=str, default='none', help="Progressive upsampling method")
50
+ parser.set_defaults(centered=True)
51
+ return parser
52
+
53
+ def __init__(self,
54
+ scale_by_sigma = True,
55
+ nonlinearity = 'swish',
56
+ nf = 128,
57
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
58
+ num_res_blocks = 2,
59
+ attn_resolutions = (),
60
+ resamp_with_conv = True,
61
+ conditional = True,
62
+ fir = True,
63
+ fir_kernel = [1, 3, 3, 1],
64
+ skip_rescale = True,
65
+ resblock_type = 'biggan',
66
+ progressive = 'none',
67
+ progressive_input = 'none',
68
+ progressive_combine = 'sum',
69
+ init_scale = 0.,
70
+ fourier_scale = 16,
71
+ image_size = 256,
72
+ embedding_type = 'fourier',
73
+ dropout = .0,
74
+ centered = True,
75
+ **unused_kwargs
76
+ ):
77
+ super().__init__()
78
+ self.act = act = get_act(nonlinearity)
79
+
80
+ self.nf = nf = nf
81
+ ch_mult = ch_mult
82
+ self.num_res_blocks = num_res_blocks = num_res_blocks
83
+ self.attn_resolutions = attn_resolutions
84
+ dropout = dropout
85
+ resamp_with_conv = resamp_with_conv
86
+ self.num_resolutions = num_resolutions = len(ch_mult)
87
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
88
+
89
+ self.conditional = conditional = conditional # noise-conditional
90
+ self.centered = centered
91
+ self.scale_by_sigma = scale_by_sigma
92
+
93
+ fir = fir
94
+ fir_kernel = fir_kernel
95
+ self.skip_rescale = skip_rescale = skip_rescale
96
+ self.resblock_type = resblock_type = resblock_type.lower()
97
+ self.progressive = progressive = progressive.lower()
98
+ self.progressive_input = progressive_input = progressive_input.lower()
99
+ self.embedding_type = embedding_type = embedding_type.lower()
100
+ init_scale = init_scale
101
+ assert progressive in ['none', 'output_skip', 'residual']
102
+ assert progressive_input in ['none', 'input_skip', 'residual']
103
+ assert embedding_type in ['fourier', 'positional']
104
+ combine_method = progressive_combine.lower()
105
+ combiner = functools.partial(Combine, method=combine_method)
106
+
107
+ num_channels = 4 # x.real, x.imag, y.real, y.imag
108
+ self.output_layer = nn.Conv2d(num_channels, 2, 1)
109
+
110
+ modules = []
111
+ # timestep/noise_level embedding
112
+ if embedding_type == 'fourier':
113
+ # Gaussian Fourier features embeddings.
114
+ modules.append(layerspp.GaussianFourierProjection(
115
+ embedding_size=nf, scale=fourier_scale
116
+ ))
117
+ embed_dim = 2 * nf
118
+ elif embedding_type == 'positional':
119
+ embed_dim = nf
120
+ else:
121
+ raise ValueError(f'embedding type {embedding_type} unknown.')
122
+
123
+ if conditional:
124
+ modules.append(nn.Linear(embed_dim, nf * 4))
125
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
126
+ nn.init.zeros_(modules[-1].bias)
127
+ modules.append(nn.Linear(nf * 4, nf * 4))
128
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
129
+ nn.init.zeros_(modules[-1].bias)
130
+
131
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
132
+ init_scale=init_scale, skip_rescale=skip_rescale)
133
+
134
+ Upsample = functools.partial(layerspp.Upsample,
135
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
136
+
137
+ if progressive == 'output_skip':
138
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
139
+ elif progressive == 'residual':
140
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
141
+ fir_kernel=fir_kernel, with_conv=True)
142
+
143
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
144
+
145
+ if progressive_input == 'input_skip':
146
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
147
+ elif progressive_input == 'residual':
148
+ pyramid_downsample = functools.partial(layerspp.Downsample,
149
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
150
+
151
+ if resblock_type == 'ddpm':
152
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
153
+ dropout=dropout, init_scale=init_scale,
154
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
155
+
156
+ elif resblock_type == 'biggan':
157
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
158
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
159
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
160
+
161
+ else:
162
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
163
+
164
+ # Downsampling block
165
+
166
+ channels = num_channels
167
+ if progressive_input != 'none':
168
+ input_pyramid_ch = channels
169
+
170
+ modules.append(conv3x3(channels, nf))
171
+ hs_c = [nf]
172
+
173
+ in_ch = nf
174
+ for i_level in range(num_resolutions):
175
+ # Residual blocks for this resolution
176
+ for i_block in range(num_res_blocks):
177
+ out_ch = nf * ch_mult[i_level]
178
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
179
+ in_ch = out_ch
180
+
181
+ if all_resolutions[i_level] in attn_resolutions:
182
+ modules.append(AttnBlock(channels=in_ch))
183
+ hs_c.append(in_ch)
184
+
185
+ if i_level != num_resolutions - 1:
186
+ if resblock_type == 'ddpm':
187
+ modules.append(Downsample(in_ch=in_ch))
188
+ else:
189
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
190
+
191
+ if progressive_input == 'input_skip':
192
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
193
+ if combine_method == 'cat':
194
+ in_ch *= 2
195
+
196
+ elif progressive_input == 'residual':
197
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
198
+ input_pyramid_ch = in_ch
199
+
200
+ hs_c.append(in_ch)
201
+
202
+ in_ch = hs_c[-1]
203
+ modules.append(ResnetBlock(in_ch=in_ch))
204
+ modules.append(AttnBlock(channels=in_ch))
205
+ modules.append(ResnetBlock(in_ch=in_ch))
206
+
207
+ pyramid_ch = 0
208
+ # Upsampling block
209
+ for i_level in reversed(range(num_resolutions)):
210
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
211
+ out_ch = nf * ch_mult[i_level]
212
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
213
+ in_ch = out_ch
214
+
215
+ if all_resolutions[i_level] in attn_resolutions:
216
+ modules.append(AttnBlock(channels=in_ch))
217
+
218
+ if progressive != 'none':
219
+ if i_level == num_resolutions - 1:
220
+ if progressive == 'output_skip':
221
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
222
+ num_channels=in_ch, eps=1e-6))
223
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
224
+ pyramid_ch = channels
225
+ elif progressive == 'residual':
226
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
227
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
228
+ pyramid_ch = in_ch
229
+ else:
230
+ raise ValueError(f'{progressive} is not a valid name.')
231
+ else:
232
+ if progressive == 'output_skip':
233
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
234
+ num_channels=in_ch, eps=1e-6))
235
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
236
+ pyramid_ch = channels
237
+ elif progressive == 'residual':
238
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
239
+ pyramid_ch = in_ch
240
+ else:
241
+ raise ValueError(f'{progressive} is not a valid name')
242
+
243
+ if i_level != 0:
244
+ if resblock_type == 'ddpm':
245
+ modules.append(Upsample(in_ch=in_ch))
246
+ else:
247
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
248
+
249
+ assert not hs_c
250
+
251
+ if progressive != 'output_skip':
252
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
253
+ num_channels=in_ch, eps=1e-6))
254
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
255
+
256
+ self.all_modules = nn.ModuleList(modules)
257
+
258
+
259
+ def forward(self, x, time_cond):
260
+ # timestep/noise_level embedding; only for continuous training
261
+ modules = self.all_modules
262
+ m_idx = 0
263
+
264
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
265
+ x = torch.cat((x[:,[0],:,:].real, x[:,[0],:,:].imag,
266
+ x[:,[1],:,:].real, x[:,[1],:,:].imag), dim=1)
267
+
268
+ if self.embedding_type == 'fourier':
269
+ # Gaussian Fourier features embeddings.
270
+ used_sigmas = time_cond
271
+ temb = modules[m_idx](torch.log(used_sigmas))
272
+ m_idx += 1
273
+
274
+ elif self.embedding_type == 'positional':
275
+ # Sinusoidal positional embeddings.
276
+ timesteps = time_cond
277
+ used_sigmas = self.sigmas[time_cond.long()]
278
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
279
+
280
+ else:
281
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
282
+
283
+ if self.conditional:
284
+ temb = modules[m_idx](temb)
285
+ m_idx += 1
286
+ temb = modules[m_idx](self.act(temb))
287
+ m_idx += 1
288
+ else:
289
+ temb = None
290
+
291
+ if not self.centered:
292
+ # If input data is in [0, 1]
293
+ x = 2 * x - 1.
294
+
295
+ # Downsampling block
296
+ input_pyramid = None
297
+ if self.progressive_input != 'none':
298
+ input_pyramid = x
299
+
300
+ # Input layer: Conv2d: 4ch -> 128ch
301
+ hs = [modules[m_idx](x)]
302
+ m_idx += 1
303
+
304
+ # Down path in U-Net
305
+ for i_level in range(self.num_resolutions):
306
+ # Residual blocks for this resolution
307
+ for i_block in range(self.num_res_blocks):
308
+ h = modules[m_idx](hs[-1], temb)
309
+ m_idx += 1
310
+ # Attention layer (optional)
311
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
312
+ h = modules[m_idx](h)
313
+ m_idx += 1
314
+ hs.append(h)
315
+
316
+ # Downsampling
317
+ if i_level != self.num_resolutions - 1:
318
+ if self.resblock_type == 'ddpm':
319
+ h = modules[m_idx](hs[-1])
320
+ m_idx += 1
321
+ else:
322
+ h = modules[m_idx](hs[-1], temb)
323
+ m_idx += 1
324
+
325
+ if self.progressive_input == 'input_skip': # Combine h with x
326
+ input_pyramid = self.pyramid_downsample(input_pyramid)
327
+ h = modules[m_idx](input_pyramid, h)
328
+ m_idx += 1
329
+
330
+ elif self.progressive_input == 'residual':
331
+ input_pyramid = modules[m_idx](input_pyramid)
332
+ m_idx += 1
333
+ if self.skip_rescale:
334
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
335
+ else:
336
+ input_pyramid = input_pyramid + h
337
+ h = input_pyramid
338
+ hs.append(h)
339
+
340
+ h = hs[-1] # actualy equal to: h = h
341
+ h = modules[m_idx](h, temb) # ResNet block
342
+ m_idx += 1
343
+ h = modules[m_idx](h) # Attention block
344
+ m_idx += 1
345
+ h = modules[m_idx](h, temb) # ResNet block
346
+ m_idx += 1
347
+
348
+ pyramid = None
349
+
350
+ # Upsampling block
351
+ for i_level in reversed(range(self.num_resolutions)):
352
+ for i_block in range(self.num_res_blocks + 1):
353
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
354
+ m_idx += 1
355
+
356
+ # edit: from -1 to -2
357
+ if h.shape[-2] in self.attn_resolutions:
358
+ h = modules[m_idx](h)
359
+ m_idx += 1
360
+
361
+ if self.progressive != 'none':
362
+ if i_level == self.num_resolutions - 1:
363
+ if self.progressive == 'output_skip':
364
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
365
+ m_idx += 1
366
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
367
+ m_idx += 1
368
+ elif self.progressive == 'residual':
369
+ pyramid = self.act(modules[m_idx](h))
370
+ m_idx += 1
371
+ pyramid = modules[m_idx](pyramid)
372
+ m_idx += 1
373
+ else:
374
+ raise ValueError(f'{self.progressive} is not a valid name.')
375
+ else:
376
+ if self.progressive == 'output_skip':
377
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
378
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
379
+ m_idx += 1
380
+ pyramid_h = modules[m_idx](pyramid_h)
381
+ m_idx += 1
382
+ pyramid = pyramid + pyramid_h
383
+ elif self.progressive == 'residual':
384
+ pyramid = modules[m_idx](pyramid)
385
+ m_idx += 1
386
+ if self.skip_rescale:
387
+ pyramid = (pyramid + h) / np.sqrt(2.)
388
+ else:
389
+ pyramid = pyramid + h
390
+ h = pyramid
391
+ else:
392
+ raise ValueError(f'{self.progressive} is not a valid name')
393
+
394
+ # Upsampling Layer
395
+ if i_level != 0:
396
+ if self.resblock_type == 'ddpm':
397
+ h = modules[m_idx](h)
398
+ m_idx += 1
399
+ else:
400
+ h = modules[m_idx](h, temb) # Upspampling
401
+ m_idx += 1
402
+
403
+ assert not hs
404
+
405
+ if self.progressive == 'output_skip':
406
+ h = pyramid
407
+ else:
408
+ h = self.act(modules[m_idx](h))
409
+ m_idx += 1
410
+ h = modules[m_idx](h)
411
+ m_idx += 1
412
+
413
+ assert m_idx == len(modules), "Implementation error"
414
+
415
+ # Convert back to complex number
416
+ h = self.output_layer(h)
417
+
418
+ if self.scale_by_sigma:
419
+ used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
420
+ h = h / used_sigmas
421
+
422
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
423
+ h = torch.view_as_complex(h)[:,None, :, :]
424
+ return h
sgmse/backbones/ncsnpp_utils/layers.py CHANGED
@@ -1,662 +1,662 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # pylint: skip-file
17
- """Common layers for defining score networks.
18
- """
19
- import math
20
- import string
21
- from functools import partial
22
- import torch.nn as nn
23
- import torch
24
- import torch.nn.functional as F
25
- import numpy as np
26
- from .normalization import ConditionalInstanceNorm2dPlus
27
-
28
-
29
- def get_act(config):
30
- """Get activation functions from the config file."""
31
-
32
- if config == 'elu':
33
- return nn.ELU()
34
- elif config == 'relu':
35
- return nn.ReLU()
36
- elif config == 'lrelu':
37
- return nn.LeakyReLU(negative_slope=0.2)
38
- elif config == 'swish':
39
- return nn.SiLU()
40
- else:
41
- raise NotImplementedError('activation function does not exist!')
42
-
43
-
44
- def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
- """1x1 convolution. Same as NCSNv1/v2."""
46
- conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
- padding=padding)
48
- init_scale = 1e-10 if init_scale == 0 else init_scale
49
- conv.weight.data *= init_scale
50
- conv.bias.data *= init_scale
51
- return conv
52
-
53
-
54
- def variance_scaling(scale, mode, distribution,
55
- in_axis=1, out_axis=0,
56
- dtype=torch.float32,
57
- device='cpu'):
58
- """Ported from JAX. """
59
-
60
- def _compute_fans(shape, in_axis=1, out_axis=0):
61
- receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
- fan_in = shape[in_axis] * receptive_field_size
63
- fan_out = shape[out_axis] * receptive_field_size
64
- return fan_in, fan_out
65
-
66
- def init(shape, dtype=dtype, device=device):
67
- fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
- if mode == "fan_in":
69
- denominator = fan_in
70
- elif mode == "fan_out":
71
- denominator = fan_out
72
- elif mode == "fan_avg":
73
- denominator = (fan_in + fan_out) / 2
74
- else:
75
- raise ValueError(
76
- "invalid mode for variance scaling initializer: {}".format(mode))
77
- variance = scale / denominator
78
- if distribution == "normal":
79
- return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
- elif distribution == "uniform":
81
- return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
- else:
83
- raise ValueError("invalid distribution for variance scaling initializer")
84
-
85
- return init
86
-
87
-
88
- def default_init(scale=1.):
89
- """The same initialization used in DDPM."""
90
- scale = 1e-10 if scale == 0 else scale
91
- return variance_scaling(scale, 'fan_avg', 'uniform')
92
-
93
-
94
- class Dense(nn.Module):
95
- """Linear layer with `default_init`."""
96
- def __init__(self):
97
- super().__init__()
98
-
99
-
100
- def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
- """1x1 convolution with DDPM initialization."""
102
- conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
- conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
- nn.init.zeros_(conv.bias)
105
- return conv
106
-
107
-
108
- def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
- """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
- init_scale = 1e-10 if init_scale == 0 else init_scale
111
- conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
- dilation=dilation, padding=padding, kernel_size=3)
113
- conv.weight.data *= init_scale
114
- conv.bias.data *= init_scale
115
- return conv
116
-
117
-
118
- def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
- """3x3 convolution with DDPM initialization."""
120
- conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
- dilation=dilation, bias=bias)
122
- conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
- nn.init.zeros_(conv.bias)
124
- return conv
125
-
126
- ###########################################################################
127
- # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
- # https://github.com/ermongroup/ncsn
129
- # https://github.com/ermongroup/ncsnv2
130
- ###########################################################################
131
-
132
-
133
- class CRPBlock(nn.Module):
134
- def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
- super().__init__()
136
- self.convs = nn.ModuleList()
137
- for i in range(n_stages):
138
- self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
- self.n_stages = n_stages
140
- if maxpool:
141
- self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
- else:
143
- self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
-
145
- self.act = act
146
-
147
- def forward(self, x):
148
- x = self.act(x)
149
- path = x
150
- for i in range(self.n_stages):
151
- path = self.pool(path)
152
- path = self.convs[i](path)
153
- x = path + x
154
- return x
155
-
156
-
157
- class CondCRPBlock(nn.Module):
158
- def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
- super().__init__()
160
- self.convs = nn.ModuleList()
161
- self.norms = nn.ModuleList()
162
- self.normalizer = normalizer
163
- for i in range(n_stages):
164
- self.norms.append(normalizer(features, num_classes, bias=True))
165
- self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
-
167
- self.n_stages = n_stages
168
- self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
- self.act = act
170
-
171
- def forward(self, x, y):
172
- x = self.act(x)
173
- path = x
174
- for i in range(self.n_stages):
175
- path = self.norms[i](path, y)
176
- path = self.pool(path)
177
- path = self.convs[i](path)
178
-
179
- x = path + x
180
- return x
181
-
182
-
183
- class RCUBlock(nn.Module):
184
- def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
- super().__init__()
186
-
187
- for i in range(n_blocks):
188
- for j in range(n_stages):
189
- setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
-
191
- self.stride = 1
192
- self.n_blocks = n_blocks
193
- self.n_stages = n_stages
194
- self.act = act
195
-
196
- def forward(self, x):
197
- for i in range(self.n_blocks):
198
- residual = x
199
- for j in range(self.n_stages):
200
- x = self.act(x)
201
- x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
-
203
- x += residual
204
- return x
205
-
206
-
207
- class CondRCUBlock(nn.Module):
208
- def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
- super().__init__()
210
-
211
- for i in range(n_blocks):
212
- for j in range(n_stages):
213
- setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
- setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
-
216
- self.stride = 1
217
- self.n_blocks = n_blocks
218
- self.n_stages = n_stages
219
- self.act = act
220
- self.normalizer = normalizer
221
-
222
- def forward(self, x, y):
223
- for i in range(self.n_blocks):
224
- residual = x
225
- for j in range(self.n_stages):
226
- x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
- x = self.act(x)
228
- x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
-
230
- x += residual
231
- return x
232
-
233
-
234
- class MSFBlock(nn.Module):
235
- def __init__(self, in_planes, features):
236
- super().__init__()
237
- assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
- self.convs = nn.ModuleList()
239
- self.features = features
240
-
241
- for i in range(len(in_planes)):
242
- self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
-
244
- def forward(self, xs, shape):
245
- sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
- for i in range(len(self.convs)):
247
- h = self.convs[i](xs[i])
248
- h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
- sums += h
250
- return sums
251
-
252
-
253
- class CondMSFBlock(nn.Module):
254
- def __init__(self, in_planes, features, num_classes, normalizer):
255
- super().__init__()
256
- assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
-
258
- self.convs = nn.ModuleList()
259
- self.norms = nn.ModuleList()
260
- self.features = features
261
- self.normalizer = normalizer
262
-
263
- for i in range(len(in_planes)):
264
- self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
- self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
-
267
- def forward(self, xs, y, shape):
268
- sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
- for i in range(len(self.convs)):
270
- h = self.norms[i](xs[i], y)
271
- h = self.convs[i](h)
272
- h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
- sums += h
274
- return sums
275
-
276
-
277
- class RefineBlock(nn.Module):
278
- def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
- super().__init__()
280
-
281
- assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
- self.n_blocks = n_blocks = len(in_planes)
283
-
284
- self.adapt_convs = nn.ModuleList()
285
- for i in range(n_blocks):
286
- self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
-
288
- self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
-
290
- if not start:
291
- self.msf = MSFBlock(in_planes, features)
292
-
293
- self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
-
295
- def forward(self, xs, output_shape):
296
- assert isinstance(xs, tuple) or isinstance(xs, list)
297
- hs = []
298
- for i in range(len(xs)):
299
- h = self.adapt_convs[i](xs[i])
300
- hs.append(h)
301
-
302
- if self.n_blocks > 1:
303
- h = self.msf(hs, output_shape)
304
- else:
305
- h = hs[0]
306
-
307
- h = self.crp(h)
308
- h = self.output_convs(h)
309
-
310
- return h
311
-
312
-
313
- class CondRefineBlock(nn.Module):
314
- def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
- super().__init__()
316
-
317
- assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
- self.n_blocks = n_blocks = len(in_planes)
319
-
320
- self.adapt_convs = nn.ModuleList()
321
- for i in range(n_blocks):
322
- self.adapt_convs.append(
323
- CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
- )
325
-
326
- self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
-
328
- if not start:
329
- self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
-
331
- self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
-
333
- def forward(self, xs, y, output_shape):
334
- assert isinstance(xs, tuple) or isinstance(xs, list)
335
- hs = []
336
- for i in range(len(xs)):
337
- h = self.adapt_convs[i](xs[i], y)
338
- hs.append(h)
339
-
340
- if self.n_blocks > 1:
341
- h = self.msf(hs, y, output_shape)
342
- else:
343
- h = hs[0]
344
-
345
- h = self.crp(h, y)
346
- h = self.output_convs(h, y)
347
-
348
- return h
349
-
350
-
351
- class ConvMeanPool(nn.Module):
352
- def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
- super().__init__()
354
- if not adjust_padding:
355
- conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
- self.conv = conv
357
- else:
358
- conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
-
360
- self.conv = nn.Sequential(
361
- nn.ZeroPad2d((1, 0, 1, 0)),
362
- conv
363
- )
364
-
365
- def forward(self, inputs):
366
- output = self.conv(inputs)
367
- output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
- output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
- return output
370
-
371
-
372
- class MeanPoolConv(nn.Module):
373
- def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
- super().__init__()
375
- self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
-
377
- def forward(self, inputs):
378
- output = inputs
379
- output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
- output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
- return self.conv(output)
382
-
383
-
384
- class UpsampleConv(nn.Module):
385
- def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
- super().__init__()
387
- self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
- self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
-
390
- def forward(self, inputs):
391
- output = inputs
392
- output = torch.cat([output, output, output, output], dim=1)
393
- output = self.pixelshuffle(output)
394
- return self.conv(output)
395
-
396
-
397
- class ConditionalResidualBlock(nn.Module):
398
- def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
- normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
- super().__init__()
401
- self.non_linearity = act
402
- self.input_dim = input_dim
403
- self.output_dim = output_dim
404
- self.resample = resample
405
- self.normalization = normalization
406
- if resample == 'down':
407
- if dilation > 1:
408
- self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
- self.normalize2 = normalization(input_dim, num_classes)
410
- self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
- conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
- else:
413
- self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
- self.normalize2 = normalization(input_dim, num_classes)
415
- self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
- conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
-
418
- elif resample is None:
419
- if dilation > 1:
420
- conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
- self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
- self.normalize2 = normalization(output_dim, num_classes)
423
- self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
- else:
425
- conv_shortcut = nn.Conv2d
426
- self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
- self.normalize2 = normalization(output_dim, num_classes)
428
- self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
- else:
430
- raise Exception('invalid resample value')
431
-
432
- if output_dim != input_dim or resample is not None:
433
- self.shortcut = conv_shortcut(input_dim, output_dim)
434
-
435
- self.normalize1 = normalization(input_dim, num_classes)
436
-
437
- def forward(self, x, y):
438
- output = self.normalize1(x, y)
439
- output = self.non_linearity(output)
440
- output = self.conv1(output)
441
- output = self.normalize2(output, y)
442
- output = self.non_linearity(output)
443
- output = self.conv2(output)
444
-
445
- if self.output_dim == self.input_dim and self.resample is None:
446
- shortcut = x
447
- else:
448
- shortcut = self.shortcut(x)
449
-
450
- return shortcut + output
451
-
452
-
453
- class ResidualBlock(nn.Module):
454
- def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
- normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
- super().__init__()
457
- self.non_linearity = act
458
- self.input_dim = input_dim
459
- self.output_dim = output_dim
460
- self.resample = resample
461
- self.normalization = normalization
462
- if resample == 'down':
463
- if dilation > 1:
464
- self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
- self.normalize2 = normalization(input_dim)
466
- self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
- conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
- else:
469
- self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
- self.normalize2 = normalization(input_dim)
471
- self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
- conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
-
474
- elif resample is None:
475
- if dilation > 1:
476
- conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
- self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
- self.normalize2 = normalization(output_dim)
479
- self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
- else:
481
- # conv_shortcut = nn.Conv2d ### Something wierd here.
482
- conv_shortcut = partial(ncsn_conv1x1)
483
- self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
- self.normalize2 = normalization(output_dim)
485
- self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
- else:
487
- raise Exception('invalid resample value')
488
-
489
- if output_dim != input_dim or resample is not None:
490
- self.shortcut = conv_shortcut(input_dim, output_dim)
491
-
492
- self.normalize1 = normalization(input_dim)
493
-
494
- def forward(self, x):
495
- output = self.normalize1(x)
496
- output = self.non_linearity(output)
497
- output = self.conv1(output)
498
- output = self.normalize2(output)
499
- output = self.non_linearity(output)
500
- output = self.conv2(output)
501
-
502
- if self.output_dim == self.input_dim and self.resample is None:
503
- shortcut = x
504
- else:
505
- shortcut = self.shortcut(x)
506
-
507
- return shortcut + output
508
-
509
-
510
- ###########################################################################
511
- # Functions below are ported over from the DDPM codebase:
512
- # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
- ###########################################################################
514
-
515
- def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
- assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
- half_dim = embedding_dim // 2
518
- # magic number 10000 is from transformers
519
- emb = math.log(max_positions) / (half_dim - 1)
520
- # emb = math.log(2.) / (half_dim - 1)
521
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
- # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
- # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
- emb = timesteps.float()[:, None] * emb[None, :]
525
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
- if embedding_dim % 2 == 1: # zero pad
527
- emb = F.pad(emb, (0, 1), mode='constant')
528
- assert emb.shape == (timesteps.shape[0], embedding_dim)
529
- return emb
530
-
531
-
532
- def _einsum(a, b, c, x, y):
533
- einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
- return torch.einsum(einsum_str, x, y)
535
-
536
-
537
- def contract_inner(x, y):
538
- """tensordot(x, y, 1)."""
539
- x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
- y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
- y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
- out_chars = x_chars[:-1] + y_chars[1:]
543
- return _einsum(x_chars, y_chars, out_chars, x, y)
544
-
545
-
546
- class NIN(nn.Module):
547
- def __init__(self, in_dim, num_units, init_scale=0.1):
548
- super().__init__()
549
- self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
- self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
-
552
- def forward(self, x):
553
- x = x.permute(0, 2, 3, 1)
554
- y = contract_inner(x, self.W) + self.b
555
- return y.permute(0, 3, 1, 2)
556
-
557
-
558
- class AttnBlock(nn.Module):
559
- """Channel-wise self-attention block."""
560
- def __init__(self, channels):
561
- super().__init__()
562
- self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
- self.NIN_0 = NIN(channels, channels)
564
- self.NIN_1 = NIN(channels, channels)
565
- self.NIN_2 = NIN(channels, channels)
566
- self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
-
568
- def forward(self, x):
569
- B, C, H, W = x.shape
570
- h = self.GroupNorm_0(x)
571
- q = self.NIN_0(h)
572
- k = self.NIN_1(h)
573
- v = self.NIN_2(h)
574
-
575
- w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
- w = torch.reshape(w, (B, H, W, H * W))
577
- w = F.softmax(w, dim=-1)
578
- w = torch.reshape(w, (B, H, W, H, W))
579
- h = torch.einsum('bhwij,bcij->bchw', w, v)
580
- h = self.NIN_3(h)
581
- return x + h
582
-
583
-
584
- class Upsample(nn.Module):
585
- def __init__(self, channels, with_conv=False):
586
- super().__init__()
587
- if with_conv:
588
- self.Conv_0 = ddpm_conv3x3(channels, channels)
589
- self.with_conv = with_conv
590
-
591
- def forward(self, x):
592
- B, C, H, W = x.shape
593
- h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
- if self.with_conv:
595
- h = self.Conv_0(h)
596
- return h
597
-
598
-
599
- class Downsample(nn.Module):
600
- def __init__(self, channels, with_conv=False):
601
- super().__init__()
602
- if with_conv:
603
- self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
- self.with_conv = with_conv
605
-
606
- def forward(self, x):
607
- B, C, H, W = x.shape
608
- # Emulate 'SAME' padding
609
- if self.with_conv:
610
- x = F.pad(x, (0, 1, 0, 1))
611
- x = self.Conv_0(x)
612
- else:
613
- x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
-
615
- assert x.shape == (B, C, H // 2, W // 2)
616
- return x
617
-
618
-
619
- class ResnetBlockDDPM(nn.Module):
620
- """The ResNet Blocks used in DDPM."""
621
- def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
- super().__init__()
623
- if out_ch is None:
624
- out_ch = in_ch
625
- self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
- self.act = act
627
- self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
- if temb_dim is not None:
629
- self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
- self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
- nn.init.zeros_(self.Dense_0.bias)
632
-
633
- self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
- self.Dropout_0 = nn.Dropout(dropout)
635
- self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
- if in_ch != out_ch:
637
- if conv_shortcut:
638
- self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
- else:
640
- self.NIN_0 = NIN(in_ch, out_ch)
641
- self.out_ch = out_ch
642
- self.in_ch = in_ch
643
- self.conv_shortcut = conv_shortcut
644
-
645
- def forward(self, x, temb=None):
646
- B, C, H, W = x.shape
647
- assert C == self.in_ch
648
- out_ch = self.out_ch if self.out_ch else self.in_ch
649
- h = self.act(self.GroupNorm_0(x))
650
- h = self.Conv_0(h)
651
- # Add bias to each feature map conditioned on the time embedding
652
- if temb is not None:
653
- h += self.Dense_0(self.act(temb))[:, :, None, None]
654
- h = self.act(self.GroupNorm_1(h))
655
- h = self.Dropout_0(h)
656
- h = self.Conv_1(h)
657
- if C != out_ch:
658
- if self.conv_shortcut:
659
- x = self.Conv_2(x)
660
- else:
661
- x = self.NIN_0(x)
662
  return x + h
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Common layers for defining score networks.
18
+ """
19
+ import math
20
+ import string
21
+ from functools import partial
22
+ import torch.nn as nn
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import numpy as np
26
+ from .normalization import ConditionalInstanceNorm2dPlus
27
+
28
+
29
+ def get_act(config):
30
+ """Get activation functions from the config file."""
31
+
32
+ if config == 'elu':
33
+ return nn.ELU()
34
+ elif config == 'relu':
35
+ return nn.ReLU()
36
+ elif config == 'lrelu':
37
+ return nn.LeakyReLU(negative_slope=0.2)
38
+ elif config == 'swish':
39
+ return nn.SiLU()
40
+ else:
41
+ raise NotImplementedError('activation function does not exist!')
42
+
43
+
44
+ def ncsn_conv1x1(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=0):
45
+ """1x1 convolution. Same as NCSNv1/v2."""
46
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias, dilation=dilation,
47
+ padding=padding)
48
+ init_scale = 1e-10 if init_scale == 0 else init_scale
49
+ conv.weight.data *= init_scale
50
+ conv.bias.data *= init_scale
51
+ return conv
52
+
53
+
54
+ def variance_scaling(scale, mode, distribution,
55
+ in_axis=1, out_axis=0,
56
+ dtype=torch.float32,
57
+ device='cpu'):
58
+ """Ported from JAX. """
59
+
60
+ def _compute_fans(shape, in_axis=1, out_axis=0):
61
+ receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
62
+ fan_in = shape[in_axis] * receptive_field_size
63
+ fan_out = shape[out_axis] * receptive_field_size
64
+ return fan_in, fan_out
65
+
66
+ def init(shape, dtype=dtype, device=device):
67
+ fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
68
+ if mode == "fan_in":
69
+ denominator = fan_in
70
+ elif mode == "fan_out":
71
+ denominator = fan_out
72
+ elif mode == "fan_avg":
73
+ denominator = (fan_in + fan_out) / 2
74
+ else:
75
+ raise ValueError(
76
+ "invalid mode for variance scaling initializer: {}".format(mode))
77
+ variance = scale / denominator
78
+ if distribution == "normal":
79
+ return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
80
+ elif distribution == "uniform":
81
+ return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
82
+ else:
83
+ raise ValueError("invalid distribution for variance scaling initializer")
84
+
85
+ return init
86
+
87
+
88
+ def default_init(scale=1.):
89
+ """The same initialization used in DDPM."""
90
+ scale = 1e-10 if scale == 0 else scale
91
+ return variance_scaling(scale, 'fan_avg', 'uniform')
92
+
93
+
94
+ class Dense(nn.Module):
95
+ """Linear layer with `default_init`."""
96
+ def __init__(self):
97
+ super().__init__()
98
+
99
+
100
+ def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
101
+ """1x1 convolution with DDPM initialization."""
102
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
103
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
104
+ nn.init.zeros_(conv.bias)
105
+ return conv
106
+
107
+
108
+ def ncsn_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
109
+ """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2."""
110
+ init_scale = 1e-10 if init_scale == 0 else init_scale
111
+ conv = nn.Conv2d(in_planes, out_planes, stride=stride, bias=bias,
112
+ dilation=dilation, padding=padding, kernel_size=3)
113
+ conv.weight.data *= init_scale
114
+ conv.bias.data *= init_scale
115
+ return conv
116
+
117
+
118
+ def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
119
+ """3x3 convolution with DDPM initialization."""
120
+ conv = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
121
+ dilation=dilation, bias=bias)
122
+ conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
123
+ nn.init.zeros_(conv.bias)
124
+ return conv
125
+
126
+ ###########################################################################
127
+ # Functions below are ported over from the NCSNv1/NCSNv2 codebase:
128
+ # https://github.com/ermongroup/ncsn
129
+ # https://github.com/ermongroup/ncsnv2
130
+ ###########################################################################
131
+
132
+
133
+ class CRPBlock(nn.Module):
134
+ def __init__(self, features, n_stages, act=nn.ReLU(), maxpool=True):
135
+ super().__init__()
136
+ self.convs = nn.ModuleList()
137
+ for i in range(n_stages):
138
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
139
+ self.n_stages = n_stages
140
+ if maxpool:
141
+ self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
142
+ else:
143
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
144
+
145
+ self.act = act
146
+
147
+ def forward(self, x):
148
+ x = self.act(x)
149
+ path = x
150
+ for i in range(self.n_stages):
151
+ path = self.pool(path)
152
+ path = self.convs[i](path)
153
+ x = path + x
154
+ return x
155
+
156
+
157
+ class CondCRPBlock(nn.Module):
158
+ def __init__(self, features, n_stages, num_classes, normalizer, act=nn.ReLU()):
159
+ super().__init__()
160
+ self.convs = nn.ModuleList()
161
+ self.norms = nn.ModuleList()
162
+ self.normalizer = normalizer
163
+ for i in range(n_stages):
164
+ self.norms.append(normalizer(features, num_classes, bias=True))
165
+ self.convs.append(ncsn_conv3x3(features, features, stride=1, bias=False))
166
+
167
+ self.n_stages = n_stages
168
+ self.pool = nn.AvgPool2d(kernel_size=5, stride=1, padding=2)
169
+ self.act = act
170
+
171
+ def forward(self, x, y):
172
+ x = self.act(x)
173
+ path = x
174
+ for i in range(self.n_stages):
175
+ path = self.norms[i](path, y)
176
+ path = self.pool(path)
177
+ path = self.convs[i](path)
178
+
179
+ x = path + x
180
+ return x
181
+
182
+
183
+ class RCUBlock(nn.Module):
184
+ def __init__(self, features, n_blocks, n_stages, act=nn.ReLU()):
185
+ super().__init__()
186
+
187
+ for i in range(n_blocks):
188
+ for j in range(n_stages):
189
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
190
+
191
+ self.stride = 1
192
+ self.n_blocks = n_blocks
193
+ self.n_stages = n_stages
194
+ self.act = act
195
+
196
+ def forward(self, x):
197
+ for i in range(self.n_blocks):
198
+ residual = x
199
+ for j in range(self.n_stages):
200
+ x = self.act(x)
201
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
202
+
203
+ x += residual
204
+ return x
205
+
206
+
207
+ class CondRCUBlock(nn.Module):
208
+ def __init__(self, features, n_blocks, n_stages, num_classes, normalizer, act=nn.ReLU()):
209
+ super().__init__()
210
+
211
+ for i in range(n_blocks):
212
+ for j in range(n_stages):
213
+ setattr(self, '{}_{}_norm'.format(i + 1, j + 1), normalizer(features, num_classes, bias=True))
214
+ setattr(self, '{}_{}_conv'.format(i + 1, j + 1), ncsn_conv3x3(features, features, stride=1, bias=False))
215
+
216
+ self.stride = 1
217
+ self.n_blocks = n_blocks
218
+ self.n_stages = n_stages
219
+ self.act = act
220
+ self.normalizer = normalizer
221
+
222
+ def forward(self, x, y):
223
+ for i in range(self.n_blocks):
224
+ residual = x
225
+ for j in range(self.n_stages):
226
+ x = getattr(self, '{}_{}_norm'.format(i + 1, j + 1))(x, y)
227
+ x = self.act(x)
228
+ x = getattr(self, '{}_{}_conv'.format(i + 1, j + 1))(x)
229
+
230
+ x += residual
231
+ return x
232
+
233
+
234
+ class MSFBlock(nn.Module):
235
+ def __init__(self, in_planes, features):
236
+ super().__init__()
237
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
238
+ self.convs = nn.ModuleList()
239
+ self.features = features
240
+
241
+ for i in range(len(in_planes)):
242
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
243
+
244
+ def forward(self, xs, shape):
245
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
246
+ for i in range(len(self.convs)):
247
+ h = self.convs[i](xs[i])
248
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
249
+ sums += h
250
+ return sums
251
+
252
+
253
+ class CondMSFBlock(nn.Module):
254
+ def __init__(self, in_planes, features, num_classes, normalizer):
255
+ super().__init__()
256
+ assert isinstance(in_planes, list) or isinstance(in_planes, tuple)
257
+
258
+ self.convs = nn.ModuleList()
259
+ self.norms = nn.ModuleList()
260
+ self.features = features
261
+ self.normalizer = normalizer
262
+
263
+ for i in range(len(in_planes)):
264
+ self.convs.append(ncsn_conv3x3(in_planes[i], features, stride=1, bias=True))
265
+ self.norms.append(normalizer(in_planes[i], num_classes, bias=True))
266
+
267
+ def forward(self, xs, y, shape):
268
+ sums = torch.zeros(xs[0].shape[0], self.features, *shape, device=xs[0].device)
269
+ for i in range(len(self.convs)):
270
+ h = self.norms[i](xs[i], y)
271
+ h = self.convs[i](h)
272
+ h = F.interpolate(h, size=shape, mode='bilinear', align_corners=True)
273
+ sums += h
274
+ return sums
275
+
276
+
277
+ class RefineBlock(nn.Module):
278
+ def __init__(self, in_planes, features, act=nn.ReLU(), start=False, end=False, maxpool=True):
279
+ super().__init__()
280
+
281
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
282
+ self.n_blocks = n_blocks = len(in_planes)
283
+
284
+ self.adapt_convs = nn.ModuleList()
285
+ for i in range(n_blocks):
286
+ self.adapt_convs.append(RCUBlock(in_planes[i], 2, 2, act))
287
+
288
+ self.output_convs = RCUBlock(features, 3 if end else 1, 2, act)
289
+
290
+ if not start:
291
+ self.msf = MSFBlock(in_planes, features)
292
+
293
+ self.crp = CRPBlock(features, 2, act, maxpool=maxpool)
294
+
295
+ def forward(self, xs, output_shape):
296
+ assert isinstance(xs, tuple) or isinstance(xs, list)
297
+ hs = []
298
+ for i in range(len(xs)):
299
+ h = self.adapt_convs[i](xs[i])
300
+ hs.append(h)
301
+
302
+ if self.n_blocks > 1:
303
+ h = self.msf(hs, output_shape)
304
+ else:
305
+ h = hs[0]
306
+
307
+ h = self.crp(h)
308
+ h = self.output_convs(h)
309
+
310
+ return h
311
+
312
+
313
+ class CondRefineBlock(nn.Module):
314
+ def __init__(self, in_planes, features, num_classes, normalizer, act=nn.ReLU(), start=False, end=False):
315
+ super().__init__()
316
+
317
+ assert isinstance(in_planes, tuple) or isinstance(in_planes, list)
318
+ self.n_blocks = n_blocks = len(in_planes)
319
+
320
+ self.adapt_convs = nn.ModuleList()
321
+ for i in range(n_blocks):
322
+ self.adapt_convs.append(
323
+ CondRCUBlock(in_planes[i], 2, 2, num_classes, normalizer, act)
324
+ )
325
+
326
+ self.output_convs = CondRCUBlock(features, 3 if end else 1, 2, num_classes, normalizer, act)
327
+
328
+ if not start:
329
+ self.msf = CondMSFBlock(in_planes, features, num_classes, normalizer)
330
+
331
+ self.crp = CondCRPBlock(features, 2, num_classes, normalizer, act)
332
+
333
+ def forward(self, xs, y, output_shape):
334
+ assert isinstance(xs, tuple) or isinstance(xs, list)
335
+ hs = []
336
+ for i in range(len(xs)):
337
+ h = self.adapt_convs[i](xs[i], y)
338
+ hs.append(h)
339
+
340
+ if self.n_blocks > 1:
341
+ h = self.msf(hs, y, output_shape)
342
+ else:
343
+ h = hs[0]
344
+
345
+ h = self.crp(h, y)
346
+ h = self.output_convs(h, y)
347
+
348
+ return h
349
+
350
+
351
+ class ConvMeanPool(nn.Module):
352
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True, adjust_padding=False):
353
+ super().__init__()
354
+ if not adjust_padding:
355
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
356
+ self.conv = conv
357
+ else:
358
+ conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
359
+
360
+ self.conv = nn.Sequential(
361
+ nn.ZeroPad2d((1, 0, 1, 0)),
362
+ conv
363
+ )
364
+
365
+ def forward(self, inputs):
366
+ output = self.conv(inputs)
367
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
368
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
369
+ return output
370
+
371
+
372
+ class MeanPoolConv(nn.Module):
373
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
374
+ super().__init__()
375
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
376
+
377
+ def forward(self, inputs):
378
+ output = inputs
379
+ output = sum([output[:, :, ::2, ::2], output[:, :, 1::2, ::2],
380
+ output[:, :, ::2, 1::2], output[:, :, 1::2, 1::2]]) / 4.
381
+ return self.conv(output)
382
+
383
+
384
+ class UpsampleConv(nn.Module):
385
+ def __init__(self, input_dim, output_dim, kernel_size=3, biases=True):
386
+ super().__init__()
387
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride=1, padding=kernel_size // 2, bias=biases)
388
+ self.pixelshuffle = nn.PixelShuffle(upscale_factor=2)
389
+
390
+ def forward(self, inputs):
391
+ output = inputs
392
+ output = torch.cat([output, output, output, output], dim=1)
393
+ output = self.pixelshuffle(output)
394
+ return self.conv(output)
395
+
396
+
397
+ class ConditionalResidualBlock(nn.Module):
398
+ def __init__(self, input_dim, output_dim, num_classes, resample=1, act=nn.ELU(),
399
+ normalization=ConditionalInstanceNorm2dPlus, adjust_padding=False, dilation=None):
400
+ super().__init__()
401
+ self.non_linearity = act
402
+ self.input_dim = input_dim
403
+ self.output_dim = output_dim
404
+ self.resample = resample
405
+ self.normalization = normalization
406
+ if resample == 'down':
407
+ if dilation > 1:
408
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
409
+ self.normalize2 = normalization(input_dim, num_classes)
410
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
411
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
412
+ else:
413
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
414
+ self.normalize2 = normalization(input_dim, num_classes)
415
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
416
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
417
+
418
+ elif resample is None:
419
+ if dilation > 1:
420
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
421
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
422
+ self.normalize2 = normalization(output_dim, num_classes)
423
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
424
+ else:
425
+ conv_shortcut = nn.Conv2d
426
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
427
+ self.normalize2 = normalization(output_dim, num_classes)
428
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
429
+ else:
430
+ raise Exception('invalid resample value')
431
+
432
+ if output_dim != input_dim or resample is not None:
433
+ self.shortcut = conv_shortcut(input_dim, output_dim)
434
+
435
+ self.normalize1 = normalization(input_dim, num_classes)
436
+
437
+ def forward(self, x, y):
438
+ output = self.normalize1(x, y)
439
+ output = self.non_linearity(output)
440
+ output = self.conv1(output)
441
+ output = self.normalize2(output, y)
442
+ output = self.non_linearity(output)
443
+ output = self.conv2(output)
444
+
445
+ if self.output_dim == self.input_dim and self.resample is None:
446
+ shortcut = x
447
+ else:
448
+ shortcut = self.shortcut(x)
449
+
450
+ return shortcut + output
451
+
452
+
453
+ class ResidualBlock(nn.Module):
454
+ def __init__(self, input_dim, output_dim, resample=None, act=nn.ELU(),
455
+ normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
456
+ super().__init__()
457
+ self.non_linearity = act
458
+ self.input_dim = input_dim
459
+ self.output_dim = output_dim
460
+ self.resample = resample
461
+ self.normalization = normalization
462
+ if resample == 'down':
463
+ if dilation > 1:
464
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim, dilation=dilation)
465
+ self.normalize2 = normalization(input_dim)
466
+ self.conv2 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
467
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
468
+ else:
469
+ self.conv1 = ncsn_conv3x3(input_dim, input_dim)
470
+ self.normalize2 = normalization(input_dim)
471
+ self.conv2 = ConvMeanPool(input_dim, output_dim, 3, adjust_padding=adjust_padding)
472
+ conv_shortcut = partial(ConvMeanPool, kernel_size=1, adjust_padding=adjust_padding)
473
+
474
+ elif resample is None:
475
+ if dilation > 1:
476
+ conv_shortcut = partial(ncsn_conv3x3, dilation=dilation)
477
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim, dilation=dilation)
478
+ self.normalize2 = normalization(output_dim)
479
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim, dilation=dilation)
480
+ else:
481
+ # conv_shortcut = nn.Conv2d ### Something wierd here.
482
+ conv_shortcut = partial(ncsn_conv1x1)
483
+ self.conv1 = ncsn_conv3x3(input_dim, output_dim)
484
+ self.normalize2 = normalization(output_dim)
485
+ self.conv2 = ncsn_conv3x3(output_dim, output_dim)
486
+ else:
487
+ raise Exception('invalid resample value')
488
+
489
+ if output_dim != input_dim or resample is not None:
490
+ self.shortcut = conv_shortcut(input_dim, output_dim)
491
+
492
+ self.normalize1 = normalization(input_dim)
493
+
494
+ def forward(self, x):
495
+ output = self.normalize1(x)
496
+ output = self.non_linearity(output)
497
+ output = self.conv1(output)
498
+ output = self.normalize2(output)
499
+ output = self.non_linearity(output)
500
+ output = self.conv2(output)
501
+
502
+ if self.output_dim == self.input_dim and self.resample is None:
503
+ shortcut = x
504
+ else:
505
+ shortcut = self.shortcut(x)
506
+
507
+ return shortcut + output
508
+
509
+
510
+ ###########################################################################
511
+ # Functions below are ported over from the DDPM codebase:
512
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
513
+ ###########################################################################
514
+
515
+ def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
516
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
517
+ half_dim = embedding_dim // 2
518
+ # magic number 10000 is from transformers
519
+ emb = math.log(max_positions) / (half_dim - 1)
520
+ # emb = math.log(2.) / (half_dim - 1)
521
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
522
+ # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
523
+ # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
524
+ emb = timesteps.float()[:, None] * emb[None, :]
525
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
526
+ if embedding_dim % 2 == 1: # zero pad
527
+ emb = F.pad(emb, (0, 1), mode='constant')
528
+ assert emb.shape == (timesteps.shape[0], embedding_dim)
529
+ return emb
530
+
531
+
532
+ def _einsum(a, b, c, x, y):
533
+ einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
534
+ return torch.einsum(einsum_str, x, y)
535
+
536
+
537
+ def contract_inner(x, y):
538
+ """tensordot(x, y, 1)."""
539
+ x_chars = list(string.ascii_lowercase[:len(x.shape)])
540
+ y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
541
+ y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
542
+ out_chars = x_chars[:-1] + y_chars[1:]
543
+ return _einsum(x_chars, y_chars, out_chars, x, y)
544
+
545
+
546
+ class NIN(nn.Module):
547
+ def __init__(self, in_dim, num_units, init_scale=0.1):
548
+ super().__init__()
549
+ self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
550
+ self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)
551
+
552
+ def forward(self, x):
553
+ x = x.permute(0, 2, 3, 1)
554
+ y = contract_inner(x, self.W) + self.b
555
+ return y.permute(0, 3, 1, 2)
556
+
557
+
558
+ class AttnBlock(nn.Module):
559
+ """Channel-wise self-attention block."""
560
+ def __init__(self, channels):
561
+ super().__init__()
562
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
563
+ self.NIN_0 = NIN(channels, channels)
564
+ self.NIN_1 = NIN(channels, channels)
565
+ self.NIN_2 = NIN(channels, channels)
566
+ self.NIN_3 = NIN(channels, channels, init_scale=0.)
567
+
568
+ def forward(self, x):
569
+ B, C, H, W = x.shape
570
+ h = self.GroupNorm_0(x)
571
+ q = self.NIN_0(h)
572
+ k = self.NIN_1(h)
573
+ v = self.NIN_2(h)
574
+
575
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
576
+ w = torch.reshape(w, (B, H, W, H * W))
577
+ w = F.softmax(w, dim=-1)
578
+ w = torch.reshape(w, (B, H, W, H, W))
579
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
580
+ h = self.NIN_3(h)
581
+ return x + h
582
+
583
+
584
+ class Upsample(nn.Module):
585
+ def __init__(self, channels, with_conv=False):
586
+ super().__init__()
587
+ if with_conv:
588
+ self.Conv_0 = ddpm_conv3x3(channels, channels)
589
+ self.with_conv = with_conv
590
+
591
+ def forward(self, x):
592
+ B, C, H, W = x.shape
593
+ h = F.interpolate(x, (H * 2, W * 2), mode='nearest')
594
+ if self.with_conv:
595
+ h = self.Conv_0(h)
596
+ return h
597
+
598
+
599
+ class Downsample(nn.Module):
600
+ def __init__(self, channels, with_conv=False):
601
+ super().__init__()
602
+ if with_conv:
603
+ self.Conv_0 = ddpm_conv3x3(channels, channels, stride=2, padding=0)
604
+ self.with_conv = with_conv
605
+
606
+ def forward(self, x):
607
+ B, C, H, W = x.shape
608
+ # Emulate 'SAME' padding
609
+ if self.with_conv:
610
+ x = F.pad(x, (0, 1, 0, 1))
611
+ x = self.Conv_0(x)
612
+ else:
613
+ x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
614
+
615
+ assert x.shape == (B, C, H // 2, W // 2)
616
+ return x
617
+
618
+
619
+ class ResnetBlockDDPM(nn.Module):
620
+ """The ResNet Blocks used in DDPM."""
621
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1):
622
+ super().__init__()
623
+ if out_ch is None:
624
+ out_ch = in_ch
625
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=in_ch, eps=1e-6)
626
+ self.act = act
627
+ self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
628
+ if temb_dim is not None:
629
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
630
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
631
+ nn.init.zeros_(self.Dense_0.bias)
632
+
633
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=32, num_channels=out_ch, eps=1e-6)
634
+ self.Dropout_0 = nn.Dropout(dropout)
635
+ self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=0.)
636
+ if in_ch != out_ch:
637
+ if conv_shortcut:
638
+ self.Conv_2 = ddpm_conv3x3(in_ch, out_ch)
639
+ else:
640
+ self.NIN_0 = NIN(in_ch, out_ch)
641
+ self.out_ch = out_ch
642
+ self.in_ch = in_ch
643
+ self.conv_shortcut = conv_shortcut
644
+
645
+ def forward(self, x, temb=None):
646
+ B, C, H, W = x.shape
647
+ assert C == self.in_ch
648
+ out_ch = self.out_ch if self.out_ch else self.in_ch
649
+ h = self.act(self.GroupNorm_0(x))
650
+ h = self.Conv_0(h)
651
+ # Add bias to each feature map conditioned on the time embedding
652
+ if temb is not None:
653
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
654
+ h = self.act(self.GroupNorm_1(h))
655
+ h = self.Dropout_0(h)
656
+ h = self.Conv_1(h)
657
+ if C != out_ch:
658
+ if self.conv_shortcut:
659
+ x = self.Conv_2(x)
660
+ else:
661
+ x = self.NIN_0(x)
662
  return x + h
sgmse/backbones/ncsnpp_utils/layerspp.py CHANGED
@@ -1,274 +1,274 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # pylint: skip-file
17
- """Layers for defining NCSN++.
18
- """
19
- from . import layers
20
- from . import up_or_down_sampling
21
- import torch.nn as nn
22
- import torch
23
- import torch.nn.functional as F
24
- import numpy as np
25
-
26
- conv1x1 = layers.ddpm_conv1x1
27
- conv3x3 = layers.ddpm_conv3x3
28
- NIN = layers.NIN
29
- default_init = layers.default_init
30
-
31
-
32
- class GaussianFourierProjection(nn.Module):
33
- """Gaussian Fourier embeddings for noise levels."""
34
-
35
- def __init__(self, embedding_size=256, scale=1.0):
36
- super().__init__()
37
- self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
-
39
- def forward(self, x):
40
- x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
- return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
-
43
-
44
- class Combine(nn.Module):
45
- """Combine information from skip connections."""
46
-
47
- def __init__(self, dim1, dim2, method='cat'):
48
- super().__init__()
49
- self.Conv_0 = conv1x1(dim1, dim2)
50
- self.method = method
51
-
52
- def forward(self, x, y):
53
- h = self.Conv_0(x)
54
- if self.method == 'cat':
55
- return torch.cat([h, y], dim=1)
56
- elif self.method == 'sum':
57
- return h + y
58
- else:
59
- raise ValueError(f'Method {self.method} not recognized.')
60
-
61
-
62
- class AttnBlockpp(nn.Module):
63
- """Channel-wise self-attention block. Modified from DDPM."""
64
-
65
- def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
- super().__init__()
67
- self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
- eps=1e-6)
69
- self.NIN_0 = NIN(channels, channels)
70
- self.NIN_1 = NIN(channels, channels)
71
- self.NIN_2 = NIN(channels, channels)
72
- self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
- self.skip_rescale = skip_rescale
74
-
75
- def forward(self, x):
76
- B, C, H, W = x.shape
77
- h = self.GroupNorm_0(x)
78
- q = self.NIN_0(h)
79
- k = self.NIN_1(h)
80
- v = self.NIN_2(h)
81
-
82
- w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
- w = torch.reshape(w, (B, H, W, H * W))
84
- w = F.softmax(w, dim=-1)
85
- w = torch.reshape(w, (B, H, W, H, W))
86
- h = torch.einsum('bhwij,bcij->bchw', w, v)
87
- h = self.NIN_3(h)
88
- if not self.skip_rescale:
89
- return x + h
90
- else:
91
- return (x + h) / np.sqrt(2.)
92
-
93
-
94
- class Upsample(nn.Module):
95
- def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96
- fir_kernel=(1, 3, 3, 1)):
97
- super().__init__()
98
- out_ch = out_ch if out_ch else in_ch
99
- if not fir:
100
- if with_conv:
101
- self.Conv_0 = conv3x3(in_ch, out_ch)
102
- else:
103
- if with_conv:
104
- self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105
- kernel=3, up=True,
106
- resample_kernel=fir_kernel,
107
- use_bias=True,
108
- kernel_init=default_init())
109
- self.fir = fir
110
- self.with_conv = with_conv
111
- self.fir_kernel = fir_kernel
112
- self.out_ch = out_ch
113
-
114
- def forward(self, x):
115
- B, C, H, W = x.shape
116
- if not self.fir:
117
- h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118
- if self.with_conv:
119
- h = self.Conv_0(h)
120
- else:
121
- if not self.with_conv:
122
- h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123
- else:
124
- h = self.Conv2d_0(x)
125
-
126
- return h
127
-
128
-
129
- class Downsample(nn.Module):
130
- def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131
- fir_kernel=(1, 3, 3, 1)):
132
- super().__init__()
133
- out_ch = out_ch if out_ch else in_ch
134
- if not fir:
135
- if with_conv:
136
- self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137
- else:
138
- if with_conv:
139
- self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140
- kernel=3, down=True,
141
- resample_kernel=fir_kernel,
142
- use_bias=True,
143
- kernel_init=default_init())
144
- self.fir = fir
145
- self.fir_kernel = fir_kernel
146
- self.with_conv = with_conv
147
- self.out_ch = out_ch
148
-
149
- def forward(self, x):
150
- B, C, H, W = x.shape
151
- if not self.fir:
152
- if self.with_conv:
153
- x = F.pad(x, (0, 1, 0, 1))
154
- x = self.Conv_0(x)
155
- else:
156
- x = F.avg_pool2d(x, 2, stride=2)
157
- else:
158
- if not self.with_conv:
159
- x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160
- else:
161
- x = self.Conv2d_0(x)
162
-
163
- return x
164
-
165
-
166
- class ResnetBlockDDPMpp(nn.Module):
167
- """ResBlock adapted from DDPM."""
168
-
169
- def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170
- dropout=0.1, skip_rescale=False, init_scale=0.):
171
- super().__init__()
172
- out_ch = out_ch if out_ch else in_ch
173
- self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174
- self.Conv_0 = conv3x3(in_ch, out_ch)
175
- if temb_dim is not None:
176
- self.Dense_0 = nn.Linear(temb_dim, out_ch)
177
- self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178
- nn.init.zeros_(self.Dense_0.bias)
179
- self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180
- self.Dropout_0 = nn.Dropout(dropout)
181
- self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182
- if in_ch != out_ch:
183
- if conv_shortcut:
184
- self.Conv_2 = conv3x3(in_ch, out_ch)
185
- else:
186
- self.NIN_0 = NIN(in_ch, out_ch)
187
-
188
- self.skip_rescale = skip_rescale
189
- self.act = act
190
- self.out_ch = out_ch
191
- self.conv_shortcut = conv_shortcut
192
-
193
- def forward(self, x, temb=None):
194
- h = self.act(self.GroupNorm_0(x))
195
- h = self.Conv_0(h)
196
- if temb is not None:
197
- h += self.Dense_0(self.act(temb))[:, :, None, None]
198
- h = self.act(self.GroupNorm_1(h))
199
- h = self.Dropout_0(h)
200
- h = self.Conv_1(h)
201
- if x.shape[1] != self.out_ch:
202
- if self.conv_shortcut:
203
- x = self.Conv_2(x)
204
- else:
205
- x = self.NIN_0(x)
206
- if not self.skip_rescale:
207
- return x + h
208
- else:
209
- return (x + h) / np.sqrt(2.)
210
-
211
-
212
- class ResnetBlockBigGANpp(nn.Module):
213
- def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214
- dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215
- skip_rescale=True, init_scale=0.):
216
- super().__init__()
217
-
218
- out_ch = out_ch if out_ch else in_ch
219
- self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220
- self.up = up
221
- self.down = down
222
- self.fir = fir
223
- self.fir_kernel = fir_kernel
224
-
225
- self.Conv_0 = conv3x3(in_ch, out_ch)
226
- if temb_dim is not None:
227
- self.Dense_0 = nn.Linear(temb_dim, out_ch)
228
- self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229
- nn.init.zeros_(self.Dense_0.bias)
230
-
231
- self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232
- self.Dropout_0 = nn.Dropout(dropout)
233
- self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234
- if in_ch != out_ch or up or down:
235
- self.Conv_2 = conv1x1(in_ch, out_ch)
236
-
237
- self.skip_rescale = skip_rescale
238
- self.act = act
239
- self.in_ch = in_ch
240
- self.out_ch = out_ch
241
-
242
- def forward(self, x, temb=None):
243
- h = self.act(self.GroupNorm_0(x))
244
-
245
- if self.up:
246
- if self.fir:
247
- h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248
- x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249
- else:
250
- h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251
- x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252
- elif self.down:
253
- if self.fir:
254
- h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255
- x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256
- else:
257
- h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258
- x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259
-
260
- h = self.Conv_0(h)
261
- # Add bias to each feature map conditioned on the time embedding
262
- if temb is not None:
263
- h += self.Dense_0(self.act(temb))[:, :, None, None]
264
- h = self.act(self.GroupNorm_1(h))
265
- h = self.Dropout_0(h)
266
- h = self.Conv_1(h)
267
-
268
- if self.in_ch != self.out_ch or self.up or self.down:
269
- x = self.Conv_2(x)
270
-
271
- if not self.skip_rescale:
272
- return x + h
273
- else:
274
- return (x + h) / np.sqrt(2.)
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+ """Layers for defining NCSN++.
18
+ """
19
+ from . import layers
20
+ from . import up_or_down_sampling
21
+ import torch.nn as nn
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import numpy as np
25
+
26
+ conv1x1 = layers.ddpm_conv1x1
27
+ conv3x3 = layers.ddpm_conv3x3
28
+ NIN = layers.NIN
29
+ default_init = layers.default_init
30
+
31
+
32
+ class GaussianFourierProjection(nn.Module):
33
+ """Gaussian Fourier embeddings for noise levels."""
34
+
35
+ def __init__(self, embedding_size=256, scale=1.0):
36
+ super().__init__()
37
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
38
+
39
+ def forward(self, x):
40
+ x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
41
+ return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
42
+
43
+
44
+ class Combine(nn.Module):
45
+ """Combine information from skip connections."""
46
+
47
+ def __init__(self, dim1, dim2, method='cat'):
48
+ super().__init__()
49
+ self.Conv_0 = conv1x1(dim1, dim2)
50
+ self.method = method
51
+
52
+ def forward(self, x, y):
53
+ h = self.Conv_0(x)
54
+ if self.method == 'cat':
55
+ return torch.cat([h, y], dim=1)
56
+ elif self.method == 'sum':
57
+ return h + y
58
+ else:
59
+ raise ValueError(f'Method {self.method} not recognized.')
60
+
61
+
62
+ class AttnBlockpp(nn.Module):
63
+ """Channel-wise self-attention block. Modified from DDPM."""
64
+
65
+ def __init__(self, channels, skip_rescale=False, init_scale=0.):
66
+ super().__init__()
67
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels,
68
+ eps=1e-6)
69
+ self.NIN_0 = NIN(channels, channels)
70
+ self.NIN_1 = NIN(channels, channels)
71
+ self.NIN_2 = NIN(channels, channels)
72
+ self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
73
+ self.skip_rescale = skip_rescale
74
+
75
+ def forward(self, x):
76
+ B, C, H, W = x.shape
77
+ h = self.GroupNorm_0(x)
78
+ q = self.NIN_0(h)
79
+ k = self.NIN_1(h)
80
+ v = self.NIN_2(h)
81
+
82
+ w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5))
83
+ w = torch.reshape(w, (B, H, W, H * W))
84
+ w = F.softmax(w, dim=-1)
85
+ w = torch.reshape(w, (B, H, W, H, W))
86
+ h = torch.einsum('bhwij,bcij->bchw', w, v)
87
+ h = self.NIN_3(h)
88
+ if not self.skip_rescale:
89
+ return x + h
90
+ else:
91
+ return (x + h) / np.sqrt(2.)
92
+
93
+
94
+ class Upsample(nn.Module):
95
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
96
+ fir_kernel=(1, 3, 3, 1)):
97
+ super().__init__()
98
+ out_ch = out_ch if out_ch else in_ch
99
+ if not fir:
100
+ if with_conv:
101
+ self.Conv_0 = conv3x3(in_ch, out_ch)
102
+ else:
103
+ if with_conv:
104
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
105
+ kernel=3, up=True,
106
+ resample_kernel=fir_kernel,
107
+ use_bias=True,
108
+ kernel_init=default_init())
109
+ self.fir = fir
110
+ self.with_conv = with_conv
111
+ self.fir_kernel = fir_kernel
112
+ self.out_ch = out_ch
113
+
114
+ def forward(self, x):
115
+ B, C, H, W = x.shape
116
+ if not self.fir:
117
+ h = F.interpolate(x, (H * 2, W * 2), 'nearest')
118
+ if self.with_conv:
119
+ h = self.Conv_0(h)
120
+ else:
121
+ if not self.with_conv:
122
+ h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
123
+ else:
124
+ h = self.Conv2d_0(x)
125
+
126
+ return h
127
+
128
+
129
+ class Downsample(nn.Module):
130
+ def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False,
131
+ fir_kernel=(1, 3, 3, 1)):
132
+ super().__init__()
133
+ out_ch = out_ch if out_ch else in_ch
134
+ if not fir:
135
+ if with_conv:
136
+ self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0)
137
+ else:
138
+ if with_conv:
139
+ self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch,
140
+ kernel=3, down=True,
141
+ resample_kernel=fir_kernel,
142
+ use_bias=True,
143
+ kernel_init=default_init())
144
+ self.fir = fir
145
+ self.fir_kernel = fir_kernel
146
+ self.with_conv = with_conv
147
+ self.out_ch = out_ch
148
+
149
+ def forward(self, x):
150
+ B, C, H, W = x.shape
151
+ if not self.fir:
152
+ if self.with_conv:
153
+ x = F.pad(x, (0, 1, 0, 1))
154
+ x = self.Conv_0(x)
155
+ else:
156
+ x = F.avg_pool2d(x, 2, stride=2)
157
+ else:
158
+ if not self.with_conv:
159
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
160
+ else:
161
+ x = self.Conv2d_0(x)
162
+
163
+ return x
164
+
165
+
166
+ class ResnetBlockDDPMpp(nn.Module):
167
+ """ResBlock adapted from DDPM."""
168
+
169
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False,
170
+ dropout=0.1, skip_rescale=False, init_scale=0.):
171
+ super().__init__()
172
+ out_ch = out_ch if out_ch else in_ch
173
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
174
+ self.Conv_0 = conv3x3(in_ch, out_ch)
175
+ if temb_dim is not None:
176
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
177
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
178
+ nn.init.zeros_(self.Dense_0.bias)
179
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
180
+ self.Dropout_0 = nn.Dropout(dropout)
181
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
182
+ if in_ch != out_ch:
183
+ if conv_shortcut:
184
+ self.Conv_2 = conv3x3(in_ch, out_ch)
185
+ else:
186
+ self.NIN_0 = NIN(in_ch, out_ch)
187
+
188
+ self.skip_rescale = skip_rescale
189
+ self.act = act
190
+ self.out_ch = out_ch
191
+ self.conv_shortcut = conv_shortcut
192
+
193
+ def forward(self, x, temb=None):
194
+ h = self.act(self.GroupNorm_0(x))
195
+ h = self.Conv_0(h)
196
+ if temb is not None:
197
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
198
+ h = self.act(self.GroupNorm_1(h))
199
+ h = self.Dropout_0(h)
200
+ h = self.Conv_1(h)
201
+ if x.shape[1] != self.out_ch:
202
+ if self.conv_shortcut:
203
+ x = self.Conv_2(x)
204
+ else:
205
+ x = self.NIN_0(x)
206
+ if not self.skip_rescale:
207
+ return x + h
208
+ else:
209
+ return (x + h) / np.sqrt(2.)
210
+
211
+
212
+ class ResnetBlockBigGANpp(nn.Module):
213
+ def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
214
+ dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1),
215
+ skip_rescale=True, init_scale=0.):
216
+ super().__init__()
217
+
218
+ out_ch = out_ch if out_ch else in_ch
219
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6)
220
+ self.up = up
221
+ self.down = down
222
+ self.fir = fir
223
+ self.fir_kernel = fir_kernel
224
+
225
+ self.Conv_0 = conv3x3(in_ch, out_ch)
226
+ if temb_dim is not None:
227
+ self.Dense_0 = nn.Linear(temb_dim, out_ch)
228
+ self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
229
+ nn.init.zeros_(self.Dense_0.bias)
230
+
231
+ self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6)
232
+ self.Dropout_0 = nn.Dropout(dropout)
233
+ self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale)
234
+ if in_ch != out_ch or up or down:
235
+ self.Conv_2 = conv1x1(in_ch, out_ch)
236
+
237
+ self.skip_rescale = skip_rescale
238
+ self.act = act
239
+ self.in_ch = in_ch
240
+ self.out_ch = out_ch
241
+
242
+ def forward(self, x, temb=None):
243
+ h = self.act(self.GroupNorm_0(x))
244
+
245
+ if self.up:
246
+ if self.fir:
247
+ h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2)
248
+ x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2)
249
+ else:
250
+ h = up_or_down_sampling.naive_upsample_2d(h, factor=2)
251
+ x = up_or_down_sampling.naive_upsample_2d(x, factor=2)
252
+ elif self.down:
253
+ if self.fir:
254
+ h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2)
255
+ x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2)
256
+ else:
257
+ h = up_or_down_sampling.naive_downsample_2d(h, factor=2)
258
+ x = up_or_down_sampling.naive_downsample_2d(x, factor=2)
259
+
260
+ h = self.Conv_0(h)
261
+ # Add bias to each feature map conditioned on the time embedding
262
+ if temb is not None:
263
+ h += self.Dense_0(self.act(temb))[:, :, None, None]
264
+ h = self.act(self.GroupNorm_1(h))
265
+ h = self.Dropout_0(h)
266
+ h = self.Conv_1(h)
267
+
268
+ if self.in_ch != self.out_ch or self.up or self.down:
269
+ x = self.Conv_2(x)
270
+
271
+ if not self.skip_rescale:
272
+ return x + h
273
+ else:
274
+ return (x + h) / np.sqrt(2.)
sgmse/backbones/ncsnpp_utils/normalization.py CHANGED
@@ -1,215 +1,215 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """Normalization layers."""
17
- import torch.nn as nn
18
- import torch
19
- import functools
20
-
21
-
22
- def get_normalization(config, conditional=False):
23
- """Obtain normalization modules from the config file."""
24
- norm = config.model.normalization
25
- if conditional:
26
- if norm == 'InstanceNorm++':
27
- return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
- else:
29
- raise NotImplementedError(f'{norm} not implemented yet.')
30
- else:
31
- if norm == 'InstanceNorm':
32
- return nn.InstanceNorm2d
33
- elif norm == 'InstanceNorm++':
34
- return InstanceNorm2dPlus
35
- elif norm == 'VarianceNorm':
36
- return VarianceNorm2d
37
- elif norm == 'GroupNorm':
38
- return nn.GroupNorm
39
- else:
40
- raise ValueError('Unknown normalization: %s' % norm)
41
-
42
-
43
- class ConditionalBatchNorm2d(nn.Module):
44
- def __init__(self, num_features, num_classes, bias=True):
45
- super().__init__()
46
- self.num_features = num_features
47
- self.bias = bias
48
- self.bn = nn.BatchNorm2d(num_features, affine=False)
49
- if self.bias:
50
- self.embed = nn.Embedding(num_classes, num_features * 2)
51
- self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
- self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
- else:
54
- self.embed = nn.Embedding(num_classes, num_features)
55
- self.embed.weight.data.uniform_()
56
-
57
- def forward(self, x, y):
58
- out = self.bn(x)
59
- if self.bias:
60
- gamma, beta = self.embed(y).chunk(2, dim=1)
61
- out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
- else:
63
- gamma = self.embed(y)
64
- out = gamma.view(-1, self.num_features, 1, 1) * out
65
- return out
66
-
67
-
68
- class ConditionalInstanceNorm2d(nn.Module):
69
- def __init__(self, num_features, num_classes, bias=True):
70
- super().__init__()
71
- self.num_features = num_features
72
- self.bias = bias
73
- self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
- if bias:
75
- self.embed = nn.Embedding(num_classes, num_features * 2)
76
- self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
- self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
- else:
79
- self.embed = nn.Embedding(num_classes, num_features)
80
- self.embed.weight.data.uniform_()
81
-
82
- def forward(self, x, y):
83
- h = self.instance_norm(x)
84
- if self.bias:
85
- gamma, beta = self.embed(y).chunk(2, dim=-1)
86
- out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
- else:
88
- gamma = self.embed(y)
89
- out = gamma.view(-1, self.num_features, 1, 1) * h
90
- return out
91
-
92
-
93
- class ConditionalVarianceNorm2d(nn.Module):
94
- def __init__(self, num_features, num_classes, bias=False):
95
- super().__init__()
96
- self.num_features = num_features
97
- self.bias = bias
98
- self.embed = nn.Embedding(num_classes, num_features)
99
- self.embed.weight.data.normal_(1, 0.02)
100
-
101
- def forward(self, x, y):
102
- vars = torch.var(x, dim=(2, 3), keepdim=True)
103
- h = x / torch.sqrt(vars + 1e-5)
104
-
105
- gamma = self.embed(y)
106
- out = gamma.view(-1, self.num_features, 1, 1) * h
107
- return out
108
-
109
-
110
- class VarianceNorm2d(nn.Module):
111
- def __init__(self, num_features, bias=False):
112
- super().__init__()
113
- self.num_features = num_features
114
- self.bias = bias
115
- self.alpha = nn.Parameter(torch.zeros(num_features))
116
- self.alpha.data.normal_(1, 0.02)
117
-
118
- def forward(self, x):
119
- vars = torch.var(x, dim=(2, 3), keepdim=True)
120
- h = x / torch.sqrt(vars + 1e-5)
121
-
122
- out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
- return out
124
-
125
-
126
- class ConditionalNoneNorm2d(nn.Module):
127
- def __init__(self, num_features, num_classes, bias=True):
128
- super().__init__()
129
- self.num_features = num_features
130
- self.bias = bias
131
- if bias:
132
- self.embed = nn.Embedding(num_classes, num_features * 2)
133
- self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
- self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
- else:
136
- self.embed = nn.Embedding(num_classes, num_features)
137
- self.embed.weight.data.uniform_()
138
-
139
- def forward(self, x, y):
140
- if self.bias:
141
- gamma, beta = self.embed(y).chunk(2, dim=-1)
142
- out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
- else:
144
- gamma = self.embed(y)
145
- out = gamma.view(-1, self.num_features, 1, 1) * x
146
- return out
147
-
148
-
149
- class NoneNorm2d(nn.Module):
150
- def __init__(self, num_features, bias=True):
151
- super().__init__()
152
-
153
- def forward(self, x):
154
- return x
155
-
156
-
157
- class InstanceNorm2dPlus(nn.Module):
158
- def __init__(self, num_features, bias=True):
159
- super().__init__()
160
- self.num_features = num_features
161
- self.bias = bias
162
- self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
- self.alpha = nn.Parameter(torch.zeros(num_features))
164
- self.gamma = nn.Parameter(torch.zeros(num_features))
165
- self.alpha.data.normal_(1, 0.02)
166
- self.gamma.data.normal_(1, 0.02)
167
- if bias:
168
- self.beta = nn.Parameter(torch.zeros(num_features))
169
-
170
- def forward(self, x):
171
- means = torch.mean(x, dim=(2, 3))
172
- m = torch.mean(means, dim=-1, keepdim=True)
173
- v = torch.var(means, dim=-1, keepdim=True)
174
- means = (means - m) / (torch.sqrt(v + 1e-5))
175
- h = self.instance_norm(x)
176
-
177
- if self.bias:
178
- h = h + means[..., None, None] * self.alpha[..., None, None]
179
- out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
- else:
181
- h = h + means[..., None, None] * self.alpha[..., None, None]
182
- out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
- return out
184
-
185
-
186
- class ConditionalInstanceNorm2dPlus(nn.Module):
187
- def __init__(self, num_features, num_classes, bias=True):
188
- super().__init__()
189
- self.num_features = num_features
190
- self.bias = bias
191
- self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
- if bias:
193
- self.embed = nn.Embedding(num_classes, num_features * 3)
194
- self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
- self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
- else:
197
- self.embed = nn.Embedding(num_classes, 2 * num_features)
198
- self.embed.weight.data.normal_(1, 0.02)
199
-
200
- def forward(self, x, y):
201
- means = torch.mean(x, dim=(2, 3))
202
- m = torch.mean(means, dim=-1, keepdim=True)
203
- v = torch.var(means, dim=-1, keepdim=True)
204
- means = (means - m) / (torch.sqrt(v + 1e-5))
205
- h = self.instance_norm(x)
206
-
207
- if self.bias:
208
- gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
- h = h + means[..., None, None] * alpha[..., None, None]
210
- out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
- else:
212
- gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
- h = h + means[..., None, None] * alpha[..., None, None]
214
- out = gamma.view(-1, self.num_features, 1, 1) * h
215
- return out
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Normalization layers."""
17
+ import torch.nn as nn
18
+ import torch
19
+ import functools
20
+
21
+
22
+ def get_normalization(config, conditional=False):
23
+ """Obtain normalization modules from the config file."""
24
+ norm = config.model.normalization
25
+ if conditional:
26
+ if norm == 'InstanceNorm++':
27
+ return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes)
28
+ else:
29
+ raise NotImplementedError(f'{norm} not implemented yet.')
30
+ else:
31
+ if norm == 'InstanceNorm':
32
+ return nn.InstanceNorm2d
33
+ elif norm == 'InstanceNorm++':
34
+ return InstanceNorm2dPlus
35
+ elif norm == 'VarianceNorm':
36
+ return VarianceNorm2d
37
+ elif norm == 'GroupNorm':
38
+ return nn.GroupNorm
39
+ else:
40
+ raise ValueError('Unknown normalization: %s' % norm)
41
+
42
+
43
+ class ConditionalBatchNorm2d(nn.Module):
44
+ def __init__(self, num_features, num_classes, bias=True):
45
+ super().__init__()
46
+ self.num_features = num_features
47
+ self.bias = bias
48
+ self.bn = nn.BatchNorm2d(num_features, affine=False)
49
+ if self.bias:
50
+ self.embed = nn.Embedding(num_classes, num_features * 2)
51
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
52
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
53
+ else:
54
+ self.embed = nn.Embedding(num_classes, num_features)
55
+ self.embed.weight.data.uniform_()
56
+
57
+ def forward(self, x, y):
58
+ out = self.bn(x)
59
+ if self.bias:
60
+ gamma, beta = self.embed(y).chunk(2, dim=1)
61
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
62
+ else:
63
+ gamma = self.embed(y)
64
+ out = gamma.view(-1, self.num_features, 1, 1) * out
65
+ return out
66
+
67
+
68
+ class ConditionalInstanceNorm2d(nn.Module):
69
+ def __init__(self, num_features, num_classes, bias=True):
70
+ super().__init__()
71
+ self.num_features = num_features
72
+ self.bias = bias
73
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
74
+ if bias:
75
+ self.embed = nn.Embedding(num_classes, num_features * 2)
76
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
77
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
78
+ else:
79
+ self.embed = nn.Embedding(num_classes, num_features)
80
+ self.embed.weight.data.uniform_()
81
+
82
+ def forward(self, x, y):
83
+ h = self.instance_norm(x)
84
+ if self.bias:
85
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
86
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
87
+ else:
88
+ gamma = self.embed(y)
89
+ out = gamma.view(-1, self.num_features, 1, 1) * h
90
+ return out
91
+
92
+
93
+ class ConditionalVarianceNorm2d(nn.Module):
94
+ def __init__(self, num_features, num_classes, bias=False):
95
+ super().__init__()
96
+ self.num_features = num_features
97
+ self.bias = bias
98
+ self.embed = nn.Embedding(num_classes, num_features)
99
+ self.embed.weight.data.normal_(1, 0.02)
100
+
101
+ def forward(self, x, y):
102
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
103
+ h = x / torch.sqrt(vars + 1e-5)
104
+
105
+ gamma = self.embed(y)
106
+ out = gamma.view(-1, self.num_features, 1, 1) * h
107
+ return out
108
+
109
+
110
+ class VarianceNorm2d(nn.Module):
111
+ def __init__(self, num_features, bias=False):
112
+ super().__init__()
113
+ self.num_features = num_features
114
+ self.bias = bias
115
+ self.alpha = nn.Parameter(torch.zeros(num_features))
116
+ self.alpha.data.normal_(1, 0.02)
117
+
118
+ def forward(self, x):
119
+ vars = torch.var(x, dim=(2, 3), keepdim=True)
120
+ h = x / torch.sqrt(vars + 1e-5)
121
+
122
+ out = self.alpha.view(-1, self.num_features, 1, 1) * h
123
+ return out
124
+
125
+
126
+ class ConditionalNoneNorm2d(nn.Module):
127
+ def __init__(self, num_features, num_classes, bias=True):
128
+ super().__init__()
129
+ self.num_features = num_features
130
+ self.bias = bias
131
+ if bias:
132
+ self.embed = nn.Embedding(num_classes, num_features * 2)
133
+ self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
134
+ self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
135
+ else:
136
+ self.embed = nn.Embedding(num_classes, num_features)
137
+ self.embed.weight.data.uniform_()
138
+
139
+ def forward(self, x, y):
140
+ if self.bias:
141
+ gamma, beta = self.embed(y).chunk(2, dim=-1)
142
+ out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
143
+ else:
144
+ gamma = self.embed(y)
145
+ out = gamma.view(-1, self.num_features, 1, 1) * x
146
+ return out
147
+
148
+
149
+ class NoneNorm2d(nn.Module):
150
+ def __init__(self, num_features, bias=True):
151
+ super().__init__()
152
+
153
+ def forward(self, x):
154
+ return x
155
+
156
+
157
+ class InstanceNorm2dPlus(nn.Module):
158
+ def __init__(self, num_features, bias=True):
159
+ super().__init__()
160
+ self.num_features = num_features
161
+ self.bias = bias
162
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
163
+ self.alpha = nn.Parameter(torch.zeros(num_features))
164
+ self.gamma = nn.Parameter(torch.zeros(num_features))
165
+ self.alpha.data.normal_(1, 0.02)
166
+ self.gamma.data.normal_(1, 0.02)
167
+ if bias:
168
+ self.beta = nn.Parameter(torch.zeros(num_features))
169
+
170
+ def forward(self, x):
171
+ means = torch.mean(x, dim=(2, 3))
172
+ m = torch.mean(means, dim=-1, keepdim=True)
173
+ v = torch.var(means, dim=-1, keepdim=True)
174
+ means = (means - m) / (torch.sqrt(v + 1e-5))
175
+ h = self.instance_norm(x)
176
+
177
+ if self.bias:
178
+ h = h + means[..., None, None] * self.alpha[..., None, None]
179
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1)
180
+ else:
181
+ h = h + means[..., None, None] * self.alpha[..., None, None]
182
+ out = self.gamma.view(-1, self.num_features, 1, 1) * h
183
+ return out
184
+
185
+
186
+ class ConditionalInstanceNorm2dPlus(nn.Module):
187
+ def __init__(self, num_features, num_classes, bias=True):
188
+ super().__init__()
189
+ self.num_features = num_features
190
+ self.bias = bias
191
+ self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
192
+ if bias:
193
+ self.embed = nn.Embedding(num_classes, num_features * 3)
194
+ self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
195
+ self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
196
+ else:
197
+ self.embed = nn.Embedding(num_classes, 2 * num_features)
198
+ self.embed.weight.data.normal_(1, 0.02)
199
+
200
+ def forward(self, x, y):
201
+ means = torch.mean(x, dim=(2, 3))
202
+ m = torch.mean(means, dim=-1, keepdim=True)
203
+ v = torch.var(means, dim=-1, keepdim=True)
204
+ means = (means - m) / (torch.sqrt(v + 1e-5))
205
+ h = self.instance_norm(x)
206
+
207
+ if self.bias:
208
+ gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
209
+ h = h + means[..., None, None] * alpha[..., None, None]
210
+ out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1)
211
+ else:
212
+ gamma, alpha = self.embed(y).chunk(2, dim=-1)
213
+ h = h + means[..., None, None] * alpha[..., None, None]
214
+ out = gamma.view(-1, self.num_features, 1, 1) * h
215
+ return out
sgmse/backbones/ncsnpp_utils/op/__init__.py CHANGED
@@ -1,2 +1 @@
1
- from .upfirdn2d import upfirdn2d
2
- # from .upfirdn2d_native import upfirdn2d
 
1
+ from .upfirdn2d import upfirdn2d
 
sgmse/backbones/ncsnpp_utils/op/fused_act.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output, empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ grad_bias = grad_input.sum(dim).detach()
39
+
40
+ return grad_input, grad_bias
41
+
42
+ @staticmethod
43
+ def backward(ctx, gradgrad_input, gradgrad_bias):
44
+ out, = ctx.saved_tensors
45
+ gradgrad_out = fused.fused_bias_act(
46
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47
+ )
48
+
49
+ return gradgrad_out, None, None, None
50
+
51
+
52
+ class FusedLeakyReLUFunction(Function):
53
+ @staticmethod
54
+ def forward(ctx, input, bias, negative_slope, scale):
55
+ empty = input.new_empty(0)
56
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57
+ ctx.save_for_backward(out)
58
+ ctx.negative_slope = negative_slope
59
+ ctx.scale = scale
60
+
61
+ return out
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad_output):
65
+ out, = ctx.saved_tensors
66
+
67
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68
+ grad_output, out, ctx.negative_slope, ctx.scale
69
+ )
70
+
71
+ return grad_input, grad_bias, None, None
72
+
73
+
74
+ class FusedLeakyReLU(nn.Module):
75
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76
+ super().__init__()
77
+
78
+ self.bias = nn.Parameter(torch.zeros(channel))
79
+ self.negative_slope = negative_slope
80
+ self.scale = scale
81
+
82
+ def forward(self, input):
83
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84
+
85
+
86
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87
+ if input.device.type == "cpu":
88
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
89
+ return (
90
+ F.leaky_relu(
91
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92
+ )
93
+ * scale
94
+ )
95
+
96
+ else:
97
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
sgmse/backbones/ncsnpp_utils/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+
4
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5
+ int act, int grad, float alpha, float scale);
6
+
7
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10
+
11
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12
+ int act, int grad, float alpha, float scale) {
13
+ CHECK_CUDA(input);
14
+ CHECK_CUDA(bias);
15
+
16
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17
+ }
18
+
19
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21
+ }
sgmse/backbones/ncsnpp_utils/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAContext.h>
12
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22
+
23
+ scalar_t zero = 0.0;
24
+
25
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26
+ scalar_t x = p_x[xi];
27
+
28
+ if (use_bias) {
29
+ x += p_b[(xi / step_b) % size_b];
30
+ }
31
+
32
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
33
+
34
+ scalar_t y;
35
+
36
+ switch (act * 10 + grad) {
37
+ default:
38
+ case 10: y = x; break;
39
+ case 11: y = x; break;
40
+ case 12: y = 0.0; break;
41
+
42
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
43
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
44
+ case 32: y = 0.0; break;
45
+ }
46
+
47
+ out[xi] = y * scale;
48
+ }
49
+ }
50
+
51
+
52
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53
+ int act, int grad, float alpha, float scale) {
54
+ int curDevice = -1;
55
+ cudaGetDevice(&curDevice);
56
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57
+
58
+ auto x = input.contiguous();
59
+ auto b = bias.contiguous();
60
+ auto ref = refer.contiguous();
61
+
62
+ int use_bias = b.numel() ? 1 : 0;
63
+ int use_ref = ref.numel() ? 1 : 0;
64
+
65
+ int size_x = x.numel();
66
+ int size_b = b.numel();
67
+ int step_b = 1;
68
+
69
+ for (int i = 1 + 1; i < x.dim(); i++) {
70
+ step_b *= x.size(i);
71
+ }
72
+
73
+ int loop_x = 4;
74
+ int block_size = 4 * 32;
75
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76
+
77
+ auto y = torch::empty_like(x);
78
+
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
81
+ y.data_ptr<scalar_t>(),
82
+ x.data_ptr<scalar_t>(),
83
+ b.data_ptr<scalar_t>(),
84
+ ref.data_ptr<scalar_t>(),
85
+ act,
86
+ grad,
87
+ alpha,
88
+ scale,
89
+ loop_x,
90
+ size_x,
91
+ step_b,
92
+ size_b,
93
+ use_bias,
94
+ use_ref
95
+ );
96
+ });
97
+
98
+ return y;
99
+ }
sgmse/backbones/ncsnpp_utils/up_or_down_sampling.py CHANGED
@@ -1,257 +1,257 @@
1
- """Layers used for up-sampling or down-sampling images.
2
-
3
- Many functions are ported from https://github.com/NVlabs/stylegan2.
4
- """
5
-
6
- import torch.nn as nn
7
- import torch
8
- import torch.nn.functional as F
9
- import numpy as np
10
- from .op import upfirdn2d
11
-
12
-
13
- # Function ported from StyleGAN2
14
- def get_weight(module,
15
- shape,
16
- weight_var='weight',
17
- kernel_init=None):
18
- """Get/create weight tensor for a convolution or fully-connected layer."""
19
-
20
- return module.param(weight_var, kernel_init, shape)
21
-
22
-
23
- class Conv2d(nn.Module):
24
- """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
25
-
26
- def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
27
- resample_kernel=(1, 3, 3, 1),
28
- use_bias=True,
29
- kernel_init=None):
30
- super().__init__()
31
- assert not (up and down)
32
- assert kernel >= 1 and kernel % 2 == 1
33
- self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
34
- if kernel_init is not None:
35
- self.weight.data = kernel_init(self.weight.data.shape)
36
- if use_bias:
37
- self.bias = nn.Parameter(torch.zeros(out_ch))
38
-
39
- self.up = up
40
- self.down = down
41
- self.resample_kernel = resample_kernel
42
- self.kernel = kernel
43
- self.use_bias = use_bias
44
-
45
- def forward(self, x):
46
- if self.up:
47
- x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
48
- elif self.down:
49
- x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
50
- else:
51
- x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
52
-
53
- if self.use_bias:
54
- x = x + self.bias.reshape(1, -1, 1, 1)
55
-
56
- return x
57
-
58
-
59
- def naive_upsample_2d(x, factor=2):
60
- _N, C, H, W = x.shape
61
- x = torch.reshape(x, (-1, C, H, 1, W, 1))
62
- x = x.repeat(1, 1, 1, factor, 1, factor)
63
- return torch.reshape(x, (-1, C, H * factor, W * factor))
64
-
65
-
66
- def naive_downsample_2d(x, factor=2):
67
- _N, C, H, W = x.shape
68
- x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
69
- return torch.mean(x, dim=(3, 5))
70
-
71
-
72
- def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
73
- """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
74
-
75
- Padding is performed only once at the beginning, not between the
76
- operations.
77
- The fused op is considerably more efficient than performing the same
78
- calculation
79
- using standard TensorFlow ops. It supports gradients of arbitrary order.
80
- Args:
81
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
82
- C]`.
83
- w: Weight tensor of the shape `[filterH, filterW, inChannels,
84
- outChannels]`. Grouped convolution can be performed by `inChannels =
85
- x.shape[0] // numGroups`.
86
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
87
- (separable). The default is `[1] * factor`, which corresponds to
88
- nearest-neighbor upsampling.
89
- factor: Integer upsampling factor (default: 2).
90
- gain: Scaling factor for signal magnitude (default: 1.0).
91
-
92
- Returns:
93
- Tensor of the shape `[N, C, H * factor, W * factor]` or
94
- `[N, H * factor, W * factor, C]`, and same datatype as `x`.
95
- """
96
-
97
- assert isinstance(factor, int) and factor >= 1
98
-
99
- # Check weight shape.
100
- assert len(w.shape) == 4
101
- convH = w.shape[2]
102
- convW = w.shape[3]
103
- inC = w.shape[1]
104
- outC = w.shape[0]
105
-
106
- assert convW == convH
107
-
108
- # Setup filter kernel.
109
- if k is None:
110
- k = [1] * factor
111
- k = _setup_kernel(k) * (gain * (factor ** 2))
112
- p = (k.shape[0] - factor) - (convW - 1)
113
-
114
- stride = (factor, factor)
115
-
116
- # Determine data dimensions.
117
- stride = [1, 1, factor, factor]
118
- output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
119
- output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
120
- output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
121
- assert output_padding[0] >= 0 and output_padding[1] >= 0
122
- num_groups = _shape(x, 1) // inC
123
-
124
- # Transpose weights.
125
- w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
126
- w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
127
- w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
128
-
129
- x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
130
- ## Original TF code.
131
- # x = tf.nn.conv2d_transpose(
132
- # x,
133
- # w,
134
- # output_shape=output_shape,
135
- # strides=stride,
136
- # padding='VALID',
137
- # data_format=data_format)
138
- ## JAX equivalent
139
-
140
- return upfirdn2d(x, torch.tensor(k, device=x.device),
141
- pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
142
-
143
-
144
- def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
145
- """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
146
-
147
- Padding is performed only once at the beginning, not between the operations.
148
- The fused op is considerably more efficient than performing the same
149
- calculation
150
- using standard TensorFlow ops. It supports gradients of arbitrary order.
151
- Args:
152
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
153
- C]`.
154
- w: Weight tensor of the shape `[filterH, filterW, inChannels,
155
- outChannels]`. Grouped convolution can be performed by `inChannels =
156
- x.shape[0] // numGroups`.
157
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
158
- (separable). The default is `[1] * factor`, which corresponds to
159
- average pooling.
160
- factor: Integer downsampling factor (default: 2).
161
- gain: Scaling factor for signal magnitude (default: 1.0).
162
-
163
- Returns:
164
- Tensor of the shape `[N, C, H // factor, W // factor]` or
165
- `[N, H // factor, W // factor, C]`, and same datatype as `x`.
166
- """
167
-
168
- assert isinstance(factor, int) and factor >= 1
169
- _outC, _inC, convH, convW = w.shape
170
- assert convW == convH
171
- if k is None:
172
- k = [1] * factor
173
- k = _setup_kernel(k) * gain
174
- p = (k.shape[0] - factor) + (convW - 1)
175
- s = [factor, factor]
176
- x = upfirdn2d(x, torch.tensor(k, device=x.device),
177
- pad=((p + 1) // 2, p // 2))
178
- return F.conv2d(x, w, stride=s, padding=0)
179
-
180
-
181
- def _setup_kernel(k):
182
- k = np.asarray(k, dtype=np.float32)
183
- if k.ndim == 1:
184
- k = np.outer(k, k)
185
- k /= np.sum(k)
186
- assert k.ndim == 2
187
- assert k.shape[0] == k.shape[1]
188
- return k
189
-
190
-
191
- def _shape(x, dim):
192
- return x.shape[dim]
193
-
194
-
195
- def upsample_2d(x, k=None, factor=2, gain=1):
196
- r"""Upsample a batch of 2D images with the given filter.
197
-
198
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
199
- and upsamples each image with the given filter. The filter is normalized so
200
- that
201
- if the input pixels are constant, they will be scaled by the specified
202
- `gain`.
203
- Pixels outside the image are assumed to be zero, and the filter is padded
204
- with
205
- zeros so that its shape is a multiple of the upsampling factor.
206
- Args:
207
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
208
- C]`.
209
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
210
- (separable). The default is `[1] * factor`, which corresponds to
211
- nearest-neighbor upsampling.
212
- factor: Integer upsampling factor (default: 2).
213
- gain: Scaling factor for signal magnitude (default: 1.0).
214
-
215
- Returns:
216
- Tensor of the shape `[N, C, H * factor, W * factor]`
217
- """
218
- assert isinstance(factor, int) and factor >= 1
219
- if k is None:
220
- k = [1] * factor
221
- k = _setup_kernel(k) * (gain * (factor ** 2))
222
- p = k.shape[0] - factor
223
- return upfirdn2d(x, torch.tensor(k, device=x.device),
224
- up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
225
-
226
-
227
- def downsample_2d(x, k=None, factor=2, gain=1):
228
- r"""Downsample a batch of 2D images with the given filter.
229
-
230
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
231
- and downsamples each image with the given filter. The filter is normalized
232
- so that
233
- if the input pixels are constant, they will be scaled by the specified
234
- `gain`.
235
- Pixels outside the image are assumed to be zero, and the filter is padded
236
- with
237
- zeros so that its shape is a multiple of the downsampling factor.
238
- Args:
239
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
240
- C]`.
241
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
242
- (separable). The default is `[1] * factor`, which corresponds to
243
- average pooling.
244
- factor: Integer downsampling factor (default: 2).
245
- gain: Scaling factor for signal magnitude (default: 1.0).
246
-
247
- Returns:
248
- Tensor of the shape `[N, C, H // factor, W // factor]`
249
- """
250
-
251
- assert isinstance(factor, int) and factor >= 1
252
- if k is None:
253
- k = [1] * factor
254
- k = _setup_kernel(k) * gain
255
- p = k.shape[0] - factor
256
- return upfirdn2d(x, torch.tensor(k, device=x.device),
257
- down=factor, pad=((p + 1) // 2, p // 2))
 
1
+ """Layers used for up-sampling or down-sampling images.
2
+
3
+ Many functions are ported from https://github.com/NVlabs/stylegan2.
4
+ """
5
+
6
+ import torch.nn as nn
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from .op import upfirdn2d
11
+
12
+
13
+ # Function ported from StyleGAN2
14
+ def get_weight(module,
15
+ shape,
16
+ weight_var='weight',
17
+ kernel_init=None):
18
+ """Get/create weight tensor for a convolution or fully-connected layer."""
19
+
20
+ return module.param(weight_var, kernel_init, shape)
21
+
22
+
23
+ class Conv2d(nn.Module):
24
+ """Conv2d layer with optimal upsampling and downsampling (StyleGAN2)."""
25
+
26
+ def __init__(self, in_ch, out_ch, kernel, up=False, down=False,
27
+ resample_kernel=(1, 3, 3, 1),
28
+ use_bias=True,
29
+ kernel_init=None):
30
+ super().__init__()
31
+ assert not (up and down)
32
+ assert kernel >= 1 and kernel % 2 == 1
33
+ self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel))
34
+ if kernel_init is not None:
35
+ self.weight.data = kernel_init(self.weight.data.shape)
36
+ if use_bias:
37
+ self.bias = nn.Parameter(torch.zeros(out_ch))
38
+
39
+ self.up = up
40
+ self.down = down
41
+ self.resample_kernel = resample_kernel
42
+ self.kernel = kernel
43
+ self.use_bias = use_bias
44
+
45
+ def forward(self, x):
46
+ if self.up:
47
+ x = upsample_conv_2d(x, self.weight, k=self.resample_kernel)
48
+ elif self.down:
49
+ x = conv_downsample_2d(x, self.weight, k=self.resample_kernel)
50
+ else:
51
+ x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2)
52
+
53
+ if self.use_bias:
54
+ x = x + self.bias.reshape(1, -1, 1, 1)
55
+
56
+ return x
57
+
58
+
59
+ def naive_upsample_2d(x, factor=2):
60
+ _N, C, H, W = x.shape
61
+ x = torch.reshape(x, (-1, C, H, 1, W, 1))
62
+ x = x.repeat(1, 1, 1, factor, 1, factor)
63
+ return torch.reshape(x, (-1, C, H * factor, W * factor))
64
+
65
+
66
+ def naive_downsample_2d(x, factor=2):
67
+ _N, C, H, W = x.shape
68
+ x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor))
69
+ return torch.mean(x, dim=(3, 5))
70
+
71
+
72
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1):
73
+ """Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
74
+
75
+ Padding is performed only once at the beginning, not between the
76
+ operations.
77
+ The fused op is considerably more efficient than performing the same
78
+ calculation
79
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
80
+ Args:
81
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
82
+ C]`.
83
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
84
+ outChannels]`. Grouped convolution can be performed by `inChannels =
85
+ x.shape[0] // numGroups`.
86
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
87
+ (separable). The default is `[1] * factor`, which corresponds to
88
+ nearest-neighbor upsampling.
89
+ factor: Integer upsampling factor (default: 2).
90
+ gain: Scaling factor for signal magnitude (default: 1.0).
91
+
92
+ Returns:
93
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
94
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
95
+ """
96
+
97
+ assert isinstance(factor, int) and factor >= 1
98
+
99
+ # Check weight shape.
100
+ assert len(w.shape) == 4
101
+ convH = w.shape[2]
102
+ convW = w.shape[3]
103
+ inC = w.shape[1]
104
+ outC = w.shape[0]
105
+
106
+ assert convW == convH
107
+
108
+ # Setup filter kernel.
109
+ if k is None:
110
+ k = [1] * factor
111
+ k = _setup_kernel(k) * (gain * (factor ** 2))
112
+ p = (k.shape[0] - factor) - (convW - 1)
113
+
114
+ stride = (factor, factor)
115
+
116
+ # Determine data dimensions.
117
+ stride = [1, 1, factor, factor]
118
+ output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW)
119
+ output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH,
120
+ output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW)
121
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
122
+ num_groups = _shape(x, 1) // inC
123
+
124
+ # Transpose weights.
125
+ w = torch.reshape(w, (num_groups, -1, inC, convH, convW))
126
+ w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
127
+ w = torch.reshape(w, (num_groups * inC, -1, convH, convW))
128
+
129
+ x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0)
130
+ ## Original TF code.
131
+ # x = tf.nn.conv2d_transpose(
132
+ # x,
133
+ # w,
134
+ # output_shape=output_shape,
135
+ # strides=stride,
136
+ # padding='VALID',
137
+ # data_format=data_format)
138
+ ## JAX equivalent
139
+
140
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
141
+ pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
142
+
143
+
144
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1):
145
+ """Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
146
+
147
+ Padding is performed only once at the beginning, not between the operations.
148
+ The fused op is considerably more efficient than performing the same
149
+ calculation
150
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
151
+ Args:
152
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
153
+ C]`.
154
+ w: Weight tensor of the shape `[filterH, filterW, inChannels,
155
+ outChannels]`. Grouped convolution can be performed by `inChannels =
156
+ x.shape[0] // numGroups`.
157
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
158
+ (separable). The default is `[1] * factor`, which corresponds to
159
+ average pooling.
160
+ factor: Integer downsampling factor (default: 2).
161
+ gain: Scaling factor for signal magnitude (default: 1.0).
162
+
163
+ Returns:
164
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
165
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
166
+ """
167
+
168
+ assert isinstance(factor, int) and factor >= 1
169
+ _outC, _inC, convH, convW = w.shape
170
+ assert convW == convH
171
+ if k is None:
172
+ k = [1] * factor
173
+ k = _setup_kernel(k) * gain
174
+ p = (k.shape[0] - factor) + (convW - 1)
175
+ s = [factor, factor]
176
+ x = upfirdn2d(x, torch.tensor(k, device=x.device),
177
+ pad=((p + 1) // 2, p // 2))
178
+ return F.conv2d(x, w, stride=s, padding=0)
179
+
180
+
181
+ def _setup_kernel(k):
182
+ k = np.asarray(k, dtype=np.float32)
183
+ if k.ndim == 1:
184
+ k = np.outer(k, k)
185
+ k /= np.sum(k)
186
+ assert k.ndim == 2
187
+ assert k.shape[0] == k.shape[1]
188
+ return k
189
+
190
+
191
+ def _shape(x, dim):
192
+ return x.shape[dim]
193
+
194
+
195
+ def upsample_2d(x, k=None, factor=2, gain=1):
196
+ r"""Upsample a batch of 2D images with the given filter.
197
+
198
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
199
+ and upsamples each image with the given filter. The filter is normalized so
200
+ that
201
+ if the input pixels are constant, they will be scaled by the specified
202
+ `gain`.
203
+ Pixels outside the image are assumed to be zero, and the filter is padded
204
+ with
205
+ zeros so that its shape is a multiple of the upsampling factor.
206
+ Args:
207
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
208
+ C]`.
209
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
210
+ (separable). The default is `[1] * factor`, which corresponds to
211
+ nearest-neighbor upsampling.
212
+ factor: Integer upsampling factor (default: 2).
213
+ gain: Scaling factor for signal magnitude (default: 1.0).
214
+
215
+ Returns:
216
+ Tensor of the shape `[N, C, H * factor, W * factor]`
217
+ """
218
+ assert isinstance(factor, int) and factor >= 1
219
+ if k is None:
220
+ k = [1] * factor
221
+ k = _setup_kernel(k) * (gain * (factor ** 2))
222
+ p = k.shape[0] - factor
223
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
224
+ up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
225
+
226
+
227
+ def downsample_2d(x, k=None, factor=2, gain=1):
228
+ r"""Downsample a batch of 2D images with the given filter.
229
+
230
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
231
+ and downsamples each image with the given filter. The filter is normalized
232
+ so that
233
+ if the input pixels are constant, they will be scaled by the specified
234
+ `gain`.
235
+ Pixels outside the image are assumed to be zero, and the filter is padded
236
+ with
237
+ zeros so that its shape is a multiple of the downsampling factor.
238
+ Args:
239
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
240
+ C]`.
241
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
242
+ (separable). The default is `[1] * factor`, which corresponds to
243
+ average pooling.
244
+ factor: Integer downsampling factor (default: 2).
245
+ gain: Scaling factor for signal magnitude (default: 1.0).
246
+
247
+ Returns:
248
+ Tensor of the shape `[N, C, H // factor, W // factor]`
249
+ """
250
+
251
+ assert isinstance(factor, int) and factor >= 1
252
+ if k is None:
253
+ k = [1] * factor
254
+ k = _setup_kernel(k) * gain
255
+ p = k.shape[0] - factor
256
+ return upfirdn2d(x, torch.tensor(k, device=x.device),
257
+ down=factor, pad=((p + 1) // 2, p // 2))
sgmse/backbones/ncsnpp_utils/utils.py CHANGED
@@ -1,189 +1,189 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- """All functions and modules related to model definition.
17
- """
18
-
19
- import torch
20
-
21
- import numpy as np
22
- from ...sdes import OUVESDE, OUVPSDE
23
-
24
-
25
- _MODELS = {}
26
-
27
-
28
- def register_model(cls=None, *, name=None):
29
- """A decorator for registering model classes."""
30
-
31
- def _register(cls):
32
- if name is None:
33
- local_name = cls.__name__
34
- else:
35
- local_name = name
36
- if local_name in _MODELS:
37
- raise ValueError(f'Already registered model with name: {local_name}')
38
- _MODELS[local_name] = cls
39
- return cls
40
-
41
- if cls is None:
42
- return _register
43
- else:
44
- return _register(cls)
45
-
46
-
47
- def get_model(name):
48
- return _MODELS[name]
49
-
50
-
51
- def get_sigmas(sigma_min, sigma_max, num_scales):
52
- """Get sigmas --- the set of noise levels for SMLD from config files.
53
- Args:
54
- config: A ConfigDict object parsed from the config file
55
- Returns:
56
- sigmas: a jax numpy arrary of noise levels
57
- """
58
- sigmas = np.exp(
59
- np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
-
61
- return sigmas
62
-
63
-
64
- def get_ddpm_params(config):
65
- """Get betas and alphas --- parameters used in the original DDPM paper."""
66
- num_diffusion_timesteps = 1000
67
- # parameters need to be adapted if number of time steps differs from 1000
68
- beta_start = config.model.beta_min / config.model.num_scales
69
- beta_end = config.model.beta_max / config.model.num_scales
70
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
-
72
- alphas = 1. - betas
73
- alphas_cumprod = np.cumprod(alphas, axis=0)
74
- sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
- sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
-
77
- return {
78
- 'betas': betas,
79
- 'alphas': alphas,
80
- 'alphas_cumprod': alphas_cumprod,
81
- 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
- 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
- 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
- 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
- 'num_diffusion_timesteps': num_diffusion_timesteps
86
- }
87
-
88
-
89
- def create_model(config):
90
- """Create the score model."""
91
- model_name = config.model.name
92
- score_model = get_model(model_name)(config)
93
- score_model = score_model.to(config.device)
94
- score_model = torch.nn.DataParallel(score_model)
95
- return score_model
96
-
97
-
98
- def get_model_fn(model, train=False):
99
- """Create a function to give the output of the score-based model.
100
-
101
- Args:
102
- model: The score model.
103
- train: `True` for training and `False` for evaluation.
104
-
105
- Returns:
106
- A model function.
107
- """
108
-
109
- def model_fn(x, labels):
110
- """Compute the output of the score-based model.
111
-
112
- Args:
113
- x: A mini-batch of input data.
114
- labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
- for different models.
116
-
117
- Returns:
118
- A tuple of (model output, new mutable states)
119
- """
120
- if not train:
121
- model.eval()
122
- return model(x, labels)
123
- else:
124
- model.train()
125
- return model(x, labels)
126
-
127
- return model_fn
128
-
129
-
130
- def get_score_fn(sde, model, train=False, continuous=False):
131
- """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
-
133
- Args:
134
- sde: An `sde_lib.SDE` object that represents the forward SDE.
135
- model: A score model.
136
- train: `True` for training and `False` for evaluation.
137
- continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
-
139
- Returns:
140
- A score function.
141
- """
142
- model_fn = get_model_fn(model, train=train)
143
-
144
- if isinstance(sde, OUVPSDE):
145
- def score_fn(x, t):
146
- # Scale neural network output by standard deviation and flip sign
147
- if continuous:
148
- # For VP-trained models, t=0 corresponds to the lowest noise level
149
- # The maximum value of time embedding is assumed to 999 for
150
- # continuously-trained models.
151
- labels = t * 999
152
- score = model_fn(x, labels)
153
- std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
- else:
155
- # For VP-trained models, t=0 corresponds to the lowest noise level
156
- labels = t * (sde.N - 1)
157
- score = model_fn(x, labels)
158
- std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
-
160
- score = -score / std[:, None, None, None]
161
- return score
162
-
163
- elif isinstance(sde, OUVESDE):
164
- def score_fn(x, t):
165
- if continuous:
166
- labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
- else:
168
- # For VE-trained models, t=0 corresponds to the highest noise level
169
- labels = sde.T - t
170
- labels *= sde.N - 1
171
- labels = torch.round(labels).long()
172
-
173
- score = model_fn(x, labels)
174
- return score
175
-
176
- else:
177
- raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
-
179
- return score_fn
180
-
181
-
182
- def to_flattened_numpy(x):
183
- """Flatten a torch tensor `x` and convert it to numpy."""
184
- return x.detach().cpu().numpy().reshape((-1,))
185
-
186
-
187
- def from_flattened_numpy(x, shape):
188
- """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
  return torch.from_numpy(x.reshape(shape))
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """All functions and modules related to model definition.
17
+ """
18
+
19
+ import torch
20
+
21
+ import numpy as np
22
+ from ...sdes import OUVESDE, OUVPSDE
23
+
24
+
25
+ _MODELS = {}
26
+
27
+
28
+ def register_model(cls=None, *, name=None):
29
+ """A decorator for registering model classes."""
30
+
31
+ def _register(cls):
32
+ if name is None:
33
+ local_name = cls.__name__
34
+ else:
35
+ local_name = name
36
+ if local_name in _MODELS:
37
+ raise ValueError(f'Already registered model with name: {local_name}')
38
+ _MODELS[local_name] = cls
39
+ return cls
40
+
41
+ if cls is None:
42
+ return _register
43
+ else:
44
+ return _register(cls)
45
+
46
+
47
+ def get_model(name):
48
+ return _MODELS[name]
49
+
50
+
51
+ def get_sigmas(sigma_min, sigma_max, num_scales):
52
+ """Get sigmas --- the set of noise levels for SMLD from config files.
53
+ Args:
54
+ config: A ConfigDict object parsed from the config file
55
+ Returns:
56
+ sigmas: a jax numpy arrary of noise levels
57
+ """
58
+ sigmas = np.exp(
59
+ np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales))
60
+
61
+ return sigmas
62
+
63
+
64
+ def get_ddpm_params(config):
65
+ """Get betas and alphas --- parameters used in the original DDPM paper."""
66
+ num_diffusion_timesteps = 1000
67
+ # parameters need to be adapted if number of time steps differs from 1000
68
+ beta_start = config.model.beta_min / config.model.num_scales
69
+ beta_end = config.model.beta_max / config.model.num_scales
70
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
71
+
72
+ alphas = 1. - betas
73
+ alphas_cumprod = np.cumprod(alphas, axis=0)
74
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
75
+ sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
76
+
77
+ return {
78
+ 'betas': betas,
79
+ 'alphas': alphas,
80
+ 'alphas_cumprod': alphas_cumprod,
81
+ 'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
82
+ 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
83
+ 'beta_min': beta_start * (num_diffusion_timesteps - 1),
84
+ 'beta_max': beta_end * (num_diffusion_timesteps - 1),
85
+ 'num_diffusion_timesteps': num_diffusion_timesteps
86
+ }
87
+
88
+
89
+ def create_model(config):
90
+ """Create the score model."""
91
+ model_name = config.model.name
92
+ score_model = get_model(model_name)(config)
93
+ score_model = score_model.to(config.device)
94
+ score_model = torch.nn.DataParallel(score_model)
95
+ return score_model
96
+
97
+
98
+ def get_model_fn(model, train=False):
99
+ """Create a function to give the output of the score-based model.
100
+
101
+ Args:
102
+ model: The score model.
103
+ train: `True` for training and `False` for evaluation.
104
+
105
+ Returns:
106
+ A model function.
107
+ """
108
+
109
+ def model_fn(x, labels):
110
+ """Compute the output of the score-based model.
111
+
112
+ Args:
113
+ x: A mini-batch of input data.
114
+ labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
115
+ for different models.
116
+
117
+ Returns:
118
+ A tuple of (model output, new mutable states)
119
+ """
120
+ if not train:
121
+ model.eval()
122
+ return model(x, labels)
123
+ else:
124
+ model.train()
125
+ return model(x, labels)
126
+
127
+ return model_fn
128
+
129
+
130
+ def get_score_fn(sde, model, train=False, continuous=False):
131
+ """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
132
+
133
+ Args:
134
+ sde: An `sde_lib.SDE` object that represents the forward SDE.
135
+ model: A score model.
136
+ train: `True` for training and `False` for evaluation.
137
+ continuous: If `True`, the score-based model is expected to directly take continuous time steps.
138
+
139
+ Returns:
140
+ A score function.
141
+ """
142
+ model_fn = get_model_fn(model, train=train)
143
+
144
+ if isinstance(sde, OUVPSDE):
145
+ def score_fn(x, t):
146
+ # Scale neural network output by standard deviation and flip sign
147
+ if continuous:
148
+ # For VP-trained models, t=0 corresponds to the lowest noise level
149
+ # The maximum value of time embedding is assumed to 999 for
150
+ # continuously-trained models.
151
+ labels = t * 999
152
+ score = model_fn(x, labels)
153
+ std = sde.marginal_prob(torch.zeros_like(x), t)[1]
154
+ else:
155
+ # For VP-trained models, t=0 corresponds to the lowest noise level
156
+ labels = t * (sde.N - 1)
157
+ score = model_fn(x, labels)
158
+ std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
159
+
160
+ score = -score / std[:, None, None, None]
161
+ return score
162
+
163
+ elif isinstance(sde, OUVESDE):
164
+ def score_fn(x, t):
165
+ if continuous:
166
+ labels = sde.marginal_prob(torch.zeros_like(x), t)[1]
167
+ else:
168
+ # For VE-trained models, t=0 corresponds to the highest noise level
169
+ labels = sde.T - t
170
+ labels *= sde.N - 1
171
+ labels = torch.round(labels).long()
172
+
173
+ score = model_fn(x, labels)
174
+ return score
175
+
176
+ else:
177
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
178
+
179
+ return score_fn
180
+
181
+
182
+ def to_flattened_numpy(x):
183
+ """Flatten a torch tensor `x` and convert it to numpy."""
184
+ return x.detach().cpu().numpy().reshape((-1,))
185
+
186
+
187
+ def from_flattened_numpy(x, shape):
188
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
189
  return torch.from_numpy(x.reshape(shape))
sgmse/backbones/ncsnpp_v2.py CHANGED
@@ -1,395 +1,395 @@
1
- # coding=utf-8
2
- # Copyright 2020 The Google Research Authors.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- # pylint: skip-file
17
-
18
- from .ncsnpp_utils import layers, layerspp, normalization
19
- import torch.nn as nn
20
- import functools
21
- import torch
22
- import numpy as np
23
-
24
- from .shared import BackboneRegistry
25
-
26
- ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
- ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
- Combine = layerspp.Combine
29
- conv3x3 = layerspp.conv3x3
30
- conv1x1 = layerspp.conv1x1
31
- get_act = layers.get_act
32
- get_normalization = normalization.get_normalization
33
- default_initializer = layers.default_init
34
-
35
-
36
- @BackboneRegistry.register("ncsnpp_v2")
37
- class NCSNpp_v2(nn.Module):
38
- """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
-
40
- @staticmethod
41
- def add_argparse_args(parser):
42
- parser.add_argument("--nf", type=int, default=128)
43
- parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
44
- parser.add_argument("--num_res_blocks", type=int, default=2)
45
- parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
46
- return parser
47
-
48
- def __init__(self,
49
- nf = 128,
50
- ch_mult = (1, 1, 2, 2, 2, 2, 2),
51
- num_res_blocks = 2,
52
- attn_resolutions = (16,),
53
- nonlinearity = 'swish',
54
- resamp_with_conv = True,
55
- fir = True,
56
- fir_kernel = [1, 3, 3, 1],
57
- skip_rescale = True,
58
- resblock_type = 'biggan',
59
- progressive = 'output_skip',
60
- progressive_input = 'input_skip',
61
- progressive_combine = 'sum',
62
- init_scale = 0.,
63
- fourier_scale = 16,
64
- image_size = 256,
65
- embedding_type = 'fourier',
66
- dropout = .0,
67
- **unused_kwargs
68
- ):
69
- super().__init__()
70
- self.act = act = get_act(nonlinearity)
71
-
72
- self.nf = nf = nf
73
- ch_mult = ch_mult
74
- self.num_res_blocks = num_res_blocks = num_res_blocks
75
- self.attn_resolutions = attn_resolutions = attn_resolutions
76
- self.num_resolutions = num_resolutions = len(ch_mult)
77
- self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
78
- self.skip_rescale = skip_rescale = skip_rescale
79
- self.resblock_type = resblock_type = resblock_type.lower()
80
- self.progressive = progressive = progressive.lower()
81
- self.progressive_input = progressive_input = progressive_input.lower()
82
- self.embedding_type = embedding_type = embedding_type.lower()
83
-
84
- assert progressive in ['none', 'output_skip', 'residual']
85
- assert progressive_input in ['none', 'input_skip', 'residual']
86
- assert embedding_type in ['fourier', 'positional']
87
- combine_method = progressive_combine.lower()
88
- combiner = functools.partial(Combine, method=combine_method)
89
-
90
- in_channels = 4 # x.real, x.imag, y.real, y.imag
91
- out_channels = 2 # score.real, score.imag
92
- self.output_layer = nn.Conv2d(in_channels, out_channels, 1)
93
-
94
- modules = []
95
- # timestep/noise_level embedding
96
- if embedding_type == 'fourier':
97
- # Gaussian Fourier features embeddings.
98
- modules.append(layerspp.GaussianFourierProjection(
99
- embedding_size=nf, scale=fourier_scale
100
- ))
101
- embed_dim = 2 * nf
102
- elif embedding_type == 'positional':
103
- embed_dim = nf
104
- else:
105
- raise ValueError(f'embedding type {embedding_type} unknown.')
106
-
107
- modules.append(nn.Linear(embed_dim, nf * 4))
108
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
109
- nn.init.zeros_(modules[-1].bias)
110
- modules.append(nn.Linear(nf * 4, nf * 4))
111
- modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
112
- nn.init.zeros_(modules[-1].bias)
113
-
114
- AttnBlock = functools.partial(layerspp.AttnBlockpp,
115
- init_scale=init_scale, skip_rescale=skip_rescale)
116
-
117
- Upsample = functools.partial(layerspp.Upsample,
118
- with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
119
-
120
- if progressive == 'output_skip':
121
- self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
122
- elif progressive == 'residual':
123
- pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
124
- fir_kernel=fir_kernel, with_conv=True)
125
-
126
- Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
127
-
128
- if progressive_input == 'input_skip':
129
- self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
130
- elif progressive_input == 'residual':
131
- pyramid_downsample = functools.partial(layerspp.Downsample,
132
- fir=fir, fir_kernel=fir_kernel, with_conv=True)
133
-
134
- if resblock_type == 'ddpm':
135
- ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
136
- dropout=dropout, init_scale=init_scale,
137
- skip_rescale=skip_rescale, temb_dim=nf * 4)
138
-
139
- elif resblock_type == 'biggan':
140
- ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
141
- dropout=dropout, fir=fir, fir_kernel=fir_kernel,
142
- init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
143
-
144
- else:
145
- raise ValueError(f'resblock type {resblock_type} unrecognized.')
146
-
147
- # Downsampling block
148
-
149
- channels = in_channels
150
- if progressive_input != 'none':
151
- input_pyramid_ch = channels
152
-
153
- modules.append(conv3x3(channels, nf))
154
- hs_c = [nf]
155
-
156
- in_ch = nf
157
- for i_level in range(num_resolutions):
158
- # Residual blocks for this resolution
159
- for i_block in range(num_res_blocks):
160
- out_ch = nf * ch_mult[i_level]
161
- modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
162
- in_ch = out_ch
163
-
164
- if all_resolutions[i_level] in attn_resolutions:
165
- modules.append(AttnBlock(channels=in_ch))
166
- hs_c.append(in_ch)
167
-
168
- if i_level != num_resolutions - 1:
169
- if resblock_type == 'ddpm':
170
- modules.append(Downsample(in_ch=in_ch))
171
- else:
172
- modules.append(ResnetBlock(down=True, in_ch=in_ch))
173
-
174
- if progressive_input == 'input_skip':
175
- modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
176
- if combine_method == 'cat':
177
- in_ch *= 2
178
-
179
- elif progressive_input == 'residual':
180
- modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
181
- input_pyramid_ch = in_ch
182
-
183
- hs_c.append(in_ch)
184
-
185
- in_ch = hs_c[-1]
186
- modules.append(ResnetBlock(in_ch=in_ch))
187
- modules.append(AttnBlock(channels=in_ch))
188
- modules.append(ResnetBlock(in_ch=in_ch))
189
-
190
- pyramid_ch = 0
191
- # Upsampling block
192
- for i_level in reversed(range(num_resolutions)):
193
- for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
194
- out_ch = nf * ch_mult[i_level]
195
- modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
196
- in_ch = out_ch
197
-
198
- if all_resolutions[i_level] in attn_resolutions:
199
- modules.append(AttnBlock(channels=in_ch))
200
-
201
- if progressive != 'none':
202
- if i_level == num_resolutions - 1:
203
- if progressive == 'output_skip':
204
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
205
- num_channels=in_ch, eps=1e-6))
206
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
207
- pyramid_ch = channels
208
- elif progressive == 'residual':
209
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
210
- modules.append(conv3x3(in_ch, in_ch, bias=True))
211
- pyramid_ch = in_ch
212
- else:
213
- raise ValueError(f'{progressive} is not a valid name.')
214
- else:
215
- if progressive == 'output_skip':
216
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
217
- num_channels=in_ch, eps=1e-6))
218
- modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
219
- pyramid_ch = channels
220
- elif progressive == 'residual':
221
- modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
222
- pyramid_ch = in_ch
223
- else:
224
- raise ValueError(f'{progressive} is not a valid name')
225
-
226
- if i_level != 0:
227
- if resblock_type == 'ddpm':
228
- modules.append(Upsample(in_ch=in_ch))
229
- else:
230
- modules.append(ResnetBlock(in_ch=in_ch, up=True))
231
-
232
- assert not hs_c
233
-
234
- if progressive != 'output_skip':
235
- modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
236
- modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
237
-
238
- self.all_modules = nn.ModuleList(modules)
239
-
240
-
241
- def forward(self, x, y, t):
242
- # timestep/noise_level embedding; only for continuous training
243
- modules = self.all_modules
244
- m_idx = 0
245
-
246
- # Convert real and imaginary parts of (x,y) into four channel dimensions
247
- x = torch.cat((x.real, x.imag, y.real, y.imag), dim=1)
248
-
249
- if self.embedding_type == 'fourier':
250
- # Gaussian Fourier features embeddings.
251
- used_sigmas = t
252
- temb = modules[m_idx](torch.log(used_sigmas))
253
- m_idx += 1
254
-
255
- elif self.embedding_type == 'positional':
256
- # Sinusoidal positional embeddings.
257
- timesteps = t
258
- used_sigmas = self.sigmas[t.long()]
259
- temb = layers.get_timestep_embedding(timesteps, self.nf)
260
-
261
- else:
262
- raise ValueError(f'embedding type {self.embedding_type} unknown.')
263
-
264
- temb = modules[m_idx](temb)
265
- m_idx += 1
266
- temb = modules[m_idx](self.act(temb))
267
- m_idx += 1
268
-
269
- # Downsampling block
270
- input_pyramid = None
271
- if self.progressive_input != 'none':
272
- input_pyramid = x
273
-
274
- # Input layer: Conv2d: 4ch -> 128ch
275
- hs = [modules[m_idx](x)]
276
- m_idx += 1
277
-
278
- # Down path in U-Net
279
- for i_level in range(self.num_resolutions):
280
- # Residual blocks for this resolution
281
- for i_block in range(self.num_res_blocks):
282
- h = modules[m_idx](hs[-1], temb)
283
- m_idx += 1
284
- # Attention layer (optional)
285
- if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
286
- h = modules[m_idx](h)
287
- m_idx += 1
288
- hs.append(h)
289
-
290
- # Downsampling
291
- if i_level != self.num_resolutions - 1:
292
- if self.resblock_type == 'ddpm':
293
- h = modules[m_idx](hs[-1])
294
- m_idx += 1
295
- else:
296
- h = modules[m_idx](hs[-1], temb)
297
- m_idx += 1
298
-
299
- if self.progressive_input == 'input_skip': # Combine h with x
300
- input_pyramid = self.pyramid_downsample(input_pyramid)
301
- h = modules[m_idx](input_pyramid, h)
302
- m_idx += 1
303
-
304
- elif self.progressive_input == 'residual':
305
- input_pyramid = modules[m_idx](input_pyramid)
306
- m_idx += 1
307
- if self.skip_rescale:
308
- input_pyramid = (input_pyramid + h) / np.sqrt(2.)
309
- else:
310
- input_pyramid = input_pyramid + h
311
- h = input_pyramid
312
- hs.append(h)
313
-
314
- h = hs[-1] # actualy equal to: h = h
315
- h = modules[m_idx](h, temb) # ResNet block
316
- m_idx += 1
317
- h = modules[m_idx](h) # Attention block
318
- m_idx += 1
319
- h = modules[m_idx](h, temb) # ResNet block
320
- m_idx += 1
321
-
322
- pyramid = None
323
-
324
- # Upsampling block
325
- for i_level in reversed(range(self.num_resolutions)):
326
- for i_block in range(self.num_res_blocks + 1):
327
- h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
328
- m_idx += 1
329
-
330
- # edit: from -1 to -2
331
- if h.shape[-2] in self.attn_resolutions:
332
- h = modules[m_idx](h)
333
- m_idx += 1
334
-
335
- if self.progressive != 'none':
336
- if i_level == self.num_resolutions - 1:
337
- if self.progressive == 'output_skip':
338
- pyramid = self.act(modules[m_idx](h)) # GroupNorm
339
- m_idx += 1
340
- pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
341
- m_idx += 1
342
- elif self.progressive == 'residual':
343
- pyramid = self.act(modules[m_idx](h))
344
- m_idx += 1
345
- pyramid = modules[m_idx](pyramid)
346
- m_idx += 1
347
- else:
348
- raise ValueError(f'{self.progressive} is not a valid name.')
349
- else:
350
- if self.progressive == 'output_skip':
351
- pyramid = self.pyramid_upsample(pyramid) # Upsample
352
- pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
353
- m_idx += 1
354
- pyramid_h = modules[m_idx](pyramid_h)
355
- m_idx += 1
356
- pyramid = pyramid + pyramid_h
357
- elif self.progressive == 'residual':
358
- pyramid = modules[m_idx](pyramid)
359
- m_idx += 1
360
- if self.skip_rescale:
361
- pyramid = (pyramid + h) / np.sqrt(2.)
362
- else:
363
- pyramid = pyramid + h
364
- h = pyramid
365
- else:
366
- raise ValueError(f'{self.progressive} is not a valid name')
367
-
368
- # Upsampling Layer
369
- if i_level != 0:
370
- if self.resblock_type == 'ddpm':
371
- h = modules[m_idx](h)
372
- m_idx += 1
373
- else:
374
- h = modules[m_idx](h, temb) # Upspampling
375
- m_idx += 1
376
-
377
- assert not hs
378
-
379
- if self.progressive == 'output_skip':
380
- h = pyramid
381
- else:
382
- h = self.act(modules[m_idx](h))
383
- m_idx += 1
384
- h = modules[m_idx](h)
385
- m_idx += 1
386
-
387
- assert m_idx == len(modules), "Implementation error"
388
-
389
- h = self.output_layer(h)
390
- h = torch.permute(h, (0, 2, 3, 1)).contiguous()
391
-
392
- # Convert back to complex number
393
- h = torch.view_as_complex(h)[:,None, :, :]
394
-
395
- return h
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # pylint: skip-file
17
+
18
+ from .ncsnpp_utils import layers, layerspp, normalization
19
+ import torch.nn as nn
20
+ import functools
21
+ import torch
22
+ import numpy as np
23
+
24
+ from .shared import BackboneRegistry
25
+
26
+ ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp
27
+ ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp
28
+ Combine = layerspp.Combine
29
+ conv3x3 = layerspp.conv3x3
30
+ conv1x1 = layerspp.conv1x1
31
+ get_act = layers.get_act
32
+ get_normalization = normalization.get_normalization
33
+ default_initializer = layers.default_init
34
+
35
+
36
+ @BackboneRegistry.register("ncsnpp_v2")
37
+ class NCSNpp_v2(nn.Module):
38
+ """NCSN++ model, adapted from https://github.com/yang-song/score_sde repository"""
39
+
40
+ @staticmethod
41
+ def add_argparse_args(parser):
42
+ parser.add_argument("--nf", type=int, default=128)
43
+ parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
44
+ parser.add_argument("--num_res_blocks", type=int, default=2)
45
+ parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
46
+ return parser
47
+
48
+ def __init__(self,
49
+ nf = 128,
50
+ ch_mult = (1, 1, 2, 2, 2, 2, 2),
51
+ num_res_blocks = 2,
52
+ attn_resolutions = (16,),
53
+ nonlinearity = 'swish',
54
+ resamp_with_conv = True,
55
+ fir = True,
56
+ fir_kernel = [1, 3, 3, 1],
57
+ skip_rescale = True,
58
+ resblock_type = 'biggan',
59
+ progressive = 'output_skip',
60
+ progressive_input = 'input_skip',
61
+ progressive_combine = 'sum',
62
+ init_scale = 0.,
63
+ fourier_scale = 16,
64
+ image_size = 256,
65
+ embedding_type = 'fourier',
66
+ dropout = .0,
67
+ **unused_kwargs
68
+ ):
69
+ super().__init__()
70
+ self.act = act = get_act(nonlinearity)
71
+
72
+ self.nf = nf = nf
73
+ ch_mult = ch_mult
74
+ self.num_res_blocks = num_res_blocks = num_res_blocks
75
+ self.attn_resolutions = attn_resolutions = attn_resolutions
76
+ self.num_resolutions = num_resolutions = len(ch_mult)
77
+ self.all_resolutions = all_resolutions = [image_size // (2 ** i) for i in range(num_resolutions)]
78
+ self.skip_rescale = skip_rescale = skip_rescale
79
+ self.resblock_type = resblock_type = resblock_type.lower()
80
+ self.progressive = progressive = progressive.lower()
81
+ self.progressive_input = progressive_input = progressive_input.lower()
82
+ self.embedding_type = embedding_type = embedding_type.lower()
83
+
84
+ assert progressive in ['none', 'output_skip', 'residual']
85
+ assert progressive_input in ['none', 'input_skip', 'residual']
86
+ assert embedding_type in ['fourier', 'positional']
87
+ combine_method = progressive_combine.lower()
88
+ combiner = functools.partial(Combine, method=combine_method)
89
+
90
+ in_channels = 4 # x.real, x.imag, y.real, y.imag
91
+ out_channels = 2 # score.real, score.imag
92
+ self.output_layer = nn.Conv2d(in_channels, out_channels, 1)
93
+
94
+ modules = []
95
+ # timestep/noise_level embedding
96
+ if embedding_type == 'fourier':
97
+ # Gaussian Fourier features embeddings.
98
+ modules.append(layerspp.GaussianFourierProjection(
99
+ embedding_size=nf, scale=fourier_scale
100
+ ))
101
+ embed_dim = 2 * nf
102
+ elif embedding_type == 'positional':
103
+ embed_dim = nf
104
+ else:
105
+ raise ValueError(f'embedding type {embedding_type} unknown.')
106
+
107
+ modules.append(nn.Linear(embed_dim, nf * 4))
108
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
109
+ nn.init.zeros_(modules[-1].bias)
110
+ modules.append(nn.Linear(nf * 4, nf * 4))
111
+ modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
112
+ nn.init.zeros_(modules[-1].bias)
113
+
114
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
115
+ init_scale=init_scale, skip_rescale=skip_rescale)
116
+
117
+ Upsample = functools.partial(layerspp.Upsample,
118
+ with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
119
+
120
+ if progressive == 'output_skip':
121
+ self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
122
+ elif progressive == 'residual':
123
+ pyramid_upsample = functools.partial(layerspp.Upsample, fir=fir,
124
+ fir_kernel=fir_kernel, with_conv=True)
125
+
126
+ Downsample = functools.partial(layerspp.Downsample, with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
127
+
128
+ if progressive_input == 'input_skip':
129
+ self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False)
130
+ elif progressive_input == 'residual':
131
+ pyramid_downsample = functools.partial(layerspp.Downsample,
132
+ fir=fir, fir_kernel=fir_kernel, with_conv=True)
133
+
134
+ if resblock_type == 'ddpm':
135
+ ResnetBlock = functools.partial(ResnetBlockDDPM, act=act,
136
+ dropout=dropout, init_scale=init_scale,
137
+ skip_rescale=skip_rescale, temb_dim=nf * 4)
138
+
139
+ elif resblock_type == 'biggan':
140
+ ResnetBlock = functools.partial(ResnetBlockBigGAN, act=act,
141
+ dropout=dropout, fir=fir, fir_kernel=fir_kernel,
142
+ init_scale=init_scale, skip_rescale=skip_rescale, temb_dim=nf * 4)
143
+
144
+ else:
145
+ raise ValueError(f'resblock type {resblock_type} unrecognized.')
146
+
147
+ # Downsampling block
148
+
149
+ channels = in_channels
150
+ if progressive_input != 'none':
151
+ input_pyramid_ch = channels
152
+
153
+ modules.append(conv3x3(channels, nf))
154
+ hs_c = [nf]
155
+
156
+ in_ch = nf
157
+ for i_level in range(num_resolutions):
158
+ # Residual blocks for this resolution
159
+ for i_block in range(num_res_blocks):
160
+ out_ch = nf * ch_mult[i_level]
161
+ modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
162
+ in_ch = out_ch
163
+
164
+ if all_resolutions[i_level] in attn_resolutions:
165
+ modules.append(AttnBlock(channels=in_ch))
166
+ hs_c.append(in_ch)
167
+
168
+ if i_level != num_resolutions - 1:
169
+ if resblock_type == 'ddpm':
170
+ modules.append(Downsample(in_ch=in_ch))
171
+ else:
172
+ modules.append(ResnetBlock(down=True, in_ch=in_ch))
173
+
174
+ if progressive_input == 'input_skip':
175
+ modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch))
176
+ if combine_method == 'cat':
177
+ in_ch *= 2
178
+
179
+ elif progressive_input == 'residual':
180
+ modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch))
181
+ input_pyramid_ch = in_ch
182
+
183
+ hs_c.append(in_ch)
184
+
185
+ in_ch = hs_c[-1]
186
+ modules.append(ResnetBlock(in_ch=in_ch))
187
+ modules.append(AttnBlock(channels=in_ch))
188
+ modules.append(ResnetBlock(in_ch=in_ch))
189
+
190
+ pyramid_ch = 0
191
+ # Upsampling block
192
+ for i_level in reversed(range(num_resolutions)):
193
+ for i_block in range(num_res_blocks + 1): # +1 blocks in upsampling because of skip connection from combiner (after downsampling)
194
+ out_ch = nf * ch_mult[i_level]
195
+ modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
196
+ in_ch = out_ch
197
+
198
+ if all_resolutions[i_level] in attn_resolutions:
199
+ modules.append(AttnBlock(channels=in_ch))
200
+
201
+ if progressive != 'none':
202
+ if i_level == num_resolutions - 1:
203
+ if progressive == 'output_skip':
204
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
205
+ num_channels=in_ch, eps=1e-6))
206
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
207
+ pyramid_ch = channels
208
+ elif progressive == 'residual':
209
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
210
+ modules.append(conv3x3(in_ch, in_ch, bias=True))
211
+ pyramid_ch = in_ch
212
+ else:
213
+ raise ValueError(f'{progressive} is not a valid name.')
214
+ else:
215
+ if progressive == 'output_skip':
216
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32),
217
+ num_channels=in_ch, eps=1e-6))
218
+ modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale))
219
+ pyramid_ch = channels
220
+ elif progressive == 'residual':
221
+ modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch))
222
+ pyramid_ch = in_ch
223
+ else:
224
+ raise ValueError(f'{progressive} is not a valid name')
225
+
226
+ if i_level != 0:
227
+ if resblock_type == 'ddpm':
228
+ modules.append(Upsample(in_ch=in_ch))
229
+ else:
230
+ modules.append(ResnetBlock(in_ch=in_ch, up=True))
231
+
232
+ assert not hs_c
233
+
234
+ if progressive != 'output_skip':
235
+ modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6))
236
+ modules.append(conv3x3(in_ch, channels, init_scale=init_scale))
237
+
238
+ self.all_modules = nn.ModuleList(modules)
239
+
240
+
241
+ def forward(self, x, y, t):
242
+ # timestep/noise_level embedding; only for continuous training
243
+ modules = self.all_modules
244
+ m_idx = 0
245
+
246
+ # Convert real and imaginary parts of (x,y) into four channel dimensions
247
+ x = torch.cat((x.real, x.imag, y.real, y.imag), dim=1)
248
+
249
+ if self.embedding_type == 'fourier':
250
+ # Gaussian Fourier features embeddings.
251
+ used_sigmas = t
252
+ temb = modules[m_idx](torch.log(used_sigmas))
253
+ m_idx += 1
254
+
255
+ elif self.embedding_type == 'positional':
256
+ # Sinusoidal positional embeddings.
257
+ timesteps = t
258
+ used_sigmas = self.sigmas[t.long()]
259
+ temb = layers.get_timestep_embedding(timesteps, self.nf)
260
+
261
+ else:
262
+ raise ValueError(f'embedding type {self.embedding_type} unknown.')
263
+
264
+ temb = modules[m_idx](temb)
265
+ m_idx += 1
266
+ temb = modules[m_idx](self.act(temb))
267
+ m_idx += 1
268
+
269
+ # Downsampling block
270
+ input_pyramid = None
271
+ if self.progressive_input != 'none':
272
+ input_pyramid = x
273
+
274
+ # Input layer: Conv2d: 4ch -> 128ch
275
+ hs = [modules[m_idx](x)]
276
+ m_idx += 1
277
+
278
+ # Down path in U-Net
279
+ for i_level in range(self.num_resolutions):
280
+ # Residual blocks for this resolution
281
+ for i_block in range(self.num_res_blocks):
282
+ h = modules[m_idx](hs[-1], temb)
283
+ m_idx += 1
284
+ # Attention layer (optional)
285
+ if h.shape[-2] in self.attn_resolutions: # edit: check H dim (-2) not W dim (-1)
286
+ h = modules[m_idx](h)
287
+ m_idx += 1
288
+ hs.append(h)
289
+
290
+ # Downsampling
291
+ if i_level != self.num_resolutions - 1:
292
+ if self.resblock_type == 'ddpm':
293
+ h = modules[m_idx](hs[-1])
294
+ m_idx += 1
295
+ else:
296
+ h = modules[m_idx](hs[-1], temb)
297
+ m_idx += 1
298
+
299
+ if self.progressive_input == 'input_skip': # Combine h with x
300
+ input_pyramid = self.pyramid_downsample(input_pyramid)
301
+ h = modules[m_idx](input_pyramid, h)
302
+ m_idx += 1
303
+
304
+ elif self.progressive_input == 'residual':
305
+ input_pyramid = modules[m_idx](input_pyramid)
306
+ m_idx += 1
307
+ if self.skip_rescale:
308
+ input_pyramid = (input_pyramid + h) / np.sqrt(2.)
309
+ else:
310
+ input_pyramid = input_pyramid + h
311
+ h = input_pyramid
312
+ hs.append(h)
313
+
314
+ h = hs[-1] # actualy equal to: h = h
315
+ h = modules[m_idx](h, temb) # ResNet block
316
+ m_idx += 1
317
+ h = modules[m_idx](h) # Attention block
318
+ m_idx += 1
319
+ h = modules[m_idx](h, temb) # ResNet block
320
+ m_idx += 1
321
+
322
+ pyramid = None
323
+
324
+ # Upsampling block
325
+ for i_level in reversed(range(self.num_resolutions)):
326
+ for i_block in range(self.num_res_blocks + 1):
327
+ h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
328
+ m_idx += 1
329
+
330
+ # edit: from -1 to -2
331
+ if h.shape[-2] in self.attn_resolutions:
332
+ h = modules[m_idx](h)
333
+ m_idx += 1
334
+
335
+ if self.progressive != 'none':
336
+ if i_level == self.num_resolutions - 1:
337
+ if self.progressive == 'output_skip':
338
+ pyramid = self.act(modules[m_idx](h)) # GroupNorm
339
+ m_idx += 1
340
+ pyramid = modules[m_idx](pyramid) # Conv2D: 256 -> 4
341
+ m_idx += 1
342
+ elif self.progressive == 'residual':
343
+ pyramid = self.act(modules[m_idx](h))
344
+ m_idx += 1
345
+ pyramid = modules[m_idx](pyramid)
346
+ m_idx += 1
347
+ else:
348
+ raise ValueError(f'{self.progressive} is not a valid name.')
349
+ else:
350
+ if self.progressive == 'output_skip':
351
+ pyramid = self.pyramid_upsample(pyramid) # Upsample
352
+ pyramid_h = self.act(modules[m_idx](h)) # GroupNorm
353
+ m_idx += 1
354
+ pyramid_h = modules[m_idx](pyramid_h)
355
+ m_idx += 1
356
+ pyramid = pyramid + pyramid_h
357
+ elif self.progressive == 'residual':
358
+ pyramid = modules[m_idx](pyramid)
359
+ m_idx += 1
360
+ if self.skip_rescale:
361
+ pyramid = (pyramid + h) / np.sqrt(2.)
362
+ else:
363
+ pyramid = pyramid + h
364
+ h = pyramid
365
+ else:
366
+ raise ValueError(f'{self.progressive} is not a valid name')
367
+
368
+ # Upsampling Layer
369
+ if i_level != 0:
370
+ if self.resblock_type == 'ddpm':
371
+ h = modules[m_idx](h)
372
+ m_idx += 1
373
+ else:
374
+ h = modules[m_idx](h, temb) # Upspampling
375
+ m_idx += 1
376
+
377
+ assert not hs
378
+
379
+ if self.progressive == 'output_skip':
380
+ h = pyramid
381
+ else:
382
+ h = self.act(modules[m_idx](h))
383
+ m_idx += 1
384
+ h = modules[m_idx](h)
385
+ m_idx += 1
386
+
387
+ assert m_idx == len(modules), "Implementation error"
388
+
389
+ h = self.output_layer(h)
390
+ h = torch.permute(h, (0, 2, 3, 1)).contiguous()
391
+
392
+ # Convert back to complex number
393
+ h = torch.view_as_complex(h)[:,None, :, :]
394
+
395
+ return h
sgmse/backbones/shared.py CHANGED
@@ -1,123 +1,123 @@
1
- import functools
2
- import numpy as np
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from sgmse.util.registry import Registry
8
-
9
-
10
- BackboneRegistry = Registry("Backbone")
11
-
12
-
13
- class GaussianFourierProjection(nn.Module):
14
- """Gaussian random features for encoding time steps."""
15
-
16
- def __init__(self, embed_dim, scale=16, complex_valued=False):
17
- super().__init__()
18
- self.complex_valued = complex_valued
19
- if not complex_valued:
20
- # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
- # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
- # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
- # and this halving is not necessary.
24
- embed_dim = embed_dim // 2
25
- # Randomly sample weights during initialization. These weights are fixed
26
- # during optimization and are not trainable.
27
- self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
-
29
- def forward(self, t):
30
- t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
- if self.complex_valued:
32
- return torch.exp(1j * t_proj)
33
- else:
34
- return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
-
36
-
37
- class DiffusionStepEmbedding(nn.Module):
38
- """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
-
40
- def __init__(self, embed_dim, complex_valued=False):
41
- super().__init__()
42
- self.complex_valued = complex_valued
43
- if not complex_valued:
44
- # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
- # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
- # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
- # and this halving is not necessary.
48
- embed_dim = embed_dim // 2
49
- self.embed_dim = embed_dim
50
-
51
- def forward(self, t):
52
- fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
- inner = t[:, None] * fac[None, :]
54
- if self.complex_valued:
55
- return torch.exp(1j * inner)
56
- else:
57
- return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
-
59
-
60
- class ComplexLinear(nn.Module):
61
- """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
- def __init__(self, input_dim, output_dim, complex_valued):
63
- super().__init__()
64
- self.complex_valued = complex_valued
65
- if self.complex_valued:
66
- self.re = nn.Linear(input_dim, output_dim)
67
- self.im = nn.Linear(input_dim, output_dim)
68
- else:
69
- self.lin = nn.Linear(input_dim, output_dim)
70
-
71
- def forward(self, x):
72
- if self.complex_valued:
73
- return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
- else:
75
- return self.lin(x)
76
-
77
-
78
- class FeatureMapDense(nn.Module):
79
- """A fully connected layer that reshapes outputs to feature maps."""
80
-
81
- def __init__(self, input_dim, output_dim, complex_valued=False):
82
- super().__init__()
83
- self.complex_valued = complex_valued
84
- self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
-
86
- def forward(self, x):
87
- return self.dense(x)[..., None, None]
88
-
89
-
90
- def torch_complex_from_reim(re, im):
91
- return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
-
93
-
94
- class ArgsComplexMultiplicationWrapper(nn.Module):
95
- """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
-
97
- Make a complex-valued module `F` from a real-valued module `f` by applying
98
- complex multiplication rules:
99
-
100
- F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
-
102
- where `f1`, `f2` are instances of `f` that do *not* share weights.
103
-
104
- Args:
105
- module_cls (callable): A class or function that returns a Torch module/functional.
106
- Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
- to construct the real and imaginary component modules.
108
- """
109
-
110
- def __init__(self, module_cls, *args, **kwargs):
111
- super().__init__()
112
- self.re_module = module_cls(*args, **kwargs)
113
- self.im_module = module_cls(*args, **kwargs)
114
-
115
- def forward(self, x, *args, **kwargs):
116
- return torch_complex_from_reim(
117
- self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
- self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
- )
120
-
121
-
122
- ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
- ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
 
1
+ import functools
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from sgmse.util.registry import Registry
8
+
9
+
10
+ BackboneRegistry = Registry("Backbone")
11
+
12
+
13
+ class GaussianFourierProjection(nn.Module):
14
+ """Gaussian random features for encoding time steps."""
15
+
16
+ def __init__(self, embed_dim, scale=16, complex_valued=False):
17
+ super().__init__()
18
+ self.complex_valued = complex_valued
19
+ if not complex_valued:
20
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
21
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
22
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
23
+ # and this halving is not necessary.
24
+ embed_dim = embed_dim // 2
25
+ # Randomly sample weights during initialization. These weights are fixed
26
+ # during optimization and are not trainable.
27
+ self.W = nn.Parameter(torch.randn(embed_dim) * scale, requires_grad=False)
28
+
29
+ def forward(self, t):
30
+ t_proj = t[:, None] * self.W[None, :] * 2*np.pi
31
+ if self.complex_valued:
32
+ return torch.exp(1j * t_proj)
33
+ else:
34
+ return torch.cat([torch.sin(t_proj), torch.cos(t_proj)], dim=-1)
35
+
36
+
37
+ class DiffusionStepEmbedding(nn.Module):
38
+ """Diffusion-Step embedding as in DiffWave / Vaswani et al. 2017."""
39
+
40
+ def __init__(self, embed_dim, complex_valued=False):
41
+ super().__init__()
42
+ self.complex_valued = complex_valued
43
+ if not complex_valued:
44
+ # If the output is real-valued, we concatenate sin+cos of the features to avoid ambiguities.
45
+ # Therefore, in this case the effective embed_dim is cut in half. For the complex-valued case,
46
+ # we use complex numbers which each represent sin+cos directly, so the ambiguity is avoided directly,
47
+ # and this halving is not necessary.
48
+ embed_dim = embed_dim // 2
49
+ self.embed_dim = embed_dim
50
+
51
+ def forward(self, t):
52
+ fac = 10**(4*torch.arange(self.embed_dim, device=t.device) / (self.embed_dim-1))
53
+ inner = t[:, None] * fac[None, :]
54
+ if self.complex_valued:
55
+ return torch.exp(1j * inner)
56
+ else:
57
+ return torch.cat([torch.sin(inner), torch.cos(inner)], dim=-1)
58
+
59
+
60
+ class ComplexLinear(nn.Module):
61
+ """A potentially complex-valued linear layer. Reduces to a regular linear layer if `complex_valued=False`."""
62
+ def __init__(self, input_dim, output_dim, complex_valued):
63
+ super().__init__()
64
+ self.complex_valued = complex_valued
65
+ if self.complex_valued:
66
+ self.re = nn.Linear(input_dim, output_dim)
67
+ self.im = nn.Linear(input_dim, output_dim)
68
+ else:
69
+ self.lin = nn.Linear(input_dim, output_dim)
70
+
71
+ def forward(self, x):
72
+ if self.complex_valued:
73
+ return (self.re(x.real) - self.im(x.imag)) + 1j*(self.re(x.imag) + self.im(x.real))
74
+ else:
75
+ return self.lin(x)
76
+
77
+
78
+ class FeatureMapDense(nn.Module):
79
+ """A fully connected layer that reshapes outputs to feature maps."""
80
+
81
+ def __init__(self, input_dim, output_dim, complex_valued=False):
82
+ super().__init__()
83
+ self.complex_valued = complex_valued
84
+ self.dense = ComplexLinear(input_dim, output_dim, complex_valued=complex_valued)
85
+
86
+ def forward(self, x):
87
+ return self.dense(x)[..., None, None]
88
+
89
+
90
+ def torch_complex_from_reim(re, im):
91
+ return torch.view_as_complex(torch.stack([re, im], dim=-1))
92
+
93
+
94
+ class ArgsComplexMultiplicationWrapper(nn.Module):
95
+ """Adapted from `asteroid`'s `complex_nn.py`, allowing args/kwargs to be passed through forward().
96
+
97
+ Make a complex-valued module `F` from a real-valued module `f` by applying
98
+ complex multiplication rules:
99
+
100
+ F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a))
101
+
102
+ where `f1`, `f2` are instances of `f` that do *not* share weights.
103
+
104
+ Args:
105
+ module_cls (callable): A class or function that returns a Torch module/functional.
106
+ Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`,
107
+ to construct the real and imaginary component modules.
108
+ """
109
+
110
+ def __init__(self, module_cls, *args, **kwargs):
111
+ super().__init__()
112
+ self.re_module = module_cls(*args, **kwargs)
113
+ self.im_module = module_cls(*args, **kwargs)
114
+
115
+ def forward(self, x, *args, **kwargs):
116
+ return torch_complex_from_reim(
117
+ self.re_module(x.real, *args, **kwargs) - self.im_module(x.imag, *args, **kwargs),
118
+ self.re_module(x.imag, *args, **kwargs) + self.im_module(x.real, *args, **kwargs),
119
+ )
120
+
121
+
122
+ ComplexConv2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.Conv2d)
123
+ ComplexConvTranspose2d = functools.partial(ArgsComplexMultiplicationWrapper, nn.ConvTranspose2d)
sgmse/data_module.py CHANGED
@@ -1,236 +1,236 @@
1
-
2
- from os.path import join
3
- import torch
4
- import pytorch_lightning as pl
5
- from torch.utils.data import Dataset
6
- from torch.utils.data import DataLoader
7
- from glob import glob
8
- from torchaudio import load
9
- import numpy as np
10
- import torch.nn.functional as F
11
-
12
-
13
- def get_window(window_type, window_length):
14
- if window_type == 'sqrthann':
15
- return torch.sqrt(torch.hann_window(window_length, periodic=True))
16
- elif window_type == 'hann':
17
- return torch.hann_window(window_length, periodic=True)
18
- else:
19
- raise NotImplementedError(f"Window type {window_type} not implemented!")
20
-
21
-
22
- class Specs(Dataset):
23
- def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
24
- format='default', normalize="noisy", spec_transform=None,
25
- stft_kwargs=None, **ignored_kwargs):
26
-
27
- # Read file paths according to file naming format.
28
- if format == "default":
29
- self.clean_files = []
30
- self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
31
- self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
32
- self.noisy_files = []
33
- self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
34
- self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
35
- elif format == "reverb":
36
- self.clean_files = []
37
- self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
38
- self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
39
- self.noisy_files = []
40
- self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
41
- self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
42
- else:
43
- # Feel free to add your own directory format
44
- raise NotImplementedError(f"Directory format {format} unknown!")
45
-
46
- self.dummy = dummy
47
- self.num_frames = num_frames
48
- self.shuffle_spec = shuffle_spec
49
- self.normalize = normalize
50
- self.spec_transform = spec_transform
51
-
52
- assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
53
- self.stft_kwargs = stft_kwargs
54
- self.hop_length = self.stft_kwargs["hop_length"]
55
- assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
56
-
57
- def __getitem__(self, i):
58
- x, _ = load(self.clean_files[i])
59
- y, _ = load(self.noisy_files[i])
60
-
61
- # formula applies for center=True
62
- target_len = (self.num_frames - 1) * self.hop_length
63
- current_len = x.size(-1)
64
- pad = max(target_len - current_len, 0)
65
- if pad == 0:
66
- # extract random part of the audio file
67
- if self.shuffle_spec:
68
- start = int(np.random.uniform(0, current_len-target_len))
69
- else:
70
- start = int((current_len-target_len)/2)
71
- x = x[..., start:start+target_len]
72
- y = y[..., start:start+target_len]
73
- else:
74
- # pad audio if the length T is smaller than num_frames
75
- x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
76
- y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
77
-
78
- # normalize w.r.t to the noisy or the clean signal or not at all
79
- # to ensure same clean signal power in x and y.
80
- if self.normalize == "noisy":
81
- normfac = y.abs().max()
82
- elif self.normalize == "clean":
83
- normfac = x.abs().max()
84
- elif self.normalize == "not":
85
- normfac = 1.0
86
- x = x / normfac
87
- y = y / normfac
88
-
89
- X = torch.stft(x, **self.stft_kwargs)
90
- Y = torch.stft(y, **self.stft_kwargs)
91
-
92
- X, Y = self.spec_transform(X), self.spec_transform(Y)
93
- return X, Y
94
-
95
- def __len__(self):
96
- if self.dummy:
97
- # for debugging shrink the data set size
98
- return int(len(self.clean_files)/200)
99
- else:
100
- return len(self.clean_files)
101
-
102
-
103
- class SpecsDataModule(pl.LightningDataModule):
104
- @staticmethod
105
- def add_argparse_args(parser):
106
- parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
107
- parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
108
- parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
109
- parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
110
- parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
111
- parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
112
- parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
113
- parser.add_argument("--num_workers", type=int, default=16, help="Number of workers to use for DataLoaders. 4 by default.")
114
- parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
115
- parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
116
- parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
117
- parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
118
- parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
119
- return parser
120
-
121
- def __init__(
122
- self, base_dir, format='default', batch_size=8,
123
- n_fft=510, hop_length=128, num_frames=256, window='hann',
124
- num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
125
- gpu=True, normalize='noisy', transform_type="exponent", **kwargs
126
- ):
127
- super().__init__()
128
- self.base_dir = base_dir
129
- self.format = format
130
- self.batch_size = batch_size
131
- self.n_fft = n_fft
132
- self.hop_length = hop_length
133
- self.num_frames = num_frames
134
- self.window = get_window(window, self.n_fft)
135
- self.windows = {}
136
- self.num_workers = num_workers
137
- self.dummy = dummy
138
- self.spec_factor = spec_factor
139
- self.spec_abs_exponent = spec_abs_exponent
140
- self.gpu = gpu
141
- self.normalize = normalize
142
- self.transform_type = transform_type
143
- self.kwargs = kwargs
144
-
145
- def setup(self, stage=None):
146
- specs_kwargs = dict(
147
- stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
148
- spec_transform=self.spec_fwd, **self.kwargs
149
- )
150
- if stage == 'fit' or stage is None:
151
- self.train_set = Specs(data_dir=self.base_dir, subset='train',
152
- dummy=self.dummy, shuffle_spec=True, format=self.format,
153
- normalize=self.normalize, **specs_kwargs)
154
- self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
155
- dummy=self.dummy, shuffle_spec=False, format=self.format,
156
- normalize=self.normalize, **specs_kwargs)
157
- if stage == 'test' or stage is None:
158
- self.test_set = Specs(data_dir=self.base_dir, subset='test',
159
- dummy=self.dummy, shuffle_spec=False, format=self.format,
160
- normalize=self.normalize, **specs_kwargs)
161
-
162
- def spec_fwd(self, spec):
163
- if self.transform_type == "exponent":
164
- if self.spec_abs_exponent != 1:
165
- # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
166
- # and introduced numerical error
167
- e = self.spec_abs_exponent
168
- spec = spec.abs()**e * torch.exp(1j * spec.angle())
169
- spec = spec * self.spec_factor
170
- elif self.transform_type == "log":
171
- spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
172
- spec = spec * self.spec_factor
173
- elif self.transform_type == "none":
174
- spec = spec
175
- return spec
176
-
177
- def spec_back(self, spec):
178
- if self.transform_type == "exponent":
179
- spec = spec / self.spec_factor
180
- if self.spec_abs_exponent != 1:
181
- e = self.spec_abs_exponent
182
- spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
183
- elif self.transform_type == "log":
184
- spec = spec / self.spec_factor
185
- spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
186
- elif self.transform_type == "none":
187
- spec = spec
188
- return spec
189
-
190
- @property
191
- def stft_kwargs(self):
192
- return {**self.istft_kwargs, "return_complex": True}
193
-
194
- @property
195
- def istft_kwargs(self):
196
- return dict(
197
- n_fft=self.n_fft, hop_length=self.hop_length,
198
- window=self.window, center=True
199
- )
200
-
201
- def _get_window(self, x):
202
- """
203
- Retrieve an appropriate window for the given tensor x, matching the device.
204
- Caches the retrieved windows so that only one window tensor will be allocated per device.
205
- """
206
- window = self.windows.get(x.device, None)
207
- if window is None:
208
- window = self.window.to(x.device)
209
- self.windows[x.device] = window
210
- return window
211
-
212
- def stft(self, sig):
213
- window = self._get_window(sig)
214
- return torch.stft(sig, **{**self.stft_kwargs, "window": window})
215
-
216
- def istft(self, spec, length=None):
217
- window = self._get_window(spec)
218
- return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
219
-
220
- def train_dataloader(self):
221
- return DataLoader(
222
- self.train_set, batch_size=self.batch_size,
223
- num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
224
- )
225
-
226
- def val_dataloader(self):
227
- return DataLoader(
228
- self.valid_set, batch_size=self.batch_size,
229
- num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
230
- )
231
-
232
- def test_dataloader(self):
233
- return DataLoader(
234
- self.test_set, batch_size=self.batch_size,
235
- num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
236
- )
 
1
+
2
+ from os.path import join
3
+ import torch
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import Dataset
6
+ from torch.utils.data import DataLoader
7
+ from glob import glob
8
+ from torchaudio import load
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def get_window(window_type, window_length):
14
+ if window_type == 'sqrthann':
15
+ return torch.sqrt(torch.hann_window(window_length, periodic=True))
16
+ elif window_type == 'hann':
17
+ return torch.hann_window(window_length, periodic=True)
18
+ else:
19
+ raise NotImplementedError(f"Window type {window_type} not implemented!")
20
+
21
+
22
+ class Specs(Dataset):
23
+ def __init__(self, data_dir, subset, dummy, shuffle_spec, num_frames,
24
+ format='default', normalize="noisy", spec_transform=None,
25
+ stft_kwargs=None, **ignored_kwargs):
26
+
27
+ # Read file paths according to file naming format.
28
+ if format == "default":
29
+ self.clean_files = []
30
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
31
+ self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
32
+ self.noisy_files = []
33
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
34
+ self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
35
+ elif format == "reverb":
36
+ self.clean_files = []
37
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
38
+ self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
39
+ self.noisy_files = []
40
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
41
+ self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
42
+ else:
43
+ # Feel free to add your own directory format
44
+ raise NotImplementedError(f"Directory format {format} unknown!")
45
+
46
+ self.dummy = dummy
47
+ self.num_frames = num_frames
48
+ self.shuffle_spec = shuffle_spec
49
+ self.normalize = normalize
50
+ self.spec_transform = spec_transform
51
+
52
+ assert all(k in stft_kwargs.keys() for k in ["n_fft", "hop_length", "center", "window"]), "misconfigured STFT kwargs"
53
+ self.stft_kwargs = stft_kwargs
54
+ self.hop_length = self.stft_kwargs["hop_length"]
55
+ assert self.stft_kwargs.get("center", None) == True, "'center' must be True for current implementation"
56
+
57
+ def __getitem__(self, i):
58
+ x, _ = load(self.clean_files[i])
59
+ y, _ = load(self.noisy_files[i])
60
+
61
+ # formula applies for center=True
62
+ target_len = (self.num_frames - 1) * self.hop_length
63
+ current_len = x.size(-1)
64
+ pad = max(target_len - current_len, 0)
65
+ if pad == 0:
66
+ # extract random part of the audio file
67
+ if self.shuffle_spec:
68
+ start = int(np.random.uniform(0, current_len-target_len))
69
+ else:
70
+ start = int((current_len-target_len)/2)
71
+ x = x[..., start:start+target_len]
72
+ y = y[..., start:start+target_len]
73
+ else:
74
+ # pad audio if the length T is smaller than num_frames
75
+ x = F.pad(x, (pad//2, pad//2+(pad%2)), mode='constant')
76
+ y = F.pad(y, (pad//2, pad//2+(pad%2)), mode='constant')
77
+
78
+ # normalize w.r.t to the noisy or the clean signal or not at all
79
+ # to ensure same clean signal power in x and y.
80
+ if self.normalize == "noisy":
81
+ normfac = y.abs().max()
82
+ elif self.normalize == "clean":
83
+ normfac = x.abs().max()
84
+ elif self.normalize == "not":
85
+ normfac = 1.0
86
+ x = x / normfac
87
+ y = y / normfac
88
+
89
+ X = torch.stft(x, **self.stft_kwargs)
90
+ Y = torch.stft(y, **self.stft_kwargs)
91
+
92
+ X, Y = self.spec_transform(X), self.spec_transform(Y)
93
+ return X, Y
94
+
95
+ def __len__(self):
96
+ if self.dummy:
97
+ # for debugging shrink the data set size
98
+ return int(len(self.clean_files)/200)
99
+ else:
100
+ return len(self.clean_files)
101
+
102
+
103
+ class SpecsDataModule(pl.LightningDataModule):
104
+ @staticmethod
105
+ def add_argparse_args(parser):
106
+ parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
107
+ parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
108
+ parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
109
+ parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
110
+ parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
111
+ parser.add_argument("--num_frames", type=int, default=256, help="Number of frames for the dataset. 256 by default.")
112
+ parser.add_argument("--window", type=str, choices=("sqrthann", "hann"), default="hann", help="The window function to use for the STFT. 'hann' by default.")
113
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers to use for DataLoaders. 4 by default.")
114
+ parser.add_argument("--dummy", action="store_true", help="Use reduced dummy dataset for prototyping.")
115
+ parser.add_argument("--spec_factor", type=float, default=0.15, help="Factor to multiply complex STFT coefficients by. 0.15 by default.")
116
+ parser.add_argument("--spec_abs_exponent", type=float, default=0.5, help="Exponent e for the transformation abs(z)**e * exp(1j*angle(z)). 0.5 by default.")
117
+ parser.add_argument("--normalize", type=str, choices=("clean", "noisy", "not"), default="noisy", help="Normalize the input waveforms by the clean signal, the noisy signal, or not at all.")
118
+ parser.add_argument("--transform_type", type=str, choices=("exponent", "log", "none"), default="exponent", help="Spectogram transformation for input representation.")
119
+ return parser
120
+
121
+ def __init__(
122
+ self, base_dir, format='default', batch_size=8,
123
+ n_fft=510, hop_length=128, num_frames=256, window='hann',
124
+ num_workers=4, dummy=False, spec_factor=0.15, spec_abs_exponent=0.5,
125
+ gpu=True, normalize='noisy', transform_type="exponent", **kwargs
126
+ ):
127
+ super().__init__()
128
+ self.base_dir = base_dir
129
+ self.format = format
130
+ self.batch_size = batch_size
131
+ self.n_fft = n_fft
132
+ self.hop_length = hop_length
133
+ self.num_frames = num_frames
134
+ self.window = get_window(window, self.n_fft)
135
+ self.windows = {}
136
+ self.num_workers = num_workers
137
+ self.dummy = dummy
138
+ self.spec_factor = spec_factor
139
+ self.spec_abs_exponent = spec_abs_exponent
140
+ self.gpu = gpu
141
+ self.normalize = normalize
142
+ self.transform_type = transform_type
143
+ self.kwargs = kwargs
144
+
145
+ def setup(self, stage=None):
146
+ specs_kwargs = dict(
147
+ stft_kwargs=self.stft_kwargs, num_frames=self.num_frames,
148
+ spec_transform=self.spec_fwd, **self.kwargs
149
+ )
150
+ if stage == 'fit' or stage is None:
151
+ self.train_set = Specs(data_dir=self.base_dir, subset='train',
152
+ dummy=self.dummy, shuffle_spec=True, format=self.format,
153
+ normalize=self.normalize, **specs_kwargs)
154
+ self.valid_set = Specs(data_dir=self.base_dir, subset='valid',
155
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
156
+ normalize=self.normalize, **specs_kwargs)
157
+ if stage == 'test' or stage is None:
158
+ self.test_set = Specs(data_dir=self.base_dir, subset='test',
159
+ dummy=self.dummy, shuffle_spec=False, format=self.format,
160
+ normalize=self.normalize, **specs_kwargs)
161
+
162
+ def spec_fwd(self, spec):
163
+ if self.transform_type == "exponent":
164
+ if self.spec_abs_exponent != 1:
165
+ # only do this calculation if spec_exponent != 1, otherwise it's quite a bit of wasted computation
166
+ # and introduced numerical error
167
+ e = self.spec_abs_exponent
168
+ spec = spec.abs()**e * torch.exp(1j * spec.angle())
169
+ spec = spec * self.spec_factor
170
+ elif self.transform_type == "log":
171
+ spec = torch.log(1 + spec.abs()) * torch.exp(1j * spec.angle())
172
+ spec = spec * self.spec_factor
173
+ elif self.transform_type == "none":
174
+ spec = spec
175
+ return spec
176
+
177
+ def spec_back(self, spec):
178
+ if self.transform_type == "exponent":
179
+ spec = spec / self.spec_factor
180
+ if self.spec_abs_exponent != 1:
181
+ e = self.spec_abs_exponent
182
+ spec = spec.abs()**(1/e) * torch.exp(1j * spec.angle())
183
+ elif self.transform_type == "log":
184
+ spec = spec / self.spec_factor
185
+ spec = (torch.exp(spec.abs()) - 1) * torch.exp(1j * spec.angle())
186
+ elif self.transform_type == "none":
187
+ spec = spec
188
+ return spec
189
+
190
+ @property
191
+ def stft_kwargs(self):
192
+ return {**self.istft_kwargs, "return_complex": True}
193
+
194
+ @property
195
+ def istft_kwargs(self):
196
+ return dict(
197
+ n_fft=self.n_fft, hop_length=self.hop_length,
198
+ window=self.window, center=True
199
+ )
200
+
201
+ def _get_window(self, x):
202
+ """
203
+ Retrieve an appropriate window for the given tensor x, matching the device.
204
+ Caches the retrieved windows so that only one window tensor will be allocated per device.
205
+ """
206
+ window = self.windows.get(x.device, None)
207
+ if window is None:
208
+ window = self.window.to(x.device)
209
+ self.windows[x.device] = window
210
+ return window
211
+
212
+ def stft(self, sig):
213
+ window = self._get_window(sig)
214
+ return torch.stft(sig, **{**self.stft_kwargs, "window": window})
215
+
216
+ def istft(self, spec, length=None):
217
+ window = self._get_window(spec)
218
+ return torch.istft(spec, **{**self.istft_kwargs, "window": window, "length": length})
219
+
220
+ def train_dataloader(self):
221
+ return DataLoader(
222
+ self.train_set, batch_size=self.batch_size,
223
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=True
224
+ )
225
+
226
+ def val_dataloader(self):
227
+ return DataLoader(
228
+ self.valid_set, batch_size=self.batch_size,
229
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
230
+ )
231
+
232
+ def test_dataloader(self):
233
+ return DataLoader(
234
+ self.test_set, batch_size=self.batch_size,
235
+ num_workers=self.num_workers, pin_memory=self.gpu, shuffle=False
236
+ )
sgmse/model.py CHANGED
@@ -1,468 +1,471 @@
1
- import time
2
- from math import ceil
3
- import warnings
4
-
5
- import torch
6
- import pytorch_lightning as pl
7
- import torch.distributed as dist
8
- from torchaudio import load
9
- from torch_ema import ExponentialMovingAverage
10
- from librosa import resample
11
-
12
- from sgmse import sampling
13
- from sgmse.sdes import SDERegistry
14
- from sgmse.backbones import BackboneRegistry
15
- from sgmse.util.inference import evaluate_model
16
- from sgmse.util.other import pad_spec, si_sdr
17
- from pesq import pesq
18
- from pystoi import stoi
19
- from torch_pesq import PesqLoss
20
- import time
21
-
22
- class ScoreModel(pl.LightningModule):
23
- @staticmethod
24
- def add_argparse_args(parser):
25
- parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
26
- parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
27
- parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
28
- parser.add_argument("--num_eval_files", type=int, default=50, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
29
- parser.add_argument("--loss_type", type=str, default="score_matching", help="The type of loss function to use.")
30
- parser.add_argument("--loss_weighting", type=str, default="sigma^2", help="The weighting of the loss function.")
31
- parser.add_argument("--network_scaling", type=str, default=None, help="The type of loss scaling to use.")
32
- parser.add_argument("--c_in", type=str, default="1", help="The input scaling for x.")
33
- parser.add_argument("--c_out", type=str, default="1", help="The output scaling.")
34
- parser.add_argument("--c_skip", type=str, default="0", help="The skip connection scaling.")
35
- parser.add_argument("--sigma_data", type=float, default=0.1, help="The data standard deviation.")
36
- parser.add_argument("--l1_weight", type=float, default=0.001, help="The balance between the time-frequency and time-domain losses.")
37
- parser.add_argument("--pesq_weight", type=float, default=0.0, help="The balance between the time-frequency and time-domain losses.")
38
- parser.add_argument("--sr", type=int, default=16000, help="The sample rate of the audio files.")
39
- return parser
40
-
41
- def __init__(
42
- self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, num_eval_files=20, loss_type='score_matching',
43
- loss_weighting='sigma^2', network_scaling=None, c_in='1', c_out='1', c_skip='0', sigma_data=0.1,
44
- l1_weight=0.001, pesq_weight=0.0, sr=16000, data_module_cls=None, **kwargs
45
- ):
46
- """
47
- Create a new ScoreModel.
48
-
49
- Args:
50
- backbone: Backbone DNN that serves as a score-based model.
51
- sde: The SDE that defines the diffusion process.
52
- lr: The learning rate of the optimizer. (1e-4 by default).
53
- ema_decay: The decay constant of the parameter EMA (0.999 by default).
54
- t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
55
- loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
56
- """
57
- super().__init__()
58
- # Initialize Backbone DNN
59
- self.backbone = backbone
60
- dnn_cls = BackboneRegistry.get_by_name(backbone)
61
- self.dnn = dnn_cls(**kwargs)
62
- # Initialize SDE
63
- sde_cls = SDERegistry.get_by_name(sde)
64
- self.sde = sde_cls(**kwargs)
65
- # Store hyperparams and save them
66
- self.lr = lr
67
- self.ema_decay = ema_decay
68
- self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
69
- self._error_loading_ema = False
70
- self.t_eps = t_eps
71
- self.loss_type = loss_type
72
- self.loss_weighting = loss_weighting
73
- self.l1_weight = l1_weight
74
- self.pesq_weight = pesq_weight
75
- self.network_scaling = network_scaling
76
- self.c_in = c_in
77
- self.c_out = c_out
78
- self.c_skip = c_skip
79
- self.sigma_data = sigma_data
80
- self.num_eval_files = num_eval_files
81
- self.sr = sr
82
- # Initialize PESQ loss if pesq_weight > 0.0
83
- if pesq_weight > 0.0:
84
- self.pesq_loss = PesqLoss(1.0, sample_rate=sr).eval()
85
- for param in self.pesq_loss.parameters():
86
- param.requires_grad = False
87
- self.save_hyperparameters(ignore=['no_wandb'])
88
- self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
89
-
90
- def configure_optimizers(self):
91
- optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
92
- return optimizer
93
-
94
- def optimizer_step(self, *args, **kwargs):
95
- # Method overridden so that the EMA params are updated after each optimizer step
96
- super().optimizer_step(*args, **kwargs)
97
- self.ema.update(self.dnn.parameters())
98
-
99
- # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
100
- def on_load_checkpoint(self, checkpoint):
101
- ema = checkpoint.get('ema', None)
102
- if ema is not None:
103
- self.ema.load_state_dict(checkpoint['ema'])
104
- else:
105
- self._error_loading_ema = True
106
- warnings.warn("EMA state_dict not found in checkpoint!")
107
-
108
- def on_save_checkpoint(self, checkpoint):
109
- checkpoint['ema'] = self.ema.state_dict()
110
-
111
- def train(self, mode, no_ema=False):
112
- res = super().train(mode) # call the standard `train` method with the given mode
113
- if not self._error_loading_ema:
114
- if mode == False and not no_ema:
115
- # eval
116
- self.ema.store(self.dnn.parameters()) # store current params in EMA
117
- self.ema.copy_to(self.dnn.parameters()) # copy EMA parameters over current params for evaluation
118
- else:
119
- # train
120
- if self.ema.collected_params is not None:
121
- self.ema.restore(self.dnn.parameters()) # restore the EMA weights (if stored)
122
- return res
123
-
124
- def eval(self, no_ema=False):
125
- return self.train(False, no_ema=no_ema)
126
-
127
- def _loss(self, forward_out, x_t, z, t, mean, x):
128
- """
129
- Different loss functions can be used to train the score model, see the paper:
130
-
131
- Julius Richter, Danilo de Oliveira, and Timo Gerkmann
132
- "Investigating Training Objectives for Generative Speech Enhancement"
133
- https://arxiv.org/abs/2409.10753
134
-
135
- """
136
-
137
- sigma = self.sde._std(t)[:, None, None, None]
138
-
139
- if self.loss_type == "score_matching":
140
- score = forward_out
141
- if self.loss_weighting == "sigma^2":
142
- losses = torch.square(torch.abs(score * sigma + z)) # Eq. (7)
143
- else:
144
- raise ValueError("Invalid loss weighting for loss_type=score_matching: {}".format(self.loss_weighting))
145
- # Sum over spatial dimensions and channels and mean over batch
146
- loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
147
- elif self.loss_type == "denoiser":
148
- score = forward_out
149
- D = score * sigma.pow(2) + x_t # equivalent to Eq. (10)
150
- losses = torch.square(torch.abs(D - mean)) # Eq. (8)
151
- if self.loss_weighting == "1":
152
- losses = losses
153
- elif self.loss_weighting == "sigma^2":
154
- losses = losses * sigma**2
155
- elif self.loss_weighting == "edm":
156
- losses = ((sigma**2 + self.sigma_data**2)/((sigma*self.sigma_data)**2))[:, None, None, None] * losses
157
- else:
158
- raise ValueError("Invalid loss weighting for loss_type=denoiser: {}".format(self.loss_weighting))
159
- # Sum over spatial dimensions and channels and mean over batch
160
- loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
161
- elif self.loss_type == "data_prediction":
162
- x_hat = forward_out
163
- B, C, F, T = x.shape
164
-
165
- # losses in the time-frequency domain (tf)
166
- losses_tf = (1/(F*T))*torch.square(torch.abs(x_hat - x))
167
- losses_tf = torch.mean(0.5*torch.sum(losses_tf.reshape(losses_tf.shape[0], -1), dim=-1))
168
-
169
- # losses in the time domain (td)
170
- target_len = (self.data_module.num_frames - 1) * self.data_module.hop_length
171
- x_hat_td = self.to_audio(x_hat.squeeze(), target_len)
172
- x_td = self.to_audio(x.squeeze(), target_len)
173
- losses_l1 = (1 / target_len) * torch.abs(x_hat_td - x_td)
174
- losses_l1 = torch.mean(0.5*torch.sum(losses_l1.reshape(losses_l1.shape[0], -1), dim=-1))
175
-
176
- # losses using PESQ
177
- if self.pesq_weight > 0.0:
178
- losses_pesq = self.pesq_loss(x_td, x_hat_td)
179
- losses_pesq = torch.mean(losses_pesq)
180
- # combine the losses
181
- loss = losses_tf + self.l1_weight * losses_l1 + self.pesq_weight * losses_pesq
182
- else:
183
- loss = losses_tf + self.l1_weight * losses_l1
184
- else:
185
- raise ValueError("Invalid loss type: {}".format(self.loss_type))
186
-
187
- return loss
188
-
189
- def _step(self, batch, batch_idx):
190
- x, y = batch
191
- t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
192
- mean, std = self.sde.marginal_prob(x, y, t)
193
- z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
194
- sigma = std[:, None, None, None]
195
- x_t = mean + sigma * z
196
- forward_out = self(x_t, y, t)
197
- loss = self._loss(forward_out, x_t, z, t, mean, x)
198
- return loss
199
-
200
- def training_step(self, batch, batch_idx):
201
- loss = self._step(batch, batch_idx)
202
- self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
203
- return loss
204
-
205
- def validation_step(self, batch, batch_idx):
206
- # Evaluate speech enhancement performance
207
- if batch_idx == 0 and self.num_eval_files != 0:
208
- rank = dist.get_rank()
209
- world_size = dist.get_world_size()
210
-
211
- # Split the evaluation files among the GPUs
212
- eval_files_per_gpu = self.num_eval_files // world_size
213
-
214
- clean_files = self.data_module.valid_set.clean_files[:self.num_eval_files]
215
- noisy_files = self.data_module.valid_set.noisy_files[:self.num_eval_files]
216
-
217
- print(f"Process: {len(clean_files)} files")
218
- # Select the files for this GPU
219
- if rank == world_size - 1:
220
- clean_files = clean_files[rank*eval_files_per_gpu:]
221
- noisy_files = noisy_files[rank*eval_files_per_gpu:]
222
- else:
223
- clean_files = clean_files[rank*eval_files_per_gpu:(rank+1)*eval_files_per_gpu]
224
- noisy_files = noisy_files[rank*eval_files_per_gpu:(rank+1)*eval_files_per_gpu]
225
-
226
- # Evaluate the performance of the model
227
- pesq_sum = 0; si_sdr_sum = 0; estoi_sum = 0;
228
- start_time = time.time()
229
- for (clean_file, noisy_file) in zip(clean_files, noisy_files):
230
- # Load the clean and noisy speech
231
- x, sr_x = load(clean_file)
232
- x = x.squeeze().numpy()
233
- y, sr_y = load(noisy_file)
234
- assert sr_x == sr_y, "Sample rates of clean and noisy files do not match!"
235
-
236
- # Resample if necessary
237
- if sr_x != 16000:
238
- x_16k = resample(x, orig_sr=sr_x, target_sr=16000).squeeze()
239
- else:
240
- x_16k = x
241
-
242
- # Enhance the noisy speech
243
- x_hat = self.enhance(y, N=self.sde.N)
244
- if self.sr != 16000:
245
- x_hat_16k = resample(x_hat, orig_sr=self.sr, target_sr=16000).squeeze()
246
- else:
247
- x_hat_16k = x_hat
248
-
249
- pesq_sum += pesq(16000, x_16k, x_hat_16k, 'wb')
250
- si_sdr_sum += si_sdr(x, x_hat)
251
- estoi_sum += stoi(x, x_hat, self.sr, extended=True)
252
-
253
- print(f"process {eval_files_per_gpu} in {time.time()-start_time}")
254
- pesq_avg = pesq_sum / len(clean_files)
255
- si_sdr_avg = si_sdr_sum / len(clean_files)
256
- estoi_avg = estoi_sum / len(clean_files)
257
-
258
- self.log('pesq', pesq_avg, on_step=False, on_epoch=True, sync_dist=True)
259
- self.log('si_sdr', si_sdr_avg, on_step=False, on_epoch=True, sync_dist=True)
260
- self.log('estoi', estoi_avg, on_step=False, on_epoch=True, sync_dist=True)
261
-
262
- loss = self._step(batch, batch_idx)
263
- self.log('valid_loss', loss, on_step=False, on_epoch=True, sync_dist=True)
264
-
265
- return loss
266
-
267
- def forward(self, x_t, y, t):
268
- """
269
- The model forward pass. In [1] and [2], the model estimates the score function. In [3], the model estimates
270
- either the score function or the target data for the Schrödinger bridge (loss_type='data_prediction').
271
-
272
- [1] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, and Timo Gerkmann
273
- "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"
274
- IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.
275
-
276
- [2] Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, and Timo Gerkmann
277
- "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation"
278
- ISCA Interspecch, Kos, Greece, Sept. 2024.
279
-
280
- [3] Julius Richter, Danilo de Oliveira, and Timo Gerkmann
281
- "Investigating Training Objectives for Generative Speech Enhancement"
282
- https://arxiv.org/abs/2409.10753
283
-
284
- """
285
-
286
- # In [3], we use new code with backbone='ncsnpp_v2':
287
- if self.backbone == "ncsnpp_v2":
288
- F = self.dnn(self._c_in(t) * x_t, self._c_in(t) * y, t)
289
-
290
- # Scaling the network output, see below Eq. (7) in the paper
291
- if self.network_scaling == "1/sigma":
292
- std = self.sde._std(t)
293
- F = F / std[:, None, None, None]
294
- elif self.network_scaling == "1/t":
295
- F = F / t[:, None, None, None]
296
-
297
- # The loss type determines the output of the model
298
- if self.loss_type == "score_matching":
299
- score = self._c_skip(t) * x_t + self._c_out(t) * F
300
- return score
301
- elif self.loss_type == "denoiser":
302
- sigmas = self.sde._std(t)[:, None, None, None]
303
- score = (F - x_t) / sigmas.pow(2)
304
- return score
305
- elif self.loss_type == 'data_prediction':
306
- x_hat = self._c_skip(t) * x_t + self._c_out(t) * F
307
- return x_hat
308
-
309
- # In [1] and [2], we use the old code:
310
- else:
311
- dnn_input = torch.cat([x_t, y], dim=1)
312
- score = -self.dnn(dnn_input, t)
313
- return score
314
-
315
- def _c_in(self, t):
316
- if self.c_in == "1":
317
- return 1.0
318
- elif self.c_in == "edm":
319
- sigma = self.sde._std(t)
320
- return (1.0 / torch.sqrt(sigma**2 + self.sigma_data**2))[:, None, None, None]
321
- else:
322
- raise ValueError("Invalid c_in type: {}".format(self.c_in))
323
-
324
- def _c_out(self, t):
325
- if self.c_out == "1":
326
- return 1.0
327
- elif self.c_out == "sigma":
328
- return self.sde._std(t)[:, None, None, None]
329
- elif self.c_out == "1/sigma":
330
- return 1.0 / self.sde._std(t)[:, None, None, None]
331
- elif self.c_out == "edm":
332
- sigma = self.sde._std(t)
333
- return ((sigma * self.sigma_data) / torch.sqrt(self.sigma_data**2 + sigma**2))[:, None, None, None]
334
- else:
335
- raise ValueError("Invalid c_out type: {}".format(self.c_out))
336
-
337
- def _c_skip(self, t):
338
- if self.c_skip == "0":
339
- return 0.0
340
- elif self.c_skip == "edm":
341
- sigma = self.sde._std(t)
342
- return (self.sigma_data**2 / (sigma**2 + self.sigma_data**2))[:, None, None, None]
343
- else:
344
- raise ValueError("Invalid c_skip type: {}".format(self.c_skip))
345
-
346
- def to(self, *args, **kwargs):
347
- """Override PyTorch .to() to also transfer the EMA of the model weights"""
348
- self.ema.to(*args, **kwargs)
349
- return super().to(*args, **kwargs)
350
-
351
- def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
352
- N = self.sde.N if N is None else N
353
- sde = self.sde.copy()
354
- sde.N = N
355
-
356
- kwargs = {"eps": self.t_eps, **kwargs}
357
- if minibatch is None:
358
- return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
359
- else:
360
- M = y.shape[0]
361
- def batched_sampling_fn():
362
- samples, ns = [], []
363
- for i in range(int(ceil(M / minibatch))):
364
- y_mini = y[i*minibatch:(i+1)*minibatch]
365
- sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
366
- sample, n = sampler()
367
- samples.append(sample)
368
- ns.append(n)
369
- samples = torch.cat(samples, dim=0)
370
- return samples, ns
371
- return batched_sampling_fn
372
-
373
- def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
374
- N = self.sde.N if N is None else N
375
- sde = self.sde.copy()
376
- sde.N = N
377
-
378
- kwargs = {"eps": self.t_eps, **kwargs}
379
- if minibatch is None:
380
- return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
381
- else:
382
- M = y.shape[0]
383
- def batched_sampling_fn():
384
- samples, ns = [], []
385
- for i in range(int(ceil(M / minibatch))):
386
- y_mini = y[i*minibatch:(i+1)*minibatch]
387
- sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
388
- sample, n = sampler()
389
- samples.append(sample)
390
- ns.append(n)
391
- samples = torch.cat(samples, dim=0)
392
- return sample, ns
393
- return batched_sampling_fn
394
-
395
- def get_sb_sampler(self, sde, y, sampler_type="ode", N=None, **kwargs):
396
- N = sde.N if N is None else N
397
- sde = self.sde.copy()
398
- sde.N = N if N is not None else sde.N
399
-
400
- return sampling.get_sb_sampler(sde, self, y=y, sampler_type=sampler_type, **kwargs)
401
-
402
- def train_dataloader(self):
403
- return self.data_module.train_dataloader()
404
-
405
- def val_dataloader(self):
406
- return self.data_module.val_dataloader()
407
-
408
- def test_dataloader(self):
409
- return self.data_module.test_dataloader()
410
-
411
- def setup(self, stage=None):
412
- return self.data_module.setup(stage=stage)
413
-
414
- def to_audio(self, spec, length=None):
415
- return self._istft(self._backward_transform(spec), length)
416
-
417
- def _forward_transform(self, spec):
418
- return self.data_module.spec_fwd(spec)
419
-
420
- def _backward_transform(self, spec):
421
- return self.data_module.spec_back(spec)
422
-
423
- def _stft(self, sig):
424
- return self.data_module.stft(sig)
425
-
426
- def _istft(self, spec, length=None):
427
- return self.data_module.istft(spec, length)
428
-
429
- def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
430
- corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
431
- **kwargs
432
- ):
433
- """
434
- One-call speech enhancement of noisy speech `y`, for convenience.
435
- """
436
- start = time.time()
437
- T_orig = y.size(1)
438
- norm_factor = y.abs().max().item()
439
- y = y / norm_factor
440
- Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
441
- Y = pad_spec(Y)
442
-
443
- # SGMSE sampling with OUVE SDE
444
- if self.sde.__class__.__name__ == 'OUVESDE':
445
- if self.sde.sampler_type == "pc":
446
- sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
447
- corrector_steps=corrector_steps, snr=snr, intermediate=False,
448
- **kwargs)
449
- elif self.sde.sampler_type == "ode":
450
- sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
451
- else:
452
- raise ValueError("Invalid sampler type for SGMSE sampling: {}".format(sampler_type))
453
- # Schrödinger bridge sampling with VE SDE
454
- elif self.sde.__class__.__name__ == 'SBVESDE':
455
- sampler = self.get_sb_sampler(sde=self.sde, y=Y.cuda(), sampler_type=self.sde.sampler_type)
456
- else:
457
- raise ValueError("Invalid SDE type for speech enhancement: {}".format(self.sde.__class__.__name__))
458
-
459
- sample, nfe = sampler()
460
- x_hat = self.to_audio(sample.squeeze(), T_orig)
461
- x_hat = x_hat * norm_factor
462
- x_hat = x_hat.squeeze().cpu().numpy()
463
- end = time.time()
464
- if timeit:
465
- rtf = (end-start)/(len(x_hat)/self.sr)
466
- return x_hat, nfe, rtf
467
- else:
468
- return x_hat
 
 
 
 
1
+ import time
2
+ from math import ceil
3
+ import warnings
4
+
5
+ import torch
6
+ import pytorch_lightning as pl
7
+ import torch.distributed as dist
8
+ from torchaudio import load
9
+ from torch_ema import ExponentialMovingAverage
10
+ from librosa import resample
11
+
12
+ from sgmse import sampling
13
+ from sgmse.sdes import SDERegistry
14
+ from sgmse.backbones import BackboneRegistry
15
+ from sgmse.util.inference import evaluate_model
16
+ from sgmse.util.other import pad_spec, si_sdr
17
+ from pesq import pesq
18
+ from pystoi import stoi
19
+ from torch_pesq import PesqLoss
20
+ import time
21
+
22
+ class ScoreModel(pl.LightningModule):
23
+ @staticmethod
24
+ def add_argparse_args(parser):
25
+ parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
26
+ parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
27
+ parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
28
+ parser.add_argument("--num_eval_files", type=int, default=50, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
29
+ parser.add_argument("--loss_type", type=str, default="score_matching", help="The type of loss function to use.")
30
+ parser.add_argument("--loss_weighting", type=str, default="sigma^2", help="The weighting of the loss function.")
31
+ parser.add_argument("--network_scaling", type=str, default=None, help="The type of loss scaling to use.")
32
+ parser.add_argument("--c_in", type=str, default="1", help="The input scaling for x.")
33
+ parser.add_argument("--c_out", type=str, default="1", help="The output scaling.")
34
+ parser.add_argument("--c_skip", type=str, default="0", help="The skip connection scaling.")
35
+ parser.add_argument("--sigma_data", type=float, default=0.1, help="The data standard deviation.")
36
+ parser.add_argument("--l1_weight", type=float, default=0.001, help="The balance between the time-frequency and time-domain losses.")
37
+ parser.add_argument("--pesq_weight", type=float, default=0.0, help="The balance between the time-frequency and time-domain losses.")
38
+ parser.add_argument("--sr", type=int, default=16000, help="The sample rate of the audio files.")
39
+ parser.add_argument("--k", type=float, default=2.6, help="Parameter of the diffusion coefficient. 2.6 by default.")
40
+ parser.add_argument("--c", type=float, default=0.4, help="Parameter of the diffusion coefficient. 0.4 by default.")
41
+
42
+ return parser
43
+
44
+ def __init__(
45
+ self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, num_eval_files=20, loss_type='score_matching',
46
+ loss_weighting='sigma^2', network_scaling=None, c_in='1', c_out='1', c_skip='0', sigma_data=0.1,
47
+ l1_weight=0.001, pesq_weight=0.0, sr=16000, data_module_cls=None, **kwargs
48
+ ):
49
+ """
50
+ Create a new ScoreModel.
51
+
52
+ Args:
53
+ backbone: Backbone DNN that serves as a score-based model.
54
+ sde: The SDE that defines the diffusion process.
55
+ lr: The learning rate of the optimizer. (1e-4 by default).
56
+ ema_decay: The decay constant of the parameter EMA (0.999 by default).
57
+ t_eps: The minimum time to practically run for to avoid issues very close to zero (1e-5 by default).
58
+ loss_type: The type of loss to use (wrt. noise z/std). Options are 'mse' (default), 'mae'
59
+ """
60
+ super().__init__()
61
+ # Initialize Backbone DNN
62
+ self.backbone = backbone
63
+ dnn_cls = BackboneRegistry.get_by_name(backbone)
64
+ self.dnn = dnn_cls(**kwargs)
65
+ # Initialize SDE
66
+ sde_cls = SDERegistry.get_by_name(sde)
67
+ self.sde = sde_cls(**kwargs)
68
+ # Store hyperparams and save them
69
+ self.lr = lr
70
+ self.ema_decay = ema_decay
71
+ self.ema = ExponentialMovingAverage(self.parameters(), decay=self.ema_decay)
72
+ self._error_loading_ema = False
73
+ self.t_eps = t_eps
74
+ self.loss_type = loss_type
75
+ self.loss_weighting = loss_weighting
76
+ self.l1_weight = l1_weight
77
+ self.pesq_weight = pesq_weight
78
+ self.network_scaling = network_scaling
79
+ self.c_in = c_in
80
+ self.c_out = c_out
81
+ self.c_skip = c_skip
82
+ self.sigma_data = sigma_data
83
+ self.num_eval_files = num_eval_files
84
+ self.sr = sr
85
+ # Initialize PESQ loss if pesq_weight > 0.0
86
+ if pesq_weight > 0.0:
87
+ self.pesq_loss = PesqLoss(1.0, sample_rate=sr).eval()
88
+ for param in self.pesq_loss.parameters():
89
+ param.requires_grad = False
90
+ self.save_hyperparameters(ignore=['no_wandb'])
91
+ self.data_module = data_module_cls(**kwargs, gpu=kwargs.get('gpus', 0) > 0)
92
+
93
+ def configure_optimizers(self):
94
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
95
+ return optimizer
96
+
97
+ def optimizer_step(self, *args, **kwargs):
98
+ # Method overridden so that the EMA params are updated after each optimizer step
99
+ super().optimizer_step(*args, **kwargs)
100
+ self.ema.update(self.dnn.parameters())
101
+
102
+ # on_load_checkpoint / on_save_checkpoint needed for EMA storing/loading
103
+ def on_load_checkpoint(self, checkpoint):
104
+ ema = checkpoint.get('ema', None)
105
+ if ema is not None:
106
+ self.ema.load_state_dict(checkpoint['ema'])
107
+ else:
108
+ self._error_loading_ema = True
109
+ warnings.warn("EMA state_dict not found in checkpoint!")
110
+
111
+ def on_save_checkpoint(self, checkpoint):
112
+ checkpoint['ema'] = self.ema.state_dict()
113
+
114
+ def train(self, mode, no_ema=False):
115
+ res = super().train(mode) # call the standard `train` method with the given mode
116
+ if not self._error_loading_ema:
117
+ if mode == False and not no_ema:
118
+ # eval
119
+ self.ema.store(self.dnn.parameters()) # store current params in EMA
120
+ self.ema.copy_to(self.dnn.parameters()) # copy EMA parameters over current params for evaluation
121
+ else:
122
+ # train
123
+ if self.ema.collected_params is not None:
124
+ self.ema.restore(self.dnn.parameters()) # restore the EMA weights (if stored)
125
+ return res
126
+
127
+ def eval(self, no_ema=False):
128
+ return self.train(False, no_ema=no_ema)
129
+
130
+ def _loss(self, forward_out, x_t, z, t, mean, x):
131
+ """
132
+ Different loss functions can be used to train the score model, see the paper:
133
+
134
+ Julius Richter, Danilo de Oliveira, and Timo Gerkmann
135
+ "Investigating Training Objectives for Generative Speech Enhancement"
136
+ https://arxiv.org/abs/2409.10753
137
+
138
+ """
139
+
140
+ sigma = self.sde._std(t)[:, None, None, None]
141
+
142
+ if self.loss_type == "score_matching":
143
+ score = forward_out
144
+ if self.loss_weighting == "sigma^2":
145
+ losses = torch.square(torch.abs(score * sigma + z)) # Eq. (7)
146
+ else:
147
+ raise ValueError("Invalid loss weighting for loss_type=score_matching: {}".format(self.loss_weighting))
148
+ # Sum over spatial dimensions and channels and mean over batch
149
+ loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
150
+ elif self.loss_type == "denoiser":
151
+ score = forward_out
152
+ D = score * sigma.pow(2) + x_t # equivalent to Eq. (10)
153
+ losses = torch.square(torch.abs(D - mean)) # Eq. (8)
154
+ if self.loss_weighting == "1":
155
+ losses = losses
156
+ elif self.loss_weighting == "sigma^2":
157
+ losses = losses * sigma**2
158
+ elif self.loss_weighting == "edm":
159
+ losses = ((sigma**2 + self.sigma_data**2)/((sigma*self.sigma_data)**2))[:, None, None, None] * losses
160
+ else:
161
+ raise ValueError("Invalid loss weighting for loss_type=denoiser: {}".format(self.loss_weighting))
162
+ # Sum over spatial dimensions and channels and mean over batch
163
+ loss = torch.mean(0.5*torch.sum(losses.reshape(losses.shape[0], -1), dim=-1))
164
+ elif self.loss_type == "data_prediction":
165
+ x_hat = forward_out
166
+ B, C, F, T = x.shape
167
+
168
+ # losses in the time-frequency domain (tf)
169
+ losses_tf = (1/(F*T))*torch.square(torch.abs(x_hat - x))
170
+ losses_tf = torch.mean(0.5*torch.sum(losses_tf.reshape(losses_tf.shape[0], -1), dim=-1))
171
+
172
+ # losses in the time domain (td)
173
+ target_len = (self.data_module.num_frames - 1) * self.data_module.hop_length
174
+ x_hat_td = self.to_audio(x_hat.squeeze(), target_len)
175
+ x_td = self.to_audio(x.squeeze(), target_len)
176
+ losses_l1 = (1 / target_len) * torch.abs(x_hat_td - x_td)
177
+ losses_l1 = torch.mean(0.5*torch.sum(losses_l1.reshape(losses_l1.shape[0], -1), dim=-1))
178
+
179
+ # losses using PESQ
180
+ if self.pesq_weight > 0.0:
181
+ losses_pesq = self.pesq_loss(x_td, x_hat_td)
182
+ losses_pesq = torch.mean(losses_pesq)
183
+ # combine the losses
184
+ loss = losses_tf + self.l1_weight * losses_l1 + self.pesq_weight * losses_pesq
185
+ else:
186
+ loss = losses_tf + self.l1_weight * losses_l1
187
+ else:
188
+ raise ValueError("Invalid loss type: {}".format(self.loss_type))
189
+
190
+ return loss
191
+
192
+ def _step(self, batch, batch_idx):
193
+ x, y = batch
194
+ t = torch.rand(x.shape[0], device=x.device) * (self.sde.T - self.t_eps) + self.t_eps
195
+ mean, std = self.sde.marginal_prob(x, y, t)
196
+ z = torch.randn_like(x) # i.i.d. normal distributed with var=0.5
197
+ sigma = std[:, None, None, None]
198
+ x_t = mean + sigma * z
199
+ forward_out = self(x_t, y, t)
200
+ loss = self._loss(forward_out, x_t, z, t, mean, x)
201
+ return loss
202
+
203
+ def training_step(self, batch, batch_idx):
204
+ loss = self._step(batch, batch_idx)
205
+ self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
206
+ return loss
207
+
208
+ def validation_step(self, batch, batch_idx):
209
+ # Evaluate speech enhancement performance
210
+ if batch_idx == 0 and self.num_eval_files != 0:
211
+ rank = dist.get_rank()
212
+ world_size = dist.get_world_size()
213
+
214
+ # Split the evaluation files among the GPUs
215
+ eval_files_per_gpu = self.num_eval_files // world_size
216
+
217
+ clean_files = self.data_module.valid_set.clean_files[:self.num_eval_files]
218
+ noisy_files = self.data_module.valid_set.noisy_files[:self.num_eval_files]
219
+
220
+ print(f"Process: {len(clean_files)} files")
221
+ # Select the files for this GPU
222
+ if rank == world_size - 1:
223
+ clean_files = clean_files[rank*eval_files_per_gpu:]
224
+ noisy_files = noisy_files[rank*eval_files_per_gpu:]
225
+ else:
226
+ clean_files = clean_files[rank*eval_files_per_gpu:(rank+1)*eval_files_per_gpu]
227
+ noisy_files = noisy_files[rank*eval_files_per_gpu:(rank+1)*eval_files_per_gpu]
228
+
229
+ # Evaluate the performance of the model
230
+ pesq_sum = 0; si_sdr_sum = 0; estoi_sum = 0;
231
+ start_time = time.time()
232
+ for (clean_file, noisy_file) in zip(clean_files, noisy_files):
233
+ # Load the clean and noisy speech
234
+ x, sr_x = load(clean_file)
235
+ x = x.squeeze().numpy()
236
+ y, sr_y = load(noisy_file)
237
+ assert sr_x == sr_y, "Sample rates of clean and noisy files do not match!"
238
+
239
+ # Resample if necessary
240
+ if sr_x != 16000:
241
+ x_16k = resample(x, orig_sr=sr_x, target_sr=16000).squeeze()
242
+ else:
243
+ x_16k = x
244
+
245
+ # Enhance the noisy speech
246
+ x_hat = self.enhance(y, N=self.sde.N)
247
+ if self.sr != 16000:
248
+ x_hat_16k = resample(x_hat, orig_sr=self.sr, target_sr=16000).squeeze()
249
+ else:
250
+ x_hat_16k = x_hat
251
+
252
+ pesq_sum += pesq(16000, x_16k, x_hat_16k, 'wb')
253
+ si_sdr_sum += si_sdr(x, x_hat)
254
+ estoi_sum += stoi(x, x_hat, self.sr, extended=True)
255
+
256
+ print(f"process {eval_files_per_gpu} in {time.time()-start_time}")
257
+ pesq_avg = pesq_sum / len(clean_files)
258
+ si_sdr_avg = si_sdr_sum / len(clean_files)
259
+ estoi_avg = estoi_sum / len(clean_files)
260
+
261
+ self.log('pesq', pesq_avg, on_step=False, on_epoch=True, sync_dist=True)
262
+ self.log('si_sdr', si_sdr_avg, on_step=False, on_epoch=True, sync_dist=True)
263
+ self.log('estoi', estoi_avg, on_step=False, on_epoch=True, sync_dist=True)
264
+
265
+ loss = self._step(batch, batch_idx)
266
+ self.log('valid_loss', loss, on_step=False, on_epoch=True, sync_dist=True)
267
+
268
+ return loss
269
+
270
+ def forward(self, x_t, y, t):
271
+ """
272
+ The model forward pass. In [1] and [2], the model estimates the score function. In [3], the model estimates
273
+ either the score function or the target data for the Schrödinger bridge (loss_type='data_prediction').
274
+
275
+ [1] Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, and Timo Gerkmann
276
+ "Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"
277
+ IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.
278
+
279
+ [2] Julius Richter, Yi-Chiao Wu, Steven Krenn, Simon Welker, Bunlong Lay, Shinji Watanabe, Alexander Richard, and Timo Gerkmann
280
+ "EARS: An Anechoic Fullband Speech Dataset Benchmarked for Speech Enhancement and Dereverberation"
281
+ ISCA Interspecch, Kos, Greece, Sept. 2024.
282
+
283
+ [3] Julius Richter, Danilo de Oliveira, and Timo Gerkmann
284
+ "Investigating Training Objectives for Generative Speech Enhancement"
285
+ https://arxiv.org/abs/2409.10753
286
+
287
+ """
288
+
289
+ # In [3], we use new code with backbone='ncsnpp_v2':
290
+ if self.backbone == "ncsnpp_v2":
291
+ F = self.dnn(self._c_in(t) * x_t, self._c_in(t) * y, t)
292
+
293
+ # Scaling the network output, see below Eq. (7) in the paper
294
+ if self.network_scaling == "1/sigma":
295
+ std = self.sde._std(t)
296
+ F = F / std[:, None, None, None]
297
+ elif self.network_scaling == "1/t":
298
+ F = F / t[:, None, None, None]
299
+
300
+ # The loss type determines the output of the model
301
+ if self.loss_type == "score_matching":
302
+ score = self._c_skip(t) * x_t + self._c_out(t) * F
303
+ return score
304
+ elif self.loss_type == "denoiser":
305
+ sigmas = self.sde._std(t)[:, None, None, None]
306
+ score = (F - x_t) / sigmas.pow(2)
307
+ return score
308
+ elif self.loss_type == 'data_prediction':
309
+ x_hat = self._c_skip(t) * x_t + self._c_out(t) * F
310
+ return x_hat
311
+
312
+ # In [1] and [2], we use the old code:
313
+ else:
314
+ dnn_input = torch.cat([x_t, y], dim=1)
315
+ score = -self.dnn(dnn_input, t)
316
+ return score
317
+
318
+ def _c_in(self, t):
319
+ if self.c_in == "1":
320
+ return 1.0
321
+ elif self.c_in == "edm":
322
+ sigma = self.sde._std(t)
323
+ return (1.0 / torch.sqrt(sigma**2 + self.sigma_data**2))[:, None, None, None]
324
+ else:
325
+ raise ValueError("Invalid c_in type: {}".format(self.c_in))
326
+
327
+ def _c_out(self, t):
328
+ if self.c_out == "1":
329
+ return 1.0
330
+ elif self.c_out == "sigma":
331
+ return self.sde._std(t)[:, None, None, None]
332
+ elif self.c_out == "1/sigma":
333
+ return 1.0 / self.sde._std(t)[:, None, None, None]
334
+ elif self.c_out == "edm":
335
+ sigma = self.sde._std(t)
336
+ return ((sigma * self.sigma_data) / torch.sqrt(self.sigma_data**2 + sigma**2))[:, None, None, None]
337
+ else:
338
+ raise ValueError("Invalid c_out type: {}".format(self.c_out))
339
+
340
+ def _c_skip(self, t):
341
+ if self.c_skip == "0":
342
+ return 0.0
343
+ elif self.c_skip == "edm":
344
+ sigma = self.sde._std(t)
345
+ return (self.sigma_data**2 / (sigma**2 + self.sigma_data**2))[:, None, None, None]
346
+ else:
347
+ raise ValueError("Invalid c_skip type: {}".format(self.c_skip))
348
+
349
+ def to(self, *args, **kwargs):
350
+ """Override PyTorch .to() to also transfer the EMA of the model weights"""
351
+ self.ema.to(*args, **kwargs)
352
+ return super().to(*args, **kwargs)
353
+
354
+ def get_pc_sampler(self, predictor_name, corrector_name, y, N=None, minibatch=None, **kwargs):
355
+ N = self.sde.N if N is None else N
356
+ sde = self.sde.copy()
357
+ sde.N = N
358
+
359
+ kwargs = {"eps": self.t_eps, **kwargs}
360
+ if minibatch is None:
361
+ return sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y, **kwargs)
362
+ else:
363
+ M = y.shape[0]
364
+ def batched_sampling_fn():
365
+ samples, ns = [], []
366
+ for i in range(int(ceil(M / minibatch))):
367
+ y_mini = y[i*minibatch:(i+1)*minibatch]
368
+ sampler = sampling.get_pc_sampler(predictor_name, corrector_name, sde=sde, score_fn=self, y=y_mini, **kwargs)
369
+ sample, n = sampler()
370
+ samples.append(sample)
371
+ ns.append(n)
372
+ samples = torch.cat(samples, dim=0)
373
+ return samples, ns
374
+ return batched_sampling_fn
375
+
376
+ def get_ode_sampler(self, y, N=None, minibatch=None, **kwargs):
377
+ N = self.sde.N if N is None else N
378
+ sde = self.sde.copy()
379
+ sde.N = N
380
+
381
+ kwargs = {"eps": self.t_eps, **kwargs}
382
+ if minibatch is None:
383
+ return sampling.get_ode_sampler(sde, self, y=y, **kwargs)
384
+ else:
385
+ M = y.shape[0]
386
+ def batched_sampling_fn():
387
+ samples, ns = [], []
388
+ for i in range(int(ceil(M / minibatch))):
389
+ y_mini = y[i*minibatch:(i+1)*minibatch]
390
+ sampler = sampling.get_ode_sampler(sde, self, y=y_mini, **kwargs)
391
+ sample, n = sampler()
392
+ samples.append(sample)
393
+ ns.append(n)
394
+ samples = torch.cat(samples, dim=0)
395
+ return sample, ns
396
+ return batched_sampling_fn
397
+
398
+ def get_sb_sampler(self, sde, y, sampler_type="ode", N=None, **kwargs):
399
+ N = sde.N if N is None else N
400
+ sde = self.sde.copy()
401
+ sde.N = N if N is not None else sde.N
402
+
403
+ return sampling.get_sb_sampler(sde, self, y=y, sampler_type=sampler_type, **kwargs)
404
+
405
+ def train_dataloader(self):
406
+ return self.data_module.train_dataloader()
407
+
408
+ def val_dataloader(self):
409
+ return self.data_module.val_dataloader()
410
+
411
+ def test_dataloader(self):
412
+ return self.data_module.test_dataloader()
413
+
414
+ def setup(self, stage=None):
415
+ return self.data_module.setup(stage=stage)
416
+
417
+ def to_audio(self, spec, length=None):
418
+ return self._istft(self._backward_transform(spec), length)
419
+
420
+ def _forward_transform(self, spec):
421
+ return self.data_module.spec_fwd(spec)
422
+
423
+ def _backward_transform(self, spec):
424
+ return self.data_module.spec_back(spec)
425
+
426
+ def _stft(self, sig):
427
+ return self.data_module.stft(sig)
428
+
429
+ def _istft(self, spec, length=None):
430
+ return self.data_module.istft(spec, length)
431
+
432
+ def enhance(self, y, sampler_type="pc", predictor="reverse_diffusion",
433
+ corrector="ald", N=30, corrector_steps=1, snr=0.5, timeit=False,
434
+ **kwargs
435
+ ):
436
+ """
437
+ One-call speech enhancement of noisy speech `y`, for convenience.
438
+ """
439
+ start = time.time()
440
+ T_orig = y.size(1)
441
+ norm_factor = y.abs().max().item()
442
+ y = y / norm_factor
443
+ Y = torch.unsqueeze(self._forward_transform(self._stft(y.cuda())), 0)
444
+ Y = pad_spec(Y)
445
+
446
+ # SGMSE sampling with OUVE SDE
447
+ if self.sde.__class__.__name__ == 'OUVESDE':
448
+ if self.sde.sampler_type == "pc":
449
+ sampler = self.get_pc_sampler(predictor, corrector, Y.cuda(), N=N,
450
+ corrector_steps=corrector_steps, snr=snr, intermediate=False,
451
+ **kwargs)
452
+ elif self.sde.sampler_type == "ode":
453
+ sampler = self.get_ode_sampler(Y.cuda(), N=N, **kwargs)
454
+ else:
455
+ raise ValueError("Invalid sampler type for SGMSE sampling: {}".format(sampler_type))
456
+ # Schrödinger bridge sampling with VE SDE
457
+ elif self.sde.__class__.__name__ == 'SBVESDE':
458
+ sampler = self.get_sb_sampler(sde=self.sde, y=Y.cuda(), sampler_type=self.sde.sampler_type)
459
+ else:
460
+ raise ValueError("Invalid SDE type for speech enhancement: {}".format(self.sde.__class__.__name__))
461
+
462
+ sample, nfe = sampler()
463
+ x_hat = self.to_audio(sample.squeeze(), T_orig)
464
+ x_hat = x_hat * norm_factor
465
+ x_hat = x_hat.squeeze().cpu().numpy()
466
+ end = time.time()
467
+ if timeit:
468
+ rtf = (end-start)/(len(x_hat)/self.sr)
469
+ return x_hat, nfe, rtf
470
+ else:
471
+ return x_hat
sgmse/sampling/__init__.py CHANGED
@@ -1,249 +1,249 @@
1
- # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
2
- """Various sampling methods."""
3
- from scipy import integrate
4
- import torch
5
-
6
- from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
7
- from .correctors import Corrector, CorrectorRegistry
8
-
9
-
10
- __all__ = [
11
- 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
12
- 'get_sampler'
13
- ]
14
-
15
-
16
- def to_flattened_numpy(x):
17
- """Flatten a torch tensor `x` and convert it to numpy."""
18
- return x.detach().cpu().numpy().reshape((-1,))
19
-
20
-
21
- def from_flattened_numpy(x, shape):
22
- """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
23
- return torch.from_numpy(x.reshape(shape))
24
-
25
-
26
- def get_pc_sampler(
27
- predictor_name, corrector_name, sde, score_fn, y,
28
- denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
29
- intermediate=False, **kwargs
30
- ):
31
- """Create a Predictor-Corrector (PC) sampler.
32
-
33
- Args:
34
- predictor_name: The name of a registered `sampling.Predictor`.
35
- corrector_name: The name of a registered `sampling.Corrector`.
36
- sde: An `sdes.SDE` object representing the forward SDE.
37
- score_fn: A function (typically learned model) that predicts the score.
38
- y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
39
- denoise: If `True`, add one-step denoising to the final samples.
40
- eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
41
- snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
42
- N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
43
-
44
- Returns:
45
- A sampling function that returns samples and the number of function evaluations during sampling.
46
- """
47
- predictor_cls = PredictorRegistry.get_by_name(predictor_name)
48
- corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
49
- predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
50
- corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
51
-
52
- def pc_sampler():
53
- """The PC sampler function."""
54
- with torch.no_grad():
55
- xt = sde.prior_sampling(y.shape, y).to(y.device)
56
- timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
57
- for i in range(sde.N):
58
- t = timesteps[i]
59
- if i != len(timesteps) - 1:
60
- stepsize = t - timesteps[i+1]
61
- else:
62
- stepsize = timesteps[-1] # from eps to 0
63
- vec_t = torch.ones(y.shape[0], device=y.device) * t
64
- xt, xt_mean = corrector.update_fn(xt, y, vec_t)
65
- xt, xt_mean = predictor.update_fn(xt, y, vec_t, stepsize)
66
- x_result = xt_mean if denoise else xt
67
- ns = sde.N * (corrector.n_steps + 1)
68
- return x_result, ns
69
-
70
- return pc_sampler
71
-
72
-
73
- def get_ode_sampler(
74
- sde, score_fn, y, inverse_scaler=None,
75
- denoise=True, rtol=1e-5, atol=1e-5,
76
- method='RK45', eps=3e-2, device='cuda', **kwargs
77
- ):
78
- """Probability flow ODE sampler with the black-box ODE solver.
79
-
80
- Args:
81
- sde: An `sdes.SDE` object representing the forward SDE.
82
- score_fn: A function (typically learned model) that predicts the score.
83
- y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
84
- inverse_scaler: The inverse data normalizer.
85
- denoise: If `True`, add one-step denoising to final samples.
86
- rtol: A `float` number. The relative tolerance level of the ODE solver.
87
- atol: A `float` number. The absolute tolerance level of the ODE solver.
88
- method: A `str`. The algorithm used for the black-box ODE solver.
89
- See the documentation of `scipy.integrate.solve_ivp`.
90
- eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
91
- device: PyTorch device.
92
-
93
- Returns:
94
- A sampling function that returns samples and the number of function evaluations during sampling.
95
- """
96
- predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
97
- rsde = sde.reverse(score_fn, probability_flow=True)
98
-
99
- def denoise_update_fn(x):
100
- vec_eps = torch.ones(x.shape[0], device=x.device) * eps
101
- _, x = predictor.update_fn(x, y, vec_eps)
102
- return x
103
-
104
- def drift_fn(x, y, t):
105
- """Get the drift function of the reverse-time SDE."""
106
- return rsde.sde(x, y, t)[0]
107
-
108
- def ode_sampler(z=None, **kwargs):
109
- """The probability flow ODE sampler with black-box ODE solver.
110
-
111
- Args:
112
- model: A score model.
113
- z: If present, generate samples from latent code `z`.
114
- Returns:
115
- samples, number of function evaluations.
116
- """
117
- with torch.no_grad():
118
- # If not represent, sample the latent code from the prior distibution of the SDE.
119
- x = sde.prior_sampling(y.shape, y).to(device)
120
-
121
- def ode_func(t, x):
122
- x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
123
- vec_t = torch.ones(y.shape[0], device=x.device) * t
124
- drift = drift_fn(x, y, vec_t)
125
- return to_flattened_numpy(drift)
126
-
127
- # Black-box ODE solver for the probability flow ODE
128
- solution = integrate.solve_ivp(
129
- ode_func, (sde.T, eps), to_flattened_numpy(x),
130
- rtol=rtol, atol=atol, method=method, **kwargs
131
- )
132
- nfe = solution.nfev
133
- x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
134
-
135
- # Denoising is equivalent to running one predictor step without adding noise
136
- if denoise:
137
- x = denoise_update_fn(x)
138
-
139
- if inverse_scaler is not None:
140
- x = inverse_scaler(x)
141
- return x, nfe
142
-
143
- return ode_sampler
144
-
145
- def get_sb_sampler(sde, model, y, eps=1e-4, n_steps=50, sampler_type="ode", **kwargs):
146
- # adapted from https://github.com/NVIDIA/NeMo/blob/78357ae99ff2cf9f179f53fbcb02c88a5a67defb/nemo/collections/audio/parts/submodules/schroedinger_bridge.py#L382
147
- def sde_sampler():
148
- """The SB-SDE sampler function."""
149
- with torch.no_grad():
150
- xt = y[:, [0], :, :] # special case for storm_2ch
151
- time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
152
-
153
- # Initial values
154
- time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
155
- sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
156
-
157
- for t in time_steps[1:]:
158
- # Prepare time steps for the whole batch
159
- time = t * torch.ones(xt.shape[0], device=xt.device)
160
-
161
- # Get noise schedule for current time
162
- sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
163
-
164
- # Run DNN
165
- current_estimate = model(xt, y, time)
166
-
167
- # Calculate scaling for the first-order discretization from the paper
168
- weight_prev = alpha_t * sigma_t**2 / (alpha_prev * sigma_prev**2 + sde.eps)
169
- tmp = 1 - sigma_t**2 / (sigma_prev**2 + sde.eps)
170
- weight_estimate = alpha_t * tmp
171
- weight_z = alpha_t * sigma_t * torch.sqrt(tmp)
172
-
173
- # View as [B, C, D, T]
174
- weight_prev = weight_prev[:, None, None, None]
175
- weight_estimate = weight_estimate[:, None, None, None]
176
- weight_z = weight_z[:, None, None, None]
177
-
178
- # Random sample
179
- z_norm = torch.randn_like(xt)
180
-
181
- if t == time_steps[-1]:
182
- weight_z = 0.0
183
-
184
- # Update state: weighted sum of previous state, current estimate and noise
185
- xt = weight_prev * xt + weight_estimate * current_estimate + weight_z * z_norm
186
-
187
- # Save previous values
188
- time_prev = time
189
- alpha_prev = alpha_t
190
- sigma_prev = sigma_t
191
- sigma_bar_prev = sigma_bart
192
-
193
- return xt, n_steps
194
-
195
- def ode_sampler():
196
- """The SB-ODE sampler function."""
197
- with torch.no_grad():
198
- xt = y
199
- time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
200
-
201
- # Initial values
202
- time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
203
- sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
204
-
205
- for t in time_steps[1:]:
206
- # Prepare time steps for the whole batch
207
- time = t * torch.ones(xt.shape[0], device=xt.device)
208
-
209
- # Get noise schedule for current time
210
- sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
211
-
212
- # Run DNN
213
- current_estimate = model(xt, y, time)
214
-
215
- # Calculate scaling for the first-order discretization from the paper
216
- weight_prev = alpha_t * sigma_t * sigma_bart / (alpha_prev * sigma_prev * sigma_bar_prev + sde.eps)
217
- weight_estimate = (
218
- alpha_t
219
- / (sigma_T**2 + sde.eps)
220
- * (sigma_bart**2 - sigma_bar_prev * sigma_t * sigma_bart / (sigma_prev + sde.eps))
221
- )
222
- weight_prior_mean = (
223
- alpha_t
224
- / (alpha_T * sigma_T**2 + sde.eps)
225
- * (sigma_t**2 - sigma_prev * sigma_t * sigma_bart / (sigma_bar_prev + sde.eps))
226
- )
227
-
228
- # View as [B, C, D, T]
229
- weight_prev = weight_prev[:, None, None, None]
230
- weight_estimate = weight_estimate[:, None, None, None]
231
- weight_prior_mean = weight_prior_mean[:, None, None, None]
232
-
233
- # Update state: weighted sum of previous state, current estimate and prior
234
- xt = weight_prev * xt + weight_estimate * current_estimate + weight_prior_mean * y
235
-
236
- # Save previous values
237
- time_prev = time
238
- alpha_prev = alpha_t
239
- sigma_prev = sigma_t
240
- sigma_bar_prev = sigma_bart
241
-
242
- return xt, n_steps
243
-
244
- if sampler_type == "sde":
245
- return sde_sampler
246
- elif sampler_type == "ode":
247
- return ode_sampler
248
- else:
249
- raise ValueError("Invalid type. Choose 'ode' or 'sde'.")
 
1
+ # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
2
+ """Various sampling methods."""
3
+ from scipy import integrate
4
+ import torch
5
+
6
+ from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
7
+ from .correctors import Corrector, CorrectorRegistry
8
+
9
+
10
+ __all__ = [
11
+ 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
12
+ 'get_sampler'
13
+ ]
14
+
15
+
16
+ def to_flattened_numpy(x):
17
+ """Flatten a torch tensor `x` and convert it to numpy."""
18
+ return x.detach().cpu().numpy().reshape((-1,))
19
+
20
+
21
+ def from_flattened_numpy(x, shape):
22
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
23
+ return torch.from_numpy(x.reshape(shape))
24
+
25
+
26
+ def get_pc_sampler(
27
+ predictor_name, corrector_name, sde, score_fn, y,
28
+ denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
29
+ intermediate=False, **kwargs
30
+ ):
31
+ """Create a Predictor-Corrector (PC) sampler.
32
+
33
+ Args:
34
+ predictor_name: The name of a registered `sampling.Predictor`.
35
+ corrector_name: The name of a registered `sampling.Corrector`.
36
+ sde: An `sdes.SDE` object representing the forward SDE.
37
+ score_fn: A function (typically learned model) that predicts the score.
38
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
39
+ denoise: If `True`, add one-step denoising to the final samples.
40
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
41
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
42
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
43
+
44
+ Returns:
45
+ A sampling function that returns samples and the number of function evaluations during sampling.
46
+ """
47
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
48
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
49
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
50
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
51
+
52
+ def pc_sampler():
53
+ """The PC sampler function."""
54
+ with torch.no_grad():
55
+ xt = sde.prior_sampling(y.shape, y).to(y.device)
56
+ timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
57
+ for i in range(sde.N):
58
+ t = timesteps[i]
59
+ if i != len(timesteps) - 1:
60
+ stepsize = t - timesteps[i+1]
61
+ else:
62
+ stepsize = timesteps[-1] # from eps to 0
63
+ vec_t = torch.ones(y.shape[0], device=y.device) * t
64
+ xt, xt_mean = corrector.update_fn(xt, y, vec_t)
65
+ xt, xt_mean = predictor.update_fn(xt, y, vec_t, stepsize)
66
+ x_result = xt_mean if denoise else xt
67
+ ns = sde.N * (corrector.n_steps + 1)
68
+ return x_result, ns
69
+
70
+ return pc_sampler
71
+
72
+
73
+ def get_ode_sampler(
74
+ sde, score_fn, y, inverse_scaler=None,
75
+ denoise=True, rtol=1e-5, atol=1e-5,
76
+ method='RK45', eps=3e-2, device='cuda', **kwargs
77
+ ):
78
+ """Probability flow ODE sampler with the black-box ODE solver.
79
+
80
+ Args:
81
+ sde: An `sdes.SDE` object representing the forward SDE.
82
+ score_fn: A function (typically learned model) that predicts the score.
83
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
84
+ inverse_scaler: The inverse data normalizer.
85
+ denoise: If `True`, add one-step denoising to final samples.
86
+ rtol: A `float` number. The relative tolerance level of the ODE solver.
87
+ atol: A `float` number. The absolute tolerance level of the ODE solver.
88
+ method: A `str`. The algorithm used for the black-box ODE solver.
89
+ See the documentation of `scipy.integrate.solve_ivp`.
90
+ eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
91
+ device: PyTorch device.
92
+
93
+ Returns:
94
+ A sampling function that returns samples and the number of function evaluations during sampling.
95
+ """
96
+ predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
97
+ rsde = sde.reverse(score_fn, probability_flow=True)
98
+
99
+ def denoise_update_fn(x):
100
+ vec_eps = torch.ones(x.shape[0], device=x.device) * eps
101
+ _, x = predictor.update_fn(x, y, vec_eps)
102
+ return x
103
+
104
+ def drift_fn(x, y, t):
105
+ """Get the drift function of the reverse-time SDE."""
106
+ return rsde.sde(x, y, t)[0]
107
+
108
+ def ode_sampler(z=None, **kwargs):
109
+ """The probability flow ODE sampler with black-box ODE solver.
110
+
111
+ Args:
112
+ model: A score model.
113
+ z: If present, generate samples from latent code `z`.
114
+ Returns:
115
+ samples, number of function evaluations.
116
+ """
117
+ with torch.no_grad():
118
+ # If not represent, sample the latent code from the prior distibution of the SDE.
119
+ x = sde.prior_sampling(y.shape, y).to(device)
120
+
121
+ def ode_func(t, x):
122
+ x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
123
+ vec_t = torch.ones(y.shape[0], device=x.device) * t
124
+ drift = drift_fn(x, y, vec_t)
125
+ return to_flattened_numpy(drift)
126
+
127
+ # Black-box ODE solver for the probability flow ODE
128
+ solution = integrate.solve_ivp(
129
+ ode_func, (sde.T, eps), to_flattened_numpy(x),
130
+ rtol=rtol, atol=atol, method=method, **kwargs
131
+ )
132
+ nfe = solution.nfev
133
+ x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
134
+
135
+ # Denoising is equivalent to running one predictor step without adding noise
136
+ if denoise:
137
+ x = denoise_update_fn(x)
138
+
139
+ if inverse_scaler is not None:
140
+ x = inverse_scaler(x)
141
+ return x, nfe
142
+
143
+ return ode_sampler
144
+
145
+ def get_sb_sampler(sde, model, y, eps=1e-4, n_steps=50, sampler_type="ode", **kwargs):
146
+ # adapted from https://github.com/NVIDIA/NeMo/blob/78357ae99ff2cf9f179f53fbcb02c88a5a67defb/nemo/collections/audio/parts/submodules/schroedinger_bridge.py#L382
147
+ def sde_sampler():
148
+ """The SB-SDE sampler function."""
149
+ with torch.no_grad():
150
+ xt = y[:, [0], :, :] # special case for storm_2ch
151
+ time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
152
+
153
+ # Initial values
154
+ time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
155
+ sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
156
+
157
+ for t in time_steps[1:]:
158
+ # Prepare time steps for the whole batch
159
+ time = t * torch.ones(xt.shape[0], device=xt.device)
160
+
161
+ # Get noise schedule for current time
162
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
163
+
164
+ # Run DNN
165
+ current_estimate = model(xt, y, time)
166
+
167
+ # Calculate scaling for the first-order discretization from the paper
168
+ weight_prev = alpha_t * sigma_t**2 / (alpha_prev * sigma_prev**2 + sde.eps)
169
+ tmp = 1 - sigma_t**2 / (sigma_prev**2 + sde.eps)
170
+ weight_estimate = alpha_t * tmp
171
+ weight_z = alpha_t * sigma_t * torch.sqrt(tmp)
172
+
173
+ # View as [B, C, D, T]
174
+ weight_prev = weight_prev[:, None, None, None]
175
+ weight_estimate = weight_estimate[:, None, None, None]
176
+ weight_z = weight_z[:, None, None, None]
177
+
178
+ # Random sample
179
+ z_norm = torch.randn_like(xt)
180
+
181
+ if t == time_steps[-1]:
182
+ weight_z = 0.0
183
+
184
+ # Update state: weighted sum of previous state, current estimate and noise
185
+ xt = weight_prev * xt + weight_estimate * current_estimate + weight_z * z_norm
186
+
187
+ # Save previous values
188
+ time_prev = time
189
+ alpha_prev = alpha_t
190
+ sigma_prev = sigma_t
191
+ sigma_bar_prev = sigma_bart
192
+
193
+ return xt, n_steps
194
+
195
+ def ode_sampler():
196
+ """The SB-ODE sampler function."""
197
+ with torch.no_grad():
198
+ xt = y
199
+ time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
200
+
201
+ # Initial values
202
+ time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
203
+ sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
204
+
205
+ for t in time_steps[1:]:
206
+ # Prepare time steps for the whole batch
207
+ time = t * torch.ones(xt.shape[0], device=xt.device)
208
+
209
+ # Get noise schedule for current time
210
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
211
+
212
+ # Run DNN
213
+ current_estimate = model(xt, y, time)
214
+
215
+ # Calculate scaling for the first-order discretization from the paper
216
+ weight_prev = alpha_t * sigma_t * sigma_bart / (alpha_prev * sigma_prev * sigma_bar_prev + sde.eps)
217
+ weight_estimate = (
218
+ alpha_t
219
+ / (sigma_T**2 + sde.eps)
220
+ * (sigma_bart**2 - sigma_bar_prev * sigma_t * sigma_bart / (sigma_prev + sde.eps))
221
+ )
222
+ weight_prior_mean = (
223
+ alpha_t
224
+ / (alpha_T * sigma_T**2 + sde.eps)
225
+ * (sigma_t**2 - sigma_prev * sigma_t * sigma_bart / (sigma_bar_prev + sde.eps))
226
+ )
227
+
228
+ # View as [B, C, D, T]
229
+ weight_prev = weight_prev[:, None, None, None]
230
+ weight_estimate = weight_estimate[:, None, None, None]
231
+ weight_prior_mean = weight_prior_mean[:, None, None, None]
232
+
233
+ # Update state: weighted sum of previous state, current estimate and prior
234
+ xt = weight_prev * xt + weight_estimate * current_estimate + weight_prior_mean * y
235
+
236
+ # Save previous values
237
+ time_prev = time
238
+ alpha_prev = alpha_t
239
+ sigma_prev = sigma_t
240
+ sigma_bar_prev = sigma_bart
241
+
242
+ return xt, n_steps
243
+
244
+ if sampler_type == "sde":
245
+ return sde_sampler
246
+ elif sampler_type == "ode":
247
+ return ode_sampler
248
+ else:
249
+ raise ValueError("Invalid type. Choose 'ode' or 'sde'.")
sgmse/sampling/correctors.py CHANGED
@@ -1,94 +1,96 @@
1
- import abc
2
- import torch
3
-
4
- from sgmse import sdes
5
- from sgmse.util.registry import Registry
6
-
7
-
8
- CorrectorRegistry = Registry("Corrector")
9
-
10
-
11
- class Corrector(abc.ABC):
12
- """The abstract class for a corrector algorithm."""
13
-
14
- def __init__(self, sde, score_fn, snr, n_steps):
15
- super().__init__()
16
- self.rsde = sde.reverse(score_fn)
17
- self.score_fn = score_fn
18
- self.snr = snr
19
- self.n_steps = n_steps
20
-
21
- @abc.abstractmethod
22
- def update_fn(self, x, y, t, *args):
23
- """One update of the corrector.
24
-
25
- Args:
26
- x: A PyTorch tensor representing the current state
27
- t: A PyTorch tensor representing the current time step.
28
- *args: Possibly additional arguments, in particular `y` for OU processes
29
-
30
- Returns:
31
- x: A PyTorch tensor of the next state.
32
- x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
33
- """
34
- pass
35
-
36
-
37
- @CorrectorRegistry.register(name='langevin')
38
- class LangevinCorrector(Corrector):
39
- def __init__(self, sde, score_fn, snr, n_steps):
40
- super().__init__(sde, score_fn, snr, n_steps)
41
- self.score_fn = score_fn
42
- self.n_steps = n_steps
43
- self.snr = snr
44
-
45
- def update_fn(self, x, y, t, *args):
46
- target_snr = self.snr
47
- for _ in range(self.n_steps):
48
- grad = self.score_fn(x, y, t, *args)
49
- noise = torch.randn_like(x)
50
- grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
51
- noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
52
- step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
53
- x_mean = x + step_size[:, None, None, None] * grad
54
- x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
55
-
56
- return x, x_mean
57
-
58
-
59
- @CorrectorRegistry.register(name='ald')
60
- class AnnealedLangevinDynamics(Corrector):
61
- """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
62
- def __init__(self, sde, score_fn, snr, n_steps):
63
- super().__init__(sde, score_fn, snr, n_steps)
64
- self.sde = sde
65
- self.score_fn = score_fn
66
- self.snr = snr
67
- self.n_steps = n_steps
68
-
69
- def update_fn(self, x, y, t, *args):
70
- n_steps = self.n_steps
71
- target_snr = self.snr
72
- std = self.sde.marginal_prob(x, y, t, *args)[1]
73
-
74
- for _ in range(n_steps):
75
- grad = self.score_fn(x, y, t, *args)
76
- noise = torch.randn_like(x)
77
- step_size = (target_snr * std) ** 2 * 2
78
- x_mean = x + step_size[:, None, None, None] * grad
79
- x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
80
-
81
- return x, x_mean
82
-
83
-
84
- @CorrectorRegistry.register(name='none')
85
- class NoneCorrector(Corrector):
86
- """An empty corrector that does nothing."""
87
-
88
- def __init__(self, *args, **kwargs):
89
- self.snr = 0
90
- self.n_steps = 0
91
- pass
92
-
93
- def update_fn(self, x, t, *args):
94
- return x, x
 
 
 
1
+ import abc
2
+ import torch
3
+
4
+ from sgmse import sdes
5
+ from sgmse.util.registry import Registry
6
+
7
+
8
+ CorrectorRegistry = Registry("Corrector")
9
+
10
+
11
+ class Corrector(abc.ABC):
12
+ """The abstract class for a corrector algorithm."""
13
+
14
+ def __init__(self, sde, score_fn, snr, n_steps):
15
+ super().__init__()
16
+ self.rsde = sde.reverse(score_fn)
17
+ self.score_fn = score_fn
18
+ self.snr = snr
19
+ self.n_steps = n_steps
20
+
21
+ @abc.abstractmethod
22
+ def update_fn(self, x, t, *args):
23
+ """One update of the corrector.
24
+
25
+ Args:
26
+ x: A PyTorch tensor representing the current state
27
+ t: A PyTorch tensor representing the current time step.
28
+ *args: Possibly additional arguments, in particular `y` for OU processes
29
+
30
+ Returns:
31
+ x: A PyTorch tensor of the next state.
32
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
33
+ """
34
+ pass
35
+
36
+
37
+ @CorrectorRegistry.register(name='langevin')
38
+ class LangevinCorrector(Corrector):
39
+ def __init__(self, sde, score_fn, snr, n_steps):
40
+ super().__init__(sde, score_fn, snr, n_steps)
41
+ self.score_fn = score_fn
42
+ self.n_steps = n_steps
43
+ self.snr = snr
44
+
45
+ def update_fn(self, x, t, *args):
46
+ target_snr = self.snr
47
+ for _ in range(self.n_steps):
48
+ grad = self.score_fn(x, t, *args)
49
+ noise = torch.randn_like(x)
50
+ grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
51
+ noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
52
+ step_size = ((target_snr * noise_norm / grad_norm) ** 2 * 2).unsqueeze(0)
53
+ x_mean = x + step_size[:, None, None, None] * grad
54
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
55
+
56
+ return x, x_mean
57
+
58
+
59
+ @CorrectorRegistry.register(name='ald')
60
+ class AnnealedLangevinDynamics(Corrector):
61
+ """The original annealed Langevin dynamics predictor in NCSN/NCSNv2."""
62
+ def __init__(self, sde, score_fn, snr, n_steps):
63
+ super().__init__(sde, score_fn, snr, n_steps)
64
+ if not isinstance(sde, (sdes.OUVESDE,)):
65
+ raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
66
+ self.sde = sde
67
+ self.score_fn = score_fn
68
+ self.snr = snr
69
+ self.n_steps = n_steps
70
+
71
+ def update_fn(self, x, t, *args):
72
+ n_steps = self.n_steps
73
+ target_snr = self.snr
74
+ std = self.sde.marginal_prob(x, t, *args)[1]
75
+
76
+ for _ in range(n_steps):
77
+ grad = self.score_fn(x, t, *args)
78
+ noise = torch.randn_like(x)
79
+ step_size = (target_snr * std) ** 2 * 2
80
+ x_mean = x + step_size[:, None, None, None] * grad
81
+ x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None]
82
+
83
+ return x, x_mean
84
+
85
+
86
+ @CorrectorRegistry.register(name='none')
87
+ class NoneCorrector(Corrector):
88
+ """An empty corrector that does nothing."""
89
+
90
+ def __init__(self, *args, **kwargs):
91
+ self.snr = 0
92
+ self.n_steps = 0
93
+ pass
94
+
95
+ def update_fn(self, x, t, *args):
96
+ return x, x
sgmse/sampling/predictors.py CHANGED
@@ -1,76 +1,76 @@
1
- import abc
2
-
3
- import torch
4
- import numpy as np
5
-
6
- from sgmse.util.registry import Registry
7
-
8
-
9
- PredictorRegistry = Registry("Predictor")
10
-
11
-
12
- class Predictor(abc.ABC):
13
- """The abstract class for a predictor algorithm."""
14
-
15
- def __init__(self, sde, score_fn, probability_flow=False):
16
- super().__init__()
17
- self.sde = sde
18
- self.rsde = sde.reverse(score_fn)
19
- self.score_fn = score_fn
20
- self.probability_flow = probability_flow
21
-
22
- @abc.abstractmethod
23
- def update_fn(self, x, t, *args):
24
- """One update of the predictor.
25
-
26
- Args:
27
- x: A PyTorch tensor representing the current state
28
- t: A Pytorch tensor representing the current time step.
29
- *args: Possibly additional arguments, in particular `y` for OU processes
30
-
31
- Returns:
32
- x: A PyTorch tensor of the next state.
33
- x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
34
- """
35
- pass
36
-
37
- def debug_update_fn(self, x, t, *args):
38
- raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
39
-
40
-
41
- @PredictorRegistry.register('euler_maruyama')
42
- class EulerMaruyamaPredictor(Predictor):
43
- def __init__(self, sde, score_fn, probability_flow=False):
44
- super().__init__(sde, score_fn, probability_flow=probability_flow)
45
-
46
- def update_fn(self, x, y, t, *args):
47
- dt = -1. / self.rsde.N
48
- z = torch.randn_like(x)
49
- f, g = self.rsde.sde(x, y, t, *args)
50
- x_mean = x + f * dt
51
- x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
52
- return x, x_mean
53
-
54
-
55
- @PredictorRegistry.register('reverse_diffusion')
56
- class ReverseDiffusionPredictor(Predictor):
57
- def __init__(self, sde, score_fn, probability_flow=False):
58
- super().__init__(sde, score_fn, probability_flow=probability_flow)
59
-
60
- def update_fn(self, x, y, t, stepsize):
61
- f, g = self.rsde.discretize(x, y, t, stepsize)
62
- z = torch.randn_like(x)
63
- x_mean = x - f
64
- x = x_mean + g[:, None, None, None] * z
65
- return x, x_mean
66
-
67
-
68
- @PredictorRegistry.register('none')
69
- class NonePredictor(Predictor):
70
- """An empty predictor that does nothing."""
71
-
72
- def __init__(self, *args, **kwargs):
73
- pass
74
-
75
- def update_fn(self, x, y, t, *args):
76
- return x, x
 
1
+ import abc
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ from sgmse.util.registry import Registry
7
+
8
+
9
+ PredictorRegistry = Registry("Predictor")
10
+
11
+
12
+ class Predictor(abc.ABC):
13
+ """The abstract class for a predictor algorithm."""
14
+
15
+ def __init__(self, sde, score_fn, probability_flow=False):
16
+ super().__init__()
17
+ self.sde = sde
18
+ self.rsde = sde.reverse(score_fn)
19
+ self.score_fn = score_fn
20
+ self.probability_flow = probability_flow
21
+
22
+ @abc.abstractmethod
23
+ def update_fn(self, x, t, *args):
24
+ """One update of the predictor.
25
+
26
+ Args:
27
+ x: A PyTorch tensor representing the current state
28
+ t: A Pytorch tensor representing the current time step.
29
+ *args: Possibly additional arguments, in particular `y` for OU processes
30
+
31
+ Returns:
32
+ x: A PyTorch tensor of the next state.
33
+ x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
34
+ """
35
+ pass
36
+
37
+ def debug_update_fn(self, x, t, *args):
38
+ raise NotImplementedError(f"Debug update function not implemented for predictor {self}.")
39
+
40
+
41
+ @PredictorRegistry.register('euler_maruyama')
42
+ class EulerMaruyamaPredictor(Predictor):
43
+ def __init__(self, sde, score_fn, probability_flow=False):
44
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
45
+
46
+ def update_fn(self, x, t, *args):
47
+ dt = -1. / self.rsde.N
48
+ z = torch.randn_like(x)
49
+ f, g = self.rsde.sde(x, t, *args)
50
+ x_mean = x + f * dt
51
+ x = x_mean + g[:, None, None, None] * np.sqrt(-dt) * z
52
+ return x, x_mean
53
+
54
+
55
+ @PredictorRegistry.register('reverse_diffusion')
56
+ class ReverseDiffusionPredictor(Predictor):
57
+ def __init__(self, sde, score_fn, probability_flow=False):
58
+ super().__init__(sde, score_fn, probability_flow=probability_flow)
59
+
60
+ def update_fn(self, x, t, y, stepsize):
61
+ f, g = self.rsde.discretize(x, t, y, stepsize)
62
+ z = torch.randn_like(x)
63
+ x_mean = x - f
64
+ x = x_mean + g[:, None, None, None] * z
65
+ return x, x_mean
66
+
67
+
68
+ @PredictorRegistry.register('none')
69
+ class NonePredictor(Predictor):
70
+ """An empty predictor that does nothing."""
71
+
72
+ def __init__(self, *args, **kwargs):
73
+ pass
74
+
75
+ def update_fn(self, x, t, *args):
76
+ return x, x
sgmse/sdes.py CHANGED
@@ -1,313 +1,392 @@
1
- """
2
- Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
-
4
- Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
- """
6
- import abc
7
- import warnings
8
-
9
- import numpy as np
10
- from sgmse.util.tensors import batch_broadcast
11
- import torch
12
-
13
- from sgmse.util.registry import Registry
14
-
15
-
16
- SDERegistry = Registry("SDE")
17
-
18
-
19
- class SDE(abc.ABC):
20
- """SDE abstract class. Functions are designed for a mini-batch of inputs."""
21
-
22
- def __init__(self, N):
23
- """Construct an SDE.
24
-
25
- Args:
26
- N: number of discretization time steps.
27
- """
28
- super().__init__()
29
- self.N = N
30
-
31
- @property
32
- @abc.abstractmethod
33
- def T(self):
34
- """End time of the SDE."""
35
- pass
36
-
37
- @abc.abstractmethod
38
- def sde(self, x, y, t, *args):
39
- pass
40
-
41
- @abc.abstractmethod
42
- def marginal_prob(self, x, y, t, *args):
43
- """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
44
- pass
45
-
46
- @abc.abstractmethod
47
- def prior_sampling(self, shape, *args):
48
- """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
49
- pass
50
-
51
- @abc.abstractmethod
52
- def prior_logp(self, z):
53
- """Compute log-density of the prior distribution.
54
-
55
- Useful for computing the log-likelihood via probability flow ODE.
56
-
57
- Args:
58
- z: latent code
59
- Returns:
60
- log probability density
61
- """
62
- pass
63
-
64
- @staticmethod
65
- @abc.abstractmethod
66
- def add_argparse_args(parent_parser):
67
- """
68
- Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
69
- """
70
- pass
71
-
72
- def discretize(self, x, y, t, stepsize):
73
- """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
74
-
75
- Useful for reverse diffusion sampling and probabiliy flow sampling.
76
- Defaults to Euler-Maruyama discretization.
77
-
78
- Args:
79
- x: a torch tensor
80
- t: a torch float representing the time step (from 0 to `self.T`)
81
-
82
- Returns:
83
- f, G
84
- """
85
- dt = stepsize
86
- drift, diffusion = self.sde(x, y, t)
87
- f = drift * dt
88
- G = diffusion * torch.sqrt(dt)
89
- return f, G
90
-
91
- def reverse(oself, score_model, probability_flow=False):
92
- """Create the reverse-time SDE/ODE.
93
-
94
- Args:
95
- score_model: A function that takes x, t and y and returns the score.
96
- probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
97
- """
98
- N = oself.N
99
- T = oself.T
100
- sde_fn = oself.sde
101
- discretize_fn = oself.discretize
102
-
103
- # Build the class for reverse-time SDE.
104
- class RSDE(oself.__class__):
105
- def __init__(self):
106
- self.N = N
107
- self.probability_flow = probability_flow
108
-
109
- @property
110
- def T(self):
111
- return T
112
-
113
- def sde(self, x, y, t, *args):
114
- """Create the drift and diffusion functions for the reverse SDE/ODE."""
115
- rsde_parts = self.rsde_parts(x, y, t, *args)
116
- total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
117
- return total_drift, diffusion
118
-
119
- def rsde_parts(self, x, y, t, *args):
120
- sde_drift, sde_diffusion = sde_fn(x, y, t, *args)
121
- score = score_model(x, y, t, *args)
122
- score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
123
- diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
124
- total_drift = sde_drift + score_drift
125
- return {
126
- 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
127
- 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
128
- }
129
-
130
- def discretize(self, x, y, t, stepsize):
131
- """Create discretized iteration rules for the reverse diffusion sampler."""
132
- f, G = discretize_fn(x, y, t, stepsize)
133
- rev_f = f - G[:, None, None, None] ** 2 * score_model(x, y, t) * (0.5 if self.probability_flow else 1.)
134
- rev_G = torch.zeros_like(G) if self.probability_flow else G
135
- return rev_f, rev_G
136
-
137
- return RSDE()
138
-
139
- @abc.abstractmethod
140
- def copy(self):
141
- pass
142
-
143
-
144
- @SDERegistry.register("ouve")
145
- class OUVESDE(SDE):
146
- @staticmethod
147
- def add_argparse_args(parser):
148
- parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
149
- parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
150
- parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
151
- parser.add_argument("--N", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
152
- parser.add_argument("--sampler_type", type=str, default="pc", help="Type of sampler to use. 'pc' by default.")
153
- return parser
154
-
155
- def __init__(self, theta, sigma_min, sigma_max, N=30, sampler_type="pc", **ignored_kwargs):
156
- """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
157
-
158
- Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
159
- to the methods which require it (e.g., `sde` or `marginal_prob`).
160
-
161
- dx = -theta (y-x) dt + sigma(t) dw
162
-
163
- with
164
-
165
- sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
166
-
167
- Args:
168
- theta: stiffness parameter.
169
- sigma_min: smallest sigma.
170
- sigma_max: largest sigma.
171
- N: number of discretization steps
172
- """
173
- super().__init__(N)
174
- self.theta = theta
175
- self.sigma_min = sigma_min
176
- self.sigma_max = sigma_max
177
- self.logsig = np.log(self.sigma_max / self.sigma_min)
178
- self.N = N
179
- self.sampler_type = sampler_type
180
-
181
- def copy(self):
182
- return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N, sampler_type=self.sampler_type)
183
-
184
- @property
185
- def T(self):
186
- return 1
187
-
188
- def sde(self, x, y, t):
189
- drift = self.theta * (y - x)
190
- # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
191
- # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
192
- # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
193
- # unless this sqrt(2*logsig) factor is included.
194
- sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
195
- diffusion = sigma * np.sqrt(2 * self.logsig)
196
- return drift, diffusion
197
-
198
- def _mean(self, x0, y, t):
199
- theta = self.theta
200
- exp_interp = torch.exp(-theta * t)[:, None, None, None]
201
- return exp_interp * x0 + (1 - exp_interp) * y
202
-
203
- def alpha(self, t):
204
- return torch.exp(-self.theta * t)
205
-
206
- def _std(self, t):
207
- # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
208
- sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
209
- # could maybe replace the two torch.exp(... * t) terms here by cached values **t
210
- return torch.sqrt(
211
- (
212
- sigma_min**2
213
- * torch.exp(-2 * theta * t)
214
- * (torch.exp(2 * (theta + logsig) * t) - 1)
215
- * logsig
216
- )
217
- /
218
- (theta + logsig)
219
- )
220
-
221
- def marginal_prob(self, x0, y, t):
222
- return self._mean(x0, y, t), self._std(t)
223
-
224
- def prior_sampling(self, shape, y):
225
- if shape != y.shape:
226
- warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
227
- std = self._std(torch.ones((y.shape[0],), device=y.device))
228
- x_T = y + torch.randn_like(y) * std[:, None, None, None]
229
- return x_T
230
-
231
- def prior_logp(self, z):
232
- raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
233
-
234
-
235
- @SDERegistry.register("sbve")
236
- class SBVESDE(SDE):
237
- @staticmethod
238
- def add_argparse_args(parser):
239
- parser.add_argument("--N", type=int, default=50, help="The number of timesteps in the SDE discretization. 50 by default")
240
- parser.add_argument("--k", type=float, default=2.6, help="Parameter of the diffusion coefficient. 2.6 by default.")
241
- parser.add_argument("--c", type=float, default=0.4, help="Parameter of the diffusion coefficient. 0.4 by default.")
242
- parser.add_argument("--eps", type=float, default=1e-8, help="Small constant to avoid numerical instability. 1e-8 by default.")
243
- parser.add_argument("--sampler_type", type=str, default="ode")
244
- return parser
245
-
246
- def __init__(self, k, c, N=50, eps=1e-8, sampler_type="ode", **ignored_kwargs):
247
- """Construct a Schrodinger Bridge with Variance Exploding SDE.
248
-
249
- As described in Jukić et al., „Schrödinger Bridge for Generative Speech Enhancement“, 2024.
250
-
251
- Args:
252
- k: stiffness parameter.
253
- c: diffusion parameter.
254
- N: number of discretization steps
255
- """
256
- super().__init__(N)
257
- self.k = k
258
- self.c = c
259
- self.N = N
260
- self.eps = eps
261
- self.sampler_type = sampler_type
262
-
263
- def copy(self):
264
- return SBVESDE(self.k, self.c, N=self.N)
265
-
266
- @property
267
- def T(self):
268
- return 1
269
-
270
- def sde(self, x, y, t):
271
- f = 0.0 # Table 1
272
- g = torch.sqrt(torch.tensor(self.c)) * self.k**(t) # Table 1
273
- return f, g
274
-
275
- def _sigmas_alphas(self, t):
276
- alpha_t = torch.ones_like(t)
277
- alpha_T = torch.ones_like(t)
278
- sigma_t = torch.sqrt((self.c*(self.k**(2*t)-1.0)) \
279
- / (2*torch.log(torch.tensor(self.k)))) # Table 1
280
- sigma_T = torch.sqrt((self.c*(self.k**(2*self.T)-1.0)) \
281
- / (2*torch.log(torch.tensor(self.k)))) # Table 1
282
-
283
- alpha_bart = alpha_t / (alpha_T + self.eps) # below Eq. (9)
284
- sigma_bart = torch.sqrt(sigma_T**2 - sigma_t**2 + self.eps) # below Eq. (9)
285
-
286
- return sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart
287
-
288
- def _mean(self, x0, y, t):
289
- sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
290
-
291
- w_xt = alpha_t * sigma_bart**2 / (sigma_T**2 + self.eps) # below Eq. (11)
292
- w_yt = alpha_bart * sigma_t**2 / (sigma_T**2 + self.eps) # below Eq. (11)
293
-
294
- mu = w_xt[:, None, None, None] * x0 + w_yt[:, None, None, None] * y # Eq. (11)
295
- return mu
296
-
297
- def _std(self, t):
298
- sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
299
-
300
- sigma_xt = (alpha_t * sigma_bart * sigma_t) / (sigma_T + self.eps)
301
- return sigma_xt
302
-
303
- def marginal_prob(self, x0, y, t):
304
- return self._mean(x0, y, t), self._std(t)
305
-
306
- def prior_sampling(self, shape, y):
307
- if shape != y.shape:
308
- warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
309
- x_T = y
310
- return x_T
311
-
312
- def prior_logp(self, z):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  raise NotImplementedError("prior_logp for SBVE SDE not yet implemented!")
 
1
+ """
2
+ Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
+
4
+ Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
+ """
6
+ import abc
7
+ import warnings
8
+
9
+ import numpy as np
10
+ from sgmse.util.tensors import batch_broadcast
11
+ import torch
12
+
13
+ from sgmse.util.registry import Registry
14
+
15
+
16
+ SDERegistry = Registry("SDE")
17
+
18
+
19
+ class SDE(abc.ABC):
20
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
21
+
22
+ def __init__(self, N):
23
+ """Construct an SDE.
24
+
25
+ Args:
26
+ N: number of discretization time steps.
27
+ """
28
+ super().__init__()
29
+ self.N = N
30
+
31
+ @property
32
+ @abc.abstractmethod
33
+ def T(self):
34
+ """End time of the SDE."""
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def sde(self, x, t, *args):
39
+ pass
40
+
41
+ @abc.abstractmethod
42
+ def marginal_prob(self, x, t, *args):
43
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def prior_sampling(self, shape, *args):
48
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
49
+ pass
50
+
51
+ @abc.abstractmethod
52
+ def prior_logp(self, z):
53
+ """Compute log-density of the prior distribution.
54
+
55
+ Useful for computing the log-likelihood via probability flow ODE.
56
+
57
+ Args:
58
+ z: latent code
59
+ Returns:
60
+ log probability density
61
+ """
62
+ pass
63
+
64
+ @staticmethod
65
+ @abc.abstractmethod
66
+ def add_argparse_args(parent_parser):
67
+ """
68
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
69
+ """
70
+ pass
71
+
72
+ def discretize(self, x, t, y, stepsize):
73
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
74
+
75
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
76
+ Defaults to Euler-Maruyama discretization.
77
+
78
+ Args:
79
+ x: a torch tensor
80
+ t: a torch float representing the time step (from 0 to `self.T`)
81
+
82
+ Returns:
83
+ f, G
84
+ """
85
+ dt = stepsize
86
+ drift, diffusion = self.sde(x, t, y)
87
+ f = drift * dt
88
+ G = diffusion * torch.sqrt(dt)
89
+ return f, G
90
+
91
+ def reverse(oself, score_model, probability_flow=False):
92
+ """Create the reverse-time SDE/ODE.
93
+
94
+ Args:
95
+ score_model: A function that takes x, t and y and returns the score.
96
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
97
+ """
98
+ N = oself.N
99
+ T = oself.T
100
+ sde_fn = oself.sde
101
+ discretize_fn = oself.discretize
102
+
103
+ # Build the class for reverse-time SDE.
104
+ class RSDE(oself.__class__):
105
+ def __init__(self):
106
+ self.N = N
107
+ self.probability_flow = probability_flow
108
+
109
+ @property
110
+ def T(self):
111
+ return T
112
+
113
+ def sde(self, x, t, *args):
114
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
115
+ rsde_parts = self.rsde_parts(x, t, *args)
116
+ total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
117
+ return total_drift, diffusion
118
+
119
+ def rsde_parts(self, x, t, *args):
120
+ sde_drift, sde_diffusion = sde_fn(x, t, *args)
121
+ score = score_model(x, t, *args)
122
+ score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
123
+ diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
124
+ total_drift = sde_drift + score_drift
125
+ return {
126
+ 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
127
+ 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
128
+ }
129
+
130
+ def discretize(self, x, t, y, stepsize):
131
+ """Create discretized iteration rules for the reverse diffusion sampler."""
132
+ f, G = discretize_fn(x, t, y, stepsize)
133
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.)
134
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
135
+ return rev_f, rev_G
136
+
137
+ return RSDE()
138
+
139
+ @abc.abstractmethod
140
+ def copy(self):
141
+ pass
142
+
143
+
144
+ @SDERegistry.register("ouve")
145
+ class OUVESDE(SDE):
146
+ @staticmethod
147
+ def add_argparse_args(parser):
148
+ parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
149
+ parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
150
+ parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
151
+ parser.add_argument("--N", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
152
+ parser.add_argument("--sampler_type", type=str, default="pc", help="Type of sampler to use. 'pc' by default.")
153
+ return parser
154
+
155
+ def __init__(self, theta, sigma_min, sigma_max, N=30, sampler_type="pc", **ignored_kwargs):
156
+ """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
157
+
158
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
159
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
160
+
161
+ dx = -theta (y-x) dt + sigma(t) dw
162
+
163
+ with
164
+
165
+ sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
166
+
167
+ Args:
168
+ theta: stiffness parameter.
169
+ sigma_min: smallest sigma.
170
+ sigma_max: largest sigma.
171
+ N: number of discretization steps
172
+ """
173
+ super().__init__(N)
174
+ self.theta = theta
175
+ self.sigma_min = sigma_min
176
+ self.sigma_max = sigma_max
177
+ self.logsig = np.log(self.sigma_max / self.sigma_min)
178
+ self.N = N
179
+ self.sampler_type = sampler_type
180
+
181
+ def copy(self):
182
+ return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N, sampler_type=self.sampler_type)
183
+
184
+ @property
185
+ def T(self):
186
+ return 1
187
+
188
+ def sde(self, x, y, t):
189
+ drift = self.theta * (y - x)
190
+ # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
191
+ # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
192
+ # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
193
+ # unless this sqrt(2*logsig) factor is included.
194
+ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
195
+ diffusion = sigma * np.sqrt(2 * self.logsig)
196
+ return drift, diffusion
197
+
198
+ def _mean(self, x0, y, t):
199
+ theta = self.theta
200
+ exp_interp = torch.exp(-theta * t)[:, None, None, None]
201
+ return exp_interp * x0 + (1 - exp_interp) * y
202
+
203
+ def alpha(self, t):
204
+ return torch.exp(-self.theta * t)
205
+
206
+ def _std(self, t):
207
+ # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
208
+ sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
209
+ # could maybe replace the two torch.exp(... * t) terms here by cached values **t
210
+ return torch.sqrt(
211
+ (
212
+ sigma_min**2
213
+ * torch.exp(-2 * theta * t)
214
+ * (torch.exp(2 * (theta + logsig) * t) - 1)
215
+ * logsig
216
+ )
217
+ /
218
+ (theta + logsig)
219
+ )
220
+
221
+ def marginal_prob(self, x0, y, t):
222
+ return self._mean(x0, y, t), self._std(t)
223
+
224
+ def prior_sampling(self, shape, y):
225
+ if shape != y.shape:
226
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
227
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
228
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
229
+ return x_T
230
+
231
+ def prior_logp(self, z):
232
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
233
+
234
+
235
+ @SDERegistry.register("ouvp")
236
+ class OUVPSDE(SDE):
237
+ # !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
238
+ @staticmethod
239
+ def add_argparse_args(parser):
240
+ parser.add_argument("--sde-n", type=int, default=1000,
241
+ help="The number of timesteps in the SDE discretization. 1000 by default")
242
+ parser.add_argument("--beta-min", type=float, required=True,
243
+ help="The minimum beta to use.")
244
+ parser.add_argument("--beta-max", type=float, required=True,
245
+ help="The maximum beta to use.")
246
+ parser.add_argument("--stiffness", type=float, default=1,
247
+ help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.")
248
+ return parser
249
+
250
+ def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs):
251
+ """
252
+ !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!!
253
+
254
+ Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
255
+
256
+ dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
257
+
258
+ with
259
+
260
+ beta(t) = beta_min + t(beta_max - beta_min)
261
+
262
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
263
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
264
+
265
+ Args:
266
+ beta_min: smallest sigma.
267
+ beta_max: largest sigma.
268
+ stiffness: stiffness factor of the drift. 1 by default.
269
+ N: number of discretization steps
270
+ """
271
+ super().__init__(N)
272
+ self.beta_min = beta_min
273
+ self.beta_max = beta_max
274
+ self.stiffness = stiffness
275
+ self.N = N
276
+
277
+ def copy(self):
278
+ return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N)
279
+
280
+ @property
281
+ def T(self):
282
+ return 1
283
+
284
+ def _beta(self, t):
285
+ return self.beta_min + t * (self.beta_max - self.beta_min)
286
+
287
+ def sde(self, x, t, y):
288
+ drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x)
289
+ diffusion = torch.sqrt(self._beta(t))
290
+ return drift, diffusion
291
+
292
+ def _mean(self, x0, t, y):
293
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
294
+ x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None]
295
+ return y + x0y_fac * (x0 - y)
296
+
297
+ def _std(self, t):
298
+ b0, b1, s = self.beta_min, self.beta_max, self.stiffness
299
+ return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s
300
+
301
+ def marginal_prob(self, x0, t, y):
302
+ return self._mean(x0, t, y), self._std(t)
303
+
304
+ def prior_sampling(self, shape, y):
305
+ if shape != y.shape:
306
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
307
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
308
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
309
+ return x_T
310
+
311
+ def prior_logp(self, z):
312
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
313
+
314
+ @SDERegistry.register("sbve")
315
+ class SBVESDE(SDE):
316
+ @staticmethod
317
+ def add_argparse_args(parser):
318
+ parser.add_argument("--N", type=int, default=50, help="The number of timesteps in the SDE discretization. 50 by default")
319
+ parser.add_argument("--k", type=float, default=2.6, help="Parameter of the diffusion coefficient. 2.6 by default.")
320
+ parser.add_argument("--c", type=float, default=0.4, help="Parameter of the diffusion coefficient. 0.4 by default.")
321
+ parser.add_argument("--eps", type=float, default=1e-8, help="Small constant to avoid numerical instability. 1e-8 by default.")
322
+ parser.add_argument("--sampler_type", type=str, default="ode")
323
+ return parser
324
+
325
+ def __init__(self, k, c, N=50, eps=1e-8, sampler_type="ode", **ignored_kwargs):
326
+ """Construct a Schrodinger Bridge with Variance Exploding SDE.
327
+
328
+ As described in Jukić et al., „Schrödinger Bridge for Generative Speech Enhancement“, 2024.
329
+
330
+ Args:
331
+ k: stiffness parameter.
332
+ c: diffusion parameter.
333
+ N: number of discretization steps
334
+ """
335
+ super().__init__(N)
336
+ self.k = k
337
+ self.c = c
338
+ self.N = N
339
+ self.eps = eps
340
+ self.sampler_type = sampler_type
341
+
342
+ def copy(self):
343
+ return SBVESDE(self.k, self.c, N=self.N)
344
+
345
+ @property
346
+ def T(self):
347
+ return 1
348
+
349
+ def sde(self, x, y, t):
350
+ f = 0.0 # Table 1
351
+ g = torch.sqrt(torch.tensor(self.c)) * self.k**(t) # Table 1
352
+ return f, g
353
+
354
+ def _sigmas_alphas(self, t):
355
+ alpha_t = torch.ones_like(t)
356
+ alpha_T = torch.ones_like(t)
357
+ sigma_t = torch.sqrt((self.c*(self.k**(2*t)-1.0)) \
358
+ / (2*torch.log(torch.tensor(self.k)))) # Table 1
359
+ sigma_T = torch.sqrt((self.c*(self.k**(2*self.T)-1.0)) \
360
+ / (2*torch.log(torch.tensor(self.k)))) # Table 1
361
+
362
+ alpha_bart = alpha_t / (alpha_T + self.eps) # below Eq. (9)
363
+ sigma_bart = torch.sqrt(sigma_T**2 - sigma_t**2 + self.eps) # below Eq. (9)
364
+
365
+ return sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart
366
+
367
+ def _mean(self, x0, y, t):
368
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
369
+
370
+ w_xt = alpha_t * sigma_bart**2 / (sigma_T**2 + self.eps) # below Eq. (11)
371
+ w_yt = alpha_bart * sigma_t**2 / (sigma_T**2 + self.eps) # below Eq. (11)
372
+
373
+ mu = w_xt[:, None, None, None] * x0 + w_yt[:, None, None, None] * y # Eq. (11)
374
+ return mu
375
+
376
+ def _std(self, t):
377
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
378
+
379
+ sigma_xt = (alpha_t * sigma_bart * sigma_t) / (sigma_T + self.eps)
380
+ return sigma_xt
381
+
382
+ def marginal_prob(self, x0, y, t):
383
+ return self._mean(x0, y, t), self._std(t)
384
+
385
+ def prior_sampling(self, shape, y):
386
+ if shape != y.shape:
387
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
388
+ x_T = y
389
+ return x_T
390
+
391
+ def prior_logp(self, z):
392
  raise NotImplementedError("prior_logp for SBVE SDE not yet implemented!")
sgmse/util/inference.py CHANGED
@@ -1,64 +1,64 @@
1
- import torch
2
- from torchaudio import load
3
-
4
- from pesq import pesq
5
- from pystoi import stoi
6
-
7
- from .other import si_sdr, pad_spec
8
-
9
- # Settings
10
- sr = 16000
11
- snr = 0.5
12
- N = 30
13
- corrector_steps = 1
14
-
15
-
16
- def evaluate_model(model, num_eval_files):
17
-
18
- clean_files = model.data_module.valid_set.clean_files
19
- noisy_files = model.data_module.valid_set.noisy_files
20
-
21
- # Select test files uniformly accros validation files
22
- total_num_files = len(clean_files)
23
- indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
24
- clean_files = list(clean_files[i] for i in indices)
25
- noisy_files = list(noisy_files[i] for i in indices)
26
-
27
- _pesq = 0
28
- _si_sdr = 0
29
- _estoi = 0
30
- # iterate over files
31
- for (clean_file, noisy_file) in zip(clean_files, noisy_files):
32
- # Load wavs
33
- x, _ = load(clean_file)
34
- y, _ = load(noisy_file)
35
- T_orig = x.size(1)
36
-
37
- # Normalize per utterance
38
- norm_factor = y.abs().max()
39
- y = y / norm_factor
40
-
41
- # Prepare DNN input
42
- Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
43
- Y = pad_spec(Y)
44
- y = y * norm_factor
45
-
46
- # Reverse sampling
47
- sampler = model.get_pc_sampler(
48
- 'reverse_diffusion', 'ald', Y.cuda(), N=N,
49
- corrector_steps=corrector_steps, snr=snr)
50
- sample, _ = sampler()
51
-
52
- x_hat = model.to_audio(sample.squeeze(), T_orig)
53
- x_hat = x_hat * norm_factor
54
-
55
- x_hat = x_hat.squeeze().cpu().numpy()
56
- x = x.squeeze().cpu().numpy()
57
- y = y.squeeze().cpu().numpy()
58
-
59
- _si_sdr += si_sdr(x, x_hat)
60
- _pesq += pesq(sr, x, x_hat, 'wb')
61
- _estoi += stoi(x, x_hat, sr, extended=True)
62
-
63
- return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
64
-
 
1
+ import torch
2
+ from torchaudio import load
3
+
4
+ from pesq import pesq
5
+ from pystoi import stoi
6
+
7
+ from .other import si_sdr, pad_spec
8
+
9
+ # Settings
10
+ sr = 16000
11
+ snr = 0.5
12
+ N = 30
13
+ corrector_steps = 1
14
+
15
+
16
+ def evaluate_model(model, num_eval_files):
17
+
18
+ clean_files = model.data_module.valid_set.clean_files
19
+ noisy_files = model.data_module.valid_set.noisy_files
20
+
21
+ # Select test files uniformly accros validation files
22
+ total_num_files = len(clean_files)
23
+ indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int)
24
+ clean_files = list(clean_files[i] for i in indices)
25
+ noisy_files = list(noisy_files[i] for i in indices)
26
+
27
+ _pesq = 0
28
+ _si_sdr = 0
29
+ _estoi = 0
30
+ # iterate over files
31
+ for (clean_file, noisy_file) in zip(clean_files, noisy_files):
32
+ # Load wavs
33
+ x, _ = load(clean_file)
34
+ y, _ = load(noisy_file)
35
+ T_orig = x.size(1)
36
+
37
+ # Normalize per utterance
38
+ norm_factor = y.abs().max()
39
+ y = y / norm_factor
40
+
41
+ # Prepare DNN input
42
+ Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
43
+ Y = pad_spec(Y)
44
+ y = y * norm_factor
45
+
46
+ # Reverse sampling
47
+ sampler = model.get_pc_sampler(
48
+ 'reverse_diffusion', 'ald', Y.cuda(), N=N,
49
+ corrector_steps=corrector_steps, snr=snr)
50
+ sample, _ = sampler()
51
+
52
+ x_hat = model.to_audio(sample.squeeze(), T_orig)
53
+ x_hat = x_hat * norm_factor
54
+
55
+ x_hat = x_hat.squeeze().cpu().numpy()
56
+ x = x.squeeze().cpu().numpy()
57
+ y = y.squeeze().cpu().numpy()
58
+
59
+ _si_sdr += si_sdr(x, x_hat)
60
+ _pesq += pesq(sr, x, x_hat, 'wb')
61
+ _estoi += stoi(x, x_hat, sr, extended=True)
62
+
63
+ return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files
64
+
sgmse/util/other.py CHANGED
@@ -1,141 +1,141 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import scipy.stats
5
- from scipy.signal import butter, sosfilt
6
-
7
- from pesq import pesq
8
- from pystoi import stoi
9
-
10
-
11
- def si_sdr_components(s_hat, s, n):
12
- # s_target
13
- alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
14
- s_target = alpha_s * s
15
-
16
- # e_noise
17
- alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
18
- e_noise = alpha_n * n
19
-
20
- # e_art
21
- e_art = s_hat - s_target - e_noise
22
-
23
- return s_target, e_noise, e_art
24
-
25
- def energy_ratios(s_hat, s, n):
26
- s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
27
-
28
- si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
29
- si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
30
- si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
31
-
32
- return si_sdr, si_sir, si_sar
33
-
34
- def mean_conf_int(data, confidence=0.95):
35
- a = 1.0 * np.array(data)
36
- n = len(a)
37
- m, se = np.mean(a), scipy.stats.sem(a)
38
- h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
39
- return m, h
40
-
41
- class Method():
42
- def __init__(self, name, base_dir, metrics):
43
- self.name = name
44
- self.base_dir = base_dir
45
- self.metrics = {}
46
-
47
- for i in range(len(metrics)):
48
- metric = metrics[i]
49
- value = []
50
- self.metrics[metric] = value
51
-
52
- def append(self, matric, value):
53
- self.metrics[matric].append(value)
54
-
55
- def get_mean_ci(self, metric):
56
- return mean_conf_int(np.array(self.metrics[metric]))
57
-
58
- def hp_filter(signal, cut_off=80, order=10, sr=16000):
59
- factor = cut_off /sr * 2
60
- sos = butter(order, factor, 'hp', output='sos')
61
- filtered = sosfilt(sos, signal)
62
- return filtered
63
-
64
- def si_sdr(s, s_hat):
65
- alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
66
- sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
67
- alpha*s - s_hat)**2)
68
- return sdr
69
-
70
- def snr_dB(s,n):
71
- s_power = 1/len(s)*np.sum(s**2)
72
- n_power = 1/len(n)*np.sum(n**2)
73
- snr_dB = 10*np.log10(s_power/n_power)
74
- return snr_dB
75
-
76
- def pad_spec(Y, mode="zero_pad"):
77
- T = Y.size(3)
78
- if T%64 !=0:
79
- num_pad = 64-T%64
80
- else:
81
- num_pad = 0
82
- if mode == "zero_pad":
83
- pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
84
- elif mode == "reflection":
85
- pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
86
- elif mode == "replication":
87
- pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
88
- else:
89
- raise NotImplementedError("This function hasn't been implemented yet.")
90
- return pad2d(Y)
91
-
92
- def ensure_dir(file_path):
93
- directory = file_path
94
- if not os.path.exists(directory):
95
- os.makedirs(directory)
96
-
97
-
98
- def print_metrics(x, y, x_hat_list, labels, sr=16000):
99
- _si_sdr_mix = si_sdr(x, y)
100
- _pesq_mix = pesq(sr, x, y, 'wb')
101
- _estoi_mix = stoi(x, y, sr, extended=True)
102
- print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
103
- for i, x_hat in enumerate(x_hat_list):
104
- _si_sdr = si_sdr(x, x_hat)
105
- _pesq = pesq(sr, x, x_hat, 'wb')
106
- _estoi = stoi(x, x_hat, sr, extended=True)
107
- print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
108
-
109
- def mean_std(data):
110
- data = data[~np.isnan(data)]
111
- mean = np.mean(data)
112
- std = np.std(data)
113
- return mean, std
114
-
115
- def print_mean_std(data, decimal=2):
116
- data = np.array(data)
117
- data = data[~np.isnan(data)]
118
- mean = np.mean(data)
119
- std = np.std(data)
120
- if decimal == 2:
121
- string = f'{mean:.2f} ± {std:.2f}'
122
- elif decimal == 1:
123
- string = f'{mean:.1f} ± {std:.1f}'
124
- return string
125
-
126
- def set_torch_cuda_arch_list():
127
- if not torch.cuda.is_available():
128
- print("CUDA is not available. No GPUs found.")
129
- return
130
-
131
- num_gpus = torch.cuda.device_count()
132
- compute_capabilities = []
133
-
134
- for i in range(num_gpus):
135
- cc_major, cc_minor = torch.cuda.get_device_capability(i)
136
- cc = f"{cc_major}.{cc_minor}"
137
- compute_capabilities.append(cc)
138
-
139
- cc_string = ";".join(compute_capabilities)
140
- os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
141
  print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import scipy.stats
5
+ from scipy.signal import butter, sosfilt
6
+
7
+ from pesq import pesq
8
+ from pystoi import stoi
9
+
10
+
11
+ def si_sdr_components(s_hat, s, n):
12
+ # s_target
13
+ alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
14
+ s_target = alpha_s * s
15
+
16
+ # e_noise
17
+ alpha_n = np.dot(s_hat, n) / np.linalg.norm(n)**2
18
+ e_noise = alpha_n * n
19
+
20
+ # e_art
21
+ e_art = s_hat - s_target - e_noise
22
+
23
+ return s_target, e_noise, e_art
24
+
25
+ def energy_ratios(s_hat, s, n):
26
+ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
27
+
28
+ si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
29
+ si_sir = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise)**2)
30
+ si_sar = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_art)**2)
31
+
32
+ return si_sdr, si_sir, si_sar
33
+
34
+ def mean_conf_int(data, confidence=0.95):
35
+ a = 1.0 * np.array(data)
36
+ n = len(a)
37
+ m, se = np.mean(a), scipy.stats.sem(a)
38
+ h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)
39
+ return m, h
40
+
41
+ class Method():
42
+ def __init__(self, name, base_dir, metrics):
43
+ self.name = name
44
+ self.base_dir = base_dir
45
+ self.metrics = {}
46
+
47
+ for i in range(len(metrics)):
48
+ metric = metrics[i]
49
+ value = []
50
+ self.metrics[metric] = value
51
+
52
+ def append(self, matric, value):
53
+ self.metrics[matric].append(value)
54
+
55
+ def get_mean_ci(self, metric):
56
+ return mean_conf_int(np.array(self.metrics[metric]))
57
+
58
+ def hp_filter(signal, cut_off=80, order=10, sr=16000):
59
+ factor = cut_off /sr * 2
60
+ sos = butter(order, factor, 'hp', output='sos')
61
+ filtered = sosfilt(sos, signal)
62
+ return filtered
63
+
64
+ def si_sdr(s, s_hat):
65
+ alpha = np.dot(s_hat, s)/np.linalg.norm(s)**2
66
+ sdr = 10*np.log10(np.linalg.norm(alpha*s)**2/np.linalg.norm(
67
+ alpha*s - s_hat)**2)
68
+ return sdr
69
+
70
+ def snr_dB(s,n):
71
+ s_power = 1/len(s)*np.sum(s**2)
72
+ n_power = 1/len(n)*np.sum(n**2)
73
+ snr_dB = 10*np.log10(s_power/n_power)
74
+ return snr_dB
75
+
76
+ def pad_spec(Y, mode="zero_pad"):
77
+ T = Y.size(3)
78
+ if T%64 !=0:
79
+ num_pad = 64-T%64
80
+ else:
81
+ num_pad = 0
82
+ if mode == "zero_pad":
83
+ pad2d = torch.nn.ZeroPad2d((0, num_pad, 0,0))
84
+ elif mode == "reflection":
85
+ pad2d = torch.nn.ReflectionPad2d((0, num_pad, 0,0))
86
+ elif mode == "replication":
87
+ pad2d = torch.nn.ReplicationPad2d((0, num_pad, 0,0))
88
+ else:
89
+ raise NotImplementedError("This function hasn't been implemented yet.")
90
+ return pad2d(Y)
91
+
92
+ def ensure_dir(file_path):
93
+ directory = file_path
94
+ if not os.path.exists(directory):
95
+ os.makedirs(directory)
96
+
97
+
98
+ def print_metrics(x, y, x_hat_list, labels, sr=16000):
99
+ _si_sdr_mix = si_sdr(x, y)
100
+ _pesq_mix = pesq(sr, x, y, 'wb')
101
+ _estoi_mix = stoi(x, y, sr, extended=True)
102
+ print(f'Mixture: PESQ: {_pesq_mix:.2f}, ESTOI: {_estoi_mix:.2f}, SI-SDR: {_si_sdr_mix:.2f}')
103
+ for i, x_hat in enumerate(x_hat_list):
104
+ _si_sdr = si_sdr(x, x_hat)
105
+ _pesq = pesq(sr, x, x_hat, 'wb')
106
+ _estoi = stoi(x, x_hat, sr, extended=True)
107
+ print(f'{labels[i]}: {_pesq:.2f}, ESTOI: {_estoi:.2f}, SI-SDR: {_si_sdr:.2f}')
108
+
109
+ def mean_std(data):
110
+ data = data[~np.isnan(data)]
111
+ mean = np.mean(data)
112
+ std = np.std(data)
113
+ return mean, std
114
+
115
+ def print_mean_std(data, decimal=2):
116
+ data = np.array(data)
117
+ data = data[~np.isnan(data)]
118
+ mean = np.mean(data)
119
+ std = np.std(data)
120
+ if decimal == 2:
121
+ string = f'{mean:.2f} ± {std:.2f}'
122
+ elif decimal == 1:
123
+ string = f'{mean:.1f} ± {std:.1f}'
124
+ return string
125
+
126
+ def set_torch_cuda_arch_list():
127
+ if not torch.cuda.is_available():
128
+ print("CUDA is not available. No GPUs found.")
129
+ return
130
+
131
+ num_gpus = torch.cuda.device_count()
132
+ compute_capabilities = []
133
+
134
+ for i in range(num_gpus):
135
+ cc_major, cc_minor = torch.cuda.get_device_capability(i)
136
+ cc = f"{cc_major}.{cc_minor}"
137
+ compute_capabilities.append(cc)
138
+
139
+ cc_string = ";".join(compute_capabilities)
140
+ os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
141
  print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
sgmse/util/registry.py CHANGED
@@ -1,34 +1,34 @@
1
- import warnings
2
- from typing import Callable
3
-
4
-
5
- class Registry:
6
- def __init__(self, managed_thing: str):
7
- """
8
- Create a new registry.
9
-
10
- Args:
11
- managed_thing: A string describing what type of thing is managed by this registry. Will be used for
12
- warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
13
- """
14
- self.managed_thing = managed_thing
15
- self._registry = {}
16
-
17
- def register(self, name: str) -> Callable:
18
- def inner_wrapper(wrapped_class) -> Callable:
19
- if name in self._registry:
20
- warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
21
- self._registry[name] = wrapped_class
22
- return wrapped_class
23
- return inner_wrapper
24
-
25
- def get_by_name(self, name: str):
26
- """Get a managed thing by name."""
27
- if name in self._registry:
28
- return self._registry[name]
29
- else:
30
- raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
31
-
32
- def get_all_names(self):
33
- """Get the list of things' names registered to this registry."""
34
- return list(self._registry.keys())
 
1
+ import warnings
2
+ from typing import Callable
3
+
4
+
5
+ class Registry:
6
+ def __init__(self, managed_thing: str):
7
+ """
8
+ Create a new registry.
9
+
10
+ Args:
11
+ managed_thing: A string describing what type of thing is managed by this registry. Will be used for
12
+ warnings and errors, so it's a good idea to keep this string globally unique and easily understood.
13
+ """
14
+ self.managed_thing = managed_thing
15
+ self._registry = {}
16
+
17
+ def register(self, name: str) -> Callable:
18
+ def inner_wrapper(wrapped_class) -> Callable:
19
+ if name in self._registry:
20
+ warnings.warn(f"{self.managed_thing} with name '{name}' doubly registered, old class will be replaced.")
21
+ self._registry[name] = wrapped_class
22
+ return wrapped_class
23
+ return inner_wrapper
24
+
25
+ def get_by_name(self, name: str):
26
+ """Get a managed thing by name."""
27
+ if name in self._registry:
28
+ return self._registry[name]
29
+ else:
30
+ raise ValueError(f"{self.managed_thing} with name '{name}' unknown.")
31
+
32
+ def get_all_names(self):
33
+ """Get the list of things' names registered to this registry."""
34
+ return list(self._registry.keys())
sgmse/util/tensors.py CHANGED
@@ -1,16 +1,16 @@
1
- def batch_broadcast(a, x):
2
- """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
3
-
4
- if len(a.shape) != 1:
5
- a = a.squeeze()
6
- if len(a.shape) != 1:
7
- raise ValueError(
8
- f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
9
- )
10
-
11
- if a.shape[0] != x.shape[0] and a.shape[0] != 1:
12
- raise ValueError(
13
- f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
14
-
15
- out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
16
- return out
 
1
+ def batch_broadcast(a, x):
2
+ """Broadcasts a over all dimensions of x, except the batch dimension, which must match."""
3
+
4
+ if len(a.shape) != 1:
5
+ a = a.squeeze()
6
+ if len(a.shape) != 1:
7
+ raise ValueError(
8
+ f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})"
9
+ )
10
+
11
+ if a.shape[0] != x.shape[0] and a.shape[0] != 1:
12
+ raise ValueError(
13
+ f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching")
14
+
15
+ out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1))))
16
+ return out