Update README.md
Browse files
README.md
CHANGED
|
@@ -5,21 +5,33 @@ license: apache-2.0
|
|
| 5 |
|
| 6 |
Code for [paper](https://arxiv.org/pdf/2403.03542) DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024). It pretrains neural operator transformers (from **7M** to **1B**) on multiple PDE datasets. We will release the pre-trained weights soon.
|
| 7 |
|
| 8 |
-

|
| 9 |
|
| 10 |
Our pre-trained DPOT achieves the state-of-the-art performance on multiple PDE datasets and could be used for finetuning on different types of downstream PDE problems.
|
| 11 |
|
| 12 |
-

|
| 13 |
|
| 14 |
|
| 15 |
|
| 16 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
##### Dataset Protocol
|
| 19 |
|
| 20 |
All datasets are stored using hdf5 format, containing `data` field. Some datasets are stored with individual hdf5 files, others are stored within a single hdf5 file.
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
| Dataset | Link |
|
| 25 |
| ------------- | ------------------------------------------------------------ |
|
|
@@ -28,78 +40,6 @@ In `data_generation/preprocess.py`, we have the script for preprocessing the da
|
|
| 28 |
| PDEArena data | [Here](https://microsoft.github.io/pdearena/datadownload/) |
|
| 29 |
| CFDbench data | [Here](https://cloud.tsinghua.edu.cn/d/435413b55dea434297d1/) |
|
| 30 |
|
| 31 |
-
In `utils/make_master_file.py` , we have all dataset configurations. When new datasets are merged, you should add a configuration dict. It stores all relative paths so that you could run on any places.
|
| 32 |
-
|
| 33 |
-
```bash
|
| 34 |
-
mkdir data
|
| 35 |
-
```
|
| 36 |
-
|
| 37 |
-
##### Single GPU Pre-training
|
| 38 |
-
|
| 39 |
-
Now we have a single GPU pretraining code script `train_temporal.py`, you could start it by
|
| 40 |
-
|
| 41 |
-
```bash
|
| 42 |
-
python train_temporal.py --model DPOT --train_paths ns2d_fno_1e-5 --test_paths ns2d_fno_1e-5 --gpu 0
|
| 43 |
-
```
|
| 44 |
-
|
| 45 |
-
to start a training process.
|
| 46 |
-
|
| 47 |
-
Or you could start it by writing a configuration file in `configs/ns2d.yaml` and start it by automatically using free GPUs with
|
| 48 |
-
|
| 49 |
-
```bash
|
| 50 |
-
python trainer.py --config_file ns2d.yaml
|
| 51 |
-
```
|
| 52 |
-
|
| 53 |
-
##### Multiple GPU Pre-training
|
| 54 |
-
|
| 55 |
-
```bash
|
| 56 |
-
python parallel_trainer.py --config_file ns2d_parallel.yaml
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
##### Configuration file
|
| 60 |
-
|
| 61 |
-
Now I use yaml as the configuration file. You could specify parameters for args. If you want to run multiple tasks, you could move parameters into the `tasks` ,
|
| 62 |
-
|
| 63 |
-
```yaml
|
| 64 |
-
model: DPOT
|
| 65 |
-
width: 512
|
| 66 |
-
tasks:
|
| 67 |
-
lr: [0.001,0.0001]
|
| 68 |
-
batch_size: [256, 32]
|
| 69 |
-
```
|
| 70 |
-
|
| 71 |
-
This means that you start 2 tasks if you submit this configuration to `trainer.py`.
|
| 72 |
-
|
| 73 |
-
##### Requirement
|
| 74 |
-
|
| 75 |
-
Install the following packages via conda-forge
|
| 76 |
-
|
| 77 |
-
```bash
|
| 78 |
-
conda install pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
|
| 79 |
-
conda install matplotlib scikit-learn scipy pandas h5py -c conda-forge
|
| 80 |
-
conda install timm einops tensorboard -c conda-forge
|
| 81 |
-
```
|
| 82 |
-
|
| 83 |
-
### Code Structure
|
| 84 |
-
|
| 85 |
-
- `README.md`
|
| 86 |
-
- `train_temporal.py`: main code of single GPU pre-training auto-regressive model
|
| 87 |
-
- `trainer.py`: framework of auto scheduling training tasks for parameter tuning
|
| 88 |
-
- `utils/`
|
| 89 |
-
- `criterion.py`: loss functions of relative error
|
| 90 |
-
- `griddataset.py`: dataset of mixture of temporal uniform grid dataset
|
| 91 |
-
- `make_master_file.py`: datasets config file
|
| 92 |
-
- `normalizer`: normalization methods (#TODO: implement instance reversible norm)
|
| 93 |
-
- `optimizer`: Adam/AdamW/Lamb optimizer supporting complex numbers
|
| 94 |
-
- `utilities.py`: other auxiliary functions
|
| 95 |
-
- `configs/`: configuration files for pre-training or fine-tuning
|
| 96 |
-
- `models/`
|
| 97 |
-
- `dpot.py`: DPOT model
|
| 98 |
-
- `fno.py`: FNO with group normalization
|
| 99 |
-
- `mlp.py`
|
| 100 |
-
- `data_generation/`: Some code for preprocessing data (ask hzk if you want to use them)
|
| 101 |
-
- `darcy/`
|
| 102 |
-
- `ns2d/`
|
| 103 |
|
| 104 |
|
| 105 |
|
|
|
|
| 5 |
|
| 6 |
Code for [paper](https://arxiv.org/pdf/2403.03542) DPOT: Auto-Regressive Denoising Operator Transformer for Large-Scale PDE Pre-Training (ICML'2024). It pretrains neural operator transformers (from **7M** to **1B**) on multiple PDE datasets. We will release the pre-trained weights soon.
|
| 7 |
|
| 8 |
+
<!--  -->
|
| 9 |
|
| 10 |
Our pre-trained DPOT achieves the state-of-the-art performance on multiple PDE datasets and could be used for finetuning on different types of downstream PDE problems.
|
| 11 |
|
| 12 |
+
<!--  -->
|
| 13 |
|
| 14 |
|
| 15 |
|
| 16 |
+
### Pre-trained Model Configuration
|
| 17 |
+
|
| 18 |
+
We have five pre-trained checkpoints of different sizes.
|
| 19 |
+
|
| 20 |
+
| Size | Attention dim | MLP dim | Layers | Heads | Model size |
|
| 21 |
+
|--------|---------------|---------|--------|-------|------------|
|
| 22 |
+
| Tiny | 512 | 512 | 4 | 4 | 7M |
|
| 23 |
+
| Small | 1024 | 1024 | 6 | 8 | 30M |
|
| 24 |
+
| Medium | 1024 | 4096 | 12 | 8 | 122M |
|
| 25 |
+
| Large | 1536 | 6144 | 24 | 16 | 509M |
|
| 26 |
+
| Huge | 2048 | 8092 | 27 | 8 | 1.03B |
|
| 27 |
+
|
| 28 |
+
|
| 29 |
|
| 30 |
##### Dataset Protocol
|
| 31 |
|
| 32 |
All datasets are stored using hdf5 format, containing `data` field. Some datasets are stored with individual hdf5 files, others are stored within a single hdf5 file.
|
| 33 |
|
| 34 |
+
Download the original file from these sources and preprocess them to `/data` folder.
|
| 35 |
|
| 36 |
| Dataset | Link |
|
| 37 |
| ------------- | ------------------------------------------------------------ |
|
|
|
|
| 40 |
| PDEArena data | [Here](https://microsoft.github.io/pdearena/datadownload/) |
|
| 41 |
| CFDbench data | [Here](https://cloud.tsinghua.edu.cn/d/435413b55dea434297d1/) |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
|