Spaces:
Runtime error
Runtime error
| import argparse | |
| import copy | |
| import datetime | |
| import json | |
| import os | |
| import random | |
| import time | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn import functional as F | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| import util.misc as utils | |
| import wandb | |
| from datasets import build_dataset | |
| from engine import evaluate, train_one_epoch | |
| from models import build_model | |
| def get_args_parser(): | |
| parser = argparse.ArgumentParser("Raster2Seq training script", add_help=False) | |
| parser.add_argument("--lr", default=2e-4, type=float) | |
| parser.add_argument("--lr_backbone_names", default=["backbone.0"], type=str, nargs="+") | |
| parser.add_argument("--lr_backbone", default=2e-5, type=float) | |
| parser.add_argument("--lr_linear_proj_names", default=["sampling_offsets"], type=str, nargs="+") | |
| parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float) | |
| parser.add_argument("--batch_size", default=10, type=int) | |
| parser.add_argument("--weight_decay", default=1e-4, type=float) | |
| parser.add_argument("--epochs", default=500, type=int) | |
| parser.add_argument("--lr_drop", default="400", type=str) | |
| parser.add_argument("--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm") | |
| parser.add_argument("--sgd", action="store_true") | |
| parser.add_argument("--input_channels", default=1, type=int) | |
| parser.add_argument("--start_from_checkpoint", default="", help="resume from checkpoint") | |
| parser.add_argument("--image_norm", action="store_true") | |
| parser.add_argument("--debug", action="store_true") | |
| parser.add_argument("--eval_every_epoch", type=int, default=20) | |
| parser.add_argument("--ckpt_every_epoch", type=int, default=20) | |
| parser.add_argument("--label_smoothing", type=float, default=0.0) | |
| parser.add_argument("--ignore_index", type=int, default=-1) | |
| parser.add_argument("--image_size", type=int, default=256) | |
| parser.add_argument("--ema4eval", action="store_true") | |
| parser.add_argument("--increase_cls_loss_coef", default=1.0, type=float) | |
| parser.add_argument("--increase_cls_loss_coef_epoch_ratio", default=-1, type=float) | |
| parser.add_argument("--use_anchor", action="store_true") | |
| parser.add_argument("--disable_wd_as_line", action="store_true") | |
| parser.add_argument("--wd_only", action="store_true") | |
| parser.add_argument("--converter_version", type=str, default="v1") | |
| parser.add_argument("--freeze_anchor", action="store_true") | |
| parser.add_argument("--inject_cls_embed", action="store_true") | |
| parser.add_argument( | |
| "--random_drop_rate", type=float, default=0.0, help="randomly drop some polygons during training" | |
| ) | |
| # raster2seq | |
| parser.add_argument("--poly2seq", action="store_true") | |
| parser.add_argument("--seq_len", type=int, default=1024) | |
| parser.add_argument("--num_bins", type=int, default=64) | |
| parser.add_argument("--pre_decoder_pos_embed", action="store_true") | |
| parser.add_argument("--learnable_dec_pe", action="store_true") | |
| parser.add_argument("--dec_qkv_proj", action="store_true") | |
| parser.add_argument("--dec_attn_concat_src", action="store_true") | |
| parser.add_argument("--per_token_sem_loss", action="store_true") | |
| parser.add_argument("--add_cls_token", action="store_true") | |
| parser.add_argument("--jointly_train", action="store_true") | |
| # parser.add_argument('--use_room_attn_at_last_dec_layer', default=False, action='store_true', help="use room-wise attention in last decoder layer") | |
| # backbone | |
| parser.add_argument("--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use") | |
| parser.add_argument( | |
| "--dilation", | |
| action="store_true", | |
| help="If true, we replace stride with dilation in the last convolutional block (DC5)", | |
| ) | |
| parser.add_argument( | |
| "--position_embedding", | |
| default="sine", | |
| type=str, | |
| choices=("sine", "learned"), | |
| help="Type of positional embedding to use on top of the image features", | |
| ) | |
| parser.add_argument("--position_embedding_scale", default=2 * np.pi, type=float, help="position / size * scale") | |
| parser.add_argument("--num_feature_levels", default=4, type=int, help="number of feature levels") | |
| # Transformer | |
| parser.add_argument("--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer") | |
| parser.add_argument("--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer") | |
| parser.add_argument( | |
| "--dim_feedforward", | |
| default=1024, | |
| type=int, | |
| help="Intermediate size of the feedforward layers in the transformer blocks", | |
| ) | |
| parser.add_argument( | |
| "--hidden_dim", default=256, type=int, help="Size of the embeddings (dimension of the transformer)" | |
| ) | |
| parser.add_argument("--dropout", default=0.1, type=float, help="Dropout applied in the transformer") | |
| parser.add_argument( | |
| "--nheads", default=8, type=int, help="Number of attention heads inside the transformer's attentions" | |
| ) | |
| parser.add_argument( | |
| "--num_queries", | |
| default=800, | |
| type=int, | |
| help="Number of query slots (num_polys * max. number of corner per poly)", | |
| ) | |
| parser.add_argument("--num_polys", default=20, type=int, help="Number of maximum number of room polygons") | |
| parser.add_argument("--dec_n_points", default=4, type=int) | |
| parser.add_argument("--enc_n_points", default=4, type=int) | |
| parser.add_argument( | |
| "--query_pos_type", | |
| default="sine", | |
| type=str, | |
| choices=("static", "sine", "none"), | |
| help="Type of query pos in decoder - \ | |
| 1. static: same setting with DETR and Deformable-DETR, the query_pos is the same for all layers \ | |
| 2. sine: since embedding from reference points (so if references points update, query_pos also \ | |
| 3. none: remove query_pos", | |
| ) | |
| parser.add_argument( | |
| "--with_poly_refine", | |
| default=True, | |
| action="store_true", | |
| help="iteratively refine reference points (i.e. positional part of polygon queries)", | |
| ) | |
| parser.add_argument( | |
| "--masked_attn", | |
| default=False, | |
| action="store_true", | |
| help="if true, the query in one room will not be allowed to attend other room", | |
| ) | |
| parser.add_argument( | |
| "--semantic_classes", | |
| default=-1, | |
| type=int, | |
| help="Number of classes for semantically-rich floorplan: \ | |
| 1. default -1 means non-semantic floorplan \ | |
| 2. 19 for Structured3D: 16 room types + 1 door + 1 window + 1 empty", | |
| ) | |
| parser.add_argument( | |
| "--disable_poly_refine", | |
| action="store_true", | |
| help="iteratively refine reference points (i.e. positional part of polygon queries)", | |
| ) | |
| # loss | |
| parser.add_argument( | |
| "--no_aux_loss", | |
| dest="aux_loss", | |
| action="store_true", | |
| help="Disables auxiliary decoding losses (loss at each layer)", | |
| ) | |
| # matcher | |
| parser.add_argument("--set_cost_class", default=2, type=float, help="Class coefficient in the matching cost") | |
| parser.add_argument("--set_cost_coords", default=5, type=float, help="L1 coords coefficient in the matching cost") | |
| # loss coefficients | |
| parser.add_argument("--cls_loss_coef", default=2, type=float) | |
| parser.add_argument("--room_cls_loss_coef", default=0.2, type=float) | |
| parser.add_argument("--coords_loss_coef", default=5, type=float) | |
| parser.add_argument("--raster_loss_coef", default=0, type=float) | |
| # dataset parameters | |
| parser.add_argument("--dataset_name", default="stru3d") | |
| parser.add_argument("--dataset_root", default="data/stru3d", type=str) | |
| parser.add_argument("--output_dir", default="output", help="path where to save, empty for no saving") | |
| parser.add_argument("--device", default="cuda", help="device to use for training / testing") | |
| parser.add_argument("--seed", default=42, type=int) | |
| parser.add_argument("--resume", default="", help="resume from checkpoint") | |
| parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch") | |
| parser.add_argument("--num_workers", default=2, type=int) | |
| parser.add_argument("--job_name", default="train_stru3d", type=str) | |
| return parser | |
| def main(args): | |
| print("git:\n {}\n".format(utils.get_sha())) | |
| print(args) | |
| # Setup DDP: | |
| dist.init_process_group("nccl") | |
| rank = dist.get_rank() | |
| device = rank % torch.cuda.device_count() | |
| seed = args.seed * dist.get_world_size() + rank | |
| # fix the seed for reproducibility | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.cuda.set_device(device) | |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
| # setup wandb for logging | |
| if rank == 0: | |
| utils.setup_wandb() | |
| wandb.init(project="Raster2Seq", resume="allow", id=args.run_name, dir="./wandb") | |
| # build dataset and dataloader | |
| dataset_train = build_dataset(image_set="train", args=args) | |
| dataset_val = build_dataset(image_set="val", args=args) | |
| tokenizer = None | |
| if args.poly2seq: | |
| args.vocab_size = dataset_train.get_vocab_size() | |
| tokenizer = dataset_train.get_tokenizer() | |
| # overfit one sample | |
| if args.debug: | |
| dataset_val = torch.utils.data.Subset(copy.deepcopy(dataset_val), [0]) | |
| dataset_train = copy.deepcopy(dataset_val) | |
| sampler_train = DistributedSampler( | |
| dataset_train, num_replicas=dist.get_world_size(), rank=rank, shuffle=True, seed=args.seed | |
| ) | |
| sampler_val = DistributedSampler( | |
| dataset_val, num_replicas=dist.get_world_size(), rank=rank, shuffle=False, seed=args.seed | |
| ) | |
| def trivial_batch_collator(batch): | |
| """ | |
| A batch collator that does nothing. | |
| """ | |
| if "target_seq" in batch[0]: | |
| # Concatenate tensors for each key in the batch | |
| delta_x1 = torch.stack([item["delta_x1"] for item in batch], dim=0) | |
| delta_x2 = torch.stack([item["delta_x2"] for item in batch], dim=0) | |
| delta_y1 = torch.stack([item["delta_y1"] for item in batch], dim=0) | |
| delta_y2 = torch.stack([item["delta_y2"] for item in batch], dim=0) | |
| seq11 = torch.stack([item["seq11"] for item in batch], dim=0) | |
| seq21 = torch.stack([item["seq21"] for item in batch], dim=0) | |
| seq12 = torch.stack([item["seq12"] for item in batch], dim=0) | |
| seq22 = torch.stack([item["seq22"] for item in batch], dim=0) | |
| target_seq = torch.stack([item["target_seq"] for item in batch], dim=0) | |
| token_labels = torch.stack([item["token_labels"] for item in batch], dim=0) | |
| mask = torch.stack([item["mask"] for item in batch], dim=0) | |
| target_polygon_labels = torch.stack([item["target_polygon_labels"] for item in batch], dim=0) | |
| # input_polygon_labels = torch.stack([item['input_polygon_labels'] for item in batch], dim=0) | |
| # Delete the keys from the batch | |
| for item in batch: | |
| del item["delta_x1"] | |
| del item["delta_x2"] | |
| del item["delta_y1"] | |
| del item["delta_y2"] | |
| del item["seq11"] | |
| del item["seq21"] | |
| del item["seq12"] | |
| del item["seq22"] | |
| del item["target_seq"] | |
| del item["token_labels"] | |
| del item["mask"] | |
| del item["target_polygon_labels"] | |
| # del item['input_polygon_labels'] | |
| # Return the concatenated batch | |
| return batch, { | |
| "delta_x1": delta_x1, | |
| "delta_x2": delta_x2, | |
| "delta_y1": delta_y1, | |
| "delta_y2": delta_y2, | |
| "seq11": seq11, | |
| "seq21": seq21, | |
| "seq12": seq12, | |
| "seq22": seq22, | |
| "target_seq": target_seq, | |
| "token_labels": token_labels, | |
| "mask": mask, | |
| "target_polygon_labels": target_polygon_labels, | |
| # 'input_polygon_labels': input_polygon_labels, | |
| } | |
| return batch, None | |
| data_loader_train = DataLoader( | |
| dataset_train, | |
| args.batch_size, | |
| shuffle=False, | |
| sampler=sampler_train, | |
| num_workers=args.num_workers, | |
| collate_fn=trivial_batch_collator, | |
| pin_memory=True, | |
| drop_last=True, | |
| ) | |
| data_loader_val = DataLoader( | |
| dataset_val, | |
| args.batch_size, | |
| shuffle=False, | |
| sampler=sampler_val, | |
| collate_fn=trivial_batch_collator, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| drop_last=False, | |
| ) | |
| # build model | |
| model, criterion = build_model(args, tokenizer=tokenizer) | |
| ema = copy.deepcopy(model).to(device) | |
| utils.requires_grad(ema, False) | |
| model = DDP(model.to(device), device_ids=[rank], find_unused_parameters=True) | |
| n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print("number of params:", n_parameters) | |
| def match_name_keywords(n, name_keywords): | |
| out = False | |
| for b in name_keywords: | |
| if b in n: | |
| out = True | |
| break | |
| return out | |
| for n, p in model.named_parameters(): | |
| print(n) | |
| if args.per_token_sem_loss and not args.jointly_train: | |
| # disable gradient for model, except new classifier | |
| for n, p in model.named_parameters(): | |
| if "room_class_embed" in n: | |
| p.requires_grad = True | |
| else: | |
| p.requires_grad = False | |
| param_dicts = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if not match_name_keywords(n, args.lr_backbone_names) | |
| and not match_name_keywords(n, args.lr_linear_proj_names) | |
| and p.requires_grad | |
| ], | |
| "lr": args.lr, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad | |
| ], | |
| "lr": args.lr_backbone, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad | |
| ], | |
| "lr": args.lr * args.lr_linear_proj_mult, | |
| }, | |
| ] | |
| print(f"Rank {dist.get_rank()}: Model has {sum(p.numel() for p in model.parameters())} parameters") | |
| if args.sgd: | |
| optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) | |
| else: | |
| optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) | |
| if args.lr_drop: | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop) | |
| else: | |
| lr_scheduler = None | |
| output_dir = Path(args.output_dir) | |
| if args.resume and os.path.exists(args.resume): | |
| checkpoint = torch.load(args.resume, map_location="cpu") | |
| for key, value in checkpoint["model"].items(): | |
| if key.startswith("module."): | |
| checkpoint[key[7:]] = checkpoint["model"][key] | |
| del checkpoint[key] | |
| missing_keys, unexpected_keys = model.module.load_state_dict(checkpoint["model"], strict=False) | |
| unexpected_keys = [k for k in unexpected_keys if not (k.endswith("total_params") or k.endswith("total_ops"))] | |
| if len(missing_keys) > 0: | |
| print("Missing Keys: {}".format(missing_keys)) | |
| raise ValueError("Missing keys in state_dict") | |
| if len(unexpected_keys) > 0: | |
| print("Unexpected Keys: {}".format(unexpected_keys)) | |
| if "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint: | |
| p_groups = copy.deepcopy(optimizer.param_groups) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| for pg, pg_old in zip(optimizer.param_groups, p_groups): | |
| pg["lr"] = pg_old["lr"] | |
| if "initial_lr" in pg_old: | |
| pg["initial_lr"] = pg_old["initial_lr"] | |
| # print(optimizer.param_groups) | |
| if lr_scheduler is not None and checkpoint["lr_scheduler"] is not None: | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | |
| # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). | |
| args.override_resumed_lr_drop = False | |
| if args.override_resumed_lr_drop: | |
| print( | |
| "Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler." | |
| ) | |
| lr_scheduler.step_size = args.lr_drop | |
| if lr_scheduler is not None: | |
| lr_scheduler.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups)) | |
| if lr_scheduler is not None: | |
| lr_scheduler.step(lr_scheduler.last_epoch) | |
| args.start_epoch = checkpoint["epoch"] + 1 | |
| # # check the resumed model | |
| # test_stats = evaluate( | |
| # model, criterion, args.dataset_name, data_loader_val, device, poly2seq=args.poly2seq | |
| # ) | |
| dist.barrier() | |
| if args.start_from_checkpoint: | |
| checkpoint = torch.load(args.start_from_checkpoint, map_location="cpu")["model"] | |
| for key, value in checkpoint.items(): | |
| if key.startswith("class_embed"): | |
| if checkpoint[key].size(0) != model.module.num_classes: | |
| if "weight" in key: | |
| checkpoint[key] = torch.cat( | |
| [checkpoint[key], torch.zeros((1, checkpoint[key].size(1)), dtype=torch.float)], dim=0 | |
| ) | |
| else: | |
| checkpoint[key] = torch.cat([checkpoint[key], torch.zeros([1], dtype=torch.float)], dim=0) | |
| elif "token_embed" in key: | |
| if checkpoint[key].size(0) != model.module.transformer.decoder.token_embed.weight.size(0): | |
| checkpoint[key] = torch.cat( | |
| [checkpoint[key], torch.zeros((1, checkpoint[key].size(1)), dtype=torch.float)], dim=0 | |
| ) | |
| elif "pos_embed" in key and checkpoint[key].shape[1] != model.module.transformer.pos_embed.shape[1]: | |
| checkpoint[key] = model.module.transformer.pos_embed | |
| elif "attention_mask" in key and checkpoint[key].shape[0] != model.module.attention_mask.shape[0]: | |
| checkpoint[key] = model.module.attention_mask | |
| elif key.startswith("input_proj") and key.endswith("weight"): | |
| # only modify the conv layer | |
| lidx, sub_lidx = int(key.split(".")[1]), int(key.split(".")[2]) | |
| if sub_lidx != 0: | |
| continue | |
| tgt_size = model.module.input_proj[lidx][0].weight.size(2) | |
| if tgt_size != checkpoint[key].size(2): | |
| checkpoint[key] = F.interpolate( | |
| checkpoint[key], size=(tgt_size, tgt_size), mode="bilinear", align_corners=False | |
| ) | |
| elif "sampling_offsets" in key: | |
| diff_scale = model.module.transformer.encoder.layers[0].self_attn.sampling_offsets.weight.size( | |
| 0 | |
| ) // checkpoint[key].size(0) | |
| if diff_scale > 1: | |
| if ".weight" in key: | |
| checkpoint[key] = checkpoint[key].repeat((diff_scale, 1)) | |
| else: | |
| checkpoint[key] = checkpoint[key].repeat((diff_scale,)) | |
| elif "attention_weights" in key: | |
| diff_scale = model.module.transformer.encoder.layers[0].self_attn.attention_weights.weight.size( | |
| 0 | |
| ) // checkpoint[key].size(0) | |
| if diff_scale > 1: | |
| if ".weight" in key: | |
| checkpoint[key] = checkpoint[key].repeat((diff_scale, 1)) | |
| else: | |
| checkpoint[key] = checkpoint[key].repeat((diff_scale,)) | |
| missing_keys, unexpected_keys = model.module.load_state_dict(checkpoint, strict=False) | |
| unexpected_keys = [k for k in unexpected_keys if not (k.endswith("total_params") or k.endswith("total_ops"))] | |
| if len(missing_keys) > 0: | |
| print("Missing Keys: {}".format(missing_keys)) | |
| if len(unexpected_keys) > 0: | |
| print("Unexpected Keys: {}".format(unexpected_keys)) | |
| dist.barrier() | |
| # Prepare models for training: | |
| utils.update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights | |
| ema.eval() | |
| print("Start training") | |
| start_time = time.time() | |
| for epoch in range(args.start_epoch, args.epochs): | |
| sampler_train.set_epoch(epoch) | |
| train_stats = train_one_epoch( | |
| model, | |
| criterion, | |
| data_loader_train, | |
| optimizer, | |
| device, | |
| epoch, | |
| args.clip_max_norm, | |
| args.poly2seq, | |
| ema_model=ema, | |
| drop_rate=args.random_drop_rate, | |
| ) | |
| if lr_scheduler is not None: | |
| lr_scheduler.step() | |
| if epoch > int(args.increase_cls_loss_coef_epoch_ratio * args.epochs) and args.increase_cls_loss_coef > 1.0: | |
| criterion._update_ce_coeff(args.increase_cls_loss_coef * args.cls_loss_coef) | |
| if (epoch + 1) in args.lr_drop or (epoch + 1) % args.ckpt_every_epoch == 0 or (epoch + 1) == args.epochs: | |
| if rank == 0: | |
| checkpoint_paths = [output_dir / "checkpoint.pth"] | |
| # extra checkpoint before LR drop and every 20 epochs | |
| checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") | |
| for checkpoint_path in checkpoint_paths: | |
| torch.save( | |
| { | |
| "model": model.module.state_dict(), | |
| "ema": ema.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "lr_scheduler": None if lr_scheduler is None else lr_scheduler.state_dict(), | |
| "epoch": epoch, | |
| "args": args, | |
| }, | |
| checkpoint_path, | |
| ) | |
| dist.barrier() | |
| log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch, "n_parameters": n_parameters} | |
| if rank == 0: | |
| wandb.log({"epoch": epoch}) | |
| wandb.log({"lr_rate": train_stats["lr"]}) | |
| train_log_dict = { | |
| "train/loss": train_stats["loss"], | |
| "train/loss_ce": train_stats["loss_ce"], | |
| "train/loss_coords": train_stats["loss_coords"], | |
| "train/loss_coords_unscaled": train_stats["loss_coords_unscaled"], | |
| "train/cardinality_error": train_stats["cardinality_error_unscaled"], | |
| } | |
| if args.semantic_classes > 0: | |
| # need to log additional metrics for semantically-rich floorplans | |
| train_log_dict["train/loss_ce_room"] = train_stats["loss_ce_room"] | |
| else: | |
| if "loss_raster" in train_stats: | |
| # only apply the rasterization loss for non-semantic floorplans | |
| train_log_dict["train/loss_raster"] = train_stats["loss_raster"] | |
| if rank == 0: | |
| wandb.log(train_log_dict) | |
| # eval every 20 | |
| if (epoch + 1) % args.eval_every_epoch == 0: | |
| eval_model = model if not args.ema4eval else ema | |
| test_stats = evaluate( | |
| eval_model, | |
| criterion, | |
| args.dataset_name, | |
| data_loader_val, | |
| device, | |
| plot_density=True, | |
| output_dir=output_dir, | |
| epoch=epoch, | |
| poly2seq=args.poly2seq, | |
| add_cls_token=args.add_cls_token, | |
| per_token_sem_loss=args.per_token_sem_loss, | |
| wd_as_line=not args.disable_wd_as_line, | |
| ) | |
| log_stats.update(**{f"test_{k}": v for k, v in test_stats.items()}) | |
| val_log_dict = { | |
| "val/loss": test_stats["loss"], | |
| "val/loss_ce": test_stats["loss_ce"], | |
| "val/loss_coords": test_stats["loss_coords"], | |
| "val/loss_coords_unscaled": test_stats["loss_coords_unscaled"], | |
| "val/cardinality_error": test_stats["cardinality_error_unscaled"], | |
| "val_metrics/room_prec": test_stats["room_prec"], | |
| "val_metrics/room_rec": test_stats["room_rec"], | |
| "val_metrics/corner_prec": test_stats["corner_prec"], | |
| "val_metrics/corner_rec": test_stats["corner_rec"], | |
| "val_metrics/angles_prec": test_stats["angles_prec"], | |
| "val_metrics/angles_rec": test_stats["angles_rec"], | |
| } | |
| if args.semantic_classes > 0: | |
| # need to log additional metrics for semantically-rich floorplans | |
| val_log_dict["val/loss_ce_room"] = test_stats["loss_ce_room"] | |
| val_log_dict["val_metrics/room_sem_prec"] = test_stats["room_sem_prec"] | |
| val_log_dict["val_metrics/room_sem_rec"] = test_stats["room_sem_rec"] | |
| if "window_door_prec" in test_stats: | |
| val_log_dict["val_metrics/window_door_prec"] = test_stats["window_door_prec"] | |
| val_log_dict["val_metrics/window_door_rec"] = test_stats["window_door_rec"] | |
| else: | |
| if "loss_raster" in test_stats: | |
| # only apply the rasterization loss for non-semantic floorplans | |
| val_log_dict["val/loss_raster"] = test_stats["loss_raster"] | |
| if "room_iou" in test_stats: | |
| val_log_dict["val_metrics/room_iou"] = test_stats["room_iou"] | |
| if rank == 0: | |
| wandb.log(val_log_dict) | |
| if args.output_dir: | |
| with (output_dir / "log.txt").open("a") as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print("Training time {}".format(total_time_str)) | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Raster2Seq training script", parents=[get_args_parser()]) | |
| args = parser.parse_args() | |
| now = datetime.datetime.now() | |
| # run_id = now.strftime("%Y-%m-%d-%H-%M-%S") | |
| args.run_name = args.job_name # run_id+'_'+args.job_name | |
| args.output_dir = os.path.join(args.output_dir, args.run_name) | |
| args.lr_drop = [] if len(args.lr_drop) == 0 else [int(x) for x in args.lr_drop.split(",")] | |
| if args.debug: | |
| args.batch_size = 1 | |
| if args.disable_poly_refine: | |
| args.with_poly_refine = False | |
| if args.output_dir: | |
| Path(args.output_dir).mkdir(parents=True, exist_ok=True) | |
| main(args) | |