mohamed12ahmed commited on
Commit
0ea9dda
·
verified ·
1 Parent(s): 161f282

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +95 -87
train.py CHANGED
@@ -5,35 +5,30 @@ import random
5
  import datetime
6
  import argparse
7
  import numpy as np
8
- from tqdm import tqdm
9
- from piq import ssim,psnr
10
  from itertools import cycle
11
 
12
  import torch
13
  import torch.nn as nn
14
  from torch.utils import data
15
- import torch.distributed as dist
16
- from torch.utils.data.distributed import DistributedSampler
17
- from torch.nn.parallel import DistributedDataParallel as DDP
18
 
 
19
 
20
  from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
21
- from loaders import docres_loader
22
- from models import restormer_arch
 
23
 
 
 
24
 
25
  def seed_torch(seed=1029):
26
  random.seed(seed)
27
  os.environ['PYTHONHASHSEED'] = str(seed)
28
  np.random.seed(seed)
29
  torch.manual_seed(seed)
30
- torch.cuda.manual_seed(seed)
31
- torch.cuda.manual_seed_all(seed)
32
  torch.backends.cudnn.benchmark = False
33
  torch.backends.cudnn.deterministic = True
34
- #torch.use_deterministic_algorithms(True)
35
- # seed_torch()
36
-
37
 
38
  def getBasecoord(h,w):
39
  base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
@@ -42,13 +37,11 @@ def getBasecoord(h,w):
42
  return base_coord
43
 
44
  def train(args):
45
-
46
- ## DDP init
47
- dist.init_process_group(backend='nccl',init_method='env://',timeout=datetime.timedelta(seconds=36000))
48
- torch.cuda.set_device(args.local_rank)
49
- device = torch.device('cuda',args.local_rank)
50
- torch.cuda.manual_seed_all(42)
51
-
52
  ### Log file:
53
  mkdir(args.logdir)
54
  mkdir(os.path.join(args.logdir,args.experiment_name))
@@ -58,10 +51,17 @@ def train(args):
58
  log_file.close()
59
 
60
  ### Setup tensorboard for visualization
61
- if args.tboard:
62
- writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
 
 
 
 
 
 
63
 
64
  ### Setup Dataloader
 
65
  datasets_setting = [
66
  {'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
67
  {'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
@@ -69,33 +69,30 @@ def train(args):
69
  {'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
70
  {'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
71
  ]
72
-
73
-
74
  ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
75
  datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
76
- trainloaders = [{'task':datasets_setting[i],'loader':data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True),'iter_loader':iter(data.DataLoader(dataset=datasets[i], sampler=DistributedSampler(datasets[i]), batch_size=args.batch_size, num_workers=2, pin_memory=True,drop_last=True))} for i in range(len(datasets))]
77
-
78
-
79
- ### test loader
80
- # for i in tqdm(range(args.total_iter)):
81
- # loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
82
- # in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
83
-
84
-
85
  ### Setup Model
86
- model = restormer_arch.Restormer(
87
- inp_channels=6,
88
- out_channels=3,
89
- dim = 48,
90
- num_blocks = [2,3,3,4],
91
- num_refinement_blocks = 4,
92
  heads = [1,2,4,8],
93
  ffn_expansion_factor = 2.66,
94
  bias = False,
95
- LayerNorm_type = 'WithBias',
96
- dual_pixel_task = True
97
- )
98
- model=DDP(model.cuda(),device_ids=[args.local_rank],output_device=args.local_rank)
 
99
 
100
  ### Optimizer
101
  optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
@@ -105,8 +102,10 @@ def train(args):
105
 
106
  ### load checkpoint
107
  iter_start=0
108
- if args.resume is not None:
109
  print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
 
 
110
  x = checkpoint['model_state']
111
  model.load_state_dict(x,strict=False)
112
  iter_start=checkpoint['iter']
@@ -114,7 +113,7 @@ def train(args):
114
 
115
  ###-----------------------------------------Training-----------------------------------------
116
  ##initialize
117
- scaler = torch.cuda.amp.GradScaler()
118
  loss_dict = {}
119
  total_step = 0
120
  l2 = nn.MSELoss()
@@ -128,74 +127,80 @@ def train(args):
128
  ## total_steps
129
  for iters in range(iter_start,args.total_iter):
130
  start_time = time.time()
 
131
  loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
132
-
133
  try:
134
  in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
135
  except StopIteration:
136
  trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
137
  in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
138
- in_im = in_im.float().cuda()
139
- gt_im = gt_im.float().cuda()
140
-
 
 
141
  binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
142
- with torch.cuda.amp.autocast():
143
- pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
144
- if trainloaders[loader_index]['task']['task'] == 'binarization':
145
- gt_im = gt_im.long()
146
- binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
147
- loss = binarization_loss
148
- elif trainloaders[loader_index]['task']['task'] == 'dewarping':
149
- dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
150
- loss = dewarping_loss
151
- elif trainloaders[loader_index]['task']['task'] == 'appearance':
152
- appearance_loss = l1(pred_im, gt_im)
153
- loss = appearance_loss
154
- elif trainloaders[loader_index]['task']['task'] == 'deblurring':
155
- deblurring_loss = l1(pred_im, gt_im)
156
- loss = deblurring_loss
157
- elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
158
- deshadowing_loss = l1(pred_im, gt_im)
159
- loss = deshadowing_loss
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  optimizer.zero_grad()
162
- scaler.scale(loss).backward()
163
- scaler.step(optimizer)
164
- scaler.update()
165
-
166
  loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
167
  loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
168
  loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
169
  loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
170
  loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
 
171
  end_time = time.time()
172
  duration = end_time-start_time
 
173
  ## log
174
  if (iters+1) % 10 == 0:
175
  ## print
176
  print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
 
177
  ## tbord
178
- if args.tboard:
179
- for key,value in loss_dict.items():
180
- writer.add_scalar('Train '+key+'/Iterations', value, total_step)
 
181
  ## logfile
182
  with open(log_file_path,'a') as f:
183
  f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
184
 
185
-
186
  if (iters+1) % 5000 == 0:
187
  state = {'iters': iters+1,
188
  'model_state': model.state_dict(),
189
  'optimizer_state' : optimizer.state_dict(),}
190
- if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
191
  os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
192
- if torch.distributed.get_rank()==0:
193
- torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
194
-
 
195
  sched.step()
196
 
197
-
198
-
199
  if __name__ == '__main__':
200
  parser = argparse.ArgumentParser(description='Hyperparams')
201
  parser.add_argument('--im_size', nargs='?', type=int, default=256,
@@ -206,16 +211,19 @@ if __name__ == '__main__':
206
  help='Batch Size')
207
  parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
208
  help='Learning Rate')
209
- parser.add_argument('--resume', nargs='?', type=str, default=None,
210
- help='Path to previous saved model to restart from')
211
- parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
212
- help='Path to store the loss logs')
213
  parser.add_argument('--tboard', dest='tboard', action='store_true',
214
  help='Enable visualization(s) on tensorboard | False by default')
215
- parser.add_argument('--local_rank',type=int,default=0,metavar='N')
216
  parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
217
  help='the name of this experiment')
218
  parser.set_defaults(tboard=False)
219
  args = parser.parse_args()
220
-
221
- train(args)
 
 
 
 
5
  import datetime
6
  import argparse
7
  import numpy as np
 
 
8
  from itertools import cycle
9
 
10
  import torch
11
  import torch.nn as nn
12
  from torch.utils import data
 
 
 
13
 
14
+ # Removed DDP and DistributedSampler imports
15
 
16
  from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
17
+ # Assumed 'loaders' and 'models' modules are available
18
+ from loaders import docres_loader
19
+ from models import restormer_arch
20
 
21
+ # --- Optional: Import for TensorBoard (uncomment if you have it installed) ---
22
+ # from torch.utils.tensorboard import SummaryWriter
23
 
24
  def seed_torch(seed=1029):
25
  random.seed(seed)
26
  os.environ['PYTHONHASHSEED'] = str(seed)
27
  np.random.seed(seed)
28
  torch.manual_seed(seed)
29
+ # Removed CUDA-specific seeding
 
30
  torch.backends.cudnn.benchmark = False
31
  torch.backends.cudnn.deterministic = True
 
 
 
32
 
33
  def getBasecoord(h,w):
34
  base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
 
37
  return base_coord
38
 
39
  def train(args):
40
+ # --- CPU/Single-Process Setup ---
41
+ # Set device to CPU
42
+ device = torch.device('cpu')
43
+ print(f"Training on device: {device}")
44
+
 
 
45
  ### Log file:
46
  mkdir(args.logdir)
47
  mkdir(os.path.join(args.logdir,args.experiment_name))
 
51
  log_file.close()
52
 
53
  ### Setup tensorboard for visualization
54
+ # Note: TensorBoard setup is commented out for robust CPU execution.
55
+ # if args.tboard:
56
+ # try:
57
+ # writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
58
+ # except NameError:
59
+ # print("Warning: TensorBoard not imported. Skipping logging to SummaryWriter.")
60
+ # args.tboard = False
61
+
62
 
63
  ### Setup Dataloader
64
+ # NOTE: You MUST update these paths to match your system setup.
65
  datasets_setting = [
66
  {'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
67
  {'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
 
69
  {'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
70
  {'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
71
  ]
 
 
72
  ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
73
  datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
74
+
75
+ # Standard DataLoader is used instead of DistributedSampler
76
+ trainloaders = [{'task':datasets_setting[i],
77
+ 'loader':data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True),
78
+ 'iter_loader':iter(data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True))}
79
+ for i in range(len(datasets))]
80
+
 
 
81
  ### Setup Model
82
+ model = restormer_arch.Restormer(
83
+ inp_channels=6,
84
+ out_channels=3,
85
+ dim = 48,
86
+ num_blocks = [2,3,3,4],
87
+ num_refinement_blocks = 4,
88
  heads = [1,2,4,8],
89
  ffn_expansion_factor = 2.66,
90
  bias = False,
91
+ LayerNorm_type = 'WithBias',
92
+ dual_pixel_task = True
93
+ )
94
+ # Move model to CPU
95
+ model.to(device)
96
 
97
  ### Optimizer
98
  optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
 
102
 
103
  ### load checkpoint
104
  iter_start=0
105
+ if args.resume is not None:
106
  print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
107
+ # Ensure checkpoint is loaded to CPU
108
+ checkpoint = torch.load(args.resume, map_location=device)
109
  x = checkpoint['model_state']
110
  model.load_state_dict(x,strict=False)
111
  iter_start=checkpoint['iter']
 
113
 
114
  ###-----------------------------------------Training-----------------------------------------
115
  ##initialize
116
+ # Removed GradScaler for AMP
117
  loss_dict = {}
118
  total_step = 0
119
  l2 = nn.MSELoss()
 
127
  ## total_steps
128
  for iters in range(iter_start,args.total_iter):
129
  start_time = time.time()
130
+
131
  loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
 
132
  try:
133
  in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
134
  except StopIteration:
135
  trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
136
  in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
137
+
138
+ # Move data to CPU
139
+ in_im = in_im.float().to(device)
140
+ gt_im = gt_im.float().to(device)
141
+
142
  binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
+ # Removed torch.cuda.amp.autocast() block
145
+ pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
146
+
147
+ if trainloaders[loader_index]['task']['task'] == 'binarization':
148
+ gt_im = gt_im.long()
149
+ binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
150
+ loss = binarization_loss
151
+ elif trainloaders[loader_index]['task']['task'] == 'dewarping':
152
+ dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
153
+ loss = dewarping_loss
154
+ elif trainloaders[loader_index]['task']['task'] == 'appearance':
155
+ appearance_loss = l1(pred_im, gt_im)
156
+ loss = appearance_loss
157
+ elif trainloaders[loader_index]['task']['task'] == 'deblurring':
158
+ deblurring_loss = l1(pred_im, gt_im)
159
+ loss = deblurring_loss
160
+ elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
161
+ deshadowing_loss = l1(pred_im, gt_im)
162
+ loss = deshadowing_loss
163
+
164
  optimizer.zero_grad()
165
+ # Standard backward pass (removed scaler)
166
+ loss.backward()
167
+ optimizer.step()
168
+
169
  loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
170
  loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
171
  loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
172
  loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
173
  loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
174
+
175
  end_time = time.time()
176
  duration = end_time-start_time
177
+
178
  ## log
179
  if (iters+1) % 10 == 0:
180
  ## print
181
  print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
182
+
183
  ## tbord
184
+ # if args.tboard:
185
+ # for key,value in loss_dict.items():
186
+ # writer.add_scalar('Train '+key+'/Iterations', value, total_step)
187
+
188
  ## logfile
189
  with open(log_file_path,'a') as f:
190
  f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
191
 
 
192
  if (iters+1) % 5000 == 0:
193
  state = {'iters': iters+1,
194
  'model_state': model.state_dict(),
195
  'optimizer_state' : optimizer.state_dict(),}
196
+ if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
197
  os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
198
+
199
+ # Save checkpoint without DDP rank check
200
+ torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
201
+
202
  sched.step()
203
 
 
 
204
  if __name__ == '__main__':
205
  parser = argparse.ArgumentParser(description='Hyperparams')
206
  parser.add_argument('--im_size', nargs='?', type=int, default=256,
 
211
  help='Batch Size')
212
  parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
213
  help='Learning Rate')
214
+ parser.add_argument('--resume', nargs='?', type=str, default=None,
215
+ help='Path to previous saved model to restart from')
216
+ parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
217
+ help='Path to store the loss logs')
218
  parser.add_argument('--tboard', dest='tboard', action='store_true',
219
  help='Enable visualization(s) on tensorboard | False by default')
220
+ # Removed local_rank argument as it's not needed for single-process CPU
221
  parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
222
  help='the name of this experiment')
223
  parser.set_defaults(tboard=False)
224
  args = parser.parse_args()
225
+
226
+ # Note: Using a low batch size (e.g., 2) is recommended for initial CPU testing.
227
+ # args.batch_size = 2 # Uncomment for quick testing
228
+
229
+ train(args)