thefynnbe commited on
Commit
9af5e69
·
verified ·
1 Parent(s): 595b472

Upload 1.3 with bioimageio.spec 0.5.7.1

Browse files
Files changed (5) hide show
  1. README.md +16 -14
  2. package/README.md +6 -428
  3. package/bioimageio.yaml +1 -5
  4. package/environment.yaml +0 -10
  5. package/model.py +951 -289
README.md CHANGED
@@ -10,7 +10,6 @@ HyLFM-Net trained on static images of arrested medaka hatchling hearts. The netw
10
  - [Bias, Risks, and Limitations](#bias-risks-and-limitations)
11
  - [How to Get Started with the Model](#how-to-get-started-with-the-model)
12
  - [Training Details](#training-details)
13
- - [Evaluation](#evaluation)
14
  - [Environmental Impact](#environmental-impact)
15
  - [Technical Specifications](#technical-specifications)
16
 
@@ -19,7 +18,7 @@ HyLFM-Net trained on static images of arrested medaka hatchling hearts. The netw
19
 
20
  ## Model Description
21
 
22
- - **model version:** 1.2
23
  - **Additional model documentation:** [package/README.md](package/README.md)
24
  - **Developed by:**
25
  - Beuttenmueller, Wagner, N., F., Norlin, N. et al. Deep learning-enhanced light-field imaging with continuous validation. Nat Methods 18, 557–563 (2021).: https://www.doi.org/10.1038/s41592-021-01136-0
@@ -42,7 +41,15 @@ HyLFM-Net trained on static images of arrested medaka hatchling hearts. The netw
42
 
43
  This model is compatible with the bioimageio.spec Python package (version >= 0.5.7.1) and the bioimageio.core Python package supporting model inference in Python code or via the `bioimageio` CLI.
44
 
 
 
45
 
 
 
 
 
 
 
46
 
47
  ## Downstream Use
48
 
@@ -88,7 +95,7 @@ Users (both direct and downstream) should be made aware of the risks, biases and
88
 
89
  # How to Get Started with the Model
90
 
91
- You can use "huggingface/thefynnbe/ambitious-sloth/v1.2" as the resource identifier to load this model directly from the Hugging Face Hub using bioimageio.spec or bioimageio.core.
92
 
93
  See [bioimageio.core documentation: Get started](https://bioimage-io.github.io/core-bioimage-io-python/latest/get-started) for instructions on how to load and run this model using the `bioimageio.core` Python package or the bioimageio CLI.
94
 
@@ -109,20 +116,13 @@ This model was trained on `10.5281/zenodo.7612115`.
109
  - **Model size:** 234.44 MB
110
 
111
 
112
- # Evaluation
113
-
114
- missing
115
- ### Validation on External Data
116
-
117
- missing
118
-
119
  # Environmental Impact
120
 
121
  - **Hardware Type:** GTX 2080 Ti
122
  - **Hours used:** 10.0
123
  - **Cloud Provider:** EMBL Heidelberg
124
  - **Compute Region:** Germany
125
- - **Carbon Emitted:** 0.54
126
 
127
 
128
 
@@ -138,7 +138,8 @@ missing
138
  - Axes: `batch, channel, y, x`
139
  - Shape: `1 × 1 × 1235 × 1425`
140
  - Data type: `float32`
141
- - Values: 1.0 arbitrary unit with offset: None in range (None, None)
 
142
  - example
143
  ![lf sample](images/input_lf_sample.png)
144
 
@@ -147,7 +148,8 @@ missing
147
  - Axes: `batch, channel, z, y, x`
148
  - Shape: `1 × 1 × 49 × 244 × 284`
149
  - Data type: `float32`
150
- - Values: 1.0 arbitrary unit with offset: None in range (None, None)
 
151
  - example
152
  prediction sample](images/output_prediction_sample.png)
153
 
@@ -162,7 +164,7 @@ missing
162
  ### Software
163
 
164
  - **Framework:** ONNX: opset version: 15 or Pytorch State Dict: 1.13 or TorchScript: 1.13
165
- - **Libraries:** Dependencies for Pytorch State dict weights are listed in [environment.yaml](package/environment.yaml).
166
  - **BioImage.IO partner compatibility:** [Compatibility Reports](https://bioimage-io.github.io/collection/latest/compatibility/#compatibility-by-resource)
167
 
168
  ---
 
10
  - [Bias, Risks, and Limitations](#bias-risks-and-limitations)
11
  - [How to Get Started with the Model](#how-to-get-started-with-the-model)
12
  - [Training Details](#training-details)
 
13
  - [Environmental Impact](#environmental-impact)
14
  - [Technical Specifications](#technical-specifications)
15
 
 
18
 
19
  ## Model Description
20
 
21
+ - **model version:** 1.3
22
  - **Additional model documentation:** [package/README.md](package/README.md)
23
  - **Developed by:**
24
  - Beuttenmueller, Wagner, N., F., Norlin, N. et al. Deep learning-enhanced light-field imaging with continuous validation. Nat Methods 18, 557–563 (2021).: https://www.doi.org/10.1038/s41592-021-01136-0
 
41
 
42
  This model is compatible with the bioimageio.spec Python package (version >= 0.5.7.1) and the bioimageio.core Python package supporting model inference in Python code or via the `bioimageio` CLI.
43
 
44
+ ```python
45
+ from bioimageio.core import predict
46
 
47
+ output_sample = predict("huggingface/thefynnbe/ambitious-sloth/1.3", inputs={'lf': '<path or tensor>'})
48
+
49
+ output_tensor = output_sample.members["prediction"]
50
+ xarray_dataarray = output_tensor.data
51
+ numpy_ndarray = output_tensor.data.to_numpy()
52
+ ```
53
 
54
  ## Downstream Use
55
 
 
95
 
96
  # How to Get Started with the Model
97
 
98
+ You can use "huggingface/thefynnbe/ambitious-sloth/1.3" as the resource identifier to load this model directly from the Hugging Face Hub using bioimageio.spec or bioimageio.core.
99
 
100
  See [bioimageio.core documentation: Get started](https://bioimage-io.github.io/core-bioimage-io-python/latest/get-started) for instructions on how to load and run this model using the `bioimageio.core` Python package or the bioimageio CLI.
101
 
 
116
  - **Model size:** 234.44 MB
117
 
118
 
 
 
 
 
 
 
 
119
  # Environmental Impact
120
 
121
  - **Hardware Type:** GTX 2080 Ti
122
  - **Hours used:** 10.0
123
  - **Cloud Provider:** EMBL Heidelberg
124
  - **Compute Region:** Germany
125
+ - **Carbon Emitted:** 0.54 kg CO2e
126
 
127
 
128
 
 
138
  - Axes: `batch, channel, y, x`
139
  - Shape: `1 × 1 × 1235 × 1425`
140
  - Data type: `float32`
141
+ - Value unit: arbitrary unit
142
+ - Value scale factor: 1.0
143
  - example
144
  ![lf sample](images/input_lf_sample.png)
145
 
 
148
  - Axes: `batch, channel, z, y, x`
149
  - Shape: `1 × 1 × 49 × 244 × 284`
150
  - Data type: `float32`
151
+ - Value unit: arbitrary unit
152
+ - Value scale factor: 1.0
153
  - example
154
  prediction sample](images/output_prediction_sample.png)
155
 
 
164
  ### Software
165
 
166
  - **Framework:** ONNX: opset version: 15 or Pytorch State Dict: 1.13 or TorchScript: 1.13
167
+ - **Libraries:** None beyond the respective framework library.
168
  - **BioImage.IO partner compatibility:** [Compatibility Reports](https://bioimage-io.github.io/collection/latest/compatibility/#compatibility-by-resource)
169
 
170
  ---
package/README.md CHANGED
@@ -1,431 +1,9 @@
1
- ![License](https://img.shields.io/github/license/bioimage-io/spec-bioimage-io.svg)
2
- ![PyPI](https://img.shields.io/pypi/v/bioimageio-spec.svg?style=popout)
3
- ![conda-version](https://anaconda.org/conda-forge/bioimageio.spec/badges/version.svg)
4
 
5
- # Specifications for bioimage.io
 
 
6
 
7
- This repository contains specifications defined by the bioimage.io community. These specifications are used for defining fields in YAML 1.2 files which should be named `rdf.yaml`. Such a rdf.yaml --- along with files referenced in it --- can be downloaded from or uploaded to the [bioimage.io website](https://bioimage.io) and may be produced or consumed by bioimage.io-compatible consumers (e.g. image analysis software like ilastik).
8
 
9
- bioimage.io-compatible resources must fulfill the following rules:
10
-
11
- Note that the Python package PyYAML does not support YAML 1.2 .
12
- We therefore use and recommend [ruyaml](https://ruyaml.readthedocs.io/en/latest/).
13
- For differences see <https://ruamelyaml.readthedocs.io/en/latest/pyyaml>.
14
-
15
- Please also note that the best way to check whether your `rdf.yaml` file is bioimage.io-compliant is to call `bioimageio.core.validate` from the [bioimageio.core](https://github.com/bioimage-io/core-bioimage-io-python) Python package.
16
- The [bioimageio.core](https://github.com/bioimage-io/core-bioimage-io-python) Python package also provides the bioimageio command line interface (CLI) with the `validate` command:
17
-
18
- ```terminal
19
- bioimageio validate path/to/your/rdf.yaml
20
- ```
21
-
22
- ## Format version overview
23
-
24
- All bioimage.io description formats are defined as [Pydantic models](https://docs.pydantic.dev/latest/).
25
-
26
- | type | format version | documentation |
27
- | --- | --- | --- |
28
- | model | 0.5 </br> 0.4 | [model_descr_v0-5.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/model_descr_v0-5.md) </br> [model_descr_v0-4.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/model_descr_v0-4.md) |
29
- | dataset | 0.3 </br> 0.2 | [dataset_descr_v0-3.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/dataset_descr_v0-3.md) </br> [dataset_descr_v0-2.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/dataset_descr_v0-2.md) |
30
- | notebook | 0.3 </br> 0.2 | [notebook_descr_v0-3.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/notebook_descr_v0-3.md) </br> [notebook_descr_v0-2.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/notebook_descr_v0-2.md) |
31
- | application | 0.3 </br> 0.2 | [application_descr_v0-3.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/application_descr_v0-3.md) </br> [application_descr_v0-2.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/application_descr_v0-2.md) |
32
- | collection | 0.3 </br> 0.2 | [collection_descr_v0-3.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/collection_descr_v0-3.md) </br> [collection_descr_v0-2.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/collection_descr_v0-2.md) |
33
- | generic | 0.3 </br> 0.2 | [generic_descr_v0-3.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/generic_descr_v0-3.md) </br> [generic_descr_v0-2.md](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/user_docs/generic_descr_v0-2.md) |
34
-
35
- ## JSON schema
36
-
37
- Simplified descriptions are available as [JSON schema](https://json-schema.org/):
38
-
39
- | bioimageio.spec version | JSON schema |
40
- | --- | --- |
41
- | latest | [bioimageio_schema_latest.json](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/bioimageio_schema_latest.json) |
42
- | 0.5 | [bioimageio_schema_v0-5.json](https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/bioimageio_schema_v0-5.json) |
43
-
44
- These are primarily intended for syntax highlighting and form generation.
45
-
46
- ## Examples
47
-
48
- We provide some [examples for using rdf.yaml files to describe models, applications, notebooks and datasets](https://github.com/bioimage-io/spec-bioimage-io/blob/main/example_descriptions/examples.md).
49
-
50
- ## 💁 Recommendations
51
-
52
- * Due to the limitations of storage services such as Zenodo, which does not support subfolders, it is recommended to place other files in the same directory level of the `rdf.yaml` file and try to avoid using subdirectories.
53
- * Use the [bioimageio.core Python package](https://github.com/bioimage-io/core-bioimage-io-python) to validate your `rdf.yaml` file.
54
- * bioimageio.spec keeps evolving. Try to use and upgrade to the most current format version!
55
-
56
- ## ⌨ bioimageio command-line interface (CLI)
57
-
58
- The bioimageio CLI has moved entirely to [bioimageio.core](https://github.com/bioimage-io/core-bioimage-io-python).
59
-
60
- ## 🖥 Installation
61
-
62
- bioimageio.spec can be installed with either `conda` or `pip`, we recommend to install `bioimageio.core` instead:
63
-
64
- ```console
65
- conda install -c conda-forge bioimageio.core
66
- ```
67
-
68
- or
69
-
70
- ```console
71
- pip install -U bioimageio.core
72
- ```
73
-
74
- ## 🏞 Environment variables
75
-
76
- TODO: link to settings in dev docs
77
-
78
- ## 🤝 How to contribute
79
-
80
- ## ♥ Contributors
81
-
82
- <a href="https://github.com/bioimage-io/spec-bioimage-io/graphs/contributors">
83
- <img alt="bioimageio.spec contributors" src="https://contrib.rocks/image?repo=bioimage-io/spec-bioimage-io" />
84
- </a>
85
-
86
- Made with [contrib.rocks](https://contrib.rocks).
87
-
88
- ## Δ Changelog
89
-
90
- ### bioimageio.spec Python package
91
-
92
- #### bioimageio.spec 0.5.2post1
93
-
94
- * fix model packaging with weights format priority
95
-
96
- #### bioimageio.spec 0.5.2
97
-
98
- * new patch version model 0.5.2
99
-
100
- #### bioimageio.spec 0.5.1
101
-
102
- * new patch version model 0.5.1
103
-
104
- #### bioimageio.spec 0.5.0post2
105
-
106
- * don't fail if CI env var is a string
107
-
108
- #### bioimageio.spec 0.5.0post1
109
-
110
- * fix `_internal.io_utils.identify_bioimageio_yaml_file()`
111
-
112
- #### bioimageio.spec 0.5.0
113
-
114
- * new description formats: [generic 0.3, application 0.3, collection 0.3, dataset 0.3, notebook 0.3](generic-030--application-030--collection-030--dataset-030--notebook-030) and [model 0.5](model-050).
115
- * various API changes, most important functions:
116
- * `bioimageio.spec.load_description` (replaces `load_raw_resource_description`, interface changed)
117
- * `bioimageio.spec.validate_format` (new)
118
- * `bioimageio.spec.dump_description` (replaces `serialize_raw_resource_description_to_dict`, interface changed)
119
- * `bioimageio.spec.update_format` (interface changed)
120
- * switch from Marshmallow to Pydantic
121
- * extended validation
122
- * one joint, more precise JSON schema
123
-
124
- #### bioimageio.spec 0.4.9
125
-
126
- * small bugixes
127
- * better type hints
128
- * improved tests
129
-
130
- #### bioimageio.spec 0.4.8post1
131
-
132
- * add `axes` and `eps` to `scale_mean_var`
133
-
134
- #### bioimageio.spec 0.4.7post1
135
-
136
- * add simple forward compatibility by treating future format versions as latest known (for the respective resource type)
137
-
138
- #### bioimageio.spec 0.4.6post3
139
-
140
- * Make CLI output more readable
141
-
142
- * find redirected URLs when checking for URL availability
143
-
144
- #### bioimageio.spec 0.4.6post2
145
-
146
- * Improve error message for non-existing RDF file path given as string
147
-
148
- * Improve documentation for model description's `documentation` field
149
-
150
- #### bioimageio.spec 0.4.6post1
151
-
152
- * fix enrich_partial_rdf_with_imjoy_plugin (see <https://github.com/bioimage-io/spec-bioimage-io/pull/452>)
153
-
154
- #### bioimageio.spec 0.4.5post16
155
-
156
- * fix rdf_update of entries in `resolve_collection_entries()`
157
-
158
- #### bioimageio.spec 0.4.5post15
159
-
160
- * pass root to `enrich_partial_rdf` arg of `resolve_collection_entries()`
161
-
162
- #### bioimageio.spec 0.4.5post14
163
-
164
- * keep `ResourceDescrption.root_path` as URI for remote resources. This fixes the collection description as the collection entries are resolved after the collection description has been loaded.
165
-
166
- #### bioimageio.spec 0.4.5post13
167
-
168
- * new bioimageio.spec.partner module adding validate-partner-collection command if optional 'lxml' dependency is available
169
-
170
- #### bioimageio.spec 0.4.5post12
171
-
172
- * new env var `BIOIMAGEIO_CACHE_WARNINGS_LIMIT` (default: 3) to avoid spam from cache hit warnings
173
-
174
- * more robust conversion of ImportableSourceFile for absolute paths to relative paths (don't fail on non-path source file)
175
-
176
- #### bioimageio.spec 0.4.5post11
177
-
178
- * resolve symlinks when transforming absolute to relative paths during serialization; see [#438](https://github.com/bioimage-io/spec-bioimage-io/pull/438)
179
-
180
- #### bioimageio.spec 0.4.5post10
181
-
182
- * fix loading of collection description with id (id used to be ignored)
183
-
184
- #### bioimageio.spec 0.4.5post9
185
-
186
- * support loading bioimageio resources by their animal nickname (currently only models have nicknames).
187
-
188
- #### bioimageio.spec 0.4.5post8
189
-
190
- * any field previously expecting a local relative path is now also accepting an absolute path
191
-
192
- * load_raw_resource_description returns a raw resource description which has no relative paths (any relative paths are converted to absolute paths).
193
-
194
- #### bioimageio.spec 0.4.4post7
195
-
196
- * add command `commands.update_rdf()`/`update-rdf`(cli)
197
-
198
- #### bioimageio.spec 0.4.4post2
199
-
200
- * fix unresolved ImportableSourceFile
201
-
202
- #### bioimageio.spec 0.4.4post1
203
-
204
- * fix collection description conversion for type field
205
-
206
- #### bioimageio.spec 0.4.3post1
207
-
208
- * fix to shape validation for model description 0.4: output shape now needs to be bigger than halo
209
-
210
- * moved objects from bioimageio.spec.shared.utils to bioimageio.spec.shared\[.node_transformer\]
211
- * additional keys to validation summary: bioimageio_spec_version, status
212
-
213
- #### bioimageio.spec 0.4.2post4
214
-
215
- * fixes to generic description:
216
- * ignore value of field `root_path` if present in yaml. This field is used internally and always present in RDF nodes.
217
-
218
- #### bioimageio.spec 0.4.1.post5
219
-
220
- * fixes to collection description:
221
- * RDFs specified directly in collection description are validated correctly even if their source field does not point to an RDF.
222
- * nesting of collection description allowed
223
-
224
- #### bioimageio.spec 0.4.1.post4
225
-
226
- * fixed missing field `icon` in generic description's raw node
227
-
228
- * fixes to collection description:
229
- * RDFs specified directly in collection description are validated correctly
230
- * no nesting of collection description allowed for now
231
- * `links` is no longer an explicit collection entry field ("moved" to unknown)
232
-
233
- #### bioimageio.spec 0.4.1.post0
234
-
235
- * new model spec 0.3.5 and 0.4.1
236
-
237
- #### bioimageio.spec 0.4.0.post3
238
-
239
- * `load_raw_resource_description` no longer accepts `update_to_current_format` kwarg (use `update_to_format` instead)
240
-
241
- #### bioimageio.spec 0.4.0.post2
242
-
243
- * `load_raw_resource_description` accepts `update_to_format` kwarg
244
-
245
- ### Resource Description Format Versions
246
-
247
- #### model 0.5.2
248
-
249
- * Non-breaking changes
250
- * added `concatenable` flag to index, time and space input axes
251
-
252
- #### model 0.5.1
253
-
254
- * Non-breaking changes
255
- * added `DataDependentSize` for `outputs.i.size` to specify an output shape that is not known before inference is run.
256
- * added optional `inputs.i.optional` field to indicate that a tensor may be `None`
257
- * made data type assumptions in `preprocessing` and `postprocessing` explicit by adding `'ensure_dtype'` operations per default.
258
- * allow to specify multiple thresholds (along an `axis`) in a 'binarize' processing step
259
-
260
- #### generic 0.3.0 / application 0.3.0 / collection 0.3.0 / dataset 0.3.0 / notebook 0.3.0
261
-
262
- * Breaking canges that are fully auto-convertible
263
- * dropped `download_url`
264
- * dropped non-file attachments
265
- * `attachments.files` moved to `attachments.i.source`
266
- * Non-breaking changes
267
- * added optional `parent` field
268
-
269
- #### model 0.5.0
270
-
271
- all generic 0.3.0 changes (except models already have the `parent` field) plus:
272
-
273
- * Breaking changes that are partially auto-convertible
274
- * `inputs.i.axes` are now defined in more detail (same for `outputs.i.axes`)
275
- * `inputs.i.shape` moved per axes to `inputs.i.axes.size` (same for `outputs.i.shape`)
276
- * new pre-/postprocessing 'fixed_zero_mean_unit_variance' separated from 'zero_mean_unit_variance', where `mode=fixed` is no longer valid.
277
- (for scalar values this is auto-convertible.)
278
- * Breaking changes that are fully auto-convertible
279
- * changes in `weights.pytorch_state_dict.architecture`
280
- * renamed `weights.pytorch_state_dict.architecture.source_file` to `...architecture.source`
281
- * changes in `weights.pytorch_state_dict.dependencies`
282
- * only conda environment allowed and specified by `weights.pytorch_state_dict.dependencies.source`
283
- * new optional field `weights.pytorch_state_dict.dependencies.sha256`
284
- * changes in `weights.tensorflow_model_bundle.dependencies`
285
- * same as changes in `weights.pytorch_state_dict.dependencies`
286
- * moved `test_inputs` to `inputs.i.test_tensor`
287
- * moved `test_outputs` to `outputs.i.test_tensor`
288
- * moved `sample_inputs` to `inputs.i.sample_tensor`
289
- * moved `sample_outputs` to `outputs.i.sample_tensor`
290
- * renamed `inputs.i.name` to `inputs.i.id`
291
- * renamed `outputs.i.name` to `outputs.i.id`
292
- * renamed `inputs.i.preprocessing.name` to `inputs.i.preprocessing.id`
293
- * renamed `outputs.i.postprocessing.name` to `outputs.i.postprocessing.id`
294
- * Non-breaking changes:
295
- * new pre-/postprocessing: `id`='ensure_dtype' with kwarg `dtype`
296
-
297
- #### generic 0.2.4 and model 0.4.10
298
-
299
- * Breaking changes that are fully auto-convertible
300
- * `id` overwritten with value from `config.bioimageio.nickname` if available
301
- * Non-breaking changes
302
- * `version_number` is a new, optional field indicating that an RDF is the nth published version with a given `id`
303
- * `id_emoji` is a new, optional field (set from `config.bioimageio.nickname_icon` if available)
304
- * `uploader` is a new, optional field with `email` and an optional `name` subfields
305
-
306
- #### model 0.4.9
307
-
308
- * Non-breaking changes
309
- * make pre-/postprocessing kwargs `mode` and `axes` always optional for model description 0.3 and 0.4
310
-
311
- #### model 0.4.8
312
-
313
- * Non-breaking changes
314
- * `cite` field is now optional
315
-
316
- #### generic 0.2.2 and model 0.4.7
317
-
318
- * Breaking changes that are fully auto-convertible
319
- * name field may not include '/' or '\' (conversion removes these)
320
-
321
- #### model 0.4.6
322
-
323
- * Non-breaking changes
324
- * Implicit output shape can be expanded by inserting `null` into `shape:scale` and indicating length of new dimension D in the `offset` field. Keep in mind that `D=2*'offset'`.
325
-
326
- #### model 0.4.5
327
-
328
- * Breaking changes that are fully auto-convertible
329
- * `parent` field changed to hold a string that is a bioimage.io ID, a URL or a local relative path (and not subfields `uri` and `sha256`)
330
-
331
- #### model 0.4.4
332
-
333
- * Non-breaking changes
334
- * new optional field `training_data`
335
-
336
- #### dataset 0.2.2
337
-
338
- * Non-breaking changes
339
- * explicitly define and document dataset description (for now, clone of generic description with type="dataset")
340
-
341
- #### model 0.4.3
342
-
343
- * Non-breaking changes
344
- * add optional field `download_url`
345
- * add optional field `dependencies` to all weight formats (not only pytorch_state_dict)
346
- * add optional `pytorch_version` to the pytorch_state_dict and torchscript weight formats
347
-
348
- #### model 0.4.2
349
-
350
- * Bug fixes:
351
- * in a `pytorch_state_dict` weight entry `architecture` is no longer optional.
352
-
353
- #### collection 0.2.2
354
-
355
- * Non-breaking changes
356
- * make `authors`, `cite`, `documentation` and `tags` optional
357
-
358
- * Breaking changes that are fully auto-convertible
359
- * Simplifies collection description 0.2.1 by merging resource type fields together to a `collection` field,
360
- holindg a list of all resources in the specified collection.
361
-
362
- #### generic 0.2.2 / model 0.3.6 / model 0.4.2
363
-
364
- * Non-breaking changes
365
- * `rdf_source` new optional field
366
- * `id` new optional field
367
-
368
- #### collection 0.2.1
369
-
370
- * First official release, extends generic description with fields `application`, `model`, `dataset`, `notebook` and (nested)
371
- `collection`, which hold lists linking to respective resources.
372
-
373
- #### generic 0.2.1
374
-
375
- * Non-breaking changes
376
- * add optional `email` and `github_user` fields to entries in `authors`
377
- * add optional `maintainers` field (entries like in `authors` but `github_user` is required (and `name` is not))
378
-
379
- #### model 0.4.1
380
-
381
- * Breaking changes that are fully auto-convertible
382
- * moved field `dependencies` to `weights:pytorch_state_dict:dependencies`
383
-
384
- * Non-breaking changes
385
- * `documentation` field accepts URLs as well
386
-
387
- #### model 0.3.5
388
-
389
- * Non-breaking changes
390
- * `documentation` field accepts URLs as well
391
-
392
- #### model 0.4.0
393
-
394
- * Breaking changes
395
- * model inputs and outputs may not use duplicated names.
396
- * model field `sha256` is required if `pytorch_state_dict` weights are defined.
397
- and is now moved to the `pytroch_state_dict` entry as `architecture_sha256`.
398
-
399
- * Breaking changes that are fully auto-convertible
400
- * model fields language and framework are removed.
401
- * model field `source` is renamed `architecture` and is moved together with `kwargs` to the `pytorch_state_dict`
402
- weights entry (if it exists, otherwise they are removed).
403
- * the weight format `pytorch_script` was renamed to `torchscript`.
404
- * Other changes
405
- * model inputs (like outputs) may be defined by `scale`ing and `offset`ing a `reference_tensor`
406
- * a `maintainers` field was added to the model description.
407
- * the entries in the `authors` field may now additionally contain `email` or `github_user`.
408
- * the summary returned by the `validate` command now also contains a list of warnings.
409
- * an `update_format` command was added to aid with updating older RDFs by applying auto-conversion.
410
-
411
- #### model 0.3.4
412
-
413
- * Non-breaking changes
414
- * Add optional parameter `eps` to `scale_range` postprocessing.
415
-
416
- #### model 0.3.3
417
-
418
- * Breaking changes that are fully auto-convertible
419
- * `reference_input` for implicit output tensor shape was renamed to `reference_tensor`
420
-
421
- #### model 0.3.2
422
-
423
- * Breaking changes
424
- * The RDF file name in a package should be `rdf.yaml` for all the RDF (not `model.yaml`);
425
- * Change `authors` and `packaged_by` fields from List[str] to List[Author] with Author consisting of a dictionary `{name: '<Full name>', affiliation: '<Affiliation>', orcid: 'optional orcid id'}`;
426
- * Add a mandatory `type` field to comply with the generic description. Only valid value is 'model' for model description;
427
- * Only allow `license` identifier from the [SPDX license list](https://spdx.org/licenses/);
428
-
429
- * Non-breaking changes
430
- * Add optional `version` field (default 0.1.0) to keep track of model changes;
431
- * Allow the values in the `attachments` list to be any values besides URI;
 
1
+ # HyLFM-Net Example
 
 
2
 
3
+ Reference example for a HyLFM-Net developed at [kreshuklab/hylfm-net](https://github.com/kreshuklab/hylfm-net).
4
+ This network is not expected to generalize to other microscopy light field datasets.
5
+ See [Deep learning-enhanced light-field imaging withcontinuous validation](https://rdcu.be/cktHs) for details.
6
 
7
+ ## Validation
8
 
9
+ HyLFM-Net reconstructions should be validated using light sheet ground truth acquired with the same HyLFM.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
package/bioimageio.yaml CHANGED
@@ -35,7 +35,7 @@ tags:
35
  - image-reconstruction
36
  - nuclei
37
  - hylfm
38
- version: 1.2
39
  format_version: 0.5.7
40
  type: model
41
  id: ambitious-sloth
@@ -138,7 +138,6 @@ weights:
138
  sha256: 461f1151d7fea5857ce8f9ceaf9cdf08b5f78ce41785725e39a77d154ccea90a
139
  architecture:
140
  source: model.py
141
- sha256: 7fbc9010a764a89e1bb6c162fc9df16eadb95d63bf3a1233cbcb61d82e3bab07
142
  callable: HyLFM_Net
143
  kwargs:
144
  c_in_3d: 64
@@ -162,9 +161,6 @@ weights:
162
  nnum: 19
163
  z_out: 49
164
  pytorch_version: 1.13
165
- dependencies:
166
- source: environment.yaml
167
- sha256: e0c059d829fa03193eede76961746f464ac9b07d072b1e6ee62395d5c03c8606
168
  torchscript:
169
  source: weights_torchscript.pt
170
  sha256: ec01e0c212b5eb422dda208af004665799637a2f2729d0ebf2e884e5d9966fc2
 
35
  - image-reconstruction
36
  - nuclei
37
  - hylfm
38
+ version: 1.3
39
  format_version: 0.5.7
40
  type: model
41
  id: ambitious-sloth
 
138
  sha256: 461f1151d7fea5857ce8f9ceaf9cdf08b5f78ce41785725e39a77d154ccea90a
139
  architecture:
140
  source: model.py
 
141
  callable: HyLFM_Net
142
  kwargs:
143
  c_in_3d: 64
 
161
  nnum: 19
162
  z_out: 49
163
  pytorch_version: 1.13
 
 
 
164
  torchscript:
165
  source: weights_torchscript.pt
166
  sha256: ec01e0c212b5eb422dda208af004665799637a2f2729d0ebf2e884e5d9966fc2
package/environment.yaml DELETED
@@ -1,10 +0,0 @@
1
- name: hylfm
2
-
3
- channels:
4
- - conda-forge
5
-
6
- dependencies:
7
- - python=3.8.*
8
- - pytorch=1.9.*
9
- - torchvision=0.10.*
10
- - inferno=v0.4.2
 
 
 
 
 
 
 
 
 
 
 
package/model.py CHANGED
@@ -1,289 +1,951 @@
1
- import collections
2
- import inspect
3
- from enum import Enum
4
- from functools import partial
5
- from typing import List, Optional, Sequence, Tuple, Union
6
-
7
- import torch.nn as nn
8
- from inferno.extensions.initializers import (
9
- Constant,
10
- Initialization,
11
- KaimingNormalWeightsZeroBias,
12
- )
13
- from inferno.extensions.layers import convolutional as inferno_convolutional
14
-
15
- Conv2D = inferno_convolutional.Conv2D
16
- ValidConv3D = inferno_convolutional.ValidConv3D
17
-
18
-
19
- class Crop(nn.Module):
20
- def __init__(self, *slices: slice):
21
- super().__init__()
22
- self.slices = slices
23
-
24
- def extra_repr(self):
25
- return str(self.slices)
26
-
27
- def forward(self, input):
28
- return input[self.slices]
29
-
30
-
31
- class ChannelFromLightField(nn.Module):
32
- def __init__(self, nnum: int):
33
- super().__init__()
34
- self.nnum = nnum
35
-
36
- def forward(self, tensor):
37
- assert len(tensor.shape) == 4, tensor.shape
38
- b, c, x, y = tensor.shape
39
- assert c == 1
40
- assert x % self.nnum == 0, (x, self.nnum)
41
- assert y % self.nnum == 0, (y, self.nnum)
42
- return (
43
- tensor.reshape(b, x // self.nnum, self.nnum, y // self.nnum, self.nnum)
44
- .transpose(1, 2)
45
- .transpose(2, 4)
46
- .transpose(3, 4)
47
- .reshape(b, self.nnum**2, x // self.nnum, y // self.nnum)
48
- )
49
-
50
-
51
- class ResnetBlock(nn.Module):
52
- def __init__(
53
- self,
54
- in_n_filters,
55
- n_filters,
56
- kernel_size=(3, 3),
57
- batch_norm=False,
58
- conv_per_block=2,
59
- valid: bool = False,
60
- activation: str = "ReLU",
61
- ):
62
- super().__init__()
63
- if batch_norm and activation != "ReLU":
64
- raise NotImplementedError("batch_norm with non ReLU activation")
65
-
66
- assert isinstance(kernel_size, tuple), kernel_size
67
- assert conv_per_block >= 2
68
- self.debug = False # sys.gettrace() is not None
69
-
70
- Conv = getattr(
71
- inferno_convolutional,
72
- f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{'' if batch_norm else activation}{len(kernel_size)}D",
73
- )
74
- FinalConv = getattr(
75
- inferno_convolutional, f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{len(kernel_size)}D"
76
- )
77
-
78
- layers = []
79
- layers.append(Conv(in_channels=in_n_filters, out_channels=n_filters, kernel_size=kernel_size))
80
-
81
- for _ in range(conv_per_block - 2):
82
- layers.append(Conv(n_filters, n_filters, kernel_size))
83
-
84
- layers.append(FinalConv(n_filters, n_filters, kernel_size))
85
-
86
- self.block = nn.Sequential(*layers)
87
-
88
- if n_filters != in_n_filters:
89
- ProjConv = getattr(inferno_convolutional, f"Conv{len(kernel_size)}D")
90
- self.projection_layer = ProjConv(in_n_filters, n_filters, kernel_size=1)
91
- else:
92
- self.projection_layer = None
93
-
94
- if valid:
95
- crop_each_side = [conv_per_block * (ks // 2) for ks in kernel_size]
96
- self.crop = Crop(..., *[slice(c, -c) for c in crop_each_side])
97
- else:
98
- self.crop = None
99
-
100
- self.relu = nn.ReLU()
101
-
102
- # determine shrinkage
103
- # self.shrinkage = (1, 1) + tuple([conv_per_block * (ks - 1) for ks in kernel_size])
104
-
105
- def forward(self, input):
106
- x = self.block(input)
107
- if self.crop is not None:
108
- input = self.crop(input)
109
-
110
- if self.projection_layer is None:
111
- x = x + input
112
- else:
113
- projected = self.projection_layer(input)
114
- x = x + projected
115
-
116
- x = self.relu(x)
117
- return x
118
-
119
-
120
- class HyLFM_Net(nn.Module):
121
- class InitName(str, Enum):
122
- uniform_ = "uniform"
123
- normal_ = "normal"
124
- constant_ = "constant"
125
- eye_ = "eye"
126
- dirac_ = "dirac"
127
- xavier_uniform_ = "xavier_uniform"
128
- xavier_normal_ = "xavier_normal"
129
- kaiming_uniform_ = "kaiming_uniform"
130
- kaiming_normal_ = "kaiming_normal"
131
- orthogonal_ = "orthogonal"
132
- sparse_ = "sparse"
133
-
134
- def __init__(
135
- self,
136
- *,
137
- z_out: int,
138
- nnum: int,
139
- kernel2d: int = 3,
140
- conv_per_block2d: int = 2,
141
- c_res2d: Sequence[Union[int, str]] = (488, 488, "u244", 244),
142
- last_kernel2d: int = 1,
143
- c_in_3d: int = 7,
144
- kernel3d: int = 3,
145
- conv_per_block3d: int = 2,
146
- c_res3d: Sequence[str] = (7, "u7", 7, 7),
147
- init_fn: Union[InitName, str] = InitName.xavier_uniform_.value,
148
- final_activation: Optional[str] = None,
149
- ):
150
- super().__init__()
151
- self.channel_from_lf = ChannelFromLightField(nnum=nnum)
152
- init_fn = self.InitName(init_fn)
153
-
154
- init_fn = getattr(nn.init, init_fn.value)
155
- self.c_res2d = list(c_res2d)
156
- self.c_res3d = list(c_res3d)
157
- c_res3d = c_res3d
158
- self.nnum = nnum
159
- self.z_out = z_out
160
- if kernel3d != 3:
161
- raise NotImplementedError("z_out expansion for other res3d kernel")
162
-
163
- dz = 2 * conv_per_block3d * (kernel3d // 2)
164
- for c in c_res3d:
165
- if isinstance(c, int) or not c.startswith("u"):
166
- z_out += dz
167
-
168
- # z_out += 4 * (len(c_res3d) - 2 * sum([layer == "u" for layer in c_res3d])) # add z_out for valid 3d convs
169
-
170
- assert c_res2d[-1] != "u", "missing # output channels for upsampling in 'c_res2d'"
171
- assert c_res3d[-1] != "u", "missing # output channels for upsampling in 'c_res3d'"
172
-
173
- res2d = []
174
- c_in = nnum**2
175
- c_out = c_in
176
- for i in range(len(c_res2d)):
177
- if not isinstance(c_res2d[i], int) and c_res2d[i].startswith("u"):
178
- c_out = int(c_res2d[i][1:])
179
- res2d.append(
180
- nn.ConvTranspose2d(
181
- in_channels=c_in, out_channels=c_out, kernel_size=2, stride=2, padding=0, output_padding=0
182
- )
183
- )
184
- else:
185
- c_out = int(c_res2d[i])
186
- res2d.append(
187
- ResnetBlock(
188
- in_n_filters=c_in,
189
- n_filters=c_out,
190
- kernel_size=(kernel2d, kernel2d),
191
- valid=False,
192
- conv_per_block=conv_per_block2d,
193
- )
194
- )
195
-
196
- c_in = c_out
197
-
198
- self.res2d = nn.Sequential(*res2d)
199
-
200
- if "gain" in inspect.signature(init_fn).parameters:
201
- init_fn_conv2d = partial(init_fn, gain=nn.init.calculate_gain("relu"))
202
- else:
203
- init_fn_conv2d = init_fn
204
-
205
- init = Initialization(weight_initializer=init_fn_conv2d, bias_initializer=Constant(0.0))
206
- self.conv2d = Conv2D(c_out, z_out * c_in_3d, last_kernel2d, activation="ReLU", initialization=init)
207
-
208
- self.c2z = lambda ipt, ip3=c_in_3d: ipt.view(ipt.shape[0], ip3, z_out, *ipt.shape[2:])
209
-
210
- res3d = []
211
- c_in = c_in_3d
212
- c_out = c_in
213
- for i in range(len(c_res3d)):
214
- if not isinstance(c_res3d[i], int) and c_res3d[i].startswith("u"):
215
- c_out = int(c_res3d[i][1:])
216
- res3d.append(
217
- nn.ConvTranspose3d(
218
- in_channels=c_in,
219
- out_channels=c_out,
220
- kernel_size=(3, 2, 2),
221
- stride=(1, 2, 2),
222
- padding=(1, 0, 0),
223
- output_padding=0,
224
- )
225
- )
226
- else:
227
- c_out = int(c_res3d[i])
228
- res3d.append(
229
- ResnetBlock(
230
- in_n_filters=c_in,
231
- n_filters=c_out,
232
- kernel_size=(kernel3d, kernel3d, kernel3d),
233
- valid=True,
234
- conv_per_block=conv_per_block3d,
235
- )
236
- )
237
-
238
- c_in = c_out
239
-
240
- self.res3d = nn.Sequential(*res3d)
241
-
242
- if "gain" in inspect.signature(init_fn).parameters:
243
- init_fn_conv3d = partial(init_fn, gain=nn.init.calculate_gain("linear"))
244
- else:
245
- init_fn_conv3d = init_fn
246
-
247
- init = Initialization(weight_initializer=init_fn_conv3d, bias_initializer=Constant(0.0))
248
- self.conv3d = ValidConv3D(c_out, 1, (1, 1, 1), initialization=init)
249
-
250
- if final_activation is None:
251
- self.final_activation = None
252
- elif final_activation == "sigmoid":
253
- self.final_activation = nn.Sigmoid()
254
- else:
255
- raise NotImplementedError(final_activation)
256
-
257
- def forward(self, x):
258
- x = self.channel_from_lf(x)
259
- x = self.res2d(x)
260
- x = self.conv2d(x)
261
- x = self.c2z(x)
262
- x = self.res3d(x)
263
- x = self.conv3d(x)
264
-
265
- if self.final_activation is not None:
266
- x = self.final_activation(x)
267
-
268
- return x
269
-
270
- def get_scale(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
271
- s = max(1, 2 * sum(isinstance(res2d, str) and res2d.startswith("u") for res2d in self.c_res2d)) * max(
272
- 1, 2 * sum(isinstance(res3d, str) and res3d.startswith("u") for res3d in self.c_res3d)
273
- )
274
- return s
275
-
276
- def get_shrink(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
277
- s = 0
278
- for res in self.c_res3d:
279
- if isinstance(res, str) and res.startswith("u"):
280
- s *= 2
281
- else:
282
- s += 2
283
-
284
- return s
285
-
286
- def get_output_shape(self, ipt_shape: Tuple[int, int]) -> Tuple[int, int, int]:
287
- scale = self.get_scaling(ipt_shape)
288
- shrink = self.get_shrink(ipt_shape)
289
- return (self.z_out,) + tuple(i * scale - 2 * shrink for i in ipt_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ import inspect
3
+ from enum import Enum
4
+ from functools import partial
5
+ from typing import Optional, Sequence, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch.nn as nn
9
+
10
+ ### Inferno parts (adapted from inferno 0.4.2)
11
+
12
+
13
+ def assert_(condition, message="", exception_type=AssertionError):
14
+ """Like assert, but with arbitrary exception types."""
15
+ if not condition:
16
+ raise exception_type(message)
17
+
18
+
19
+ # proxy for generated classes in inferno
20
+ generated_inferno_classes = {}
21
+
22
+
23
+ def partial_cls(base_cls, name, fix=None, default=None):
24
+
25
+ # helper function
26
+ def insert_if_not_present(dict_a, dict_b):
27
+ for kw, val in dict_b.items():
28
+ if kw not in dict_a:
29
+ dict_a[kw] = val
30
+ return dict_a
31
+
32
+ # helper function
33
+ def insert_call_if_present(dict_a, dict_b, callback):
34
+ for kw, val in dict_b.items():
35
+ if kw not in dict_a:
36
+ dict_a[kw] = val
37
+ else:
38
+ callback(kw)
39
+ return dict_a
40
+
41
+ # helper class
42
+ class PartialCls(object):
43
+ def __init__(self, base_cls, name, fix=None, default=None):
44
+
45
+ self.base_cls = base_cls
46
+ self.name = name
47
+ self.fix = [fix, {}][fix is None]
48
+ self.default = [default, {}][default is None]
49
+
50
+ if self.fix.keys() & self.default.keys():
51
+ raise TypeError("fix and default share keys")
52
+
53
+ # remove binded kw
54
+ self._allowed_kw = self._get_allowed_kw()
55
+
56
+ def _get_allowed_kw(self):
57
+
58
+ argspec = inspect.getfullargspec(base_cls.__init__)
59
+ args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = (
60
+ argspec
61
+ )
62
+
63
+ if varargs is not None:
64
+ raise TypeError(
65
+ "partial_cls can only be used if __init__ has no varargs"
66
+ )
67
+
68
+ if varkw is not None:
69
+ raise TypeError("partial_cls can only be used if __init__ has no varkw")
70
+
71
+ if kwonlyargs is not None and kwonlyargs != []:
72
+ raise TypeError("partial_cls can only be used without kwonlyargs")
73
+
74
+ if args is None or len(args) < 1:
75
+ raise TypeError("seems like self is missing")
76
+
77
+ return [kw for kw in args[1:] if kw not in self.fix]
78
+
79
+ def _build_kw(self, args, kwargs):
80
+ # handle *args
81
+ if len(args) > len(self._allowed_kw):
82
+ raise TypeError("to many arguments")
83
+
84
+ all_args = {}
85
+ for arg, akw in zip(args, self._allowed_kw):
86
+ all_args[akw] = arg
87
+
88
+ # handle **kwargs
89
+ intersection = self.fix.keys() & kwargs.keys()
90
+ if len(intersection) >= 1:
91
+ kw = intersection.pop()
92
+ raise TypeError(
93
+ "`{}.__init__` got unexpected keyword argument '{}'".format(
94
+ name, kw
95
+ )
96
+ )
97
+
98
+ def raise_cb(kw):
99
+ raise TypeError(
100
+ "{}.__init__ got multiple values for argument '{}'".format(name, kw)
101
+ )
102
+
103
+ all_args = insert_call_if_present(all_args, kwargs, raise_cb)
104
+
105
+ # handle fixed arguments
106
+ def raise_cb(kw):
107
+ raise TypeError()
108
+
109
+ all_args = insert_call_if_present(all_args, self.fix, raise_cb)
110
+
111
+ # handle defaults
112
+ all_args = insert_if_not_present(all_args, self.default)
113
+
114
+ # handle fixed
115
+ all_args.update(self.fix)
116
+
117
+ return all_args
118
+
119
+ def build_cls(self):
120
+
121
+ def new_init(self_of_new_cls, *args, **kwargs):
122
+ combined_args = self._build_kw(args=args, kwargs=kwargs)
123
+
124
+ # call base cls init
125
+ super(self_of_new_cls.__class__, self_of_new_cls).__init__(
126
+ **combined_args
127
+ )
128
+
129
+ return type(name, (self.base_cls,), {"__init__": new_init})
130
+
131
+ return PartialCls(
132
+ base_cls=base_cls, name=name, fix=fix, default=default
133
+ ).build_cls()
134
+
135
+
136
+ def register_partial_cls(base_cls, name, fix=None, default=None):
137
+ generatedClass = partial_cls(base_cls=base_cls, name=name, fix=fix, default=default)
138
+ generated_inferno_classes[generatedClass.__name__] = generatedClass
139
+
140
+
141
+ class Initializer(object):
142
+ """
143
+ Base class for all initializers.
144
+ """
145
+
146
+ # TODO Support LSTMs and GRUs
147
+ VALID_LAYERS = {
148
+ "Conv1d",
149
+ "Conv2d",
150
+ "Conv3d",
151
+ "ConvTranspose1d",
152
+ "ConvTranspose2d",
153
+ "ConvTranspose3d",
154
+ "Linear",
155
+ "Bilinear",
156
+ "Embedding",
157
+ }
158
+
159
+ def __call__(self, module):
160
+ module_class_name = module.__class__.__name__
161
+ if module_class_name in self.VALID_LAYERS:
162
+ # Apply to weight and bias
163
+ try:
164
+ if hasattr(module, "weight"):
165
+ self.call_on_weight(module.weight.data)
166
+ except NotImplementedError:
167
+ # Don't cry if it's not implemented
168
+ pass
169
+
170
+ try:
171
+ if hasattr(module, "bias"):
172
+ self.call_on_bias(module.bias.data)
173
+ except NotImplementedError:
174
+ pass
175
+
176
+ return module
177
+
178
+ def call_on_bias(self, tensor):
179
+ return self.call_on_tensor(tensor)
180
+
181
+ def call_on_weight(self, tensor):
182
+ return self.call_on_tensor(tensor)
183
+
184
+ def call_on_tensor(self, tensor):
185
+ raise NotImplementedError
186
+
187
+ @classmethod
188
+ def initializes_weight(cls):
189
+ return "call_on_tensor" in cls.__dict__ or "call_on_weight" in cls.__dict__
190
+
191
+ @classmethod
192
+ def initializes_bias(cls):
193
+ return "call_on_tensor" in cls.__dict__ or "call_on_bias" in cls.__dict__
194
+
195
+
196
+ class Initialization(Initializer):
197
+ def __init__(self, weight_initializer=None, bias_initializer=None):
198
+ if weight_initializer is None:
199
+ self.weight_initializer = Initializer()
200
+ else:
201
+ if isinstance(weight_initializer, Initializer):
202
+ assert weight_initializer.initializes_weight()
203
+ self.weight_initializer = weight_initializer
204
+ elif isinstance(weight_initializer, str):
205
+ init_function = getattr(nn.init, weight_initializer, None)
206
+ assert init_function is not None
207
+ self.weight_initializer = WeightInitFunction(
208
+ init_function=init_function
209
+ )
210
+ else:
211
+ # Provison for weight_initializer to be a function
212
+ assert callable(weight_initializer)
213
+ self.weight_initializer = WeightInitFunction(
214
+ init_function=weight_initializer
215
+ )
216
+
217
+ if bias_initializer is None:
218
+ self.bias_initializer = Initializer()
219
+ else:
220
+ if isinstance(bias_initializer, Initializer):
221
+ assert bias_initializer.initializes_bias
222
+ self.bias_initializer = bias_initializer
223
+ elif isinstance(bias_initializer, str):
224
+ init_function = getattr(nn.init, bias_initializer, None)
225
+ assert init_function is not None
226
+ self.bias_initializer = BiasInitFunction(init_function=init_function)
227
+ else:
228
+ assert callable(bias_initializer)
229
+ self.bias_initializer = BiasInitFunction(init_function=bias_initializer)
230
+
231
+ def call_on_weight(self, tensor):
232
+ return self.weight_initializer.call_on_weight(tensor)
233
+
234
+ def call_on_bias(self, tensor):
235
+ return self.bias_initializer.call_on_bias(tensor)
236
+
237
+
238
+ class WeightInitFunction(Initializer):
239
+ def __init__(self, init_function, *init_function_args, **init_function_kwargs):
240
+ super(WeightInitFunction, self).__init__()
241
+ assert callable(init_function)
242
+ self.init_function = init_function
243
+ self.init_function_args = init_function_args
244
+ self.init_function_kwargs = init_function_kwargs
245
+
246
+ def call_on_weight(self, tensor):
247
+ return self.init_function(
248
+ tensor, *self.init_function_args, **self.init_function_kwargs
249
+ )
250
+
251
+
252
+ class BiasInitFunction(Initializer):
253
+ def __init__(self, init_function, *init_function_args, **init_function_kwargs):
254
+ super(BiasInitFunction, self).__init__()
255
+ assert callable(init_function)
256
+ self.init_function = init_function
257
+ self.init_function_args = init_function_args
258
+ self.init_function_kwargs = init_function_kwargs
259
+
260
+ def call_on_bias(self, tensor):
261
+ return self.init_function(
262
+ tensor, *self.init_function_args, **self.init_function_kwargs
263
+ )
264
+
265
+
266
+ class TensorInitFunction(Initializer):
267
+ def __init__(self, init_function, *init_function_args, **init_function_kwargs):
268
+ super(TensorInitFunction, self).__init__()
269
+ assert callable(init_function)
270
+ self.init_function = init_function
271
+ self.init_function_args = init_function_args
272
+ self.init_function_kwargs = init_function_kwargs
273
+
274
+ def call_on_tensor(self, tensor):
275
+ return self.init_function(
276
+ tensor, *self.init_function_args, **self.init_function_kwargs
277
+ )
278
+
279
+
280
+ class Constant(Initializer):
281
+ """Initialize with a constant."""
282
+
283
+ def __init__(self, constant):
284
+ self.constant = constant
285
+
286
+ def call_on_tensor(self, tensor):
287
+ tensor.fill_(self.constant)
288
+ return tensor
289
+
290
+
291
+ class NormalWeights(Initializer):
292
+ """
293
+ Initialize weights with random numbers drawn from the normal distribution at
294
+ `mean` and `stddev`.
295
+ """
296
+
297
+ def __init__(self, mean=0.0, stddev=1.0, sqrt_gain_over_fan_in=None):
298
+ self.mean = mean
299
+ self.stddev = stddev
300
+ self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in
301
+
302
+ def compute_fan_in(self, tensor):
303
+ if tensor.dim() == 2:
304
+ return tensor.size(1)
305
+ else:
306
+ return np.prod(list(tensor.size())[1:])
307
+
308
+ def call_on_weight(self, tensor):
309
+ # Compute stddev if required
310
+ if self.sqrt_gain_over_fan_in is not None:
311
+ stddev = self.stddev * np.sqrt(
312
+ self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor)
313
+ )
314
+ else:
315
+ stddev = self.stddev
316
+ # Init
317
+ tensor.normal_(self.mean, stddev)
318
+
319
+
320
+ class OrthogonalWeightsZeroBias(Initialization):
321
+ def __init__(self, orthogonal_gain=1.0):
322
+ # This prevents a deprecated warning in Pytorch 0.4+
323
+ orthogonal = getattr(nn.init, "orthogonal_", nn.init.orthogonal)
324
+ super(OrthogonalWeightsZeroBias, self).__init__(
325
+ weight_initializer=partial(orthogonal, gain=orthogonal_gain),
326
+ bias_initializer=Constant(0.0),
327
+ )
328
+
329
+
330
+ class KaimingNormalWeightsZeroBias(Initialization):
331
+ def __init__(self, relu_leakage=0):
332
+ # This prevents a deprecated warning in Pytorch 0.4+
333
+ kaiming_normal = getattr(nn.init, "kaiming_normal_", nn.init.kaiming_normal)
334
+ super(KaimingNormalWeightsZeroBias, self).__init__(
335
+ weight_initializer=partial(kaiming_normal, a=relu_leakage),
336
+ bias_initializer=Constant(0.0),
337
+ )
338
+
339
+
340
+ class SELUWeightsZeroBias(Initialization):
341
+ def __init__(self):
342
+ super(SELUWeightsZeroBias, self).__init__(
343
+ weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.0),
344
+ bias_initializer=Constant(0.0),
345
+ )
346
+
347
+
348
+ class ELUWeightsZeroBias(Initialization):
349
+ def __init__(self):
350
+ super(ELUWeightsZeroBias, self).__init__(
351
+ weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),
352
+ bias_initializer=Constant(0.0),
353
+ )
354
+
355
+
356
+ class BatchNormND(nn.Module):
357
+ def __init__(
358
+ self,
359
+ dim,
360
+ num_features,
361
+ eps=1e-5,
362
+ momentum=0.1,
363
+ affine=True,
364
+ track_running_stats=True,
365
+ ):
366
+ super(BatchNormND, self).__init__()
367
+ assert dim in [1, 2, 3]
368
+ self.bn = getattr(nn, "BatchNorm{}d".format(dim))(
369
+ num_features=num_features,
370
+ eps=eps,
371
+ momentum=momentum,
372
+ affine=affine,
373
+ track_running_stats=track_running_stats,
374
+ )
375
+
376
+ def forward(self, x):
377
+ return self.bn(x)
378
+
379
+
380
+ class ConvActivation(nn.Module):
381
+ """Convolutional layer with 'SAME' padding by default followed by an activation."""
382
+
383
+ def __init__(
384
+ self,
385
+ in_channels,
386
+ out_channels,
387
+ kernel_size,
388
+ dim,
389
+ activation,
390
+ stride=1,
391
+ dilation=1,
392
+ groups=None,
393
+ depthwise=False,
394
+ bias=True,
395
+ deconv=False,
396
+ initialization=None,
397
+ valid_conv=False,
398
+ ):
399
+ super(ConvActivation, self).__init__()
400
+ # Validate dim
401
+ assert_(
402
+ dim in [1, 2, 3],
403
+ "`dim` must be one of [1, 2, 3], got {}.".format(dim),
404
+ )
405
+ self.dim = dim
406
+ # Check if depthwise
407
+ if depthwise:
408
+
409
+ # We know that in_channels == out_channels, but we also want a consistent API.
410
+ # As a compromise, we allow that out_channels be None or 'auto'.
411
+ out_channels = (
412
+ in_channels if out_channels in [None, "auto"] else out_channels
413
+ )
414
+ assert_(
415
+ in_channels == out_channels,
416
+ "For depthwise convolutions, number of input channels (given: {}) "
417
+ "must equal the number of output channels (given {}).".format(
418
+ in_channels, out_channels
419
+ ),
420
+ ValueError,
421
+ )
422
+ assert_(
423
+ groups is None or groups == in_channels,
424
+ "For depthwise convolutions, groups (given: {}) must "
425
+ "equal the number of channels (given: {}).".format(groups, in_channels),
426
+ )
427
+ groups = in_channels
428
+ else:
429
+ groups = 1 if groups is None else groups
430
+ self.depthwise = depthwise
431
+ if valid_conv:
432
+ self.conv = getattr(nn, "Conv{}d".format(self.dim))(
433
+ in_channels=in_channels,
434
+ out_channels=out_channels,
435
+ kernel_size=kernel_size,
436
+ stride=stride,
437
+ dilation=dilation,
438
+ groups=groups,
439
+ bias=bias,
440
+ )
441
+ elif not deconv:
442
+ # Get padding
443
+ padding = self.get_padding(kernel_size, dilation)
444
+ self.conv = getattr(nn, "Conv{}d".format(self.dim))(
445
+ in_channels=in_channels,
446
+ out_channels=out_channels,
447
+ kernel_size=kernel_size,
448
+ padding=padding,
449
+ stride=stride,
450
+ dilation=dilation,
451
+ groups=groups,
452
+ bias=bias,
453
+ )
454
+ else:
455
+ self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))(
456
+ in_channels=in_channels,
457
+ out_channels=out_channels,
458
+ kernel_size=kernel_size,
459
+ stride=stride,
460
+ dilation=dilation,
461
+ groups=groups,
462
+ bias=bias,
463
+ )
464
+ if initialization is None:
465
+ pass
466
+ elif isinstance(initialization, Initializer):
467
+ self.conv.apply(initialization)
468
+ else:
469
+ raise NotImplementedError
470
+
471
+ if isinstance(activation, str):
472
+ self.activation = getattr(nn, activation)()
473
+ elif isinstance(activation, nn.Module):
474
+ self.activation = activation
475
+ elif activation is None:
476
+ self.activation = None
477
+ else:
478
+ raise NotImplementedError
479
+
480
+ def forward(self, input):
481
+ conved = self.conv(input)
482
+ if self.activation is not None:
483
+ activated = self.activation(conved)
484
+ else:
485
+ # No activation
486
+ activated = conved
487
+ return activated
488
+
489
+ def _pair_or_triplet(self, object_):
490
+ if isinstance(object_, (list, tuple)):
491
+ assert len(object_) == self.dim
492
+ return object_
493
+ else:
494
+ object_ = [object_] * self.dim
495
+ return object_
496
+
497
+ def _get_padding(self, _kernel_size, _dilation):
498
+ assert isinstance(_kernel_size, int)
499
+ assert isinstance(_dilation, int)
500
+ assert _kernel_size % 2 == 1
501
+ return ((_kernel_size - 1) // 2) * _dilation
502
+
503
+ def get_padding(self, kernel_size, dilation):
504
+ kernel_size = self._pair_or_triplet(kernel_size)
505
+ dilation = self._pair_or_triplet(dilation)
506
+ padding = [
507
+ self._get_padding(_kernel_size, _dilation)
508
+ for _kernel_size, _dilation in zip(kernel_size, dilation)
509
+ ]
510
+ return tuple(padding)
511
+
512
+
513
+ # for consistency
514
+ ConvActivationND = ConvActivation
515
+
516
+
517
+ class _BNReLUSomeConv(object):
518
+ def forward(self, input):
519
+ normed = self.batchnorm(input)
520
+ activated = self.activation(normed)
521
+ conved = self.conv(activated)
522
+ return conved
523
+
524
+
525
+ class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation):
526
+ def __init__(
527
+ self,
528
+ in_channels,
529
+ out_channels,
530
+ kernel_size,
531
+ dim,
532
+ stride=1,
533
+ dilation=1,
534
+ deconv=False,
535
+ ):
536
+
537
+ super(BNReLUConvBaseND, self).__init__(
538
+ in_channels=in_channels,
539
+ out_channels=out_channels,
540
+ kernel_size=kernel_size,
541
+ dim=dim,
542
+ stride=stride,
543
+ activation=nn.ReLU(inplace=True),
544
+ dilation=dilation,
545
+ deconv=deconv,
546
+ initialization=KaimingNormalWeightsZeroBias(0),
547
+ )
548
+ self.batchnorm = BatchNormND(dim, in_channels)
549
+
550
+
551
+ def _register_bnr_conv_cls(conv_name, fix=None, default=None):
552
+ if fix is None:
553
+ fix = {}
554
+ if default is None:
555
+ default = {}
556
+ for dim in [1, 2, 3]:
557
+
558
+ cls_name = "BNReLU{}ND".format(conv_name)
559
+ register_partial_cls(BNReLUConvBaseND, cls_name, fix=fix, default=default)
560
+
561
+ for dim in [1, 2, 3]:
562
+ cls_name = "BNReLU{}{}D".format(conv_name, dim)
563
+
564
+ register_partial_cls(
565
+ BNReLUConvBaseND, cls_name, fix={**fix, "dim": dim}, default=default
566
+ )
567
+
568
+
569
+ def _register_conv_cls(conv_name, fix=None, default=None):
570
+ if fix is None:
571
+ fix = {}
572
+ if default is None:
573
+ default = {}
574
+
575
+ # simple conv activation
576
+ activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""]
577
+ init_map = {"ReLU": KaimingNormalWeightsZeroBias, "SELU": SELUWeightsZeroBias}
578
+ for activation_str in activations:
579
+ cls_name = cls_name = "{}{}ND".format(conv_name, activation_str)
580
+ initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias)
581
+ if activation_str == "":
582
+ activation = None
583
+ _fix = {**fix}
584
+ _default = {"activation": None}
585
+ elif activation_str == "SELU":
586
+ activation = nn.SELU(inplace=True)
587
+ _fix = {**fix, "activation": activation}
588
+ _default = {**default}
589
+ else:
590
+ activation = activation_str
591
+ _fix = {**fix, "activation": activation}
592
+ _default = {**default}
593
+
594
+ register_partial_cls(
595
+ ConvActivation,
596
+ cls_name,
597
+ fix=_fix,
598
+ default={**_default, "initialization": initialization_cls()},
599
+ )
600
+ for dim in [1, 2, 3]:
601
+ cls_name = "{}{}{}D".format(conv_name, activation_str, dim)
602
+ register_partial_cls(
603
+ ConvActivation,
604
+ cls_name,
605
+ fix={**_fix, "dim": dim},
606
+ default={**_default, "initialization": initialization_cls()},
607
+ )
608
+
609
+
610
+ _register_conv_cls("Conv")
611
+ _register_conv_cls("ValidConv", fix=dict(valid_conv=True))
612
+
613
+ Conv2D = generated_inferno_classes["Conv2D"]
614
+ ValidConv3D = generated_inferno_classes["ValidConv3D"]
615
+
616
+
617
+ ### HyLFM architecture
618
+ class Crop(nn.Module):
619
+ def __init__(self, *slices: slice):
620
+ super().__init__()
621
+ self.slices = slices
622
+
623
+ def extra_repr(self):
624
+ return str(self.slices)
625
+
626
+ def forward(self, input):
627
+ return input[self.slices]
628
+
629
+
630
+ class ChannelFromLightField(nn.Module):
631
+ def __init__(self, nnum: int):
632
+ super().__init__()
633
+ self.nnum = nnum
634
+
635
+ def forward(self, tensor):
636
+ assert len(tensor.shape) == 4, tensor.shape
637
+ b, c, x, y = tensor.shape
638
+ assert c == 1
639
+ assert x % self.nnum == 0, (x, self.nnum)
640
+ assert y % self.nnum == 0, (y, self.nnum)
641
+ return (
642
+ tensor.reshape(b, x // self.nnum, self.nnum, y // self.nnum, self.nnum)
643
+ .transpose(1, 2)
644
+ .transpose(2, 4)
645
+ .transpose(3, 4)
646
+ .reshape(b, self.nnum**2, x // self.nnum, y // self.nnum)
647
+ )
648
+
649
+
650
+ class ResnetBlock(nn.Module):
651
+ def __init__(
652
+ self,
653
+ in_n_filters,
654
+ n_filters,
655
+ kernel_size=(3, 3),
656
+ batch_norm=False,
657
+ conv_per_block=2,
658
+ valid: bool = False,
659
+ activation: str = "ReLU",
660
+ ):
661
+ super().__init__()
662
+ if batch_norm and activation != "ReLU":
663
+ raise NotImplementedError("batch_norm with non ReLU activation")
664
+
665
+ assert isinstance(kernel_size, tuple), kernel_size
666
+ assert conv_per_block >= 2
667
+ self.debug = False # sys.gettrace() is not None
668
+
669
+ Conv = generated_inferno_classes[
670
+ f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{'' if batch_norm else activation}{len(kernel_size)}D"
671
+ ]
672
+ FinalConv = generated_inferno_classes[
673
+ f"{'BNReLU' if batch_norm else ''}{'Valid' if valid else ''}Conv{len(kernel_size)}D"
674
+ ]
675
+
676
+ layers = []
677
+ layers.append(
678
+ Conv(
679
+ in_channels=in_n_filters,
680
+ out_channels=n_filters,
681
+ kernel_size=kernel_size,
682
+ )
683
+ )
684
+
685
+ for _ in range(conv_per_block - 2):
686
+ layers.append(Conv(n_filters, n_filters, kernel_size))
687
+
688
+ layers.append(FinalConv(n_filters, n_filters, kernel_size))
689
+
690
+ self.block = nn.Sequential(*layers)
691
+
692
+ if n_filters != in_n_filters:
693
+ ProjConv = generated_inferno_classes[f"Conv{len(kernel_size)}D"]
694
+ self.projection_layer = ProjConv(in_n_filters, n_filters, kernel_size=1)
695
+ else:
696
+ self.projection_layer = None
697
+
698
+ if valid:
699
+ crop_each_side = [conv_per_block * (ks // 2) for ks in kernel_size]
700
+ self.crop = Crop(..., *[slice(c, -c) for c in crop_each_side])
701
+ else:
702
+ self.crop = None
703
+
704
+ self.relu = nn.ReLU()
705
+
706
+ # determine shrinkage
707
+ # self.shrinkage = (1, 1) + tuple([conv_per_block * (ks - 1) for ks in kernel_size])
708
+
709
+ def forward(self, input):
710
+ x = self.block(input)
711
+ if self.crop is not None:
712
+ input = self.crop(input)
713
+
714
+ if self.projection_layer is None:
715
+ x = x + input
716
+ else:
717
+ projected = self.projection_layer(input)
718
+ x = x + projected
719
+
720
+ x = self.relu(x)
721
+ return x
722
+
723
+
724
+ class HyLFM_Net(nn.Module):
725
+ class InitName(str, Enum):
726
+ uniform_ = "uniform"
727
+ normal_ = "normal"
728
+ constant_ = "constant"
729
+ eye_ = "eye"
730
+ dirac_ = "dirac"
731
+ xavier_uniform_ = "xavier_uniform"
732
+ xavier_normal_ = "xavier_normal"
733
+ kaiming_uniform_ = "kaiming_uniform"
734
+ kaiming_normal_ = "kaiming_normal"
735
+ orthogonal_ = "orthogonal"
736
+ sparse_ = "sparse"
737
+
738
+ def __init__(
739
+ self,
740
+ *,
741
+ z_out: int,
742
+ nnum: int,
743
+ kernel2d: int = 3,
744
+ conv_per_block2d: int = 2,
745
+ c_res2d: Sequence[Union[int, str]] = (488, 488, "u244", 244),
746
+ last_kernel2d: int = 1,
747
+ c_in_3d: int = 7,
748
+ kernel3d: int = 3,
749
+ conv_per_block3d: int = 2,
750
+ c_res3d: Sequence[str] = (7, "u7", 7, 7),
751
+ init_fn: Union[InitName, str] = InitName.xavier_uniform_.value,
752
+ final_activation: Optional[str] = None,
753
+ ):
754
+ super().__init__()
755
+ self.channel_from_lf = ChannelFromLightField(nnum=nnum)
756
+ init_fn = self.InitName(init_fn)
757
+
758
+ if hasattr(nn.init, f"{init_fn.value}_"):
759
+ # prevents deprecation warning
760
+ init_fn = getattr(nn.init, f"{init_fn.value}_")
761
+ else:
762
+ init_fn = getattr(nn.init, init_fn.value)
763
+
764
+ self.c_res2d = list(c_res2d)
765
+ self.c_res3d = list(c_res3d)
766
+ c_res3d = c_res3d
767
+ self.nnum = nnum
768
+ self.z_out = z_out
769
+ if kernel3d != 3:
770
+ raise NotImplementedError("z_out expansion for other res3d kernel")
771
+
772
+ dz = 2 * conv_per_block3d * (kernel3d // 2)
773
+ for c in c_res3d:
774
+ if isinstance(c, int) or not c.startswith("u"):
775
+ z_out += dz
776
+
777
+ # z_out += 4 * (len(c_res3d) - 2 * sum([layer == "u" for layer in c_res3d])) # add z_out for valid 3d convs
778
+
779
+ assert (
780
+ c_res2d[-1] != "u"
781
+ ), "missing # output channels for upsampling in 'c_res2d'"
782
+ assert (
783
+ c_res3d[-1] != "u"
784
+ ), "missing # output channels for upsampling in 'c_res3d'"
785
+
786
+ res2d = []
787
+ c_in = nnum**2
788
+ c_out = c_in
789
+ for i in range(len(c_res2d)):
790
+ if not isinstance(c_res2d[i], int) and c_res2d[i].startswith("u"):
791
+ c_out = int(c_res2d[i][1:])
792
+ res2d.append(
793
+ nn.ConvTranspose2d(
794
+ in_channels=c_in,
795
+ out_channels=c_out,
796
+ kernel_size=2,
797
+ stride=2,
798
+ padding=0,
799
+ output_padding=0,
800
+ )
801
+ )
802
+ else:
803
+ c_out = int(c_res2d[i])
804
+ res2d.append(
805
+ ResnetBlock(
806
+ in_n_filters=c_in,
807
+ n_filters=c_out,
808
+ kernel_size=(kernel2d, kernel2d),
809
+ valid=False,
810
+ conv_per_block=conv_per_block2d,
811
+ )
812
+ )
813
+
814
+ c_in = c_out
815
+
816
+ self.res2d = nn.Sequential(*res2d)
817
+
818
+ if "gain" in inspect.signature(init_fn).parameters:
819
+ init_fn_conv2d = partial(init_fn, gain=nn.init.calculate_gain("relu"))
820
+ else:
821
+ init_fn_conv2d = init_fn
822
+
823
+ init = Initialization(
824
+ weight_initializer=init_fn_conv2d, bias_initializer=Constant(0.0)
825
+ )
826
+ self.conv2d = Conv2D(
827
+ c_out,
828
+ z_out * c_in_3d,
829
+ last_kernel2d,
830
+ activation="ReLU",
831
+ initialization=init,
832
+ )
833
+
834
+ self.c2z = lambda ipt, ip3=c_in_3d: ipt.view(
835
+ ipt.shape[0], ip3, z_out, *ipt.shape[2:]
836
+ )
837
+
838
+ res3d = []
839
+ c_in = c_in_3d
840
+ c_out = c_in
841
+ for i in range(len(c_res3d)):
842
+ if not isinstance(c_res3d[i], int) and c_res3d[i].startswith("u"):
843
+ c_out = int(c_res3d[i][1:])
844
+ res3d.append(
845
+ nn.ConvTranspose3d(
846
+ in_channels=c_in,
847
+ out_channels=c_out,
848
+ kernel_size=(3, 2, 2),
849
+ stride=(1, 2, 2),
850
+ padding=(1, 0, 0),
851
+ output_padding=0,
852
+ )
853
+ )
854
+ else:
855
+ c_out = int(c_res3d[i])
856
+ res3d.append(
857
+ ResnetBlock(
858
+ in_n_filters=c_in,
859
+ n_filters=c_out,
860
+ kernel_size=(kernel3d, kernel3d, kernel3d),
861
+ valid=True,
862
+ conv_per_block=conv_per_block3d,
863
+ )
864
+ )
865
+
866
+ c_in = c_out
867
+
868
+ self.res3d = nn.Sequential(*res3d)
869
+
870
+ if "gain" in inspect.signature(init_fn).parameters:
871
+ init_fn_conv3d = partial(init_fn, gain=nn.init.calculate_gain("linear"))
872
+ else:
873
+ init_fn_conv3d = init_fn
874
+
875
+ init = Initialization(
876
+ weight_initializer=init_fn_conv3d, bias_initializer=Constant(0.0)
877
+ )
878
+ self.conv3d = ValidConv3D(c_out, 1, (1, 1, 1), initialization=init)
879
+
880
+ if final_activation is None:
881
+ self.final_activation = None
882
+ elif final_activation == "sigmoid":
883
+ self.final_activation = nn.Sigmoid()
884
+ else:
885
+ raise NotImplementedError(final_activation)
886
+
887
+ def forward(self, x):
888
+ x = self.channel_from_lf(x)
889
+ x = self.res2d(x)
890
+ x = self.conv2d(x)
891
+ x = self.c2z(x)
892
+ x = self.res3d(x)
893
+ x = self.conv3d(x)
894
+
895
+ if self.final_activation is not None:
896
+ x = self.final_activation(x)
897
+
898
+ return x
899
+
900
+ def get_scale(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
901
+ s = max(
902
+ 1,
903
+ 2
904
+ * sum(
905
+ isinstance(res2d, str) and res2d.startswith("u")
906
+ for res2d in self.c_res2d
907
+ ),
908
+ ) * max(
909
+ 1,
910
+ 2
911
+ * sum(
912
+ isinstance(res3d, str) and res3d.startswith("u")
913
+ for res3d in self.c_res3d
914
+ ),
915
+ )
916
+ return s
917
+
918
+ def get_shrink(self, ipt_shape: Optional[Tuple[int, int]] = None) -> int:
919
+ s = 0
920
+ for res in self.c_res3d:
921
+ if isinstance(res, str) and res.startswith("u"):
922
+ s *= 2
923
+ else:
924
+ s += 2
925
+
926
+ return s
927
+
928
+ def get_output_shape(self, ipt_shape: Tuple[int, int]) -> Tuple[int, int, int]:
929
+ scale = self.get_scale(ipt_shape)
930
+ shrink = self.get_shrink(ipt_shape)
931
+ return (self.z_out,) + tuple(i * scale - 2 * shrink for i in ipt_shape)
932
+
933
+
934
+ if __name__ == "__main__":
935
+ # Example usage
936
+ model = HyLFM_Net(
937
+ z_out=9,
938
+ nnum=5,
939
+ kernel2d=3,
940
+ conv_per_block2d=2,
941
+ c_res2d=(12, 14, "u14", 8),
942
+ last_kernel2d=1,
943
+ c_in_3d=7,
944
+ kernel3d=3,
945
+ conv_per_block3d=2,
946
+ c_res3d=(7, "u7", 7, 7),
947
+ init_fn="xavier_uniform",
948
+ final_activation="sigmoid",
949
+ )
950
+ print(model)
951
+ print(model.get_output_shape((64, 64)))