raster2seq / main_ddp.py
anas
Initial deployment of Raster2Seq floor plan vectorization API
fadb92b
import argparse
import copy
import datetime
import json
import os
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
from torch.nn import functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import util.misc as utils
import wandb
from datasets import build_dataset
from engine import evaluate, train_one_epoch
from models import build_model
def get_args_parser():
parser = argparse.ArgumentParser("Raster2Seq training script", add_help=False)
parser.add_argument("--lr", default=2e-4, type=float)
parser.add_argument("--lr_backbone_names", default=["backbone.0"], type=str, nargs="+")
parser.add_argument("--lr_backbone", default=2e-5, type=float)
parser.add_argument("--lr_linear_proj_names", default=["sampling_offsets"], type=str, nargs="+")
parser.add_argument("--lr_linear_proj_mult", default=0.1, type=float)
parser.add_argument("--batch_size", default=10, type=int)
parser.add_argument("--weight_decay", default=1e-4, type=float)
parser.add_argument("--epochs", default=500, type=int)
parser.add_argument("--lr_drop", default="400", type=str)
parser.add_argument("--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm")
parser.add_argument("--sgd", action="store_true")
parser.add_argument("--input_channels", default=1, type=int)
parser.add_argument("--start_from_checkpoint", default="", help="resume from checkpoint")
parser.add_argument("--image_norm", action="store_true")
parser.add_argument("--debug", action="store_true")
parser.add_argument("--eval_every_epoch", type=int, default=20)
parser.add_argument("--ckpt_every_epoch", type=int, default=20)
parser.add_argument("--label_smoothing", type=float, default=0.0)
parser.add_argument("--ignore_index", type=int, default=-1)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--ema4eval", action="store_true")
parser.add_argument("--increase_cls_loss_coef", default=1.0, type=float)
parser.add_argument("--increase_cls_loss_coef_epoch_ratio", default=-1, type=float)
parser.add_argument("--use_anchor", action="store_true")
parser.add_argument("--disable_wd_as_line", action="store_true")
parser.add_argument("--wd_only", action="store_true")
parser.add_argument("--converter_version", type=str, default="v1")
parser.add_argument("--freeze_anchor", action="store_true")
parser.add_argument("--inject_cls_embed", action="store_true")
parser.add_argument(
"--random_drop_rate", type=float, default=0.0, help="randomly drop some polygons during training"
)
# raster2seq
parser.add_argument("--poly2seq", action="store_true")
parser.add_argument("--seq_len", type=int, default=1024)
parser.add_argument("--num_bins", type=int, default=64)
parser.add_argument("--pre_decoder_pos_embed", action="store_true")
parser.add_argument("--learnable_dec_pe", action="store_true")
parser.add_argument("--dec_qkv_proj", action="store_true")
parser.add_argument("--dec_attn_concat_src", action="store_true")
parser.add_argument("--per_token_sem_loss", action="store_true")
parser.add_argument("--add_cls_token", action="store_true")
parser.add_argument("--jointly_train", action="store_true")
# parser.add_argument('--use_room_attn_at_last_dec_layer', default=False, action='store_true', help="use room-wise attention in last decoder layer")
# backbone
parser.add_argument("--backbone", default="resnet50", type=str, help="Name of the convolutional backbone to use")
parser.add_argument(
"--dilation",
action="store_true",
help="If true, we replace stride with dilation in the last convolutional block (DC5)",
)
parser.add_argument(
"--position_embedding",
default="sine",
type=str,
choices=("sine", "learned"),
help="Type of positional embedding to use on top of the image features",
)
parser.add_argument("--position_embedding_scale", default=2 * np.pi, type=float, help="position / size * scale")
parser.add_argument("--num_feature_levels", default=4, type=int, help="number of feature levels")
# Transformer
parser.add_argument("--enc_layers", default=6, type=int, help="Number of encoding layers in the transformer")
parser.add_argument("--dec_layers", default=6, type=int, help="Number of decoding layers in the transformer")
parser.add_argument(
"--dim_feedforward",
default=1024,
type=int,
help="Intermediate size of the feedforward layers in the transformer blocks",
)
parser.add_argument(
"--hidden_dim", default=256, type=int, help="Size of the embeddings (dimension of the transformer)"
)
parser.add_argument("--dropout", default=0.1, type=float, help="Dropout applied in the transformer")
parser.add_argument(
"--nheads", default=8, type=int, help="Number of attention heads inside the transformer's attentions"
)
parser.add_argument(
"--num_queries",
default=800,
type=int,
help="Number of query slots (num_polys * max. number of corner per poly)",
)
parser.add_argument("--num_polys", default=20, type=int, help="Number of maximum number of room polygons")
parser.add_argument("--dec_n_points", default=4, type=int)
parser.add_argument("--enc_n_points", default=4, type=int)
parser.add_argument(
"--query_pos_type",
default="sine",
type=str,
choices=("static", "sine", "none"),
help="Type of query pos in decoder - \
1. static: same setting with DETR and Deformable-DETR, the query_pos is the same for all layers \
2. sine: since embedding from reference points (so if references points update, query_pos also \
3. none: remove query_pos",
)
parser.add_argument(
"--with_poly_refine",
default=True,
action="store_true",
help="iteratively refine reference points (i.e. positional part of polygon queries)",
)
parser.add_argument(
"--masked_attn",
default=False,
action="store_true",
help="if true, the query in one room will not be allowed to attend other room",
)
parser.add_argument(
"--semantic_classes",
default=-1,
type=int,
help="Number of classes for semantically-rich floorplan: \
1. default -1 means non-semantic floorplan \
2. 19 for Structured3D: 16 room types + 1 door + 1 window + 1 empty",
)
parser.add_argument(
"--disable_poly_refine",
action="store_true",
help="iteratively refine reference points (i.e. positional part of polygon queries)",
)
# loss
parser.add_argument(
"--no_aux_loss",
dest="aux_loss",
action="store_true",
help="Disables auxiliary decoding losses (loss at each layer)",
)
# matcher
parser.add_argument("--set_cost_class", default=2, type=float, help="Class coefficient in the matching cost")
parser.add_argument("--set_cost_coords", default=5, type=float, help="L1 coords coefficient in the matching cost")
# loss coefficients
parser.add_argument("--cls_loss_coef", default=2, type=float)
parser.add_argument("--room_cls_loss_coef", default=0.2, type=float)
parser.add_argument("--coords_loss_coef", default=5, type=float)
parser.add_argument("--raster_loss_coef", default=0, type=float)
# dataset parameters
parser.add_argument("--dataset_name", default="stru3d")
parser.add_argument("--dataset_root", default="data/stru3d", type=str)
parser.add_argument("--output_dir", default="output", help="path where to save, empty for no saving")
parser.add_argument("--device", default="cuda", help="device to use for training / testing")
parser.add_argument("--seed", default=42, type=int)
parser.add_argument("--resume", default="", help="resume from checkpoint")
parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
parser.add_argument("--num_workers", default=2, type=int)
parser.add_argument("--job_name", default="train_stru3d", type=str)
return parser
def main(args):
print("git:\n {}\n".format(utils.get_sha()))
print(args)
# Setup DDP:
dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.seed * dist.get_world_size() + rank
# fix the seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
# setup wandb for logging
if rank == 0:
utils.setup_wandb()
wandb.init(project="Raster2Seq", resume="allow", id=args.run_name, dir="./wandb")
# build dataset and dataloader
dataset_train = build_dataset(image_set="train", args=args)
dataset_val = build_dataset(image_set="val", args=args)
tokenizer = None
if args.poly2seq:
args.vocab_size = dataset_train.get_vocab_size()
tokenizer = dataset_train.get_tokenizer()
# overfit one sample
if args.debug:
dataset_val = torch.utils.data.Subset(copy.deepcopy(dataset_val), [0])
dataset_train = copy.deepcopy(dataset_val)
sampler_train = DistributedSampler(
dataset_train, num_replicas=dist.get_world_size(), rank=rank, shuffle=True, seed=args.seed
)
sampler_val = DistributedSampler(
dataset_val, num_replicas=dist.get_world_size(), rank=rank, shuffle=False, seed=args.seed
)
def trivial_batch_collator(batch):
"""
A batch collator that does nothing.
"""
if "target_seq" in batch[0]:
# Concatenate tensors for each key in the batch
delta_x1 = torch.stack([item["delta_x1"] for item in batch], dim=0)
delta_x2 = torch.stack([item["delta_x2"] for item in batch], dim=0)
delta_y1 = torch.stack([item["delta_y1"] for item in batch], dim=0)
delta_y2 = torch.stack([item["delta_y2"] for item in batch], dim=0)
seq11 = torch.stack([item["seq11"] for item in batch], dim=0)
seq21 = torch.stack([item["seq21"] for item in batch], dim=0)
seq12 = torch.stack([item["seq12"] for item in batch], dim=0)
seq22 = torch.stack([item["seq22"] for item in batch], dim=0)
target_seq = torch.stack([item["target_seq"] for item in batch], dim=0)
token_labels = torch.stack([item["token_labels"] for item in batch], dim=0)
mask = torch.stack([item["mask"] for item in batch], dim=0)
target_polygon_labels = torch.stack([item["target_polygon_labels"] for item in batch], dim=0)
# input_polygon_labels = torch.stack([item['input_polygon_labels'] for item in batch], dim=0)
# Delete the keys from the batch
for item in batch:
del item["delta_x1"]
del item["delta_x2"]
del item["delta_y1"]
del item["delta_y2"]
del item["seq11"]
del item["seq21"]
del item["seq12"]
del item["seq22"]
del item["target_seq"]
del item["token_labels"]
del item["mask"]
del item["target_polygon_labels"]
# del item['input_polygon_labels']
# Return the concatenated batch
return batch, {
"delta_x1": delta_x1,
"delta_x2": delta_x2,
"delta_y1": delta_y1,
"delta_y2": delta_y2,
"seq11": seq11,
"seq21": seq21,
"seq12": seq12,
"seq22": seq22,
"target_seq": target_seq,
"token_labels": token_labels,
"mask": mask,
"target_polygon_labels": target_polygon_labels,
# 'input_polygon_labels': input_polygon_labels,
}
return batch, None
data_loader_train = DataLoader(
dataset_train,
args.batch_size,
shuffle=False,
sampler=sampler_train,
num_workers=args.num_workers,
collate_fn=trivial_batch_collator,
pin_memory=True,
drop_last=True,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
shuffle=False,
sampler=sampler_val,
collate_fn=trivial_batch_collator,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
# build model
model, criterion = build_model(args, tokenizer=tokenizer)
ema = copy.deepcopy(model).to(device)
utils.requires_grad(ema, False)
model = DDP(model.to(device), device_ids=[rank], find_unused_parameters=True)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)
def match_name_keywords(n, name_keywords):
out = False
for b in name_keywords:
if b in n:
out = True
break
return out
for n, p in model.named_parameters():
print(n)
if args.per_token_sem_loss and not args.jointly_train:
# disable gradient for model, except new classifier
for n, p in model.named_parameters():
if "room_class_embed" in n:
p.requires_grad = True
else:
p.requires_grad = False
param_dicts = [
{
"params": [
p
for n, p in model.named_parameters()
if not match_name_keywords(n, args.lr_backbone_names)
and not match_name_keywords(n, args.lr_linear_proj_names)
and p.requires_grad
],
"lr": args.lr,
},
{
"params": [
p
for n, p in model.named_parameters()
if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad
],
"lr": args.lr_backbone,
},
{
"params": [
p
for n, p in model.named_parameters()
if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad
],
"lr": args.lr * args.lr_linear_proj_mult,
},
]
print(f"Rank {dist.get_rank()}: Model has {sum(p.numel() for p in model.parameters())} parameters")
if args.sgd:
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
else:
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
if args.lr_drop:
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop)
else:
lr_scheduler = None
output_dir = Path(args.output_dir)
if args.resume and os.path.exists(args.resume):
checkpoint = torch.load(args.resume, map_location="cpu")
for key, value in checkpoint["model"].items():
if key.startswith("module."):
checkpoint[key[7:]] = checkpoint["model"][key]
del checkpoint[key]
missing_keys, unexpected_keys = model.module.load_state_dict(checkpoint["model"], strict=False)
unexpected_keys = [k for k in unexpected_keys if not (k.endswith("total_params") or k.endswith("total_ops"))]
if len(missing_keys) > 0:
print("Missing Keys: {}".format(missing_keys))
raise ValueError("Missing keys in state_dict")
if len(unexpected_keys) > 0:
print("Unexpected Keys: {}".format(unexpected_keys))
if "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint:
p_groups = copy.deepcopy(optimizer.param_groups)
optimizer.load_state_dict(checkpoint["optimizer"])
for pg, pg_old in zip(optimizer.param_groups, p_groups):
pg["lr"] = pg_old["lr"]
if "initial_lr" in pg_old:
pg["initial_lr"] = pg_old["initial_lr"]
# print(optimizer.param_groups)
if lr_scheduler is not None and checkpoint["lr_scheduler"] is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
# todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
args.override_resumed_lr_drop = False
if args.override_resumed_lr_drop:
print(
"Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler."
)
lr_scheduler.step_size = args.lr_drop
if lr_scheduler is not None:
lr_scheduler.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
if lr_scheduler is not None:
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint["epoch"] + 1
# # check the resumed model
# test_stats = evaluate(
# model, criterion, args.dataset_name, data_loader_val, device, poly2seq=args.poly2seq
# )
dist.barrier()
if args.start_from_checkpoint:
checkpoint = torch.load(args.start_from_checkpoint, map_location="cpu")["model"]
for key, value in checkpoint.items():
if key.startswith("class_embed"):
if checkpoint[key].size(0) != model.module.num_classes:
if "weight" in key:
checkpoint[key] = torch.cat(
[checkpoint[key], torch.zeros((1, checkpoint[key].size(1)), dtype=torch.float)], dim=0
)
else:
checkpoint[key] = torch.cat([checkpoint[key], torch.zeros([1], dtype=torch.float)], dim=0)
elif "token_embed" in key:
if checkpoint[key].size(0) != model.module.transformer.decoder.token_embed.weight.size(0):
checkpoint[key] = torch.cat(
[checkpoint[key], torch.zeros((1, checkpoint[key].size(1)), dtype=torch.float)], dim=0
)
elif "pos_embed" in key and checkpoint[key].shape[1] != model.module.transformer.pos_embed.shape[1]:
checkpoint[key] = model.module.transformer.pos_embed
elif "attention_mask" in key and checkpoint[key].shape[0] != model.module.attention_mask.shape[0]:
checkpoint[key] = model.module.attention_mask
elif key.startswith("input_proj") and key.endswith("weight"):
# only modify the conv layer
lidx, sub_lidx = int(key.split(".")[1]), int(key.split(".")[2])
if sub_lidx != 0:
continue
tgt_size = model.module.input_proj[lidx][0].weight.size(2)
if tgt_size != checkpoint[key].size(2):
checkpoint[key] = F.interpolate(
checkpoint[key], size=(tgt_size, tgt_size), mode="bilinear", align_corners=False
)
elif "sampling_offsets" in key:
diff_scale = model.module.transformer.encoder.layers[0].self_attn.sampling_offsets.weight.size(
0
) // checkpoint[key].size(0)
if diff_scale > 1:
if ".weight" in key:
checkpoint[key] = checkpoint[key].repeat((diff_scale, 1))
else:
checkpoint[key] = checkpoint[key].repeat((diff_scale,))
elif "attention_weights" in key:
diff_scale = model.module.transformer.encoder.layers[0].self_attn.attention_weights.weight.size(
0
) // checkpoint[key].size(0)
if diff_scale > 1:
if ".weight" in key:
checkpoint[key] = checkpoint[key].repeat((diff_scale, 1))
else:
checkpoint[key] = checkpoint[key].repeat((diff_scale,))
missing_keys, unexpected_keys = model.module.load_state_dict(checkpoint, strict=False)
unexpected_keys = [k for k in unexpected_keys if not (k.endswith("total_params") or k.endswith("total_ops"))]
if len(missing_keys) > 0:
print("Missing Keys: {}".format(missing_keys))
if len(unexpected_keys) > 0:
print("Unexpected Keys: {}".format(unexpected_keys))
dist.barrier()
# Prepare models for training:
utils.update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
ema.eval()
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model,
criterion,
data_loader_train,
optimizer,
device,
epoch,
args.clip_max_norm,
args.poly2seq,
ema_model=ema,
drop_rate=args.random_drop_rate,
)
if lr_scheduler is not None:
lr_scheduler.step()
if epoch > int(args.increase_cls_loss_coef_epoch_ratio * args.epochs) and args.increase_cls_loss_coef > 1.0:
criterion._update_ce_coeff(args.increase_cls_loss_coef * args.cls_loss_coef)
if (epoch + 1) in args.lr_drop or (epoch + 1) % args.ckpt_every_epoch == 0 or (epoch + 1) == args.epochs:
if rank == 0:
checkpoint_paths = [output_dir / "checkpoint.pth"]
# extra checkpoint before LR drop and every 20 epochs
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
torch.save(
{
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": None if lr_scheduler is None else lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
},
checkpoint_path,
)
dist.barrier()
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, "epoch": epoch, "n_parameters": n_parameters}
if rank == 0:
wandb.log({"epoch": epoch})
wandb.log({"lr_rate": train_stats["lr"]})
train_log_dict = {
"train/loss": train_stats["loss"],
"train/loss_ce": train_stats["loss_ce"],
"train/loss_coords": train_stats["loss_coords"],
"train/loss_coords_unscaled": train_stats["loss_coords_unscaled"],
"train/cardinality_error": train_stats["cardinality_error_unscaled"],
}
if args.semantic_classes > 0:
# need to log additional metrics for semantically-rich floorplans
train_log_dict["train/loss_ce_room"] = train_stats["loss_ce_room"]
else:
if "loss_raster" in train_stats:
# only apply the rasterization loss for non-semantic floorplans
train_log_dict["train/loss_raster"] = train_stats["loss_raster"]
if rank == 0:
wandb.log(train_log_dict)
# eval every 20
if (epoch + 1) % args.eval_every_epoch == 0:
eval_model = model if not args.ema4eval else ema
test_stats = evaluate(
eval_model,
criterion,
args.dataset_name,
data_loader_val,
device,
plot_density=True,
output_dir=output_dir,
epoch=epoch,
poly2seq=args.poly2seq,
add_cls_token=args.add_cls_token,
per_token_sem_loss=args.per_token_sem_loss,
wd_as_line=not args.disable_wd_as_line,
)
log_stats.update(**{f"test_{k}": v for k, v in test_stats.items()})
val_log_dict = {
"val/loss": test_stats["loss"],
"val/loss_ce": test_stats["loss_ce"],
"val/loss_coords": test_stats["loss_coords"],
"val/loss_coords_unscaled": test_stats["loss_coords_unscaled"],
"val/cardinality_error": test_stats["cardinality_error_unscaled"],
"val_metrics/room_prec": test_stats["room_prec"],
"val_metrics/room_rec": test_stats["room_rec"],
"val_metrics/corner_prec": test_stats["corner_prec"],
"val_metrics/corner_rec": test_stats["corner_rec"],
"val_metrics/angles_prec": test_stats["angles_prec"],
"val_metrics/angles_rec": test_stats["angles_rec"],
}
if args.semantic_classes > 0:
# need to log additional metrics for semantically-rich floorplans
val_log_dict["val/loss_ce_room"] = test_stats["loss_ce_room"]
val_log_dict["val_metrics/room_sem_prec"] = test_stats["room_sem_prec"]
val_log_dict["val_metrics/room_sem_rec"] = test_stats["room_sem_rec"]
if "window_door_prec" in test_stats:
val_log_dict["val_metrics/window_door_prec"] = test_stats["window_door_prec"]
val_log_dict["val_metrics/window_door_rec"] = test_stats["window_door_rec"]
else:
if "loss_raster" in test_stats:
# only apply the rasterization loss for non-semantic floorplans
val_log_dict["val/loss_raster"] = test_stats["loss_raster"]
if "room_iou" in test_stats:
val_log_dict["val_metrics/room_iou"] = test_stats["room_iou"]
if rank == 0:
wandb.log(val_log_dict)
if args.output_dir:
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Raster2Seq training script", parents=[get_args_parser()])
args = parser.parse_args()
now = datetime.datetime.now()
# run_id = now.strftime("%Y-%m-%d-%H-%M-%S")
args.run_name = args.job_name # run_id+'_'+args.job_name
args.output_dir = os.path.join(args.output_dir, args.run_name)
args.lr_drop = [] if len(args.lr_drop) == 0 else [int(x) for x in args.lr_drop.split(",")]
if args.debug:
args.batch_size = 1
if args.disable_poly_refine:
args.with_poly_refine = False
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)