Spaces:
Running
on
Zero
Running
on
Zero
fix bug and tested locally.
Browse files- utils/model.py +211 -262
- utils/predict.py +1 -1
utils/model.py
CHANGED
|
@@ -157,7 +157,7 @@ class OwlViTForClassification(nn.Module):
|
|
| 157 |
config_class = OwlViTConfig
|
| 158 |
|
| 159 |
def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
|
| 160 |
-
super(
|
| 161 |
|
| 162 |
self.config = owlvit_det_model.config
|
| 163 |
self.num_classes = num_classes
|
|
@@ -202,12 +202,12 @@ class OwlViTForClassification(nn.Module):
|
|
| 202 |
losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
|
| 203 |
losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
|
| 204 |
|
| 205 |
-
self.criterion = DetrLoss(
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
)
|
| 211 |
|
| 212 |
self.freeze_parameters(freeze_box_heads, train_box_heads_only)
|
| 213 |
del owlvit_det_model
|
|
@@ -417,22 +417,7 @@ class OwlViTForClassification(nn.Module):
|
|
| 417 |
topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
|
| 418 |
|
| 419 |
else:
|
| 420 |
-
|
| 421 |
-
print(f"text_inputs_parts - input_ids: {text_inputs_parts['input_ids'].shape}. attention_mask : {text_inputs_parts['attention_mask'].shape}")
|
| 422 |
-
seq_length = text_inputs_parts['input_ids'].shape[-1]
|
| 423 |
-
position_ids = self.owlvit.text_model.embeddings.position_ids[:, :seq_length]
|
| 424 |
-
txt_embeds = self.owlvit.text_model.embeddings.token_embedding(text_inputs_parts['input_ids'])
|
| 425 |
-
print(f"position_embedding: {self.owlvit.text_model.embeddings.position_embedding(position_ids).shape}")
|
| 426 |
-
print(f"text_embeds: {txt_embeds.shape}")
|
| 427 |
-
|
| 428 |
-
device_ = txt_embeds.device
|
| 429 |
-
position_ids = position_ids.to(device_)
|
| 430 |
-
txt_embeds_size_0 = text_embeds.size(0)
|
| 431 |
-
position_embedding = position_ids.cpu().repeat(txt_embeds_size_0, 1, 1)
|
| 432 |
-
text_inputs_parts["position_ids"] = position_ids
|
| 433 |
-
print(f"position_embedding : {position_embedding.shape}")
|
| 434 |
-
print(f"pos + emb: {(txt_embeds.cpu() + position_embedding).shape}")
|
| 435 |
-
text_embeds_parts = self.owlvit.text_model.get_text_features(**text_inputs_parts)
|
| 436 |
|
| 437 |
# # Embed images and text queries
|
| 438 |
query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
|
|
@@ -460,46 +445,10 @@ class OwlViTForClassification(nn.Module):
|
|
| 460 |
outputs_loss["logits"] = pred_logits_parts
|
| 461 |
outputs_loss["pred_boxes"] = pred_boxes
|
| 462 |
|
| 463 |
-
# Compute box + class losses
|
| 464 |
-
loss_dict = self.criterion(outputs_loss, targets, mapping_indices)
|
| 465 |
-
|
| 466 |
-
# Compute symmetric loss to get rid of the teacher model
|
| 467 |
-
logits_per_image = torch.softmax(pred_logits_parts, dim=1)
|
| 468 |
-
logits_per_text = torch.softmax(pred_logits_parts, dim=-1)
|
| 469 |
-
|
| 470 |
-
# For getting rid of the teacher model
|
| 471 |
-
if self.weight_dict["loss_sym_box_label"] > 0:
|
| 472 |
-
sym_loss_box_label = self.loss_symmetric(logits_per_image, logits_per_text, teacher_boxes_logits)
|
| 473 |
-
loss_dict["loss_sym_box_label"] = sym_loss_box_label
|
| 474 |
-
# ----------------------------------------------------------------------------------------
|
| 475 |
-
|
| 476 |
-
#DEBUG:
|
| 477 |
-
print(f"im_features size: {image_feats.shape}, text_embeds size: {text_embeds.shape}")
|
| 478 |
-
print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
|
| 479 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
| 480 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
print(f"image_text_logits sum: {image_text_logits.sum().item()}")
|
| 484 |
-
|
| 485 |
-
if self.weight_dict["loss_xclip"] > 0:
|
| 486 |
-
targets_cls = torch.tensor([target["targets_cls"] for target in targets]).unsqueeze(1).to(self.device)
|
| 487 |
-
if self.network_type == "classification":
|
| 488 |
-
one_hot = torch.zeros_like(pred_logits).scatter(1, targets_cls, 1).to(self.device)
|
| 489 |
-
cls_loss = self.ce_loss(pred_logits, one_hot)
|
| 490 |
-
loss_dict["loss_xclip"] = cls_loss
|
| 491 |
-
else:
|
| 492 |
-
# TODO: Need a linear classifier for this approach
|
| 493 |
-
# Compute symmetric loss for part-descriptor contrastive learning
|
| 494 |
-
logits_per_image = torch.softmax(image_text_logits, dim=0)
|
| 495 |
-
logits_per_text = torch.softmax(image_text_logits, dim=-1)
|
| 496 |
-
sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
|
| 497 |
-
loss_dict["loss_xclip"] = sym_loss
|
| 498 |
-
|
| 499 |
-
#DEBUG:
|
| 500 |
-
print(f"pred_logits size: {part_logits.shape}, pred_logits size: {part_logits.shape}")
|
| 501 |
-
print(f"part_logits sum: {pred_logits.sum().item()}, part_logits sum: {pred_logits.sum().item()}")
|
| 502 |
-
return pred_logits, part_logits, loss_dict
|
| 503 |
|
| 504 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|
| 505 |
# text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
|
|
@@ -537,204 +486,204 @@ class OwlViTForClassification(nn.Module):
|
|
| 537 |
|
| 538 |
return sym_loss
|
| 539 |
|
| 540 |
-
class DetrLoss(nn.Module):
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
|
|
|
| 157 |
config_class = OwlViTConfig
|
| 158 |
|
| 159 |
def __init__(self, owlvit_det_model, num_classes, weight_dict, device, freeze_box_heads=False, train_box_heads_only=False, network_type=None, logits_from_teacher=False, finetuned: bool = False, custom_box_head: bool = False):
|
| 160 |
+
super().__init__()
|
| 161 |
|
| 162 |
self.config = owlvit_det_model.config
|
| 163 |
self.num_classes = num_classes
|
|
|
|
| 202 |
losses += ["boxes"] if weight_dict["loss_bbox"] > 0 else []
|
| 203 |
losses += ["labels"] if weight_dict["loss_ce"] > 0 else []
|
| 204 |
|
| 205 |
+
# self.criterion = DetrLoss(
|
| 206 |
+
# matcher=None,
|
| 207 |
+
# num_parts=self.num_parts,
|
| 208 |
+
# eos_coef=0.1, # Following facebook/detr-resnet-50
|
| 209 |
+
# losses=losses,
|
| 210 |
+
# )
|
| 211 |
|
| 212 |
self.freeze_parameters(freeze_box_heads, train_box_heads_only)
|
| 213 |
del owlvit_det_model
|
|
|
|
| 417 |
topk_scores, topk_idxs = torch.topk(teacher_boxes_logits, k=1, dim=1)
|
| 418 |
|
| 419 |
else:
|
| 420 |
+
text_embeds_parts = self.owlvit.get_text_features(**text_inputs_parts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
|
| 422 |
# # Embed images and text queries
|
| 423 |
query_mask, text_embeds_parts = self._get_text_query_mask(text_inputs_parts, text_embeds_parts, batch_size)
|
|
|
|
| 445 |
outputs_loss["logits"] = pred_logits_parts
|
| 446 |
outputs_loss["pred_boxes"] = pred_boxes
|
| 447 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
| 449 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
| 450 |
+
|
| 451 |
+
return pred_logits, part_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|
| 454 |
# text/image logits (batch_size*num_boxes, num_classes*num_descs): The logits that softmax over text descriptors or boxes
|
|
|
|
| 486 |
|
| 487 |
return sym_loss
|
| 488 |
|
| 489 |
+
# class DetrLoss(nn.Module):
|
| 490 |
+
# """
|
| 491 |
+
# This class computes the losses for DetrForObjectDetection/DetrForSegmentation. The process happens in two steps: 1)
|
| 492 |
+
# we compute hungarian assignment between ground truth boxes and the outputs of the model 2) we supervise each pair
|
| 493 |
+
# of matched ground-truth / prediction (supervise class and box).
|
| 494 |
+
|
| 495 |
+
# A note on the `num_classes` argument (copied from original repo in detr.py): "the naming of the `num_classes`
|
| 496 |
+
# parameter of the criterion is somewhat misleading. It indeed corresponds to `max_obj_id` + 1, where `max_obj_id` is
|
| 497 |
+
# the maximum id for a class in your dataset. For example, COCO has a `max_obj_id` of 90, so we pass `num_classes` to
|
| 498 |
+
# be 91. As another example, for a dataset that has a single class with `id` 1, you should pass `num_classes` to be 2
|
| 499 |
+
# (`max_obj_id` + 1). For more details on this, check the following discussion
|
| 500 |
+
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223"
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
# Args:
|
| 504 |
+
# matcher (`DetrHungarianMatcher`):
|
| 505 |
+
# Module able to compute a matching between targets and proposals.
|
| 506 |
+
# num_parts (`int`):
|
| 507 |
+
# Number of object categories, omitting the special no-object category.
|
| 508 |
+
# eos_coef (`float`):
|
| 509 |
+
# Relative classification weight applied to the no-object category.
|
| 510 |
+
# losses (`List[str]`):
|
| 511 |
+
# List of all the losses to be applied. See `get_loss` for a list of all available losses.
|
| 512 |
+
# """
|
| 513 |
+
|
| 514 |
+
# def __init__(self, matcher, num_parts, eos_coef, losses):
|
| 515 |
+
# super().__init__()
|
| 516 |
+
# self.matcher = matcher
|
| 517 |
+
# self.num_parts = num_parts
|
| 518 |
+
# self.eos_coef = eos_coef
|
| 519 |
+
# self.losses = losses
|
| 520 |
+
|
| 521 |
+
# # empty_weight = torch.ones(self.num_parts + 1)
|
| 522 |
+
# empty_weight = torch.ones(self.num_parts)
|
| 523 |
+
# empty_weight[-1] = self.eos_coef
|
| 524 |
+
# self.register_buffer("empty_weight", empty_weight)
|
| 525 |
+
|
| 526 |
+
# # removed logging parameter, which was part of the original implementation
|
| 527 |
+
# def loss_labels(self, outputs, targets, indices, num_boxes):
|
| 528 |
+
# """
|
| 529 |
+
# Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
|
| 530 |
+
# [nb_target_boxes]
|
| 531 |
+
# """
|
| 532 |
+
# if "logits" not in outputs:
|
| 533 |
+
# raise KeyError("No logits were found in the outputs")
|
| 534 |
+
# source_logits = outputs["logits"]
|
| 535 |
+
|
| 536 |
+
# idx = self._get_source_permutation_idx(indices)
|
| 537 |
+
# # target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
|
| 538 |
+
# # target_classes = torch.full(source_logits.shape[:2], self.num_parts, dtype=torch.int64, device=source_logits.device)
|
| 539 |
+
# # target_classes[idx] = target_classes_o
|
| 540 |
+
|
| 541 |
+
# source_logits = source_logits[idx].view(len(indices), -1, self.num_parts)
|
| 542 |
+
# target_classes = torch.stack([t["class_labels"][J] for t, (_, J) in zip(targets, indices)], dim=0)
|
| 543 |
+
|
| 544 |
+
# loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes, self.empty_weight)
|
| 545 |
+
# losses = {"loss_ce": loss_ce}
|
| 546 |
+
|
| 547 |
+
# return losses
|
| 548 |
+
|
| 549 |
+
# @torch.no_grad()
|
| 550 |
+
# def loss_cardinality(self, outputs, targets, indices, num_boxes):
|
| 551 |
+
# """
|
| 552 |
+
# Compute the cardinality error, i.e. the absolute error in the number of predicted non-empty boxes.
|
| 553 |
+
|
| 554 |
+
# This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients.
|
| 555 |
+
# """
|
| 556 |
+
# logits = outputs["logits"]
|
| 557 |
+
# device = logits.device
|
| 558 |
+
# target_lengths = torch.as_tensor([len(v["class_labels"]) for v in targets], device=device)
|
| 559 |
+
# # Count the number of predictions that are NOT "no-object" (which is the last class)
|
| 560 |
+
# card_pred = (logits.argmax(-1) != logits.shape[-1] - 1).sum(1)
|
| 561 |
+
# card_err = nn.functional.l1_loss(card_pred.float(), target_lengths.float())
|
| 562 |
+
# losses = {"cardinality_error": card_err}
|
| 563 |
+
# return losses
|
| 564 |
+
|
| 565 |
+
# def loss_boxes(self, outputs, targets, indices, num_boxes):
|
| 566 |
+
# """
|
| 567 |
+
# Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss.
|
| 568 |
+
|
| 569 |
+
# Targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]. The target boxes
|
| 570 |
+
# are expected in format (center_x, center_y, w, h), normalized by the image size.
|
| 571 |
+
# """
|
| 572 |
+
# if "pred_boxes" not in outputs:
|
| 573 |
+
# raise KeyError("No predicted boxes found in outputs")
|
| 574 |
+
|
| 575 |
+
# idx = self._get_source_permutation_idx(indices)
|
| 576 |
+
# source_boxes = outputs["pred_boxes"][idx]
|
| 577 |
+
# target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
| 578 |
+
|
| 579 |
+
# losses = {}
|
| 580 |
+
|
| 581 |
+
# loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
| 582 |
+
# losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
| 583 |
+
|
| 584 |
+
# loss_giou = 1 - torch.diag(generalized_box_iou(center_to_corners_format(source_boxes), center_to_corners_format(target_boxes)))
|
| 585 |
+
# losses["loss_giou"] = loss_giou.sum() / num_boxes
|
| 586 |
+
|
| 587 |
+
# return losses
|
| 588 |
+
|
| 589 |
+
# def loss_masks(self, outputs, targets, indices, num_boxes):
|
| 590 |
+
# """
|
| 591 |
+
# Compute the losses related to the masks: the focal loss and the dice loss.
|
| 592 |
+
|
| 593 |
+
# Targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w].
|
| 594 |
+
# """
|
| 595 |
+
# if "pred_masks" not in outputs:
|
| 596 |
+
# raise KeyError("No predicted masks found in outputs")
|
| 597 |
+
|
| 598 |
+
# source_idx = self._get_source_permutation_idx(indices)
|
| 599 |
+
# target_idx = self._get_target_permutation_idx(indices)
|
| 600 |
+
# source_masks = outputs["pred_masks"]
|
| 601 |
+
# source_masks = source_masks[source_idx]
|
| 602 |
+
# masks = [t["masks"] for t in targets]
|
| 603 |
+
|
| 604 |
+
# # TODO use valid to mask invalid areas due to padding in loss
|
| 605 |
+
# target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
| 606 |
+
# target_masks = target_masks.to(source_masks)
|
| 607 |
+
# target_masks = target_masks[target_idx]
|
| 608 |
+
|
| 609 |
+
# # upsample predictions to the target size
|
| 610 |
+
# source_masks = nn.functional.interpolate(
|
| 611 |
+
# source_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
|
| 612 |
+
# )
|
| 613 |
+
# source_masks = source_masks[:, 0].flatten(1)
|
| 614 |
+
|
| 615 |
+
# target_masks = target_masks.flatten(1)
|
| 616 |
+
# target_masks = target_masks.view(source_masks.shape)
|
| 617 |
+
# losses = {
|
| 618 |
+
# "loss_mask": sigmoid_focal_loss(source_masks, target_masks, num_boxes),
|
| 619 |
+
# "loss_dice": dice_loss(source_masks, target_masks, num_boxes),
|
| 620 |
+
# }
|
| 621 |
+
# return losses
|
| 622 |
+
|
| 623 |
+
# def _get_source_permutation_idx(self, indices):
|
| 624 |
+
# # permute predictions following indices
|
| 625 |
+
# batch_idx = torch.cat([torch.full_like(source, i) for i, (source, _) in enumerate(indices)])
|
| 626 |
+
# source_idx = torch.cat([source for (source, _) in indices])
|
| 627 |
+
# return batch_idx, source_idx
|
| 628 |
+
|
| 629 |
+
# def _get_target_permutation_idx(self, indices):
|
| 630 |
+
# # permute targets following indices
|
| 631 |
+
# batch_idx = torch.cat([torch.full_like(target, i) for i, (_, target) in enumerate(indices)])
|
| 632 |
+
# target_idx = torch.cat([target for (_, target) in indices])
|
| 633 |
+
# return batch_idx, target_idx
|
| 634 |
+
|
| 635 |
+
# def get_loss(self, loss, outputs, targets, indices, num_boxes):
|
| 636 |
+
# loss_map = {
|
| 637 |
+
# "labels": self.loss_labels,
|
| 638 |
+
# "cardinality": self.loss_cardinality,
|
| 639 |
+
# "boxes": self.loss_boxes,
|
| 640 |
+
# "masks": self.loss_masks,
|
| 641 |
+
# }
|
| 642 |
+
# if loss not in loss_map:
|
| 643 |
+
# raise ValueError(f"Loss {loss} not supported")
|
| 644 |
+
# return loss_map[loss](outputs, targets, indices, num_boxes)
|
| 645 |
+
|
| 646 |
+
# def forward(self, outputs, targets, indices):
|
| 647 |
+
# """
|
| 648 |
+
# This performs the loss computation.
|
| 649 |
+
|
| 650 |
+
# Args:
|
| 651 |
+
# outputs (`dict`, *optional*):
|
| 652 |
+
# Dictionary of tensors, see the output specification of the model for the format.
|
| 653 |
+
# targets (`List[dict]`, *optional*):
|
| 654 |
+
# List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
| 655 |
+
# losses applied, see each loss' doc.
|
| 656 |
+
# """
|
| 657 |
+
# outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
| 658 |
+
|
| 659 |
+
# # ThangPM: Do NOT use bipartite matching --> Use the boxes selected by argmax for computing symmetric loss
|
| 660 |
+
# # Retrieve the matching between the outputs of the last layer and the targets
|
| 661 |
+
# # indices = self.matcher(outputs_without_aux, targets)
|
| 662 |
+
|
| 663 |
+
# # Compute the average number of target boxes across all nodes, for normalization purposes
|
| 664 |
+
# num_boxes = sum(len(t["class_labels"]) for t in targets)
|
| 665 |
+
# num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
| 666 |
+
# # (Niels): comment out function below, distributed training to be added
|
| 667 |
+
# # if is_dist_avail_and_initialized():
|
| 668 |
+
# # torch.distributed.all_reduce(num_boxes)
|
| 669 |
+
# # (Niels) in original implementation, num_boxes is divided by get_world_size()
|
| 670 |
+
# num_boxes = torch.clamp(num_boxes, min=1).item()
|
| 671 |
+
|
| 672 |
+
# # Compute all the requested losses
|
| 673 |
+
# losses = {}
|
| 674 |
+
# for loss in self.losses:
|
| 675 |
+
# losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
|
| 676 |
+
|
| 677 |
+
# # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
| 678 |
+
# if "auxiliary_outputs" in outputs:
|
| 679 |
+
# for i, auxiliary_outputs in enumerate(outputs["auxiliary_outputs"]):
|
| 680 |
+
# # indices = self.matcher(auxiliary_outputs, targets)
|
| 681 |
+
# for loss in self.losses:
|
| 682 |
+
# if loss == "masks":
|
| 683 |
+
# # Intermediate masks losses are too costly to compute, we ignore them.
|
| 684 |
+
# continue
|
| 685 |
+
# l_dict = self.get_loss(loss, auxiliary_outputs, targets, indices, num_boxes)
|
| 686 |
+
# l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
|
| 687 |
+
# losses.update(l_dict)
|
| 688 |
+
|
| 689 |
+
# return losses
|
utils/predict.py
CHANGED
|
@@ -112,7 +112,7 @@ def xclip_pred(new_desc: dict,
|
|
| 112 |
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
| 113 |
image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
| 114 |
|
| 115 |
-
pred_logits, part_logits
|
| 116 |
|
| 117 |
b, c, n = part_logits.shape
|
| 118 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|
|
|
|
| 112 |
image_input = owlvit_processor(images=image, return_tensors='pt').to(device)
|
| 113 |
image_embeds, _ = model.image_embedder(pixel_values = image_input['pixel_values'])
|
| 114 |
|
| 115 |
+
pred_logits, part_logits = model(image_embeds, part_embeds, query_embeds, None)
|
| 116 |
|
| 117 |
b, c, n = part_logits.shape
|
| 118 |
mask = torch.tensor(desc_mask, dtype=float).unsqueeze(0).unsqueeze(0).repeat(b, c, 1).to(device)
|