Mateusz Mr贸z commited on
Commit
cd77b9d
1 Parent(s): b3f1331

Implement Magiv2Model with detection, OCR, and character association capabilities

Browse files

- Added Magiv2Model class inheriting from PreTrainedModel.
- Integrated VisionEncoderDecoderModel for OCR and ViTMAEModel for crop embeddings.
- Implemented ConditionalDetrModel for object detection with associated prediction heads.
- Developed methods for chapter-wide predictions, character name assignments, and affinity matrix calculations.
- Included utility functions for bounding box operations and Hungarian matching for object assignments.
- Added support for processing images in batches and handling various detection thresholds.
- Implemented visualization and prediction methods for single images.

config.json CHANGED
@@ -487,4 +487,4 @@
487
  "ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
488
  "torch_dtype": "float32",
489
  "transformers_version": "4.34.0.dev0"
490
- }
 
487
  "ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
488
  "torch_dtype": "float32",
489
  "transformers_version": "4.34.0.dev0"
490
+ }
configuration_magiv2.py CHANGED
@@ -1,38 +1,131 @@
1
  from transformers import PretrainedConfig, VisionEncoderDecoderConfig
2
- from typing import List
3
 
4
 
5
  class Magiv2Config(PretrainedConfig):
6
- model_type = "magiv2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def __init__(
9
  self,
10
  disable_ocr: bool = False,
11
  disable_crop_embeddings: bool = False,
12
  disable_detections: bool = False,
13
- detection_model_config: dict = None,
14
- ocr_model_config: dict = None,
15
- crop_embedding_model_config: dict = None,
16
- detection_image_preprocessing_config: dict = None,
17
- ocr_pretrained_processor_path: str = None,
18
- crop_embedding_image_preprocessing_config: dict = None,
19
- **kwargs,
20
- ):
21
- self.disable_ocr = disable_ocr
22
- self.disable_crop_embeddings = disable_crop_embeddings
23
- self.disable_detections = disable_detections
24
- self.kwargs = kwargs
25
- self.detection_model_config = None
26
- self.ocr_model_config = None
27
- self.crop_embedding_model_config = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  if detection_model_config is not None:
29
- self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
 
 
 
 
 
30
  if ocr_model_config is not None:
31
- self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
 
 
 
 
32
  if crop_embedding_model_config is not None:
33
- self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
34
-
35
- self.detection_image_preprocessing_config = detection_image_preprocessing_config
36
- self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
37
- self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  super().__init__(**kwargs)
 
1
  from transformers import PretrainedConfig, VisionEncoderDecoderConfig
2
+ from typing import Any, Optional
3
 
4
 
5
  class Magiv2Config(PretrainedConfig):
6
+ """
7
+ Klasa konfiguracyjna dla modelu Magiv2.
8
+
9
+ Magiv2Config dziedziczy po PretrainedConfig z biblioteki transformers i definiuje
10
+ kompletn膮 konfiguracj臋 dla modelu wizyjnego sk艂adaj膮cego si臋 z trzech g艂贸wnych komponent贸w:
11
+ - Model detekcji obiekt贸w (detection)
12
+ - Model OCR (rozpoznawanie tekstu)
13
+ - Model embedowania wyci臋tych fragment贸w obrazu (crop embeddings)
14
+
15
+ Attributes:
16
+ model_type: Identyfikator typu modelu dla biblioteki transformers
17
+ disable_ocr: Flaga wy艂膮czaj膮ca modu艂 OCR
18
+ disable_crop_embeddings: Flaga wy艂膮czaj膮ca modu艂 embedowania wyci臋tych fragment贸w
19
+ disable_detections: Flaga wy艂膮czaj膮ca modu艂 detekcji obiekt贸w
20
+ detection_model_config: Konfiguracja modelu detekcji (po deserializacji)
21
+ ocr_model_config: Konfiguracja modelu OCR (po deserializacji)
22
+ crop_embedding_model_config: Konfiguracja modelu embedowania (po deserializacji)
23
+ detection_image_preprocessing_config: Parametry przetwarzania obrazu dla detekcji
24
+ ocr_pretrained_processor_path: 艢cie偶ka do wytrenowanego procesora OCR
25
+ crop_embedding_image_preprocessing_config: Parametry przetwarzania obrazu dla embedowania
26
+ """
27
+
28
+ # Identyfikator typu modelu u偶ywany przez bibliotek臋 transformers
29
+ model_type: str = "magiv2"
30
 
31
  def __init__(
32
  self,
33
  disable_ocr: bool = False,
34
  disable_crop_embeddings: bool = False,
35
  disable_detections: bool = False,
36
+ detection_model_config: Optional[dict[str, Any]] = None,
37
+ ocr_model_config: Optional[dict[str, Any]] = None,
38
+ crop_embedding_model_config: Optional[dict[str, Any]] = None,
39
+ detection_image_preprocessing_config: Optional[dict[str, Any]] = None,
40
+ ocr_pretrained_processor_path: Optional[str] = None,
41
+ crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = None,
42
+ **kwargs: Any,
43
+ ) -> None:
44
+ """
45
+ Inicjalizuje konfiguracj臋 modelu Magiv2.
46
+
47
+ Konstruktor przyjmuje parametry kontroluj膮ce kt贸re modu艂y modelu s膮 aktywne,
48
+ oraz konfiguracje dla poszczeg贸lnych komponent贸w. Konfiguracje przekazane jako
49
+ s艂owniki s膮 deserializowane do odpowiednich obiekt贸w Config z transformers.
50
+
51
+ Args:
52
+ disable_ocr: Czy wy艂膮czy膰 modu艂 rozpoznawania tekstu (OCR).
53
+ Domy艣lnie False - OCR jest aktywne.
54
+ disable_crop_embeddings: Czy wy艂膮czy膰 modu艂 tworzenia embedding贸w dla wyci臋tych
55
+ fragment贸w obrazu. Domy艣lnie False - embedowanie aktywne.
56
+ disable_detections: Czy wy艂膮czy膰 modu艂 detekcji obiekt贸w na obrazie.
57
+ Domy艣lnie False - detekcja aktywna.
58
+ detection_model_config: S艂ownik z konfiguracj膮 modelu detekcji obiekt贸w.
59
+ Je艣li podany, zostanie zdeserializowany do PretrainedConfig.
60
+ ocr_model_config: S艂ownik z konfiguracj膮 modelu OCR (encoder-decoder).
61
+ Je艣li podany, zostanie zdeserializowany do VisionEncoderDecoderConfig.
62
+ crop_embedding_model_config: S艂ownik z konfiguracj膮 modelu embedowania wyci臋tych
63
+ fragment贸w. Je艣li podany, zostanie zdeserializowany
64
+ do PretrainedConfig.
65
+ detection_image_preprocessing_config: S艂ownik z parametrami preprocessingu obrazu
66
+ dla modu艂u detekcji (np. rozmiar, normalizacja).
67
+ ocr_pretrained_processor_path: 艢cie偶ka do katalogu lub Hub ID z wytrenowanym
68
+ procesorem obrazu dla modu艂u OCR.
69
+ crop_embedding_image_preprocessing_config: S艂ownik z parametrami preprocessingu
70
+ obrazu dla modu艂u embedowania.
71
+ **kwargs: Dodatkowe argumenty przekazywane do klasy bazowej PretrainedConfig.
72
+
73
+ Returns:
74
+ None
75
+
76
+ Note:
77
+ - Konfiguracje modeli s膮 deserializowane z dict do obiekt贸w Config tylko wtedy,
78
+ gdy zosta艂y przekazane (nie s膮 None)
79
+ - Flagi disable_* pozwalaj膮 na selektywne wy艂膮czanie poszczeg贸lnych modu艂贸w
80
+ - Wszystkie dodatkowe kwargs s膮 przekazywane do klasy bazowej PretrainedConfig
81
+ """
82
+ # Przechowywanie flag wy艂膮czaj膮cych poszczeg贸lne modu艂y
83
+ self.disable_ocr: bool = disable_ocr
84
+ self.disable_crop_embeddings: bool = disable_crop_embeddings
85
+ self.disable_detections: bool = disable_detections
86
+
87
+ # Przechowywanie dodatkowych argument贸w przekazanych do konstruktora
88
+ self.kwargs: dict[str, Any] = kwargs
89
+
90
+ # Inicjalizacja atrybut贸w konfiguracji modeli jako None
91
+ # (mog膮 zosta膰 zdeserializowane poni偶ej je艣li parametry nie s膮 None)
92
+ self.detection_model_config: Optional[PretrainedConfig] = None
93
+ self.ocr_model_config: Optional[VisionEncoderDecoderConfig] = None
94
+ self.crop_embedding_model_config: Optional[PretrainedConfig] = None
95
+
96
+ # Deserializacja konfiguracji modelu detekcji ze s艂ownika do obiektu PretrainedConfig
97
  if detection_model_config is not None:
98
+ self.detection_model_config = PretrainedConfig.from_dict(
99
+ detection_model_config
100
+ )
101
+
102
+ # Deserializacja konfiguracji modelu OCR ze s艂ownika do obiektu VisionEncoderDecoderConfig
103
+ # OCR wykorzystuje architektur臋 encoder-decoder (vision encoder + text decoder)
104
  if ocr_model_config is not None:
105
+ self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(
106
+ ocr_model_config
107
+ )
108
+
109
+ # Deserializacja konfiguracji modelu embedowania ze s艂ownika do obiektu PretrainedConfig
110
  if crop_embedding_model_config is not None:
111
+ self.crop_embedding_model_config = PretrainedConfig.from_dict(
112
+ crop_embedding_model_config
113
+ )
114
+
115
+ # Przechowywanie konfiguracji preprocessingu obrazu dla modu艂u detekcji
116
+ # (np. docelowy rozmiar obrazu, parametry normalizacji, augmentacje)
117
+ self.detection_image_preprocessing_config: Optional[dict[str, Any]] = (
118
+ detection_image_preprocessing_config
119
+ )
120
+
121
+ # 艢cie偶ka do wytrenowanego procesora OCR (mo偶e by膰 lokalna lub z Hugging Face Hub)
122
+ self.ocr_pretrained_processor_path: Optional[str] = ocr_pretrained_processor_path
123
+
124
+ # Przechowywanie konfiguracji preprocessingu obrazu dla modu艂u embedowania
125
+ # (np. docelowy rozmiar wyci臋膰, parametry normalizacji)
126
+ self.crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = (
127
+ crop_embedding_image_preprocessing_config
128
+ )
129
+
130
+ # Wywo艂anie konstruktora klasy bazowej PretrainedConfig z dodatkowymi kwargs
131
  super().__init__(**kwargs)
modelling_magiv2.py CHANGED
The diff for this file is too large to render. See raw diff
 
modelling_magiv2_pre.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
2
+ from transformers.models.conditional_detr.modeling_conditional_detr import (
3
+ ConditionalDetrMLPPredictionHead,
4
+ ConditionalDetrModelOutput,
5
+ inverse_sigmoid,
6
+ )
7
+ from .configuration_magiv2 import Magiv2Config
8
+ from .processing_magiv2 import Magiv2Processor
9
+ from torch import nn
10
+ from typing import Optional, List
11
+ import torch
12
+ from einops import rearrange, repeat
13
+ from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
14
+ from transformers.image_transforms import center_to_corners_format
15
+ from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
16
+ import pulp
17
+ import scipy
18
+ import numpy as np
19
+ from scipy.optimize import linear_sum_assignment
20
+
21
+
22
+ class Magiv2Model(PreTrainedModel):
23
+ config_class = Magiv2Config
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.config = config
28
+ self.processor = Magiv2Processor(config)
29
+ if not config.disable_ocr:
30
+ self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
31
+ if not config.disable_crop_embeddings:
32
+ self.crop_embedding_model = ViTMAEModel(
33
+ config.crop_embedding_model_config)
34
+ if not config.disable_detections:
35
+ self.num_non_obj_tokens = 5
36
+ self.detection_transformer = ConditionalDetrModel(
37
+ config.detection_model_config)
38
+ self.bbox_predictor = ConditionalDetrMLPPredictionHead(
39
+ input_dim=config.detection_model_config.d_model,
40
+ hidden_dim=config.detection_model_config.d_model,
41
+ output_dim=4, num_layers=3
42
+ )
43
+ self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
44
+ input_dim=3 * config.detection_model_config.d_model +
45
+ (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
46
+ hidden_dim=config.detection_model_config.d_model,
47
+ output_dim=1, num_layers=3
48
+ )
49
+ self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
50
+ input_dim=3 * config.detection_model_config.d_model,
51
+ hidden_dim=config.detection_model_config.d_model,
52
+ output_dim=1, num_layers=3
53
+ )
54
+ self.text_tail_matching_head = ConditionalDetrMLPPredictionHead(
55
+ input_dim=2 * config.detection_model_config.d_model,
56
+ hidden_dim=config.detection_model_config.d_model,
57
+ output_dim=1, num_layers=3
58
+ )
59
+ self.class_labels_classifier = nn.Linear(
60
+ config.detection_model_config.d_model, config.detection_model_config.num_labels
61
+ )
62
+ self.is_this_text_a_dialogue = nn.Linear(
63
+ config.detection_model_config.d_model, 1
64
+ )
65
+ self.matcher = ConditionalDetrHungarianMatcher(
66
+ class_cost=config.detection_model_config.class_cost,
67
+ bbox_cost=config.detection_model_config.bbox_cost,
68
+ giou_cost=config.detection_model_config.giou_cost
69
+ )
70
+
71
+ def move_to_device(self, input):
72
+ return move_to_device(input, self.device)
73
+
74
+ @torch.no_grad()
75
+ def do_chapter_wide_prediction(self, pages_in_order, character_bank, eta=0.75, batch_size=8, use_tqdm=False, do_ocr=True):
76
+ texts = []
77
+ characters = []
78
+ character_clusters = []
79
+ if use_tqdm:
80
+ from tqdm import tqdm
81
+ iterator = tqdm(range(0, len(pages_in_order), batch_size))
82
+ else:
83
+ iterator = range(0, len(pages_in_order), batch_size)
84
+ per_page_results = []
85
+ for i in iterator:
86
+ pages = pages_in_order[i:i+batch_size]
87
+ results = self.predict_detections_and_associations(pages)
88
+ per_page_results.extend([result for result in results])
89
+
90
+ texts = [result["texts"] for result in per_page_results]
91
+ characters = [result["characters"] for result in per_page_results]
92
+ character_clusters = [result["character_cluster_labels"]
93
+ for result in per_page_results]
94
+ assigned_character_names = self.assign_names_to_characters(
95
+ pages_in_order, characters, character_bank, character_clusters, eta=eta)
96
+ if do_ocr:
97
+ ocr = self.predict_ocr(pages_in_order, texts, use_tqdm=use_tqdm)
98
+ offset_characters = 0
99
+ iteration_over = zip(
100
+ per_page_results, ocr) if do_ocr else per_page_results
101
+ for iter in iteration_over:
102
+ if do_ocr:
103
+ result, ocr_for_page = iter
104
+ result["ocr"] = ocr_for_page
105
+ else:
106
+ result = iter
107
+ result["character_names"] = assigned_character_names[offset_characters:
108
+ offset_characters + len(result["characters"])]
109
+ offset_characters += len(result["characters"])
110
+ return per_page_results
111
+
112
+ def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
113
+ if len(character_bank["images"]) == 0:
114
+ return ["Other" for bboxes_for_image in character_bboxes for bbox in bboxes_for_image]
115
+ chapter_wide_char_embeddings = self.predict_crop_embeddings(
116
+ images, character_bboxes)
117
+ chapter_wide_char_embeddings = torch.cat(
118
+ chapter_wide_char_embeddings, dim=0)
119
+ chapter_wide_char_embeddings = torch.nn.functional.normalize(
120
+ chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
121
+ # create must-link and cannot link constraints from character_clusters
122
+ must_link = []
123
+ cannot_link = []
124
+ offset = 0
125
+ for clusters_per_image in character_clusters:
126
+ for i in range(len(clusters_per_image)):
127
+ for j in range(i+1, len(clusters_per_image)):
128
+ if clusters_per_image[i] == clusters_per_image[j]:
129
+ must_link.append((offset + i, offset + j))
130
+ else:
131
+ cannot_link.append((offset + i, offset + j))
132
+ offset += len(clusters_per_image)
133
+ character_bank_for_this_chapter = self.predict_crop_embeddings(
134
+ character_bank["images"], [[[0, 0, x.shape[1], x.shape[0]]] for x in character_bank["images"]])
135
+ character_bank_for_this_chapter = torch.cat(
136
+ character_bank_for_this_chapter, dim=0)
137
+ character_bank_for_this_chapter = torch.nn.functional.normalize(
138
+ character_bank_for_this_chapter, p=2, dim=1).cpu().numpy()
139
+ costs = scipy.spatial.distance.cdist(
140
+ chapter_wide_char_embeddings, character_bank_for_this_chapter)
141
+ none_of_the_above = eta * np.ones((costs.shape[0], 1))
142
+ costs = np.concatenate([costs, none_of_the_above], axis=1)
143
+ sense = pulp.LpMinimize
144
+ num_supply, num_demand = costs.shape
145
+ problem = pulp.LpProblem("Optimal_Transport_Problem", sense)
146
+ x = pulp.LpVariable.dicts("x", ((i, j) for i in range(
147
+ num_supply) for j in range(num_demand)), cat='Binary')
148
+ # Objective Function to minimize
149
+ problem += pulp.lpSum([costs[i][j] * x[(i, j)]
150
+ for i in range(num_supply) for j in range(num_demand)])
151
+ # each crop must be assigned to exactly one character
152
+ for i in range(num_supply):
153
+ problem += pulp.lpSum([x[(i, j)] for j in range(num_demand)]
154
+ ) == 1, f"Supply_{i}_Total_Assignment"
155
+ # cannot link constraints
156
+ for j in range(num_demand-1):
157
+ for (s1, s2) in cannot_link:
158
+ problem += x[(s1, j)] + x[(s2, j)
159
+ ] <= 1, f"Exclusion_{s1}_{s2}_Demand_{j}"
160
+ # must link constraints
161
+ for j in range(num_demand):
162
+ for (s1, s2) in must_link:
163
+ problem += x[(s1, j)] - x[(s2, j)
164
+ ] == 0, f"Inclusion_{s1}_{s2}_Demand_{j}"
165
+ problem.solve()
166
+ assignments = []
167
+ for v in problem.variables():
168
+ if v.varValue is not None and v.varValue > 0:
169
+ index, assignment = v.name.split(
170
+ "(")[1].split(")")[0].split(",")
171
+ assignment = assignment[1:]
172
+ assignments.append((int(index), int(assignment)))
173
+
174
+ labels = np.zeros(num_supply)
175
+ for i, j in assignments:
176
+ labels[i] = j
177
+
178
+ return [character_bank["names"][int(i)] if i < len(character_bank["names"]) else "Other" for i in labels]
179
+
180
+ def predict_detections_and_associations(
181
+ self,
182
+ images,
183
+ move_to_device_fn=None,
184
+ character_detection_threshold=0.3,
185
+ panel_detection_threshold=0.2,
186
+ text_detection_threshold=0.3,
187
+ tail_detection_threshold=0.34,
188
+ character_character_matching_threshold=0.65,
189
+ text_character_matching_threshold=0.35,
190
+ text_tail_matching_threshold=0.3,
191
+ text_classification_threshold=0.5,
192
+ ):
193
+ assert not self.config.disable_detections
194
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
195
+
196
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(
197
+ images)
198
+ inputs_to_detection_transformer = move_to_device_fn(
199
+ inputs_to_detection_transformer)
200
+
201
+ detection_transformer_output = self._get_detection_transformer_output(
202
+ **inputs_to_detection_transformer)
203
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(
204
+ detection_transformer_output)
205
+
206
+ original_image_sizes = torch.stack([torch.tensor(
207
+ img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device)
208
+
209
+ batch_scores, batch_labels = predicted_class_scores.max(-1)
210
+ batch_scores = batch_scores.sigmoid()
211
+ batch_labels = batch_labels.long()
212
+ batch_bboxes = center_to_corners_format(predicted_bboxes)
213
+
214
+ # scale the bboxes back to the original image size
215
+ if isinstance(original_image_sizes, List):
216
+ img_h = torch.Tensor([i[0] for i in original_image_sizes])
217
+ img_w = torch.Tensor([i[1] for i in original_image_sizes])
218
+ else:
219
+ img_h, img_w = original_image_sizes.unbind(1)
220
+ scale_fct = torch.stack(
221
+ [img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
222
+ batch_bboxes = batch_bboxes * scale_fct[:, None, :]
223
+
224
+ batch_panel_indices = self.processor._get_indices_of_panels_to_keep(
225
+ batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
226
+ batch_character_indices = self.processor._get_indices_of_characters_to_keep(
227
+ batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
228
+ batch_text_indices = self.processor._get_indices_of_texts_to_keep(
229
+ batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
230
+ batch_tail_indices = self.processor._get_indices_of_tails_to_keep(
231
+ batch_scores, batch_labels, batch_bboxes, tail_detection_threshold)
232
+
233
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(
234
+ detection_transformer_output)
235
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(
236
+ detection_transformer_output)
237
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(
238
+ detection_transformer_output)
239
+
240
+ text_character_affinity_matrices = self._get_text_character_affinity_matrices(
241
+ character_obj_tokens_for_batch=[x[i] for x, i in zip(
242
+ predicted_obj_tokens_for_batch, batch_character_indices)],
243
+ text_obj_tokens_for_this_batch=[x[i] for x, i in zip(
244
+ predicted_obj_tokens_for_batch, batch_text_indices)],
245
+ t2c_tokens_for_batch=predicted_t2c_tokens_for_batch,
246
+ apply_sigmoid=True,
247
+ )
248
+
249
+ character_bboxes_in_batch = [batch_bboxes[i][j]
250
+ for i, j in enumerate(batch_character_indices)]
251
+ character_character_affinity_matrices = self._get_character_character_affinity_matrices(
252
+ character_obj_tokens_for_batch=[x[i] for x, i in zip(
253
+ predicted_obj_tokens_for_batch, batch_character_indices)],
254
+ crop_embeddings_for_batch=self.predict_crop_embeddings(
255
+ images, character_bboxes_in_batch, move_to_device_fn),
256
+ c2c_tokens_for_batch=predicted_c2c_tokens_for_batch,
257
+ apply_sigmoid=True,
258
+ )
259
+
260
+ text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
261
+ text_obj_tokens_for_this_batch=[x[i] for x, i in zip(
262
+ predicted_obj_tokens_for_batch, batch_text_indices)],
263
+ tail_obj_tokens_for_batch=[x[i] for x, i in zip(
264
+ predicted_obj_tokens_for_batch, batch_tail_indices)],
265
+ apply_sigmoid=True,
266
+ )
267
+
268
+ is_this_text_a_dialogue = self._get_text_classification(
269
+ [x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)])
270
+
271
+ results = []
272
+ for batch_index in range(len(batch_scores)):
273
+ panel_indices = batch_panel_indices[batch_index]
274
+ character_indices = batch_character_indices[batch_index]
275
+ text_indices = batch_text_indices[batch_index]
276
+ tail_indices = batch_tail_indices[batch_index]
277
+
278
+ character_bboxes = batch_bboxes[batch_index][character_indices]
279
+ panel_bboxes = batch_bboxes[batch_index][panel_indices]
280
+ text_bboxes = batch_bboxes[batch_index][text_indices]
281
+ tail_bboxes = batch_bboxes[batch_index][tail_indices]
282
+
283
+ local_sorted_panel_indices = sort_panels(panel_bboxes)
284
+ panel_bboxes = panel_bboxes[local_sorted_panel_indices]
285
+ local_sorted_text_indices = sort_text_boxes_in_reading_order(
286
+ text_bboxes, panel_bboxes)
287
+ text_bboxes = text_bboxes[local_sorted_text_indices]
288
+
289
+ character_character_matching_scores = character_character_affinity_matrices[
290
+ batch_index]
291
+ text_character_matching_scores = text_character_affinity_matrices[
292
+ batch_index][local_sorted_text_indices]
293
+ text_tail_matching_scores = text_tail_affinity_matrices[
294
+ batch_index][local_sorted_text_indices]
295
+
296
+ is_essential_text = is_this_text_a_dialogue[batch_index][
297
+ local_sorted_text_indices] > text_classification_threshold
298
+ character_cluster_labels = UnionFind.from_adj_matrix(
299
+ character_character_matching_scores > character_character_matching_threshold
300
+ ).get_labels_for_connected_components()
301
+
302
+ if 0 in text_character_matching_scores.shape:
303
+ text_character_associations = torch.zeros(
304
+ (0, 2), dtype=torch.long)
305
+ else:
306
+ most_likely_speaker_for_each_text = torch.argmax(
307
+ text_character_matching_scores, dim=1)
308
+ text_indices = torch.arange(len(text_bboxes)).type_as(
309
+ most_likely_speaker_for_each_text)
310
+ text_character_associations = torch.stack(
311
+ [text_indices, most_likely_speaker_for_each_text], dim=1)
312
+ to_keep = text_character_matching_scores.max(
313
+ dim=1).values > text_character_matching_threshold
314
+ text_character_associations = text_character_associations[to_keep]
315
+
316
+ if 0 in text_tail_matching_scores.shape:
317
+ text_tail_associations = torch.zeros((0, 2), dtype=torch.long)
318
+ else:
319
+ most_likely_tail_for_each_text = torch.argmax(
320
+ text_tail_matching_scores, dim=1)
321
+ text_indices = torch.arange(len(text_bboxes)).type_as(
322
+ most_likely_tail_for_each_text)
323
+ text_tail_associations = torch.stack(
324
+ [text_indices, most_likely_tail_for_each_text], dim=1)
325
+ to_keep = text_tail_matching_scores.max(
326
+ dim=1).values > text_tail_matching_threshold
327
+ text_tail_associations = text_tail_associations[to_keep]
328
+
329
+ results.append({
330
+ "panels": panel_bboxes.tolist(),
331
+ "texts": text_bboxes.tolist(),
332
+ "characters": character_bboxes.tolist(),
333
+ "tails": tail_bboxes.tolist(),
334
+ "text_character_associations": text_character_associations.tolist(),
335
+ "text_tail_associations": text_tail_associations.tolist(),
336
+ "character_cluster_labels": character_cluster_labels,
337
+ "is_essential_text": is_essential_text.tolist(),
338
+ })
339
+
340
+ return results
341
+
342
+ def get_affinity_matrices_given_annotations(
343
+ self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
344
+ ):
345
+ assert not self.config.disable_detections
346
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
347
+
348
+ character_bboxes_in_batch = [[bbox for bbox, label in zip(
349
+ a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
350
+ crop_embeddings_for_batch = self.predict_crop_embeddings(
351
+ images, character_bboxes_in_batch, move_to_device_fn)
352
+
353
+ inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(
354
+ images, annotations)
355
+ inputs_to_detection_transformer = move_to_device_fn(
356
+ inputs_to_detection_transformer)
357
+ processed_targets = inputs_to_detection_transformer.pop("labels")
358
+
359
+ detection_transformer_output = self._get_detection_transformer_output(
360
+ **inputs_to_detection_transformer)
361
+ predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(
362
+ detection_transformer_output)
363
+ predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(
364
+ detection_transformer_output)
365
+ predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(
366
+ detection_transformer_output)
367
+
368
+ predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(
369
+ detection_transformer_output)
370
+ matching_dict = {
371
+ "logits": predicted_class_scores,
372
+ "pred_boxes": predicted_bboxes,
373
+ }
374
+ indices = self.matcher(matching_dict, processed_targets)
375
+
376
+ matched_char_obj_tokens_for_batch = []
377
+ matched_text_obj_tokens_for_batch = []
378
+ matched_tail_obj_tokens_for_batch = []
379
+ t2c_tokens_for_batch = []
380
+ c2c_tokens_for_batch = []
381
+
382
+ for j, (pred_idx, tgt_idx) in enumerate(indices):
383
+ target_idx_to_pred_idx = {tgt.item(): pred.item()
384
+ for pred, tgt in zip(pred_idx, tgt_idx)}
385
+ targets_for_this_image = processed_targets[j]
386
+ indices_of_text_boxes_in_annotation = [i for i, label in enumerate(
387
+ targets_for_this_image["class_labels"]) if label == 1]
388
+ indices_of_char_boxes_in_annotation = [i for i, label in enumerate(
389
+ targets_for_this_image["class_labels"]) if label == 0]
390
+ indices_of_tail_boxes_in_annotation = [i for i, label in enumerate(
391
+ targets_for_this_image["class_labels"]) if label == 3]
392
+ predicted_text_indices = [target_idx_to_pred_idx[i]
393
+ for i in indices_of_text_boxes_in_annotation]
394
+ predicted_char_indices = [target_idx_to_pred_idx[i]
395
+ for i in indices_of_char_boxes_in_annotation]
396
+ predicted_tail_indices = [target_idx_to_pred_idx[i]
397
+ for i in indices_of_tail_boxes_in_annotation]
398
+ matched_char_obj_tokens_for_batch.append(
399
+ predicted_obj_tokens_for_batch[j][predicted_char_indices])
400
+ matched_text_obj_tokens_for_batch.append(
401
+ predicted_obj_tokens_for_batch[j][predicted_text_indices])
402
+ matched_tail_obj_tokens_for_batch.append(
403
+ predicted_obj_tokens_for_batch[j][predicted_tail_indices])
404
+ t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
405
+ c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
406
+
407
+ text_character_affinity_matrices = self._get_text_character_affinity_matrices(
408
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
409
+ text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
410
+ t2c_tokens_for_batch=t2c_tokens_for_batch,
411
+ apply_sigmoid=apply_sigmoid,
412
+ )
413
+
414
+ character_character_affinity_matrices = self._get_character_character_affinity_matrices(
415
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
416
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
417
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
418
+ apply_sigmoid=apply_sigmoid,
419
+ )
420
+
421
+ character_character_affinity_matrices_crop_only = self._get_character_character_affinity_matrices(
422
+ character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
423
+ crop_embeddings_for_batch=crop_embeddings_for_batch,
424
+ c2c_tokens_for_batch=c2c_tokens_for_batch,
425
+ crop_only=True,
426
+ apply_sigmoid=apply_sigmoid,
427
+ )
428
+
429
+ text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
430
+ text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
431
+ tail_obj_tokens_for_batch=matched_tail_obj_tokens_for_batch,
432
+ apply_sigmoid=apply_sigmoid,
433
+ )
434
+
435
+ is_this_text_a_dialogue = self._get_text_classification(
436
+ matched_text_obj_tokens_for_batch, apply_sigmoid=apply_sigmoid)
437
+
438
+ return {
439
+ "text_character_affinity_matrices": text_character_affinity_matrices,
440
+ "character_character_affinity_matrices": character_character_affinity_matrices,
441
+ "character_character_affinity_matrices_crop_only": character_character_affinity_matrices_crop_only,
442
+ "text_tail_affinity_matrices": text_tail_affinity_matrices,
443
+ "is_this_text_a_dialogue": is_this_text_a_dialogue,
444
+ }
445
+
446
+ def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
447
+ if self.config.disable_crop_embeddings:
448
+ return None
449
+
450
+ assert isinstance(
451
+ crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
452
+
453
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
454
+
455
+ # temporarily change the mask ratio from default to the one specified
456
+ old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
457
+ self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
458
+
459
+ crops_per_image = []
460
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
461
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
462
+ crops = self.processor.crop_image(image, bboxes)
463
+ assert len(crops) == num_crops
464
+ crops_per_image.extend(crops)
465
+
466
+ if len(crops_per_image) == 0:
467
+ return [move_to_device_fn(torch.zeros(0, self.config.crop_embedding_model_config.hidden_size)) for _ in crop_bboxes]
468
+
469
+ crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(
470
+ crops_per_image)
471
+ crops_per_image = move_to_device_fn(crops_per_image)
472
+
473
+ # process the crops in batches to avoid OOM
474
+ embeddings = []
475
+ for i in range(0, len(crops_per_image), batch_size):
476
+ crops = crops_per_image[i:i+batch_size]
477
+ embeddings_per_batch = self.crop_embedding_model(
478
+ crops).last_hidden_state[:, 0]
479
+ embeddings.append(embeddings_per_batch)
480
+ embeddings = torch.cat(embeddings, dim=0)
481
+
482
+ crop_embeddings_for_batch = []
483
+ for num_crops in num_crops_per_batch:
484
+ crop_embeddings_for_batch.append(embeddings[:num_crops])
485
+ embeddings = embeddings[num_crops:]
486
+
487
+ # restore the mask ratio to the default
488
+ self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
489
+
490
+ return crop_embeddings_for_batch
491
+
492
+ def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32, max_new_tokens=64):
493
+ assert not self.config.disable_ocr
494
+ move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
495
+
496
+ crops_per_image = []
497
+ num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
498
+ for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
499
+ crops = self.processor.crop_image(image, bboxes)
500
+ assert len(crops) == num_crops
501
+ crops_per_image.extend(crops)
502
+
503
+ if len(crops_per_image) == 0:
504
+ return [[] for _ in crop_bboxes]
505
+
506
+ crops_per_image = self.processor.preprocess_inputs_for_ocr(
507
+ crops_per_image)
508
+ crops_per_image = move_to_device_fn(crops_per_image)
509
+
510
+ # process the crops in batches to avoid OOM
511
+ all_generated_texts = []
512
+ if use_tqdm:
513
+ from tqdm import tqdm
514
+ pbar = tqdm(range(0, len(crops_per_image), batch_size))
515
+ else:
516
+ pbar = range(0, len(crops_per_image), batch_size)
517
+ for i in pbar:
518
+ crops = crops_per_image[i:i+batch_size]
519
+ generated_ids = self.ocr_model.generate(
520
+ crops, max_new_tokens=max_new_tokens)
521
+ generated_texts = self.processor.postprocess_ocr_tokens(
522
+ generated_ids)
523
+ all_generated_texts.extend(generated_texts)
524
+
525
+ texts_for_images = []
526
+ for num_crops in num_crops_per_batch:
527
+ texts_for_images.append([x.replace("\n", "")
528
+ for x in all_generated_texts[:num_crops]])
529
+ all_generated_texts = all_generated_texts[num_crops:]
530
+
531
+ return texts_for_images
532
+
533
+ def visualise_single_image_prediction(
534
+ self, image_as_np_array, predictions, filename=None
535
+ ):
536
+ return visualise_single_image_prediction(image_as_np_array, predictions, filename)
537
+
538
+ @torch.no_grad()
539
+ def _get_detection_transformer_output(
540
+ self,
541
+ pixel_values: torch.FloatTensor,
542
+ pixel_mask: Optional[torch.LongTensor] = None
543
+ ):
544
+ if self.config.disable_detections:
545
+ raise ValueError(
546
+ "Detection model is disabled. Set disable_detections=False in the config.")
547
+ return self.detection_transformer(
548
+ pixel_values=pixel_values,
549
+ pixel_mask=pixel_mask,
550
+ return_dict=True
551
+ )
552
+
553
+ def _get_predicted_obj_tokens(
554
+ self,
555
+ detection_transformer_output: ConditionalDetrModelOutput
556
+ ):
557
+ return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
558
+
559
+ def _get_predicted_c2c_tokens(
560
+ self,
561
+ detection_transformer_output: ConditionalDetrModelOutput
562
+ ):
563
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
564
+
565
+ def _get_predicted_t2c_tokens(
566
+ self,
567
+ detection_transformer_output: ConditionalDetrModelOutput
568
+ ):
569
+ return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
570
+
571
+ def _get_predicted_bboxes_and_classes(
572
+ self,
573
+ detection_transformer_output: ConditionalDetrModelOutput,
574
+ ):
575
+ if self.config.disable_detections:
576
+ raise ValueError(
577
+ "Detection model is disabled. Set disable_detections=False in the config.")
578
+
579
+ obj = self._get_predicted_obj_tokens(detection_transformer_output)
580
+
581
+ predicted_class_scores = self.class_labels_classifier(obj)
582
+ reference = detection_transformer_output.reference_points[:-
583
+ self.num_non_obj_tokens]
584
+ reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
585
+ predicted_boxes = self.bbox_predictor(obj)
586
+ predicted_boxes[..., :2] += reference_before_sigmoid
587
+ predicted_boxes = predicted_boxes.sigmoid()
588
+
589
+ return predicted_class_scores, predicted_boxes
590
+
591
+ def _get_text_classification(
592
+ self,
593
+ text_obj_tokens_for_batch: List[torch.FloatTensor],
594
+ apply_sigmoid=False,
595
+ ):
596
+ assert not self.config.disable_detections
597
+ is_this_text_a_dialogue = []
598
+ for text_obj_tokens in text_obj_tokens_for_batch:
599
+ if text_obj_tokens.shape[0] == 0:
600
+ is_this_text_a_dialogue.append(
601
+ torch.tensor([], dtype=torch.bool))
602
+ continue
603
+ classification = self.is_this_text_a_dialogue(
604
+ text_obj_tokens).squeeze(-1)
605
+ if apply_sigmoid:
606
+ classification = classification.sigmoid()
607
+ is_this_text_a_dialogue.append(classification)
608
+ return is_this_text_a_dialogue
609
+
610
+ def _get_character_character_affinity_matrices(
611
+ self,
612
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
613
+ crop_embeddings_for_batch: List[torch.FloatTensor] = None,
614
+ c2c_tokens_for_batch: List[torch.FloatTensor] = None,
615
+ crop_only=False,
616
+ apply_sigmoid=True,
617
+ ):
618
+ assert self.config.disable_detections or (
619
+ character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
620
+ assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
621
+ assert not self.config.disable_detections or not self.config.disable_crop_embeddings
622
+
623
+ if crop_only:
624
+ affinity_matrices = []
625
+ for crop_embeddings in crop_embeddings_for_batch:
626
+ crop_embeddings = crop_embeddings / \
627
+ crop_embeddings.norm(dim=-1, keepdim=True)
628
+ affinity_matrix = crop_embeddings @ crop_embeddings.T
629
+ affinity_matrices.append(affinity_matrix)
630
+ return affinity_matrices
631
+ affinity_matrices = []
632
+ for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
633
+ if character_obj_tokens.shape[0] == 0:
634
+ affinity_matrices.append(torch.zeros(
635
+ 0, 0).type_as(character_obj_tokens))
636
+ continue
637
+ if not self.config.disable_crop_embeddings:
638
+ crop_embeddings = crop_embeddings_for_batch[batch_index]
639
+ assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
640
+ character_obj_tokens = torch.cat(
641
+ [character_obj_tokens, crop_embeddings], dim=-1)
642
+ char_i = repeat(character_obj_tokens, "i d -> i repeat d",
643
+ repeat=character_obj_tokens.shape[0])
644
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d",
645
+ repeat=character_obj_tokens.shape[0])
646
+ char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
647
+ c2c = repeat(c2c, "d -> repeat d", repeat=char_ij.shape[0])
648
+ char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
649
+ character_character_affinities = self.character_character_matching_head(
650
+ char_ij_c2c)
651
+ character_character_affinities = rearrange(
652
+ character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
653
+ character_character_affinities = (
654
+ character_character_affinities + character_character_affinities.T) / 2
655
+ if apply_sigmoid:
656
+ character_character_affinities = character_character_affinities.sigmoid()
657
+ affinity_matrices.append(character_character_affinities)
658
+ return affinity_matrices
659
+
660
+ def _get_text_character_affinity_matrices(
661
+ self,
662
+ character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
663
+ text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
664
+ t2c_tokens_for_batch: List[torch.FloatTensor] = None,
665
+ apply_sigmoid=True,
666
+ ):
667
+ assert not self.config.disable_detections
668
+ assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
669
+ affinity_matrices = []
670
+ for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
671
+ if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
672
+ affinity_matrices.append(torch.zeros(
673
+ text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
674
+ continue
675
+ text_i = repeat(text_obj_tokens, "i d -> i repeat d",
676
+ repeat=character_obj_tokens.shape[0])
677
+ char_j = repeat(character_obj_tokens, "j d -> repeat j d",
678
+ repeat=text_obj_tokens.shape[0])
679
+ text_char = rearrange(
680
+ [text_i, char_j], "two i j d -> (i j) (two d)")
681
+ t2c = repeat(t2c, "d -> repeat d", repeat=text_char.shape[0])
682
+ text_char_t2c = torch.cat([text_char, t2c], dim=-1)
683
+ text_character_affinities = self.text_character_matching_head(
684
+ text_char_t2c)
685
+ text_character_affinities = rearrange(
686
+ text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
687
+ if apply_sigmoid:
688
+ text_character_affinities = text_character_affinities.sigmoid()
689
+ affinity_matrices.append(text_character_affinities)
690
+ return affinity_matrices
691
+
692
+ def _get_text_tail_affinity_matrices(
693
+ self,
694
+ text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
695
+ tail_obj_tokens_for_batch: List[torch.FloatTensor] = None,
696
+ apply_sigmoid=True,
697
+ ):
698
+ assert not self.config.disable_detections
699
+ assert tail_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None
700
+ affinity_matrices = []
701
+ for tail_obj_tokens, text_obj_tokens in zip(tail_obj_tokens_for_batch, text_obj_tokens_for_this_batch):
702
+ if tail_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
703
+ affinity_matrices.append(torch.zeros(
704
+ text_obj_tokens.shape[0], tail_obj_tokens.shape[0]).type_as(tail_obj_tokens))
705
+ continue
706
+ text_i = repeat(text_obj_tokens, "i d -> i repeat d",
707
+ repeat=tail_obj_tokens.shape[0])
708
+ tail_j = repeat(tail_obj_tokens, "j d -> repeat j d",
709
+ repeat=text_obj_tokens.shape[0])
710
+ text_tail = rearrange(
711
+ [text_i, tail_j], "two i j d -> (i j) (two d)")
712
+ text_tail_affinities = self.text_tail_matching_head(text_tail)
713
+ text_tail_affinities = rearrange(
714
+ text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
715
+ if apply_sigmoid:
716
+ text_tail_affinities = text_tail_affinities.sigmoid()
717
+ affinity_matrices.append(text_tail_affinities)
718
+ return affinity_matrices
719
+
720
+ # Copied from transformers.models.detr.modeling_detr._upcast
721
+
722
+
723
+ def _upcast(t):
724
+ # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
725
+ if t.is_floating_point():
726
+ return t if t.dtype in (torch.float32, torch.float64) else t.float()
727
+ else:
728
+ return t if t.dtype in (torch.int32, torch.int64) else t.int()
729
+
730
+
731
+ # Copied from transformers.models.detr.modeling_detr.box_area
732
+ def box_area(boxes):
733
+ """
734
+ Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
735
+
736
+ Args:
737
+ boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
738
+ Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
739
+ < x2` and `0 <= y1 < y2`.
740
+
741
+ Returns:
742
+ `torch.FloatTensor`: a tensor containing the area for each box.
743
+ """
744
+ boxes = _upcast(boxes)
745
+ return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
746
+
747
+
748
+ # Copied from transformers.models.detr.modeling_detr.box_iou
749
+ def box_iou(boxes1, boxes2):
750
+ area1 = box_area(boxes1)
751
+ area2 = box_area(boxes2)
752
+
753
+ left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
754
+ right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
755
+
756
+ width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
757
+ inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
758
+
759
+ union = area1[:, None] + area2 - inter
760
+
761
+ iou = inter / union
762
+ return iou, union
763
+
764
+
765
+ # Copied from transformers.models.detr.modeling_detr.generalized_box_iou
766
+ def generalized_box_iou(boxes1, boxes2):
767
+ """
768
+ Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
769
+
770
+ Returns:
771
+ `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
772
+ """
773
+ # degenerate boxes gives inf / nan results
774
+ # so do an early check
775
+ if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
776
+ raise ValueError(
777
+ f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
778
+ if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
779
+ raise ValueError(
780
+ f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
781
+ iou, union = box_iou(boxes1, boxes2)
782
+
783
+ top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
784
+ bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
785
+
786
+ width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
787
+ area = width_height[:, :, 0] * width_height[:, :, 1]
788
+
789
+ return iou - (area - union) / area
790
+
791
+
792
+ # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
793
+ class ConditionalDetrHungarianMatcher(nn.Module):
794
+ """
795
+ This class computes an assignment between the targets and the predictions of the network.
796
+
797
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
798
+ predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
799
+ un-matched (and thus treated as non-objects).
800
+
801
+ Args:
802
+ class_cost:
803
+ The relative weight of the classification error in the matching cost.
804
+ bbox_cost:
805
+ The relative weight of the L1 error of the bounding box coordinates in the matching cost.
806
+ giou_cost:
807
+ The relative weight of the giou loss of the bounding box in the matching cost.
808
+ """
809
+
810
+ def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
811
+ super().__init__()
812
+
813
+ self.class_cost = class_cost
814
+ self.bbox_cost = bbox_cost
815
+ self.giou_cost = giou_cost
816
+ if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
817
+ raise ValueError("All costs of the Matcher can't be 0")
818
+
819
+ @torch.no_grad()
820
+ def forward(self, outputs, targets):
821
+ """
822
+ Args:
823
+ outputs (`dict`):
824
+ A dictionary that contains at least these entries:
825
+ * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
826
+ * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
827
+ targets (`List[dict]`):
828
+ A list of targets (len(targets) = batch_size), where each target is a dict containing:
829
+ * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
830
+ ground-truth
831
+ objects in the target) containing the class labels
832
+ * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
833
+
834
+ Returns:
835
+ `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
836
+ - index_i is the indices of the selected predictions (in order)
837
+ - index_j is the indices of the corresponding selected targets (in order)
838
+ For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
839
+ """
840
+ batch_size, num_queries = outputs["logits"].shape[:2]
841
+
842
+ # We flatten to compute the cost matrices in a batch
843
+ # [batch_size * num_queries, num_classes]
844
+ out_prob = outputs["logits"].flatten(0, 1).sigmoid()
845
+ out_bbox = outputs["pred_boxes"].flatten(
846
+ 0, 1) # [batch_size * num_queries, 4]
847
+
848
+ # Also concat the target labels and boxes
849
+ target_ids = torch.cat([v["class_labels"] for v in targets])
850
+ target_bbox = torch.cat([v["boxes"] for v in targets])
851
+
852
+ # Compute the classification cost.
853
+ alpha = 0.25
854
+ gamma = 2.0
855
+ neg_cost_class = (1 - alpha) * (out_prob**gamma) * \
856
+ (-(1 - out_prob + 1e-8).log())
857
+ pos_cost_class = alpha * \
858
+ ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
859
+ class_cost = pos_cost_class[:, target_ids] - \
860
+ neg_cost_class[:, target_ids]
861
+
862
+ # Compute the L1 cost between boxes
863
+ bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
864
+
865
+ # Compute the giou cost between boxes
866
+ giou_cost = -generalized_box_iou(center_to_corners_format(
867
+ out_bbox), center_to_corners_format(target_bbox))
868
+
869
+ # Final cost matrix
870
+ cost_matrix = self.bbox_cost * bbox_cost + \
871
+ self.class_cost * class_cost + self.giou_cost * giou_cost
872
+ cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
873
+
874
+ sizes = [len(v["boxes"]) for v in targets]
875
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(
876
+ cost_matrix.split(sizes, -1))]
877
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
processing_magiv2.py CHANGED
@@ -5,6 +5,7 @@ from shapely.geometry import box
5
  from .utils import x1y1x2y2_to_xywh
6
  import numpy as np
7
 
 
8
  class Magiv2Processor():
9
  def __init__(self, config):
10
  self.config = config
@@ -13,34 +14,38 @@ class Magiv2Processor():
13
  self.crop_embedding_image_preprocessor = None
14
  if not config.disable_detections:
15
  assert config.detection_image_preprocessing_config is not None
16
- self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(config.detection_image_preprocessing_config)
 
17
  if not config.disable_ocr:
18
  assert config.ocr_pretrained_processor_path is not None
19
- self.ocr_preprocessor = TrOCRProcessor.from_pretrained(config.ocr_pretrained_processor_path)
 
20
  if not config.disable_crop_embeddings:
21
  assert config.crop_embedding_image_preprocessing_config is not None
22
- self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(config.crop_embedding_image_preprocessing_config)
23
-
 
24
  def preprocess_inputs_for_detection(self, images, annotations=None):
25
  images = list(images)
26
  assert isinstance(images[0], np.ndarray)
27
  annotations = self._convert_annotations_to_coco_format(annotations)
28
- inputs = self.detection_image_preprocessor(images, annotations=annotations, return_tensors="pt")
 
29
  return inputs
30
 
31
  def preprocess_inputs_for_ocr(self, images):
32
  images = list(images)
33
  assert isinstance(images[0], np.ndarray)
34
  return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
35
-
36
  def preprocess_inputs_for_crop_embeddings(self, images):
37
  images = list(images)
38
  assert isinstance(images[0], np.ndarray)
39
  return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
40
-
41
  def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
42
  return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
43
-
44
  def crop_image(self, image, bboxes):
45
  crops_for_image = []
46
  for bbox in bboxes:
@@ -48,7 +53,8 @@ class Magiv2Processor():
48
 
49
  # fix the bounding box in case it is out of bounds or too small
50
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
51
- x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2) # just incase
 
52
  x1, y1 = max(0, x1), max(0, y1)
53
  x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
54
  x2, y2 = max(0, x2), max(0, y2)
@@ -71,10 +77,11 @@ class Magiv2Processor():
71
  def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
72
  indices_of_characters_to_keep = []
73
  for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
74
- indices = torch.where((labels == 0) & (scores > character_detection_threshold))[0]
 
75
  indices_of_characters_to_keep.append(indices)
76
  return indices_of_characters_to_keep
77
-
78
  def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
79
  indices_of_panels_to_keep = []
80
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
@@ -85,7 +92,8 @@ class Magiv2Processor():
85
  if len(indices) == 0:
86
  indices_of_panels_to_keep.append([])
87
  continue
88
- scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
 
89
  panels_to_keep = []
90
  union_of_panels_so_far = box(0, 0, 0, 0)
91
  for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
@@ -95,21 +103,25 @@ class Magiv2Processor():
95
  if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
96
  continue
97
  panels_to_keep.append((ps, pl, pb, pi))
98
- union_of_panels_so_far = union_of_panels_so_far.union(panel_polygon)
99
- indices_of_panels_to_keep.append([p[3].item() for p in panels_to_keep])
 
 
100
  return indices_of_panels_to_keep
101
-
102
  def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
103
  indices_of_texts_to_keep = []
104
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
105
- indices = torch.where((labels == 1) & (scores > text_detection_threshold))[0]
 
106
  bboxes = bboxes[indices]
107
  scores = scores[indices]
108
  labels = labels[indices]
109
  if len(indices) == 0:
110
  indices_of_texts_to_keep.append([])
111
  continue
112
- scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
 
113
  texts_to_keep = []
114
  texts_to_keep_as_shapely_objects = []
115
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
@@ -122,20 +134,23 @@ class Magiv2Processor():
122
  if should_append:
123
  texts_to_keep.append((ts, tl, tb, ti))
124
  texts_to_keep_as_shapely_objects.append(text_polygon)
125
- indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
 
126
  return indices_of_texts_to_keep
127
-
128
  def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
129
  indices_of_texts_to_keep = []
130
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
131
- indices = torch.where((labels == 3) & (scores > text_detection_threshold))[0]
 
132
  bboxes = bboxes[indices]
133
  scores = scores[indices]
134
  labels = labels[indices]
135
  if len(indices) == 0:
136
  indices_of_texts_to_keep.append([])
137
  continue
138
- scores, labels, indices, bboxes = zip(*sorted(zip(scores, labels, indices, bboxes), reverse=True))
 
139
  texts_to_keep = []
140
  texts_to_keep_as_shapely_objects = []
141
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
@@ -148,9 +163,10 @@ class Magiv2Processor():
148
  if should_append:
149
  texts_to_keep.append((ts, tl, tb, ti))
150
  texts_to_keep_as_shapely_objects.append(text_polygon)
151
- indices_of_texts_to_keep.append([t[3].item() for t in texts_to_keep])
 
152
  return indices_of_texts_to_keep
153
-
154
  def _convert_annotations_to_coco_format(self, annotations):
155
  if annotations is None:
156
  return None
@@ -169,7 +185,7 @@ class Magiv2Processor():
169
  })
170
  coco_annotations.append(coco_annotation)
171
  return coco_annotations
172
-
173
  def _verify_annotations_are_in_correct_format(self, annotations):
174
  error_msg = """
175
  Annotations must be in the following format:
 
5
  from .utils import x1y1x2y2_to_xywh
6
  import numpy as np
7
 
8
+
9
  class Magiv2Processor():
10
  def __init__(self, config):
11
  self.config = config
 
14
  self.crop_embedding_image_preprocessor = None
15
  if not config.disable_detections:
16
  assert config.detection_image_preprocessing_config is not None
17
+ self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(
18
+ config.detection_image_preprocessing_config)
19
  if not config.disable_ocr:
20
  assert config.ocr_pretrained_processor_path is not None
21
+ self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
22
+ config.ocr_pretrained_processor_path)
23
  if not config.disable_crop_embeddings:
24
  assert config.crop_embedding_image_preprocessing_config is not None
25
+ self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
26
+ config.crop_embedding_image_preprocessing_config)
27
+
28
  def preprocess_inputs_for_detection(self, images, annotations=None):
29
  images = list(images)
30
  assert isinstance(images[0], np.ndarray)
31
  annotations = self._convert_annotations_to_coco_format(annotations)
32
+ inputs = self.detection_image_preprocessor(
33
+ images, annotations=annotations, return_tensors="pt")
34
  return inputs
35
 
36
  def preprocess_inputs_for_ocr(self, images):
37
  images = list(images)
38
  assert isinstance(images[0], np.ndarray)
39
  return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
40
+
41
  def preprocess_inputs_for_crop_embeddings(self, images):
42
  images = list(images)
43
  assert isinstance(images[0], np.ndarray)
44
  return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
45
+
46
  def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
47
  return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
48
+
49
  def crop_image(self, image, bboxes):
50
  crops_for_image = []
51
  for bbox in bboxes:
 
53
 
54
  # fix the bounding box in case it is out of bounds or too small
55
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
56
+ x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(
57
+ x1, x2), max(y1, y2) # just incase
58
  x1, y1 = max(0, x1), max(0, y1)
59
  x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
60
  x2, y2 = max(0, x2), max(0, y2)
 
77
  def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
78
  indices_of_characters_to_keep = []
79
  for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
80
+ indices = torch.where((labels == 0) & (
81
+ scores > character_detection_threshold))[0]
82
  indices_of_characters_to_keep.append(indices)
83
  return indices_of_characters_to_keep
84
+
85
  def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
86
  indices_of_panels_to_keep = []
87
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
 
92
  if len(indices) == 0:
93
  indices_of_panels_to_keep.append([])
94
  continue
95
+ scores, labels, indices, bboxes = zip(
96
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
97
  panels_to_keep = []
98
  union_of_panels_so_far = box(0, 0, 0, 0)
99
  for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
 
103
  if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
104
  continue
105
  panels_to_keep.append((ps, pl, pb, pi))
106
+ union_of_panels_so_far = union_of_panels_so_far.union(
107
+ panel_polygon)
108
+ indices_of_panels_to_keep.append(
109
+ [p[3].item() for p in panels_to_keep])
110
  return indices_of_panels_to_keep
111
+
112
  def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
113
  indices_of_texts_to_keep = []
114
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
115
+ indices = torch.where((labels == 1) & (
116
+ scores > text_detection_threshold))[0]
117
  bboxes = bboxes[indices]
118
  scores = scores[indices]
119
  labels = labels[indices]
120
  if len(indices) == 0:
121
  indices_of_texts_to_keep.append([])
122
  continue
123
+ scores, labels, indices, bboxes = zip(
124
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
125
  texts_to_keep = []
126
  texts_to_keep_as_shapely_objects = []
127
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
 
134
  if should_append:
135
  texts_to_keep.append((ts, tl, tb, ti))
136
  texts_to_keep_as_shapely_objects.append(text_polygon)
137
+ indices_of_texts_to_keep.append(
138
+ [t[3].item() for t in texts_to_keep])
139
  return indices_of_texts_to_keep
140
+
141
  def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
142
  indices_of_texts_to_keep = []
143
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
144
+ indices = torch.where((labels == 3) & (
145
+ scores > text_detection_threshold))[0]
146
  bboxes = bboxes[indices]
147
  scores = scores[indices]
148
  labels = labels[indices]
149
  if len(indices) == 0:
150
  indices_of_texts_to_keep.append([])
151
  continue
152
+ scores, labels, indices, bboxes = zip(
153
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
154
  texts_to_keep = []
155
  texts_to_keep_as_shapely_objects = []
156
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
 
163
  if should_append:
164
  texts_to_keep.append((ts, tl, tb, ti))
165
  texts_to_keep_as_shapely_objects.append(text_polygon)
166
+ indices_of_texts_to_keep.append(
167
+ [t[3].item() for t in texts_to_keep])
168
  return indices_of_texts_to_keep
169
+
170
  def _convert_annotations_to_coco_format(self, annotations):
171
  if annotations is None:
172
  return None
 
185
  })
186
  coco_annotations.append(coco_annotation)
187
  return coco_annotations
188
+
189
  def _verify_annotations_are_in_correct_format(self, annotations):
190
  error_msg = """
191
  Annotations must be in the following format:
utils.py CHANGED
@@ -8,6 +8,7 @@ import networkx as nx
8
  from copy import deepcopy
9
  from itertools import groupby
10
 
 
11
  def move_to_device(inputs, device):
12
  if hasattr(inputs, "keys"):
13
  return {k: move_to_device(v, device) for k, v in inputs.items()}
@@ -20,6 +21,7 @@ def move_to_device(inputs, device):
20
  else:
21
  return inputs.to(device)
22
 
 
23
  class UnionFind:
24
  def __init__(self, n):
25
  self.parent = list(range(n))
@@ -34,7 +36,7 @@ class UnionFind:
34
  if adj_matrix[i, j] > 0:
35
  ufds.unite(i, j)
36
  return ufds
37
-
38
  @classmethod
39
  def from_adj_list(cls, adj_list):
40
  ufds = cls(len(adj_list))
@@ -42,7 +44,7 @@ class UnionFind:
42
  for j in adj_list[i]:
43
  ufds.unite(i, j)
44
  return ufds
45
-
46
  @classmethod
47
  def from_edge_list(cls, edge_list, num_nodes):
48
  ufds = cls(num_nodes)
@@ -65,11 +67,11 @@ class UnionFind:
65
  self.parent[y] = x
66
  self.size[x] += self.size[y]
67
  self.num_components -= 1
68
-
69
  def get_components_of(self, x):
70
  x = self.find(x)
71
  return [i for i in range(len(self.parent)) if self.find(i) == x]
72
-
73
  def are_connected(self, x, y):
74
  return self.find(x) == self.find(y)
75
 
@@ -78,7 +80,7 @@ class UnionFind:
78
 
79
  def get_num_components(self):
80
  return self.num_components
81
-
82
  def get_labels_for_connected_components(self):
83
  map_parent_to_label = {}
84
  labels = []
@@ -89,32 +91,36 @@ class UnionFind:
89
  labels.append(map_parent_to_label[parent])
90
  return labels
91
 
 
92
  def visualise_single_image_prediction(image_as_np_array, predictions, filename):
93
  figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
94
  subplot.imshow(image_as_np_array)
95
  plot_bboxes(subplot, predictions["panels"], color="green")
96
- plot_bboxes(subplot, predictions["texts"], color="red", visibility=predictions["is_essential_text"])
 
97
  plot_bboxes(subplot, predictions["characters"], color="blue")
98
  plot_bboxes(subplot, predictions["tails"], color="purple")
99
 
100
  for i, name in enumerate(predictions["character_names"]):
101
  char_bbox = predictions["characters"][i]
102
  x1, y1, x2, y2 = char_bbox
103
- subplot.text(x1, y1 - 2, name,
104
- verticalalignment='bottom', horizontalalignment='left',
105
- bbox=dict(facecolor='blue', alpha=1, edgecolor='none'), # Background settings
106
- color='white', fontsize=8)
 
107
 
108
  COLOURS = [
109
- "#b7ff51", # green
110
- "#f50a8f", # pink
111
- "#4b13b6", # purple
112
- "#ddaa34", # orange
113
- "#bea2a2", # brown
114
  ]
115
  colour_index = 0
116
  character_cluster_labels = predictions["character_cluster_labels"]
117
- unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
 
118
  for label in unique_label_sorted_by_frequency:
119
  root = None
120
  others = []
@@ -127,7 +133,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
127
  if colour_index >= len(COLOURS):
128
  random_colour = COLOURS[0]
129
  while random_colour in COLOURS:
130
- random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
 
 
131
  else:
132
  random_colour = COLOURS[colour_index]
133
  colour_index += 1
@@ -143,8 +151,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
143
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
144
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
145
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
146
- subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
147
-
 
148
  for (i, j) in predictions["text_character_associations"]:
149
  bbox_i = predictions["texts"][i]
150
  bbox_j = predictions["characters"][j]
@@ -154,8 +163,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
154
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
155
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
156
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
157
- subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed")
158
-
 
159
  for (i, j) in predictions["text_tail_associations"]:
160
  bbox_i = predictions["texts"][i]
161
  bbox_j = predictions["tails"][j]
@@ -163,7 +173,8 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
163
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
164
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
165
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
166
- subplot.plot([x1, x2], [y1, y2], color="purple", linewidth=2, linestyle="dashed")
 
167
 
168
  subplot.axis("off")
169
  if filename is not None:
@@ -174,6 +185,7 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
174
  plt.close()
175
  return image
176
 
 
177
  def plot_bboxes(subplot, bboxes, color="red", visibility=None):
178
  if visibility is None:
179
  visibility = [1] * len(bboxes)
@@ -187,6 +199,7 @@ def plot_bboxes(subplot, bboxes, color="red", visibility=None):
187
  )
188
  subplot.add_patch(rect)
189
 
 
190
  def sort_panels(rects):
191
  before_rects = convert_to_list_of_lists(rects)
192
  # slightly erode all rectangles initially to account for imperfect detections
@@ -212,34 +225,42 @@ def sort_panels(rects):
212
  G.remove_edge(*max_cyclic_edge)
213
  return list(nx.topological_sort(G))
214
 
 
215
  def is_strictly_above(rectA, rectB):
216
  x1A, y1A, x2A, y2A = rectA
217
  x1B, y1B, x2B, y2B = rectB
218
  return y2A < y1B
219
 
 
220
  def is_strictly_below(rectA, rectB):
221
  x1A, y1A, x2A, y2A = rectA
222
  x1B, y1B, x2B, y2B = rectB
223
  return y2B < y1A
224
 
 
225
  def is_strictly_left_of(rectA, rectB):
226
  x1A, y1A, x2A, y2A = rectA
227
  x1B, y1B, x2B, y2B = rectB
228
  return x2A < x1B
229
 
 
230
  def is_strictly_right_of(rectA, rectB):
231
  x1A, y1A, x2A, y2A = rectA
232
  x1B, y1B, x2B, y2B = rectB
233
  return x2B < x1A
234
 
 
235
  def intersects(rectA, rectB):
236
  return box(*rectA).intersects(box(*rectB))
237
 
 
238
  def is_there_a_directed_edge(a, b, rects):
239
  rectA = rects[a]
240
  rectB = rects[b]
241
- centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
242
- centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
 
 
243
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
244
  return box(*rectA).area > (box(*rectB)).area
245
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
@@ -256,34 +277,41 @@ def is_there_a_directed_edge(a, b, rects):
256
  if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
257
  return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
258
  if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
259
- return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
260
  # otherwise they intersect
261
  copy_A = erode_rectangle(copy_A, 0.05)
262
  copy_B = erode_rectangle(copy_B, 0.05)
263
-
 
264
  def get_distance(rectA, rectB):
265
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
266
 
 
267
  def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
268
  rects = deepcopy(rects)
269
  while True:
270
- xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
271
- rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
272
- rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
273
-
 
 
 
274
  # try to split the panels using a "horizontal" lines
275
- overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
 
276
  panel_index_to_split = {}
277
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
278
  for i, index in enumerate(rect_index):
279
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
280
  panel_index_to_split[index] = split_index
281
-
282
  if panel_index_to_split[a] != panel_index_to_split[b]:
283
  return panel_index_to_split[a] < panel_index_to_split[b]
284
-
285
  # try to split the panels using a "vertical" lines
286
- overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
 
287
  panel_index_to_split = {}
288
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
289
  for i, index in enumerate(rect_index):
@@ -291,10 +319,11 @@ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
291
  panel_index_to_split[index] = split_index
292
  if panel_index_to_split[a] != panel_index_to_split[b]:
293
  return panel_index_to_split[a] < panel_index_to_split[b]
294
-
295
  # otherwise, erode the rectangles and try again
296
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
297
 
 
298
  def erode_rectangle(bbox, erosion_factor):
299
  x1, y1, x2, y2 = bbox
300
  w, h = x2 - x1, y2 - y1
@@ -312,6 +341,7 @@ def erode_rectangle(bbox, erosion_factor):
312
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
313
  return [x1, y1, x2, y2]
314
 
 
315
  def merge_overlapping_ranges(ranges):
316
  """
317
  ranges: list of tuples (x1, x2)
@@ -333,6 +363,7 @@ def merge_overlapping_ranges(ranges):
333
  merged_ranges.append((prev_x1, prev_x2))
334
  return merged_ranges
335
 
 
336
  def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
337
  text_bboxes = convert_to_list_of_lists(text_bboxes)
338
  sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
@@ -344,18 +375,23 @@ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
344
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
345
  return [list(indices) for _, indices in groups]
346
 
347
- panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
 
348
  indices_of_texts = list(range(len(text_bboxes)))
349
- indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
 
350
  indices_of_texts = list(indices_of_texts)
351
  grouped_indices = indices_of_same_elements(panel_id_for_text)
352
  for group in grouped_indices:
353
  subset_of_text_indices = [indices_of_texts[i] for i in group]
354
- text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
 
355
  sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
356
- indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
 
357
  return indices_of_texts
358
 
 
359
  def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
360
  text_to_panel_mapping = []
361
  for text_bbox in text_bboxes:
@@ -368,14 +404,19 @@ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
368
  for j, annotation in enumerate(sorted_panel_bboxes):
369
  shapely_annotation_polygon = box(*annotation)
370
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
371
- all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
372
- all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
 
 
373
  if len(all_intersections) == 0:
374
- text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
 
375
  else:
376
- text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
 
377
  return text_to_panel_mapping
378
 
 
379
  def sort_texts_within_panel(rects):
380
  smallest_y = float("inf")
381
  greatest_x = float("-inf")
@@ -383,29 +424,33 @@ def sort_texts_within_panel(rects):
383
  x1, y1, x2, y2 = rect
384
  smallest_y = min(smallest_y, y1)
385
  greatest_x = max(greatest_x, x2)
386
-
387
  reference_point = Point(greatest_x, smallest_y)
388
 
389
  polygons_and_index = []
390
  for i, rect in enumerate(rects):
391
  x1, y1, x2, y2 = rect
392
- polygons_and_index.append((box(x1,y1,x2,y2), i))
393
  # sort points by closest to reference point
394
- polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
 
395
  indices = [x[1] for x in polygons_and_index]
396
  return indices
397
 
 
398
  def x1y1wh_to_x1y1x2y2(bbox):
399
  x1, y1, w, h = bbox
400
  return [x1, y1, x1 + w, y1 + h]
401
 
 
402
  def x1y1x2y2_to_xywh(bbox):
403
  x1, y1, x2, y2 = bbox
404
  return [x1, y1, x2 - x1, y2 - y1]
405
 
 
406
  def convert_to_list_of_lists(rects):
407
  if isinstance(rects, torch.Tensor):
408
  return rects.tolist()
409
  if isinstance(rects, np.ndarray):
410
  return rects.tolist()
411
- return [[a, b, c, d] for a, b, c, d in rects]
 
8
  from copy import deepcopy
9
  from itertools import groupby
10
 
11
+
12
  def move_to_device(inputs, device):
13
  if hasattr(inputs, "keys"):
14
  return {k: move_to_device(v, device) for k, v in inputs.items()}
 
21
  else:
22
  return inputs.to(device)
23
 
24
+
25
  class UnionFind:
26
  def __init__(self, n):
27
  self.parent = list(range(n))
 
36
  if adj_matrix[i, j] > 0:
37
  ufds.unite(i, j)
38
  return ufds
39
+
40
  @classmethod
41
  def from_adj_list(cls, adj_list):
42
  ufds = cls(len(adj_list))
 
44
  for j in adj_list[i]:
45
  ufds.unite(i, j)
46
  return ufds
47
+
48
  @classmethod
49
  def from_edge_list(cls, edge_list, num_nodes):
50
  ufds = cls(num_nodes)
 
67
  self.parent[y] = x
68
  self.size[x] += self.size[y]
69
  self.num_components -= 1
70
+
71
  def get_components_of(self, x):
72
  x = self.find(x)
73
  return [i for i in range(len(self.parent)) if self.find(i) == x]
74
+
75
  def are_connected(self, x, y):
76
  return self.find(x) == self.find(y)
77
 
 
80
 
81
  def get_num_components(self):
82
  return self.num_components
83
+
84
  def get_labels_for_connected_components(self):
85
  map_parent_to_label = {}
86
  labels = []
 
91
  labels.append(map_parent_to_label[parent])
92
  return labels
93
 
94
+
95
  def visualise_single_image_prediction(image_as_np_array, predictions, filename):
96
  figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
97
  subplot.imshow(image_as_np_array)
98
  plot_bboxes(subplot, predictions["panels"], color="green")
99
+ plot_bboxes(subplot, predictions["texts"], color="red",
100
+ visibility=predictions["is_essential_text"])
101
  plot_bboxes(subplot, predictions["characters"], color="blue")
102
  plot_bboxes(subplot, predictions["tails"], color="purple")
103
 
104
  for i, name in enumerate(predictions["character_names"]):
105
  char_bbox = predictions["characters"][i]
106
  x1, y1, x2, y2 = char_bbox
107
+ subplot.text(x1, y1 - 2, name,
108
+ verticalalignment='bottom', horizontalalignment='left',
109
+ # Background settings
110
+ bbox=dict(facecolor='blue', alpha=1, edgecolor='none'),
111
+ color='white', fontsize=8)
112
 
113
  COLOURS = [
114
+ "#b7ff51", # green
115
+ "#f50a8f", # pink
116
+ "#4b13b6", # purple
117
+ "#ddaa34", # orange
118
+ "#bea2a2", # brown
119
  ]
120
  colour_index = 0
121
  character_cluster_labels = predictions["character_cluster_labels"]
122
+ unique_label_sorted_by_frequency = sorted(list(set(
123
+ character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
124
  for label in unique_label_sorted_by_frequency:
125
  root = None
126
  others = []
 
133
  if colour_index >= len(COLOURS):
134
  random_colour = COLOURS[0]
135
  while random_colour in COLOURS:
136
+ random_colour = "#" + \
137
+ "".join([random.choice("0123456789ABCDEF")
138
+ for j in range(6)])
139
  else:
140
  random_colour = COLOURS[colour_index]
141
  colour_index += 1
 
151
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
152
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
153
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
154
+ subplot.plot([x2], [y2], color=random_colour,
155
+ marker="o", markersize=5)
156
+
157
  for (i, j) in predictions["text_character_associations"]:
158
  bbox_i = predictions["texts"][i]
159
  bbox_j = predictions["characters"][j]
 
163
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
164
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
165
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
166
+ subplot.plot([x1, x2], [y1, y2], color="red",
167
+ linewidth=2, linestyle="dashed")
168
+
169
  for (i, j) in predictions["text_tail_associations"]:
170
  bbox_i = predictions["texts"][i]
171
  bbox_j = predictions["tails"][j]
 
173
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
174
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
175
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
176
+ subplot.plot([x1, x2], [y1, y2], color="purple",
177
+ linewidth=2, linestyle="dashed")
178
 
179
  subplot.axis("off")
180
  if filename is not None:
 
185
  plt.close()
186
  return image
187
 
188
+
189
  def plot_bboxes(subplot, bboxes, color="red", visibility=None):
190
  if visibility is None:
191
  visibility = [1] * len(bboxes)
 
199
  )
200
  subplot.add_patch(rect)
201
 
202
+
203
  def sort_panels(rects):
204
  before_rects = convert_to_list_of_lists(rects)
205
  # slightly erode all rectangles initially to account for imperfect detections
 
225
  G.remove_edge(*max_cyclic_edge)
226
  return list(nx.topological_sort(G))
227
 
228
+
229
  def is_strictly_above(rectA, rectB):
230
  x1A, y1A, x2A, y2A = rectA
231
  x1B, y1B, x2B, y2B = rectB
232
  return y2A < y1B
233
 
234
+
235
  def is_strictly_below(rectA, rectB):
236
  x1A, y1A, x2A, y2A = rectA
237
  x1B, y1B, x2B, y2B = rectB
238
  return y2B < y1A
239
 
240
+
241
  def is_strictly_left_of(rectA, rectB):
242
  x1A, y1A, x2A, y2A = rectA
243
  x1B, y1B, x2B, y2B = rectB
244
  return x2A < x1B
245
 
246
+
247
  def is_strictly_right_of(rectA, rectB):
248
  x1A, y1A, x2A, y2A = rectA
249
  x1B, y1B, x2B, y2B = rectB
250
  return x2B < x1A
251
 
252
+
253
  def intersects(rectA, rectB):
254
  return box(*rectA).intersects(box(*rectB))
255
 
256
+
257
  def is_there_a_directed_edge(a, b, rects):
258
  rectA = rects[a]
259
  rectB = rects[b]
260
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
261
+ rectA[1] + (rectA[3] - rectA[1]) / 2]
262
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
263
+ rectB[1] + (rectB[3] - rectB[1]) / 2]
264
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
265
  return box(*rectA).area > (box(*rectB)).area
266
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
 
277
  if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
278
  return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
279
  if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
280
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
281
  # otherwise they intersect
282
  copy_A = erode_rectangle(copy_A, 0.05)
283
  copy_B = erode_rectangle(copy_B, 0.05)
284
+
285
+
286
  def get_distance(rectA, rectB):
287
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
288
 
289
+
290
  def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
291
  rects = deepcopy(rects)
292
  while True:
293
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
294
+ rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
295
+ rect_index = [i for i in range(len(rects)) if intersects(
296
+ rects[i], [xmin, ymin, xmax, ymax])]
297
+ rects_copy = [rect for rect in rects if intersects(
298
+ rect, [xmin, ymin, xmax, ymax])]
299
+
300
  # try to split the panels using a "horizontal" lines
301
+ overlapping_y_ranges = merge_overlapping_ranges(
302
+ [(y1, y2) for x1, y1, x2, y2 in rects_copy])
303
  panel_index_to_split = {}
304
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
305
  for i, index in enumerate(rect_index):
306
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
307
  panel_index_to_split[index] = split_index
308
+
309
  if panel_index_to_split[a] != panel_index_to_split[b]:
310
  return panel_index_to_split[a] < panel_index_to_split[b]
311
+
312
  # try to split the panels using a "vertical" lines
313
+ overlapping_x_ranges = merge_overlapping_ranges(
314
+ [(x1, x2) for x1, y1, x2, y2 in rects_copy])
315
  panel_index_to_split = {}
316
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
317
  for i, index in enumerate(rect_index):
 
319
  panel_index_to_split[index] = split_index
320
  if panel_index_to_split[a] != panel_index_to_split[b]:
321
  return panel_index_to_split[a] < panel_index_to_split[b]
322
+
323
  # otherwise, erode the rectangles and try again
324
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
325
 
326
+
327
  def erode_rectangle(bbox, erosion_factor):
328
  x1, y1, x2, y2 = bbox
329
  w, h = x2 - x1, y2 - y1
 
341
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
342
  return [x1, y1, x2, y2]
343
 
344
+
345
  def merge_overlapping_ranges(ranges):
346
  """
347
  ranges: list of tuples (x1, x2)
 
363
  merged_ranges.append((prev_x1, prev_x2))
364
  return merged_ranges
365
 
366
+
367
  def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
368
  text_bboxes = convert_to_list_of_lists(text_bboxes)
369
  sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
 
375
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
376
  return [list(indices) for _, indices in groups]
377
 
378
+ panel_id_for_text = get_text_to_panel_mapping(
379
+ text_bboxes, sorted_panel_bboxes)
380
  indices_of_texts = list(range(len(text_bboxes)))
381
+ indices_of_texts, panel_id_for_text = zip(
382
+ *sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
383
  indices_of_texts = list(indices_of_texts)
384
  grouped_indices = indices_of_same_elements(panel_id_for_text)
385
  for group in grouped_indices:
386
  subset_of_text_indices = [indices_of_texts[i] for i in group]
387
+ text_bboxes_of_subset = [text_bboxes[i]
388
+ for i in subset_of_text_indices]
389
  sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
390
+ indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
391
+ for i in sorted_subset_indices]
392
  return indices_of_texts
393
 
394
+
395
  def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
396
  text_to_panel_mapping = []
397
  for text_bbox in text_bboxes:
 
404
  for j, annotation in enumerate(sorted_panel_bboxes):
405
  shapely_annotation_polygon = box(*annotation)
406
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
407
+ all_intersections.append(
408
+ (shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
409
+ all_distances.append(
410
+ (shapely_text_polygon.distance(shapely_annotation_polygon), j))
411
  if len(all_intersections) == 0:
412
+ text_to_panel_mapping.append(
413
+ min(all_distances, key=lambda x: x[0])[1])
414
  else:
415
+ text_to_panel_mapping.append(
416
+ max(all_intersections, key=lambda x: x[0])[1])
417
  return text_to_panel_mapping
418
 
419
+
420
  def sort_texts_within_panel(rects):
421
  smallest_y = float("inf")
422
  greatest_x = float("-inf")
 
424
  x1, y1, x2, y2 = rect
425
  smallest_y = min(smallest_y, y1)
426
  greatest_x = max(greatest_x, x2)
427
+
428
  reference_point = Point(greatest_x, smallest_y)
429
 
430
  polygons_and_index = []
431
  for i, rect in enumerate(rects):
432
  x1, y1, x2, y2 = rect
433
+ polygons_and_index.append((box(x1, y1, x2, y2), i))
434
  # sort points by closest to reference point
435
+ polygons_and_index = sorted(
436
+ polygons_and_index, key=lambda x: reference_point.distance(x[0]))
437
  indices = [x[1] for x in polygons_and_index]
438
  return indices
439
 
440
+
441
  def x1y1wh_to_x1y1x2y2(bbox):
442
  x1, y1, w, h = bbox
443
  return [x1, y1, x1 + w, y1 + h]
444
 
445
+
446
  def x1y1x2y2_to_xywh(bbox):
447
  x1, y1, x2, y2 = bbox
448
  return [x1, y1, x2 - x1, y2 - y1]
449
 
450
+
451
  def convert_to_list_of_lists(rects):
452
  if isinstance(rects, torch.Tensor):
453
  return rects.tolist()
454
  if isinstance(rects, np.ndarray):
455
  return rects.tolist()
456
+ return [[a, b, c, d] for a, b, c, d in rects]