import os import sys import time import shutil import logging import numpy as np from scipy import interpolate from datetime import datetime import torch import torch.nn.functional as F class InputPadder: """ Pads images such that dimensions are divisible by 8 """ def __init__(self, dims, mode='sintel', divis_by=8): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by if mode == 'sintel': self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] else: self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] def pad(self, *inputs): assert all((x.ndim == 4) for x in inputs) return [F.pad(x, self._pad, mode='replicate') for x in inputs] def pad_intrinsics(self, intrinsic): intrinsic[:, 2] += self._pad[0] intrinsic[:, 3] += self._pad[2] return intrinsic def unpad(self, x): assert x.ndim == 4 ht, wd = x.shape[-2:] c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] def forward_interpolate(flow): flow = flow.detach().cpu().numpy() dx, dy = flow[0], flow[1] ht, wd = dx.shape x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) x1 = x0 + dx y1 = y0 + dy x1 = x1.reshape(-1) y1 = y1.reshape(-1) dx = dx.reshape(-1) dy = dy.reshape(-1) valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) x1 = x1[valid] y1 = y1[valid] dx = dx[valid] dy = dy[valid] flow_x = interpolate.griddata( (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) flow_y = interpolate.griddata( (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) flow = np.stack([flow_x, flow_y], axis=0) return torch.from_numpy(flow).float() def bilinear_sampler(img, coords, mode='bilinear', mask=False): """ Wrapper for grid_sample, uses pixel coordinates """ H, W = img.shape[-2:] xgrid, ygrid = coords.split([1,1], dim=-1) xgrid = 2*xgrid/(W-1) - 1 if H > 1: ygrid = 2*ygrid/(H-1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) if mask: mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) return img, mask.float() return img def coords_grid(batch, ht, wd): coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) def hor_coords_grid(batch, ht, wd): # (batch,1,H,W) hor_coords = torch.arange(wd).float().repeat(batch, 1, ht, 1) return hor_coords def upflow8(flow, mode='bilinear'): new_size = (8 * flow.shape[2], 8 * flow.shape[3]) return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) def gauss_blur(input, N=5, std=1): B, D, H, W = input.shape x, y = torch.meshgrid(torch.arange(N).float() - N//2, torch.arange(N).float() - N//2) unnormalized_gaussian = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * std ** 2)) weights = unnormalized_gaussian / unnormalized_gaussian.sum().clamp(min=1e-4) weights = weights.view(1,1,N,N).to(input) output = F.conv2d(input.reshape(B*D,1,H,W), weights, padding=N//2) return output.view(B, D, H, W) def disparity_computation(params, slant=None, slant_norm=False, coords0=None): """ args: params: (B,C,...), C is the type of parameters. coords0: (B,C,...), C is the number of coordinates' axis. """ if slant is None or len(slant)==0 : offset = params elif slant=="slant" : # d = a*u + b*v + c B,H,W = coords0.shape[0], coords0.shape[-2], coords0.shape[-1] if slant_norm: norm_range = torch.Tensor([W,H])[None,:,None,None].float().to(coords0.device) offset = params[:,0] * coords0[:,0] / norm_range[:,0] + \ params[:,1] * coords0[:,1] / norm_range[:,1] + \ params[:,2] else: offset = params[:,0] * coords0[:,0] + \ params[:,1] * coords0[:,1] + \ params[:,2] elif slant=="slant_local": raise Exception("slant_local is not supported") else: raise Exception(f"{slant} is not supported") return offset def sv_intermediate_results(data, name, sv_path): try: sv_path = os.path.join(sv_path, "data") if not os.path.exists(sv_path): os.makedirs(sv_path) data_numpy = data.cpu().data.numpy() np.save(os.path.join(sv_path, name+".npy"), data_numpy) # print("saving to {}".format( os.path.join(sv_path, name+".npy") )) except Exception as err: raise Exception(err, data.shape, name, sv_path) def load_intermediate_results(name, sv_path): sv_path = os.path.join(sv_path, "data") data = np.load(os.path.join(sv_path, name+".npy")) return data def rescale_modulation(itr, iters, modulation_alg, modulation_ratio): # we hope modulation has less effect at the first several iterations as the disp is unreliable and the lcoal LBP disp is unreliable if modulation_alg == "linear": ratio = modulation_ratio * itr / iters elif modulation_alg == "sigmoid": ratio = modulation_ratio * 1 / (1 + np.exp(-2 * (itr - 5))) else: raise Exception("Not supported modulation_alg: {}".format(modulation_alg)) return ratio NODE_RANK = os.getenv('NODE_RANK', default=0) LOCAL_RANK = os.getenv("LOCAL_RANK", default=0) LOG_ROOT = os.getenv('LOG_ROOT', default="logs") TB_ROOT = os.getenv('TB_ROOT', default="runs") class LoggerCommon: def __init__(self, name): self.name = name self.log_name = '{}-{}.log'.format(name, datetime.now().strftime("%y%m%d_%H%M%S")) self.log_path = os.path.join(LOG_ROOT, self.log_name) if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: os.makedirs(LOG_ROOT, exist_ok=True) logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', handlers = [logging.FileHandler(self.log_path), logging.StreamHandler()] ) self.logger = logging.getLogger(name) self.logger.addHandler(logging.FileHandler(self.log_path)) def _set_handlers(self): # 清除之前的所有处理器 self.logger.handlers.clear() # 设置文件和控制台处理器 file_handler = logging.FileHandler(self.log_path) console_handler = logging.StreamHandler() # 设置处理器格式 formatter = logging.Formatter('%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s') file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) # 添加处理器 self.logger.addHandler(file_handler) self.logger.addHandler(console_handler) def set_log_path(self, new_log_path, name=None): # 删除旧日志文件(如果存在) if os.path.exists(self.log_path): os.remove(self.log_path) # 更新路径并重设处理器 if name is not None: self.name = name self.log_name = '{}-{}.log'.format(self.name, datetime.now().strftime("%y%m%d_%H%M%S")) self.log_path = os.path.join(new_log_path, self.log_name) os.makedirs(new_log_path, exist_ok=True) self._set_handlers() def info(self, message): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: self.logger.info(message) def warning(self, message): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: self.logger.warning(message) def error(self, message): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: self.logger.error(message) def exception(self, message): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: self.logger.exception(message) def print_args(self, args): msg = "" args_dict = vars(args) max_arg_length = max(len(arg_name) for arg_name in args_dict.keys()) for arg_name, arg_value in args_dict.items(): arg_name_padded = arg_name.ljust(max_arg_length) msg += f"{arg_name_padded}: {arg_value}\r\n" self.info(msg) from torch.utils.tensorboard import SummaryWriter class LoggerTraining(LoggerCommon): SUM_FREQ = 100 def __init__(self, name, model=None, scheduler=None): super(LoggerTraining, self).__init__(name) if int(LOCAL_RANK)==0 and int(NODE_RANK)==0: os.makedirs(TB_ROOT, exist_ok=True) self.model = model self.scheduler = scheduler self.silence = False self.total_steps = 0 self.running_loss = {} self.writer = SummaryWriter(log_dir=TB_ROOT) def set_training(self, model, scheduler): self.model = model self.scheduler = scheduler def _print_training_status(self): metrics_data = [self.running_loss[k]/LoggerTraining.SUM_FREQ for k in sorted(self.running_loss.keys())] training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, self.scheduler.get_last_lr()[0]) metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data) # print the training status self.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") if self.writer is None: self.writer = SummaryWriter(log_dir=TB_ROOT) for k in self.running_loss: self.writer.add_scalar(k, self.running_loss[k]/LoggerTraining.SUM_FREQ, self.total_steps) self.running_loss[k] = 0.0 def push(self, metrics): self.total_steps += 1 for key in metrics: if key not in self.running_loss: self.running_loss[key] = 0.0 self.running_loss[key] += metrics[key] if self.total_steps % LoggerTraining.SUM_FREQ == LoggerTraining.SUM_FREQ-1: self._print_training_status() self.running_loss = {} def write_dict(self, results): if self.writer is None: self.writer = SummaryWriter(log_dir=TB_ROOT) for key in results: self.writer.add_scalar(key, results[key], self.total_steps) def close(self): self.writer.close() def init_directories(directories): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0 : for directory in directories: os.makedirs(directory, exist_ok=True) def delete_directories_if_static(directories): if int(LOCAL_RANK)==0 and int(NODE_RANK)==0 : # 如果检测到文件大小有变化,终止删除操作 if not is_any_folder_static(directories): print("File sizes are changing in one of the directories {}.".format(directories) + \ "No directories will be deleted.") return # 如果所有文件都静止,删除目录 for directory in directories: if os.path.exists(directory): shutil.rmtree(directory) print(f"Directory {directory} deleted") def get_file_sizes(directories): """返回多个目录中所有文件的大小字典""" file_sizes = {} for directory in directories: if os.path.exists(directory): for root, dirs, files in os.walk(directory): for file in files: filepath = os.path.join(root, file) file_sizes[filepath] = os.path.getsize(filepath) return file_sizes def is_any_folder_static(directories, check_interval=2): """检测所有文件是否静止(没有变化)""" # 获取所有文件初始大小 initial_sizes = get_file_sizes(directories) time.sleep(check_interval) # 等待一段时间,观察文件变化 final_sizes = {filepath: os.path.getsize(filepath) for filepath in initial_sizes if os.path.exists(filepath)} # 如果文件大小一致,则所有文件静止 return initial_sizes == final_sizes