Spaces:
Runtime error
Runtime error
| import copy | |
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from datasets.poly_data import TokenType | |
| from util.misc import NestedTensor, nested_tensor_from_tensor_list | |
| from .backbone import build_backbone | |
| from .deformable_transformer_v2 import build_deforamble_transformer | |
| from .label_smoothing_loss import label_smoothed_nll_loss | |
| from .losses import MaskRasterizationLoss | |
| def _get_clones(module, N): | |
| return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) | |
| class Raster2Seq(nn.Module): | |
| """This is the RoomFormer module that performs floorplan reconstruction""" | |
| def __init__( | |
| self, | |
| backbone, | |
| transformer, | |
| num_classes, | |
| num_queries, | |
| num_polys, | |
| num_feature_levels, | |
| aux_loss=True, | |
| with_poly_refine=False, | |
| masked_attn=False, | |
| semantic_classes=-1, | |
| seq_len=1024, | |
| tokenizer=None, | |
| use_anchor=False, | |
| patch_size=1, | |
| freeze_anchor=False, | |
| inject_cls_embed=False, | |
| ): | |
| """Initializes the model. | |
| Parameters: | |
| backbone: torch module of the backbone to be used. See backbone.py | |
| transformer: torch module of the transformer architecture. See transformer.py | |
| num_classes: number of object classes | |
| num_queries: number of object queries, ie detection slot. This is the maximal number of possible corners | |
| in a single image. | |
| num_polys: maximal number of possible polygons in a single image. | |
| num_queries/num_polys would be the maximal number of possible corners in a single polygon. | |
| aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. | |
| with_poly_refine: iterative polygon refinement | |
| """ | |
| super().__init__() | |
| self.num_queries = num_queries | |
| self.num_polys = num_polys | |
| assert num_queries % num_polys == 0 | |
| self.transformer = transformer | |
| hidden_dim = transformer.d_model | |
| self.num_classes = num_classes | |
| self.class_embed = nn.Linear(hidden_dim, num_classes) | |
| self.coords_embed = MLP(hidden_dim, hidden_dim, 2, 3) | |
| self.num_feature_levels = num_feature_levels | |
| self.tokenizer = tokenizer | |
| self.seq_len = seq_len | |
| self.patch_size = patch_size | |
| self.inject_cls_embed = inject_cls_embed | |
| # self.tgt_embed = nn.Embedding(num_queries, hidden_dim) | |
| if num_feature_levels > 1: | |
| num_backbone_outs = len(backbone.strides) | |
| input_proj_list = [] | |
| for _ in range(num_backbone_outs): | |
| in_channels = backbone.num_channels[_] | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size, padding=0), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ) | |
| for _ in range(num_feature_levels - num_backbone_outs): | |
| if patch_size == 1: | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ) | |
| else: | |
| input_proj_list.append( | |
| nn.Sequential( | |
| nn.Conv2d( | |
| in_channels, hidden_dim, kernel_size=2 * patch_size, stride=2 * patch_size, padding=0 | |
| ), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ) | |
| in_channels = hidden_dim | |
| self.input_proj = nn.ModuleList(input_proj_list) | |
| else: | |
| self.input_proj = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), | |
| nn.GroupNorm(32, hidden_dim), | |
| ) | |
| ] | |
| ) | |
| self.backbone = backbone | |
| self.aux_loss = aux_loss | |
| self.with_poly_refine = with_poly_refine | |
| prior_prob = 0.01 | |
| bias_value = -math.log((1 - prior_prob) / prior_prob) | |
| self.class_embed.bias.data = torch.ones(num_classes) * bias_value | |
| nn.init.constant_(self.coords_embed.layers[-1].weight.data, 0) | |
| nn.init.constant_(self.coords_embed.layers[-1].bias.data, 0) | |
| for proj in self.input_proj: | |
| nn.init.xavier_uniform_(proj[0].weight, gain=1) | |
| nn.init.constant_(proj[0].bias, 0) | |
| num_pred = transformer.decoder.num_layers | |
| if with_poly_refine: | |
| self.class_embed = _get_clones(self.class_embed, num_pred) | |
| self.coords_embed = _get_clones(self.coords_embed, num_pred) | |
| nn.init.constant_(self.coords_embed[0].layers[-1].bias.data[2:], -2.0) | |
| else: | |
| nn.init.constant_(self.coords_embed.layers[-1].bias.data[2:], -2.0) | |
| self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) | |
| self.coords_embed = nn.ModuleList([self.coords_embed for _ in range(num_pred)]) | |
| if use_anchor or with_poly_refine: | |
| self.query_embed = nn.Embedding(seq_len, 2) | |
| self.query_embed.weight.requires_grad = not freeze_anchor | |
| else: | |
| self.query_embed = None | |
| self.transformer.decoder.coords_embed = self.coords_embed | |
| self.transformer.decoder.class_embed = self.class_embed | |
| # Semantically-rich floorplan | |
| self.room_class_embed = None | |
| if semantic_classes > 0: | |
| self.room_class_embed = nn.Linear(hidden_dim, semantic_classes) | |
| if self.inject_cls_embed: | |
| self.transformer.decoder.room_class_embed = self.room_class_embed | |
| # self.num_queries_per_poly = num_queries // num_polys | |
| # # The attention mask is used to prevent object queries in one polygon attending to another polygon, default false | |
| # if masked_attn: | |
| # self.attention_mask = torch.ones((num_queries, num_queries), dtype=torch.bool) | |
| # for i in range(num_polys): | |
| # self.attention_mask[i * self.num_queries_per_poly:(i + 1) * self.num_queries_per_poly, | |
| # i * self.num_queries_per_poly:(i + 1) * self.num_queries_per_poly] = False | |
| # else: | |
| # self.attention_mask = None | |
| self.register_buffer("attention_mask", self._create_causal_attention_mask(seq_len)) | |
| def _create_causal_attention_mask(self, seq_len): | |
| """ | |
| Creates a causal attention mask for a sequence of length `seq_len`. | |
| """ | |
| # Create an upper triangular matrix with 1s above the diagonal | |
| mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) | |
| # Invert the mask: 1 -> -inf (masked), 0 -> 0 (unmasked) | |
| causal_mask = mask.masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) | |
| return causal_mask | |
| def forward(self, samples: NestedTensor, seq_kwargs=None): | |
| """The forward expects a NestedTensor, which consists of: | |
| - samples.tensors: batched images, of shape [batch_size x C x H x W] | |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
| It returns a dict with the following elements: | |
| - "pred_logits": the classification logits (including no-object) for all queries. | |
| Shape= [batch_size x num_queries x (num_classes + 1)] | |
| - "pred_coords": The normalized corner coordinates for all queries, represented as | |
| (x, y). These values are normalized in [0, 1], | |
| relative to the size of each individual image (disregarding possible padding). | |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
| dictionnaries containing the two above keys for each decoder layer. | |
| """ | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| features, pos = self.backbone(samples) | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): | |
| src, mask = feat.decompose() | |
| src = self.input_proj[l](src) | |
| srcs.append(src) | |
| if self.patch_size != 1: | |
| mask = F.interpolate(mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
| pos[l] = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| pos.append(pos_l) | |
| query_embeds = None if self.query_embed is None else self.query_embed.weight | |
| tgt_embeds = None | |
| hs, init_reference, inter_references, inter_classes = self.transformer( | |
| srcs, masks, pos, query_embeds, tgt_embeds, self.attention_mask, seq_kwargs | |
| ) | |
| outputs_class = inter_classes | |
| outputs_coord = inter_references | |
| out = {"pred_logits": outputs_class[-1], "pred_coords": outputs_coord[-1]} | |
| if self.room_class_embed is not None: | |
| outputs_room_class = self.room_class_embed(hs[-1]) | |
| out = { | |
| "pred_logits": outputs_class[-1], | |
| "pred_coords": outputs_coord[-1], | |
| "pred_room_logits": outputs_room_class, | |
| } | |
| if self.aux_loss: | |
| out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) | |
| return out | |
| def _prepare_sequences(self, b): | |
| prev_output_token_11 = [[self.tokenizer.bos] for _ in range(b)] | |
| prev_output_token_12 = [[self.tokenizer.bos] for _ in range(b)] | |
| prev_output_token_21 = [[self.tokenizer.bos] for _ in range(b)] | |
| prev_output_token_22 = [[self.tokenizer.bos] for _ in range(b)] | |
| delta_x1 = [[0] for _ in range(b)] | |
| delta_y1 = [[0] for _ in range(b)] | |
| delta_x2 = [[1] for _ in range(b)] | |
| delta_y2 = [[1] for _ in range(b)] | |
| gen_out = [[] for _ in range(b)] | |
| if self.inject_cls_embed: | |
| input_polygon_labels = [[self.semantic_classes - 1] for _ in range(b)] | |
| else: | |
| input_polygon_labels = [[-1] for _ in range(b)] # dummies values, not used in inference | |
| return ( | |
| prev_output_token_11, | |
| prev_output_token_12, | |
| prev_output_token_21, | |
| prev_output_token_22, | |
| delta_x1, | |
| delta_x2, | |
| delta_y1, | |
| delta_y2, | |
| gen_out, | |
| input_polygon_labels, | |
| ) | |
| def forward_inference(self, samples: NestedTensor, use_cache=True): | |
| """The forward expects a NestedTensor, which consists of: | |
| - samples.tensors: batched images, of shape [batch_size x C x H x W] | |
| - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels | |
| It returns a dict with the following elements: | |
| - "pred_logits": the classification logits (including no-object) for all queries. | |
| Shape= [batch_size x num_queries x (num_classes + 1)] | |
| - "pred_coords": The normalized corner coordinates for all queries, represented as | |
| (x, y). These values are normalized in [0, 1], | |
| relative to the size of each individual image (disregarding possible padding). | |
| - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of | |
| dictionnaries containing the two above keys for each decoder layer. | |
| """ | |
| if not isinstance(samples, NestedTensor): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| features, pos = self.backbone(samples) | |
| bs = samples.tensors.shape[0] | |
| srcs = [] | |
| masks = [] | |
| for l, feat in enumerate(features): | |
| src, mask = feat.decompose() | |
| src = self.input_proj[l](src) | |
| srcs.append(src) | |
| if self.patch_size != 1: | |
| mask = F.interpolate(mask[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
| pos[l] = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| masks.append(mask) | |
| assert mask is not None | |
| if self.num_feature_levels > len(srcs): | |
| _len_srcs = len(srcs) | |
| for l in range(_len_srcs, self.num_feature_levels): | |
| if l == _len_srcs: | |
| src = self.input_proj[l](features[-1].tensors) | |
| else: | |
| src = self.input_proj[l](srcs[-1]) | |
| m = samples.mask | |
| mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] | |
| pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) | |
| srcs.append(src) | |
| masks.append(mask) | |
| pos.append(pos_l) | |
| ##### decoder part | |
| if use_cache: | |
| # kv cache for faster inference | |
| max_src_len = sum([x.size(2) * x.size(3) for x in srcs]) # 1360 | |
| self._setup_caches(bs, max_src_len) | |
| ( | |
| prev_output_token_11, | |
| prev_output_token_12, | |
| prev_output_token_21, | |
| prev_output_token_22, | |
| delta_x1, | |
| delta_x2, | |
| delta_y1, | |
| delta_y2, | |
| gen_out, | |
| input_polygon_labels, | |
| ) = self._prepare_sequences(bs) | |
| query_embeds = None if self.query_embed is None else self.query_embed.weight | |
| # tgt_embeds = self.tgt_embed.weight | |
| tgt_embeds = None | |
| enc_cache = None | |
| device = samples.tensors.device | |
| num_bins = self.tokenizer.num_bins | |
| min_len = 6 | |
| max_len = self.tokenizer.seq_len | |
| unfinish_flag = np.ones(bs) | |
| i = 0 | |
| output_hs_list = [] | |
| while i < max_len and unfinish_flag.any(): | |
| prev_output_tokens_11_tensor = torch.tensor(np.array(prev_output_token_11)[:, i : i + 1]).to(device).long() | |
| prev_output_tokens_12_tensor = torch.tensor(np.array(prev_output_token_12)[:, i : i + 1]).to(device).long() | |
| prev_output_tokens_21_tensor = torch.tensor(np.array(prev_output_token_21)[:, i : i + 1]).to(device).long() | |
| prev_output_tokens_22_tensor = torch.tensor(np.array(prev_output_token_22)[:, i : i + 1]).to(device).long() | |
| delta_x1_tensor = torch.tensor(np.array(delta_x1)[:, i : i + 1], dtype=torch.float32).to(device) | |
| delta_x2_tensor = torch.tensor(np.array(delta_x2)[:, i : i + 1], dtype=torch.float32).to(device) | |
| delta_y1_tensor = torch.tensor(np.array(delta_y1)[:, i : i + 1], dtype=torch.float32).to(device) | |
| delta_y2_tensor = torch.tensor(np.array(delta_y2)[:, i : i + 1], dtype=torch.float32).to(device) | |
| input_polygon_labels_tensor = torch.tensor( | |
| np.array(input_polygon_labels)[:, i : i + 1], dtype=torch.long | |
| ).to(device) | |
| seq_kwargs = { | |
| "seq11": prev_output_tokens_11_tensor, | |
| "seq12": prev_output_tokens_12_tensor, | |
| "seq21": prev_output_tokens_21_tensor, | |
| "seq22": prev_output_tokens_22_tensor, | |
| "delta_x1": delta_x1_tensor, | |
| "delta_x2": delta_x2_tensor, | |
| "delta_y1": delta_y1_tensor, | |
| "delta_y2": delta_y2_tensor, | |
| "input_polygon_labels": input_polygon_labels_tensor, | |
| } | |
| if not use_cache: | |
| hs, _, reg_output, cls_output = self.transformer( | |
| srcs, | |
| masks, | |
| pos, | |
| query_embeds, | |
| tgt_embeds, | |
| None, | |
| seq_kwargs, | |
| force_simple_returns=True, | |
| return_enc_cache=use_cache, | |
| enc_cache=None, | |
| decode_token_pos=None, | |
| ) | |
| output_hs_list.append(hs[:, i : i + 1]) | |
| else: | |
| decode_token_pos = torch.tensor([i], device=device, dtype=torch.long) | |
| hs, _, reg_output, cls_output, enc_cache = self.transformer( | |
| srcs, | |
| masks, | |
| pos, | |
| query_embeds, | |
| tgt_embeds, | |
| None, | |
| seq_kwargs, | |
| force_simple_returns=True, | |
| return_enc_cache=use_cache, | |
| enc_cache=enc_cache, | |
| decode_token_pos=decode_token_pos, | |
| ) | |
| output_hs_list.append(hs) | |
| cls_type = torch.argmax(cls_output, 2) | |
| # print(cls_type, torch.softmax(cls_output, dim=2)[:, :, cls_type], torch.topk(torch.softmax(cls_output, dim=2), k=3)) | |
| for j in range(bs): | |
| if unfinish_flag[j] == 1: # prediction is not finished | |
| cls_j = cls_type[j, 0].item() | |
| if cls_j == TokenType.coord.value or (cls_j == TokenType.eos.value and i < min_len): | |
| output_j_x, output_j_y = reg_output[j, 0].detach().cpu().numpy() | |
| output_j_x = min(output_j_x, 1) | |
| output_j_y = min(output_j_y, 1) | |
| gen_out[j].append([output_j_x, output_j_y]) | |
| output_j_x = output_j_x * (num_bins - 1) | |
| output_j_y = output_j_y * (num_bins - 1) | |
| output_j_x_floor = math.floor(output_j_x) | |
| output_j_y_floor = math.floor(output_j_y) | |
| output_j_x_ceil = math.ceil(output_j_x) | |
| output_j_y_ceil = math.ceil(output_j_y) | |
| # tokenization | |
| prev_output_token_11[j].append(output_j_x_floor * num_bins + output_j_y_floor) | |
| prev_output_token_12[j].append(output_j_x_floor * num_bins + output_j_y_ceil) | |
| prev_output_token_21[j].append(output_j_x_ceil * num_bins + output_j_y_floor) | |
| prev_output_token_22[j].append(output_j_x_ceil * num_bins + output_j_y_ceil) | |
| delta_x = output_j_x - output_j_x_floor | |
| delta_y = output_j_y - output_j_y_floor | |
| elif cls_j == TokenType.sep.value: | |
| gen_out[j].append(2) | |
| prev_output_token_11[j].append(self.tokenizer.sep) | |
| prev_output_token_12[j].append(self.tokenizer.sep) | |
| prev_output_token_21[j].append(self.tokenizer.sep) | |
| prev_output_token_22[j].append(self.tokenizer.sep) | |
| delta_x = 0 | |
| delta_y = 0 | |
| elif cls_j == TokenType.cls.value: | |
| gen_out[j].append(-1) | |
| prev_output_token_11[j].append(self.tokenizer.cls) | |
| prev_output_token_12[j].append(self.tokenizer.cls) | |
| prev_output_token_21[j].append(self.tokenizer.cls) | |
| prev_output_token_22[j].append(self.tokenizer.cls) | |
| delta_x = 0 | |
| delta_y = 0 | |
| else: # eos is predicted and i >= min_len | |
| unfinish_flag[j] = 0 | |
| gen_out[j].append(-1) | |
| prev_output_token_11[j].append(self.tokenizer.eos) | |
| prev_output_token_12[j].append(self.tokenizer.eos) | |
| prev_output_token_21[j].append(self.tokenizer.eos) | |
| prev_output_token_22[j].append(self.tokenizer.eos) | |
| delta_x = 0 | |
| delta_y = 0 | |
| else: # prediction is finished | |
| gen_out[j].append(-1) | |
| prev_output_token_11[j].append(self.tokenizer.pad) | |
| prev_output_token_12[j].append(self.tokenizer.pad) | |
| prev_output_token_21[j].append(self.tokenizer.pad) | |
| prev_output_token_22[j].append(self.tokenizer.pad) | |
| delta_x = 0 | |
| delta_y = 0 | |
| delta_x1[j].append(delta_x) | |
| delta_y1[j].append(delta_y) | |
| delta_x2[j].append(1 - delta_x) | |
| delta_y2[j].append(1 - delta_y) | |
| i += 1 | |
| out = {"pred_logits": cls_output, "pred_coords": reg_output, "gen_out": gen_out} | |
| # hack implementation of room label prediction, not compatible with auxiliary loss | |
| if self.room_class_embed is not None: | |
| hs = torch.cat(output_hs_list, dim=1) | |
| outputs_room_class = self.room_class_embed(hs) | |
| out = { | |
| "pred_logits": cls_output, | |
| "pred_coords": reg_output, | |
| "pred_room_logits": outputs_room_class, | |
| "gen_out": gen_out, | |
| "anchors": query_embeds.detach(), | |
| } | |
| return out | |
| def _set_aux_loss(self, outputs_class, outputs_coord): | |
| return [{"pred_logits": a, "pred_coords": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] | |
| def _setup_caches(self, max_bs, max_src_len): | |
| self.transformer._setup_caches( | |
| max_bs, | |
| self.seq_len, | |
| max_src_len, | |
| self.transformer.d_model, | |
| self.transformer.nhead, | |
| self.transformer.level_embed.dtype, | |
| device=self.transformer.level_embed.device, | |
| ) | |
| class SemHead(nn.Module): | |
| def __init__(self, hidden_dim, num_classes): | |
| super().__init__() | |
| self.shared_layer = nn.Linear(hidden_dim, hidden_dim) | |
| self.room_embed = nn.Linear(hidden_dim, num_classes - 2) | |
| self.num_classes = num_classes | |
| self.window_door_embed = nn.Linear(hidden_dim, 2) | |
| def forward(self, x): | |
| x = F.normalize(torch.relu(self.shared_layer(x)), p=2, dim=-1, eps=1e-12) | |
| room_out = self.room_embed(x) | |
| window_door_out = self.window_door_embed(x) | |
| out = torch.cat([room_out[:, :, :-1], window_door_out, room_out[:, :, -1:]], dim=-1) | |
| return out.contiguous() | |
| class SetCriterion(nn.Module): | |
| """This class computes the loss for multiple polygons. | |
| The process happens in two steps: | |
| 1) we compute hungarian assignment between ground truth polygons and the outputs of the model | |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and coords) | |
| """ | |
| def __init__( | |
| self, | |
| num_classes, | |
| semantic_classes, | |
| matcher, | |
| weight_dict, | |
| losses, | |
| label_smoothing=0.0, | |
| per_token_sem_loss=False, | |
| ): | |
| """Create the criterion. | |
| Parameters: | |
| num_classes: number of classes for corner validity (binary) | |
| semantic_classes: number of semantic classes for polygon (room type, door, window) | |
| matcher: module able to compute a matching between targets and proposals | |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| """ | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.semantic_classes = semantic_classes | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.losses = losses | |
| self.label_smoothing = label_smoothing | |
| self.per_token_sem_loss = per_token_sem_loss | |
| if "loss_raster" in self.weight_dict: | |
| self.raster_loss = MaskRasterizationLoss(None) | |
| def _update_ce_coeff(self, loss_ce_coeff): | |
| self.weight_dict["loss_ce"] = loss_ce_coeff | |
| def loss_labels(self, outputs, targets, indices): | |
| """Classification loss (NLL) | |
| targets dicts must contain the key "labels" | |
| """ | |
| assert "pred_logits" in outputs | |
| src_logits = outputs["pred_logits"] | |
| target_classes = targets["token_labels"].to(src_logits.device) | |
| mask = (target_classes != -1).bool() | |
| loss_ce = label_smoothed_nll_loss( | |
| src_logits[mask], target_classes[mask], epsilon=self.label_smoothing, reduction="mean" | |
| ) | |
| losses = {"loss_ce": loss_ce} | |
| if "pred_room_logits" in outputs: | |
| room_src_logits = outputs["pred_room_logits"] | |
| if not self.per_token_sem_loss: | |
| mask = target_classes == 3 # cls token | |
| room_target_classes = targets["target_polygon_labels"].to(room_src_logits.device) | |
| loss_ce_room = label_smoothed_nll_loss( | |
| room_src_logits[mask], | |
| room_target_classes[room_target_classes != -1], | |
| epsilon=self.label_smoothing, | |
| reduction="mean", | |
| ) | |
| else: | |
| room_target_classes = targets["target_polygon_labels"].to(room_src_logits.device) | |
| loss_ce_room = label_smoothed_nll_loss( | |
| room_src_logits[room_target_classes != -1], | |
| room_target_classes[room_target_classes != -1], | |
| epsilon=self.label_smoothing, | |
| reduction="mean", | |
| ) | |
| losses = {"loss_ce": loss_ce, "loss_ce_room": loss_ce_room} | |
| return losses | |
| def loss_cardinality(self, outputs, targets, indices): | |
| """Compute the cardinality error, ie the absolute error in the number of predicted non-empty corners | |
| This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients | |
| """ | |
| losses = {"cardinality_error": 0.0} | |
| return losses | |
| def _extract_polygons(self, sequence, token_labels): | |
| # sequence: [B, N, 2], token_labels: [B, N] | |
| B, N = token_labels.shape | |
| polygons = [] | |
| for b in range(B): | |
| labels = token_labels[b] # [N] | |
| coords = sequence[b] # [N, 2] | |
| # Find separator and EOS positions | |
| sep_eos_mask = (labels == 1) | (labels == 2) | |
| split_indices = torch.nonzero(sep_eos_mask, as_tuple=False).squeeze(-1) | |
| # Handle empty case | |
| if len(split_indices) == 0: | |
| # No separators found, treat entire sequence as one polygon | |
| corner_mask = labels == 0 | |
| if corner_mask.any(): | |
| polygons.append(coords[corner_mask]) | |
| continue | |
| # Create start and end indices | |
| device = labels.device | |
| starts = torch.cat([torch.tensor([0], device=device), split_indices[:-1] + 1]) | |
| ends = split_indices | |
| # Extract polygons between separators | |
| for s, e in zip(starts, ends): | |
| if s < e: # Valid range | |
| segment_labels = labels[s:e] | |
| segment_coords = coords[s:e] | |
| corner_mask = segment_labels == 0 | |
| if corner_mask.any(): | |
| polygons.append(segment_coords[corner_mask]) | |
| return polygons | |
| def loss_polys(self, outputs, targets, indices): | |
| """Compute the losses related to the polygons: | |
| 1. L1 loss for polygon coordinates | |
| 2. Dice loss for polygon rasterizated binary masks | |
| """ | |
| assert "pred_coords" in outputs | |
| src_poly = outputs["pred_coords"] | |
| device = src_poly.device | |
| token_labels = targets["token_labels"].to(device) | |
| mask = (token_labels == 0).bool() | |
| target_polys = targets["target_seq"].to(device) | |
| loss_coords = F.l1_loss(src_poly[mask], target_polys[mask]) | |
| losses = {} | |
| losses["loss_coords"] = loss_coords | |
| # omit the rasterization loss for semantically-rich floorplan | |
| if self.weight_dict.get("loss_raster", 0) > 0: | |
| pred_poly_list = self._extract_polygons(src_poly, token_labels) | |
| target_poly_list = self._extract_polygons(target_polys, token_labels) | |
| loss_raster_mask = self.raster_loss( | |
| pred_poly_list, | |
| target_poly_list, | |
| [len(x) for x in target_poly_list], | |
| ) | |
| losses["loss_raster"] = loss_raster_mask | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def _get_tgt_permutation_idx(self, indices): | |
| # permute targets following indices | |
| batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
| return batch_idx, tgt_idx | |
| def get_loss(self, loss, outputs, targets, indices, **kwargs): | |
| loss_map = {"labels": self.loss_labels, "cardinality": self.loss_cardinality, "polys": self.loss_polys} | |
| assert loss in loss_map, f"do you really want to compute {loss} loss?" | |
| return loss_map[loss](outputs, targets, indices, **kwargs) | |
| def forward(self, outputs, targets): | |
| """This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| """ | |
| indices = None | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| kwargs = {} | |
| losses.update(self.get_loss(loss, outputs, targets, indices, **kwargs)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if "aux_outputs" in outputs: | |
| for i, aux_outputs in enumerate(outputs["aux_outputs"]): | |
| for loss in self.losses: | |
| l_dict = self.get_loss(loss, aux_outputs, targets, indices) | |
| l_dict = {k + f"_{i}": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| if "enc_outputs" in outputs: | |
| enc_outputs = outputs["enc_outputs"] | |
| indices = self.matcher(enc_outputs, targets) | |
| for loss in self.losses: | |
| l_dict = self.get_loss(loss, enc_outputs, targets, indices) | |
| l_dict = {k + "_enc": v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| return losses | |
| class MLP(nn.Module): | |
| """Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def build(args, train=True, tokenizer=None): | |
| num_classes = 3 if not args.add_cls_token else 4 # <coord> <sep> <eos> <cls> | |
| if tokenizer is not None: | |
| pad_idx = tokenizer.pad | |
| backbone = build_backbone(args) | |
| transformer = build_deforamble_transformer(args, pad_idx=pad_idx) | |
| model = Raster2Seq( | |
| backbone, | |
| transformer, | |
| num_classes=num_classes, | |
| num_queries=args.num_queries, | |
| num_polys=args.num_polys, | |
| num_feature_levels=args.num_feature_levels, | |
| aux_loss=args.aux_loss, | |
| with_poly_refine=args.with_poly_refine, | |
| masked_attn=args.masked_attn, | |
| semantic_classes=args.semantic_classes, | |
| seq_len=args.seq_len, | |
| tokenizer=tokenizer, | |
| use_anchor=args.use_anchor, | |
| patch_size=[1, 2][args.image_size == 512], # 1 for 256x256, 2 for 512x512 | |
| freeze_anchor=getattr(args, "freeze_anchor", False), | |
| inject_cls_embed=getattr(args, "inject_cls_embed", False), | |
| ) | |
| if not train: | |
| return model | |
| device = torch.device(args.device) | |
| matcher = None # build_matcher(args) | |
| weight_dict = { | |
| "loss_ce": args.cls_loss_coef, | |
| "loss_ce_room": args.room_cls_loss_coef, | |
| "loss_coords": args.coords_loss_coef, | |
| } | |
| if args.raster_loss_coef > 0: | |
| weight_dict["loss_raster"] = args.raster_loss_coef | |
| weight_dict["loss_dir"] = 1 | |
| enc_weight_dict = {} | |
| enc_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) | |
| weight_dict.update(enc_weight_dict) | |
| # TODO this is a hack | |
| if args.aux_loss: | |
| aux_weight_dict = {} | |
| for i in range(args.dec_layers - 1): | |
| aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | |
| aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) | |
| weight_dict.update(aux_weight_dict) | |
| losses = ["labels", "polys", "cardinality"] | |
| # num_classes, matcher, weight_dict, losses | |
| criterion = SetCriterion( | |
| num_classes, | |
| args.semantic_classes, | |
| matcher, | |
| weight_dict, | |
| losses, | |
| label_smoothing=args.label_smoothing, | |
| per_token_sem_loss=args.per_token_sem_loss, | |
| ) | |
| criterion.to(device) | |
| return model, criterion | |