32353912
Browse files- diffusion.py +7 -6
- perlmutter_diffusion.sbatch +2 -2
- tensorboard.ipynb +8 -26
diffusion.py
CHANGED
|
@@ -234,7 +234,7 @@ class TrainConfig:
|
|
| 234 |
###########################
|
| 235 |
## hardcoding these here ##
|
| 236 |
###########################
|
| 237 |
-
push_to_hub = True
|
| 238 |
hub_model_id = "Xsmos/ml21cm"
|
| 239 |
hub_private_repo = False
|
| 240 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
|
@@ -438,13 +438,13 @@ class DDPM21CM:
|
|
| 438 |
|
| 439 |
del dataset
|
| 440 |
|
| 441 |
-
def transform(self, img, idx):
|
| 442 |
#flip along x or y or both
|
| 443 |
-
flip_xy = [i+
|
| 444 |
img[idx] = torch.flip(img[idx], dims=flip_xy)
|
| 445 |
# flip diagonally
|
| 446 |
if getrandbits(1):
|
| 447 |
-
img
|
| 448 |
#print(f"transform: img.shape={img.shape}, idx={idx}, flip_xy={flip_xy}, w/ transpose")
|
| 449 |
#else:
|
| 450 |
#print(f"transform: img.shape={img.shape}, idx={idx}, flip_xy={flip_xy}, w/o tranpose")
|
|
@@ -499,8 +499,9 @@ class DDPM21CM:
|
|
| 499 |
epoch_start = time()
|
| 500 |
for i, (x, c) in enumerate(self.dataloader):
|
| 501 |
if self.config.dim == 3:
|
| 502 |
-
|
| 503 |
-
|
|
|
|
| 504 |
|
| 505 |
x = x.to(self.config.device)#.to(self.config.dtype)
|
| 506 |
# autocast forward propogation
|
|
|
|
| 234 |
###########################
|
| 235 |
## hardcoding these here ##
|
| 236 |
###########################
|
| 237 |
+
push_to_hub = False #True
|
| 238 |
hub_model_id = "Xsmos/ml21cm"
|
| 239 |
hub_private_repo = False
|
| 240 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
|
|
|
| 438 |
|
| 439 |
del dataset
|
| 440 |
|
| 441 |
+
def transform(self, img, idx=0):
|
| 442 |
#flip along x or y or both
|
| 443 |
+
flip_xy = [i+2 for i in range(2) if getrandbits(1)]
|
| 444 |
img[idx] = torch.flip(img[idx], dims=flip_xy)
|
| 445 |
# flip diagonally
|
| 446 |
if getrandbits(1):
|
| 447 |
+
img = img.transpose(2,3)
|
| 448 |
#print(f"transform: img.shape={img.shape}, idx={idx}, flip_xy={flip_xy}, w/ transpose")
|
| 449 |
#else:
|
| 450 |
#print(f"transform: img.shape={img.shape}, idx={idx}, flip_xy={flip_xy}, w/o tranpose")
|
|
|
|
| 499 |
epoch_start = time()
|
| 500 |
for i, (x, c) in enumerate(self.dataloader):
|
| 501 |
if self.config.dim == 3:
|
| 502 |
+
x = self.transform(x)
|
| 503 |
+
#for idx in range(len(x)):
|
| 504 |
+
# x = self.transform(x, idx)
|
| 505 |
|
| 506 |
x = x.to(self.config.device)#.to(self.config.dtype)
|
| 507 |
# autocast forward propogation
|
perlmutter_diffusion.sbatch
CHANGED
|
@@ -3,9 +3,9 @@
|
|
| 3 |
#SBATCH -J diffusion
|
| 4 |
#SBATCH -C gpu&hbm80g
|
| 5 |
#SBATCH -q regular #shared
|
| 6 |
-
#SBATCH -
|
| 7 |
#SBATCH --gpus-per-node=4
|
| 8 |
-
#SBATCH -t 16:00
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
|
|
|
| 3 |
#SBATCH -J diffusion
|
| 4 |
#SBATCH -C gpu&hbm80g
|
| 5 |
#SBATCH -q regular #shared
|
| 6 |
+
#SBATCH -N1
|
| 7 |
#SBATCH --gpus-per-node=4
|
| 8 |
+
#SBATCH -t 16:00
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
tensorboard.ipynb
CHANGED
|
@@ -2,21 +2,12 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
-
"execution_count":
|
| 6 |
"id": "ae45e44e-a11c-43ef-b830-c7a58a72f51e",
|
| 7 |
"metadata": {
|
| 8 |
"tags": []
|
| 9 |
},
|
| 10 |
-
"outputs": [
|
| 11 |
-
{
|
| 12 |
-
"name": "stdout",
|
| 13 |
-
"output_type": "stream",
|
| 14 |
-
"text": [
|
| 15 |
-
"The tensorboard extension is already loaded. To reload it, use:\n",
|
| 16 |
-
" %reload_ext tensorboard\n"
|
| 17 |
-
]
|
| 18 |
-
}
|
| 19 |
-
],
|
| 20 |
"source": [
|
| 21 |
"import nersc_tensorboard_helper\n",
|
| 22 |
"%load_ext tensorboard"
|
|
@@ -24,30 +15,21 @@
|
|
| 24 |
},
|
| 25 |
{
|
| 26 |
"cell_type": "code",
|
| 27 |
-
"execution_count":
|
| 28 |
"id": "a5c088b8-5051-402f-b4ec-2b684ad5a952",
|
| 29 |
"metadata": {},
|
| 30 |
"outputs": [
|
| 31 |
-
{
|
| 32 |
-
"data": {
|
| 33 |
-
"text/plain": [
|
| 34 |
-
"Reusing TensorBoard on port 33249 (pid 774528), started 4:01:53 ago. (Use '!kill 774528' to kill it.)"
|
| 35 |
-
]
|
| 36 |
-
},
|
| 37 |
-
"metadata": {},
|
| 38 |
-
"output_type": "display_data"
|
| 39 |
-
},
|
| 40 |
{
|
| 41 |
"data": {
|
| 42 |
"text/html": [
|
| 43 |
"\n",
|
| 44 |
-
" <iframe id=\"tensorboard-frame-
|
| 45 |
" </iframe>\n",
|
| 46 |
" <script>\n",
|
| 47 |
" (function() {\n",
|
| 48 |
-
" const frame = document.getElementById(\"tensorboard-frame-
|
| 49 |
" const url = new URL(\"/\", window.location);\n",
|
| 50 |
-
" const port =
|
| 51 |
" if (port) {\n",
|
| 52 |
" url.port = port;\n",
|
| 53 |
" }\n",
|
|
@@ -70,14 +52,14 @@
|
|
| 70 |
},
|
| 71 |
{
|
| 72 |
"cell_type": "code",
|
| 73 |
-
"execution_count":
|
| 74 |
"id": "2f76c0a9-2218-4073-86aa-f4f655d7642f",
|
| 75 |
"metadata": {},
|
| 76 |
"outputs": [
|
| 77 |
{
|
| 78 |
"data": {
|
| 79 |
"text/html": [
|
| 80 |
-
"<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/
|
| 81 |
],
|
| 82 |
"text/plain": [
|
| 83 |
"<IPython.core.display.HTML object>"
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 3,
|
| 6 |
"id": "ae45e44e-a11c-43ef-b830-c7a58a72f51e",
|
| 7 |
"metadata": {
|
| 8 |
"tags": []
|
| 9 |
},
|
| 10 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"source": [
|
| 12 |
"import nersc_tensorboard_helper\n",
|
| 13 |
"%load_ext tensorboard"
|
|
|
|
| 15 |
},
|
| 16 |
{
|
| 17 |
"cell_type": "code",
|
| 18 |
+
"execution_count": 4,
|
| 19 |
"id": "a5c088b8-5051-402f-b4ec-2b684ad5a952",
|
| 20 |
"metadata": {},
|
| 21 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
{
|
| 23 |
"data": {
|
| 24 |
"text/html": [
|
| 25 |
"\n",
|
| 26 |
+
" <iframe id=\"tensorboard-frame-7f9f2d9b643f0d7b\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
|
| 27 |
" </iframe>\n",
|
| 28 |
" <script>\n",
|
| 29 |
" (function() {\n",
|
| 30 |
+
" const frame = document.getElementById(\"tensorboard-frame-7f9f2d9b643f0d7b\");\n",
|
| 31 |
" const url = new URL(\"/\", window.location);\n",
|
| 32 |
+
" const port = 40725;\n",
|
| 33 |
" if (port) {\n",
|
| 34 |
" url.port = port;\n",
|
| 35 |
" }\n",
|
|
|
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
| 55 |
+
"execution_count": 5,
|
| 56 |
"id": "2f76c0a9-2218-4073-86aa-f4f655d7642f",
|
| 57 |
"metadata": {},
|
| 58 |
"outputs": [
|
| 59 |
{
|
| 60 |
"data": {
|
| 61 |
"text/html": [
|
| 62 |
+
"<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/40725/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/40725/</a>"
|
| 63 |
],
|
| 64 |
"text/plain": [
|
| 65 |
"<IPython.core.display.HTML object>"
|