import os import cv2 import time import random import datetime import argparse import numpy as np from itertools import cycle import torch import torch.nn as nn from torch.utils import data # Removed DDP and DistributedSampler imports from utils import dict2string,mkdir,get_lr,torch2cvimg,second2hours # Assumed 'loaders' and 'models' modules are available from loaders import docres_loader from models import restormer_arch # --- Optional: Import for TensorBoard (uncomment if you have it installed) --- # from torch.utils.tensorboard import SummaryWriter def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) # Removed CUDA-specific seeding torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def getBasecoord(h,w): base_coord0 = np.tile(np.arange(h).reshape(h,1),(1,w)).astype(np.float32) base_coord1 = np.tile(np.arange(w).reshape(1,w),(h,1)).astype(np.float32) base_coord = np.concatenate((np.expand_dims(base_coord1,-1),np.expand_dims(base_coord0,-1)),-1) return base_coord def train(args): # --- CPU/Single-Process Setup --- # Set device to CPU device = torch.device('cpu') print(f"Training on device: {device}") ### Log file: mkdir(args.logdir) mkdir(os.path.join(args.logdir,args.experiment_name)) log_file_path=os.path.join(args.logdir,args.experiment_name,'log.txt') log_file=open(log_file_path,'a') log_file.write('\n--------------- '+args.experiment_name+' ---------------\n') log_file.close() ### Setup tensorboard for visualization # Note: TensorBoard setup is commented out for robust CPU execution. # if args.tboard: # try: # writer = SummaryWriter(os.path.join(args.logdir,args.experiment_name,'runs'),args.experiment_name) # except NameError: # print("Warning: TensorBoard not imported. Skipping logging to SummaryWriter.") # args.tboard = False ### Setup Dataloader # NOTE: You MUST update these paths to match your system setup. datasets_setting = [ {'task':'deblurring','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deblurring/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deblurring/tdd/train.json']}, {'task':'dewarping','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/dewarping/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/dewarping/doc3d/train_1_19.json']}, {'task':'binarization','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/binarization/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/binarization/train.json']}, {'task':'deshadowing','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/deshadowing/train.json']}, {'task':'appearance','ratio':1,'im_path':'/home/jiaxin/Training_Data/DocRes_data/train/appearance/','json_paths':['/home/jiaxin/Training_Data/DocRes_data/train/appearance/trainv2.json']} ] ratios = [dataset_setting['ratio'] for dataset_setting in datasets_setting] datasets = [docres_loader.DocResTrainDataset(dataset=dataset_setting,img_size=args.im_size) for dataset_setting in datasets_setting] # Standard DataLoader is used instead of DistributedSampler trainloaders = [{'task':datasets_setting[i], 'loader':data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True), 'iter_loader':iter(data.DataLoader(dataset=datasets[i], batch_size=args.batch_size, num_workers=0, pin_memory=False, drop_last=True))} for i in range(len(datasets))] ### Setup Model model = restormer_arch.Restormer( inp_channels=6, out_channels=3, dim = 48, num_blocks = [2,3,3,4], num_refinement_blocks = 4, heads = [1,2,4,8], ffn_expansion_factor = 2.66, bias = False, LayerNorm_type = 'WithBias', dual_pixel_task = True ) # Move model to CPU model.to(device) ### Optimizer optimizer= torch.optim.AdamW(model.parameters(),lr=args.l_rate,weight_decay=5e-4) ### LR Scheduler sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.total_iter, eta_min=1e-6, last_epoch=-1) ### load checkpoint iter_start=0 if args.resume is not None: print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) # Ensure checkpoint is loaded to CPU checkpoint = torch.load(args.resume, map_location=device) x = checkpoint['model_state'] model.load_state_dict(x,strict=False) iter_start=checkpoint['iter'] print("Loaded checkpoint '{}' (iter {})".format(args.resume, iter_start)) ###-----------------------------------------Training----------------------------------------- ##initialize # Removed GradScaler for AMP loss_dict = {} total_step = 0 l2 = nn.MSELoss() l1 = nn.L1Loss() ce = nn.CrossEntropyLoss() bce = nn.BCEWithLogitsLoss() m = nn.Sigmoid() best = 0 best_ce = 999 ## total_steps for iters in range(iter_start,args.total_iter): start_time = time.time() loader_index = random.choices(list(range(len(trainloaders))),ratios)[0] try: in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) except StopIteration: trainloaders[loader_index]['iter_loader']=iter(trainloaders[loader_index]['loader']) in_im,gt_im = next(trainloaders[loader_index]['iter_loader']) # Move data to CPU in_im = in_im.float().to(device) gt_im = gt_im.float().to(device) binarization_loss,appearance_loss,dewarping_loss,deblurring_loss,deshadowing_loss = 0,0,0,0,0 # Removed torch.cuda.amp.autocast() block pred_im = model(in_im,trainloaders[loader_index]['task']['task']) if trainloaders[loader_index]['task']['task'] == 'binarization': gt_im = gt_im.long() binarization_loss = ce(pred_im[:,:2,:,:], gt_im[:,0,:,:]) loss = binarization_loss elif trainloaders[loader_index]['task']['task'] == 'dewarping': dewarping_loss = l1(pred_im[:,:2,:,:], gt_im[:,:2,:,:]) loss = dewarping_loss elif trainloaders[loader_index]['task']['task'] == 'appearance': appearance_loss = l1(pred_im, gt_im) loss = appearance_loss elif trainloaders[loader_index]['task']['task'] == 'deblurring': deblurring_loss = l1(pred_im, gt_im) loss = deblurring_loss elif trainloaders[loader_index]['task']['task'] == 'deshadowing': deshadowing_loss = l1(pred_im, gt_im) loss = deshadowing_loss optimizer.zero_grad() # Standard backward pass (removed scaler) loss.backward() optimizer.step() loss_dict['dew_loss']=dewarping_loss.item() if isinstance(dewarping_loss,torch.Tensor) else 0 loss_dict['app_loss']=appearance_loss.item() if isinstance(appearance_loss,torch.Tensor) else 0 loss_dict['des_loss']=deshadowing_loss.item() if isinstance(deshadowing_loss,torch.Tensor) else 0 loss_dict['deb_loss']=deblurring_loss.item() if isinstance(deblurring_loss,torch.Tensor) else 0 loss_dict['bin_loss']=binarization_loss.item() if isinstance(binarization_loss,torch.Tensor) else 0 end_time = time.time() duration = end_time-start_time ## log if (iters+1) % 10 == 0: ## print print('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))) ## tbord # if args.tboard: # for key,value in loss_dict.items(): # writer.add_scalar('Train '+key+'/Iterations', value, total_step) ## logfile with open(log_file_path,'a') as f: f.write('iters [{}/{}] -- '.format(iters+1,args.total_iter)+dict2string(loss_dict)+' --lr {:6f}'.format(get_lr(optimizer))+' -- time {}'.format(second2hours(duration*(args.total_iter-iters)))+'\n') if (iters+1) % 5000 == 0: state = {'iters': iters+1, 'model_state': model.state_dict(), 'optimizer_state' : optimizer.state_dict(),} if not os.path.exists(os.path.join(args.logdir,args.experiment_name)): os.system('mkdir ' + os.path.join(args.logdir,args.experiment_name)) # Save checkpoint without DDP rank check torch.save(state, os.path.join(args.logdir,args.experiment_name,"{}.pkl".format(iters+1))) sched.step() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Hyperparams') parser.add_argument('--im_size', nargs='?', type=int, default=256, help='Height of the input image') parser.add_argument('--total_iter', nargs='?', type=int, default=100000, help='# of the epochs') parser.add_argument('--batch_size', nargs='?', type=int, default=10, help='Batch Size') parser.add_argument('--l_rate', nargs='?', type=float, default=2e-4, help='Learning Rate') parser.add_argument('--resume', nargs='?', type=str, default=None, help='Path to previous saved model to restart from') parser.add_argument('--logdir', nargs='?', type=str, default='./checkpoints/', help='Path to store the loss logs') parser.add_argument('--tboard', dest='tboard', action='store_true', help='Enable visualization(s) on tensorboard | False by default') # Removed local_rank argument as it's not needed for single-process CPU parser.add_argument('--experiment_name', nargs='?', type=str,default='experiment_name', help='the name of this experiment') parser.set_defaults(tboard=False) args = parser.parse_args() # Note: Using a low batch size (e.g., 2) is recommended for initial CPU testing. # args.batch_size = 2 # Uncomment for quick testing train(args)