mohamed12ahmed's picture
Update train.py
0ea9dda verified
raw
history blame
10.7 kB
import os
import cv2
import time
import random
import datetime
import argparse
import numpy as np
from itertools import cycle
import torch
import torch.nn as nn
from torch.utils import data
# Removed DDP and DistributedSampler imports
from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours
# Assumed 'loaders' and 'models' modules are available
from loaders import docres_loader
from models import restormer_arch
# --- Optional: Import for TensorBoard (uncomment if you have it installed) ---
# from torch.utils.tensorboard import SummaryWriter
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Removed CUDA-specific seeding
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def getBasecoord(h,w):
base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32)
base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32)
base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1)
return base_coord
def train(args):
# --- CPU/Single-Process Setup ---
# Set device to CPU
device = torch.device('cpu')
print(f"Training on device: {device}")
### Log file:
mkdir(args.logdir)
mkdir(os.path.join(args.logdir,args.experiment_name))
log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt')
log_file=open(log_file_path,'a')
log_file.write('\n--------------- '+args.experiment_name+' ---------------\n')
log_file.close()
### Setup tensorboard for visualization
# Note: TensorBoard setup is commented out for robust CPU execution.
# if args.tboard:
# try:
# writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name)
# except NameError:
# print("Warning: TensorBoard not imported. Skipping logging to SummaryWriter.")
# args.tboard = False
### Setup Dataloader
# NOTE: You MUST update these paths to match your system setup.
datasets_setting = [
{'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']},
{'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']},
{'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']},
{'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']},
{'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']}
]
ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting]
datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting]
# Standard DataLoader is used instead of DistributedSampler
trainloaders = [{'task':datasets_setting[i],
'loader':data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True),
'iter_loader':iter(data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True))}
for i in range(len(datasets))]
### Setup Model
model = restormer_arch.Restormer(
inp_channels=6,
out_channels=3,
dim = 48,
num_blocks = [2,3,3,4],
num_refinement_blocks = 4,
heads = [1,2,4,8],
ffn_expansion_factor = 2.66,
bias = False,
LayerNorm_type = 'WithBias',
dual_pixel_task = True
)
# Move model to CPU
model.to(device)
### Optimizer
optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4)
### LR Scheduler
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1)
### load checkpoint
iter_start=0
if args.resume is not None:
print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
# Ensure checkpoint is loaded to CPU
checkpoint = torch.load(args.resume, map_location=device)
x = checkpoint['model_state']
model.load_state_dict(x,strict=False)
iter_start=checkpoint['iter']
print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start))
###-----------------------------------------Training-----------------------------------------
##initialize
# Removed GradScaler for AMP
loss_dict = {}
total_step = 0
l2 = nn.MSELoss()
l1 = nn.L1Loss()
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()
m = nn.Sigmoid()
best = 0
best_ce = 999
## total_steps
for iters in range(iter_start,args.total_iter):
start_time = time.time()
loader_index = random.choices(list(range(len(trainloaders))),ratios)[0]
try:
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
except StopIteration:
trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader'])
in_im,gt_im = next(trainloaders[loader_index]['iter_loader'])
# Move data to CPU
in_im = in_im.float().to(device)
gt_im = gt_im.float().to(device)
binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0
# Removed torch.cuda.amp.autocast() block
pred_im = model(in_im,trainloaders[loader_index]['task']['task'])
if trainloaders[loader_index]['task']['task'] == 'binarization':
gt_im = gt_im.long()
binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:])
loss = binarization_loss
elif trainloaders[loader_index]['task']['task'] == 'dewarping':
dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:])
loss = dewarping_loss
elif trainloaders[loader_index]['task']['task'] == 'appearance':
appearance_loss = l1(pred_im, gt_im)
loss = appearance_loss
elif trainloaders[loader_index]['task']['task'] == 'deblurring':
deblurring_loss = l1(pred_im, gt_im)
loss = deblurring_loss
elif trainloaders[loader_index]['task']['task'] == 'deshadowing':
deshadowing_loss = l1(pred_im, gt_im)
loss = deshadowing_loss
optimizer.zero_grad()
# Standard backward pass (removed scaler)
loss.backward()
optimizer.step()
loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0
loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0
loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0
loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0
loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0
end_time = time.time()
duration = end_time-start_time
## log
if (iters+1) % 10 == 0:
## print
print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters))))
## tbord
# if args.tboard:
# for key,value in loss_dict.items():
# writer.add_scalar('Train '+key+'/Iterations', value, total_step)
## logfile
with open(log_file_path,'a') as f:
f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n')
if (iters+1) % 5000 == 0:
state = {'iters': iters+1,
'model_state': model.state_dict(),
'optimizer_state' : optimizer.state_dict(),}
if not os.path.exists(os.path.join(args.logdir,args.experiment_name)):
os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name))
# Save checkpoint without DDP rank check
torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1)))
sched.step()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--im_size', nargs='?', type=int, default=256,
help='Height of the input image')
parser.add_argument('--total_iter', nargs='?', type=int, default=100000,
help='# of the epochs')
parser.add_argument('--batch_size', nargs='?', type=int, default=10,
help='Batch Size')
parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4,
help='Learning Rate')
parser.add_argument('--resume', nargs='?', type=str, default=None,
help='Path to previous saved model to restart from')
parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/',
help='Path to store the loss logs')
parser.add_argument('--tboard', dest='tboard', action='store_true',
help='Enable visualization(s) on tensorboard | False by default')
# Removed local_rank argument as it's not needed for single-process CPU
parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name',
help='the name of this experiment')
parser.set_defaults(tboard=False)
args = parser.parse_args()
# Note: Using a low batch size (e.g., 2) is recommended for initial CPU testing.
# args.batch_size = 2 # Uncomment for quick testing
train(args)