File size: 10,682 Bytes
c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda c509e76 0ea9dda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import cv2
import time
import random
import datetime
import argparse
import numpy as np
from itertools import cycle
import torch
import torch.nn as nn
from torch.utils import data
# Removed DDP and DistributedSampler imports
from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
# Assumed 'loaders' and 'models' modules are available
from loaders import docres_loader
from models import restormer_arch
# --- Optional: Import for TensorBoard (uncomment if you have it installed) ---
# from torch.utils.tensorboard import SummaryWriter
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Removed CUDA-specific seeding
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def getBasecoord(h,w):
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
return base_coord
def train(args):
# --- CPU/Single-Process Setup ---
# Set device to CPU
device = torch.device('cpu')
print(f"Training on device: {device}")
### Log file:
mkdir(args.logdir)
mkdir(os.path.join(args.logdir,args.experiment_name))
log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt')
log_file=open(log_file_path,'a')
log_file.write('\n--------------- '+args.experiment_name+' ---------------\n')
log_file.close()
### Setup tensorboard for visualization
# Note: TensorBoard setup is commented out for robust CPU execution.
# if args.tboard:
# try:
# writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
# except NameError:
# print("Warning: TensorBoard not imported. Skipping logging to SummaryWriter.")
# args.tboard = False
### Setup Dataloader
# NOTE: You MUST update these paths to match your system setup.
datasets_setting = [
{'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
{'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
{'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']},
{'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
{'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
]
ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
# Standard DataLoader is used instead of DistributedSampler
trainloaders = [{'task':datasets_setting[i],
'loader':data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True),
'iter_loader':iter(data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True))}
for i in range(len(datasets))]
### Setup Model
model = restormer_arch.Restormer(
inp_channels=6,
out_channels=3,
dim = 48,
num_blocks = [2,3,3,4],
num_refinement_blocks = 4,
heads = [1,2,4,8],
ffn_expansion_factor = 2.66,
bias = False,
LayerNorm_type = 'WithBias',
dual_pixel_task = True
)
# Move model to CPU
model.to(device)
### Optimizer
optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
### LR Scheduler
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1)
### load checkpoint
iter_start=0
if args.resume is not None:
print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
# Ensure checkpoint is loaded to CPU
checkpoint = torch.load(args.resume, map_location=device)
x = checkpoint['model_state']
model.load_state_dict(x,strict=False)
iter_start=checkpoint['iter']
print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start))
###-----------------------------------------Training-----------------------------------------
##initialize
# Removed GradScaler for AMP
loss_dict = {}
total_step = 0
l2 = nn.MSELoss()
l1 = nn.L1Loss()
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()
m = nn.Sigmoid()
best = 0
best_ce = 999
## total_steps
for iters in range(iter_start,args.total_iter):
start_time = time.time()
loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
try:
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
except StopIteration:
trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
# Move data to CPU
in_im = in_im.float().to(device)
gt_im = gt_im.float().to(device)
binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
# Removed torch.cuda.amp.autocast() block
pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
if trainloaders[loader_index]['task']['task'] == 'binarization':
gt_im = gt_im.long()
binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
loss = binarization_loss
elif trainloaders[loader_index]['task']['task'] == 'dewarping':
dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
loss = dewarping_loss
elif trainloaders[loader_index]['task']['task'] == 'appearance':
appearance_loss = l1(pred_im, gt_im)
loss = appearance_loss
elif trainloaders[loader_index]['task']['task'] == 'deblurring':
deblurring_loss = l1(pred_im, gt_im)
loss = deblurring_loss
elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
deshadowing_loss = l1(pred_im, gt_im)
loss = deshadowing_loss
optimizer.zero_grad()
# Standard backward pass (removed scaler)
loss.backward()
optimizer.step()
loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
end_time = time.time()
duration = end_time-start_time
## log
if (iters+1) % 10 == 0:
## print
print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
## tbord
# if args.tboard:
# for key,value in loss_dict.items():
# writer.add_scalar('Train '+key+'/Iterations', value, total_step)
## logfile
with open(log_file_path,'a') as f:
f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
if (iters+1) % 5000 == 0:
state = {'iters': iters+1,
'model_state': model.state_dict(),
'optimizer_state' : optimizer.state_dict(),}
if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
# Save checkpoint without DDP rank check
torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
sched.step()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--im_size', nargs='?', type=int, default=256,
help='Height of the input image')
parser.add_argument('--total_iter', nargs='?', type=int, default=100000,
help='# of the epochs')
parser.add_argument('--batch_size', nargs='?', type=int, default=10,
help='Batch Size')
parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
help='Learning Rate')
parser.add_argument('--resume', nargs='?', type=str, default=None,
help='Path to previous saved model to restart from')
parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
help='Path to store the loss logs')
parser.add_argument('--tboard', dest='tboard', action='store_true',
help='Enable visualization(s) on tensorboard | False by default')
# Removed local_rank argument as it's not needed for single-process CPU
parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
help='the name of this experiment')
parser.set_defaults(tboard=False)
args = parser.parse_args()
# Note: Using a low batch size (e.g., 2) is recommended for initial CPU testing.
# args.batch_size = 2 # Uncomment for quick testing
train(args)
|