Mateusz Mróz commited on
Commit
bfcd54f
·
1 Parent(s): cd77b9d

Add Magiv2 model configuration, processing, and utility functions

Browse files

- Implement Magiv2Config class for model configuration, supporting detection, OCR, and crop embeddings.
- Create Magiv2Processor class for preprocessing inputs for detection, OCR, and crop embeddings.
- Add utility functions for handling bounding boxes, including cropping, sorting, and visualizing predictions.
- Introduce UnionFind class for managing connected components in bounding box graphs.
- Implement functions for converting annotation formats and managing text-to-panel mappings.

Files changed (5) hide show
  1. configuration_magiv2_PRE.py +131 -0
  2. processing_magiv2.py +364 -65
  3. processing_magiv2_PRE.py +225 -0
  4. utils.py +867 -152
  5. utils_PRE.py +456 -0
configuration_magiv2_PRE.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, VisionEncoderDecoderConfig
2
+ from typing import Any, Optional
3
+
4
+
5
+ class Magiv2Config(PretrainedConfig):
6
+ """
7
+ Klasa konfiguracyjna dla modelu Magiv2.
8
+
9
+ Magiv2Config dziedziczy po PretrainedConfig z biblioteki transformers i definiuje
10
+ kompletną konfigurację dla modelu wizyjnego składającego się z trzech głównych komponentów:
11
+ - Model detekcji obiektów (detection)
12
+ - Model OCR (rozpoznawanie tekstu)
13
+ - Model embedowania wyciętych fragmentów obrazu (crop embeddings)
14
+
15
+ Attributes:
16
+ model_type: Identyfikator typu modelu dla biblioteki transformers
17
+ disable_ocr: Flaga wyłączająca moduł OCR
18
+ disable_crop_embeddings: Flaga wyłączająca moduł embedowania wyciętych fragmentów
19
+ disable_detections: Flaga wyłączająca moduł detekcji obiektów
20
+ detection_model_config: Konfiguracja modelu detekcji (po deserializacji)
21
+ ocr_model_config: Konfiguracja modelu OCR (po deserializacji)
22
+ crop_embedding_model_config: Konfiguracja modelu embedowania (po deserializacji)
23
+ detection_image_preprocessing_config: Parametry przetwarzania obrazu dla detekcji
24
+ ocr_pretrained_processor_path: Ścieżka do wytrenowanego procesora OCR
25
+ crop_embedding_image_preprocessing_config: Parametry przetwarzania obrazu dla embedowania
26
+ """
27
+
28
+ # Identyfikator typu modelu używany przez bibliotekę transformers
29
+ model_type: str = "magiv2"
30
+
31
+ def __init__(
32
+ self,
33
+ disable_ocr: bool = False,
34
+ disable_crop_embeddings: bool = False,
35
+ disable_detections: bool = False,
36
+ detection_model_config: Optional[dict[str, Any]] = None,
37
+ ocr_model_config: Optional[dict[str, Any]] = None,
38
+ crop_embedding_model_config: Optional[dict[str, Any]] = None,
39
+ detection_image_preprocessing_config: Optional[dict[str, Any]] = None,
40
+ ocr_pretrained_processor_path: Optional[str] = None,
41
+ crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = None,
42
+ **kwargs: Any,
43
+ ) -> None:
44
+ """
45
+ Inicjalizuje konfigurację modelu Magiv2.
46
+
47
+ Konstruktor przyjmuje parametry kontrolujące które moduły modelu są aktywne,
48
+ oraz konfiguracje dla poszczególnych komponentów. Konfiguracje przekazane jako
49
+ słowniki są deserializowane do odpowiednich obiektów Config z transformers.
50
+
51
+ Args:
52
+ disable_ocr: Czy wyłączyć moduł rozpoznawania tekstu (OCR).
53
+ Domyślnie False - OCR jest aktywne.
54
+ disable_crop_embeddings: Czy wyłączyć moduł tworzenia embeddingów dla wyciętych
55
+ fragmentów obrazu. Domyślnie False - embedowanie aktywne.
56
+ disable_detections: Czy wyłączyć moduł detekcji obiektów na obrazie.
57
+ Domyślnie False - detekcja aktywna.
58
+ detection_model_config: Słownik z konfiguracją modelu detekcji obiektów.
59
+ Jeśli podany, zostanie zdeserializowany do PretrainedConfig.
60
+ ocr_model_config: Słownik z konfiguracją modelu OCR (encoder-decoder).
61
+ Jeśli podany, zostanie zdeserializowany do VisionEncoderDecoderConfig.
62
+ crop_embedding_model_config: Słownik z konfiguracją modelu embedowania wyciętych
63
+ fragmentów. Jeśli podany, zostanie zdeserializowany
64
+ do PretrainedConfig.
65
+ detection_image_preprocessing_config: Słownik z parametrami preprocessingu obrazu
66
+ dla modułu detekcji (np. rozmiar, normalizacja).
67
+ ocr_pretrained_processor_path: Ścieżka do katalogu lub Hub ID z wytrenowanym
68
+ procesorem obrazu dla modułu OCR.
69
+ crop_embedding_image_preprocessing_config: Słownik z parametrami preprocessingu
70
+ obrazu dla modułu embedowania.
71
+ **kwargs: Dodatkowe argumenty przekazywane do klasy bazowej PretrainedConfig.
72
+
73
+ Returns:
74
+ None
75
+
76
+ Note:
77
+ - Konfiguracje modeli są deserializowane z dict do obiektów Config tylko wtedy,
78
+ gdy zostały przekazane (nie są None)
79
+ - Flagi disable_* pozwalają na selektywne wyłączanie poszczególnych modułów
80
+ - Wszystkie dodatkowe kwargs są przekazywane do klasy bazowej PretrainedConfig
81
+ """
82
+ # Przechowywanie flag wyłączających poszczególne moduły
83
+ self.disable_ocr: bool = disable_ocr
84
+ self.disable_crop_embeddings: bool = disable_crop_embeddings
85
+ self.disable_detections: bool = disable_detections
86
+
87
+ # Przechowywanie dodatkowych argumentów przekazanych do konstruktora
88
+ self.kwargs: dict[str, Any] = kwargs
89
+
90
+ # Inicjalizacja atrybutów konfiguracji modeli jako None
91
+ # (mog�� zostać zdeserializowane poniżej jeśli parametry nie są None)
92
+ self.detection_model_config: Optional[PretrainedConfig] = None
93
+ self.ocr_model_config: Optional[VisionEncoderDecoderConfig] = None
94
+ self.crop_embedding_model_config: Optional[PretrainedConfig] = None
95
+
96
+ # Deserializacja konfiguracji modelu detekcji ze słownika do obiektu PretrainedConfig
97
+ if detection_model_config is not None:
98
+ self.detection_model_config = PretrainedConfig.from_dict(
99
+ detection_model_config
100
+ )
101
+
102
+ # Deserializacja konfiguracji modelu OCR ze słownika do obiektu VisionEncoderDecoderConfig
103
+ # OCR wykorzystuje architekturę encoder-decoder (vision encoder + text decoder)
104
+ if ocr_model_config is not None:
105
+ self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(
106
+ ocr_model_config
107
+ )
108
+
109
+ # Deserializacja konfiguracji modelu embedowania ze słownika do obiektu PretrainedConfig
110
+ if crop_embedding_model_config is not None:
111
+ self.crop_embedding_model_config = PretrainedConfig.from_dict(
112
+ crop_embedding_model_config
113
+ )
114
+
115
+ # Przechowywanie konfiguracji preprocessingu obrazu dla modułu detekcji
116
+ # (np. docelowy rozmiar obrazu, parametry normalizacji, augmentacje)
117
+ self.detection_image_preprocessing_config: Optional[dict[str, Any]] = (
118
+ detection_image_preprocessing_config
119
+ )
120
+
121
+ # Ścieżka do wytrenowanego procesora OCR (może być lokalna lub z Hugging Face Hub)
122
+ self.ocr_pretrained_processor_path: Optional[str] = ocr_pretrained_processor_path
123
+
124
+ # Przechowywanie konfiguracji preprocessingu obrazu dla modułu embedowania
125
+ # (np. docelowy rozmiar wycięć, parametry normalizacji)
126
+ self.crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = (
127
+ crop_embedding_image_preprocessing_config
128
+ )
129
+
130
+ # Wywołanie konstruktora klasy bazowej PretrainedConfig z dodatkowymi kwargs
131
+ super().__init__(**kwargs)
processing_magiv2.py CHANGED
@@ -1,118 +1,325 @@
1
  from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
2
  import torch
3
- from typing import List
4
  from shapely.geometry import box
 
5
  from .utils import x1y1x2y2_to_xywh
6
  import numpy as np
 
7
 
8
 
9
  class Magiv2Processor():
10
- def __init__(self, config):
11
- self.config = config
12
- self.detection_image_preprocessor = None
13
- self.ocr_preprocessor = None
14
- self.crop_embedding_image_preprocessor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  if not config.disable_detections:
16
  assert config.detection_image_preprocessing_config is not None
17
  self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(
18
  config.detection_image_preprocessing_config)
 
 
19
  if not config.disable_ocr:
20
  assert config.ocr_pretrained_processor_path is not None
21
  self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
22
  config.ocr_pretrained_processor_path)
 
 
23
  if not config.disable_crop_embeddings:
24
  assert config.crop_embedding_image_preprocessing_config is not None
25
  self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
26
  config.crop_embedding_image_preprocessing_config)
27
 
28
- def preprocess_inputs_for_detection(self, images, annotations=None):
29
- images = list(images)
30
- assert isinstance(images[0], np.ndarray)
31
- annotations = self._convert_annotations_to_coco_format(annotations)
32
- inputs = self.detection_image_preprocessor(
33
- images, annotations=annotations, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return inputs
35
 
36
- def preprocess_inputs_for_ocr(self, images):
37
- images = list(images)
38
- assert isinstance(images[0], np.ndarray)
39
- return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- def preprocess_inputs_for_crop_embeddings(self, images):
42
- images = list(images)
43
- assert isinstance(images[0], np.ndarray)
44
- return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
45
 
46
- def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
48
 
49
- def crop_image(self, image, bboxes):
50
- crops_for_image = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  for bbox in bboxes:
 
 
 
 
52
  x1, y1, x2, y2 = bbox
53
 
54
- # fix the bounding box in case it is out of bounds or too small
 
55
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
56
- x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(
57
- x1, x2), max(y1, y2) # just incase
 
58
  x1, y1 = max(0, x1), max(0, y1)
59
  x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
 
60
  x2, y2 = max(0, x2), max(0, y2)
61
  x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
 
 
62
  if x2 - x1 < 10:
63
  if image.shape[1] - x1 > 10:
64
  x2 = x1 + 10
65
  else:
66
  x1 = x2 - 10
 
 
67
  if y2 - y1 < 10:
68
  if image.shape[0] - y1 > 10:
69
  y2 = y1 + 10
70
  else:
71
  y1 = y2 - 10
72
 
73
- crop = image[y1:y2, x1:x2]
 
74
  crops_for_image.append(crop)
75
  return crops_for_image
76
 
77
- def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
78
- indices_of_characters_to_keep = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
80
- indices = torch.where((labels == 0) & (
 
81
  scores > character_detection_threshold))[0]
82
  indices_of_characters_to_keep.append(indices)
83
  return indices_of_characters_to_keep
84
 
85
- def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
86
- indices_of_panels_to_keep = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
88
- indices = torch.where(labels == 2)[0]
 
89
  bboxes = bboxes[indices]
90
  scores = scores[indices]
91
  labels = labels[indices]
92
  if len(indices) == 0:
93
  indices_of_panels_to_keep.append([])
94
  continue
 
 
95
  scores, labels, indices, bboxes = zip(
96
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
97
- panels_to_keep = []
98
- union_of_panels_so_far = box(0, 0, 0, 0)
 
 
 
 
99
  for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
100
- panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
 
 
 
101
  if ps < panel_detection_threshold:
102
  continue
 
 
103
  if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
104
  continue
 
 
105
  panels_to_keep.append((ps, pl, pb, pi))
 
106
  union_of_panels_so_far = union_of_panels_so_far.union(
107
  panel_polygon)
 
 
108
  indices_of_panels_to_keep.append(
109
  [p[3].item() for p in panels_to_keep])
110
  return indices_of_panels_to_keep
111
 
112
- def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
113
- indices_of_texts_to_keep = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
115
- indices = torch.where((labels == 1) & (
 
116
  scores > text_detection_threshold))[0]
117
  bboxes = bboxes[indices]
118
  scores = scores[indices]
@@ -120,74 +327,159 @@ class Magiv2Processor():
120
  if len(indices) == 0:
121
  indices_of_texts_to_keep.append([])
122
  continue
 
 
123
  scores, labels, indices, bboxes = zip(
124
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
125
- texts_to_keep = []
126
- texts_to_keep_as_shapely_objects = []
 
 
 
 
127
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
128
- text_polygon = box(tb[0], tb[1], tb[2], tb[3])
129
- should_append = True
 
 
 
130
  for t in texts_to_keep_as_shapely_objects:
 
131
  if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
132
  should_append = False
133
  break
 
134
  if should_append:
135
  texts_to_keep.append((ts, tl, tb, ti))
136
  texts_to_keep_as_shapely_objects.append(text_polygon)
 
 
137
  indices_of_texts_to_keep.append(
138
  [t[3].item() for t in texts_to_keep])
139
  return indices_of_texts_to_keep
140
 
141
- def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
142
- indices_of_texts_to_keep = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
144
- indices = torch.where((labels == 3) & (
 
145
  scores > text_detection_threshold))[0]
146
  bboxes = bboxes[indices]
147
  scores = scores[indices]
148
  labels = labels[indices]
149
  if len(indices) == 0:
150
- indices_of_texts_to_keep.append([])
151
  continue
 
 
152
  scores, labels, indices, bboxes = zip(
153
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
154
- texts_to_keep = []
155
- texts_to_keep_as_shapely_objects = []
 
 
 
 
156
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
157
- text_polygon = box(tb[0], tb[1], tb[2], tb[3])
158
- should_append = True
159
- for t in texts_to_keep_as_shapely_objects:
160
- if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
 
 
 
 
161
  should_append = False
162
  break
 
163
  if should_append:
164
- texts_to_keep.append((ts, tl, tb, ti))
165
- texts_to_keep_as_shapely_objects.append(text_polygon)
166
- indices_of_texts_to_keep.append(
167
- [t[3].item() for t in texts_to_keep])
168
- return indices_of_texts_to_keep
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- def _convert_annotations_to_coco_format(self, annotations):
 
 
 
 
 
 
 
171
  if annotations is None:
172
  return None
 
173
  self._verify_annotations_are_in_correct_format(annotations)
174
- coco_annotations = []
 
175
  for annotation in annotations:
176
- coco_annotation = {
177
  "image_id": annotation["image_id"],
178
  "annotations": [],
179
  }
 
180
  for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
181
  coco_annotation["annotations"].append({
 
182
  "bbox": x1y1x2y2_to_xywh(bbox),
183
  "category_id": label,
 
184
  "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
185
  })
186
  coco_annotations.append(coco_annotation)
187
  return coco_annotations
188
 
189
- def _verify_annotations_are_in_correct_format(self, annotations):
190
- error_msg = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  Annotations must be in the following format:
192
  [
193
  {
@@ -197,20 +489,27 @@ class Magiv2Processor():
197
  },
198
  ...
199
  ]
200
- Labels: 0 for characters, 1 for text, 2 for panels.
201
  """
202
  if annotations is None:
203
  return
 
 
204
  if not isinstance(annotations, List) and not isinstance(annotations, tuple):
205
  raise ValueError(
206
  f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
207
  )
 
208
  if len(annotations) == 0:
209
  return
 
 
210
  if not isinstance(annotations[0], dict):
211
  raise ValueError(
212
- f"{error_msg} Expected a List[Dicct], found {type(annotations[0])}."
213
  )
 
 
214
  if "image_id" not in annotations[0]:
215
  raise ValueError(
216
  f"{error_msg} Dict must contain 'image_id'."
 
1
  from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
2
  import torch
3
+ from typing import List, Dict, Any, Optional, Tuple
4
  from shapely.geometry import box
5
+ from shapely.geometry.polygon import Polygon
6
  from .utils import x1y1x2y2_to_xywh
7
  import numpy as np
8
+ from numpy.typing import NDArray
9
 
10
 
11
  class Magiv2Processor():
12
+ """
13
+ Procesor danych dla modelu Magiv2 - obsługuje preprocessing i postprocessing.
14
+
15
+ Klasa odpowiedzialna za przygotowanie danych wejściowych dla różnych modułów
16
+ Magiv2 (detekcja, OCR, embeddingi) oraz przetwarzanie outputów. Zawiera również
17
+ metody pomocnicze do filtrowania detekcji i konwersji formatów anotacji.
18
+
19
+ Attributes:
20
+ config: Konfiguracja modelu Magiv2
21
+ detection_image_preprocessor: Preprocessor dla obrazów do detekcji obiektów
22
+ ocr_preprocessor: Preprocessor dla obrazów do OCR
23
+ crop_embedding_image_preprocessor: Preprocessor dla wyciętych fragmentów obrazu
24
+ """
25
+
26
+ def __init__(self, config: Any) -> None:
27
+ """
28
+ Inicjalizuje procesor z podaną konfiguracją.
29
+
30
+ Tworzy preprocessory dla modułów, które są aktywne zgodnie z konfiguracją:
31
+ - Detekcja obiektów: ConditionalDetrImageProcessor
32
+ - OCR: TrOCRProcessor
33
+ - Embeddingi crops: ViTImageProcessor
34
+
35
+ Args:
36
+ config: Obiekt konfiguracji Magiv2Config z parametrami preprocessingu
37
+ """
38
+ self.config: Any = config
39
+ self.detection_image_preprocessor: Optional[ConditionalDetrImageProcessor] = None
40
+ self.ocr_preprocessor: Optional[TrOCRProcessor] = None
41
+ self.crop_embedding_image_preprocessor: Optional[ViTImageProcessor] = None
42
+
43
+ # Inicjalizacja preprocessora dla detekcji obiektów (jeśli aktywny)
44
  if not config.disable_detections:
45
  assert config.detection_image_preprocessing_config is not None
46
  self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(
47
  config.detection_image_preprocessing_config)
48
+
49
+ # Inicjalizacja preprocessora dla OCR (jeśli aktywny)
50
  if not config.disable_ocr:
51
  assert config.ocr_pretrained_processor_path is not None
52
  self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
53
  config.ocr_pretrained_processor_path)
54
+
55
+ # Inicjalizacja preprocessora dla embeddingów crops (jeśli aktywny)
56
  if not config.disable_crop_embeddings:
57
  assert config.crop_embedding_image_preprocessing_config is not None
58
  self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
59
  config.crop_embedding_image_preprocessing_config)
60
 
61
+ def preprocess_inputs_for_detection(
62
+ self,
63
+ images: List[NDArray[np.uint8]],
64
+ annotations: Optional[List[Dict[str, Any]]] = None
65
+ ) -> Dict[str, torch.Tensor]:
66
+ """
67
+ Preprocessuje obrazy do formatu wymaganego przez moduł detekcji obiektów.
68
+
69
+ Wykonuje normalizację, resize i padding obrazów. Jeśli podano anotacje,
70
+ konwertuje je do formatu COCO i skaluje współrzędnie bbox zgodnie z resize.
71
+
72
+ Args:
73
+ images: Lista obrazów jako numpy arrays (format HWC)
74
+ annotations: Opcjonalne anotacje ground truth w formacie:
75
+ [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}]
76
+
77
+ Returns:
78
+ Słownik z kluczami:
79
+ - "pixel_values": torch.Tensor z preprocessowanymi obrazami
80
+ - "pixel_mask": torch.Tensor z maską paddingu
81
+ - "labels": List[Dict] z przetworzonymi anotacjami (jeśli podano)
82
+ """
83
+ images_list: List[NDArray[np.uint8]] = list(images)
84
+ assert isinstance(images_list[0], np.ndarray)
85
+ # Konwersja anotacji do formatu COCO (bbox w formacie xywh zamiast x1y1x2y2)
86
+ coco_annotations: Optional[List[Dict[str, Any]]
87
+ ] = self._convert_annotations_to_coco_format(annotations)
88
+ # Preprocessing obrazów i anotacji
89
+ inputs: Dict[str, torch.Tensor] = self.detection_image_preprocessor(
90
+ images_list, annotations=coco_annotations, return_tensors="pt")
91
  return inputs
92
 
93
+ def preprocess_inputs_for_ocr(self, images: List[NDArray[np.uint8]]) -> torch.Tensor:
94
+ """
95
+ Preprocessuje obrazy do formatu wymaganego przez moduł OCR.
96
+
97
+ Wykonuje normalizację i resize obrazów tekstowych dla modelu TrOCR.
98
+
99
+ Args:
100
+ images: Lista obrazów jako numpy arrays (fragmenty z tekstem)
101
+
102
+ Returns:
103
+ Tensor z preprocessowanymi obrazami [batch, channels, height, width]
104
+ """
105
+ images_list: List[NDArray[np.uint8]] = list(images)
106
+ assert isinstance(images_list[0], np.ndarray)
107
+ return self.ocr_preprocessor(images_list, return_tensors="pt").pixel_values
108
+
109
+ def preprocess_inputs_for_crop_embeddings(self, images: List[NDArray[np.uint8]]) -> torch.Tensor:
110
+ """
111
+ Preprocessuje wycięte fragmenty obrazów dla modułu embeddingów.
112
 
113
+ Wykonuje normalizację i resize crops dla modelu ViT-MAE.
114
+
115
+ Args:
116
+ images: Lista wyciętych fragmentów obrazów jako numpy arrays
117
+
118
+ Returns:
119
+ Tensor z preprocessowanymi crops [batch, channels, height, width]
120
+ """
121
+ images_list: List[NDArray[np.uint8]] = list(images)
122
+ assert isinstance(images_list[0], np.ndarray)
123
+ return self.crop_embedding_image_preprocessor(images_list, return_tensors="pt").pixel_values
124
 
125
+ def postprocess_ocr_tokens(
126
+ self,
127
+ generated_ids: torch.Tensor,
128
+ skip_special_tokens: bool = True
129
+ ) -> List[str]:
130
+ """
131
+ Dekoduje tokeny wygenerowane przez model OCR na tekst.
132
+
133
+ Args:
134
+ generated_ids: Tensor z ID tokenów wygenerowanych przez decoder OCR
135
+ skip_special_tokens: Czy pomijać specjalne tokeny (PAD, BOS, EOS) w wyniku
136
+
137
+ Returns:
138
+ Lista stringów z rozpoznanym tekstem
139
+ """
140
  return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
141
 
142
+ def crop_image(
143
+ self,
144
+ image: NDArray[np.uint8],
145
+ bboxes: List[List[float]]
146
+ ) -> List[NDArray[np.uint8]]:
147
+ """
148
+ Wycina fragmenty obrazu zgodnie z podanymi bounding boxami.
149
+
150
+ Metoda automatycznie naprawia nieprawidłowe bounding boxy:
151
+ - Ogranicza współrzędne do granic obrazu
152
+ - Zapewnia minimalny rozmiar 10x10 pikseli
153
+ - Zamienia współrzędne jeśli są w nieprawidłowej kolejności
154
+
155
+ Args:
156
+ image: Obraz źródłowy jako numpy array (format HWC)
157
+ bboxes: Lista bounding boxów w formacie [x1, y1, x2, y2]
158
+
159
+ Returns:
160
+ Lista wyciętych fragmentów obrazu (każdy jako numpy array)
161
+ """
162
+ crops_for_image: List[NDArray[np.uint8]] = []
163
  for bbox in bboxes:
164
+ x1: float
165
+ y1: float
166
+ x2: float
167
+ y2: float
168
  x1, y1, x2, y2 = bbox
169
 
170
+ # Naprawa bounding boxa w przypadku gdy jest poza granicami lub za mały
171
+ # Konwersja do int
172
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
173
+ # Upewnienie się że x1<x2 i y1<y2 (na wypadek odwróconej kolejności)
174
+ x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2)
175
+ # Ograniczenie do granic obrazu (minimum)
176
  x1, y1 = max(0, x1), max(0, y1)
177
  x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
178
+ # Ograniczenie do granic obrazu (maksimum)
179
  x2, y2 = max(0, x2), max(0, y2)
180
  x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
181
+
182
+ # Zapewnienie minimalnej szerokości 10 pikseli
183
  if x2 - x1 < 10:
184
  if image.shape[1] - x1 > 10:
185
  x2 = x1 + 10
186
  else:
187
  x1 = x2 - 10
188
+
189
+ # Zapewnienie minimalnej wysokości 10 pikseli
190
  if y2 - y1 < 10:
191
  if image.shape[0] - y1 > 10:
192
  y2 = y1 + 10
193
  else:
194
  y1 = y2 - 10
195
 
196
+ # Wycięcie fragmentu obrazu
197
+ crop: NDArray[np.uint8] = image[y1:y2, x1:x2]
198
  crops_for_image.append(crop)
199
  return crops_for_image
200
 
201
+ def _get_indices_of_characters_to_keep(
202
+ self,
203
+ batch_scores: torch.Tensor,
204
+ batch_labels: torch.Tensor,
205
+ batch_bboxes: torch.Tensor,
206
+ character_detection_threshold: float
207
+ ) -> List[torch.Tensor]:
208
+ """
209
+ Filtruje detekcje postaci na podstawie progu prawdopodobieństwa.
210
+
211
+ Zachowuje tylko detekcje z etykietą 0 (postać) i score powyżej progu.
212
+
213
+ Args:
214
+ batch_scores: Tensor ze scorami prawdopodobieństwa [batch, num_queries]
215
+ batch_labels: Tensor z etykietami klas [batch, num_queries]
216
+ batch_bboxes: Tensor z bounding boxami [batch, num_queries, 4]
217
+ character_detection_threshold: Minimalny score do zachowania detekcji (0-1)
218
+
219
+ Returns:
220
+ Lista tensorów z indeksami postaci do zachowania dla każdego obrazu
221
+ """
222
+ indices_of_characters_to_keep: List[torch.Tensor] = []
223
  for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
224
+ # Filtrowanie: label=0 (postać) AND score > próg
225
+ indices: torch.Tensor = torch.where((labels == 0) & (
226
  scores > character_detection_threshold))[0]
227
  indices_of_characters_to_keep.append(indices)
228
  return indices_of_characters_to_keep
229
 
230
+ def _get_indices_of_panels_to_keep(
231
+ self,
232
+ batch_scores: torch.Tensor,
233
+ batch_labels: torch.Tensor,
234
+ batch_bboxes: torch.Tensor,
235
+ panel_detection_threshold: float
236
+ ) -> List[List[int]]:
237
+ """
238
+ Filtruje detekcje paneli z zastosowaniem NMS (Non-Maximum Suppression).
239
+
240
+ Zachowuje tylko panele z etykietą 2 i score powyżej progu. Dodatkowo
241
+ stosuje NMS aby usunąć nakładające się panele - jeśli nowy panel
242
+ pokrywa się w >50% z już zaakceptowanymi panelami, jest odrzucany.
243
+
244
+ Args:
245
+ batch_scores: Tensor ze scorami [batch, num_queries]
246
+ batch_labels: Tensor z etykietami [batch, num_queries]
247
+ batch_bboxes: Tensor z bboxami [batch, num_queries, 4]
248
+ panel_detection_threshold: Minimalny score do zachowania panelu
249
+
250
+ Returns:
251
+ Lista list indeksów paneli do zachowania (po NMS) dla każdego obrazu
252
+ """
253
+ indices_of_panels_to_keep: List[List[int]] = []
254
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
255
+ # Wybranie tylko detekcji z label=2 (panel)
256
+ indices: torch.Tensor = torch.where(labels == 2)[0]
257
  bboxes = bboxes[indices]
258
  scores = scores[indices]
259
  labels = labels[indices]
260
  if len(indices) == 0:
261
  indices_of_panels_to_keep.append([])
262
  continue
263
+
264
+ # Sortowanie paneli malejąco po score (najlepsze pierwsze)
265
  scores, labels, indices, bboxes = zip(
266
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
267
+
268
+ panels_to_keep: List[Tuple[torch.Tensor,
269
+ torch.Tensor, torch.Tensor, torch.Tensor]] = []
270
+ # Unia wszystkich zaakceptowanych paneli (do sprawdzania nakładania)
271
+ union_of_panels_so_far: Polygon = box(0, 0, 0, 0)
272
+
273
  for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
274
+ # Konwersja bbox na polygon Shapely
275
+ panel_polygon: Polygon = box(pb[0], pb[1], pb[2], pb[3])
276
+
277
+ # Odrzuć jeśli score poniżej progu
278
  if ps < panel_detection_threshold:
279
  continue
280
+
281
+ # Odrzuć jeśli panel nakłada się >50% z już zaakceptowanymi panelami (NMS)
282
  if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
283
  continue
284
+
285
+ # Zaakceptuj panel
286
  panels_to_keep.append((ps, pl, pb, pi))
287
+ # Dodaj do unii zaakceptowanych paneli
288
  union_of_panels_so_far = union_of_panels_so_far.union(
289
  panel_polygon)
290
+
291
+ # Wyciągnięcie indeksów zaakceptowanych paneli
292
  indices_of_panels_to_keep.append(
293
  [p[3].item() for p in panels_to_keep])
294
  return indices_of_panels_to_keep
295
 
296
+ def _get_indices_of_texts_to_keep(
297
+ self,
298
+ batch_scores: torch.Tensor,
299
+ batch_labels: torch.Tensor,
300
+ batch_bboxes: torch.Tensor,
301
+ text_detection_threshold: float
302
+ ) -> List[List[int]]:
303
+ """
304
+ Filtruje detekcje tekstu z zastosowaniem NMS (Non-Maximum Suppression).
305
+
306
+ Zachowuje tylko tekst z etykietą 1 i score powyżej progu. Stosuje NMS
307
+ aby usunąć duplikaty - jeśli nowy tekst ma IoU >0.5 z już zaakceptowanym
308
+ tekstem, jest odrzucany.
309
+
310
+ Args:
311
+ batch_scores: Tensor ze scorami [batch, num_queries]
312
+ batch_labels: Tensor z etykietami [batch, num_queries]
313
+ batch_bboxes: Tensor z bboxami [batch, num_queries, 4]
314
+ text_detection_threshold: Minimalny score do zachowania tekstu
315
+
316
+ Returns:
317
+ Lista list indeksów tekstów do zachowania (po NMS) dla każdego obrazu
318
+ """
319
+ indices_of_texts_to_keep: List[List[int]] = []
320
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
321
+ # Filtrowanie: label=1 (tekst) AND score > próg
322
+ indices: torch.Tensor = torch.where((labels == 1) & (
323
  scores > text_detection_threshold))[0]
324
  bboxes = bboxes[indices]
325
  scores = scores[indices]
 
327
  if len(indices) == 0:
328
  indices_of_texts_to_keep.append([])
329
  continue
330
+
331
+ # Sortowanie tekstów malejąco po score (najlepsze pierwsze)
332
  scores, labels, indices, bboxes = zip(
333
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
334
+
335
+ texts_to_keep: List[Tuple[torch.Tensor,
336
+ torch.Tensor, torch.Tensor, torch.Tensor]] = []
337
+ # Lista polygonów zaakceptowanych tekstów (do sprawdzania nakładania)
338
+ texts_to_keep_as_shapely_objects: List[Polygon] = []
339
+
340
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
341
+ # Konwersja bbox na polygon Shapely
342
+ text_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3])
343
+ should_append: bool = True
344
+
345
+ # Sprawdź nakładanie z już zaakceptowanymi tekstami
346
  for t in texts_to_keep_as_shapely_objects:
347
+ # Jeśli IoU > 0.5, odrzuć (to duplikat)
348
  if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
349
  should_append = False
350
  break
351
+
352
  if should_append:
353
  texts_to_keep.append((ts, tl, tb, ti))
354
  texts_to_keep_as_shapely_objects.append(text_polygon)
355
+
356
+ # Wyciągnięcie indeksów zaakceptowanych tekstów
357
  indices_of_texts_to_keep.append(
358
  [t[3].item() for t in texts_to_keep])
359
  return indices_of_texts_to_keep
360
 
361
+ def _get_indices_of_tails_to_keep(
362
+ self,
363
+ batch_scores: torch.Tensor,
364
+ batch_labels: torch.Tensor,
365
+ batch_bboxes: torch.Tensor,
366
+ text_detection_threshold: float
367
+ ) -> List[List[int]]:
368
+ """
369
+ Filtruje detekcje ogonów dymków z zastosowaniem NMS (Non-Maximum Suppression).
370
+
371
+ Zachowuje tylko ogony z etykietą 3 i score powyżej progu. Stosuje NMS
372
+ aby usunąć duplikaty - jeśli nowy ogon ma IoU >0.5 z już zaakceptowanym
373
+ ogonem, jest odrzucany.
374
+
375
+ Args:
376
+ batch_scores: Tensor ze scorami [batch, num_queries]
377
+ batch_labels: Tensor z etykietami [batch, num_queries]
378
+ batch_bboxes: Tensor z bboxami [batch, num_queries, 4]
379
+ text_detection_threshold: Minimalny score do zachowania ogona
380
+
381
+ Returns:
382
+ Lista list indeksów ogonów do zachowania (po NMS) dla każdego obrazu
383
+ """
384
+ indices_of_tails_to_keep: List[List[int]] = []
385
  for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
386
+ # Filtrowanie: label=3 (ogon dymku) AND score > próg
387
+ indices: torch.Tensor = torch.where((labels == 3) & (
388
  scores > text_detection_threshold))[0]
389
  bboxes = bboxes[indices]
390
  scores = scores[indices]
391
  labels = labels[indices]
392
  if len(indices) == 0:
393
+ indices_of_tails_to_keep.append([])
394
  continue
395
+
396
+ # Sortowanie ogonów malejąco po score (najlepsze pierwsze)
397
  scores, labels, indices, bboxes = zip(
398
  *sorted(zip(scores, labels, indices, bboxes), reverse=True))
399
+
400
+ tails_to_keep: List[Tuple[torch.Tensor,
401
+ torch.Tensor, torch.Tensor, torch.Tensor]] = []
402
+ # Lista polygonów zaakceptowanych ogonów (do sprawdzania nakładania)
403
+ tails_to_keep_as_shapely_objects: List[Polygon] = []
404
+
405
  for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
406
+ # Konwersja bbox na polygon Shapely
407
+ tail_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3])
408
+ should_append: bool = True
409
+
410
+ # Sprawdź nakładanie z już zaakceptowanymi ogonami
411
+ for t in tails_to_keep_as_shapely_objects:
412
+ # Jeśli IoU > 0.5, odrzuć (to duplikat)
413
+ if t.intersection(tail_polygon).area / t.union(tail_polygon).area > 0.5:
414
  should_append = False
415
  break
416
+
417
  if should_append:
418
+ tails_to_keep.append((ts, tl, tb, ti))
419
+ tails_to_keep_as_shapely_objects.append(tail_polygon)
420
+
421
+ # Wyciągnięcie indeksów zaakceptowanych ogonów
422
+ indices_of_tails_to_keep.append(
423
+ [t[3].item() for t in tails_to_keep])
424
+ return indices_of_tails_to_keep
425
+
426
+ def _convert_annotations_to_coco_format(
427
+ self,
428
+ annotations: Optional[List[Dict[str, Any]]]
429
+ ) -> Optional[List[Dict[str, Any]]]:
430
+ """
431
+ Konwertuje anotacje z formatu x1y1x2y2 do formatu COCO (xywh).
432
+
433
+ Format COCO używa bbox jako [x, y, width, height] zamiast [x1, y1, x2, y2].
434
+ Dodatkowo oblicza pole powierzchni dla każdego bbox.
435
 
436
+ Args:
437
+ annotations: Lista anotacji w formacie:
438
+ [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}]
439
+ lub None
440
+
441
+ Returns:
442
+ Lista anotacji w formacie COCO lub None jeśli input był None
443
+ """
444
  if annotations is None:
445
  return None
446
+ # Weryfikacja poprawności formatu anotacji
447
  self._verify_annotations_are_in_correct_format(annotations)
448
+
449
+ coco_annotations: List[Dict[str, Any]] = []
450
  for annotation in annotations:
451
+ coco_annotation: Dict[str, Any] = {
452
  "image_id": annotation["image_id"],
453
  "annotations": [],
454
  }
455
+ # Konwersja każdego bbox z x1y1x2y2 na xywh
456
  for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
457
  coco_annotation["annotations"].append({
458
+ # [x1,y1,x2,y2] -> [x,y,w,h]
459
  "bbox": x1y1x2y2_to_xywh(bbox),
460
  "category_id": label,
461
+ # width * height
462
  "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
463
  })
464
  coco_annotations.append(coco_annotation)
465
  return coco_annotations
466
 
467
+ def _verify_annotations_are_in_correct_format(self, annotations: Optional[List[Dict[str, Any]]]) -> None:
468
+ """
469
+ Weryfikuje poprawność formatu anotacji.
470
+
471
+ Sprawdza czy anotacje są w oczekiwanym formacie:
472
+ - Lista/tupla słowników
473
+ - Każdy słownik zawiera klucze: "image_id", "bboxes_as_x1y1x2y2", "labels"
474
+ - Labels: 0=postać, 1=tekst, 2=panel, 3=ogon
475
+
476
+ Args:
477
+ annotations: Anotacje do weryfikacji lub None
478
+
479
+ Raises:
480
+ ValueError: Jeśli format anotacji jest nieprawidłowy
481
+ """
482
+ error_msg: str = """
483
  Annotations must be in the following format:
484
  [
485
  {
 
489
  },
490
  ...
491
  ]
492
+ Labels: 0 for characters, 1 for text, 2 for panels, 3 for tails.
493
  """
494
  if annotations is None:
495
  return
496
+
497
+ # Sprawdzenie czy to lista lub tupla
498
  if not isinstance(annotations, List) and not isinstance(annotations, tuple):
499
  raise ValueError(
500
  f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
501
  )
502
+
503
  if len(annotations) == 0:
504
  return
505
+
506
+ # Sprawdzenie czy elementy to słowniki
507
  if not isinstance(annotations[0], dict):
508
  raise ValueError(
509
+ f"{error_msg} Expected a List[Dict], found {type(annotations[0])}."
510
  )
511
+
512
+ # Sprawdzenie wymaganych kluczy w słowniku
513
  if "image_id" not in annotations[0]:
514
  raise ValueError(
515
  f"{error_msg} Dict must contain 'image_id'."
processing_magiv2_PRE.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor
2
+ import torch
3
+ from typing import List
4
+ from shapely.geometry import box
5
+ from .utils import x1y1x2y2_to_xywh
6
+ import numpy as np
7
+
8
+
9
+ class Magiv2Processor():
10
+ def __init__(self, config):
11
+ self.config = config
12
+ self.detection_image_preprocessor = None
13
+ self.ocr_preprocessor = None
14
+ self.crop_embedding_image_preprocessor = None
15
+ if not config.disable_detections:
16
+ assert config.detection_image_preprocessing_config is not None
17
+ self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(
18
+ config.detection_image_preprocessing_config)
19
+ if not config.disable_ocr:
20
+ assert config.ocr_pretrained_processor_path is not None
21
+ self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
22
+ config.ocr_pretrained_processor_path)
23
+ if not config.disable_crop_embeddings:
24
+ assert config.crop_embedding_image_preprocessing_config is not None
25
+ self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
26
+ config.crop_embedding_image_preprocessing_config)
27
+
28
+ def preprocess_inputs_for_detection(self, images, annotations=None):
29
+ images = list(images)
30
+ assert isinstance(images[0], np.ndarray)
31
+ annotations = self._convert_annotations_to_coco_format(annotations)
32
+ inputs = self.detection_image_preprocessor(
33
+ images, annotations=annotations, return_tensors="pt")
34
+ return inputs
35
+
36
+ def preprocess_inputs_for_ocr(self, images):
37
+ images = list(images)
38
+ assert isinstance(images[0], np.ndarray)
39
+ return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
40
+
41
+ def preprocess_inputs_for_crop_embeddings(self, images):
42
+ images = list(images)
43
+ assert isinstance(images[0], np.ndarray)
44
+ return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
45
+
46
+ def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
47
+ return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
48
+
49
+ def crop_image(self, image, bboxes):
50
+ crops_for_image = []
51
+ for bbox in bboxes:
52
+ x1, y1, x2, y2 = bbox
53
+
54
+ # fix the bounding box in case it is out of bounds or too small
55
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
56
+ x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(
57
+ x1, x2), max(y1, y2) # just incase
58
+ x1, y1 = max(0, x1), max(0, y1)
59
+ x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
60
+ x2, y2 = max(0, x2), max(0, y2)
61
+ x2, y2 = min(image.shape[1], x2), min(image.shape[0], y2)
62
+ if x2 - x1 < 10:
63
+ if image.shape[1] - x1 > 10:
64
+ x2 = x1 + 10
65
+ else:
66
+ x1 = x2 - 10
67
+ if y2 - y1 < 10:
68
+ if image.shape[0] - y1 > 10:
69
+ y2 = y1 + 10
70
+ else:
71
+ y1 = y2 - 10
72
+
73
+ crop = image[y1:y2, x1:x2]
74
+ crops_for_image.append(crop)
75
+ return crops_for_image
76
+
77
+ def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
78
+ indices_of_characters_to_keep = []
79
+ for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
80
+ indices = torch.where((labels == 0) & (
81
+ scores > character_detection_threshold))[0]
82
+ indices_of_characters_to_keep.append(indices)
83
+ return indices_of_characters_to_keep
84
+
85
+ def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
86
+ indices_of_panels_to_keep = []
87
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
88
+ indices = torch.where(labels == 2)[0]
89
+ bboxes = bboxes[indices]
90
+ scores = scores[indices]
91
+ labels = labels[indices]
92
+ if len(indices) == 0:
93
+ indices_of_panels_to_keep.append([])
94
+ continue
95
+ scores, labels, indices, bboxes = zip(
96
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
97
+ panels_to_keep = []
98
+ union_of_panels_so_far = box(0, 0, 0, 0)
99
+ for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
100
+ panel_polygon = box(pb[0], pb[1], pb[2], pb[3])
101
+ if ps < panel_detection_threshold:
102
+ continue
103
+ if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
104
+ continue
105
+ panels_to_keep.append((ps, pl, pb, pi))
106
+ union_of_panels_so_far = union_of_panels_so_far.union(
107
+ panel_polygon)
108
+ indices_of_panels_to_keep.append(
109
+ [p[3].item() for p in panels_to_keep])
110
+ return indices_of_panels_to_keep
111
+
112
+ def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
113
+ indices_of_texts_to_keep = []
114
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
115
+ indices = torch.where((labels == 1) & (
116
+ scores > text_detection_threshold))[0]
117
+ bboxes = bboxes[indices]
118
+ scores = scores[indices]
119
+ labels = labels[indices]
120
+ if len(indices) == 0:
121
+ indices_of_texts_to_keep.append([])
122
+ continue
123
+ scores, labels, indices, bboxes = zip(
124
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
125
+ texts_to_keep = []
126
+ texts_to_keep_as_shapely_objects = []
127
+ for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
128
+ text_polygon = box(tb[0], tb[1], tb[2], tb[3])
129
+ should_append = True
130
+ for t in texts_to_keep_as_shapely_objects:
131
+ if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
132
+ should_append = False
133
+ break
134
+ if should_append:
135
+ texts_to_keep.append((ts, tl, tb, ti))
136
+ texts_to_keep_as_shapely_objects.append(text_polygon)
137
+ indices_of_texts_to_keep.append(
138
+ [t[3].item() for t in texts_to_keep])
139
+ return indices_of_texts_to_keep
140
+
141
+ def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
142
+ indices_of_texts_to_keep = []
143
+ for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
144
+ indices = torch.where((labels == 3) & (
145
+ scores > text_detection_threshold))[0]
146
+ bboxes = bboxes[indices]
147
+ scores = scores[indices]
148
+ labels = labels[indices]
149
+ if len(indices) == 0:
150
+ indices_of_texts_to_keep.append([])
151
+ continue
152
+ scores, labels, indices, bboxes = zip(
153
+ *sorted(zip(scores, labels, indices, bboxes), reverse=True))
154
+ texts_to_keep = []
155
+ texts_to_keep_as_shapely_objects = []
156
+ for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
157
+ text_polygon = box(tb[0], tb[1], tb[2], tb[3])
158
+ should_append = True
159
+ for t in texts_to_keep_as_shapely_objects:
160
+ if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5:
161
+ should_append = False
162
+ break
163
+ if should_append:
164
+ texts_to_keep.append((ts, tl, tb, ti))
165
+ texts_to_keep_as_shapely_objects.append(text_polygon)
166
+ indices_of_texts_to_keep.append(
167
+ [t[3].item() for t in texts_to_keep])
168
+ return indices_of_texts_to_keep
169
+
170
+ def _convert_annotations_to_coco_format(self, annotations):
171
+ if annotations is None:
172
+ return None
173
+ self._verify_annotations_are_in_correct_format(annotations)
174
+ coco_annotations = []
175
+ for annotation in annotations:
176
+ coco_annotation = {
177
+ "image_id": annotation["image_id"],
178
+ "annotations": [],
179
+ }
180
+ for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]):
181
+ coco_annotation["annotations"].append({
182
+ "bbox": x1y1x2y2_to_xywh(bbox),
183
+ "category_id": label,
184
+ "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
185
+ })
186
+ coco_annotations.append(coco_annotation)
187
+ return coco_annotations
188
+
189
+ def _verify_annotations_are_in_correct_format(self, annotations):
190
+ error_msg = """
191
+ Annotations must be in the following format:
192
+ [
193
+ {
194
+ "image_id": 0,
195
+ "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]],
196
+ "labels": [0, 1, 2],
197
+ },
198
+ ...
199
+ ]
200
+ Labels: 0 for characters, 1 for text, 2 for panels.
201
+ """
202
+ if annotations is None:
203
+ return
204
+ if not isinstance(annotations, List) and not isinstance(annotations, tuple):
205
+ raise ValueError(
206
+ f"{error_msg} Expected a List/Tuple, found {type(annotations)}."
207
+ )
208
+ if len(annotations) == 0:
209
+ return
210
+ if not isinstance(annotations[0], dict):
211
+ raise ValueError(
212
+ f"{error_msg} Expected a List[Dicct], found {type(annotations[0])}."
213
+ )
214
+ if "image_id" not in annotations[0]:
215
+ raise ValueError(
216
+ f"{error_msg} Dict must contain 'image_id'."
217
+ )
218
+ if "bboxes_as_x1y1x2y2" not in annotations[0]:
219
+ raise ValueError(
220
+ f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'."
221
+ )
222
+ if "labels" not in annotations[0]:
223
+ raise ValueError(
224
+ f"{error_msg} Dict must contain 'labels'."
225
+ )
utils.py CHANGED
@@ -1,36 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
  import random
4
  import matplotlib.pyplot as plt
5
  import matplotlib.patches as patches
6
  from shapely.geometry import Point, box
 
7
  import networkx as nx
8
  from copy import deepcopy
9
  from itertools import groupby
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
11
 
12
- def move_to_device(inputs, device):
 
 
 
 
 
 
13
  if hasattr(inputs, "keys"):
 
14
  return {k: move_to_device(v, device) for k, v in inputs.items()}
15
  elif isinstance(inputs, list):
 
16
  return [move_to_device(v, device) for v in inputs]
17
  elif isinstance(inputs, tuple):
 
18
  return tuple([move_to_device(v, device) for v in inputs])
19
  elif isinstance(inputs, np.ndarray):
 
20
  return torch.from_numpy(inputs).to(device)
21
  else:
 
22
  return inputs.to(device)
23
 
24
 
25
  class UnionFind:
26
- def __init__(self, n):
27
- self.parent = list(range(n))
28
- self.size = [1] * n
29
- self.num_components = n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @classmethod
32
- def from_adj_matrix(cls, adj_matrix):
33
- ufds = cls(adj_matrix.shape[0])
 
 
 
 
 
 
 
 
 
 
 
34
  for i in range(adj_matrix.shape[0]):
35
  for j in range(adj_matrix.shape[1]):
36
  if adj_matrix[i, j] > 0:
@@ -38,229 +242,482 @@ class UnionFind:
38
  return ufds
39
 
40
  @classmethod
41
- def from_adj_list(cls, adj_list):
42
- ufds = cls(len(adj_list))
 
 
 
 
 
 
 
 
 
43
  for i in range(len(adj_list)):
44
  for j in adj_list[i]:
45
  ufds.unite(i, j)
46
  return ufds
47
 
48
  @classmethod
49
- def from_edge_list(cls, edge_list, num_nodes):
50
- ufds = cls(num_nodes)
 
 
 
 
 
 
 
 
 
 
51
  for edge in edge_list:
52
  ufds.unite(edge[0], edge[1])
53
  return ufds
54
 
55
- def find(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if self.parent[x] == x:
57
  return x
 
58
  self.parent[x] = self.find(self.parent[x])
59
  return self.parent[x]
60
 
61
- def unite(self, x, y):
 
 
 
 
 
 
 
 
 
 
62
  x = self.find(x)
63
  y = self.find(y)
64
  if x != y:
 
65
  if self.size[x] < self.size[y]:
66
  x, y = y, x
67
  self.parent[y] = x
68
  self.size[x] += self.size[y]
69
  self.num_components -= 1
70
 
71
- def get_components_of(self, x):
 
 
 
 
 
 
 
 
 
72
  x = self.find(x)
73
  return [i for i in range(len(self.parent)) if self.find(i) == x]
74
 
75
- def are_connected(self, x, y):
 
 
 
 
 
 
 
 
 
 
76
  return self.find(x) == self.find(y)
77
 
78
- def get_size(self, x):
 
 
 
 
 
 
 
 
 
79
  return self.size[self.find(x)]
80
 
81
- def get_num_components(self):
 
 
 
 
 
 
82
  return self.num_components
83
 
84
- def get_labels_for_connected_components(self):
85
- map_parent_to_label = {}
86
- labels = []
 
 
 
 
 
 
 
 
 
87
  for i in range(len(self.parent)):
88
- parent = self.find(i)
89
  if parent not in map_parent_to_label:
90
  map_parent_to_label[parent] = len(map_parent_to_label)
91
  labels.append(map_parent_to_label[parent])
92
  return labels
93
 
94
 
95
- def visualise_single_image_prediction(image_as_np_array, predictions, filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
97
  subplot.imshow(image_as_np_array)
 
 
98
  plot_bboxes(subplot, predictions["panels"], color="green")
99
  plot_bboxes(subplot, predictions["texts"], color="red",
100
  visibility=predictions["is_essential_text"])
101
  plot_bboxes(subplot, predictions["characters"], color="blue")
102
  plot_bboxes(subplot, predictions["tails"], color="purple")
103
 
 
104
  for i, name in enumerate(predictions["character_names"]):
105
- char_bbox = predictions["characters"][i]
 
 
 
 
106
  x1, y1, x2, y2 = char_bbox
107
  subplot.text(x1, y1 - 2, name,
108
  verticalalignment='bottom', horizontalalignment='left',
109
- # Background settings
110
  bbox=dict(facecolor='blue', alpha=1, edgecolor='none'),
111
  color='white', fontsize=8)
112
 
113
- COLOURS = [
114
- "#b7ff51", # green
115
- "#f50a8f", # pink
116
- "#4b13b6", # purple
117
- "#ddaa34", # orange
118
- "#bea2a2", # brown
 
119
  ]
120
- colour_index = 0
121
- character_cluster_labels = predictions["character_cluster_labels"]
122
- unique_label_sorted_by_frequency = sorted(list(set(
 
123
  character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
 
 
124
  for label in unique_label_sorted_by_frequency:
125
- root = None
126
- others = []
 
127
  for i in range(len(predictions["characters"])):
128
  if character_cluster_labels[i] == label:
129
  if root is None:
130
- root = i
131
  else:
132
- others.append(i)
 
 
133
  if colour_index >= len(COLOURS):
134
- random_colour = COLOURS[0]
 
135
  while random_colour in COLOURS:
136
  random_colour = "#" + \
137
  "".join([random.choice("0123456789ABCDEF")
138
  for j in range(6)])
139
  else:
140
- random_colour = COLOURS[colour_index]
141
  colour_index += 1
142
- bbox_i = predictions["characters"][root]
143
- x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
144
- y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
 
 
 
145
  subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
 
 
146
  for j in others:
147
- # draw line from centre of bbox i to centre of bbox j
148
- bbox_j = predictions["characters"][j]
149
  x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
150
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
151
- x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
152
- y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
 
153
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
 
154
  subplot.plot([x2], [y2], color=random_colour,
155
  marker="o", markersize=5)
156
 
 
157
  for (i, j) in predictions["text_character_associations"]:
158
- bbox_i = predictions["texts"][i]
159
- bbox_j = predictions["characters"][j]
 
160
  if not predictions["is_essential_text"][i]:
161
  continue
162
- x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
163
- y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
164
- x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
165
- y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
 
 
166
  subplot.plot([x1, x2], [y1, y2], color="red",
167
  linewidth=2, linestyle="dashed")
168
 
 
169
  for (i, j) in predictions["text_tail_associations"]:
170
- bbox_i = predictions["texts"][i]
171
- bbox_j = predictions["tails"][j]
172
- x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
173
- y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
174
- x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
175
- y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
 
 
176
  subplot.plot([x1, x2], [y1, y2], color="purple",
177
  linewidth=2, linestyle="dashed")
178
 
 
179
  subplot.axis("off")
 
180
  if filename is not None:
181
  plt.savefig(filename, bbox_inches="tight", pad_inches=0)
182
 
 
183
  figure.canvas.draw()
184
- image = np.array(figure.canvas.renderer._renderer)
185
  plt.close()
186
  return image
187
 
188
 
189
- def plot_bboxes(subplot, bboxes, color="red", visibility=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  if visibility is None:
191
  visibility = [1] * len(bboxes)
 
192
  for id, bbox in enumerate(bboxes):
 
193
  if visibility[id] == 0:
194
  continue
195
- w = bbox[2] - bbox[0]
196
- h = bbox[3] - bbox[1]
197
- rect = patches.Rectangle(
 
 
198
  bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
199
  )
200
  subplot.add_patch(rect)
201
 
202
 
203
- def sort_panels(rects):
204
- before_rects = convert_to_list_of_lists(rects)
205
- # slightly erode all rectangles initially to account for imperfect detections
206
- rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
207
- G = nx.DiGraph()
208
- G.add_nodes_from(range(len(rects)))
209
- for i in range(len(rects)):
210
- for j in range(len(rects)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  if i == j:
212
  continue
213
- if is_there_a_directed_edge(i, j, rects):
214
- G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
 
 
215
  else:
216
- G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
 
 
 
217
  while True:
218
- cycles = sorted(nx.simple_cycles(G))
219
  cycles = [cycle for cycle in cycles if len(cycle) > 1]
220
  if len(cycles) == 0:
221
  break
222
- cycle = cycles[0]
223
- edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
224
- max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
 
 
 
 
 
225
  G.remove_edge(*max_cyclic_edge)
 
 
226
  return list(nx.topological_sort(G))
227
 
228
 
229
- def is_strictly_above(rectA, rectB):
 
 
 
 
 
230
  x1A, y1A, x2A, y2A = rectA
 
 
 
 
231
  x1B, y1B, x2B, y2B = rectB
232
  return y2A < y1B
233
 
234
 
235
- def is_strictly_below(rectA, rectB):
 
 
 
 
 
236
  x1A, y1A, x2A, y2A = rectA
 
 
 
 
237
  x1B, y1B, x2B, y2B = rectB
238
  return y2B < y1A
239
 
240
 
241
- def is_strictly_left_of(rectA, rectB):
 
 
 
 
 
242
  x1A, y1A, x2A, y2A = rectA
 
 
 
 
243
  x1B, y1B, x2B, y2B = rectB
244
  return x2A < x1B
245
 
246
 
247
- def is_strictly_right_of(rectA, rectB):
 
 
 
 
 
248
  x1A, y1A, x2A, y2A = rectA
 
 
 
 
249
  x1B, y1B, x2B, y2B = rectB
250
  return x2B < x1A
251
 
252
 
253
- def intersects(rectA, rectB):
 
254
  return box(*rectA).intersects(box(*rectB))
255
 
256
 
257
- def is_there_a_directed_edge(a, b, rects):
258
- rectA = rects[a]
259
- rectB = rects[b]
260
- centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
261
- rectA[1] + (rectA[3] - rectA[1]) / 2]
262
- centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
263
- rectB[1] + (rectB[3] - rectB[1]) / 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
265
  return box(*rectA).area > (box(*rectB)).area
266
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
@@ -283,172 +740,430 @@ def is_there_a_directed_edge(a, b, rects):
283
  copy_B = erode_rectangle(copy_B, 0.05)
284
 
285
 
286
- def get_distance(rectA, rectB):
 
 
 
 
 
 
 
 
 
 
287
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
288
 
289
 
290
- def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  rects = deepcopy(rects)
 
292
  while True:
 
 
 
 
 
293
  xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
294
  rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
295
- rect_index = [i for i in range(len(rects)) if intersects(
 
 
296
  rects[i], [xmin, ymin, xmax, ymax])]
297
- rects_copy = [rect for rect in rects if intersects(
 
298
  rect, [xmin, ymin, xmax, ymax])]
299
 
300
- # try to split the panels using a "horizontal" lines
301
- overlapping_y_ranges = merge_overlapping_ranges(
 
302
  [(y1, y2) for x1, y1, x2, y2 in rects_copy])
303
- panel_index_to_split = {}
 
 
304
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
305
  for i, index in enumerate(rect_index):
 
306
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
307
  panel_index_to_split[index] = split_index
308
 
 
309
  if panel_index_to_split[a] != panel_index_to_split[b]:
310
  return panel_index_to_split[a] < panel_index_to_split[b]
311
 
312
- # try to split the panels using a "vertical" lines
313
- overlapping_x_ranges = merge_overlapping_ranges(
 
314
  [(x1, x2) for x1, y1, x2, y2 in rects_copy])
315
- panel_index_to_split = {}
 
 
 
316
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
317
  for i, index in enumerate(rect_index):
 
318
  if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
319
  panel_index_to_split[index] = split_index
 
 
320
  if panel_index_to_split[a] != panel_index_to_split[b]:
321
  return panel_index_to_split[a] < panel_index_to_split[b]
322
 
323
- # otherwise, erode the rectangles and try again
 
324
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
325
 
326
 
327
- def erode_rectangle(bbox, erosion_factor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  x1, y1, x2, y2 = bbox
 
 
329
  w, h = x2 - x1, y2 - y1
 
 
 
330
  cx, cy = x1 + w / 2, y1 + h / 2
 
 
331
  if w < h:
332
- aspect_ratio = w / h
333
- erosion_factor_width = erosion_factor * aspect_ratio
334
- erosion_factor_height = erosion_factor
335
  else:
336
- aspect_ratio = h / w
337
- erosion_factor_width = erosion_factor
338
- erosion_factor_height = erosion_factor * aspect_ratio
 
 
339
  w = w - w * erosion_factor_width
340
  h = h - h * erosion_factor_height
 
341
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
342
  return [x1, y1, x2, y2]
343
 
344
 
345
- def merge_overlapping_ranges(ranges):
346
  """
347
- ranges: list of tuples (x1, x2)
 
 
 
 
 
 
 
 
 
348
  """
349
  if len(ranges) == 0:
350
  return []
351
- ranges = sorted(ranges, key=lambda x: x[0])
352
- merged_ranges = []
353
- for i, r in enumerate(ranges):
 
 
 
 
354
  if i == 0:
355
  prev_x1, prev_x2 = r
356
  continue
 
 
357
  x1, x2 = r
 
358
  if x1 > prev_x2:
359
  merged_ranges.append((prev_x1, prev_x2))
360
  prev_x1, prev_x2 = x1, x2
361
  else:
 
362
  prev_x2 = max(prev_x2, x2)
 
363
  merged_ranges.append((prev_x1, prev_x2))
364
  return merged_ranges
365
 
366
 
367
- def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
368
- text_bboxes = convert_to_list_of_lists(text_bboxes)
369
- sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
- if len(text_bboxes) == 0:
 
 
 
 
 
 
 
372
  return []
373
 
374
- def indices_of_same_elements(nums):
 
375
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
376
  return [list(indices) for _, indices in groups]
377
 
378
- panel_id_for_text = get_text_to_panel_mapping(
379
- text_bboxes, sorted_panel_bboxes)
380
- indices_of_texts = list(range(len(text_bboxes)))
 
 
381
  indices_of_texts, panel_id_for_text = zip(
382
  *sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
383
  indices_of_texts = list(indices_of_texts)
384
- grouped_indices = indices_of_same_elements(panel_id_for_text)
 
 
 
385
  for group in grouped_indices:
386
- subset_of_text_indices = [indices_of_texts[i] for i in group]
387
- text_bboxes_of_subset = [text_bboxes[i]
388
- for i in subset_of_text_indices]
389
- sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
 
 
 
390
  indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
391
  for i in sorted_subset_indices]
392
  return indices_of_texts
393
 
394
 
395
- def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
396
- text_to_panel_mapping = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  for text_bbox in text_bboxes:
398
- shapely_text_polygon = box(*text_bbox)
399
- all_intersections = []
400
- all_distances = []
 
 
 
 
401
  if len(sorted_panel_bboxes) == 0:
402
  text_to_panel_mapping.append(-1)
403
  continue
 
 
404
  for j, annotation in enumerate(sorted_panel_bboxes):
405
- shapely_annotation_polygon = box(*annotation)
 
 
 
406
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
407
- all_intersections.append(
408
- (shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
409
- all_distances.append(
410
- (shapely_text_polygon.distance(shapely_annotation_polygon), j))
 
 
 
 
 
 
411
  if len(all_intersections) == 0:
412
- text_to_panel_mapping.append(
413
- min(all_distances, key=lambda x: x[0])[1])
 
 
414
  else:
415
- text_to_panel_mapping.append(
416
- max(all_intersections, key=lambda x: x[0])[1])
 
 
 
417
  return text_to_panel_mapping
418
 
419
 
420
- def sort_texts_within_panel(rects):
421
- smallest_y = float("inf")
422
- greatest_x = float("-inf")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  for i, rect in enumerate(rects):
 
 
 
 
424
  x1, y1, x2, y2 = rect
425
- smallest_y = min(smallest_y, y1)
426
- greatest_x = max(greatest_x, x2)
427
 
428
- reference_point = Point(greatest_x, smallest_y)
 
429
 
430
- polygons_and_index = []
 
431
  for i, rect in enumerate(rects):
 
 
 
 
432
  x1, y1, x2, y2 = rect
433
  polygons_and_index.append((box(x1, y1, x2, y2), i))
434
- # sort points by closest to reference point
 
435
  polygons_and_index = sorted(
436
  polygons_and_index, key=lambda x: reference_point.distance(x[0]))
437
- indices = [x[1] for x in polygons_and_index]
 
 
438
  return indices
439
 
440
 
441
- def x1y1wh_to_x1y1x2y2(bbox):
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  x1, y1, w, h = bbox
443
  return [x1, y1, x1 + w, y1 + h]
444
 
445
 
446
- def x1y1x2y2_to_xywh(bbox):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  x1, y1, x2, y2 = bbox
448
  return [x1, y1, x2 - x1, y2 - y1]
449
 
450
 
451
- def convert_to_list_of_lists(rects):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  if isinstance(rects, torch.Tensor):
453
  return rects.tolist()
454
  if isinstance(rects, np.ndarray):
 
1
+ """
2
+ Funkcje pomocnicze dla modelu Magiv2.
3
+
4
+ ═══════════════════════════════════════════════════════════════════════════════
5
+ STRESZCZENIE ZAWARTOŚCI PLIKU
6
+ ═══════════════════════════════════════════════════════════════════════════════
7
+
8
+ Ten moduł zawiera narzędzia pomocnicze do przetwarzania i wizualizacji wyników
9
+ modelu Magiv2 dla analizy komiksów/mangi. Plik składa się z 5 głównych kategorii:
10
+
11
+ 1. ZARZĄDZANIE URZĄDZENIAMI
12
+ ├─ move_to_device() - Rekurencyjne przenoszenie danych między CPU/GPU
13
+ │ Obsługuje: dict, list, tuple, numpy.ndarray, torch.Tensor
14
+ └─ Używane przy każdym wywołaniu modelu do przeniesienia danych na właściwe urządzenie
15
+
16
+ 2. STRUKTURA UNION-FIND DO KLASTROWANIA (linie ~53-190)
17
+ ├─ class UnionFind - Disjoint Set Union z kompresją ścieżki i union by size
18
+ │ ├─ __init__(n) - Inicjalizacja n rozłącznych elementów
19
+ │ ├─ from_adj_matrix() - Tworzenie z macierzy sąsiedztwa
20
+ │ ├─ from_adj_list() - Tworzenie z listy sąsiedztwa
21
+ │ ├─ from_edge_list() - Tworzenie z listy krawędzi
22
+ │ ├─ find(x) - Znajdowanie korzenia zbioru (z path compression)
23
+ │ ├─ unite(x, y) - Łączenie zbiorów (z union by size)
24
+ │ ├─ get_components_of(x) - Wszystkie elementy w zbiorze x
25
+ │ ├─ are_connected(x, y) - Sprawdzanie czy x i y w tym samym zbiorze
26
+ │ ├─ get_size(x) - Rozmiar zbioru zawierającego x
27
+ │ ├─ get_num_components() - Liczba rozłącznych zbiorów
28
+ │ └─ get_labels_for_connected_components() - Generowanie etykiet klastrów
29
+ └─ Używane do grupowania postaci na podstawie macierzy podobieństwa
30
+
31
+ 3. WIZUALIZACJA WYNIKÓW
32
+ ├─ visualise_single_image_prediction() - Główna funkcja wizualizacji
33
+ │ ├─ Rysuje bounding boxy: panele (zielone), tekst (czerwone),
34
+ │ │ postaci (niebieskie), ogony dymków (fioletowe)
35
+ │ ├─ Wyświetla imiona postaci nad ich bounding boxami
36
+ │ ├─ Rysuje klastry postaci (ta sama osoba) jako kolorowe linie w układzie gwiazdki
37
+ │ ├─ Pokazuje asocjacje tekst-postać (kto mówi) - czerwone przerywane linie
38
+ │ ├─ Pokazuje asocjacje tekst-ogon - fioletowe przerywane linie
39
+ │ └─ Zwraca obraz jako numpy array lub zapisuje do pliku
40
+ └─ plot_bboxes() - Pomocnicza funkcja do rysowania prostokątów
41
+
42
+ 4. SORTOWANIE PANELI I TEKSTÓW W KOLEJNOŚCI CZYTANIA
43
+
44
+ A. Sortowanie paneli (manga: prawo->lewo, góra->dół):
45
+ ├─ sort_panels() - Główny algorytm sortowania paneli
46
+ │ ├─ Buduje skierowany graf kolejności czytania
47
+ │ ├─ Używa erozji paneli (5%) do obsługi niedokładnych detekcji
48
+ │ ├─ Usuwa cykle przez eliminację najdłuższych krawędzi
49
+ │ └─ Zwraca sortowanie topologiczne (kolejność czytania)
50
+
51
+ ├─ is_there_a_directed_edge() - Określa czy panel A jest przed B
52
+ │ ├─ Reguły mangi: prawo ma priorytet nad górą
53
+ │ ├─ Obsługuje nakładające się panele przez erozję
54
+ │ └─ Używa heurystyk cięć (cuts) dla skomplikowanych układów
55
+
56
+ ├─ use_cuts_to_determine_edge_from_a_to_b() - Zaawansowane heurystyki
57
+ │ ├─ Dzieli panele na "wiersze" (overlapping Y ranges)
58
+ │ ├─ Dzieli panele na "kolumny" (overlapping X ranges)
59
+ │ └─ Iteracyjna erozja gdy nie można określić kolejności
60
+
61
+ └─ Funkcje pomocnicze geometrii:
62
+ ├─ is_strictly_above/below/left_of/right_of() - Relacje przestrzenne
63
+ ├─ intersects() - Sprawdzanie przecięcia prostokątów (Shapely)
64
+ ├─ get_distance() - Odległość euklidesowa między prostokątami
65
+ ├─ erode_rectangle() - Zmniejszanie prostokąta z zachowaniem aspect ratio
66
+ └─ merge_overlapping_ranges() - Scalanie nakładających się zakresów 1D
67
+
68
+ B. Sortowanie tekstów:
69
+ ├─ sort_text_boxes_in_reading_order() - Sortuje teksty według paneli
70
+ │ ├─ Przypisuje każdy tekst do najbliższego panelu
71
+ │ ├─ Sortuje teksty według kolejności paneli
72
+ │ └─ W każdym panelu sortuje według odległości od prawego górnego rogu
73
+
74
+ ├─ get_text_to_panel_mapping() - Przypisanie tekst->panel
75
+ │ ├─ Preferuje nakładanie się (intersection area)
76
+ │ └─ Fallback: najbliższy panel (distance)
77
+
78
+ └─ sort_texts_within_panel() - Sortowanie w obrębie jednego panelu
79
+ └─ Sortuje według odległości od prawego górnego rogu panelu
80
+
81
+ 5. KONWERSJE FORMATÓW BOUNDING BOXÓW
82
+ ├─ x1y1wh_to_x1y1x2y2() - (x, y, width, height) -> (x1, y1, x2, y2)
83
+ ├─ x1y1x2y2_to_xywh() - (x1, y1, x2, y2) -> (x, y, width, height)
84
+ │ └─ Format COCO używa xywh zamiast corners
85
+ └─ convert_to_list_of_lists() - Uniwersalna konwersja torch/numpy/list
86
+
87
+ ═══════════════════════════════════════════════════════════════════════════════
88
+ KLUCZOWE ALGORYTMY
89
+ ═══════════════════════════════════════════════════════════════════════════════
90
+
91
+ 1. UNION-FIND (O(α(n)) - prawie stała):
92
+ - Path compression: podczas find() ustawiamy rodzica bezpośrednio na korzeń
93
+ - Union by size: mniejszy zbiór dołączamy do większego dla zbalansowania
94
+ - Używane do klastrowania postaci z macierzy podobieństwa
95
+
96
+ 2. SORTOWANIE PANELI (O(n² log n)):
97
+ - Graf skierowany gdzie krawędź A->B = "A przed B"
98
+ - Reguły: prawo > góra (manga) lub lewo > góra (komiks zachodni)
99
+ - Usuwanie cykli przez eliminację najdłuższych krawędzi
100
+ - Sortowanie topologiczne DAG dla finalnej kolejności
101
+ - Erozja progresywna (5% na iterację) dla nakładających się paneli
102
+
103
+ 3. SORTOWANIE TEKSTÓW (O(n log n)):
104
+ - Przypisanie do paneli: max(intersection_area) lub min(distance)
105
+ - Sortowanie według ID panelu (panele już posortowane)
106
+ - W panelu: sortowanie według distance od prawego górnego rogu
107
+ - Odległość w Shapely: shortest distance między geometriami
108
+
109
+ ═══════════════════════════════════════════════════════════════════════════════
110
+ ZALEŻNOŚCI ZEWNĘTRZNE
111
+ ═══════════════════════════════════════════════════════════════════════════════
112
+
113
+ - torch: Tensory GPU/CPU, operacje na urządzeniach
114
+ - numpy: Operacje na tablicach, NDArray typing
115
+ - matplotlib: Wizualizacja (pyplot, patches)
116
+ - shapely: Geometria 2D (Point, box, Polygon) - przecięcia, odległości
117
+ - networkx: Grafy (DiGraph, topological_sort, simple_cycles)
118
+ - typing: Type hints (Any, Dict, List, Tuple, Union, Optional)
119
+
120
+ ═══════════════════════════════════════════════════════════════════════════════
121
+ TYPOWE UŻYCIE
122
+ ═══════════════════════════════════════════════════════════════════════════════
123
+
124
+ # 1. Przeniesienie danych na GPU
125
+ inputs = move_to_device({"images": np_array, "labels": [0, 1, 2]}, device)
126
+
127
+ # 2. Klastrowanie postaci z macierzy podobieństwa
128
+ uf = UnionFind.from_adj_matrix(similarity_matrix > threshold)
129
+ cluster_labels = uf.get_labels_for_connected_components()
130
+
131
+ # 3. Sortowanie paneli w kolejności czytania
132
+ sorted_panel_indices = sort_panels(panel_bboxes)
133
+
134
+ # 4. Sortowanie tekstów
135
+ sorted_text_indices = sort_text_boxes_in_reading_order(
136
+ text_bboxes, sorted_panel_bboxes
137
+ )
138
+
139
+ # 5. Wizualizacja wyników
140
+ image = visualise_single_image_prediction(
141
+ image_array, predictions, filename="output.png"
142
+ )
143
+
144
+ # 6. Konwersja formatów bbox
145
+ coco_bbox = x1y1x2y2_to_xywh([10, 20, 30, 40]) # -> [10, 20, 20, 20]
146
+
147
+ ═══════════════════════════════════════════════════════════════════════════════
148
+ """
149
+
150
  import torch
151
  import numpy as np
152
  import random
153
  import matplotlib.pyplot as plt
154
  import matplotlib.patches as patches
155
  from shapely.geometry import Point, box
156
+ from shapely.geometry.polygon import Polygon
157
  import networkx as nx
158
  from copy import deepcopy
159
  from itertools import groupby
160
+ from typing import Any, Dict, List, Tuple, Union, Optional
161
+ from numpy.typing import NDArray
162
+
163
+
164
+ def move_to_device(inputs: Any, device: torch.device) -> Any:
165
+ """
166
+ Rekurencyjnie przenosi dane na określone urządzenie (CPU/GPU).
167
 
168
+ Obsługuje różne typy danych:
169
+ - Słowniki: przenosi każdy klucz-wartość rekurencyjnie
170
+ - Listy: przenosi każdy element rekurencyjnie
171
+ - Tuple: przenosi każdy element rekurencyjnie
172
+ - numpy.ndarray: konwertuje na torch.Tensor i przenosi
173
+ - torch.Tensor: przenosi bezpośrednio
174
 
175
+ Args:
176
+ inputs: Dane do przeniesienia (dict, list, tuple, array, tensor)
177
+ device: Docelowe urządzenie torch (torch.device)
178
+
179
+ Returns:
180
+ Dane przeniesione na docelowe urządzenie (ten sam typ co input)
181
+ """
182
  if hasattr(inputs, "keys"):
183
+ # Słownik - przenoś każdą wartość rekurencyjnie
184
  return {k: move_to_device(v, device) for k, v in inputs.items()}
185
  elif isinstance(inputs, list):
186
+ # Lista - przenoś każdy element rekurencyjnie
187
  return [move_to_device(v, device) for v in inputs]
188
  elif isinstance(inputs, tuple):
189
+ # Tuple - przenoś każdy element rekurencyjnie
190
  return tuple([move_to_device(v, device) for v in inputs])
191
  elif isinstance(inputs, np.ndarray):
192
+ # NumPy array - konwertuj na tensor i przenieś
193
  return torch.from_numpy(inputs).to(device)
194
  else:
195
+ # Tensor - przenieś bezpośrednio
196
  return inputs.to(device)
197
 
198
 
199
  class UnionFind:
200
+ """
201
+ Union-Find (Disjoint Set Union) - struktura danych do klastrowania.
202
+
203
+ Używana do grupowania postaci na podstawie macierzy podobieństwa.
204
+ Implementuje algorytm z kompresją ścieżki (path compression) i
205
+ łączeniem według rozmiaru (union by size) dla optymalnej wydajności.
206
+
207
+ Attributes:
208
+ parent: Lista rodziców dla każdego węzła (indeks -> rodzic)
209
+ size: Rozmiary poddrzew dla każdego korzenia
210
+ num_components: Liczba rozłącznych komponentów (klastrów)
211
+ """
212
+
213
+ def __init__(self, n: int) -> None:
214
+ """
215
+ Inicjalizuje Union-Find z n rozłącznymi elementami.
216
+
217
+ Args:
218
+ n: Liczba elementów (węzłów) w strukturze
219
+ """
220
+ self.parent: List[int] = list(range(n))
221
+ self.size: List[int] = [1] * n
222
+ self.num_components: int = n
223
 
224
  @classmethod
225
+ def from_adj_matrix(cls, adj_matrix: torch.Tensor) -> 'UnionFind':
226
+ """
227
+ Tworzy Union-Find z macierzy sąsiedztwa (adjacency matrix).
228
+
229
+ Łączy węzły i,j jeśli adj_matrix[i,j] > 0 (są połączone krawędzią).
230
+
231
+ Args:
232
+ adj_matrix: Macierz sąsiedztwa [n, n] (1 = połączone, 0 = rozłączone)
233
+
234
+ Returns:
235
+ Nowa instancja UnionFind z połączonymi węzłami
236
+ """
237
+ ufds: UnionFind = cls(adj_matrix.shape[0])
238
  for i in range(adj_matrix.shape[0]):
239
  for j in range(adj_matrix.shape[1]):
240
  if adj_matrix[i, j] > 0:
 
242
  return ufds
243
 
244
  @classmethod
245
+ def from_adj_list(cls, adj_list: List[List[int]]) -> 'UnionFind':
246
+ """
247
+ Tworzy Union-Find z listy sąsiedztwa (adjacency list).
248
+
249
+ Args:
250
+ adj_list: Lista list, gdzie adj_list[i] zawiera sąsiadów węzła i
251
+
252
+ Returns:
253
+ Nowa instancja UnionFind z połączonymi węzłami
254
+ """
255
+ ufds: UnionFind = cls(len(adj_list))
256
  for i in range(len(adj_list)):
257
  for j in adj_list[i]:
258
  ufds.unite(i, j)
259
  return ufds
260
 
261
  @classmethod
262
+ def from_edge_list(cls, edge_list: List[Tuple[int, int]], num_nodes: int) -> 'UnionFind':
263
+ """
264
+ Tworzy Union-Find z listy krawędzi.
265
+
266
+ Args:
267
+ edge_list: Lista krotek (i, j) reprezentujących krawędzie
268
+ num_nodes: Całkowita liczba węzłów w grafie
269
+
270
+ Returns:
271
+ Nowa instancja UnionFind z połączonymi węzłami
272
+ """
273
+ ufds: UnionFind = cls(num_nodes)
274
  for edge in edge_list:
275
  ufds.unite(edge[0], edge[1])
276
  return ufds
277
 
278
+ def find(self, x: int) -> int:
279
+ """
280
+ Znajduje korzeń (reprezentanta) zbioru zawierającego x.
281
+
282
+ Implementuje kompresję ścieżki (path compression) - podczas
283
+ przechodzenia do korzenia, ustawia rodzica każdego węzła
284
+ bezpośrednio na korzeń dla przyszłych szybszych zapytań.
285
+
286
+ Args:
287
+ x: Indeks węzła
288
+
289
+ Returns:
290
+ Indeks korzenia zbioru zawierającego x
291
+ """
292
  if self.parent[x] == x:
293
  return x
294
+ # Kompresja ścieżki - ustawiamy rodzica na korzeń
295
  self.parent[x] = self.find(self.parent[x])
296
  return self.parent[x]
297
 
298
+ def unite(self, x: int, y: int) -> None:
299
+ """
300
+ Łączy zbiory zawierające x i y.
301
+
302
+ Implementuje union by size - mniejszy zbiór jest dołączany
303
+ do większego dla utrzymania zbalansowanego drzewa.
304
+
305
+ Args:
306
+ x: Indeks pierwszego węzła
307
+ y: Indeks drugiego węzła
308
+ """
309
  x = self.find(x)
310
  y = self.find(y)
311
  if x != y:
312
+ # Łączenie według rozmiaru - mniejszy do większego
313
  if self.size[x] < self.size[y]:
314
  x, y = y, x
315
  self.parent[y] = x
316
  self.size[x] += self.size[y]
317
  self.num_components -= 1
318
 
319
+ def get_components_of(self, x: int) -> List[int]:
320
+ """
321
+ Zwraca wszystkie węzły w tym samym zbiorze co x.
322
+
323
+ Args:
324
+ x: Indeks węzła
325
+
326
+ Returns:
327
+ Lista indeksów wszystkich węzłów w zbiorze zawierającym x
328
+ """
329
  x = self.find(x)
330
  return [i for i in range(len(self.parent)) if self.find(i) == x]
331
 
332
+ def are_connected(self, x: int, y: int) -> bool:
333
+ """
334
+ Sprawdza czy x i y są w tym samym zbiorze.
335
+
336
+ Args:
337
+ x: Indeks pierwszego węzła
338
+ y: Indeks drugiego węzła
339
+
340
+ Returns:
341
+ True jeśli x i y są w tym samym zbiorze, False w przeciwnym razie
342
+ """
343
  return self.find(x) == self.find(y)
344
 
345
+ def get_size(self, x: int) -> int:
346
+ """
347
+ Zwraca rozmiar zbioru zawierającego x.
348
+
349
+ Args:
350
+ x: Indeks węzła
351
+
352
+ Returns:
353
+ Liczba węzłów w zbiorze zawierającym x
354
+ """
355
  return self.size[self.find(x)]
356
 
357
+ def get_num_components(self) -> int:
358
+ """
359
+ Zwraca liczbę rozłącznych zbiorów (komponentów).
360
+
361
+ Returns:
362
+ Liczba rozłącznych zbiorów w strukturze
363
+ """
364
  return self.num_components
365
 
366
+ def get_labels_for_connected_components(self) -> List[int]:
367
+ """
368
+ Generuje etykiety klastrów dla wszystkich węzłów.
369
+
370
+ Węzły w tym samym zbiorze otrzymują tę samą etykietę (0, 1, 2, ...).
371
+ Etykiety są przypisywane w kolejności pierwszego napotkania korzenia.
372
+
373
+ Returns:
374
+ Lista etykiet klastrów (długość n), gdzie labels[i] to klaster węzła i
375
+ """
376
+ map_parent_to_label: Dict[int, int] = {}
377
+ labels: List[int] = []
378
  for i in range(len(self.parent)):
379
+ parent: int = self.find(i)
380
  if parent not in map_parent_to_label:
381
  map_parent_to_label[parent] = len(map_parent_to_label)
382
  labels.append(map_parent_to_label[parent])
383
  return labels
384
 
385
 
386
+ def visualise_single_image_prediction(
387
+ image_as_np_array: NDArray[np.uint8],
388
+ predictions: Dict[str, Any],
389
+ filename: Optional[str]
390
+ ) -> NDArray[np.uint8]:
391
+ """
392
+ Wizualizuje wyniki predykcji modelu na obrazie strony mangi/komiksu.
393
+
394
+ Rysuje:
395
+ - Zielone prostokąty: panele
396
+ - Czerwone prostokąty: tekst (tylko essential_text, tj. dialogi)
397
+ - Niebieskie prostokąty: postaci
398
+ - Fioletowe prostokąty: ogony dymków
399
+ - Niebieskie etykiety: imiona postaci
400
+ - Kolorowe linie: klastry postaci (ta sama osoba)
401
+ - Czerwone przerywane linie: asocjacje tekst-postać (kto mówi)
402
+ - Fioletowe przerywane linie: asocjacje tekst-ogon
403
+
404
+ Args:
405
+ image_as_np_array: Obraz strony jako numpy array [H, W, C]
406
+ predictions: Słownik z wynikami zawierający klucze:
407
+ - "panels", "texts", "characters", "tails": bounding boxy
408
+ - "character_names": imiona postaci
409
+ - "character_cluster_labels": etykiety klastrów postaci
410
+ - "text_character_associations": pary (idx_tekstu, idx_postaci)
411
+ - "text_tail_associations": pary (idx_tekstu, idx_ogona)
412
+ - "is_essential_text": flagi czy tekst to dialog
413
+ filename: Opcjonalna ścieżka do zapisu wizualizacji (lub None)
414
+
415
+ Returns:
416
+ Obraz wizualizacji jako numpy array [H, W, C]
417
+ """
418
  figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
419
  subplot.imshow(image_as_np_array)
420
+
421
+ # Rysowanie bounding boxów dla każdego typu obiektu
422
  plot_bboxes(subplot, predictions["panels"], color="green")
423
  plot_bboxes(subplot, predictions["texts"], color="red",
424
  visibility=predictions["is_essential_text"])
425
  plot_bboxes(subplot, predictions["characters"], color="blue")
426
  plot_bboxes(subplot, predictions["tails"], color="purple")
427
 
428
+ # Rysowanie imion postaci nad bounding boxami
429
  for i, name in enumerate(predictions["character_names"]):
430
+ char_bbox: List[float] = predictions["characters"][i]
431
+ x1: float
432
+ y1: float
433
+ x2: float
434
+ y2: float
435
  x1, y1, x2, y2 = char_bbox
436
  subplot.text(x1, y1 - 2, name,
437
  verticalalignment='bottom', horizontalalignment='left',
438
+ # Tło etykiety (niebieski prostokąt)
439
  bbox=dict(facecolor='blue', alpha=1, edgecolor='none'),
440
  color='white', fontsize=8)
441
 
442
+ # Paleta kolorów dla klastrów postaci
443
+ COLOURS: List[str] = [
444
+ "#b7ff51", # zielony
445
+ "#f50a8f", # różowy
446
+ "#4b13b6", # fioletowy
447
+ "#ddaa34", # pomarańczowy
448
+ "#bea2a2", # brązowy
449
  ]
450
+ colour_index: int = 0
451
+ character_cluster_labels: List[int] = predictions["character_cluster_labels"]
452
+ # Sortowanie etykiet klastrów według częstości (najczęstsze pierwsze)
453
+ unique_label_sorted_by_frequency: List[int] = sorted(list(set(
454
  character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
455
+
456
+ # Rysowanie linii łączących postaci w tym samym klastrze (ta sama osoba)
457
  for label in unique_label_sorted_by_frequency:
458
+ root: Optional[int] = None
459
+ others: List[int] = []
460
+ # Znajdź wszystkie postaci z tym samym labelem klastra
461
  for i in range(len(predictions["characters"])):
462
  if character_cluster_labels[i] == label:
463
  if root is None:
464
+ root = i # Pierwszy jako korzeń (centrum gwiazdki)
465
  else:
466
+ others.append(i) # Pozostałe jako liście
467
+
468
+ # Wybór koloru dla tego klastra
469
  if colour_index >= len(COLOURS):
470
+ # Jeśli zabrakło predefiniowanych kolorów, generuj losowy
471
+ random_colour: str = COLOURS[0]
472
  while random_colour in COLOURS:
473
  random_colour = "#" + \
474
  "".join([random.choice("0123456789ABCDEF")
475
  for j in range(6)])
476
  else:
477
+ random_colour: str = COLOURS[colour_index]
478
  colour_index += 1
479
+
480
+ # Oblicz centrum bbox korzenia
481
+ bbox_i: List[float] = predictions["characters"][root]
482
+ x1: float = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
483
+ y1: float = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
484
+ # Rysuj punkt w centrum korzenia
485
  subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
486
+
487
+ # Rysuj linie od korzenia do wszystkich innych postaci w klastrze
488
  for j in others:
489
+ bbox_j: List[float] = predictions["characters"][j]
 
490
  x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
491
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
492
+ x2: float = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
493
+ y2: float = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
494
+ # Linia od korzenia do liścia
495
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
496
+ # Punkt w centrum liścia
497
  subplot.plot([x2], [y2], color=random_colour,
498
  marker="o", markersize=5)
499
 
500
+ # Rysowanie asocjacji tekst-postać (kto mówi - czerwone przerywane linie)
501
  for (i, j) in predictions["text_character_associations"]:
502
+ bbox_i: List[float] = predictions["texts"][i]
503
+ bbox_j: List[float] = predictions["characters"][j]
504
+ # Pomiń jeśli tekst nie jest dialogiem
505
  if not predictions["is_essential_text"][i]:
506
  continue
507
+ # Oblicz centra bounding boxów
508
+ x1: float = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
509
+ y1: float = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
510
+ x2: float = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
511
+ y2: float = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
512
+ # Rysuj linię od tekstu do postaci
513
  subplot.plot([x1, x2], [y1, y2], color="red",
514
  linewidth=2, linestyle="dashed")
515
 
516
+ # Rysowanie asocjacji tekst-ogon (fioletowe przerywane linie)
517
  for (i, j) in predictions["text_tail_associations"]:
518
+ bbox_i: List[float] = predictions["texts"][i]
519
+ bbox_j: List[float] = predictions["tails"][j]
520
+ # Oblicz centra bounding boxów
521
+ x1: float = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
522
+ y1: float = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
523
+ x2: float = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
524
+ y2: float = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
525
+ # Rysuj linię od tekstu do ogona
526
  subplot.plot([x1, x2], [y1, y2], color="purple",
527
  linewidth=2, linestyle="dashed")
528
 
529
+ # Ukryj osie wykresu
530
  subplot.axis("off")
531
+ # Zapisz do pliku jeśli podano ścieżkę
532
  if filename is not None:
533
  plt.savefig(filename, bbox_inches="tight", pad_inches=0)
534
 
535
+ # Konwertuj figure matplotlib na numpy array
536
  figure.canvas.draw()
537
+ image: NDArray[np.uint8] = np.array(figure.canvas.renderer._renderer)
538
  plt.close()
539
  return image
540
 
541
 
542
+ def plot_bboxes(
543
+ subplot: Any,
544
+ bboxes: List[List[float]],
545
+ color: str = "red",
546
+ visibility: Optional[List[int]] = None
547
+ ) -> None:
548
+ """
549
+ Rysuje bounding boxy na subplocie matplotlib.
550
+
551
+ Args:
552
+ subplot: Subplot matplotlib do rysowania
553
+ bboxes: Lista bounding boxów w formacie [x1, y1, x2, y2]
554
+ color: Kolor krawędzi prostokątów (domyślnie "red")
555
+ visibility: Opcjonalna lista flag (1=widoczny, 0=ukryty).
556
+ Jeśli None, wszystkie boxy są widoczne
557
+ """
558
  if visibility is None:
559
  visibility = [1] * len(bboxes)
560
+
561
  for id, bbox in enumerate(bboxes):
562
+ # Pomiń niewidoczne boxy
563
  if visibility[id] == 0:
564
  continue
565
+ # Oblicz szerokość i wysokość
566
+ w: float = bbox[2] - bbox[0]
567
+ h: float = bbox[3] - bbox[1]
568
+ # Utwórz prostokąt
569
+ rect: patches.Rectangle = patches.Rectangle(
570
  bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
571
  )
572
  subplot.add_patch(rect)
573
 
574
 
575
+ def sort_panels(rects: Union[torch.Tensor, NDArray, List[List[float]]]) -> List[int]:
576
+ """
577
+ Sortuje panele w kolejności czytania (prawo->lewo, góra->dół dla mangi).
578
+
579
+ Algorytm:
580
+ 1. Lekka erozja paneli aby obsłużyć niedokładne detekcje
581
+ 2. Budowa grafu skierowanego z krawędziami reprezentującymi kolejność czytania
582
+ 3. Usunięcie cykli przez eliminację najdłuższych krawędzi w każdym cyklu
583
+ 4. Sortowanie topologiczne grafu acyklicznego
584
+
585
+ Args:
586
+ rects: Bounding boxy paneli [x1, y1, x2, y2]
587
+
588
+ Returns:
589
+ Lista indeksów paneli w kolejności czytania
590
+ """
591
+ before_rects: List[List[float]] = convert_to_list_of_lists(rects)
592
+ # Lekka erozja prostokątów (5%) aby obsłużyć niedokładne detekcje
593
+ rects_eroded: List[List[float]] = [
594
+ erode_rectangle(rect, 0.05) for rect in before_rects]
595
+
596
+ # Budowa skierowanego grafu kolejności czytania
597
+ G: nx.DiGraph = nx.DiGraph()
598
+ G.add_nodes_from(range(len(rects_eroded)))
599
+ for i in range(len(rects_eroded)):
600
+ for j in range(len(rects_eroded)):
601
  if i == j:
602
  continue
603
+ # Sprawdź czy istnieje krawędź i->j (i jest przed j w kolejności czytania)
604
+ if is_there_a_directed_edge(i, j, rects_eroded):
605
+ G.add_edge(i, j, weight=get_distance(
606
+ rects_eroded[i], rects_eroded[j]))
607
  else:
608
+ G.add_edge(j, i, weight=get_distance(
609
+ rects_eroded[i], rects_eroded[j]))
610
+
611
+ # Usuwanie cykli przez eliminację najdłuższych krawędzi
612
  while True:
613
+ cycles: List[List[int]] = sorted(nx.simple_cycles(G))
614
  cycles = [cycle for cycle in cycles if len(cycle) > 1]
615
  if len(cycles) == 0:
616
  break
617
+ # Weź pierwszy cykl
618
+ cycle: List[int] = cycles[0]
619
+ # Znajdź wszystkie krawędzie w cyklu
620
+ edges: List[Tuple[int, int]] = [
621
+ e for e in zip(cycle, cycle[1:] + cycle[:1])]
622
+ # Usuń najdłuższą krawędź (najmniej pewną)
623
+ max_cyclic_edge: Tuple[int, int] = max(
624
+ edges, key=lambda x: G.edges[x]["weight"])
625
  G.remove_edge(*max_cyclic_edge)
626
+
627
+ # Sortowanie topologiczne grafu acyklicznego daje kolejność czytania
628
  return list(nx.topological_sort(G))
629
 
630
 
631
+ def is_strictly_above(rectA: List[float], rectB: List[float]) -> bool:
632
+ """Sprawdza czy rectA jest całkowicie nad rectB (dolna krawędź A < górna krawędź B)."""
633
+ x1A: float
634
+ y1A: float
635
+ x2A: float
636
+ y2A: float
637
  x1A, y1A, x2A, y2A = rectA
638
+ x1B: float
639
+ y1B: float
640
+ x2B: float
641
+ y2B: float
642
  x1B, y1B, x2B, y2B = rectB
643
  return y2A < y1B
644
 
645
 
646
+ def is_strictly_below(rectA: List[float], rectB: List[float]) -> bool:
647
+ """Sprawdza czy rectA jest całkowicie pod rectB (dolna krawędź B < górna krawędź A)."""
648
+ x1A: float
649
+ y1A: float
650
+ x2A: float
651
+ y2A: float
652
  x1A, y1A, x2A, y2A = rectA
653
+ x1B: float
654
+ y1B: float
655
+ x2B: float
656
+ y2B: float
657
  x1B, y1B, x2B, y2B = rectB
658
  return y2B < y1A
659
 
660
 
661
+ def is_strictly_left_of(rectA: List[float], rectB: List[float]) -> bool:
662
+ """Sprawdza czy rectA jest całkowicie na lewo od rectB (prawa krawędź A < lewa krawędź B)."""
663
+ x1A: float
664
+ y1A: float
665
+ x2A: float
666
+ y2A: float
667
  x1A, y1A, x2A, y2A = rectA
668
+ x1B: float
669
+ y1B: float
670
+ x2B: float
671
+ y2B: float
672
  x1B, y1B, x2B, y2B = rectB
673
  return x2A < x1B
674
 
675
 
676
+ def is_strictly_right_of(rectA: List[float], rectB: List[float]) -> bool:
677
+ """Sprawdza czy rectA jest całkowicie na prawo od rectB (prawa krawędź B < lewa krawędź A)."""
678
+ x1A: float
679
+ y1A: float
680
+ x2A: float
681
+ y2A: float
682
  x1A, y1A, x2A, y2A = rectA
683
+ x1B: float
684
+ y1B: float
685
+ x2B: float
686
+ y2B: float
687
  x1B, y1B, x2B, y2B = rectB
688
  return x2B < x1A
689
 
690
 
691
+ def intersects(rectA: List[float], rectB: List[float]) -> bool:
692
+ """Sprawdza czy dwa prostokąty się przecinają używając Shapely."""
693
  return box(*rectA).intersects(box(*rectB))
694
 
695
 
696
+ def is_there_a_directed_edge(a: int, b: int, rects: List[List[float]]) -> bool:
697
+ """
698
+ Określa czy panel 'a' powinien być czytany przed panelem 'b'.
699
+
700
+ Używa reguł kolejności czytania mangi (prawo->lewo, góra->dół):
701
+ - Jeśli A jest na prawo i nie poniżej B -> A przed B
702
+ - Jeśli A jest nad i nie na lewo od B -> A przed B
703
+ - Dla nakładających się paneli używa erozji i heurystyk
704
+
705
+ Args:
706
+ a: Indeks pierwszego panelu
707
+ b: Indeks drugiego panelu
708
+ rects: Lista bounding boxów paneli
709
+
710
+ Returns:
711
+ True jeśli istnieje krawędź a->b (a przed b), False w przeciwnym razie
712
+ """
713
+ rectA: List[float] = rects[a]
714
+ rectB: List[float] = rects[b]
715
+ # Oblicz centra prostokątów
716
+ centre_of_A: List[float] = [rectA[0] + (rectA[2] - rectA[0]) / 2,
717
+ rectA[1] + (rectA[3] - rectA[1]) / 2]
718
+ centre_of_B: List[float] = [rectB[0] + (rectB[2] - rectB[0]) / 2,
719
+ rectB[1] + (rectB[3] - rectB[1]) / 2]
720
+ # Jeśli centra są w tym samym miejscu, większy panel jest pierwszy
721
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
722
  return box(*rectA).area > (box(*rectB)).area
723
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
 
740
  copy_B = erode_rectangle(copy_B, 0.05)
741
 
742
 
743
+ def get_distance(rectA: List[float], rectB: List[float]) -> float:
744
+ """
745
+ Oblicza odległość euklidesową między dwoma prostokątami.
746
+
747
+ Args:
748
+ rectA: Pierwszy prostokąt [x1, y1, x2, y2]
749
+ rectB: Drugi prostokąt [x1, y1, x2, y2]
750
+
751
+ Returns:
752
+ Odległość między prostokątami (0 jeśli się przecinają)
753
+ """
754
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
755
 
756
 
757
+ def use_cuts_to_determine_edge_from_a_to_b(a: int, b: int, rects: List[List[float]]) -> bool:
758
+ """
759
+ Używa zaawansowanych heurystyk "cięć" do określenia kolejności czytania paneli.
760
+
761
+ Gdy standardowe reguły przestrzenne (prawo/lewo/góra/dół) nie mogą jednoznacznie
762
+ określić kolejności między dwoma panelami, ta funkcja stosuje algorytm dzielenia
763
+ przestrzeni na "wiersze" i "kolumny" aby ustalić która z tych paneli jest pierwsza.
764
+
765
+ Algorytm:
766
+ 1. Wyznacza minimalny prostokąt otaczający oba panele (a i b)
767
+ 2. Znajduje wszystkie panele przecinające ten obszar
768
+ 3. KROK POZIOMY: Dzieli panele na "wiersze" (overlapping Y ranges)
769
+ - Scala nakładające się zakresy Y w nieprzekrywające się poziomy
770
+ - Jeśli a i b są w różnych poziomach -> wyższy poziom jest pierwszy
771
+ 4. KROK PIONOWY: Dzieli panele na "kolumny" (overlapping X ranges, odwrócone)
772
+ - Scala nakładające się zakresy X w nieprzekrywające się kolumny
773
+ - Kolumny są odwrócone (prawo->lewo) dla mangi
774
+ - Jeśli a i b są w różnych kolumnach -> prawa kolumna jest pierwsza
775
+ 5. EROZJA: Jeśli nadal nie można określić, zmniejsz panele o 5% i powtórz
776
+
777
+ Ta funkcja jest wywoływana tylko dla skomplikowanych układów paneli,
778
+ gdzie panele są częściowo nakładające się lub ułożone nieregularnie.
779
+
780
+ Args:
781
+ a: Indeks pierwszego panelu
782
+ b: Indeks drugiego panelu
783
+ rects: Lista wszystkich bounding boxów paneli [x1, y1, x2, y2]
784
+
785
+ Returns:
786
+ True jeśli panel 'a' powinien być czytany przed panelem 'b', False w przeciwnym razie
787
+ """
788
+ # Kopia głęboka aby nie modyfikować oryginalnych prostokątów
789
  rects = deepcopy(rects)
790
+
791
  while True:
792
+ # Oblicz minimalny prostokąt otaczający oba panele a i b
793
+ xmin: float
794
+ ymin: float
795
+ xmax: float
796
+ ymax: float
797
  xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
798
  rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
799
+
800
+ # Znajdź indeksy wszystkich paneli przecinających otaczający prostokąt
801
+ rect_index: List[int] = [i for i in range(len(rects)) if intersects(
802
  rects[i], [xmin, ymin, xmax, ymax])]
803
+ # Pobierz bounding boxy tych paneli
804
+ rects_copy: List[List[float]] = [rect for rect in rects if intersects(
805
  rect, [xmin, ymin, xmax, ymax])]
806
 
807
+ # PRÓBA 1: Podziel panele używając "poziomych" linii (wiersze)
808
+ # Scal nakładające się zakresy Y aby uzyskać nieprzekrywające się poziomy
809
+ overlapping_y_ranges: List[Tuple[float, float]] = merge_overlapping_ranges(
810
  [(y1, y2) for x1, y1, x2, y2 in rects_copy])
811
+ panel_index_to_split: Dict[int, int] = {}
812
+
813
+ # Przypisz każdy panel do poziomu (split_index)
814
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
815
  for i, index in enumerate(rect_index):
816
+ # Jeśli panel całkowicie mieści się w tym poziomie Y
817
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
818
  panel_index_to_split[index] = split_index
819
 
820
+ # Jeśli a i b są w różnych poziomach -> wyższy (mniejszy Y) jest pierwszy
821
  if panel_index_to_split[a] != panel_index_to_split[b]:
822
  return panel_index_to_split[a] < panel_index_to_split[b]
823
 
824
+ # PRÓBA 2: Podziel panele używając "pionowych" linii (kolumny)
825
+ # Scal nakładające się zakresy X aby uzyskać nieprzekrywające się kolumny
826
+ overlapping_x_ranges: List[Tuple[float, float]] = merge_overlapping_ranges(
827
  [(x1, x2) for x1, y1, x2, y2 in rects_copy])
828
+ panel_index_to_split: Dict[int, int] = {}
829
+
830
+ # Przypisz każdy panel do kolumny (split_index)
831
+ # [::-1] odwraca kolejność dla mangi (prawo->lewo)
832
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
833
  for i, index in enumerate(rect_index):
834
+ # Jeśli panel całkowicie mieści się w tej kolumnie X
835
  if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
836
  panel_index_to_split[index] = split_index
837
+
838
+ # Jeśli a i b są w różnych kolumnach -> prawa (mniejszy index po odwróceniu) jest pierwsza
839
  if panel_index_to_split[a] != panel_index_to_split[b]:
840
  return panel_index_to_split[a] < panel_index_to_split[b]
841
 
842
+ # PRÓBA 3: Erozja - zmniejsz prostokąty o 5% i spróbuj ponownie
843
+ # To pomaga gdy panele są bardzo blisko siebie lub lekko nakładające się
844
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
845
 
846
 
847
+ def erode_rectangle(bbox: List[float], erosion_factor: float) -> List[float]:
848
+ """
849
+ Zmniejsza prostokąt proporcjonalnie zachowując aspect ratio.
850
+
851
+ Erozja jest stosowana względem krótszego boku aby zachować kształt.
852
+ Używane do obsługi niedokładnych detekcji paneli.
853
+
854
+ Args:
855
+ bbox: Bounding box [x1, y1, x2, y2]
856
+ erosion_factor: Współczynnik erozji (0-1), np. 0.05 = 5% redukcja
857
+
858
+ Returns:
859
+ Zmniejszony bounding box [x1, y1, x2, y2]
860
+ """
861
+ x1: float
862
+ y1: float
863
+ x2: float
864
+ y2: float
865
  x1, y1, x2, y2 = bbox
866
+ w: float
867
+ h: float
868
  w, h = x2 - x1, y2 - y1
869
+ # Oblicz centrum
870
+ cx: float
871
+ cy: float
872
  cx, cy = x1 + w / 2, y1 + h / 2
873
+
874
+ # Oblicz współczynniki erozji względem aspect ratio
875
  if w < h:
876
+ aspect_ratio: float = w / h
877
+ erosion_factor_width: float = erosion_factor * aspect_ratio
878
+ erosion_factor_height: float = erosion_factor
879
  else:
880
+ aspect_ratio: float = h / w
881
+ erosion_factor_width: float = erosion_factor
882
+ erosion_factor_height: float = erosion_factor * aspect_ratio
883
+
884
+ # Zmniejsz wymiary
885
  w = w - w * erosion_factor_width
886
  h = h - h * erosion_factor_height
887
+ # Oblicz nowe współrzędne względem centrum
888
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
889
  return [x1, y1, x2, y2]
890
 
891
 
892
+ def merge_overlapping_ranges(ranges: List[Tuple[float, float]]) -> List[Tuple[float, float]]:
893
  """
894
+ Scala nakładające się zakresy 1D w nieprzekrywające się zakresy.
895
+
896
+ Używane do dzielenia paneli na "wiersze" lub "kolumny" dla określenia
897
+ kolejności czytania gdy panele są ułożone nieregularnie.
898
+
899
+ Args:
900
+ ranges: Lista krotek (początek, koniec) reprezentujących zakresy
901
+
902
+ Returns:
903
+ Lista scalonych nieprzekrywających się zakresów, posortowana
904
  """
905
  if len(ranges) == 0:
906
  return []
907
+ # Sortuj zakresy według początku
908
+ ranges_sorted: List[Tuple[float, float]] = sorted(
909
+ ranges, key=lambda x: x[0])
910
+ merged_ranges: List[Tuple[float, float]] = []
911
+ prev_x1: float
912
+ prev_x2: float
913
+ for i, r in enumerate(ranges_sorted):
914
  if i == 0:
915
  prev_x1, prev_x2 = r
916
  continue
917
+ x1: float
918
+ x2: float
919
  x1, x2 = r
920
+ # Jeśli zakres nie nakłada się z poprzednim, dodaj poprzedni
921
  if x1 > prev_x2:
922
  merged_ranges.append((prev_x1, prev_x2))
923
  prev_x1, prev_x2 = x1, x2
924
  else:
925
+ # Nakładają się - scal przez rozszerzenie poprzedniego
926
  prev_x2 = max(prev_x2, x2)
927
+ # Dodaj ostatni zakres
928
  merged_ranges.append((prev_x1, prev_x2))
929
  return merged_ranges
930
 
931
 
932
+ def sort_text_boxes_in_reading_order(
933
+ text_bboxes: Union[torch.Tensor, NDArray, List[List[float]]],
934
+ sorted_panel_bboxes: Union[torch.Tensor, NDArray, List[List[float]]]
935
+ ) -> List[int]:
936
+ """
937
+ Sortuje teksty w kolejności czytania, grupując według paneli.
938
+
939
+ Algorytm:
940
+ 1. Przypisz każdy tekst do najbliższego/najbardziej nakładającego się panelu
941
+ 2. Sortuj teksty według ID panelu (panele już są w kolejności czytania)
942
+ 3. W obrębie każdego panelu, sortuj teksty według odległości od prawego górnego rogu
943
+
944
+ Args:
945
+ text_bboxes: Bounding boxy tekstów [x1, y1, x2, y2]
946
+ sorted_panel_bboxes: Bounding boxy paneli już posortowane w kolejności czytania
947
 
948
+ Returns:
949
+ Lista indeksów tekstów w kolejności czytania
950
+ """
951
+ text_bboxes_list: List[List[float]] = convert_to_list_of_lists(text_bboxes)
952
+ sorted_panel_bboxes_list: List[List[float]] = convert_to_list_of_lists(
953
+ sorted_panel_bboxes)
954
+
955
+ if len(text_bboxes_list) == 0:
956
  return []
957
 
958
+ def indices_of_same_elements(nums: List[int]) -> List[List[int]]:
959
+ """Grupuje indeksy według wartości (elementy z tą samą wartością w jednej grupie)."""
960
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
961
  return [list(indices) for _, indices in groups]
962
 
963
+ # Przypisz każdy tekst do panelu
964
+ panel_id_for_text: List[int] = get_text_to_panel_mapping(
965
+ text_bboxes_list, sorted_panel_bboxes_list)
966
+ # Sortuj teksty według ID panelu
967
+ indices_of_texts: List[int] = list(range(len(text_bboxes_list)))
968
  indices_of_texts, panel_id_for_text = zip(
969
  *sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
970
  indices_of_texts = list(indices_of_texts)
971
+
972
+ # Dla każdej grupy tekstów w tym samym panelu, sortuj wewnątrz panelu
973
+ grouped_indices: List[List[int]] = indices_of_same_elements(
974
+ panel_id_for_text)
975
  for group in grouped_indices:
976
+ subset_of_text_indices: List[int] = [
977
+ indices_of_texts[i] for i in group]
978
+ text_bboxes_of_subset: List[List[float]] = [text_bboxes_list[i]
979
+ for i in subset_of_text_indices]
980
+ # Sortuj teksty w obrębie panelu (według odległości od prawego górnego rogu)
981
+ sorted_subset_indices: List[int] = sort_texts_within_panel(
982
+ text_bboxes_of_subset)
983
  indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
984
  for i in sorted_subset_indices]
985
  return indices_of_texts
986
 
987
 
988
+ def get_text_to_panel_mapping(
989
+ text_bboxes: List[List[float]],
990
+ sorted_panel_bboxes: List[List[float]]
991
+ ) -> List[int]:
992
+ """
993
+ Przypisuje każdy tekst do najbliższego/najbardziej nakładającego się panelu.
994
+
995
+ Algorytm priorytetów:
996
+ 1. PRIORYTET 1 - Przecięcie (intersection): Jeśli tekst przecina się z jakimś panelem,
997
+ wybierz panel z największą powierzchnią przecięcia (tekst "w środku" panelu)
998
+ 2. PRIORYTET 2 - Odległość (distance): Jeśli tekst nie przecina się z żadnym panelem,
999
+ wybierz najbliższy panel (tekst "obok" panelu)
1000
+ 3. BRAK PANELI: Jeśli nie ma żadnych paneli, przypisz -1 (brak przypisania)
1001
+
1002
+ Ta funkcja jest kluczowa dla sortowania tekstów w kolejności czytania,
1003
+ ponieważ teksty są grupowane według paneli, a panele są już posortowane.
1004
+
1005
+ Args:
1006
+ text_bboxes: Lista bounding boxów tekstów [x1, y1, x2, y2]
1007
+ sorted_panel_bboxes: Lista bounding boxów paneli [x1, y1, x2, y2],
1008
+ już posortowana w kolejności czytania
1009
+
1010
+ Returns:
1011
+ Lista indeksów paneli dla każdego tekstu (długość = len(text_bboxes)).
1012
+ Wartość -1 oznacza brak przypisania (gdy nie ma żadnych paneli).
1013
+ """
1014
+ text_to_panel_mapping: List[int] = []
1015
+
1016
  for text_bbox in text_bboxes:
1017
+ # Konwertuj bbox tekstu na polygon Shapely
1018
+ shapely_text_polygon: Polygon = box(*text_bbox)
1019
+ all_intersections: List[Tuple[float, int]] = [] # (area, panel_index)
1020
+ # (distance, panel_index)
1021
+ all_distances: List[Tuple[float, int]] = []
1022
+
1023
+ # Brak paneli - przypisz -1
1024
  if len(sorted_panel_bboxes) == 0:
1025
  text_to_panel_mapping.append(-1)
1026
  continue
1027
+
1028
+ # Sprawdź wszystkie panele
1029
  for j, annotation in enumerate(sorted_panel_bboxes):
1030
+ # Konwertuj bbox panelu na polygon Shapely
1031
+ shapely_annotation_polygon: Polygon = box(*annotation)
1032
+
1033
+ # Jeśli tekst przecina się z panelem, zapisz powierzchnię przecięcia
1034
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
1035
+ intersection_area: float = shapely_text_polygon.intersection(
1036
+ shapely_annotation_polygon).area
1037
+ all_intersections.append((intersection_area, j))
1038
+
1039
+ # Zawsze oblicz odległość (fallback jeśli brak przecięć)
1040
+ distance: float = shapely_text_polygon.distance(
1041
+ shapely_annotation_polygon)
1042
+ all_distances.append((distance, j))
1043
+
1044
+ # DECYZJA: Czy są przecięcia?
1045
  if len(all_intersections) == 0:
1046
+ # Brak przecięć -> wybierz najbliższy panel (minimalna odległość)
1047
+ closest_panel_index: int = min(
1048
+ all_distances, key=lambda x: x[0])[1]
1049
+ text_to_panel_mapping.append(closest_panel_index)
1050
  else:
1051
+ # Są przecięcia -> wybierz panel z największą powierzchnią przecięcia
1052
+ best_panel_index: int = max(
1053
+ all_intersections, key=lambda x: x[0])[1]
1054
+ text_to_panel_mapping.append(best_panel_index)
1055
+
1056
  return text_to_panel_mapping
1057
 
1058
 
1059
+ def sort_texts_within_panel(rects: List[List[float]]) -> List[int]:
1060
+ """
1061
+ Sortuje teksty w obrębie jednego panelu według odległości od prawego górnego rogu.
1062
+
1063
+ Dla mangi (czytanej prawo->lewo, góra->dół), teksty są czytane od prawego
1064
+ górnego rogu. Algorytm:
1065
+ 1. Znajdź prawy górny róg panelu (max(X), min(Y) ze wszystkich tekstów)
1066
+ 2. Oblicz odległość każdego tekstu od tego punktu odniesienia
1067
+ 3. Sortuj teksty według odległości (najbliższe pierwsze)
1068
+
1069
+ Tekst najbliższy prawego górnego rogu jest czytany jako pierwszy,
1070
+ następnie kolejne w dół i w lewo.
1071
+
1072
+ Args:
1073
+ rects: Lista bounding boxów tekstów w jednym panelu [x1, y1, x2, y2]
1074
+
1075
+ Returns:
1076
+ Lista indeksów tekstów posortowana według kolejności czytania
1077
+ (indeks 0 = pierwszy tekst do przeczytania)
1078
+ """
1079
+ # Znajdź prawy górny róg obszaru (punkt odniesienia dla mangi)
1080
+ smallest_y: float = float("inf") # Najmniejszy Y = najwyższy punkt
1081
+ greatest_x: float = float("-inf") # Największy X = najbardziej prawy punkt
1082
+
1083
  for i, rect in enumerate(rects):
1084
+ x1: float
1085
+ y1: float
1086
+ x2: float
1087
+ y2: float
1088
  x1, y1, x2, y2 = rect
1089
+ smallest_y = min(smallest_y, y1) # Szukaj najwyższego punktu
1090
+ greatest_x = max(greatest_x, x2) # Szukaj najbardziej prawego punktu
1091
 
1092
+ # Punkt odniesienia - prawy górny róg panelu
1093
+ reference_point: Point = Point(greatest_x, smallest_y)
1094
 
1095
+ # Konwertuj prostokąty na polygony Shapely wraz z ich indeksami
1096
+ polygons_and_index: List[Tuple[Polygon, int]] = []
1097
  for i, rect in enumerate(rects):
1098
+ x1: float
1099
+ y1: float
1100
+ x2: float
1101
+ y2: float
1102
  x1, y1, x2, y2 = rect
1103
  polygons_and_index.append((box(x1, y1, x2, y2), i))
1104
+
1105
+ # Sortuj według odległości od punktu odniesienia (najmniejsza odległość pierwsza)
1106
  polygons_and_index = sorted(
1107
  polygons_and_index, key=lambda x: reference_point.distance(x[0]))
1108
+
1109
+ # Wyciągnij tylko indeksy (porzuć polygony)
1110
+ indices: List[int] = [x[1] for x in polygons_and_index]
1111
  return indices
1112
 
1113
 
1114
+ def x1y1wh_to_x1y1x2y2(bbox: List[float]) -> List[float]:
1115
+ """
1116
+ Konwertuje bbox z formatu (x1, y1, width, height) na (x1, y1, x2, y2).
1117
+
1118
+ Args:
1119
+ bbox: Bounding box [x1, y1, width, height]
1120
+
1121
+ Returns:
1122
+ Bounding box [x1, y1, x2, y2] (corners format)
1123
+ """
1124
+ x1: float
1125
+ y1: float
1126
+ w: float
1127
+ h: float
1128
  x1, y1, w, h = bbox
1129
  return [x1, y1, x1 + w, y1 + h]
1130
 
1131
 
1132
+ def x1y1x2y2_to_xywh(bbox: List[float]) -> List[float]:
1133
+ """
1134
+ Konwertuje bbox z formatu (x1, y1, x2, y2) na (x, y, width, height).
1135
+
1136
+ Format COCO używa (x, y, w, h) zamiast corners.
1137
+
1138
+ Args:
1139
+ bbox: Bounding box [x1, y1, x2, y2] (corners format)
1140
+
1141
+ Returns:
1142
+ Bounding box [x, y, width, height] (COCO format)
1143
+ """
1144
+ x1: float
1145
+ y1: float
1146
+ x2: float
1147
+ y2: float
1148
  x1, y1, x2, y2 = bbox
1149
  return [x1, y1, x2 - x1, y2 - y1]
1150
 
1151
 
1152
+ def convert_to_list_of_lists(rects: Union[torch.Tensor, NDArray, List]) -> List[List[float]]:
1153
+ """
1154
+ Konwertuje różne formaty bounding boxów na List[List[float]].
1155
+
1156
+ Obsługuje:
1157
+ - torch.Tensor -> list
1158
+ - numpy.ndarray -> list
1159
+ - iterable -> list of lists
1160
+
1161
+ Args:
1162
+ rects: Bounding boxy w dowolnym formacie
1163
+
1164
+ Returns:
1165
+ Lista list [[x1, y1, x2, y2], ...]
1166
+ """
1167
  if isinstance(rects, torch.Tensor):
1168
  return rects.tolist()
1169
  if isinstance(rects, np.ndarray):
utils_PRE.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+ from shapely.geometry import Point, box
7
+ import networkx as nx
8
+ from copy import deepcopy
9
+ from itertools import groupby
10
+
11
+
12
+ def move_to_device(inputs, device):
13
+ if hasattr(inputs, "keys"):
14
+ return {k: move_to_device(v, device) for k, v in inputs.items()}
15
+ elif isinstance(inputs, list):
16
+ return [move_to_device(v, device) for v in inputs]
17
+ elif isinstance(inputs, tuple):
18
+ return tuple([move_to_device(v, device) for v in inputs])
19
+ elif isinstance(inputs, np.ndarray):
20
+ return torch.from_numpy(inputs).to(device)
21
+ else:
22
+ return inputs.to(device)
23
+
24
+
25
+ class UnionFind:
26
+ def __init__(self, n):
27
+ self.parent = list(range(n))
28
+ self.size = [1] * n
29
+ self.num_components = n
30
+
31
+ @classmethod
32
+ def from_adj_matrix(cls, adj_matrix):
33
+ ufds = cls(adj_matrix.shape[0])
34
+ for i in range(adj_matrix.shape[0]):
35
+ for j in range(adj_matrix.shape[1]):
36
+ if adj_matrix[i, j] > 0:
37
+ ufds.unite(i, j)
38
+ return ufds
39
+
40
+ @classmethod
41
+ def from_adj_list(cls, adj_list):
42
+ ufds = cls(len(adj_list))
43
+ for i in range(len(adj_list)):
44
+ for j in adj_list[i]:
45
+ ufds.unite(i, j)
46
+ return ufds
47
+
48
+ @classmethod
49
+ def from_edge_list(cls, edge_list, num_nodes):
50
+ ufds = cls(num_nodes)
51
+ for edge in edge_list:
52
+ ufds.unite(edge[0], edge[1])
53
+ return ufds
54
+
55
+ def find(self, x):
56
+ if self.parent[x] == x:
57
+ return x
58
+ self.parent[x] = self.find(self.parent[x])
59
+ return self.parent[x]
60
+
61
+ def unite(self, x, y):
62
+ x = self.find(x)
63
+ y = self.find(y)
64
+ if x != y:
65
+ if self.size[x] < self.size[y]:
66
+ x, y = y, x
67
+ self.parent[y] = x
68
+ self.size[x] += self.size[y]
69
+ self.num_components -= 1
70
+
71
+ def get_components_of(self, x):
72
+ x = self.find(x)
73
+ return [i for i in range(len(self.parent)) if self.find(i) == x]
74
+
75
+ def are_connected(self, x, y):
76
+ return self.find(x) == self.find(y)
77
+
78
+ def get_size(self, x):
79
+ return self.size[self.find(x)]
80
+
81
+ def get_num_components(self):
82
+ return self.num_components
83
+
84
+ def get_labels_for_connected_components(self):
85
+ map_parent_to_label = {}
86
+ labels = []
87
+ for i in range(len(self.parent)):
88
+ parent = self.find(i)
89
+ if parent not in map_parent_to_label:
90
+ map_parent_to_label[parent] = len(map_parent_to_label)
91
+ labels.append(map_parent_to_label[parent])
92
+ return labels
93
+
94
+
95
+ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
96
+ figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
97
+ subplot.imshow(image_as_np_array)
98
+ plot_bboxes(subplot, predictions["panels"], color="green")
99
+ plot_bboxes(subplot, predictions["texts"], color="red",
100
+ visibility=predictions["is_essential_text"])
101
+ plot_bboxes(subplot, predictions["characters"], color="blue")
102
+ plot_bboxes(subplot, predictions["tails"], color="purple")
103
+
104
+ for i, name in enumerate(predictions["character_names"]):
105
+ char_bbox = predictions["characters"][i]
106
+ x1, y1, x2, y2 = char_bbox
107
+ subplot.text(x1, y1 - 2, name,
108
+ verticalalignment='bottom', horizontalalignment='left',
109
+ # Background settings
110
+ bbox=dict(facecolor='blue', alpha=1, edgecolor='none'),
111
+ color='white', fontsize=8)
112
+
113
+ COLOURS = [
114
+ "#b7ff51", # green
115
+ "#f50a8f", # pink
116
+ "#4b13b6", # purple
117
+ "#ddaa34", # orange
118
+ "#bea2a2", # brown
119
+ ]
120
+ colour_index = 0
121
+ character_cluster_labels = predictions["character_cluster_labels"]
122
+ unique_label_sorted_by_frequency = sorted(list(set(
123
+ character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
124
+ for label in unique_label_sorted_by_frequency:
125
+ root = None
126
+ others = []
127
+ for i in range(len(predictions["characters"])):
128
+ if character_cluster_labels[i] == label:
129
+ if root is None:
130
+ root = i
131
+ else:
132
+ others.append(i)
133
+ if colour_index >= len(COLOURS):
134
+ random_colour = COLOURS[0]
135
+ while random_colour in COLOURS:
136
+ random_colour = "#" + \
137
+ "".join([random.choice("0123456789ABCDEF")
138
+ for j in range(6)])
139
+ else:
140
+ random_colour = COLOURS[colour_index]
141
+ colour_index += 1
142
+ bbox_i = predictions["characters"][root]
143
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
144
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
145
+ subplot.plot([x1], [y1], color=random_colour, marker="o", markersize=5)
146
+ for j in others:
147
+ # draw line from centre of bbox i to centre of bbox j
148
+ bbox_j = predictions["characters"][j]
149
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
150
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
151
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
152
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
153
+ subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
154
+ subplot.plot([x2], [y2], color=random_colour,
155
+ marker="o", markersize=5)
156
+
157
+ for (i, j) in predictions["text_character_associations"]:
158
+ bbox_i = predictions["texts"][i]
159
+ bbox_j = predictions["characters"][j]
160
+ if not predictions["is_essential_text"][i]:
161
+ continue
162
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
163
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
164
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
165
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
166
+ subplot.plot([x1, x2], [y1, y2], color="red",
167
+ linewidth=2, linestyle="dashed")
168
+
169
+ for (i, j) in predictions["text_tail_associations"]:
170
+ bbox_i = predictions["texts"][i]
171
+ bbox_j = predictions["tails"][j]
172
+ x1 = bbox_i[0] + (bbox_i[2] - bbox_i[0]) / 2
173
+ y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
174
+ x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
175
+ y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
176
+ subplot.plot([x1, x2], [y1, y2], color="purple",
177
+ linewidth=2, linestyle="dashed")
178
+
179
+ subplot.axis("off")
180
+ if filename is not None:
181
+ plt.savefig(filename, bbox_inches="tight", pad_inches=0)
182
+
183
+ figure.canvas.draw()
184
+ image = np.array(figure.canvas.renderer._renderer)
185
+ plt.close()
186
+ return image
187
+
188
+
189
+ def plot_bboxes(subplot, bboxes, color="red", visibility=None):
190
+ if visibility is None:
191
+ visibility = [1] * len(bboxes)
192
+ for id, bbox in enumerate(bboxes):
193
+ if visibility[id] == 0:
194
+ continue
195
+ w = bbox[2] - bbox[0]
196
+ h = bbox[3] - bbox[1]
197
+ rect = patches.Rectangle(
198
+ bbox[:2], w, h, linewidth=1, edgecolor=color, facecolor="none", linestyle="solid"
199
+ )
200
+ subplot.add_patch(rect)
201
+
202
+
203
+ def sort_panels(rects):
204
+ before_rects = convert_to_list_of_lists(rects)
205
+ # slightly erode all rectangles initially to account for imperfect detections
206
+ rects = [erode_rectangle(rect, 0.05) for rect in before_rects]
207
+ G = nx.DiGraph()
208
+ G.add_nodes_from(range(len(rects)))
209
+ for i in range(len(rects)):
210
+ for j in range(len(rects)):
211
+ if i == j:
212
+ continue
213
+ if is_there_a_directed_edge(i, j, rects):
214
+ G.add_edge(i, j, weight=get_distance(rects[i], rects[j]))
215
+ else:
216
+ G.add_edge(j, i, weight=get_distance(rects[i], rects[j]))
217
+ while True:
218
+ cycles = sorted(nx.simple_cycles(G))
219
+ cycles = [cycle for cycle in cycles if len(cycle) > 1]
220
+ if len(cycles) == 0:
221
+ break
222
+ cycle = cycles[0]
223
+ edges = [e for e in zip(cycle, cycle[1:] + cycle[:1])]
224
+ max_cyclic_edge = max(edges, key=lambda x: G.edges[x]["weight"])
225
+ G.remove_edge(*max_cyclic_edge)
226
+ return list(nx.topological_sort(G))
227
+
228
+
229
+ def is_strictly_above(rectA, rectB):
230
+ x1A, y1A, x2A, y2A = rectA
231
+ x1B, y1B, x2B, y2B = rectB
232
+ return y2A < y1B
233
+
234
+
235
+ def is_strictly_below(rectA, rectB):
236
+ x1A, y1A, x2A, y2A = rectA
237
+ x1B, y1B, x2B, y2B = rectB
238
+ return y2B < y1A
239
+
240
+
241
+ def is_strictly_left_of(rectA, rectB):
242
+ x1A, y1A, x2A, y2A = rectA
243
+ x1B, y1B, x2B, y2B = rectB
244
+ return x2A < x1B
245
+
246
+
247
+ def is_strictly_right_of(rectA, rectB):
248
+ x1A, y1A, x2A, y2A = rectA
249
+ x1B, y1B, x2B, y2B = rectB
250
+ return x2B < x1A
251
+
252
+
253
+ def intersects(rectA, rectB):
254
+ return box(*rectA).intersects(box(*rectB))
255
+
256
+
257
+ def is_there_a_directed_edge(a, b, rects):
258
+ rectA = rects[a]
259
+ rectB = rects[b]
260
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
261
+ rectA[1] + (rectA[3] - rectA[1]) / 2]
262
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
263
+ rectB[1] + (rectB[3] - rectB[1]) / 2]
264
+ if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
265
+ return box(*rectA).area > (box(*rectB)).area
266
+ copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
267
+ copy_B = [rectB[0], rectB[1], rectB[2], rectB[3]]
268
+ while True:
269
+ if is_strictly_above(copy_A, copy_B) and not is_strictly_left_of(copy_A, copy_B):
270
+ return 1
271
+ if is_strictly_above(copy_B, copy_A) and not is_strictly_left_of(copy_B, copy_A):
272
+ return 0
273
+ if is_strictly_right_of(copy_A, copy_B) and not is_strictly_below(copy_A, copy_B):
274
+ return 1
275
+ if is_strictly_right_of(copy_B, copy_A) and not is_strictly_below(copy_B, copy_A):
276
+ return 0
277
+ if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
278
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
279
+ if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
280
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
281
+ # otherwise they intersect
282
+ copy_A = erode_rectangle(copy_A, 0.05)
283
+ copy_B = erode_rectangle(copy_B, 0.05)
284
+
285
+
286
+ def get_distance(rectA, rectB):
287
+ return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
288
+
289
+
290
+ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
291
+ rects = deepcopy(rects)
292
+ while True:
293
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
294
+ rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
295
+ rect_index = [i for i in range(len(rects)) if intersects(
296
+ rects[i], [xmin, ymin, xmax, ymax])]
297
+ rects_copy = [rect for rect in rects if intersects(
298
+ rect, [xmin, ymin, xmax, ymax])]
299
+
300
+ # try to split the panels using a "horizontal" lines
301
+ overlapping_y_ranges = merge_overlapping_ranges(
302
+ [(y1, y2) for x1, y1, x2, y2 in rects_copy])
303
+ panel_index_to_split = {}
304
+ for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
305
+ for i, index in enumerate(rect_index):
306
+ if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
307
+ panel_index_to_split[index] = split_index
308
+
309
+ if panel_index_to_split[a] != panel_index_to_split[b]:
310
+ return panel_index_to_split[a] < panel_index_to_split[b]
311
+
312
+ # try to split the panels using a "vertical" lines
313
+ overlapping_x_ranges = merge_overlapping_ranges(
314
+ [(x1, x2) for x1, y1, x2, y2 in rects_copy])
315
+ panel_index_to_split = {}
316
+ for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
317
+ for i, index in enumerate(rect_index):
318
+ if x1 <= rects_copy[i][0] <= rects_copy[i][2] <= x2:
319
+ panel_index_to_split[index] = split_index
320
+ if panel_index_to_split[a] != panel_index_to_split[b]:
321
+ return panel_index_to_split[a] < panel_index_to_split[b]
322
+
323
+ # otherwise, erode the rectangles and try again
324
+ rects = [erode_rectangle(rect, 0.05) for rect in rects]
325
+
326
+
327
+ def erode_rectangle(bbox, erosion_factor):
328
+ x1, y1, x2, y2 = bbox
329
+ w, h = x2 - x1, y2 - y1
330
+ cx, cy = x1 + w / 2, y1 + h / 2
331
+ if w < h:
332
+ aspect_ratio = w / h
333
+ erosion_factor_width = erosion_factor * aspect_ratio
334
+ erosion_factor_height = erosion_factor
335
+ else:
336
+ aspect_ratio = h / w
337
+ erosion_factor_width = erosion_factor
338
+ erosion_factor_height = erosion_factor * aspect_ratio
339
+ w = w - w * erosion_factor_width
340
+ h = h - h * erosion_factor_height
341
+ x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
342
+ return [x1, y1, x2, y2]
343
+
344
+
345
+ def merge_overlapping_ranges(ranges):
346
+ """
347
+ ranges: list of tuples (x1, x2)
348
+ """
349
+ if len(ranges) == 0:
350
+ return []
351
+ ranges = sorted(ranges, key=lambda x: x[0])
352
+ merged_ranges = []
353
+ for i, r in enumerate(ranges):
354
+ if i == 0:
355
+ prev_x1, prev_x2 = r
356
+ continue
357
+ x1, x2 = r
358
+ if x1 > prev_x2:
359
+ merged_ranges.append((prev_x1, prev_x2))
360
+ prev_x1, prev_x2 = x1, x2
361
+ else:
362
+ prev_x2 = max(prev_x2, x2)
363
+ merged_ranges.append((prev_x1, prev_x2))
364
+ return merged_ranges
365
+
366
+
367
+ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
368
+ text_bboxes = convert_to_list_of_lists(text_bboxes)
369
+ sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
370
+
371
+ if len(text_bboxes) == 0:
372
+ return []
373
+
374
+ def indices_of_same_elements(nums):
375
+ groups = groupby(range(len(nums)), key=lambda i: nums[i])
376
+ return [list(indices) for _, indices in groups]
377
+
378
+ panel_id_for_text = get_text_to_panel_mapping(
379
+ text_bboxes, sorted_panel_bboxes)
380
+ indices_of_texts = list(range(len(text_bboxes)))
381
+ indices_of_texts, panel_id_for_text = zip(
382
+ *sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
383
+ indices_of_texts = list(indices_of_texts)
384
+ grouped_indices = indices_of_same_elements(panel_id_for_text)
385
+ for group in grouped_indices:
386
+ subset_of_text_indices = [indices_of_texts[i] for i in group]
387
+ text_bboxes_of_subset = [text_bboxes[i]
388
+ for i in subset_of_text_indices]
389
+ sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
390
+ indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
391
+ for i in sorted_subset_indices]
392
+ return indices_of_texts
393
+
394
+
395
+ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
396
+ text_to_panel_mapping = []
397
+ for text_bbox in text_bboxes:
398
+ shapely_text_polygon = box(*text_bbox)
399
+ all_intersections = []
400
+ all_distances = []
401
+ if len(sorted_panel_bboxes) == 0:
402
+ text_to_panel_mapping.append(-1)
403
+ continue
404
+ for j, annotation in enumerate(sorted_panel_bboxes):
405
+ shapely_annotation_polygon = box(*annotation)
406
+ if shapely_text_polygon.intersects(shapely_annotation_polygon):
407
+ all_intersections.append(
408
+ (shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
409
+ all_distances.append(
410
+ (shapely_text_polygon.distance(shapely_annotation_polygon), j))
411
+ if len(all_intersections) == 0:
412
+ text_to_panel_mapping.append(
413
+ min(all_distances, key=lambda x: x[0])[1])
414
+ else:
415
+ text_to_panel_mapping.append(
416
+ max(all_intersections, key=lambda x: x[0])[1])
417
+ return text_to_panel_mapping
418
+
419
+
420
+ def sort_texts_within_panel(rects):
421
+ smallest_y = float("inf")
422
+ greatest_x = float("-inf")
423
+ for i, rect in enumerate(rects):
424
+ x1, y1, x2, y2 = rect
425
+ smallest_y = min(smallest_y, y1)
426
+ greatest_x = max(greatest_x, x2)
427
+
428
+ reference_point = Point(greatest_x, smallest_y)
429
+
430
+ polygons_and_index = []
431
+ for i, rect in enumerate(rects):
432
+ x1, y1, x2, y2 = rect
433
+ polygons_and_index.append((box(x1, y1, x2, y2), i))
434
+ # sort points by closest to reference point
435
+ polygons_and_index = sorted(
436
+ polygons_and_index, key=lambda x: reference_point.distance(x[0]))
437
+ indices = [x[1] for x in polygons_and_index]
438
+ return indices
439
+
440
+
441
+ def x1y1wh_to_x1y1x2y2(bbox):
442
+ x1, y1, w, h = bbox
443
+ return [x1, y1, x1 + w, y1 + h]
444
+
445
+
446
+ def x1y1x2y2_to_xywh(bbox):
447
+ x1, y1, x2, y2 = bbox
448
+ return [x1, y1, x2 - x1, y2 - y1]
449
+
450
+
451
+ def convert_to_list_of_lists(rects):
452
+ if isinstance(rects, torch.Tensor):
453
+ return rects.tolist()
454
+ if isinstance(rects, np.ndarray):
455
+ return rects.tolist()
456
+ return [[a, b, c, d] for a, b, c, d in rects]