Mateusz Mr贸z
commited on
Commit
路
cd77b9d
1
Parent(s):
b3f1331
Implement Magiv2Model with detection, OCR, and character association capabilities
Browse files- Added Magiv2Model class inheriting from PreTrainedModel.
- Integrated VisionEncoderDecoderModel for OCR and ViTMAEModel for crop embeddings.
- Implemented ConditionalDetrModel for object detection with associated prediction heads.
- Developed methods for chapter-wide predictions, character name assignments, and affinity matrix calculations.
- Included utility functions for bounding box operations and Hungarian matching for object assignments.
- Added support for processing images in batches and handling various detection thresholds.
- Implemented visualization and prediction methods for single images.
- config.json +1 -1
- configuration_magiv2.py +117 -24
- modelling_magiv2.py +0 -0
- modelling_magiv2_pre.py +877 -0
- processing_magiv2.py +40 -24
- utils.py +92 -47
config.json
CHANGED
|
@@ -487,4 +487,4 @@
|
|
| 487 |
"ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
|
| 488 |
"torch_dtype": "float32",
|
| 489 |
"transformers_version": "4.34.0.dev0"
|
| 490 |
-
}
|
|
|
|
| 487 |
"ocr_pretrained_processor_path": "microsoft/trocr-base-printed",
|
| 488 |
"torch_dtype": "float32",
|
| 489 |
"transformers_version": "4.34.0.dev0"
|
| 490 |
+
}
|
configuration_magiv2.py
CHANGED
|
@@ -1,38 +1,131 @@
|
|
| 1 |
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
|
| 2 |
-
from typing import
|
| 3 |
|
| 4 |
|
| 5 |
class Magiv2Config(PretrainedConfig):
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
disable_ocr: bool = False,
|
| 11 |
disable_crop_embeddings: bool = False,
|
| 12 |
disable_detections: bool = False,
|
| 13 |
-
detection_model_config: dict = None,
|
| 14 |
-
ocr_model_config: dict = None,
|
| 15 |
-
crop_embedding_model_config: dict = None,
|
| 16 |
-
detection_image_preprocessing_config: dict = None,
|
| 17 |
-
ocr_pretrained_processor_path: str = None,
|
| 18 |
-
crop_embedding_image_preprocessing_config: dict = None,
|
| 19 |
-
**kwargs,
|
| 20 |
-
):
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
if detection_model_config is not None:
|
| 29 |
-
self.detection_model_config = PretrainedConfig.from_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if ocr_model_config is not None:
|
| 31 |
-
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
if crop_embedding_model_config is not None:
|
| 33 |
-
self.crop_embedding_model_config = PretrainedConfig.from_dict(
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
super().__init__(**kwargs)
|
|
|
|
| 1 |
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
|
| 4 |
|
| 5 |
class Magiv2Config(PretrainedConfig):
|
| 6 |
+
"""
|
| 7 |
+
Klasa konfiguracyjna dla modelu Magiv2.
|
| 8 |
+
|
| 9 |
+
Magiv2Config dziedziczy po PretrainedConfig z biblioteki transformers i definiuje
|
| 10 |
+
kompletn膮 konfiguracj臋 dla modelu wizyjnego sk艂adaj膮cego si臋 z trzech g艂贸wnych komponent贸w:
|
| 11 |
+
- Model detekcji obiekt贸w (detection)
|
| 12 |
+
- Model OCR (rozpoznawanie tekstu)
|
| 13 |
+
- Model embedowania wyci臋tych fragment贸w obrazu (crop embeddings)
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
model_type: Identyfikator typu modelu dla biblioteki transformers
|
| 17 |
+
disable_ocr: Flaga wy艂膮czaj膮ca modu艂 OCR
|
| 18 |
+
disable_crop_embeddings: Flaga wy艂膮czaj膮ca modu艂 embedowania wyci臋tych fragment贸w
|
| 19 |
+
disable_detections: Flaga wy艂膮czaj膮ca modu艂 detekcji obiekt贸w
|
| 20 |
+
detection_model_config: Konfiguracja modelu detekcji (po deserializacji)
|
| 21 |
+
ocr_model_config: Konfiguracja modelu OCR (po deserializacji)
|
| 22 |
+
crop_embedding_model_config: Konfiguracja modelu embedowania (po deserializacji)
|
| 23 |
+
detection_image_preprocessing_config: Parametry przetwarzania obrazu dla detekcji
|
| 24 |
+
ocr_pretrained_processor_path: 艢cie偶ka do wytrenowanego procesora OCR
|
| 25 |
+
crop_embedding_image_preprocessing_config: Parametry przetwarzania obrazu dla embedowania
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# Identyfikator typu modelu u偶ywany przez bibliotek臋 transformers
|
| 29 |
+
model_type: str = "magiv2"
|
| 30 |
|
| 31 |
def __init__(
|
| 32 |
self,
|
| 33 |
disable_ocr: bool = False,
|
| 34 |
disable_crop_embeddings: bool = False,
|
| 35 |
disable_detections: bool = False,
|
| 36 |
+
detection_model_config: Optional[dict[str, Any]] = None,
|
| 37 |
+
ocr_model_config: Optional[dict[str, Any]] = None,
|
| 38 |
+
crop_embedding_model_config: Optional[dict[str, Any]] = None,
|
| 39 |
+
detection_image_preprocessing_config: Optional[dict[str, Any]] = None,
|
| 40 |
+
ocr_pretrained_processor_path: Optional[str] = None,
|
| 41 |
+
crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = None,
|
| 42 |
+
**kwargs: Any,
|
| 43 |
+
) -> None:
|
| 44 |
+
"""
|
| 45 |
+
Inicjalizuje konfiguracj臋 modelu Magiv2.
|
| 46 |
+
|
| 47 |
+
Konstruktor przyjmuje parametry kontroluj膮ce kt贸re modu艂y modelu s膮 aktywne,
|
| 48 |
+
oraz konfiguracje dla poszczeg贸lnych komponent贸w. Konfiguracje przekazane jako
|
| 49 |
+
s艂owniki s膮 deserializowane do odpowiednich obiekt贸w Config z transformers.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
disable_ocr: Czy wy艂膮czy膰 modu艂 rozpoznawania tekstu (OCR).
|
| 53 |
+
Domy艣lnie False - OCR jest aktywne.
|
| 54 |
+
disable_crop_embeddings: Czy wy艂膮czy膰 modu艂 tworzenia embedding贸w dla wyci臋tych
|
| 55 |
+
fragment贸w obrazu. Domy艣lnie False - embedowanie aktywne.
|
| 56 |
+
disable_detections: Czy wy艂膮czy膰 modu艂 detekcji obiekt贸w na obrazie.
|
| 57 |
+
Domy艣lnie False - detekcja aktywna.
|
| 58 |
+
detection_model_config: S艂ownik z konfiguracj膮 modelu detekcji obiekt贸w.
|
| 59 |
+
Je艣li podany, zostanie zdeserializowany do PretrainedConfig.
|
| 60 |
+
ocr_model_config: S艂ownik z konfiguracj膮 modelu OCR (encoder-decoder).
|
| 61 |
+
Je艣li podany, zostanie zdeserializowany do VisionEncoderDecoderConfig.
|
| 62 |
+
crop_embedding_model_config: S艂ownik z konfiguracj膮 modelu embedowania wyci臋tych
|
| 63 |
+
fragment贸w. Je艣li podany, zostanie zdeserializowany
|
| 64 |
+
do PretrainedConfig.
|
| 65 |
+
detection_image_preprocessing_config: S艂ownik z parametrami preprocessingu obrazu
|
| 66 |
+
dla modu艂u detekcji (np. rozmiar, normalizacja).
|
| 67 |
+
ocr_pretrained_processor_path: 艢cie偶ka do katalogu lub Hub ID z wytrenowanym
|
| 68 |
+
procesorem obrazu dla modu艂u OCR.
|
| 69 |
+
crop_embedding_image_preprocessing_config: S艂ownik z parametrami preprocessingu
|
| 70 |
+
obrazu dla modu艂u embedowania.
|
| 71 |
+
**kwargs: Dodatkowe argumenty przekazywane do klasy bazowej PretrainedConfig.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
None
|
| 75 |
+
|
| 76 |
+
Note:
|
| 77 |
+
- Konfiguracje modeli s膮 deserializowane z dict do obiekt贸w Config tylko wtedy,
|
| 78 |
+
gdy zosta艂y przekazane (nie s膮 None)
|
| 79 |
+
- Flagi disable_* pozwalaj膮 na selektywne wy艂膮czanie poszczeg贸lnych modu艂贸w
|
| 80 |
+
- Wszystkie dodatkowe kwargs s膮 przekazywane do klasy bazowej PretrainedConfig
|
| 81 |
+
"""
|
| 82 |
+
# Przechowywanie flag wy艂膮czaj膮cych poszczeg贸lne modu艂y
|
| 83 |
+
self.disable_ocr: bool = disable_ocr
|
| 84 |
+
self.disable_crop_embeddings: bool = disable_crop_embeddings
|
| 85 |
+
self.disable_detections: bool = disable_detections
|
| 86 |
+
|
| 87 |
+
# Przechowywanie dodatkowych argument贸w przekazanych do konstruktora
|
| 88 |
+
self.kwargs: dict[str, Any] = kwargs
|
| 89 |
+
|
| 90 |
+
# Inicjalizacja atrybut贸w konfiguracji modeli jako None
|
| 91 |
+
# (mog膮 zosta膰 zdeserializowane poni偶ej je艣li parametry nie s膮 None)
|
| 92 |
+
self.detection_model_config: Optional[PretrainedConfig] = None
|
| 93 |
+
self.ocr_model_config: Optional[VisionEncoderDecoderConfig] = None
|
| 94 |
+
self.crop_embedding_model_config: Optional[PretrainedConfig] = None
|
| 95 |
+
|
| 96 |
+
# Deserializacja konfiguracji modelu detekcji ze s艂ownika do obiektu PretrainedConfig
|
| 97 |
if detection_model_config is not None:
|
| 98 |
+
self.detection_model_config = PretrainedConfig.from_dict(
|
| 99 |
+
detection_model_config
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Deserializacja konfiguracji modelu OCR ze s艂ownika do obiektu VisionEncoderDecoderConfig
|
| 103 |
+
# OCR wykorzystuje architektur臋 encoder-decoder (vision encoder + text decoder)
|
| 104 |
if ocr_model_config is not None:
|
| 105 |
+
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(
|
| 106 |
+
ocr_model_config
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Deserializacja konfiguracji modelu embedowania ze s艂ownika do obiektu PretrainedConfig
|
| 110 |
if crop_embedding_model_config is not None:
|
| 111 |
+
self.crop_embedding_model_config = PretrainedConfig.from_dict(
|
| 112 |
+
crop_embedding_model_config
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Przechowywanie konfiguracji preprocessingu obrazu dla modu艂u detekcji
|
| 116 |
+
# (np. docelowy rozmiar obrazu, parametry normalizacji, augmentacje)
|
| 117 |
+
self.detection_image_preprocessing_config: Optional[dict[str, Any]] = (
|
| 118 |
+
detection_image_preprocessing_config
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# 艢cie偶ka do wytrenowanego procesora OCR (mo偶e by膰 lokalna lub z Hugging Face Hub)
|
| 122 |
+
self.ocr_pretrained_processor_path: Optional[str] = ocr_pretrained_processor_path
|
| 123 |
+
|
| 124 |
+
# Przechowywanie konfiguracji preprocessingu obrazu dla modu艂u embedowania
|
| 125 |
+
# (np. docelowy rozmiar wyci臋膰, parametry normalizacji)
|
| 126 |
+
self.crop_embedding_image_preprocessing_config: Optional[dict[str, Any]] = (
|
| 127 |
+
crop_embedding_image_preprocessing_config
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# Wywo艂anie konstruktora klasy bazowej PretrainedConfig z dodatkowymi kwargs
|
| 131 |
super().__init__(**kwargs)
|
modelling_magiv2.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modelling_magiv2_pre.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
|
| 2 |
+
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
+
ConditionalDetrMLPPredictionHead,
|
| 4 |
+
ConditionalDetrModelOutput,
|
| 5 |
+
inverse_sigmoid,
|
| 6 |
+
)
|
| 7 |
+
from .configuration_magiv2 import Magiv2Config
|
| 8 |
+
from .processing_magiv2 import Magiv2Processor
|
| 9 |
+
from torch import nn
|
| 10 |
+
from typing import Optional, List
|
| 11 |
+
import torch
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
| 14 |
+
from transformers.image_transforms import center_to_corners_format
|
| 15 |
+
from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order
|
| 16 |
+
import pulp
|
| 17 |
+
import scipy
|
| 18 |
+
import numpy as np
|
| 19 |
+
from scipy.optimize import linear_sum_assignment
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Magiv2Model(PreTrainedModel):
|
| 23 |
+
config_class = Magiv2Config
|
| 24 |
+
|
| 25 |
+
def __init__(self, config):
|
| 26 |
+
super().__init__(config)
|
| 27 |
+
self.config = config
|
| 28 |
+
self.processor = Magiv2Processor(config)
|
| 29 |
+
if not config.disable_ocr:
|
| 30 |
+
self.ocr_model = VisionEncoderDecoderModel(config.ocr_model_config)
|
| 31 |
+
if not config.disable_crop_embeddings:
|
| 32 |
+
self.crop_embedding_model = ViTMAEModel(
|
| 33 |
+
config.crop_embedding_model_config)
|
| 34 |
+
if not config.disable_detections:
|
| 35 |
+
self.num_non_obj_tokens = 5
|
| 36 |
+
self.detection_transformer = ConditionalDetrModel(
|
| 37 |
+
config.detection_model_config)
|
| 38 |
+
self.bbox_predictor = ConditionalDetrMLPPredictionHead(
|
| 39 |
+
input_dim=config.detection_model_config.d_model,
|
| 40 |
+
hidden_dim=config.detection_model_config.d_model,
|
| 41 |
+
output_dim=4, num_layers=3
|
| 42 |
+
)
|
| 43 |
+
self.character_character_matching_head = ConditionalDetrMLPPredictionHead(
|
| 44 |
+
input_dim=3 * config.detection_model_config.d_model +
|
| 45 |
+
(2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0),
|
| 46 |
+
hidden_dim=config.detection_model_config.d_model,
|
| 47 |
+
output_dim=1, num_layers=3
|
| 48 |
+
)
|
| 49 |
+
self.text_character_matching_head = ConditionalDetrMLPPredictionHead(
|
| 50 |
+
input_dim=3 * config.detection_model_config.d_model,
|
| 51 |
+
hidden_dim=config.detection_model_config.d_model,
|
| 52 |
+
output_dim=1, num_layers=3
|
| 53 |
+
)
|
| 54 |
+
self.text_tail_matching_head = ConditionalDetrMLPPredictionHead(
|
| 55 |
+
input_dim=2 * config.detection_model_config.d_model,
|
| 56 |
+
hidden_dim=config.detection_model_config.d_model,
|
| 57 |
+
output_dim=1, num_layers=3
|
| 58 |
+
)
|
| 59 |
+
self.class_labels_classifier = nn.Linear(
|
| 60 |
+
config.detection_model_config.d_model, config.detection_model_config.num_labels
|
| 61 |
+
)
|
| 62 |
+
self.is_this_text_a_dialogue = nn.Linear(
|
| 63 |
+
config.detection_model_config.d_model, 1
|
| 64 |
+
)
|
| 65 |
+
self.matcher = ConditionalDetrHungarianMatcher(
|
| 66 |
+
class_cost=config.detection_model_config.class_cost,
|
| 67 |
+
bbox_cost=config.detection_model_config.bbox_cost,
|
| 68 |
+
giou_cost=config.detection_model_config.giou_cost
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def move_to_device(self, input):
|
| 72 |
+
return move_to_device(input, self.device)
|
| 73 |
+
|
| 74 |
+
@torch.no_grad()
|
| 75 |
+
def do_chapter_wide_prediction(self, pages_in_order, character_bank, eta=0.75, batch_size=8, use_tqdm=False, do_ocr=True):
|
| 76 |
+
texts = []
|
| 77 |
+
characters = []
|
| 78 |
+
character_clusters = []
|
| 79 |
+
if use_tqdm:
|
| 80 |
+
from tqdm import tqdm
|
| 81 |
+
iterator = tqdm(range(0, len(pages_in_order), batch_size))
|
| 82 |
+
else:
|
| 83 |
+
iterator = range(0, len(pages_in_order), batch_size)
|
| 84 |
+
per_page_results = []
|
| 85 |
+
for i in iterator:
|
| 86 |
+
pages = pages_in_order[i:i+batch_size]
|
| 87 |
+
results = self.predict_detections_and_associations(pages)
|
| 88 |
+
per_page_results.extend([result for result in results])
|
| 89 |
+
|
| 90 |
+
texts = [result["texts"] for result in per_page_results]
|
| 91 |
+
characters = [result["characters"] for result in per_page_results]
|
| 92 |
+
character_clusters = [result["character_cluster_labels"]
|
| 93 |
+
for result in per_page_results]
|
| 94 |
+
assigned_character_names = self.assign_names_to_characters(
|
| 95 |
+
pages_in_order, characters, character_bank, character_clusters, eta=eta)
|
| 96 |
+
if do_ocr:
|
| 97 |
+
ocr = self.predict_ocr(pages_in_order, texts, use_tqdm=use_tqdm)
|
| 98 |
+
offset_characters = 0
|
| 99 |
+
iteration_over = zip(
|
| 100 |
+
per_page_results, ocr) if do_ocr else per_page_results
|
| 101 |
+
for iter in iteration_over:
|
| 102 |
+
if do_ocr:
|
| 103 |
+
result, ocr_for_page = iter
|
| 104 |
+
result["ocr"] = ocr_for_page
|
| 105 |
+
else:
|
| 106 |
+
result = iter
|
| 107 |
+
result["character_names"] = assigned_character_names[offset_characters:
|
| 108 |
+
offset_characters + len(result["characters"])]
|
| 109 |
+
offset_characters += len(result["characters"])
|
| 110 |
+
return per_page_results
|
| 111 |
+
|
| 112 |
+
def assign_names_to_characters(self, images, character_bboxes, character_bank, character_clusters, eta=0.75):
|
| 113 |
+
if len(character_bank["images"]) == 0:
|
| 114 |
+
return ["Other" for bboxes_for_image in character_bboxes for bbox in bboxes_for_image]
|
| 115 |
+
chapter_wide_char_embeddings = self.predict_crop_embeddings(
|
| 116 |
+
images, character_bboxes)
|
| 117 |
+
chapter_wide_char_embeddings = torch.cat(
|
| 118 |
+
chapter_wide_char_embeddings, dim=0)
|
| 119 |
+
chapter_wide_char_embeddings = torch.nn.functional.normalize(
|
| 120 |
+
chapter_wide_char_embeddings, p=2, dim=1).cpu().numpy()
|
| 121 |
+
# create must-link and cannot link constraints from character_clusters
|
| 122 |
+
must_link = []
|
| 123 |
+
cannot_link = []
|
| 124 |
+
offset = 0
|
| 125 |
+
for clusters_per_image in character_clusters:
|
| 126 |
+
for i in range(len(clusters_per_image)):
|
| 127 |
+
for j in range(i+1, len(clusters_per_image)):
|
| 128 |
+
if clusters_per_image[i] == clusters_per_image[j]:
|
| 129 |
+
must_link.append((offset + i, offset + j))
|
| 130 |
+
else:
|
| 131 |
+
cannot_link.append((offset + i, offset + j))
|
| 132 |
+
offset += len(clusters_per_image)
|
| 133 |
+
character_bank_for_this_chapter = self.predict_crop_embeddings(
|
| 134 |
+
character_bank["images"], [[[0, 0, x.shape[1], x.shape[0]]] for x in character_bank["images"]])
|
| 135 |
+
character_bank_for_this_chapter = torch.cat(
|
| 136 |
+
character_bank_for_this_chapter, dim=0)
|
| 137 |
+
character_bank_for_this_chapter = torch.nn.functional.normalize(
|
| 138 |
+
character_bank_for_this_chapter, p=2, dim=1).cpu().numpy()
|
| 139 |
+
costs = scipy.spatial.distance.cdist(
|
| 140 |
+
chapter_wide_char_embeddings, character_bank_for_this_chapter)
|
| 141 |
+
none_of_the_above = eta * np.ones((costs.shape[0], 1))
|
| 142 |
+
costs = np.concatenate([costs, none_of_the_above], axis=1)
|
| 143 |
+
sense = pulp.LpMinimize
|
| 144 |
+
num_supply, num_demand = costs.shape
|
| 145 |
+
problem = pulp.LpProblem("Optimal_Transport_Problem", sense)
|
| 146 |
+
x = pulp.LpVariable.dicts("x", ((i, j) for i in range(
|
| 147 |
+
num_supply) for j in range(num_demand)), cat='Binary')
|
| 148 |
+
# Objective Function to minimize
|
| 149 |
+
problem += pulp.lpSum([costs[i][j] * x[(i, j)]
|
| 150 |
+
for i in range(num_supply) for j in range(num_demand)])
|
| 151 |
+
# each crop must be assigned to exactly one character
|
| 152 |
+
for i in range(num_supply):
|
| 153 |
+
problem += pulp.lpSum([x[(i, j)] for j in range(num_demand)]
|
| 154 |
+
) == 1, f"Supply_{i}_Total_Assignment"
|
| 155 |
+
# cannot link constraints
|
| 156 |
+
for j in range(num_demand-1):
|
| 157 |
+
for (s1, s2) in cannot_link:
|
| 158 |
+
problem += x[(s1, j)] + x[(s2, j)
|
| 159 |
+
] <= 1, f"Exclusion_{s1}_{s2}_Demand_{j}"
|
| 160 |
+
# must link constraints
|
| 161 |
+
for j in range(num_demand):
|
| 162 |
+
for (s1, s2) in must_link:
|
| 163 |
+
problem += x[(s1, j)] - x[(s2, j)
|
| 164 |
+
] == 0, f"Inclusion_{s1}_{s2}_Demand_{j}"
|
| 165 |
+
problem.solve()
|
| 166 |
+
assignments = []
|
| 167 |
+
for v in problem.variables():
|
| 168 |
+
if v.varValue is not None and v.varValue > 0:
|
| 169 |
+
index, assignment = v.name.split(
|
| 170 |
+
"(")[1].split(")")[0].split(",")
|
| 171 |
+
assignment = assignment[1:]
|
| 172 |
+
assignments.append((int(index), int(assignment)))
|
| 173 |
+
|
| 174 |
+
labels = np.zeros(num_supply)
|
| 175 |
+
for i, j in assignments:
|
| 176 |
+
labels[i] = j
|
| 177 |
+
|
| 178 |
+
return [character_bank["names"][int(i)] if i < len(character_bank["names"]) else "Other" for i in labels]
|
| 179 |
+
|
| 180 |
+
def predict_detections_and_associations(
|
| 181 |
+
self,
|
| 182 |
+
images,
|
| 183 |
+
move_to_device_fn=None,
|
| 184 |
+
character_detection_threshold=0.3,
|
| 185 |
+
panel_detection_threshold=0.2,
|
| 186 |
+
text_detection_threshold=0.3,
|
| 187 |
+
tail_detection_threshold=0.34,
|
| 188 |
+
character_character_matching_threshold=0.65,
|
| 189 |
+
text_character_matching_threshold=0.35,
|
| 190 |
+
text_tail_matching_threshold=0.3,
|
| 191 |
+
text_classification_threshold=0.5,
|
| 192 |
+
):
|
| 193 |
+
assert not self.config.disable_detections
|
| 194 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 195 |
+
|
| 196 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(
|
| 197 |
+
images)
|
| 198 |
+
inputs_to_detection_transformer = move_to_device_fn(
|
| 199 |
+
inputs_to_detection_transformer)
|
| 200 |
+
|
| 201 |
+
detection_transformer_output = self._get_detection_transformer_output(
|
| 202 |
+
**inputs_to_detection_transformer)
|
| 203 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(
|
| 204 |
+
detection_transformer_output)
|
| 205 |
+
|
| 206 |
+
original_image_sizes = torch.stack([torch.tensor(
|
| 207 |
+
img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device)
|
| 208 |
+
|
| 209 |
+
batch_scores, batch_labels = predicted_class_scores.max(-1)
|
| 210 |
+
batch_scores = batch_scores.sigmoid()
|
| 211 |
+
batch_labels = batch_labels.long()
|
| 212 |
+
batch_bboxes = center_to_corners_format(predicted_bboxes)
|
| 213 |
+
|
| 214 |
+
# scale the bboxes back to the original image size
|
| 215 |
+
if isinstance(original_image_sizes, List):
|
| 216 |
+
img_h = torch.Tensor([i[0] for i in original_image_sizes])
|
| 217 |
+
img_w = torch.Tensor([i[1] for i in original_image_sizes])
|
| 218 |
+
else:
|
| 219 |
+
img_h, img_w = original_image_sizes.unbind(1)
|
| 220 |
+
scale_fct = torch.stack(
|
| 221 |
+
[img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device)
|
| 222 |
+
batch_bboxes = batch_bboxes * scale_fct[:, None, :]
|
| 223 |
+
|
| 224 |
+
batch_panel_indices = self.processor._get_indices_of_panels_to_keep(
|
| 225 |
+
batch_scores, batch_labels, batch_bboxes, panel_detection_threshold)
|
| 226 |
+
batch_character_indices = self.processor._get_indices_of_characters_to_keep(
|
| 227 |
+
batch_scores, batch_labels, batch_bboxes, character_detection_threshold)
|
| 228 |
+
batch_text_indices = self.processor._get_indices_of_texts_to_keep(
|
| 229 |
+
batch_scores, batch_labels, batch_bboxes, text_detection_threshold)
|
| 230 |
+
batch_tail_indices = self.processor._get_indices_of_tails_to_keep(
|
| 231 |
+
batch_scores, batch_labels, batch_bboxes, tail_detection_threshold)
|
| 232 |
+
|
| 233 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(
|
| 234 |
+
detection_transformer_output)
|
| 235 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(
|
| 236 |
+
detection_transformer_output)
|
| 237 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(
|
| 238 |
+
detection_transformer_output)
|
| 239 |
+
|
| 240 |
+
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
| 241 |
+
character_obj_tokens_for_batch=[x[i] for x, i in zip(
|
| 242 |
+
predicted_obj_tokens_for_batch, batch_character_indices)],
|
| 243 |
+
text_obj_tokens_for_this_batch=[x[i] for x, i in zip(
|
| 244 |
+
predicted_obj_tokens_for_batch, batch_text_indices)],
|
| 245 |
+
t2c_tokens_for_batch=predicted_t2c_tokens_for_batch,
|
| 246 |
+
apply_sigmoid=True,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
character_bboxes_in_batch = [batch_bboxes[i][j]
|
| 250 |
+
for i, j in enumerate(batch_character_indices)]
|
| 251 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
| 252 |
+
character_obj_tokens_for_batch=[x[i] for x, i in zip(
|
| 253 |
+
predicted_obj_tokens_for_batch, batch_character_indices)],
|
| 254 |
+
crop_embeddings_for_batch=self.predict_crop_embeddings(
|
| 255 |
+
images, character_bboxes_in_batch, move_to_device_fn),
|
| 256 |
+
c2c_tokens_for_batch=predicted_c2c_tokens_for_batch,
|
| 257 |
+
apply_sigmoid=True,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
|
| 261 |
+
text_obj_tokens_for_this_batch=[x[i] for x, i in zip(
|
| 262 |
+
predicted_obj_tokens_for_batch, batch_text_indices)],
|
| 263 |
+
tail_obj_tokens_for_batch=[x[i] for x, i in zip(
|
| 264 |
+
predicted_obj_tokens_for_batch, batch_tail_indices)],
|
| 265 |
+
apply_sigmoid=True,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
is_this_text_a_dialogue = self._get_text_classification(
|
| 269 |
+
[x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)])
|
| 270 |
+
|
| 271 |
+
results = []
|
| 272 |
+
for batch_index in range(len(batch_scores)):
|
| 273 |
+
panel_indices = batch_panel_indices[batch_index]
|
| 274 |
+
character_indices = batch_character_indices[batch_index]
|
| 275 |
+
text_indices = batch_text_indices[batch_index]
|
| 276 |
+
tail_indices = batch_tail_indices[batch_index]
|
| 277 |
+
|
| 278 |
+
character_bboxes = batch_bboxes[batch_index][character_indices]
|
| 279 |
+
panel_bboxes = batch_bboxes[batch_index][panel_indices]
|
| 280 |
+
text_bboxes = batch_bboxes[batch_index][text_indices]
|
| 281 |
+
tail_bboxes = batch_bboxes[batch_index][tail_indices]
|
| 282 |
+
|
| 283 |
+
local_sorted_panel_indices = sort_panels(panel_bboxes)
|
| 284 |
+
panel_bboxes = panel_bboxes[local_sorted_panel_indices]
|
| 285 |
+
local_sorted_text_indices = sort_text_boxes_in_reading_order(
|
| 286 |
+
text_bboxes, panel_bboxes)
|
| 287 |
+
text_bboxes = text_bboxes[local_sorted_text_indices]
|
| 288 |
+
|
| 289 |
+
character_character_matching_scores = character_character_affinity_matrices[
|
| 290 |
+
batch_index]
|
| 291 |
+
text_character_matching_scores = text_character_affinity_matrices[
|
| 292 |
+
batch_index][local_sorted_text_indices]
|
| 293 |
+
text_tail_matching_scores = text_tail_affinity_matrices[
|
| 294 |
+
batch_index][local_sorted_text_indices]
|
| 295 |
+
|
| 296 |
+
is_essential_text = is_this_text_a_dialogue[batch_index][
|
| 297 |
+
local_sorted_text_indices] > text_classification_threshold
|
| 298 |
+
character_cluster_labels = UnionFind.from_adj_matrix(
|
| 299 |
+
character_character_matching_scores > character_character_matching_threshold
|
| 300 |
+
).get_labels_for_connected_components()
|
| 301 |
+
|
| 302 |
+
if 0 in text_character_matching_scores.shape:
|
| 303 |
+
text_character_associations = torch.zeros(
|
| 304 |
+
(0, 2), dtype=torch.long)
|
| 305 |
+
else:
|
| 306 |
+
most_likely_speaker_for_each_text = torch.argmax(
|
| 307 |
+
text_character_matching_scores, dim=1)
|
| 308 |
+
text_indices = torch.arange(len(text_bboxes)).type_as(
|
| 309 |
+
most_likely_speaker_for_each_text)
|
| 310 |
+
text_character_associations = torch.stack(
|
| 311 |
+
[text_indices, most_likely_speaker_for_each_text], dim=1)
|
| 312 |
+
to_keep = text_character_matching_scores.max(
|
| 313 |
+
dim=1).values > text_character_matching_threshold
|
| 314 |
+
text_character_associations = text_character_associations[to_keep]
|
| 315 |
+
|
| 316 |
+
if 0 in text_tail_matching_scores.shape:
|
| 317 |
+
text_tail_associations = torch.zeros((0, 2), dtype=torch.long)
|
| 318 |
+
else:
|
| 319 |
+
most_likely_tail_for_each_text = torch.argmax(
|
| 320 |
+
text_tail_matching_scores, dim=1)
|
| 321 |
+
text_indices = torch.arange(len(text_bboxes)).type_as(
|
| 322 |
+
most_likely_tail_for_each_text)
|
| 323 |
+
text_tail_associations = torch.stack(
|
| 324 |
+
[text_indices, most_likely_tail_for_each_text], dim=1)
|
| 325 |
+
to_keep = text_tail_matching_scores.max(
|
| 326 |
+
dim=1).values > text_tail_matching_threshold
|
| 327 |
+
text_tail_associations = text_tail_associations[to_keep]
|
| 328 |
+
|
| 329 |
+
results.append({
|
| 330 |
+
"panels": panel_bboxes.tolist(),
|
| 331 |
+
"texts": text_bboxes.tolist(),
|
| 332 |
+
"characters": character_bboxes.tolist(),
|
| 333 |
+
"tails": tail_bboxes.tolist(),
|
| 334 |
+
"text_character_associations": text_character_associations.tolist(),
|
| 335 |
+
"text_tail_associations": text_tail_associations.tolist(),
|
| 336 |
+
"character_cluster_labels": character_cluster_labels,
|
| 337 |
+
"is_essential_text": is_essential_text.tolist(),
|
| 338 |
+
})
|
| 339 |
+
|
| 340 |
+
return results
|
| 341 |
+
|
| 342 |
+
def get_affinity_matrices_given_annotations(
|
| 343 |
+
self, images, annotations, move_to_device_fn=None, apply_sigmoid=True
|
| 344 |
+
):
|
| 345 |
+
assert not self.config.disable_detections
|
| 346 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 347 |
+
|
| 348 |
+
character_bboxes_in_batch = [[bbox for bbox, label in zip(
|
| 349 |
+
a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations]
|
| 350 |
+
crop_embeddings_for_batch = self.predict_crop_embeddings(
|
| 351 |
+
images, character_bboxes_in_batch, move_to_device_fn)
|
| 352 |
+
|
| 353 |
+
inputs_to_detection_transformer = self.processor.preprocess_inputs_for_detection(
|
| 354 |
+
images, annotations)
|
| 355 |
+
inputs_to_detection_transformer = move_to_device_fn(
|
| 356 |
+
inputs_to_detection_transformer)
|
| 357 |
+
processed_targets = inputs_to_detection_transformer.pop("labels")
|
| 358 |
+
|
| 359 |
+
detection_transformer_output = self._get_detection_transformer_output(
|
| 360 |
+
**inputs_to_detection_transformer)
|
| 361 |
+
predicted_obj_tokens_for_batch = self._get_predicted_obj_tokens(
|
| 362 |
+
detection_transformer_output)
|
| 363 |
+
predicted_t2c_tokens_for_batch = self._get_predicted_t2c_tokens(
|
| 364 |
+
detection_transformer_output)
|
| 365 |
+
predicted_c2c_tokens_for_batch = self._get_predicted_c2c_tokens(
|
| 366 |
+
detection_transformer_output)
|
| 367 |
+
|
| 368 |
+
predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes(
|
| 369 |
+
detection_transformer_output)
|
| 370 |
+
matching_dict = {
|
| 371 |
+
"logits": predicted_class_scores,
|
| 372 |
+
"pred_boxes": predicted_bboxes,
|
| 373 |
+
}
|
| 374 |
+
indices = self.matcher(matching_dict, processed_targets)
|
| 375 |
+
|
| 376 |
+
matched_char_obj_tokens_for_batch = []
|
| 377 |
+
matched_text_obj_tokens_for_batch = []
|
| 378 |
+
matched_tail_obj_tokens_for_batch = []
|
| 379 |
+
t2c_tokens_for_batch = []
|
| 380 |
+
c2c_tokens_for_batch = []
|
| 381 |
+
|
| 382 |
+
for j, (pred_idx, tgt_idx) in enumerate(indices):
|
| 383 |
+
target_idx_to_pred_idx = {tgt.item(): pred.item()
|
| 384 |
+
for pred, tgt in zip(pred_idx, tgt_idx)}
|
| 385 |
+
targets_for_this_image = processed_targets[j]
|
| 386 |
+
indices_of_text_boxes_in_annotation = [i for i, label in enumerate(
|
| 387 |
+
targets_for_this_image["class_labels"]) if label == 1]
|
| 388 |
+
indices_of_char_boxes_in_annotation = [i for i, label in enumerate(
|
| 389 |
+
targets_for_this_image["class_labels"]) if label == 0]
|
| 390 |
+
indices_of_tail_boxes_in_annotation = [i for i, label in enumerate(
|
| 391 |
+
targets_for_this_image["class_labels"]) if label == 3]
|
| 392 |
+
predicted_text_indices = [target_idx_to_pred_idx[i]
|
| 393 |
+
for i in indices_of_text_boxes_in_annotation]
|
| 394 |
+
predicted_char_indices = [target_idx_to_pred_idx[i]
|
| 395 |
+
for i in indices_of_char_boxes_in_annotation]
|
| 396 |
+
predicted_tail_indices = [target_idx_to_pred_idx[i]
|
| 397 |
+
for i in indices_of_tail_boxes_in_annotation]
|
| 398 |
+
matched_char_obj_tokens_for_batch.append(
|
| 399 |
+
predicted_obj_tokens_for_batch[j][predicted_char_indices])
|
| 400 |
+
matched_text_obj_tokens_for_batch.append(
|
| 401 |
+
predicted_obj_tokens_for_batch[j][predicted_text_indices])
|
| 402 |
+
matched_tail_obj_tokens_for_batch.append(
|
| 403 |
+
predicted_obj_tokens_for_batch[j][predicted_tail_indices])
|
| 404 |
+
t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j])
|
| 405 |
+
c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j])
|
| 406 |
+
|
| 407 |
+
text_character_affinity_matrices = self._get_text_character_affinity_matrices(
|
| 408 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
| 409 |
+
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
|
| 410 |
+
t2c_tokens_for_batch=t2c_tokens_for_batch,
|
| 411 |
+
apply_sigmoid=apply_sigmoid,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
character_character_affinity_matrices = self._get_character_character_affinity_matrices(
|
| 415 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
| 416 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
| 417 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
| 418 |
+
apply_sigmoid=apply_sigmoid,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
character_character_affinity_matrices_crop_only = self._get_character_character_affinity_matrices(
|
| 422 |
+
character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch,
|
| 423 |
+
crop_embeddings_for_batch=crop_embeddings_for_batch,
|
| 424 |
+
c2c_tokens_for_batch=c2c_tokens_for_batch,
|
| 425 |
+
crop_only=True,
|
| 426 |
+
apply_sigmoid=apply_sigmoid,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
text_tail_affinity_matrices = self._get_text_tail_affinity_matrices(
|
| 430 |
+
text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch,
|
| 431 |
+
tail_obj_tokens_for_batch=matched_tail_obj_tokens_for_batch,
|
| 432 |
+
apply_sigmoid=apply_sigmoid,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
is_this_text_a_dialogue = self._get_text_classification(
|
| 436 |
+
matched_text_obj_tokens_for_batch, apply_sigmoid=apply_sigmoid)
|
| 437 |
+
|
| 438 |
+
return {
|
| 439 |
+
"text_character_affinity_matrices": text_character_affinity_matrices,
|
| 440 |
+
"character_character_affinity_matrices": character_character_affinity_matrices,
|
| 441 |
+
"character_character_affinity_matrices_crop_only": character_character_affinity_matrices_crop_only,
|
| 442 |
+
"text_tail_affinity_matrices": text_tail_affinity_matrices,
|
| 443 |
+
"is_this_text_a_dialogue": is_this_text_a_dialogue,
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
def predict_crop_embeddings(self, images, crop_bboxes, move_to_device_fn=None, mask_ratio=0.0, batch_size=256):
|
| 447 |
+
if self.config.disable_crop_embeddings:
|
| 448 |
+
return None
|
| 449 |
+
|
| 450 |
+
assert isinstance(
|
| 451 |
+
crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for"
|
| 452 |
+
|
| 453 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 454 |
+
|
| 455 |
+
# temporarily change the mask ratio from default to the one specified
|
| 456 |
+
old_mask_ratio = self.crop_embedding_model.embeddings.config.mask_ratio
|
| 457 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio
|
| 458 |
+
|
| 459 |
+
crops_per_image = []
|
| 460 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
| 461 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
| 462 |
+
crops = self.processor.crop_image(image, bboxes)
|
| 463 |
+
assert len(crops) == num_crops
|
| 464 |
+
crops_per_image.extend(crops)
|
| 465 |
+
|
| 466 |
+
if len(crops_per_image) == 0:
|
| 467 |
+
return [move_to_device_fn(torch.zeros(0, self.config.crop_embedding_model_config.hidden_size)) for _ in crop_bboxes]
|
| 468 |
+
|
| 469 |
+
crops_per_image = self.processor.preprocess_inputs_for_crop_embeddings(
|
| 470 |
+
crops_per_image)
|
| 471 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
| 472 |
+
|
| 473 |
+
# process the crops in batches to avoid OOM
|
| 474 |
+
embeddings = []
|
| 475 |
+
for i in range(0, len(crops_per_image), batch_size):
|
| 476 |
+
crops = crops_per_image[i:i+batch_size]
|
| 477 |
+
embeddings_per_batch = self.crop_embedding_model(
|
| 478 |
+
crops).last_hidden_state[:, 0]
|
| 479 |
+
embeddings.append(embeddings_per_batch)
|
| 480 |
+
embeddings = torch.cat(embeddings, dim=0)
|
| 481 |
+
|
| 482 |
+
crop_embeddings_for_batch = []
|
| 483 |
+
for num_crops in num_crops_per_batch:
|
| 484 |
+
crop_embeddings_for_batch.append(embeddings[:num_crops])
|
| 485 |
+
embeddings = embeddings[num_crops:]
|
| 486 |
+
|
| 487 |
+
# restore the mask ratio to the default
|
| 488 |
+
self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio
|
| 489 |
+
|
| 490 |
+
return crop_embeddings_for_batch
|
| 491 |
+
|
| 492 |
+
def predict_ocr(self, images, crop_bboxes, move_to_device_fn=None, use_tqdm=False, batch_size=32, max_new_tokens=64):
|
| 493 |
+
assert not self.config.disable_ocr
|
| 494 |
+
move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn
|
| 495 |
+
|
| 496 |
+
crops_per_image = []
|
| 497 |
+
num_crops_per_batch = [len(bboxes) for bboxes in crop_bboxes]
|
| 498 |
+
for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch):
|
| 499 |
+
crops = self.processor.crop_image(image, bboxes)
|
| 500 |
+
assert len(crops) == num_crops
|
| 501 |
+
crops_per_image.extend(crops)
|
| 502 |
+
|
| 503 |
+
if len(crops_per_image) == 0:
|
| 504 |
+
return [[] for _ in crop_bboxes]
|
| 505 |
+
|
| 506 |
+
crops_per_image = self.processor.preprocess_inputs_for_ocr(
|
| 507 |
+
crops_per_image)
|
| 508 |
+
crops_per_image = move_to_device_fn(crops_per_image)
|
| 509 |
+
|
| 510 |
+
# process the crops in batches to avoid OOM
|
| 511 |
+
all_generated_texts = []
|
| 512 |
+
if use_tqdm:
|
| 513 |
+
from tqdm import tqdm
|
| 514 |
+
pbar = tqdm(range(0, len(crops_per_image), batch_size))
|
| 515 |
+
else:
|
| 516 |
+
pbar = range(0, len(crops_per_image), batch_size)
|
| 517 |
+
for i in pbar:
|
| 518 |
+
crops = crops_per_image[i:i+batch_size]
|
| 519 |
+
generated_ids = self.ocr_model.generate(
|
| 520 |
+
crops, max_new_tokens=max_new_tokens)
|
| 521 |
+
generated_texts = self.processor.postprocess_ocr_tokens(
|
| 522 |
+
generated_ids)
|
| 523 |
+
all_generated_texts.extend(generated_texts)
|
| 524 |
+
|
| 525 |
+
texts_for_images = []
|
| 526 |
+
for num_crops in num_crops_per_batch:
|
| 527 |
+
texts_for_images.append([x.replace("\n", "")
|
| 528 |
+
for x in all_generated_texts[:num_crops]])
|
| 529 |
+
all_generated_texts = all_generated_texts[num_crops:]
|
| 530 |
+
|
| 531 |
+
return texts_for_images
|
| 532 |
+
|
| 533 |
+
def visualise_single_image_prediction(
|
| 534 |
+
self, image_as_np_array, predictions, filename=None
|
| 535 |
+
):
|
| 536 |
+
return visualise_single_image_prediction(image_as_np_array, predictions, filename)
|
| 537 |
+
|
| 538 |
+
@torch.no_grad()
|
| 539 |
+
def _get_detection_transformer_output(
|
| 540 |
+
self,
|
| 541 |
+
pixel_values: torch.FloatTensor,
|
| 542 |
+
pixel_mask: Optional[torch.LongTensor] = None
|
| 543 |
+
):
|
| 544 |
+
if self.config.disable_detections:
|
| 545 |
+
raise ValueError(
|
| 546 |
+
"Detection model is disabled. Set disable_detections=False in the config.")
|
| 547 |
+
return self.detection_transformer(
|
| 548 |
+
pixel_values=pixel_values,
|
| 549 |
+
pixel_mask=pixel_mask,
|
| 550 |
+
return_dict=True
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
def _get_predicted_obj_tokens(
|
| 554 |
+
self,
|
| 555 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
| 556 |
+
):
|
| 557 |
+
return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens]
|
| 558 |
+
|
| 559 |
+
def _get_predicted_c2c_tokens(
|
| 560 |
+
self,
|
| 561 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
| 562 |
+
):
|
| 563 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens]
|
| 564 |
+
|
| 565 |
+
def _get_predicted_t2c_tokens(
|
| 566 |
+
self,
|
| 567 |
+
detection_transformer_output: ConditionalDetrModelOutput
|
| 568 |
+
):
|
| 569 |
+
return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1]
|
| 570 |
+
|
| 571 |
+
def _get_predicted_bboxes_and_classes(
|
| 572 |
+
self,
|
| 573 |
+
detection_transformer_output: ConditionalDetrModelOutput,
|
| 574 |
+
):
|
| 575 |
+
if self.config.disable_detections:
|
| 576 |
+
raise ValueError(
|
| 577 |
+
"Detection model is disabled. Set disable_detections=False in the config.")
|
| 578 |
+
|
| 579 |
+
obj = self._get_predicted_obj_tokens(detection_transformer_output)
|
| 580 |
+
|
| 581 |
+
predicted_class_scores = self.class_labels_classifier(obj)
|
| 582 |
+
reference = detection_transformer_output.reference_points[:-
|
| 583 |
+
self.num_non_obj_tokens]
|
| 584 |
+
reference_before_sigmoid = inverse_sigmoid(reference).transpose(0, 1)
|
| 585 |
+
predicted_boxes = self.bbox_predictor(obj)
|
| 586 |
+
predicted_boxes[..., :2] += reference_before_sigmoid
|
| 587 |
+
predicted_boxes = predicted_boxes.sigmoid()
|
| 588 |
+
|
| 589 |
+
return predicted_class_scores, predicted_boxes
|
| 590 |
+
|
| 591 |
+
def _get_text_classification(
|
| 592 |
+
self,
|
| 593 |
+
text_obj_tokens_for_batch: List[torch.FloatTensor],
|
| 594 |
+
apply_sigmoid=False,
|
| 595 |
+
):
|
| 596 |
+
assert not self.config.disable_detections
|
| 597 |
+
is_this_text_a_dialogue = []
|
| 598 |
+
for text_obj_tokens in text_obj_tokens_for_batch:
|
| 599 |
+
if text_obj_tokens.shape[0] == 0:
|
| 600 |
+
is_this_text_a_dialogue.append(
|
| 601 |
+
torch.tensor([], dtype=torch.bool))
|
| 602 |
+
continue
|
| 603 |
+
classification = self.is_this_text_a_dialogue(
|
| 604 |
+
text_obj_tokens).squeeze(-1)
|
| 605 |
+
if apply_sigmoid:
|
| 606 |
+
classification = classification.sigmoid()
|
| 607 |
+
is_this_text_a_dialogue.append(classification)
|
| 608 |
+
return is_this_text_a_dialogue
|
| 609 |
+
|
| 610 |
+
def _get_character_character_affinity_matrices(
|
| 611 |
+
self,
|
| 612 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
| 613 |
+
crop_embeddings_for_batch: List[torch.FloatTensor] = None,
|
| 614 |
+
c2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
| 615 |
+
crop_only=False,
|
| 616 |
+
apply_sigmoid=True,
|
| 617 |
+
):
|
| 618 |
+
assert self.config.disable_detections or (
|
| 619 |
+
character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None)
|
| 620 |
+
assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None
|
| 621 |
+
assert not self.config.disable_detections or not self.config.disable_crop_embeddings
|
| 622 |
+
|
| 623 |
+
if crop_only:
|
| 624 |
+
affinity_matrices = []
|
| 625 |
+
for crop_embeddings in crop_embeddings_for_batch:
|
| 626 |
+
crop_embeddings = crop_embeddings / \
|
| 627 |
+
crop_embeddings.norm(dim=-1, keepdim=True)
|
| 628 |
+
affinity_matrix = crop_embeddings @ crop_embeddings.T
|
| 629 |
+
affinity_matrices.append(affinity_matrix)
|
| 630 |
+
return affinity_matrices
|
| 631 |
+
affinity_matrices = []
|
| 632 |
+
for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)):
|
| 633 |
+
if character_obj_tokens.shape[0] == 0:
|
| 634 |
+
affinity_matrices.append(torch.zeros(
|
| 635 |
+
0, 0).type_as(character_obj_tokens))
|
| 636 |
+
continue
|
| 637 |
+
if not self.config.disable_crop_embeddings:
|
| 638 |
+
crop_embeddings = crop_embeddings_for_batch[batch_index]
|
| 639 |
+
assert character_obj_tokens.shape[0] == crop_embeddings.shape[0]
|
| 640 |
+
character_obj_tokens = torch.cat(
|
| 641 |
+
[character_obj_tokens, crop_embeddings], dim=-1)
|
| 642 |
+
char_i = repeat(character_obj_tokens, "i d -> i repeat d",
|
| 643 |
+
repeat=character_obj_tokens.shape[0])
|
| 644 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d",
|
| 645 |
+
repeat=character_obj_tokens.shape[0])
|
| 646 |
+
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
|
| 647 |
+
c2c = repeat(c2c, "d -> repeat d", repeat=char_ij.shape[0])
|
| 648 |
+
char_ij_c2c = torch.cat([char_ij, c2c], dim=-1)
|
| 649 |
+
character_character_affinities = self.character_character_matching_head(
|
| 650 |
+
char_ij_c2c)
|
| 651 |
+
character_character_affinities = rearrange(
|
| 652 |
+
character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
|
| 653 |
+
character_character_affinities = (
|
| 654 |
+
character_character_affinities + character_character_affinities.T) / 2
|
| 655 |
+
if apply_sigmoid:
|
| 656 |
+
character_character_affinities = character_character_affinities.sigmoid()
|
| 657 |
+
affinity_matrices.append(character_character_affinities)
|
| 658 |
+
return affinity_matrices
|
| 659 |
+
|
| 660 |
+
def _get_text_character_affinity_matrices(
|
| 661 |
+
self,
|
| 662 |
+
character_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
| 663 |
+
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
|
| 664 |
+
t2c_tokens_for_batch: List[torch.FloatTensor] = None,
|
| 665 |
+
apply_sigmoid=True,
|
| 666 |
+
):
|
| 667 |
+
assert not self.config.disable_detections
|
| 668 |
+
assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None
|
| 669 |
+
affinity_matrices = []
|
| 670 |
+
for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch):
|
| 671 |
+
if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
|
| 672 |
+
affinity_matrices.append(torch.zeros(
|
| 673 |
+
text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens))
|
| 674 |
+
continue
|
| 675 |
+
text_i = repeat(text_obj_tokens, "i d -> i repeat d",
|
| 676 |
+
repeat=character_obj_tokens.shape[0])
|
| 677 |
+
char_j = repeat(character_obj_tokens, "j d -> repeat j d",
|
| 678 |
+
repeat=text_obj_tokens.shape[0])
|
| 679 |
+
text_char = rearrange(
|
| 680 |
+
[text_i, char_j], "two i j d -> (i j) (two d)")
|
| 681 |
+
t2c = repeat(t2c, "d -> repeat d", repeat=text_char.shape[0])
|
| 682 |
+
text_char_t2c = torch.cat([text_char, t2c], dim=-1)
|
| 683 |
+
text_character_affinities = self.text_character_matching_head(
|
| 684 |
+
text_char_t2c)
|
| 685 |
+
text_character_affinities = rearrange(
|
| 686 |
+
text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
| 687 |
+
if apply_sigmoid:
|
| 688 |
+
text_character_affinities = text_character_affinities.sigmoid()
|
| 689 |
+
affinity_matrices.append(text_character_affinities)
|
| 690 |
+
return affinity_matrices
|
| 691 |
+
|
| 692 |
+
def _get_text_tail_affinity_matrices(
|
| 693 |
+
self,
|
| 694 |
+
text_obj_tokens_for_this_batch: List[torch.FloatTensor] = None,
|
| 695 |
+
tail_obj_tokens_for_batch: List[torch.FloatTensor] = None,
|
| 696 |
+
apply_sigmoid=True,
|
| 697 |
+
):
|
| 698 |
+
assert not self.config.disable_detections
|
| 699 |
+
assert tail_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None
|
| 700 |
+
affinity_matrices = []
|
| 701 |
+
for tail_obj_tokens, text_obj_tokens in zip(tail_obj_tokens_for_batch, text_obj_tokens_for_this_batch):
|
| 702 |
+
if tail_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0:
|
| 703 |
+
affinity_matrices.append(torch.zeros(
|
| 704 |
+
text_obj_tokens.shape[0], tail_obj_tokens.shape[0]).type_as(tail_obj_tokens))
|
| 705 |
+
continue
|
| 706 |
+
text_i = repeat(text_obj_tokens, "i d -> i repeat d",
|
| 707 |
+
repeat=tail_obj_tokens.shape[0])
|
| 708 |
+
tail_j = repeat(tail_obj_tokens, "j d -> repeat j d",
|
| 709 |
+
repeat=text_obj_tokens.shape[0])
|
| 710 |
+
text_tail = rearrange(
|
| 711 |
+
[text_i, tail_j], "two i j d -> (i j) (two d)")
|
| 712 |
+
text_tail_affinities = self.text_tail_matching_head(text_tail)
|
| 713 |
+
text_tail_affinities = rearrange(
|
| 714 |
+
text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
| 715 |
+
if apply_sigmoid:
|
| 716 |
+
text_tail_affinities = text_tail_affinities.sigmoid()
|
| 717 |
+
affinity_matrices.append(text_tail_affinities)
|
| 718 |
+
return affinity_matrices
|
| 719 |
+
|
| 720 |
+
# Copied from transformers.models.detr.modeling_detr._upcast
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def _upcast(t):
|
| 724 |
+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
| 725 |
+
if t.is_floating_point():
|
| 726 |
+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
| 727 |
+
else:
|
| 728 |
+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
# Copied from transformers.models.detr.modeling_detr.box_area
|
| 732 |
+
def box_area(boxes):
|
| 733 |
+
"""
|
| 734 |
+
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
| 738 |
+
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
| 739 |
+
< x2` and `0 <= y1 < y2`.
|
| 740 |
+
|
| 741 |
+
Returns:
|
| 742 |
+
`torch.FloatTensor`: a tensor containing the area for each box.
|
| 743 |
+
"""
|
| 744 |
+
boxes = _upcast(boxes)
|
| 745 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
# Copied from transformers.models.detr.modeling_detr.box_iou
|
| 749 |
+
def box_iou(boxes1, boxes2):
|
| 750 |
+
area1 = box_area(boxes1)
|
| 751 |
+
area2 = box_area(boxes2)
|
| 752 |
+
|
| 753 |
+
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
| 754 |
+
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
| 755 |
+
|
| 756 |
+
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
| 757 |
+
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
| 758 |
+
|
| 759 |
+
union = area1[:, None] + area2 - inter
|
| 760 |
+
|
| 761 |
+
iou = inter / union
|
| 762 |
+
return iou, union
|
| 763 |
+
|
| 764 |
+
|
| 765 |
+
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
| 766 |
+
def generalized_box_iou(boxes1, boxes2):
|
| 767 |
+
"""
|
| 768 |
+
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
| 769 |
+
|
| 770 |
+
Returns:
|
| 771 |
+
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
| 772 |
+
"""
|
| 773 |
+
# degenerate boxes gives inf / nan results
|
| 774 |
+
# so do an early check
|
| 775 |
+
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
| 776 |
+
raise ValueError(
|
| 777 |
+
f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
| 778 |
+
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
| 779 |
+
raise ValueError(
|
| 780 |
+
f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
| 781 |
+
iou, union = box_iou(boxes1, boxes2)
|
| 782 |
+
|
| 783 |
+
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
| 784 |
+
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
| 785 |
+
|
| 786 |
+
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
| 787 |
+
area = width_height[:, :, 0] * width_height[:, :, 1]
|
| 788 |
+
|
| 789 |
+
return iou - (area - union) / area
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
| 793 |
+
class ConditionalDetrHungarianMatcher(nn.Module):
|
| 794 |
+
"""
|
| 795 |
+
This class computes an assignment between the targets and the predictions of the network.
|
| 796 |
+
|
| 797 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
| 798 |
+
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
| 799 |
+
un-matched (and thus treated as non-objects).
|
| 800 |
+
|
| 801 |
+
Args:
|
| 802 |
+
class_cost:
|
| 803 |
+
The relative weight of the classification error in the matching cost.
|
| 804 |
+
bbox_cost:
|
| 805 |
+
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
| 806 |
+
giou_cost:
|
| 807 |
+
The relative weight of the giou loss of the bounding box in the matching cost.
|
| 808 |
+
"""
|
| 809 |
+
|
| 810 |
+
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
| 811 |
+
super().__init__()
|
| 812 |
+
|
| 813 |
+
self.class_cost = class_cost
|
| 814 |
+
self.bbox_cost = bbox_cost
|
| 815 |
+
self.giou_cost = giou_cost
|
| 816 |
+
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
| 817 |
+
raise ValueError("All costs of the Matcher can't be 0")
|
| 818 |
+
|
| 819 |
+
@torch.no_grad()
|
| 820 |
+
def forward(self, outputs, targets):
|
| 821 |
+
"""
|
| 822 |
+
Args:
|
| 823 |
+
outputs (`dict`):
|
| 824 |
+
A dictionary that contains at least these entries:
|
| 825 |
+
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
| 826 |
+
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
| 827 |
+
targets (`List[dict]`):
|
| 828 |
+
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
| 829 |
+
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
| 830 |
+
ground-truth
|
| 831 |
+
objects in the target) containing the class labels
|
| 832 |
+
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
| 836 |
+
- index_i is the indices of the selected predictions (in order)
|
| 837 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
| 838 |
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
| 839 |
+
"""
|
| 840 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
| 841 |
+
|
| 842 |
+
# We flatten to compute the cost matrices in a batch
|
| 843 |
+
# [batch_size * num_queries, num_classes]
|
| 844 |
+
out_prob = outputs["logits"].flatten(0, 1).sigmoid()
|
| 845 |
+
out_bbox = outputs["pred_boxes"].flatten(
|
| 846 |
+
0, 1) # [batch_size * num_queries, 4]
|
| 847 |
+
|
| 848 |
+
# Also concat the target labels and boxes
|
| 849 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
| 850 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
| 851 |
+
|
| 852 |
+
# Compute the classification cost.
|
| 853 |
+
alpha = 0.25
|
| 854 |
+
gamma = 2.0
|
| 855 |
+
neg_cost_class = (1 - alpha) * (out_prob**gamma) * \
|
| 856 |
+
(-(1 - out_prob + 1e-8).log())
|
| 857 |
+
pos_cost_class = alpha * \
|
| 858 |
+
((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
| 859 |
+
class_cost = pos_cost_class[:, target_ids] - \
|
| 860 |
+
neg_cost_class[:, target_ids]
|
| 861 |
+
|
| 862 |
+
# Compute the L1 cost between boxes
|
| 863 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
| 864 |
+
|
| 865 |
+
# Compute the giou cost between boxes
|
| 866 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(
|
| 867 |
+
out_bbox), center_to_corners_format(target_bbox))
|
| 868 |
+
|
| 869 |
+
# Final cost matrix
|
| 870 |
+
cost_matrix = self.bbox_cost * bbox_cost + \
|
| 871 |
+
self.class_cost * class_cost + self.giou_cost * giou_cost
|
| 872 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
| 873 |
+
|
| 874 |
+
sizes = [len(v["boxes"]) for v in targets]
|
| 875 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(
|
| 876 |
+
cost_matrix.split(sizes, -1))]
|
| 877 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
processing_magiv2.py
CHANGED
|
@@ -5,6 +5,7 @@ from shapely.geometry import box
|
|
| 5 |
from .utils import x1y1x2y2_to_xywh
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
| 8 |
class Magiv2Processor():
|
| 9 |
def __init__(self, config):
|
| 10 |
self.config = config
|
|
@@ -13,34 +14,38 @@ class Magiv2Processor():
|
|
| 13 |
self.crop_embedding_image_preprocessor = None
|
| 14 |
if not config.disable_detections:
|
| 15 |
assert config.detection_image_preprocessing_config is not None
|
| 16 |
-
self.detection_image_preprocessor =
|
|
|
|
| 17 |
if not config.disable_ocr:
|
| 18 |
assert config.ocr_pretrained_processor_path is not None
|
| 19 |
-
self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
|
|
|
|
| 20 |
if not config.disable_crop_embeddings:
|
| 21 |
assert config.crop_embedding_image_preprocessing_config is not None
|
| 22 |
-
self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
|
| 23 |
-
|
|
|
|
| 24 |
def preprocess_inputs_for_detection(self, images, annotations=None):
|
| 25 |
images = list(images)
|
| 26 |
assert isinstance(images[0], np.ndarray)
|
| 27 |
annotations = self._convert_annotations_to_coco_format(annotations)
|
| 28 |
-
inputs = self.detection_image_preprocessor(
|
|
|
|
| 29 |
return inputs
|
| 30 |
|
| 31 |
def preprocess_inputs_for_ocr(self, images):
|
| 32 |
images = list(images)
|
| 33 |
assert isinstance(images[0], np.ndarray)
|
| 34 |
return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
|
| 35 |
-
|
| 36 |
def preprocess_inputs_for_crop_embeddings(self, images):
|
| 37 |
images = list(images)
|
| 38 |
assert isinstance(images[0], np.ndarray)
|
| 39 |
return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
|
| 40 |
-
|
| 41 |
def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
|
| 42 |
return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
|
| 43 |
-
|
| 44 |
def crop_image(self, image, bboxes):
|
| 45 |
crops_for_image = []
|
| 46 |
for bbox in bboxes:
|
|
@@ -48,7 +53,8 @@ class Magiv2Processor():
|
|
| 48 |
|
| 49 |
# fix the bounding box in case it is out of bounds or too small
|
| 50 |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 51 |
-
x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(
|
|
|
|
| 52 |
x1, y1 = max(0, x1), max(0, y1)
|
| 53 |
x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
|
| 54 |
x2, y2 = max(0, x2), max(0, y2)
|
|
@@ -71,10 +77,11 @@ class Magiv2Processor():
|
|
| 71 |
def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
|
| 72 |
indices_of_characters_to_keep = []
|
| 73 |
for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
|
| 74 |
-
indices = torch.where((labels == 0) & (
|
|
|
|
| 75 |
indices_of_characters_to_keep.append(indices)
|
| 76 |
return indices_of_characters_to_keep
|
| 77 |
-
|
| 78 |
def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
|
| 79 |
indices_of_panels_to_keep = []
|
| 80 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
|
@@ -85,7 +92,8 @@ class Magiv2Processor():
|
|
| 85 |
if len(indices) == 0:
|
| 86 |
indices_of_panels_to_keep.append([])
|
| 87 |
continue
|
| 88 |
-
scores, labels, indices, bboxes
|
|
|
|
| 89 |
panels_to_keep = []
|
| 90 |
union_of_panels_so_far = box(0, 0, 0, 0)
|
| 91 |
for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
|
|
@@ -95,21 +103,25 @@ class Magiv2Processor():
|
|
| 95 |
if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
|
| 96 |
continue
|
| 97 |
panels_to_keep.append((ps, pl, pb, pi))
|
| 98 |
-
union_of_panels_so_far = union_of_panels_so_far.union(
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
return indices_of_panels_to_keep
|
| 101 |
-
|
| 102 |
def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
| 103 |
indices_of_texts_to_keep = []
|
| 104 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
| 105 |
-
indices = torch.where((labels == 1) & (
|
|
|
|
| 106 |
bboxes = bboxes[indices]
|
| 107 |
scores = scores[indices]
|
| 108 |
labels = labels[indices]
|
| 109 |
if len(indices) == 0:
|
| 110 |
indices_of_texts_to_keep.append([])
|
| 111 |
continue
|
| 112 |
-
scores, labels, indices, bboxes
|
|
|
|
| 113 |
texts_to_keep = []
|
| 114 |
texts_to_keep_as_shapely_objects = []
|
| 115 |
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
|
@@ -122,20 +134,23 @@ class Magiv2Processor():
|
|
| 122 |
if should_append:
|
| 123 |
texts_to_keep.append((ts, tl, tb, ti))
|
| 124 |
texts_to_keep_as_shapely_objects.append(text_polygon)
|
| 125 |
-
indices_of_texts_to_keep.append(
|
|
|
|
| 126 |
return indices_of_texts_to_keep
|
| 127 |
-
|
| 128 |
def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
| 129 |
indices_of_texts_to_keep = []
|
| 130 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
| 131 |
-
indices = torch.where((labels == 3) & (
|
|
|
|
| 132 |
bboxes = bboxes[indices]
|
| 133 |
scores = scores[indices]
|
| 134 |
labels = labels[indices]
|
| 135 |
if len(indices) == 0:
|
| 136 |
indices_of_texts_to_keep.append([])
|
| 137 |
continue
|
| 138 |
-
scores, labels, indices, bboxes
|
|
|
|
| 139 |
texts_to_keep = []
|
| 140 |
texts_to_keep_as_shapely_objects = []
|
| 141 |
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
|
@@ -148,9 +163,10 @@ class Magiv2Processor():
|
|
| 148 |
if should_append:
|
| 149 |
texts_to_keep.append((ts, tl, tb, ti))
|
| 150 |
texts_to_keep_as_shapely_objects.append(text_polygon)
|
| 151 |
-
indices_of_texts_to_keep.append(
|
|
|
|
| 152 |
return indices_of_texts_to_keep
|
| 153 |
-
|
| 154 |
def _convert_annotations_to_coco_format(self, annotations):
|
| 155 |
if annotations is None:
|
| 156 |
return None
|
|
@@ -169,7 +185,7 @@ class Magiv2Processor():
|
|
| 169 |
})
|
| 170 |
coco_annotations.append(coco_annotation)
|
| 171 |
return coco_annotations
|
| 172 |
-
|
| 173 |
def _verify_annotations_are_in_correct_format(self, annotations):
|
| 174 |
error_msg = """
|
| 175 |
Annotations must be in the following format:
|
|
|
|
| 5 |
from .utils import x1y1x2y2_to_xywh
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
+
|
| 9 |
class Magiv2Processor():
|
| 10 |
def __init__(self, config):
|
| 11 |
self.config = config
|
|
|
|
| 14 |
self.crop_embedding_image_preprocessor = None
|
| 15 |
if not config.disable_detections:
|
| 16 |
assert config.detection_image_preprocessing_config is not None
|
| 17 |
+
self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict(
|
| 18 |
+
config.detection_image_preprocessing_config)
|
| 19 |
if not config.disable_ocr:
|
| 20 |
assert config.ocr_pretrained_processor_path is not None
|
| 21 |
+
self.ocr_preprocessor = TrOCRProcessor.from_pretrained(
|
| 22 |
+
config.ocr_pretrained_processor_path)
|
| 23 |
if not config.disable_crop_embeddings:
|
| 24 |
assert config.crop_embedding_image_preprocessing_config is not None
|
| 25 |
+
self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict(
|
| 26 |
+
config.crop_embedding_image_preprocessing_config)
|
| 27 |
+
|
| 28 |
def preprocess_inputs_for_detection(self, images, annotations=None):
|
| 29 |
images = list(images)
|
| 30 |
assert isinstance(images[0], np.ndarray)
|
| 31 |
annotations = self._convert_annotations_to_coco_format(annotations)
|
| 32 |
+
inputs = self.detection_image_preprocessor(
|
| 33 |
+
images, annotations=annotations, return_tensors="pt")
|
| 34 |
return inputs
|
| 35 |
|
| 36 |
def preprocess_inputs_for_ocr(self, images):
|
| 37 |
images = list(images)
|
| 38 |
assert isinstance(images[0], np.ndarray)
|
| 39 |
return self.ocr_preprocessor(images, return_tensors="pt").pixel_values
|
| 40 |
+
|
| 41 |
def preprocess_inputs_for_crop_embeddings(self, images):
|
| 42 |
images = list(images)
|
| 43 |
assert isinstance(images[0], np.ndarray)
|
| 44 |
return self.crop_embedding_image_preprocessor(images, return_tensors="pt").pixel_values
|
| 45 |
+
|
| 46 |
def postprocess_ocr_tokens(self, generated_ids, skip_special_tokens=True):
|
| 47 |
return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens)
|
| 48 |
+
|
| 49 |
def crop_image(self, image, bboxes):
|
| 50 |
crops_for_image = []
|
| 51 |
for bbox in bboxes:
|
|
|
|
| 53 |
|
| 54 |
# fix the bounding box in case it is out of bounds or too small
|
| 55 |
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
| 56 |
+
x1, y1, x2, y2 = min(x1, x2), min(y1, y2), max(
|
| 57 |
+
x1, x2), max(y1, y2) # just incase
|
| 58 |
x1, y1 = max(0, x1), max(0, y1)
|
| 59 |
x1, y1 = min(image.shape[1], x1), min(image.shape[0], y1)
|
| 60 |
x2, y2 = max(0, x2), max(0, y2)
|
|
|
|
| 77 |
def _get_indices_of_characters_to_keep(self, batch_scores, batch_labels, batch_bboxes, character_detection_threshold):
|
| 78 |
indices_of_characters_to_keep = []
|
| 79 |
for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes):
|
| 80 |
+
indices = torch.where((labels == 0) & (
|
| 81 |
+
scores > character_detection_threshold))[0]
|
| 82 |
indices_of_characters_to_keep.append(indices)
|
| 83 |
return indices_of_characters_to_keep
|
| 84 |
+
|
| 85 |
def _get_indices_of_panels_to_keep(self, batch_scores, batch_labels, batch_bboxes, panel_detection_threshold):
|
| 86 |
indices_of_panels_to_keep = []
|
| 87 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
|
|
|
| 92 |
if len(indices) == 0:
|
| 93 |
indices_of_panels_to_keep.append([])
|
| 94 |
continue
|
| 95 |
+
scores, labels, indices, bboxes = zip(
|
| 96 |
+
*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
| 97 |
panels_to_keep = []
|
| 98 |
union_of_panels_so_far = box(0, 0, 0, 0)
|
| 99 |
for ps, pb, pl, pi in zip(scores, bboxes, labels, indices):
|
|
|
|
| 103 |
if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5:
|
| 104 |
continue
|
| 105 |
panels_to_keep.append((ps, pl, pb, pi))
|
| 106 |
+
union_of_panels_so_far = union_of_panels_so_far.union(
|
| 107 |
+
panel_polygon)
|
| 108 |
+
indices_of_panels_to_keep.append(
|
| 109 |
+
[p[3].item() for p in panels_to_keep])
|
| 110 |
return indices_of_panels_to_keep
|
| 111 |
+
|
| 112 |
def _get_indices_of_texts_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
| 113 |
indices_of_texts_to_keep = []
|
| 114 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
| 115 |
+
indices = torch.where((labels == 1) & (
|
| 116 |
+
scores > text_detection_threshold))[0]
|
| 117 |
bboxes = bboxes[indices]
|
| 118 |
scores = scores[indices]
|
| 119 |
labels = labels[indices]
|
| 120 |
if len(indices) == 0:
|
| 121 |
indices_of_texts_to_keep.append([])
|
| 122 |
continue
|
| 123 |
+
scores, labels, indices, bboxes = zip(
|
| 124 |
+
*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
| 125 |
texts_to_keep = []
|
| 126 |
texts_to_keep_as_shapely_objects = []
|
| 127 |
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
|
|
|
| 134 |
if should_append:
|
| 135 |
texts_to_keep.append((ts, tl, tb, ti))
|
| 136 |
texts_to_keep_as_shapely_objects.append(text_polygon)
|
| 137 |
+
indices_of_texts_to_keep.append(
|
| 138 |
+
[t[3].item() for t in texts_to_keep])
|
| 139 |
return indices_of_texts_to_keep
|
| 140 |
+
|
| 141 |
def _get_indices_of_tails_to_keep(self, batch_scores, batch_labels, batch_bboxes, text_detection_threshold):
|
| 142 |
indices_of_texts_to_keep = []
|
| 143 |
for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes):
|
| 144 |
+
indices = torch.where((labels == 3) & (
|
| 145 |
+
scores > text_detection_threshold))[0]
|
| 146 |
bboxes = bboxes[indices]
|
| 147 |
scores = scores[indices]
|
| 148 |
labels = labels[indices]
|
| 149 |
if len(indices) == 0:
|
| 150 |
indices_of_texts_to_keep.append([])
|
| 151 |
continue
|
| 152 |
+
scores, labels, indices, bboxes = zip(
|
| 153 |
+
*sorted(zip(scores, labels, indices, bboxes), reverse=True))
|
| 154 |
texts_to_keep = []
|
| 155 |
texts_to_keep_as_shapely_objects = []
|
| 156 |
for ts, tb, tl, ti in zip(scores, bboxes, labels, indices):
|
|
|
|
| 163 |
if should_append:
|
| 164 |
texts_to_keep.append((ts, tl, tb, ti))
|
| 165 |
texts_to_keep_as_shapely_objects.append(text_polygon)
|
| 166 |
+
indices_of_texts_to_keep.append(
|
| 167 |
+
[t[3].item() for t in texts_to_keep])
|
| 168 |
return indices_of_texts_to_keep
|
| 169 |
+
|
| 170 |
def _convert_annotations_to_coco_format(self, annotations):
|
| 171 |
if annotations is None:
|
| 172 |
return None
|
|
|
|
| 185 |
})
|
| 186 |
coco_annotations.append(coco_annotation)
|
| 187 |
return coco_annotations
|
| 188 |
+
|
| 189 |
def _verify_annotations_are_in_correct_format(self, annotations):
|
| 190 |
error_msg = """
|
| 191 |
Annotations must be in the following format:
|
utils.py
CHANGED
|
@@ -8,6 +8,7 @@ import networkx as nx
|
|
| 8 |
from copy import deepcopy
|
| 9 |
from itertools import groupby
|
| 10 |
|
|
|
|
| 11 |
def move_to_device(inputs, device):
|
| 12 |
if hasattr(inputs, "keys"):
|
| 13 |
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
|
@@ -20,6 +21,7 @@ def move_to_device(inputs, device):
|
|
| 20 |
else:
|
| 21 |
return inputs.to(device)
|
| 22 |
|
|
|
|
| 23 |
class UnionFind:
|
| 24 |
def __init__(self, n):
|
| 25 |
self.parent = list(range(n))
|
|
@@ -34,7 +36,7 @@ class UnionFind:
|
|
| 34 |
if adj_matrix[i, j] > 0:
|
| 35 |
ufds.unite(i, j)
|
| 36 |
return ufds
|
| 37 |
-
|
| 38 |
@classmethod
|
| 39 |
def from_adj_list(cls, adj_list):
|
| 40 |
ufds = cls(len(adj_list))
|
|
@@ -42,7 +44,7 @@ class UnionFind:
|
|
| 42 |
for j in adj_list[i]:
|
| 43 |
ufds.unite(i, j)
|
| 44 |
return ufds
|
| 45 |
-
|
| 46 |
@classmethod
|
| 47 |
def from_edge_list(cls, edge_list, num_nodes):
|
| 48 |
ufds = cls(num_nodes)
|
|
@@ -65,11 +67,11 @@ class UnionFind:
|
|
| 65 |
self.parent[y] = x
|
| 66 |
self.size[x] += self.size[y]
|
| 67 |
self.num_components -= 1
|
| 68 |
-
|
| 69 |
def get_components_of(self, x):
|
| 70 |
x = self.find(x)
|
| 71 |
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
| 72 |
-
|
| 73 |
def are_connected(self, x, y):
|
| 74 |
return self.find(x) == self.find(y)
|
| 75 |
|
|
@@ -78,7 +80,7 @@ class UnionFind:
|
|
| 78 |
|
| 79 |
def get_num_components(self):
|
| 80 |
return self.num_components
|
| 81 |
-
|
| 82 |
def get_labels_for_connected_components(self):
|
| 83 |
map_parent_to_label = {}
|
| 84 |
labels = []
|
|
@@ -89,32 +91,36 @@ class UnionFind:
|
|
| 89 |
labels.append(map_parent_to_label[parent])
|
| 90 |
return labels
|
| 91 |
|
|
|
|
| 92 |
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
| 93 |
figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
|
| 94 |
subplot.imshow(image_as_np_array)
|
| 95 |
plot_bboxes(subplot, predictions["panels"], color="green")
|
| 96 |
-
plot_bboxes(subplot, predictions["texts"], color="red",
|
|
|
|
| 97 |
plot_bboxes(subplot, predictions["characters"], color="blue")
|
| 98 |
plot_bboxes(subplot, predictions["tails"], color="purple")
|
| 99 |
|
| 100 |
for i, name in enumerate(predictions["character_names"]):
|
| 101 |
char_bbox = predictions["characters"][i]
|
| 102 |
x1, y1, x2, y2 = char_bbox
|
| 103 |
-
subplot.text(x1, y1 - 2, name,
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
COLOURS = [
|
| 109 |
-
"#b7ff51",
|
| 110 |
-
"#f50a8f",
|
| 111 |
-
"#4b13b6",
|
| 112 |
-
"#ddaa34",
|
| 113 |
-
"#bea2a2",
|
| 114 |
]
|
| 115 |
colour_index = 0
|
| 116 |
character_cluster_labels = predictions["character_cluster_labels"]
|
| 117 |
-
unique_label_sorted_by_frequency = sorted(list(set(
|
|
|
|
| 118 |
for label in unique_label_sorted_by_frequency:
|
| 119 |
root = None
|
| 120 |
others = []
|
|
@@ -127,7 +133,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 127 |
if colour_index >= len(COLOURS):
|
| 128 |
random_colour = COLOURS[0]
|
| 129 |
while random_colour in COLOURS:
|
| 130 |
-
random_colour = "#" +
|
|
|
|
|
|
|
| 131 |
else:
|
| 132 |
random_colour = COLOURS[colour_index]
|
| 133 |
colour_index += 1
|
|
@@ -143,8 +151,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 143 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 144 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 145 |
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
| 146 |
-
subplot.plot([x2], [y2], color=random_colour,
|
| 147 |
-
|
|
|
|
| 148 |
for (i, j) in predictions["text_character_associations"]:
|
| 149 |
bbox_i = predictions["texts"][i]
|
| 150 |
bbox_j = predictions["characters"][j]
|
|
@@ -154,8 +163,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 154 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 155 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 156 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 157 |
-
subplot.plot([x1, x2], [y1, y2], color="red",
|
| 158 |
-
|
|
|
|
| 159 |
for (i, j) in predictions["text_tail_associations"]:
|
| 160 |
bbox_i = predictions["texts"][i]
|
| 161 |
bbox_j = predictions["tails"][j]
|
|
@@ -163,7 +173,8 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 163 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 164 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 165 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 166 |
-
subplot.plot([x1, x2], [y1, y2], color="purple",
|
|
|
|
| 167 |
|
| 168 |
subplot.axis("off")
|
| 169 |
if filename is not None:
|
|
@@ -174,6 +185,7 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 174 |
plt.close()
|
| 175 |
return image
|
| 176 |
|
|
|
|
| 177 |
def plot_bboxes(subplot, bboxes, color="red", visibility=None):
|
| 178 |
if visibility is None:
|
| 179 |
visibility = [1] * len(bboxes)
|
|
@@ -187,6 +199,7 @@ def plot_bboxes(subplot, bboxes, color="red", visibility=None):
|
|
| 187 |
)
|
| 188 |
subplot.add_patch(rect)
|
| 189 |
|
|
|
|
| 190 |
def sort_panels(rects):
|
| 191 |
before_rects = convert_to_list_of_lists(rects)
|
| 192 |
# slightly erode all rectangles initially to account for imperfect detections
|
|
@@ -212,34 +225,42 @@ def sort_panels(rects):
|
|
| 212 |
G.remove_edge(*max_cyclic_edge)
|
| 213 |
return list(nx.topological_sort(G))
|
| 214 |
|
|
|
|
| 215 |
def is_strictly_above(rectA, rectB):
|
| 216 |
x1A, y1A, x2A, y2A = rectA
|
| 217 |
x1B, y1B, x2B, y2B = rectB
|
| 218 |
return y2A < y1B
|
| 219 |
|
|
|
|
| 220 |
def is_strictly_below(rectA, rectB):
|
| 221 |
x1A, y1A, x2A, y2A = rectA
|
| 222 |
x1B, y1B, x2B, y2B = rectB
|
| 223 |
return y2B < y1A
|
| 224 |
|
|
|
|
| 225 |
def is_strictly_left_of(rectA, rectB):
|
| 226 |
x1A, y1A, x2A, y2A = rectA
|
| 227 |
x1B, y1B, x2B, y2B = rectB
|
| 228 |
return x2A < x1B
|
| 229 |
|
|
|
|
| 230 |
def is_strictly_right_of(rectA, rectB):
|
| 231 |
x1A, y1A, x2A, y2A = rectA
|
| 232 |
x1B, y1B, x2B, y2B = rectB
|
| 233 |
return x2B < x1A
|
| 234 |
|
|
|
|
| 235 |
def intersects(rectA, rectB):
|
| 236 |
return box(*rectA).intersects(box(*rectB))
|
| 237 |
|
|
|
|
| 238 |
def is_there_a_directed_edge(a, b, rects):
|
| 239 |
rectA = rects[a]
|
| 240 |
rectB = rects[b]
|
| 241 |
-
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
| 244 |
return box(*rectA).area > (box(*rectB)).area
|
| 245 |
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
|
@@ -256,34 +277,41 @@ def is_there_a_directed_edge(a, b, rects):
|
|
| 256 |
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
| 257 |
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 258 |
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
| 259 |
-
|
| 260 |
# otherwise they intersect
|
| 261 |
copy_A = erode_rectangle(copy_A, 0.05)
|
| 262 |
copy_B = erode_rectangle(copy_B, 0.05)
|
| 263 |
-
|
|
|
|
| 264 |
def get_distance(rectA, rectB):
|
| 265 |
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
| 266 |
|
|
|
|
| 267 |
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
| 268 |
rects = deepcopy(rects)
|
| 269 |
while True:
|
| 270 |
-
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
| 274 |
# try to split the panels using a "horizontal" lines
|
| 275 |
-
overlapping_y_ranges = merge_overlapping_ranges(
|
|
|
|
| 276 |
panel_index_to_split = {}
|
| 277 |
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
| 278 |
for i, index in enumerate(rect_index):
|
| 279 |
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
| 280 |
panel_index_to_split[index] = split_index
|
| 281 |
-
|
| 282 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 283 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 284 |
-
|
| 285 |
# try to split the panels using a "vertical" lines
|
| 286 |
-
overlapping_x_ranges = merge_overlapping_ranges(
|
|
|
|
| 287 |
panel_index_to_split = {}
|
| 288 |
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
| 289 |
for i, index in enumerate(rect_index):
|
|
@@ -291,10 +319,11 @@ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
|
| 291 |
panel_index_to_split[index] = split_index
|
| 292 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 293 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 294 |
-
|
| 295 |
# otherwise, erode the rectangles and try again
|
| 296 |
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
| 297 |
|
|
|
|
| 298 |
def erode_rectangle(bbox, erosion_factor):
|
| 299 |
x1, y1, x2, y2 = bbox
|
| 300 |
w, h = x2 - x1, y2 - y1
|
|
@@ -312,6 +341,7 @@ def erode_rectangle(bbox, erosion_factor):
|
|
| 312 |
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
| 313 |
return [x1, y1, x2, y2]
|
| 314 |
|
|
|
|
| 315 |
def merge_overlapping_ranges(ranges):
|
| 316 |
"""
|
| 317 |
ranges: list of tuples (x1, x2)
|
|
@@ -333,6 +363,7 @@ def merge_overlapping_ranges(ranges):
|
|
| 333 |
merged_ranges.append((prev_x1, prev_x2))
|
| 334 |
return merged_ranges
|
| 335 |
|
|
|
|
| 336 |
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
| 337 |
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
| 338 |
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
|
@@ -344,18 +375,23 @@ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
|
| 344 |
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
| 345 |
return [list(indices) for _, indices in groups]
|
| 346 |
|
| 347 |
-
panel_id_for_text = get_text_to_panel_mapping(
|
|
|
|
| 348 |
indices_of_texts = list(range(len(text_bboxes)))
|
| 349 |
-
indices_of_texts, panel_id_for_text = zip(
|
|
|
|
| 350 |
indices_of_texts = list(indices_of_texts)
|
| 351 |
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
| 352 |
for group in grouped_indices:
|
| 353 |
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
| 354 |
-
text_bboxes_of_subset = [text_bboxes[i]
|
|
|
|
| 355 |
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
| 356 |
-
indices_of_texts[group[0]
|
|
|
|
| 357 |
return indices_of_texts
|
| 358 |
|
|
|
|
| 359 |
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
| 360 |
text_to_panel_mapping = []
|
| 361 |
for text_bbox in text_bboxes:
|
|
@@ -368,14 +404,19 @@ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
|
| 368 |
for j, annotation in enumerate(sorted_panel_bboxes):
|
| 369 |
shapely_annotation_polygon = box(*annotation)
|
| 370 |
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
| 371 |
-
all_intersections.append(
|
| 372 |
-
|
|
|
|
|
|
|
| 373 |
if len(all_intersections) == 0:
|
| 374 |
-
text_to_panel_mapping.append(
|
|
|
|
| 375 |
else:
|
| 376 |
-
text_to_panel_mapping.append(
|
|
|
|
| 377 |
return text_to_panel_mapping
|
| 378 |
|
|
|
|
| 379 |
def sort_texts_within_panel(rects):
|
| 380 |
smallest_y = float("inf")
|
| 381 |
greatest_x = float("-inf")
|
|
@@ -383,29 +424,33 @@ def sort_texts_within_panel(rects):
|
|
| 383 |
x1, y1, x2, y2 = rect
|
| 384 |
smallest_y = min(smallest_y, y1)
|
| 385 |
greatest_x = max(greatest_x, x2)
|
| 386 |
-
|
| 387 |
reference_point = Point(greatest_x, smallest_y)
|
| 388 |
|
| 389 |
polygons_and_index = []
|
| 390 |
for i, rect in enumerate(rects):
|
| 391 |
x1, y1, x2, y2 = rect
|
| 392 |
-
polygons_and_index.append((box(x1,y1,x2,y2), i))
|
| 393 |
# sort points by closest to reference point
|
| 394 |
-
polygons_and_index = sorted(
|
|
|
|
| 395 |
indices = [x[1] for x in polygons_and_index]
|
| 396 |
return indices
|
| 397 |
|
|
|
|
| 398 |
def x1y1wh_to_x1y1x2y2(bbox):
|
| 399 |
x1, y1, w, h = bbox
|
| 400 |
return [x1, y1, x1 + w, y1 + h]
|
| 401 |
|
|
|
|
| 402 |
def x1y1x2y2_to_xywh(bbox):
|
| 403 |
x1, y1, x2, y2 = bbox
|
| 404 |
return [x1, y1, x2 - x1, y2 - y1]
|
| 405 |
|
|
|
|
| 406 |
def convert_to_list_of_lists(rects):
|
| 407 |
if isinstance(rects, torch.Tensor):
|
| 408 |
return rects.tolist()
|
| 409 |
if isinstance(rects, np.ndarray):
|
| 410 |
return rects.tolist()
|
| 411 |
-
return [[a, b, c, d] for a, b, c, d in rects]
|
|
|
|
| 8 |
from copy import deepcopy
|
| 9 |
from itertools import groupby
|
| 10 |
|
| 11 |
+
|
| 12 |
def move_to_device(inputs, device):
|
| 13 |
if hasattr(inputs, "keys"):
|
| 14 |
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
|
|
|
| 21 |
else:
|
| 22 |
return inputs.to(device)
|
| 23 |
|
| 24 |
+
|
| 25 |
class UnionFind:
|
| 26 |
def __init__(self, n):
|
| 27 |
self.parent = list(range(n))
|
|
|
|
| 36 |
if adj_matrix[i, j] > 0:
|
| 37 |
ufds.unite(i, j)
|
| 38 |
return ufds
|
| 39 |
+
|
| 40 |
@classmethod
|
| 41 |
def from_adj_list(cls, adj_list):
|
| 42 |
ufds = cls(len(adj_list))
|
|
|
|
| 44 |
for j in adj_list[i]:
|
| 45 |
ufds.unite(i, j)
|
| 46 |
return ufds
|
| 47 |
+
|
| 48 |
@classmethod
|
| 49 |
def from_edge_list(cls, edge_list, num_nodes):
|
| 50 |
ufds = cls(num_nodes)
|
|
|
|
| 67 |
self.parent[y] = x
|
| 68 |
self.size[x] += self.size[y]
|
| 69 |
self.num_components -= 1
|
| 70 |
+
|
| 71 |
def get_components_of(self, x):
|
| 72 |
x = self.find(x)
|
| 73 |
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
| 74 |
+
|
| 75 |
def are_connected(self, x, y):
|
| 76 |
return self.find(x) == self.find(y)
|
| 77 |
|
|
|
|
| 80 |
|
| 81 |
def get_num_components(self):
|
| 82 |
return self.num_components
|
| 83 |
+
|
| 84 |
def get_labels_for_connected_components(self):
|
| 85 |
map_parent_to_label = {}
|
| 86 |
labels = []
|
|
|
|
| 91 |
labels.append(map_parent_to_label[parent])
|
| 92 |
return labels
|
| 93 |
|
| 94 |
+
|
| 95 |
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
| 96 |
figure, subplot = plt.subplots(1, 1, figsize=(10, 10))
|
| 97 |
subplot.imshow(image_as_np_array)
|
| 98 |
plot_bboxes(subplot, predictions["panels"], color="green")
|
| 99 |
+
plot_bboxes(subplot, predictions["texts"], color="red",
|
| 100 |
+
visibility=predictions["is_essential_text"])
|
| 101 |
plot_bboxes(subplot, predictions["characters"], color="blue")
|
| 102 |
plot_bboxes(subplot, predictions["tails"], color="purple")
|
| 103 |
|
| 104 |
for i, name in enumerate(predictions["character_names"]):
|
| 105 |
char_bbox = predictions["characters"][i]
|
| 106 |
x1, y1, x2, y2 = char_bbox
|
| 107 |
+
subplot.text(x1, y1 - 2, name,
|
| 108 |
+
verticalalignment='bottom', horizontalalignment='left',
|
| 109 |
+
# Background settings
|
| 110 |
+
bbox=dict(facecolor='blue', alpha=1, edgecolor='none'),
|
| 111 |
+
color='white', fontsize=8)
|
| 112 |
|
| 113 |
COLOURS = [
|
| 114 |
+
"#b7ff51", # green
|
| 115 |
+
"#f50a8f", # pink
|
| 116 |
+
"#4b13b6", # purple
|
| 117 |
+
"#ddaa34", # orange
|
| 118 |
+
"#bea2a2", # brown
|
| 119 |
]
|
| 120 |
colour_index = 0
|
| 121 |
character_cluster_labels = predictions["character_cluster_labels"]
|
| 122 |
+
unique_label_sorted_by_frequency = sorted(list(set(
|
| 123 |
+
character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
|
| 124 |
for label in unique_label_sorted_by_frequency:
|
| 125 |
root = None
|
| 126 |
others = []
|
|
|
|
| 133 |
if colour_index >= len(COLOURS):
|
| 134 |
random_colour = COLOURS[0]
|
| 135 |
while random_colour in COLOURS:
|
| 136 |
+
random_colour = "#" + \
|
| 137 |
+
"".join([random.choice("0123456789ABCDEF")
|
| 138 |
+
for j in range(6)])
|
| 139 |
else:
|
| 140 |
random_colour = COLOURS[colour_index]
|
| 141 |
colour_index += 1
|
|
|
|
| 151 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 152 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 153 |
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
| 154 |
+
subplot.plot([x2], [y2], color=random_colour,
|
| 155 |
+
marker="o", markersize=5)
|
| 156 |
+
|
| 157 |
for (i, j) in predictions["text_character_associations"]:
|
| 158 |
bbox_i = predictions["texts"][i]
|
| 159 |
bbox_j = predictions["characters"][j]
|
|
|
|
| 163 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 164 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 165 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 166 |
+
subplot.plot([x1, x2], [y1, y2], color="red",
|
| 167 |
+
linewidth=2, linestyle="dashed")
|
| 168 |
+
|
| 169 |
for (i, j) in predictions["text_tail_associations"]:
|
| 170 |
bbox_i = predictions["texts"][i]
|
| 171 |
bbox_j = predictions["tails"][j]
|
|
|
|
| 173 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 174 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 175 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 176 |
+
subplot.plot([x1, x2], [y1, y2], color="purple",
|
| 177 |
+
linewidth=2, linestyle="dashed")
|
| 178 |
|
| 179 |
subplot.axis("off")
|
| 180 |
if filename is not None:
|
|
|
|
| 185 |
plt.close()
|
| 186 |
return image
|
| 187 |
|
| 188 |
+
|
| 189 |
def plot_bboxes(subplot, bboxes, color="red", visibility=None):
|
| 190 |
if visibility is None:
|
| 191 |
visibility = [1] * len(bboxes)
|
|
|
|
| 199 |
)
|
| 200 |
subplot.add_patch(rect)
|
| 201 |
|
| 202 |
+
|
| 203 |
def sort_panels(rects):
|
| 204 |
before_rects = convert_to_list_of_lists(rects)
|
| 205 |
# slightly erode all rectangles initially to account for imperfect detections
|
|
|
|
| 225 |
G.remove_edge(*max_cyclic_edge)
|
| 226 |
return list(nx.topological_sort(G))
|
| 227 |
|
| 228 |
+
|
| 229 |
def is_strictly_above(rectA, rectB):
|
| 230 |
x1A, y1A, x2A, y2A = rectA
|
| 231 |
x1B, y1B, x2B, y2B = rectB
|
| 232 |
return y2A < y1B
|
| 233 |
|
| 234 |
+
|
| 235 |
def is_strictly_below(rectA, rectB):
|
| 236 |
x1A, y1A, x2A, y2A = rectA
|
| 237 |
x1B, y1B, x2B, y2B = rectB
|
| 238 |
return y2B < y1A
|
| 239 |
|
| 240 |
+
|
| 241 |
def is_strictly_left_of(rectA, rectB):
|
| 242 |
x1A, y1A, x2A, y2A = rectA
|
| 243 |
x1B, y1B, x2B, y2B = rectB
|
| 244 |
return x2A < x1B
|
| 245 |
|
| 246 |
+
|
| 247 |
def is_strictly_right_of(rectA, rectB):
|
| 248 |
x1A, y1A, x2A, y2A = rectA
|
| 249 |
x1B, y1B, x2B, y2B = rectB
|
| 250 |
return x2B < x1A
|
| 251 |
|
| 252 |
+
|
| 253 |
def intersects(rectA, rectB):
|
| 254 |
return box(*rectA).intersects(box(*rectB))
|
| 255 |
|
| 256 |
+
|
| 257 |
def is_there_a_directed_edge(a, b, rects):
|
| 258 |
rectA = rects[a]
|
| 259 |
rectB = rects[b]
|
| 260 |
+
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
|
| 261 |
+
rectA[1] + (rectA[3] - rectA[1]) / 2]
|
| 262 |
+
centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
|
| 263 |
+
rectB[1] + (rectB[3] - rectB[1]) / 2]
|
| 264 |
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
| 265 |
return box(*rectA).area > (box(*rectB)).area
|
| 266 |
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
|
|
|
| 277 |
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
| 278 |
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 279 |
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
| 280 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 281 |
# otherwise they intersect
|
| 282 |
copy_A = erode_rectangle(copy_A, 0.05)
|
| 283 |
copy_B = erode_rectangle(copy_B, 0.05)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
def get_distance(rectA, rectB):
|
| 287 |
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
| 288 |
|
| 289 |
+
|
| 290 |
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
| 291 |
rects = deepcopy(rects)
|
| 292 |
while True:
|
| 293 |
+
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
|
| 294 |
+
rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
|
| 295 |
+
rect_index = [i for i in range(len(rects)) if intersects(
|
| 296 |
+
rects[i], [xmin, ymin, xmax, ymax])]
|
| 297 |
+
rects_copy = [rect for rect in rects if intersects(
|
| 298 |
+
rect, [xmin, ymin, xmax, ymax])]
|
| 299 |
+
|
| 300 |
# try to split the panels using a "horizontal" lines
|
| 301 |
+
overlapping_y_ranges = merge_overlapping_ranges(
|
| 302 |
+
[(y1, y2) for x1, y1, x2, y2 in rects_copy])
|
| 303 |
panel_index_to_split = {}
|
| 304 |
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
| 305 |
for i, index in enumerate(rect_index):
|
| 306 |
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
| 307 |
panel_index_to_split[index] = split_index
|
| 308 |
+
|
| 309 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 310 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 311 |
+
|
| 312 |
# try to split the panels using a "vertical" lines
|
| 313 |
+
overlapping_x_ranges = merge_overlapping_ranges(
|
| 314 |
+
[(x1, x2) for x1, y1, x2, y2 in rects_copy])
|
| 315 |
panel_index_to_split = {}
|
| 316 |
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
| 317 |
for i, index in enumerate(rect_index):
|
|
|
|
| 319 |
panel_index_to_split[index] = split_index
|
| 320 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 321 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 322 |
+
|
| 323 |
# otherwise, erode the rectangles and try again
|
| 324 |
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
| 325 |
|
| 326 |
+
|
| 327 |
def erode_rectangle(bbox, erosion_factor):
|
| 328 |
x1, y1, x2, y2 = bbox
|
| 329 |
w, h = x2 - x1, y2 - y1
|
|
|
|
| 341 |
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
| 342 |
return [x1, y1, x2, y2]
|
| 343 |
|
| 344 |
+
|
| 345 |
def merge_overlapping_ranges(ranges):
|
| 346 |
"""
|
| 347 |
ranges: list of tuples (x1, x2)
|
|
|
|
| 363 |
merged_ranges.append((prev_x1, prev_x2))
|
| 364 |
return merged_ranges
|
| 365 |
|
| 366 |
+
|
| 367 |
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
| 368 |
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
| 369 |
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
|
|
|
| 375 |
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
| 376 |
return [list(indices) for _, indices in groups]
|
| 377 |
|
| 378 |
+
panel_id_for_text = get_text_to_panel_mapping(
|
| 379 |
+
text_bboxes, sorted_panel_bboxes)
|
| 380 |
indices_of_texts = list(range(len(text_bboxes)))
|
| 381 |
+
indices_of_texts, panel_id_for_text = zip(
|
| 382 |
+
*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
|
| 383 |
indices_of_texts = list(indices_of_texts)
|
| 384 |
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
| 385 |
for group in grouped_indices:
|
| 386 |
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
| 387 |
+
text_bboxes_of_subset = [text_bboxes[i]
|
| 388 |
+
for i in subset_of_text_indices]
|
| 389 |
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
| 390 |
+
indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
|
| 391 |
+
for i in sorted_subset_indices]
|
| 392 |
return indices_of_texts
|
| 393 |
|
| 394 |
+
|
| 395 |
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
| 396 |
text_to_panel_mapping = []
|
| 397 |
for text_bbox in text_bboxes:
|
|
|
|
| 404 |
for j, annotation in enumerate(sorted_panel_bboxes):
|
| 405 |
shapely_annotation_polygon = box(*annotation)
|
| 406 |
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
| 407 |
+
all_intersections.append(
|
| 408 |
+
(shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
|
| 409 |
+
all_distances.append(
|
| 410 |
+
(shapely_text_polygon.distance(shapely_annotation_polygon), j))
|
| 411 |
if len(all_intersections) == 0:
|
| 412 |
+
text_to_panel_mapping.append(
|
| 413 |
+
min(all_distances, key=lambda x: x[0])[1])
|
| 414 |
else:
|
| 415 |
+
text_to_panel_mapping.append(
|
| 416 |
+
max(all_intersections, key=lambda x: x[0])[1])
|
| 417 |
return text_to_panel_mapping
|
| 418 |
|
| 419 |
+
|
| 420 |
def sort_texts_within_panel(rects):
|
| 421 |
smallest_y = float("inf")
|
| 422 |
greatest_x = float("-inf")
|
|
|
|
| 424 |
x1, y1, x2, y2 = rect
|
| 425 |
smallest_y = min(smallest_y, y1)
|
| 426 |
greatest_x = max(greatest_x, x2)
|
| 427 |
+
|
| 428 |
reference_point = Point(greatest_x, smallest_y)
|
| 429 |
|
| 430 |
polygons_and_index = []
|
| 431 |
for i, rect in enumerate(rects):
|
| 432 |
x1, y1, x2, y2 = rect
|
| 433 |
+
polygons_and_index.append((box(x1, y1, x2, y2), i))
|
| 434 |
# sort points by closest to reference point
|
| 435 |
+
polygons_and_index = sorted(
|
| 436 |
+
polygons_and_index, key=lambda x: reference_point.distance(x[0]))
|
| 437 |
indices = [x[1] for x in polygons_and_index]
|
| 438 |
return indices
|
| 439 |
|
| 440 |
+
|
| 441 |
def x1y1wh_to_x1y1x2y2(bbox):
|
| 442 |
x1, y1, w, h = bbox
|
| 443 |
return [x1, y1, x1 + w, y1 + h]
|
| 444 |
|
| 445 |
+
|
| 446 |
def x1y1x2y2_to_xywh(bbox):
|
| 447 |
x1, y1, x2, y2 = bbox
|
| 448 |
return [x1, y1, x2 - x1, y2 - y1]
|
| 449 |
|
| 450 |
+
|
| 451 |
def convert_to_list_of_lists(rects):
|
| 452 |
if isinstance(rects, torch.Tensor):
|
| 453 |
return rects.tolist()
|
| 454 |
if isinstance(rects, np.ndarray):
|
| 455 |
return rects.tolist()
|
| 456 |
+
return [[a, b, c, d] for a, b, c, d in rects]
|