YuanGao-YG commited on
Commit
ff06f7b
·
verified ·
1 Parent(s): ecf41a8

Upload 4 files

Browse files
train_base_model.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import h5py
5
+ import json
6
+ import torch
7
+ import pickle
8
+ import logging
9
+ import argparse
10
+ import cProfile
11
+ import numpy as np
12
+ # import matplotlib.pyplot as plt
13
+ from icecream import ic
14
+ from shutil import copyfile
15
+ from collections import OrderedDict
16
+ import torchvision
17
+ import torch.nn as nn
18
+ import torch.cuda.amp as amp
19
+ import torch.distributed as dist
20
+ from torchsummary import summary
21
+ from torchvision.utils import save_image
22
+ from torch.nn.parallel import DistributedDataParallel
23
+
24
+ from my_utils import logging_utils
25
+ logging_utils.config_logger()
26
+ from my_utils.YParams import YParams
27
+ from my_utils.darcy_loss import LossScaler, LpLoss, channel_wise_LpLoss
28
+ from my_utils.data_loader import get_data_loader
29
+
30
+ from ruamel.yaml import YAML
31
+ from ruamel.yaml.comments import CommentedMap as ruamelDict
32
+ import torch.utils.checkpoint as checkpoint
33
+ import gc
34
+
35
+
36
+
37
+
38
+ class Trainer():
39
+ def count_parameters(self):
40
+ return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
41
+
42
+ def __init__(self, params, world_rank):
43
+
44
+ self.params = params
45
+ self.world_rank = world_rank
46
+ self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
47
+
48
+
49
+ # Init gpu
50
+ local_rank = int(os.environ["LOCAL_RANK"])
51
+ torch.cuda.set_device(local_rank)
52
+ self.device = torch.device('cuda', local_rank)
53
+ logging.info('device: %s' % self.device)
54
+
55
+ # Load data
56
+ logging.info('rank %d, begin data loader init' % world_rank)
57
+ self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(
58
+ params,
59
+ params.train_data_path,
60
+ dist.is_initialized(),
61
+ train=True)
62
+ self.valid_data_loader, self.valid_dataset, self.valid_sampler = get_data_loader(
63
+ params,
64
+ params.valid_data_path,
65
+ dist.is_initialized(),
66
+ train=True)
67
+
68
+ if params.loss_channel_wise:
69
+ self.loss_obj = channel_wise_LpLoss(scale = params.loss_scale)
70
+
71
+ # loss scaler
72
+ self.mse_loss_scaler = LossScaler()
73
+
74
+ logging.info('rank %d, data loader initialized' % world_rank)
75
+
76
+ # Load model
77
+ if params.nettype == 'NeuralOM':
78
+ from networks.MIGNN1 import MIGraph as model
79
+ else:
80
+ raise Exception("not implemented")
81
+
82
+ self.model = model(params).to(self.device)
83
+
84
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr = params.lr)
85
+
86
+ if params.enable_amp == True:
87
+ self.gscaler = amp.GradScaler()
88
+
89
+ if dist.is_initialized():
90
+ self.model = DistributedDataParallel(
91
+ self.model,
92
+ device_ids=[params.local_rank],
93
+ output_device=[params.local_rank],
94
+ find_unused_parameters=False
95
+ )
96
+
97
+ self.iters = 0
98
+ self.startEpoch = 0
99
+
100
+ if (params.multi_steps_finetune == 1) and (params.resuming):
101
+ logging.info("Loading checkpoint %s" % params.checkpoint_path)
102
+ self.restore_checkpoint(params.checkpoint_path)
103
+
104
+ if params.multi_steps_finetune > 1:
105
+ logging.info("Starting from pretrained one-step model at %s"%params.pretrained_ckpt_path)
106
+ self.restore_checkpoint(params.pretrained_ckpt_path)
107
+ self.iters = 0
108
+ self.startEpoch = 0
109
+ logging.info("Adding %d epochs specified in config file for refining pretrained model"%params.finetune_max_epochs)
110
+ params['max_epochs'] = params.finetune_max_epochs
111
+
112
+ self.epoch = self.startEpoch
113
+
114
+ if params.scheduler == 'CosineAnnealingLR':
115
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
116
+ self.optimizer,
117
+ T_max=params.max_epochs,
118
+ last_epoch=self.startEpoch - 1
119
+ )
120
+ else:
121
+ self.scheduler = None
122
+
123
+ if params.log_to_screen:
124
+ logging.info("Number of trainable model parameters: {}".format(self.count_parameters()))
125
+
126
+ def switch_off_grad(self, model):
127
+ for param in model.parameters():
128
+ param.requires_grad = False
129
+
130
+ def train(self):
131
+ if self.params.log_to_screen:
132
+ logging.info("Starting Training Loop...")
133
+
134
+ best_valid_loss = 1.e6
135
+ for epoch in range(self.startEpoch, self.params.max_epochs):
136
+ if dist.is_initialized():
137
+ self.train_sampler.set_epoch(epoch)
138
+ self.valid_sampler.set_epoch(epoch)
139
+
140
+ start = time.time()
141
+ tr_time, data_time, step_time, train_logs = self.train_one_epoch()
142
+ valid_time, valid_logs = self.validate_one_epoch()
143
+
144
+ if self.world_rank == 0:
145
+ if self.params.save_checkpoint:
146
+ # checkpoint at the end of every epoch
147
+ self.save_checkpoint(self.params.checkpoint_path)
148
+ if valid_logs['valid_loss'] <= best_valid_loss:
149
+ logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss']))
150
+ self.save_checkpoint(self.params.best_checkpoint_path)
151
+ best_valid_loss = valid_logs['valid_loss']
152
+
153
+ if self.params.log_to_screen:
154
+ logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
155
+ logging.info('lr for epoch {} is {}'.format(epoch + 1, self.optimizer.param_groups[0]['lr']))
156
+ logging.info('train data time={}, train per epoch time={}, train per step time={}, valid time={}'.format(data_time, tr_time, step_time, valid_time))
157
+ logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['train_loss'], valid_logs['valid_loss']))
158
+
159
+ if self.params.scheduler == 'CosineAnnealingLR':
160
+ self.scheduler.step()
161
+
162
+ torch.cuda.empty_cache()
163
+ gc.collect()
164
+
165
+ def land_mask_func(self, x, y, land_mask_path):
166
+ # 0:land, 1:ocean
167
+ with h5py.File(land_mask_path, 'r') as _f:
168
+ # logging.info(f"Loading land mask data from {self.params.land_mask_path}")
169
+ mask_data = torch.as_tensor(_f['fields'])
170
+ # ic(mask_data.shape)
171
+ mask_data = mask_data[0,self.params.out_channels].to(x.device, dtype=torch.bool)
172
+ # ic(mask_data.shape, x.shape, y.shape)
173
+ x = torch.masked_fill(input=x, mask=~mask_data, value=0)
174
+ y = torch.masked_fill(input=y, mask=~mask_data, value=0)
175
+ return x, y
176
+
177
+ def train_one_epoch(self):
178
+ self.epoch += 1
179
+ tr_time = 0
180
+ data_time = 0
181
+ self.model.train()
182
+
183
+
184
+ steps_in_one_epoch = 0
185
+ for i, data in enumerate(self.train_data_loader, 0):
186
+ self.iters += 1
187
+ steps_in_one_epoch += 1
188
+
189
+ data_start = time.time()
190
+
191
+ (inp, tar) = data
192
+
193
+ if self.params.orography and self.params.multi_steps_finetune > 1:
194
+ orog = torch.unsqueeze(inp[:,-1], dim=1)
195
+
196
+ data_time += time.time() - data_start
197
+
198
+ tr_start = time.time()
199
+ self.model.zero_grad()
200
+
201
+ num_steps = params.multi_steps_finetune
202
+ # print('num_steps:', num_steps)
203
+
204
+ with amp.autocast(self.params.enable_amp):
205
+
206
+ gen_prev = None
207
+ loss = 0.0
208
+ cw_loss = 0.0
209
+
210
+ for step_idx in range(num_steps):
211
+ if step_idx == 0:
212
+ inp_step_1 = inp.to(self.device, dtype = torch.float32)
213
+ if params.multi_steps_finetune == 1:
214
+ gen_cur = self.model(inp_step_1)
215
+ else:
216
+ gen_cur = checkpoint.checkpoint(self.model, inp_step_1, use_reentrant=False)
217
+ else:
218
+ atmos_force0 = tar[:, step_idx-1, self.params.atmos_channels].to(self.device, dtype=torch.float)
219
+ atmos_force1 = tar[:, step_idx, self.params.atmos_channels].to(self.device, dtype=torch.float)
220
+ gen_prev = torch.cat( (gen_prev, atmos_force0, atmos_force1), axis = 1).to(self.device, dtype = torch.float32)
221
+ gen_cur = checkpoint.checkpoint(self.model, gen_prev, use_reentrant=False)
222
+
223
+ if params.multi_steps_finetune == 1:
224
+ tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float)
225
+ else:
226
+ tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float)
227
+
228
+ if self.params.land_mask:
229
+ # print('land_mask')
230
+ gen_cur, tar_step = self.land_mask_func(gen_cur, tar_step, self.params.land_mask_path)
231
+
232
+ loss_step, cw_loss_step = self.loss_obj(gen_cur, tar_step)
233
+
234
+ loss += loss_step
235
+ cw_loss += cw_loss_step
236
+ if step_idx == 0:
237
+ del inp
238
+ mse1 = torch.mean((gen_cur - tar_step) ** 2).item()
239
+
240
+ gen_prev = gen_cur
241
+
242
+ del tar_step, gen_cur
243
+ del gen_prev
244
+
245
+ if self.params.enable_amp:
246
+ self.gscaler.scale(loss).backward()
247
+ self.gscaler.step(self.optimizer)
248
+ else:
249
+ loss.backward()
250
+ self.optimizer.step()
251
+ # print('1_step_mse:', mse1)
252
+
253
+
254
+ if self.params.enable_amp:
255
+ self.gscaler.update()
256
+ # break
257
+
258
+ tr_time += time.time() - tr_start
259
+
260
+ logs = {'train_loss': loss}
261
+
262
+ for vi, v in enumerate(self.params.out_variables):
263
+ logs[f'{v}_train_loss'] = cw_loss[vi]
264
+
265
+ if dist.is_initialized():
266
+ for key in sorted(logs.keys()):
267
+ dist.all_reduce(logs[key].detach())
268
+ logs[key] = float(logs[key] / dist.get_world_size())
269
+
270
+ # time of one step in epoch
271
+ step_time = tr_time / steps_in_one_epoch
272
+
273
+ return tr_time, data_time, step_time, logs
274
+
275
+ def validate_one_epoch(self):
276
+
277
+ logging.info('validating...')
278
+ self.model.eval()
279
+
280
+ valid_buff = torch.zeros((3+self.params.N_out_channels), dtype=torch.float32, device=self.device)
281
+ valid_loss = valid_buff[0].view(-1) # 0
282
+ valid_l1 = valid_buff[1].view(-1) # 0
283
+ valid_steps = valid_buff[-1].view(-1) # 0
284
+
285
+ valid_start = time.time()
286
+ sample_idx = np.random.randint(len(self.valid_data_loader))
287
+ with torch.no_grad():
288
+ for i, data in enumerate(self.valid_data_loader, 0):
289
+ # if i > 1:
290
+ # break
291
+ inp, tar = map(lambda x: x.to(self.device, dtype=torch.float), data)
292
+ # gen = self.model(inp)
293
+ num_steps = params.multi_steps_finetune
294
+ for step_idx in range(num_steps):
295
+ if step_idx == 0:
296
+ inp_step_1 = inp.to(self.device, dtype = torch.float32)
297
+ gen_cur = self.model(inp_step_1)
298
+ else:
299
+ atmos_force0 = tar[:, step_idx-1, self.params.atmos_channels].to(self.device, dtype=torch.float)
300
+ atmos_force1 = tar[:, step_idx, self.params.atmos_channels].to(self.device, dtype=torch.float)
301
+ gen_prev = torch.cat( (gen_prev, atmos_force0, atmos_force1), axis = 1).to(self.device, dtype = torch.float32)
302
+ gen_cur = self.model(gen_prev)
303
+ # gen_cur = checkpoint.checkpoint(self.model, gen_prev, use_reentrant=False)
304
+
305
+ if params.multi_steps_finetune == 1:
306
+ tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float)
307
+ else:
308
+ tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float)
309
+ if self.params.land_mask:
310
+ gen_cur, tar_step = self.land_mask_func(gen_cur, tar_step, self.params.land_mask_path)
311
+ if step_idx == 0:
312
+ del inp_step_1
313
+ gen_prev = gen_cur
314
+
315
+ if step_idx == params.multi_steps_finetune - 1:
316
+ gen, tar = gen_cur, tar_step
317
+
318
+ del tar_step, gen_cur
319
+ del gen_prev
320
+
321
+ gen.to(self.device, dtype=torch.float)
322
+
323
+ if self.params.land_mask:
324
+ gen, tar = self.land_mask_func(gen, tar, self.params.land_mask_path)
325
+
326
+ _, cw_valid_loss = self.loss_obj(gen, tar)
327
+ valid_loss_ = torch.mean((gen[:, :, :, :] - tar[:, :, :, :]) ** 2).item()
328
+ valid_loss += valid_loss_
329
+ valid_l1 += nn.functional.l1_loss(gen, tar)
330
+
331
+ for vi, v in enumerate(self.params.out_variables):
332
+ valid_buff[vi+2] += cw_valid_loss[vi]
333
+
334
+ valid_steps += 1.
335
+
336
+ # save fields for vis before log norm
337
+ os.makedirs(params['experiment_dir'] + "/" + str(i), exist_ok =True)
338
+
339
+ del gen, tar
340
+
341
+ if dist.is_initialized():
342
+ dist.all_reduce(valid_buff)
343
+
344
+ # divide by number of steps
345
+ valid_buff[0:-1] = valid_buff[0:-1] / valid_buff[-1] # loss/steps, l1/steps
346
+ valid_buff_cpu = valid_buff.detach().cpu().numpy()
347
+
348
+ valid_time = time.time() - valid_start
349
+
350
+ logs = {'valid_loss': valid_buff_cpu[0],
351
+ 'valid_l1': valid_buff_cpu[1]}
352
+ for vi, v in enumerate(self.params.out_variables):
353
+ logs[f'{v}_valid_loss'] = valid_buff_cpu[vi+2]
354
+
355
+
356
+ return valid_time, logs
357
+
358
+
359
+
360
+ def load_model(self, model_path):
361
+ if self.params.log_to_screen:
362
+ logging.info('Loading the model weights from {}'.format(model_path))
363
+
364
+ checkpoint = torch.load(model_path, map_location='cuda:{}'.format(self.params.local_rank))
365
+
366
+ if dist.is_initialized():
367
+ self.model.load_state_dict(checkpoint['model_state'])
368
+ else:
369
+ new_model_state = OrderedDict()
370
+ model_key = 'model_state' if 'model_state' in checkpoint else 'state_dict'
371
+ for key in checkpoint[model_key].keys():
372
+ if 'module.' in key: # model was stored using ddp which prepends module
373
+ name = str(key[7:])
374
+ new_model_state[name] = checkpoint[model_key][key]
375
+ else:
376
+ new_model_state[key] = checkpoint[model_key][key]
377
+ self.model.load_state_dict(new_model_state)
378
+ self.model.eval()
379
+
380
+ def save_checkpoint(self, checkpoint_path, model=None):
381
+ """ We intentionally require a checkpoint_dir to be passed
382
+ in order to allow Ray Tune to use this function """
383
+
384
+ if not model:
385
+ model = self.model
386
+
387
+ torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(),
388
+ 'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)
389
+
390
+ def restore_checkpoint(self, checkpoint_path):
391
+ """ We intentionally require a checkpoint_dir to be passed
392
+ in order to allow Ray Tune to use this function """
393
+ checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank))
394
+ try:
395
+ self.model.load_state_dict(checkpoint['model_state'])
396
+ except:
397
+ new_state_dict = OrderedDict()
398
+ for key, val in checkpoint['model_state'].items():
399
+ name = key[7:]
400
+ new_state_dict[name] = val
401
+ self.model.load_state_dict(new_state_dict)
402
+ self.iters = checkpoint['iters']
403
+ self.startEpoch = checkpoint['epoch']
404
+ if self.params.resuming and (self.params.multi_steps_finetune == 1):
405
+ # restore checkpoint is used for finetuning as well as resuming.
406
+ # If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
407
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
408
+
409
+
410
+ if __name__ == '__main__':
411
+ parser = argparse.ArgumentParser()
412
+ parser.add_argument("--run_num", default='00', type=str)
413
+ parser.add_argument("--yaml_config", default='./config/Model.yaml', type=str)
414
+ parser.add_argument("--multi_steps_finetune", default=1, type=int)
415
+ parser.add_argument("--finetune_max_epochs", default=50, type=int)
416
+ parser.add_argument("--batch_size", default=16, type=int)
417
+ parser.add_argument("--config", default='MIGraph', type=str)
418
+ parser.add_argument("--enable_amp", action='store_true')
419
+ parser.add_argument("--epsilon_factor", default=0, type=float)
420
+ parser.add_argument("--local_rank", default=-1, type=int, help='node rank for distributed training')
421
+ args = parser.parse_args()
422
+
423
+ params = YParams(os.path.abspath(args.yaml_config), args.config, True)
424
+ params['epsilon_factor'] = args.epsilon_factor
425
+ params['multi_steps_finetune'] = args.multi_steps_finetune
426
+ params['finetune_max_epochs'] = args.finetune_max_epochs
427
+
428
+ params['world_size'] = 1
429
+ if 'WORLD_SIZE' in os.environ:
430
+ params['world_size'] = int(os.environ['WORLD_SIZE'])
431
+ print('world_size :', params['world_size'])
432
+
433
+ print('Initialize distributed process group...')
434
+ dist.init_process_group(backend='nccl')
435
+ local_rank = int(os.environ["LOCAL_RANK"])
436
+ torch.cuda.set_device(local_rank)
437
+ params['local_rank'] = local_rank # GPU ID
438
+
439
+ torch.backends.cudnn.benchmark = True
440
+ world_rank = dist.get_rank()
441
+
442
+ params['global_batch_size'] = args.batch_size
443
+ params['batch_size'] = int(args.batch_size // params['world_size']) # batch size must be divisible by the number of gpu's
444
+ params['enable_amp'] = args.enable_amp # Automatic Mixed Precision Training
445
+
446
+ # Set up directory
447
+ if params['multi_steps_finetune'] > 1:
448
+ pretrained_expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
449
+ multi_steps = params['multi_steps_finetune']
450
+ if params['multi_steps_finetune'] > 2:
451
+ params['pretrained_ckpt_path'] = os.path.join(pretrained_expDir, f'{multi_steps-1}_steps_finetune/training_checkpoints/best_ckpt.tar')
452
+ else:
453
+ params['pretrained_ckpt_path'] = os.path.join(pretrained_expDir, 'training_checkpoints/best_ckpt.tar')
454
+
455
+ expDir = os.path.join(pretrained_expDir, f'{multi_steps}_steps_finetune')
456
+ if world_rank == 0:
457
+ os.makedirs(expDir, exist_ok=True)
458
+ os.makedirs(os.path.join(expDir, 'training_checkpoints/'), exist_ok=True)
459
+
460
+ params['experiment_dir'] = os.path.abspath(expDir)
461
+ params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
462
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
463
+
464
+ params['resuming'] = True
465
+ else:
466
+ expDir = os.path.join(params.exp_dir, args.config, str(args.run_num))
467
+ if world_rank == 0:
468
+ os.makedirs(expDir, exist_ok =True)
469
+ os.makedirs(os.path.join(expDir, 'training_checkpoints/'), exist_ok =True)
470
+ copyfile(os.path.abspath(args.yaml_config), os.path.join(expDir, 'config.yaml'))
471
+
472
+ params['experiment_dir'] = os.path.abspath(expDir)
473
+ params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
474
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
475
+
476
+ # Do not comment this line out please:
477
+ args.resuming = True if os.path.isfile(params.checkpoint_path) else False
478
+ params['resuming'] = args.resuming
479
+
480
+
481
+ if world_rank == 0:
482
+ logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'train.log'))
483
+ logging_utils.log_versions()
484
+ params.log()
485
+
486
+ params['log_to_screen'] = (world_rank == 0) and params['log_to_screen']
487
+
488
+ params['in_channels'] = np.array(params['in_channels'])
489
+ params['out_channels'] = np.array(params['out_channels'])
490
+ params['N_out_channels'] = len(params['out_channels'])
491
+ if params.orography:
492
+ params['N_in_channels'] = len(params['in_channels']) + 1
493
+ else:
494
+ params['N_in_channels'] = len(params['in_channels'])
495
+
496
+ if world_rank == 0:
497
+ hparams = ruamelDict()
498
+ yaml = YAML()
499
+ for key, value in params.params.items():
500
+ hparams[str(key)] = str(value)
501
+ with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile:
502
+ yaml.dump(hparams, hpfile)
503
+
504
+ trainer = Trainer(params, world_rank)
505
+ trainer.train()
506
+ logging.info('DONE ---- rank %d' % world_rank)
train_base_model.sh ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_group='NeuralOM'
2
+ yaml_config='config/Model.yaml'
3
+ config='NeuralOM'
4
+ batch_size=16
5
+ run_num=$(date "+%Y%m%d-%H%M%S")
6
+ # run_num='20250501-000000'
7
+ multi_steps_finetune=1
8
+ finetune_max_epochs=0
9
+
10
+ TRAIN_DIR=$(dirname $(realpath train_base_model.py))
11
+
12
+ export MASTER_ADDR=30.207.97.183 # 主节点的IP地址或主机名
13
+ export MASTER_PORT=31317
14
+ export WORLD_SIZE=16
15
+ export NODE_RANK=0
16
+
17
+ source ~/.bashrc
18
+ conda activate triton_v2
19
+ export NCCL_IB_GID_INDEX=3
20
+ export NCCL_IB_SL=3
21
+ export NCCL_CHECK_DISABLE=1
22
+ export NCCL_P2P_DISABLE=0
23
+ export NCCL_IB_DISABLE=0
24
+ export NCCL_LL_THRESHOLD=16384
25
+ export NCCL_IB_CUDA_SUPPORT=1
26
+ export NCCL_TOPO_AFFINITY=0
27
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
28
+ export NCCL_COLLNET_ENABLE=0
29
+ export SHARP_COLL_ENABLE_SAT=0
30
+ export NCCL_NET_GDR_LEVEL=2
31
+ export NCCL_IB_QPS_PER_CONNECTION=4
32
+ export NCCL_IB_TC=160
33
+ export NCCL_PXN_DISABLE=0
34
+ export NCCL_DEBUG=WARN
35
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2400
36
+ export NCCL_SOCKET_IFNAME=bond1
37
+
38
+ export TORCH_NCCL_BLOCKING_WAIT=1
39
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
40
+
41
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
42
+ nohup torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT $TRAIN_DIR/train_base_model.py \
43
+ --yaml_config=$yaml_config --config=$config --run_num=$run_num --batch_size=$batch_size --multi_steps_finetune=$multi_steps_finetune --finetune_max_epochs=$finetune_max_epochs \
44
+ >> ./logs/${config}_${wandb_group}_rank0_${SLURM_JOB_ID}_${run_num}.log 2>&1 &
45
+
46
+ ssh root@30.207.98.235 "
47
+ source ~/.bashrc; \
48
+ conda activate triton_v2; \
49
+
50
+ export NCCL_IB_GID_INDEX=3
51
+ export NCCL_IB_SL=3
52
+ export NCCL_CHECK_DISABLE=1
53
+ export NCCL_P2P_DISABLE=0
54
+ export NCCL_IB_DISABLE=0
55
+ export NCCL_LL_THRESHOLD=16384
56
+ export NCCL_IB_CUDA_SUPPORT=1
57
+ export NCCL_TOPO_AFFINITY=0
58
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
59
+ export NCCL_COLLNET_ENABLE=0
60
+ export SHARP_COLL_ENABLE_SAT=0
61
+ export NCCL_NET_GDR_LEVEL=2
62
+ export NCCL_IB_QPS_PER_CONNECTION=4
63
+ export NCCL_IB_TC=160
64
+ export NCCL_PXN_DISABLE=0
65
+ export NCCL_DEBUG=WARN
66
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2400
67
+ export NCCL_SOCKET_IFNAME=bond1
68
+
69
+ export TORCH_NCCL_BLOCKING_WAIT=1
70
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
71
+
72
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7; \
73
+ export MASTER_ADDR=$MASTER_ADDR; export MASTER_PORT=$MASTER_PORT; export WORLD_SIZE=16; export NODE_RANK=1; \
74
+ nohup torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT $TRAIN_DIR/train_base_model.py \
75
+ --yaml_config=$yaml_config --config=$config --run_num=$run_num --batch_size=$batch_size --multi_steps_finetune=$multi_steps_finetune --finetune_max_epochs=$finetune_max_epochs \
76
+ >> $TRAIN_DIR/logs/${config}_${wandb_group}_rank1_${SLURM_JOB_ID}_${run_num}.log 2>&1 &"
train_residual_model.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import h5py
5
+ import json
6
+ import torch
7
+ import pickle
8
+ import logging
9
+ import argparse
10
+ import cProfile
11
+ import numpy as np
12
+ # import matplotlib.pyplot as plt
13
+ from icecream import ic
14
+ from shutil import copyfile
15
+ from collections import OrderedDict
16
+ import torchvision
17
+ import torch.nn as nn
18
+ import torch.cuda.amp as amp
19
+ import torch.distributed as dist
20
+ from torchvision.utils import save_image
21
+ from torch.nn.parallel import DistributedDataParallel
22
+
23
+ from my_utils import logging_utils
24
+ logging_utils.config_logger()
25
+ from my_utils.YParams import YParams
26
+ from my_utils.darcy_loss import LossScaler, LpLoss, channel_wise_LpLoss
27
+ from my_utils.data_loader import get_data_loader
28
+
29
+ from ruamel.yaml import YAML
30
+ from ruamel.yaml.comments import CommentedMap as ruamelDict
31
+ import torch.utils.checkpoint as checkpoint
32
+ import gc
33
+
34
+
35
+ class Trainer():
36
+ def count_parameters(self):
37
+ return sum(p.numel() for p in self.model2.parameters() if p.requires_grad)
38
+
39
+ def __init__(self, params, world_rank):
40
+
41
+ self.params = params
42
+ self.world_rank = world_rank
43
+ self.device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
44
+
45
+ # Init gpu
46
+ local_rank = int(os.environ["LOCAL_RANK"])
47
+ torch.cuda.set_device(local_rank)
48
+ self.device = torch.device('cuda', local_rank)
49
+ logging.info('device: %s' % self.device)
50
+
51
+ script_dir = os.path.dirname(os.path.abspath(__file__))
52
+ train_data_path = os.path.join(script_dir, params.train_data_path)
53
+ valid_data_path = os.path.join(script_dir, params.valid_data_path)
54
+ land_mask_path = os.path.join(script_dir, params.land_mask_path)
55
+
56
+ with h5py.File(land_mask_path, 'r') as _f:
57
+ self.mask_data = torch.as_tensor(_f['fields'])[0, self.params.out_channels].to(self.device, dtype=torch.bool)
58
+
59
+ # Load data
60
+ logging.info('rank %d, begin data loader init' % world_rank)
61
+ self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(
62
+ params,
63
+ train_data_path,
64
+ dist.is_initialized(),
65
+ train=True)
66
+ self.valid_data_loader, self.valid_dataset, self.valid_sampler = get_data_loader(
67
+ params,
68
+ valid_data_path,
69
+ dist.is_initialized(),
70
+ train=True)
71
+
72
+ if params.loss_channel_wise:
73
+ self.loss_obj = channel_wise_LpLoss(scale = params.loss_scale)
74
+ else:
75
+ self.loss_obj = LpLoss()
76
+
77
+ # loss scaler
78
+ self.mse_loss_scaler = LossScaler()
79
+
80
+ logging.info('rank %d, data loader initialized' % world_rank)
81
+
82
+ if params.nettype == 'NeuralOM':
83
+ from networks.MIGNN1 import MIGraph as model
84
+ from networks.MIGNN2 import MIGraph_stage2 as model2
85
+ else:
86
+ raise Exception("not implemented")
87
+
88
+ self.model = model(params).to(self.device)
89
+ self.model2 = model2(params).to(self.device)
90
+
91
+ self.optimizer = torch.optim.Adam(self.model2.parameters(), lr = params.lr)
92
+
93
+ if params.enable_amp == True:
94
+ self.gscaler = amp.GradScaler()
95
+
96
+ if dist.is_initialized():
97
+ self.model = DistributedDataParallel(
98
+ self.model,
99
+ device_ids=[params.local_rank],
100
+ output_device=[params.local_rank],
101
+ find_unused_parameters=False
102
+ )
103
+
104
+ self.switch_off_grad(self.model)
105
+
106
+ if dist.is_initialized():
107
+ self.model2 = DistributedDataParallel(
108
+ self.model2,
109
+ device_ids=[params.local_rank],
110
+ output_device=[params.local_rank],
111
+ find_unused_parameters=False
112
+ )
113
+
114
+ self.iters = 0
115
+ self.startEpoch = 0
116
+
117
+ if params.multi_steps_finetune > 1:
118
+ logging.info("Starting from pretrained one-step model at %s"%params.pretrained_ckpt_path)
119
+ self.restore_checkpoint(params.pretrained_ckpt_path)
120
+ self.iters = 0
121
+ self.startEpoch = 0
122
+ logging.info("Adding %d epochs specified in config file for refining pretrained model"%params.finetune_max_epochs)
123
+ params['max_epochs'] = params.finetune_max_epochs
124
+
125
+ self.epoch = self.startEpoch
126
+
127
+ if params.scheduler == 'CosineAnnealingLR':
128
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
129
+ self.optimizer,
130
+ T_max=params.max_epochs,
131
+ last_epoch=self.startEpoch - 1
132
+ )
133
+ else:
134
+ self.scheduler = None
135
+
136
+ if params.log_to_screen:
137
+ logging.info("Number of trainable model parameters: {}".format(self.count_parameters()))
138
+
139
+ def switch_off_grad(self, model):
140
+ for param in model.parameters():
141
+ param.requires_grad = False
142
+
143
+ def train(self):
144
+ if self.params.log_to_screen:
145
+ logging.info("Starting Training Loop...")
146
+
147
+ best_valid_loss = 1.e6
148
+ for epoch in range(self.startEpoch, self.params.max_epochs):
149
+ if dist.is_initialized():
150
+ self.train_sampler.set_epoch(epoch)
151
+ self.valid_sampler.set_epoch(epoch)
152
+
153
+ start = time.time()
154
+ tr_time, data_time, step_time, train_logs = self.train_one_epoch()
155
+ valid_time, valid_logs = self.validate_one_epoch()
156
+
157
+ if self.world_rank == 0:
158
+ if self.params.save_checkpoint:
159
+ # checkpoint at the end of every epoch
160
+ self.save_checkpoint(self.params.checkpoint_path, self.model2)
161
+ if valid_logs['valid_loss'] <= best_valid_loss:
162
+ logging.info('Val loss improved from {} to {}'.format(best_valid_loss, valid_logs['valid_loss']))
163
+ self.save_checkpoint(self.params.best_checkpoint_path, self.model2)
164
+ best_valid_loss = valid_logs['valid_loss']
165
+
166
+ if self.params.log_to_screen:
167
+ logging.info('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
168
+ logging.info('lr for epoch {} is {}'.format(epoch + 1, self.optimizer.param_groups[0]['lr']))
169
+ logging.info('train data time={}, train per epoch time={}, train per step time={}, valid time={}'.format(data_time, tr_time, step_time, valid_time))
170
+ logging.info('Train loss: {}. Valid loss: {}'.format(train_logs['train_loss'], valid_logs['valid_loss']))
171
+
172
+ if self.params.scheduler == 'CosineAnnealingLR':
173
+ self.scheduler.step()
174
+
175
+ torch.cuda.empty_cache()
176
+ gc.collect()
177
+
178
+ def land_mask_func(self, x, y):
179
+ x = torch.masked_fill(input=x, mask=~self.mask_data, value=0)
180
+ y = torch.masked_fill(input=y, mask=~self.mask_data, value=0)
181
+ return x, y
182
+
183
+ def land_mask_func_single(self, x):
184
+ x = torch.masked_fill(input=x, mask=~self.mask_data, value=0)
185
+ return x
186
+
187
+
188
+ def train_one_epoch(self):
189
+ self.epoch += 1
190
+ tr_time = 0
191
+ data_time = 0
192
+ # self.model.train()
193
+ self.model.eval()
194
+ self.model2.train()
195
+
196
+
197
+ steps_in_one_epoch = 0
198
+
199
+ for i, data in enumerate(self.train_data_loader, 0):
200
+ self.iters += 1
201
+ steps_in_one_epoch += 1
202
+
203
+ data_start = time.time()
204
+
205
+ (inp, tar) = data
206
+
207
+ data_time += time.time() - data_start
208
+
209
+ tr_start = time.time()
210
+ # self.model.zero_grad()
211
+ self.model2.zero_grad()
212
+
213
+ num_steps = params.multi_steps_finetune
214
+ # print('num_steps:', num_steps)
215
+
216
+ with amp.autocast(self.params.enable_amp):
217
+
218
+ gen_prev = None
219
+ loss = 0.0
220
+ cw_loss = 0.0
221
+
222
+ for step_idx in range(num_steps):
223
+ if step_idx == 0:
224
+ inp_step_1 = inp.to(self.device, dtype = torch.float32)
225
+ with torch.no_grad():
226
+ gen_model1 = self.model(inp_step_1)
227
+ gen_model1 = self.land_mask_func_single(gen_model1)
228
+ gen_cur = checkpoint.checkpoint(self.model2, gen_model1, use_reentrant=False) + gen_model1
229
+ else:
230
+ atmos_force0 = tar[:, step_idx-1, self.params.atmos_channels].to(self.device, dtype=torch.float)
231
+ atmos_force1 = tar[:, step_idx, self.params.atmos_channels].to(self.device, dtype=torch.float)
232
+ gen_prev = torch.cat( (gen_prev, atmos_force0, atmos_force1), axis = 1).to(self.device, dtype = torch.float32)
233
+ with torch.no_grad():
234
+ gen_model1 = self.model(gen_prev)
235
+ gen_model1 = self.land_mask_func_single(gen_model1)
236
+ gen_cur = checkpoint.checkpoint(self.model2, gen_model1, use_reentrant=False) + gen_model1
237
+
238
+ if params.multi_steps_finetune == 1:
239
+ tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float)
240
+ else:
241
+ tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float)
242
+
243
+ gen_cur, tar_step = self.land_mask_func(gen_cur, tar_step)
244
+
245
+
246
+ loss_step, cw_loss_step = self.loss_obj(gen_cur, tar_step)
247
+
248
+ loss += loss_step
249
+ cw_loss += cw_loss_step
250
+ if step_idx == 0:
251
+ del inp
252
+ mse1 = torch.mean((gen_cur - tar_step) ** 2).item()
253
+
254
+ gen_prev = gen_cur
255
+
256
+ del tar_step, gen_cur
257
+ del gen_prev
258
+
259
+ if self.params.enable_amp:
260
+ self.gscaler.scale(loss).backward()
261
+ self.gscaler.step(self.optimizer)
262
+ else:
263
+ loss.backward()
264
+ self.optimizer.step()
265
+ print('1_step_mse:', mse1)
266
+
267
+
268
+ if self.params.enable_amp:
269
+ self.gscaler.update()
270
+ # break
271
+
272
+ tr_time += time.time() - tr_start
273
+
274
+ logs = {'train_loss': loss}
275
+
276
+ for vi, v in enumerate(self.params.out_variables):
277
+ logs[f'{v}_train_loss'] = cw_loss[vi]
278
+
279
+ if dist.is_initialized():
280
+ for key in sorted(logs.keys()):
281
+ dist.all_reduce(logs[key].detach())
282
+ logs[key] = float(logs[key] / dist.get_world_size())
283
+
284
+ # time of one step in epoch
285
+ step_time = tr_time / steps_in_one_epoch
286
+
287
+ return tr_time, data_time, step_time, logs
288
+
289
+ def validate_one_epoch(self):
290
+
291
+ logging.info('validating...')
292
+ self.model.eval()
293
+
294
+ valid_buff = torch.zeros((3+self.params.N_out_channels), dtype=torch.float32, device=self.device)
295
+ valid_loss = valid_buff[0].view(-1) # 0
296
+ valid_l1 = valid_buff[1].view(-1) # 0
297
+ valid_steps = valid_buff[-1].view(-1) # 0
298
+
299
+ valid_start = time.time()
300
+ sample_idx = np.random.randint(len(self.valid_data_loader))
301
+ with torch.no_grad():
302
+ for i, data in enumerate(self.valid_data_loader, 0):
303
+ # if i > 1:
304
+ # break
305
+
306
+ inp, tar = map(lambda x: x.to(self.device, dtype=torch.float), data)
307
+ # gen = self.model(inp)
308
+ num_steps = params.multi_steps_finetune
309
+ for step_idx in range(num_steps):
310
+ if step_idx == 0:
311
+ inp_step_1 = inp.to(self.device, dtype = torch.float32)
312
+ gen_model1 = self.model(inp_step_1)
313
+ gen_model1 = self.land_mask_func_single(gen_model1)
314
+ gen_cur = self.model2(gen_model1) + gen_model1
315
+ else:
316
+ atmos_force0 = tar[:, step_idx-1, self.params.atmos_channels].to(self.device, dtype=torch.float)
317
+ atmos_force1 = tar[:, step_idx, self.params.atmos_channels].to(self.device, dtype=torch.float)
318
+ gen_prev = torch.cat( (gen_prev, atmos_force0, atmos_force1), axis = 1).to(self.device, dtype = torch.float32)
319
+ gen_model1 = self.model(gen_prev)
320
+ gen_model1 = self.land_mask_func_single(gen_model1)
321
+ gen_cur = self.model2(gen_model1) + gen_model1
322
+
323
+ if params.multi_steps_finetune == 1:
324
+ tar_step = tar[:, self.params.out_channels].to(self.device, dtype=torch.float)
325
+ else:
326
+ tar_step = tar[:, step_idx, self.params.out_channels].to(self.device, dtype=torch.float)
327
+ if self.params.land_mask:
328
+ gen_cur, tar_step = self.land_mask_func(gen_cur, tar_step)
329
+ if step_idx == 0:
330
+ del inp_step_1
331
+ gen_prev = gen_cur
332
+
333
+ if step_idx == params.multi_steps_finetune - 1:
334
+ gen, tar = gen_cur, tar_step
335
+
336
+ del tar_step, gen_cur
337
+ del gen_prev
338
+
339
+ gen.to(self.device, dtype=torch.float)
340
+
341
+ if self.params.land_mask:
342
+ gen, tar = self.land_mask_func(gen, tar)
343
+
344
+ _, cw_valid_loss = self.loss_obj(gen, tar)
345
+ valid_loss_ = torch.mean((gen[:, :, :, :] - tar[:, :, :, :]) ** 2).item()
346
+ valid_loss += valid_loss_
347
+ valid_l1 += nn.functional.l1_loss(gen, tar)
348
+
349
+ for vi, v in enumerate(self.params.out_variables):
350
+ valid_buff[vi+2] += cw_valid_loss[vi]
351
+
352
+ valid_steps += 1.
353
+
354
+ # save fields for vis before log norm
355
+ os.makedirs(params['experiment_dir'] + "/" + str(i), exist_ok =True)
356
+
357
+ del gen, tar
358
+
359
+ if dist.is_initialized():
360
+ dist.all_reduce(valid_buff)
361
+
362
+ # divide by number of steps
363
+ valid_buff[0:-1] = valid_buff[0:-1] / valid_buff[-1] # loss/steps, l1/steps
364
+ valid_buff_cpu = valid_buff.detach().cpu().numpy()
365
+
366
+ valid_time = time.time() - valid_start
367
+
368
+ logs = {'valid_loss': valid_buff_cpu[0],
369
+ 'valid_l1': valid_buff_cpu[1]}
370
+ for vi, v in enumerate(self.params.out_variables):
371
+ logs[f'{v}_valid_loss'] = valid_buff_cpu[vi+2]
372
+
373
+
374
+ return valid_time, logs
375
+
376
+
377
+ def load_model(self, model_path):
378
+ if self.params.log_to_screen:
379
+ logging.info('Loading the model weights from {}'.format(model_path))
380
+
381
+ checkpoint = torch.load(model_path, map_location='cuda:{}'.format(self.params.local_rank))
382
+
383
+ if dist.is_initialized():
384
+ self.model.load_state_dict(checkpoint['model_state'])
385
+ else:
386
+ new_model_state = OrderedDict()
387
+ model_key = 'model_state' if 'model_state' in checkpoint else 'state_dict'
388
+ for key in checkpoint[model_key].keys():
389
+ if 'module.' in key: # model was stored using ddp which prepends module
390
+ name = str(key[7:])
391
+ new_model_state[name] = checkpoint[model_key][key]
392
+ else:
393
+ new_model_state[key] = checkpoint[model_key][key]
394
+ self.model.load_state_dict(new_model_state)
395
+ self.model.eval()
396
+
397
+ def save_checkpoint(self, checkpoint_path, model):
398
+ """ We intentionally require a checkpoint_dir to be passed
399
+ in order to allow Ray Tune to use this function """
400
+
401
+ # if not model:
402
+ # model = self.model
403
+
404
+ torch.save({'iters': self.iters, 'epoch': self.epoch, 'model_state': model.state_dict(),
405
+ 'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)
406
+
407
+ def restore_checkpoint(self, checkpoint_path):
408
+ """ We intentionally require a checkpoint_dir to be passed
409
+ in order to allow Ray Tune to use this function """
410
+ checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.params.local_rank))
411
+ try:
412
+ self.model.load_state_dict(checkpoint['model_state'])
413
+ except:
414
+ new_state_dict = OrderedDict()
415
+ for key, val in checkpoint['model_state'].items():
416
+ name = key[7:]
417
+ new_state_dict[name] = val
418
+ self.model.load_state_dict(new_state_dict)
419
+ self.iters = checkpoint['iters']
420
+ self.startEpoch = checkpoint['epoch']
421
+ # if self.params.resuming and (self.params.multi_steps_finetune == 1):
422
+ # # restore checkpoint is used for finetuning as well as resuming.
423
+ # # If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
424
+ # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
425
+
426
+
427
+ if __name__ == '__main__':
428
+ parser = argparse.ArgumentParser()
429
+ parser.add_argument("--run_num", default='00', type=str)
430
+ parser.add_argument("--yaml_config", default='./config/Model.yaml', type=str)
431
+ parser.add_argument("--multi_steps_finetune", default=1, type=int)
432
+ parser.add_argument("--multi_stages", default=1, type=int)
433
+ parser.add_argument("--finetune_max_epochs", default=50, type=int)
434
+ parser.add_argument("--batch_size", default=16, type=int)
435
+ parser.add_argument("--config", default='NeuralOM', type=str)
436
+ parser.add_argument("--enable_amp", action='store_true')
437
+ parser.add_argument("--epsilon_factor", default=0, type=float)
438
+ parser.add_argument("--local_rank", default=-1, type=int, help='node rank for distributed training')
439
+ args = parser.parse_args()
440
+
441
+ script_dir = os.path.dirname(os.path.abspath(__file__))
442
+ yaml_path = os.path.join(script_dir, args.yaml_config)
443
+ params = YParams(os.path.abspath(yaml_path), args.config, True)
444
+ params['epsilon_factor'] = args.epsilon_factor
445
+ params['multi_steps_finetune'] = args.multi_steps_finetune
446
+ params['multi_stages'] = args.multi_stages
447
+ params['finetune_max_epochs'] = args.finetune_max_epochs
448
+
449
+ params['world_size'] = 1
450
+ if 'WORLD_SIZE' in os.environ:
451
+ params['world_size'] = int(os.environ['WORLD_SIZE'])
452
+ print('world_size :', params['world_size'])
453
+
454
+ print('Initialize distributed process group...')
455
+ dist.init_process_group(backend='nccl')
456
+ local_rank = int(os.environ["LOCAL_RANK"])
457
+ torch.cuda.set_device(local_rank)
458
+ params['local_rank'] = local_rank # GPU ID
459
+
460
+ torch.backends.cudnn.benchmark = True
461
+ world_rank = dist.get_rank()
462
+ params['global_batch_size'] = args.batch_size
463
+ params['batch_size'] = int(args.batch_size // params['world_size']) # batch size must be divisible by the number of gpu's
464
+ params['enable_amp'] = args.enable_amp # Automatic Mixed Precision Training
465
+ script_dir = os.path.dirname(os.path.abspath(__file__))
466
+ exp_dir_path = os.path.join(script_dir, params.exp_dir)
467
+ pretrained_expDir = os.path.join(exp_dir_path, args.config, str(args.run_num))
468
+ multi_steps = params['multi_steps_finetune']
469
+ multi_stages = params['multi_stages']
470
+
471
+ params['pretrained_ckpt_path'] = os.path.join(pretrained_expDir, f'6_steps_finetune/training_checkpoints/best_ckpt.tar')
472
+
473
+ expDir = os.path.join(pretrained_expDir, f'6_steps_finetune/{multi_stages}_stages_finetune/{multi_steps}_steps_finetune')
474
+ if world_rank == 0:
475
+ os.makedirs(expDir, exist_ok=True)
476
+ os.makedirs(os.path.join(expDir, 'training_checkpoints/'), exist_ok=True)
477
+
478
+ params['experiment_dir'] = os.path.abspath(expDir)
479
+ params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
480
+ params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
481
+
482
+ params['resuming'] = True
483
+
484
+
485
+ if world_rank == 0:
486
+ logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'train.log'))
487
+ logging_utils.log_versions()
488
+ params.log()
489
+
490
+ params['log_to_screen'] = (world_rank == 0) and params['log_to_screen']
491
+
492
+ params['in_channels'] = np.array(params['in_channels'])
493
+ params['out_channels'] = np.array(params['out_channels'])
494
+ params['N_out_channels'] = len(params['out_channels'])
495
+ if params.orography:
496
+ params['N_in_channels'] = len(params['in_channels']) + 1
497
+ else:
498
+ params['N_in_channels'] = len(params['in_channels'])
499
+
500
+ if world_rank == 0:
501
+ hparams = ruamelDict()
502
+ yaml = YAML()
503
+ for key, value in params.params.items():
504
+ hparams[str(key)] = str(value)
505
+ with open(os.path.join(expDir, 'hyperparams.yaml'), 'w') as hpfile:
506
+ yaml.dump(hparams, hpfile)
507
+
508
+ trainer = Trainer(params, world_rank)
509
+ trainer.train()
510
+ logging.info('DONE ---- rank %d' % world_rank)
train_residual_model.sh ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_group='NeuralOM'
2
+ yaml_config='config/Model.yaml'
3
+ config='NeuralOM'
4
+ batch_size=16
5
+ # run_num=$(date "+%Y%m%d-%H%M%S")
6
+ run_num='20250501-000000'
7
+ multi_steps_finetune=10
8
+ multi_stages=2
9
+ finetune_max_epochs=200
10
+
11
+ TRAIN_DIR=$(dirname $(realpath train_residual_model.py))
12
+
13
+ export MASTER_ADDR=30.207.97.183
14
+ export MASTER_PORT=31319
15
+ export WORLD_SIZE=16
16
+ export NODE_RANK=0
17
+
18
+ source ~/.bashrc
19
+ conda activate triton_v2
20
+ export NCCL_IB_GID_INDEX=3
21
+ export NCCL_IB_SL=3
22
+ export NCCL_CHECK_DISABLE=1
23
+ export NCCL_P2P_DISABLE=0
24
+ export NCCL_IB_DISABLE=0
25
+ export NCCL_LL_THRESHOLD=16384
26
+ export NCCL_IB_CUDA_SUPPORT=1
27
+ export NCCL_TOPO_AFFINITY=0
28
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
29
+ export NCCL_COLLNET_ENABLE=0
30
+ export SHARP_COLL_ENABLE_SAT=0
31
+ export NCCL_NET_GDR_LEVEL=2
32
+ export NCCL_IB_QPS_PER_CONNECTION=4
33
+ export NCCL_IB_TC=160
34
+ export NCCL_PXN_DISABLE=0
35
+ export NCCL_DEBUG=WARN
36
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2400
37
+ export NCCL_SOCKET_IFNAME=bond1
38
+
39
+ export TORCH_NCCL_BLOCKING_WAIT=1
40
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
41
+
42
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
43
+ nohup torchrun --nproc_per_node=8 --nnodes=2 --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT $TRAIN_DIR/train_residual_model.py \
44
+ --yaml_config=$yaml_config --config=$config --run_num=$run_num --batch_size=$batch_size --multi_steps_finetune=$multi_steps_finetune --finetune_max_epochs=$finetune_max_epochs \
45
+ >> ./logs/${config}_${wandb_group}_rank0_${SLURM_JOB_ID}_${run_num}.log 2>&1 &
46
+
47
+ ssh root@30.207.98.235 "
48
+ source ~/.bashrc; \
49
+ conda activate triton_v2; \
50
+
51
+ export NCCL_IB_GID_INDEX=3
52
+ export NCCL_IB_SL=3
53
+ export NCCL_CHECK_DISABLE=1
54
+ export NCCL_P2P_DISABLE=0
55
+ export NCCL_IB_DISABLE=0
56
+ export NCCL_LL_THRESHOLD=16384
57
+ export NCCL_IB_CUDA_SUPPORT=1
58
+ export NCCL_TOPO_AFFINITY=0
59
+ export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
60
+ export NCCL_COLLNET_ENABLE=0
61
+ export SHARP_COLL_ENABLE_SAT=0
62
+ export NCCL_NET_GDR_LEVEL=2
63
+ export NCCL_IB_QPS_PER_CONNECTION=4
64
+ export NCCL_IB_TC=160
65
+ export NCCL_PXN_DISABLE=0
66
+ export NCCL_DEBUG=WARN
67
+ export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=2400
68
+ export NCCL_SOCKET_IFNAME=bond1
69
+
70
+ export TORCH_NCCL_BLOCKING_WAIT=1
71
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
72
+
73
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7; \
74
+ export MASTER_ADDR=$MASTER_ADDR; export MASTER_PORT=$MASTER_PORT; export WORLD_SIZE=16; export NODE_RANK=1; \
75
+ nohup torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT $TRAIN_DIR/train_residual_model.py \
76
+ --yaml_config=$yaml_config --config=$config --run_num=$run_num --batch_size=$batch_size --multi_steps_finetune=$multi_steps_finetune --finetune_max_epochs=$finetune_max_epochs \
77
+ >> $TRAIN_DIR/logs/${config}_${wandb_group}_rank1_${SLURM_JOB_ID}_${run_num}.log 2>&1 &"