Update modelling_magiv2.py
Browse files- modelling_magiv2.py +149 -2
modelling_magiv2.py
CHANGED
|
@@ -2,7 +2,6 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
|
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
| 5 |
-
ConditionalDetrHungarianMatcher,
|
| 6 |
inverse_sigmoid,
|
| 7 |
)
|
| 8 |
from .configuration_magiv2 import Magiv2Config
|
|
@@ -17,6 +16,7 @@ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
|
|
| 17 |
import pulp
|
| 18 |
import scipy
|
| 19 |
import numpy as np
|
|
|
|
| 20 |
|
| 21 |
class Magiv2Model(PreTrainedModel):
|
| 22 |
config_class = Magiv2Config
|
|
@@ -611,4 +611,151 @@ class Magiv2Model(PreTrainedModel):
|
|
| 611 |
if apply_sigmoid:
|
| 612 |
text_tail_affinities = text_tail_affinities.sigmoid()
|
| 613 |
affinity_matrices.append(text_tail_affinities)
|
| 614 |
-
return affinity_matrices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
|
|
|
| 5 |
inverse_sigmoid,
|
| 6 |
)
|
| 7 |
from .configuration_magiv2 import Magiv2Config
|
|
|
|
| 16 |
import pulp
|
| 17 |
import scipy
|
| 18 |
import numpy as np
|
| 19 |
+
from scipy.optimize import linear_sum_assignment
|
| 20 |
|
| 21 |
class Magiv2Model(PreTrainedModel):
|
| 22 |
config_class = Magiv2Config
|
|
|
|
| 611 |
if apply_sigmoid:
|
| 612 |
text_tail_affinities = text_tail_affinities.sigmoid()
|
| 613 |
affinity_matrices.append(text_tail_affinities)
|
| 614 |
+
return affinity_matrices
|
| 615 |
+
|
| 616 |
+
# Copied from transformers.models.detr.modeling_detr._upcast
|
| 617 |
+
def _upcast(t):
|
| 618 |
+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
| 619 |
+
if t.is_floating_point():
|
| 620 |
+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
| 621 |
+
else:
|
| 622 |
+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
# Copied from transformers.models.detr.modeling_detr.box_area
|
| 626 |
+
def box_area(boxes):
|
| 627 |
+
"""
|
| 628 |
+
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
| 632 |
+
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
| 633 |
+
< x2` and `0 <= y1 < y2`.
|
| 634 |
+
|
| 635 |
+
Returns:
|
| 636 |
+
`torch.FloatTensor`: a tensor containing the area for each box.
|
| 637 |
+
"""
|
| 638 |
+
boxes = _upcast(boxes)
|
| 639 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
# Copied from transformers.models.detr.modeling_detr.box_iou
|
| 643 |
+
def box_iou(boxes1, boxes2):
|
| 644 |
+
area1 = box_area(boxes1)
|
| 645 |
+
area2 = box_area(boxes2)
|
| 646 |
+
|
| 647 |
+
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
| 648 |
+
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
| 649 |
+
|
| 650 |
+
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
| 651 |
+
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
| 652 |
+
|
| 653 |
+
union = area1[:, None] + area2 - inter
|
| 654 |
+
|
| 655 |
+
iou = inter / union
|
| 656 |
+
return iou, union
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
| 660 |
+
def generalized_box_iou(boxes1, boxes2):
|
| 661 |
+
"""
|
| 662 |
+
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
| 663 |
+
|
| 664 |
+
Returns:
|
| 665 |
+
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
| 666 |
+
"""
|
| 667 |
+
# degenerate boxes gives inf / nan results
|
| 668 |
+
# so do an early check
|
| 669 |
+
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
| 670 |
+
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
| 671 |
+
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
| 672 |
+
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
| 673 |
+
iou, union = box_iou(boxes1, boxes2)
|
| 674 |
+
|
| 675 |
+
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| 676 |
+
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
| 677 |
+
|
| 678 |
+
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
| 679 |
+
area = width_height[:, :, 0] * width_height[:, :, 1]
|
| 680 |
+
|
| 681 |
+
return iou - (area - union) / area
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
| 685 |
+
class ConditionalDetrHungarianMatcher(nn.Module):
|
| 686 |
+
"""
|
| 687 |
+
This class computes an assignment between the targets and the predictions of the network.
|
| 688 |
+
|
| 689 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
| 690 |
+
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
| 691 |
+
un-matched (and thus treated as non-objects).
|
| 692 |
+
|
| 693 |
+
Args:
|
| 694 |
+
class_cost:
|
| 695 |
+
The relative weight of the classification error in the matching cost.
|
| 696 |
+
bbox_cost:
|
| 697 |
+
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
| 698 |
+
giou_cost:
|
| 699 |
+
The relative weight of the giou loss of the bounding box in the matching cost.
|
| 700 |
+
"""
|
| 701 |
+
|
| 702 |
+
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
| 703 |
+
super().__init__()
|
| 704 |
+
|
| 705 |
+
self.class_cost = class_cost
|
| 706 |
+
self.bbox_cost = bbox_cost
|
| 707 |
+
self.giou_cost = giou_cost
|
| 708 |
+
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
| 709 |
+
raise ValueError("All costs of the Matcher can't be 0")
|
| 710 |
+
|
| 711 |
+
@torch.no_grad()
|
| 712 |
+
def forward(self, outputs, targets):
|
| 713 |
+
"""
|
| 714 |
+
Args:
|
| 715 |
+
outputs (`dict`):
|
| 716 |
+
A dictionary that contains at least these entries:
|
| 717 |
+
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 718 |
+
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
| 719 |
+
targets (`List[dict]`):
|
| 720 |
+
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 721 |
+
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
| 722 |
+
ground-truth
|
| 723 |
+
objects in the target) containing the class labels
|
| 724 |
+
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
| 725 |
+
|
| 726 |
+
Returns:
|
| 727 |
+
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
| 728 |
+
- index_i is the indices of the selected predictions (in order)
|
| 729 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 730 |
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 731 |
+
"""
|
| 732 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
| 733 |
+
|
| 734 |
+
# We flatten to compute the cost matrices in a batch
|
| 735 |
+
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
| 736 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
| 737 |
+
|
| 738 |
+
# Also concat the target labels and boxes
|
| 739 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
| 740 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
| 741 |
+
|
| 742 |
+
# Compute the classification cost.
|
| 743 |
+
alpha = 0.25
|
| 744 |
+
gamma = 2.0
|
| 745 |
+
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
| 746 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
| 747 |
+
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
| 748 |
+
|
| 749 |
+
# Compute the L1 cost between boxes
|
| 750 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
| 751 |
+
|
| 752 |
+
# Compute the giou cost between boxes
|
| 753 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
| 754 |
+
|
| 755 |
+
# Final cost matrix
|
| 756 |
+
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
| 757 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
| 758 |
+
|
| 759 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 760 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
| 761 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|