Mateusz Mróz commited on
Commit ·
bfabb11
1
Parent(s): 848f81c
Refactor Florence2Processor and utils.py for improved readability and maintainability; add colab setup script with necessary patches and dependencies for model execution.
Browse files- colab copy.py +187 -0
- colab.py +52 -16
- modeling_florence2.py +382 -209
- processing_florence2.py +68 -45
- utils.py +87 -42
colab copy.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @title Ostateczne Uruchomienie `magiv3` z Wymaganymi Poprawkami
|
| 2 |
+
# TO nie działa ale pobiera jakieś zależności by dziłąło colab.py
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import re
|
| 6 |
+
import requests
|
| 7 |
+
import torch
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
from IPython.display import display
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
# --- KROK 1: PRZYGOTOWANIE ŚRODOWISKA Z POPRAWKAMI ---
|
| 15 |
+
|
| 16 |
+
print("--- KROK 1: PRZYGOTOWANIE ŚRODOWISKA Z POPRAWKAMI ---")
|
| 17 |
+
|
| 18 |
+
# --- Instalacja zależności ---
|
| 19 |
+
print("⏳ Instaluję `uv` i wszystkie potrzebne pakiety...")
|
| 20 |
+
!curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 21 |
+
os.environ['PATH'] = f"/root/.local/bin:{os.environ['PATH']}"
|
| 22 |
+
!uv pip install --quiet transformers accelerate einops timm scipy tokenizers pulp torch pytorch-metric-learning Pillow requests shapely
|
| 23 |
+
print("✅ Zależności zainstalowane.")
|
| 24 |
+
|
| 25 |
+
# --- Klonowanie repozytorium ---
|
| 26 |
+
repo_path = "/content/magiv3"
|
| 27 |
+
print(f"\n⏳ Klonuję repozytorium do folderu `{repo_path}`...")
|
| 28 |
+
if os.path.exists(repo_path):
|
| 29 |
+
!rm -rf {repo_path}
|
| 30 |
+
!git clone https://huggingface.co/ragavsachdeva/magiv3 {repo_path}
|
| 31 |
+
print("✅ Repozytorium sklonowane.")
|
| 32 |
+
|
| 33 |
+
# --- OSTATECZNA, KOMPLEKSOWA POPRAWKA KODU ---
|
| 34 |
+
file_to_patch = os.path.join(repo_path, "modeling_florence2.py")
|
| 35 |
+
print(f"\n⏳ Nanoszę wszystkie wymagane poprawki na plik `{file_to_patch}`...")
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
with open(file_to_patch, 'r', encoding='utf-8') as f:
|
| 39 |
+
content = f.read()
|
| 40 |
+
|
| 41 |
+
# Poprawka 1: Dodanie importu GenerationMixin
|
| 42 |
+
if "from transformers.generation.utils import GenerationMixin" not in content:
|
| 43 |
+
content = content.replace(
|
| 44 |
+
"from transformers.modeling_utils import PreTrainedModel",
|
| 45 |
+
"from transformers.generation.utils import GenerationMixin\nfrom transformers.modeling_utils import PreTrainedModel"
|
| 46 |
+
)
|
| 47 |
+
print("PATCH 1: Dodano import `GenerationMixin`.")
|
| 48 |
+
|
| 49 |
+
# Poprawka 2: Naprawa klasy bazowej modelu językowego (dla metody .generate)
|
| 50 |
+
original_lang_class = "class Florence2LanguagePreTrainedModel(PreTrainedModel):"
|
| 51 |
+
patched_lang_class = "class Florence2LanguagePreTrainedModel(GenerationMixin, PreTrainedModel):"
|
| 52 |
+
if original_lang_class in content:
|
| 53 |
+
content = content.replace(original_lang_class, patched_lang_class)
|
| 54 |
+
print("PATCH 2: Poprawiono dziedziczenie `Florence2LanguagePreTrainedModel`.")
|
| 55 |
+
|
| 56 |
+
# Poprawka 3: Usunięcie wadliwych właściwości, które powodują błąd inicjalizacji
|
| 57 |
+
faulty_property_block = r"""
|
| 58 |
+
@property
|
| 59 |
+
def _supports_flash_attn_2\(self\):.*?return self.language_model._supports_flash_attn_2
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def _supports_sdpa\(self\):.*?return self.language_model._supports_sdpa"""
|
| 63 |
+
|
| 64 |
+
if re.search(faulty_property_block, content, flags=re.DOTALL):
|
| 65 |
+
content = re.sub(faulty_property_block, "", content, flags=re.DOTALL)
|
| 66 |
+
print("PATCH 3: Usunięto wadliwe właściwości `@property`.")
|
| 67 |
+
|
| 68 |
+
# Poprawka 4: Naprawa błędu nadpisywania modelu w __init__
|
| 69 |
+
faulty_init_block = r""" language_model = Florence2LanguageForConditionalGeneration\(config=config.text_config\)
|
| 70 |
+
|
| 71 |
+
if language_model._tied_weights_keys is not None:
|
| 72 |
+
self._tied_weights_keys = \[f"language_model.{k}" for k in language_model._tied_weights_keys\]
|
| 73 |
+
self.language_model = language_model"""
|
| 74 |
+
|
| 75 |
+
correct_init_block = r""" # This line is intentionally left blank.
|
| 76 |
+
# The language_model is already initialized by the parent class.
|
| 77 |
+
# The original code had a bug here that overwrote the pretrained language model.
|
| 78 |
+
if self.language_model._tied_weights_keys is not None:
|
| 79 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]"""
|
| 80 |
+
|
| 81 |
+
if re.search(faulty_init_block, content, flags=re.DOTALL):
|
| 82 |
+
content = re.sub(faulty_init_block, correct_init_block, content, flags=re.DOTALL)
|
| 83 |
+
print("PATCH 4: Naprawiono błąd nadpisywania modelu w `__init__`.")
|
| 84 |
+
|
| 85 |
+
with open(file_to_patch, 'w', encoding='utf-8') as f:
|
| 86 |
+
f.write(content)
|
| 87 |
+
print("\n✅ Wszystkie poprawki zostały pomyślnie naniesione!")
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print(f"❌ Wystąpił krytyczny błąd podczas patchowania pliku: {e}")
|
| 91 |
+
sys.exit()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# --- KROK 2: POBRANIE OBRAZKA TESTOWEGO ---
|
| 95 |
+
|
| 96 |
+
print("\n--- KROK 2: POBRANIE OBRAZKA TESTOWEGO ---")
|
| 97 |
+
IMAGE_URL = "https://raw.githubusercontent.com/MattyMroz/Manga_Whisperer/refs/heads/main/input/raw/04.jpg"
|
| 98 |
+
IMAGE_PATH = "/content/test_image.jpg"
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
response = requests.get(IMAGE_URL)
|
| 102 |
+
response.raise_for_status()
|
| 103 |
+
with open(IMAGE_PATH, 'wb') as f:
|
| 104 |
+
f.write(response.content)
|
| 105 |
+
print(f"✅ Obrazek testowy został pomyślnie pobrany i zapisany jako `{IMAGE_PATH}`.")
|
| 106 |
+
display(Image.open(IMAGE_PATH).resize((300, 400)))
|
| 107 |
+
except Exception as e:
|
| 108 |
+
print(f"❌ Nie udało się pobrać obrazka. Błąd: {e}")
|
| 109 |
+
sys.exit()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# --- KROK 3: URUCHOMIENIE POPRAWIONEGO MODELU ---
|
| 113 |
+
|
| 114 |
+
print("\n--- KROK 3: URUCHOMIENIE POPRAWIONEGO MODELU ---")
|
| 115 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 116 |
+
|
| 117 |
+
# Dodajemy poprawiony kod do ścieżki Pythona
|
| 118 |
+
if repo_path not in sys.path:
|
| 119 |
+
sys.path.insert(0, repo_path)
|
| 120 |
+
|
| 121 |
+
# Importujemy klasy z naszego poprawionego kodu
|
| 122 |
+
from magiv3.modeling_florence2 import Florence2ForConditionalGeneration
|
| 123 |
+
from transformers import AutoProcessor
|
| 124 |
+
|
| 125 |
+
model = None
|
| 126 |
+
processor = None
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
print(f"⏳ Ładowanie modelu i procesora (z użyciem poprawionego kodu)...")
|
| 130 |
+
model_id = "ragavsachdeva/magiv3"
|
| 131 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 132 |
+
|
| 133 |
+
# Używamy AutoProcessor, ale dla modelu musimy wskazać naszą poprawioną klasę
|
| 134 |
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
| 135 |
+
model = Florence2ForConditionalGeneration.from_pretrained(
|
| 136 |
+
model_id,
|
| 137 |
+
torch_dtype=torch.float16,
|
| 138 |
+
trust_remote_code=True
|
| 139 |
+
).to(device).eval()
|
| 140 |
+
|
| 141 |
+
print("✅ Model i procesor załadowane pomyślnie.")
|
| 142 |
+
|
| 143 |
+
except Exception as e:
|
| 144 |
+
print(f"\n❌ Wystąpił błąd podczas ładowania modelu, nawet po poprawkach: {e}")
|
| 145 |
+
sys.exit()
|
| 146 |
+
|
| 147 |
+
# Uruchamiamy predykcję
|
| 148 |
+
if model and processor:
|
| 149 |
+
try:
|
| 150 |
+
print("\n⏳ Przygotowuję dane wejściowe...")
|
| 151 |
+
images = [Image.open(IMAGE_PATH).convert("RGB")]
|
| 152 |
+
np_images = [np.array(img) for img in images]
|
| 153 |
+
print("✅ Dane wejściowe gotowe.")
|
| 154 |
+
|
| 155 |
+
print("\n⏳ Uruchamiam `predict_detections_and_associations`...")
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
results = model.predict_detections_and_associations(np_images, processor)
|
| 158 |
+
|
| 159 |
+
print("✅ Przetwarzanie zakończone pomyślnie!")
|
| 160 |
+
print("\n--- WYNIKI ---")
|
| 161 |
+
|
| 162 |
+
# Funkcja do wizualizacji
|
| 163 |
+
def visualize_results(image, results):
|
| 164 |
+
colors = {"panels": "red", "texts": "blue", "characters": "green", "tails": "yellow"}
|
| 165 |
+
draw_image = image.copy()
|
| 166 |
+
draw = ImageDraw.Draw(draw_image)
|
| 167 |
+
try:
|
| 168 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 15)
|
| 169 |
+
except IOError:
|
| 170 |
+
font = ImageFont.load_default()
|
| 171 |
+
|
| 172 |
+
for category, bboxes in results.items():
|
| 173 |
+
if category not in colors: continue
|
| 174 |
+
for i, box in enumerate(bboxes):
|
| 175 |
+
draw.rectangle(box, outline=colors[category], width=3)
|
| 176 |
+
draw.text((box[0], box[1]), f"{category}_{i}", fill=colors[category], font=font)
|
| 177 |
+
return draw_image
|
| 178 |
+
|
| 179 |
+
visualized_image = visualize_results(images[0], results[0])
|
| 180 |
+
display(visualized_image)
|
| 181 |
+
|
| 182 |
+
serializable_results = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in results[0].items()}
|
| 183 |
+
print(json.dumps(serializable_results, indent=2))
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"\n❌ WYSTĄPIŁ KRYTYCZNY BŁĄD PODCZAS PRZETWARZANIA:")
|
| 187 |
+
print(f"Błąd: {e}")
|
colab.py
CHANGED
|
@@ -1,20 +1,20 @@
|
|
| 1 |
# ==============================================================================
|
| 2 |
# 1) INSTALACJA PAKIETÓW
|
| 3 |
# ==============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
!pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning"
|
| 5 |
|
| 6 |
# ==============================================================================
|
| 7 |
# 2) IMPORTY
|
| 8 |
# ==============================================================================
|
| 9 |
-
import re
|
| 10 |
-
import requests
|
| 11 |
-
import torch
|
| 12 |
-
import json
|
| 13 |
-
import math
|
| 14 |
-
from PIL import Image, ImageDraw
|
| 15 |
-
from io import BytesIO
|
| 16 |
-
from IPython.display import display
|
| 17 |
-
from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
|
| 18 |
|
| 19 |
# ==============================================================================
|
| 20 |
# 3) POBRANIE OBRAZU
|
|
@@ -62,9 +62,15 @@ print("Model i procesor załadowane pomyślnie.")
|
|
| 62 |
# ==============================================================================
|
| 63 |
|
| 64 |
|
| 65 |
-
def create_visualization(image, data):
|
| 66 |
"""
|
| 67 |
Rysuje zaawansowaną wizualizację detekcji i asocjacji na obrazie.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"""
|
| 69 |
img_draw = image.copy()
|
| 70 |
draw = ImageDraw.Draw(img_draw)
|
|
@@ -77,8 +83,10 @@ def create_visualization(image, data):
|
|
| 77 |
"tails": "purple",
|
| 78 |
"cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"],
|
| 79 |
"speaker_line": "magenta",
|
|
|
|
|
|
|
| 80 |
}
|
| 81 |
-
line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1}
|
| 82 |
|
| 83 |
def get_box_center(box):
|
| 84 |
x1, y1, x2, y2 = box
|
|
@@ -131,10 +139,35 @@ def create_visualization(image, data):
|
|
| 131 |
draw_dashed_line(
|
| 132 |
draw, p1, p2, fill=colors["speaker_line"], width=1)
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
return img_draw
|
| 135 |
|
| 136 |
|
| 137 |
-
def process_image(image, caption_for_grounding="elf girl"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
print("\n--- Rozpoczynanie przetwarzania obrazu ---")
|
| 139 |
images = [image]
|
| 140 |
captions = [caption_for_grounding]
|
|
@@ -157,8 +190,9 @@ def process_image(image, caption_for_grounding="elf girl"):
|
|
| 157 |
"grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))]
|
| 158 |
}
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
|
|
|
| 162 |
|
| 163 |
print("--- Zakończono przetwarzanie ---")
|
| 164 |
return final_json, visualization_image
|
|
@@ -167,9 +201,11 @@ def process_image(image, caption_for_grounding="elf girl"):
|
|
| 167 |
# 6) URUCHOMIENIE I WYŚWIETLENIE WYNIKÓW
|
| 168 |
# ==============================================================================
|
| 169 |
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
json_output, image_output = process_image(
|
| 172 |
-
pil_img, caption_for_grounding="elf girl")
|
| 173 |
|
| 174 |
print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====")
|
| 175 |
print(json.dumps(json_output, indent=2))
|
|
|
|
| 1 |
# ==============================================================================
|
| 2 |
# 1) INSTALACJA PAKIETÓW
|
| 3 |
# ==============================================================================
|
| 4 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
|
| 5 |
+
from IPython.display import display
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
import math
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import requests
|
| 12 |
+
import re
|
| 13 |
!pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning"
|
| 14 |
|
| 15 |
# ==============================================================================
|
| 16 |
# 2) IMPORTY
|
| 17 |
# ==============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
# ==============================================================================
|
| 20 |
# 3) POBRANIE OBRAZU
|
|
|
|
| 62 |
# ==============================================================================
|
| 63 |
|
| 64 |
|
| 65 |
+
def create_visualization(image, data, detailed_mode=False):
|
| 66 |
"""
|
| 67 |
Rysuje zaawansowaną wizualizację detekcji i asocjacji na obrazie.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
image: Obraz wejściowy
|
| 71 |
+
data: Dane JSON z wynikami
|
| 72 |
+
detailed_mode: Jeśli True, rysuje wszystko z JSON (OCR, grounding).
|
| 73 |
+
Jeśli False (domyślnie), rysuje tylko detekcje i asocjacje.
|
| 74 |
"""
|
| 75 |
img_draw = image.copy()
|
| 76 |
draw = ImageDraw.Draw(img_draw)
|
|
|
|
| 83 |
"tails": "purple",
|
| 84 |
"cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"],
|
| 85 |
"speaker_line": "magenta",
|
| 86 |
+
"ocr": "orange",
|
| 87 |
+
"grounding": "cyan",
|
| 88 |
}
|
| 89 |
+
line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1, "ocr": 2, "grounding": 2}
|
| 90 |
|
| 91 |
def get_box_center(box):
|
| 92 |
x1, y1, x2, y2 = box
|
|
|
|
| 139 |
draw_dashed_line(
|
| 140 |
draw, p1, p2, fill=colors["speaker_line"], width=1)
|
| 141 |
|
| 142 |
+
# Tryb wybredny - rysowanie dodatkowych elementów z JSON
|
| 143 |
+
if detailed_mode:
|
| 144 |
+
# Rysowanie OCR boxes
|
| 145 |
+
ocr_data = data.get("ocr", [])
|
| 146 |
+
for ocr_item in ocr_data:
|
| 147 |
+
box = ocr_item.get("box")
|
| 148 |
+
if box:
|
| 149 |
+
draw.rectangle(box, outline=colors["ocr"], width=line_widths["ocr"])
|
| 150 |
+
|
| 151 |
+
# Rysowanie Grounding boxes
|
| 152 |
+
grounding_data = data.get("grounding", [])
|
| 153 |
+
for grounding_item in grounding_data:
|
| 154 |
+
boxes = grounding_item.get("boxes", [])
|
| 155 |
+
for box in boxes:
|
| 156 |
+
draw.rectangle(box, outline=colors["grounding"], width=line_widths["grounding"])
|
| 157 |
+
|
| 158 |
return img_draw
|
| 159 |
|
| 160 |
|
| 161 |
+
def process_image(image, caption_for_grounding="elf girl", detailed_mode=False):
|
| 162 |
+
"""
|
| 163 |
+
Przetwarza obraz i tworzy wizualizację.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
image: Obraz wejściowy
|
| 167 |
+
caption_for_grounding: Caption dla character grounding
|
| 168 |
+
detailed_mode: Jeśli True, wizualizacja zawiera wszystko z JSON (OCR, grounding).
|
| 169 |
+
Jeśli False (domyślnie), tylko detekcje i asocjacje.
|
| 170 |
+
"""
|
| 171 |
print("\n--- Rozpoczynanie przetwarzania obrazu ---")
|
| 172 |
images = [image]
|
| 173 |
captions = [caption_for_grounding]
|
|
|
|
| 190 |
"grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))]
|
| 191 |
}
|
| 192 |
|
| 193 |
+
mode_text = "wybredny (wszystkie elementy)" if detailed_mode else "domyślny (detekcje i asocjacje)"
|
| 194 |
+
print(f"Tworzenie wizualizacji w trybie: {mode_text}")
|
| 195 |
+
visualization_image = create_visualization(image, final_json, detailed_mode=detailed_mode)
|
| 196 |
|
| 197 |
print("--- Zakończono przetwarzanie ---")
|
| 198 |
return final_json, visualization_image
|
|
|
|
| 201 |
# 6) URUCHOMIENIE I WYŚWIETLENIE WYNIKÓW
|
| 202 |
# ==============================================================================
|
| 203 |
|
| 204 |
+
# Tryb wizualizacji:
|
| 205 |
+
# detailed_mode=False (domyślny) - rysuje tylko detekcje i asocjacje (obecne kolory)
|
| 206 |
+
# detailed_mode=True (wybredny) - rysuje wszystko z JSON: OCR (pomarańczowy), grounding (cyjan)
|
| 207 |
json_output, image_output = process_image(
|
| 208 |
+
pil_img, caption_for_grounding="elf girl", detailed_mode=True)
|
| 209 |
|
| 210 |
print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====")
|
| 211 |
print(json.dumps(json_output, indent=2))
|
modeling_florence2.py
CHANGED
|
@@ -23,7 +23,7 @@ import torch.utils.checkpoint
|
|
| 23 |
from torch import nn
|
| 24 |
import torch.nn.functional as F
|
| 25 |
import torch.utils.checkpoint as checkpoint
|
| 26 |
-
from torch.nn import CrossEntropyLoss
|
| 27 |
from collections import OrderedDict
|
| 28 |
from einops import rearrange, repeat
|
| 29 |
from timm.models.layers import DropPath, trunc_normal_
|
|
@@ -41,7 +41,7 @@ from transformers.utils import (
|
|
| 41 |
is_flash_attn_2_available,
|
| 42 |
is_flash_attn_greater_or_equal_2_10,
|
| 43 |
)
|
| 44 |
-
from .configuration_florence2 import Florence2Config
|
| 45 |
from .configuration_florence2 import Florence2LanguageConfig
|
| 46 |
from .configuration_florence2 import Florence2VisionConfig
|
| 47 |
from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_pairs_indices
|
|
@@ -72,6 +72,7 @@ logger = logging.get_logger(__name__)
|
|
| 72 |
|
| 73 |
_CONFIG_FOR_DOC = "Florence2Config"
|
| 74 |
|
|
|
|
| 75 |
class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
| 76 |
"""
|
| 77 |
This module learns positional embeddings up to a fixed maximum size.
|
|
@@ -80,7 +81,8 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
|
| 80 |
def __init__(self, embedding_dim=256, num_pos=50):
|
| 81 |
super().__init__()
|
| 82 |
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
|
| 83 |
-
self.column_embeddings = nn.Embedding(
|
|
|
|
| 84 |
|
| 85 |
def forward(self, pixel_values):
|
| 86 |
"""
|
|
@@ -95,7 +97,8 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
|
| 95 |
x_emb = self.column_embeddings(width_values)
|
| 96 |
y_emb = self.row_embeddings(height_values)
|
| 97 |
# (height, width, embedding_dim * 2)
|
| 98 |
-
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1),
|
|
|
|
| 99 |
# (embedding_dim * 2, height, width)
|
| 100 |
pos = pos.permute(2, 0, 1)
|
| 101 |
pos = pos.unsqueeze(0)
|
|
@@ -105,6 +108,7 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
|
| 105 |
pos = pos.permute(0, 2, 3, 1)
|
| 106 |
return pos
|
| 107 |
|
|
|
|
| 108 |
class PositionalEmbeddingCosine1D(nn.Module):
|
| 109 |
"""
|
| 110 |
This class implements a very simple positional encoding. It follows closely
|
|
@@ -116,6 +120,7 @@ class PositionalEmbeddingCosine1D(nn.Module):
|
|
| 116 |
dropout_prob: The dropout probability.
|
| 117 |
max_seq_len: The maximum length to precompute the positional encodings.
|
| 118 |
"""
|
|
|
|
| 119 |
def __init__(
|
| 120 |
self,
|
| 121 |
embed_dim: int = 512,
|
|
@@ -171,6 +176,7 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
|
|
| 171 |
embed_dim: The dimension of the embeddings.
|
| 172 |
max_seq_len: The maximum length to precompute the positional encodings.
|
| 173 |
"""
|
|
|
|
| 174 |
def __init__(
|
| 175 |
self,
|
| 176 |
embedding_dim: int = 512,
|
|
@@ -196,7 +202,8 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
|
|
| 196 |
len_seq = seq_embeds.size(-2)
|
| 197 |
assert len_seq <= self.num_pos
|
| 198 |
# [T, D]
|
| 199 |
-
pos_embeds = self.embeddings(
|
|
|
|
| 200 |
# Adapt pre-computed positional embeddings to the input.
|
| 201 |
if shape_len == 3:
|
| 202 |
pos_embeds = pos_embeds.view(
|
|
@@ -204,7 +211,6 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
|
|
| 204 |
return pos_embeds
|
| 205 |
|
| 206 |
|
| 207 |
-
|
| 208 |
class MySequential(nn.Sequential):
|
| 209 |
def forward(self, *inputs):
|
| 210 |
for module in self._modules.values():
|
|
@@ -349,7 +355,8 @@ class ChannelAttention(nn.Module):
|
|
| 349 |
def forward(self, x, size):
|
| 350 |
B, N, C = x.shape
|
| 351 |
|
| 352 |
-
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C //
|
|
|
|
| 353 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 354 |
|
| 355 |
q = q * (float(N) ** -0.5)
|
|
@@ -368,18 +375,22 @@ class ChannelBlock(nn.Module):
|
|
| 368 |
conv_at_attn=True, conv_at_ffn=True):
|
| 369 |
super().__init__()
|
| 370 |
|
| 371 |
-
drop_path = DropPath(
|
|
|
|
| 372 |
|
| 373 |
-
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
|
|
|
| 374 |
self.channel_attn = PreNorm(
|
| 375 |
norm_layer(dim),
|
| 376 |
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
|
| 377 |
drop_path
|
| 378 |
)
|
| 379 |
-
self.conv2 = PreNorm(None, DepthWiseConv2d(
|
|
|
|
| 380 |
self.ffn = PreNorm(
|
| 381 |
norm_layer(dim),
|
| 382 |
-
Mlp(in_features=dim, hidden_features=int(
|
|
|
|
| 383 |
drop_path
|
| 384 |
)
|
| 385 |
|
|
@@ -397,16 +408,19 @@ class ChannelBlock(nn.Module):
|
|
| 397 |
|
| 398 |
def window_partition(x, window_size: int):
|
| 399 |
B, H, W, C = x.shape
|
| 400 |
-
x = x.view(B, H // window_size, window_size,
|
| 401 |
-
|
|
|
|
|
|
|
| 402 |
return windows
|
| 403 |
|
| 404 |
|
| 405 |
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
|
| 406 |
-
B = batch_size
|
| 407 |
# this will cause onnx conversion failed for dynamic axis, because treated as constant
|
| 408 |
-
# int(windows.shape[0] / (H * W / window_size / window_size))
|
| 409 |
-
x = windows.view(B, H // window_size, W // window_size,
|
|
|
|
| 410 |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 411 |
return x
|
| 412 |
|
|
@@ -447,7 +461,8 @@ class WindowAttention(nn.Module):
|
|
| 447 |
# attn_windows = self.attn(x_windows)
|
| 448 |
|
| 449 |
B_, N, C = x.shape
|
| 450 |
-
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
|
|
|
|
| 451 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 452 |
|
| 453 |
q = q * self.scale
|
|
@@ -478,18 +493,22 @@ class SpatialBlock(nn.Module):
|
|
| 478 |
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
|
| 479 |
super().__init__()
|
| 480 |
|
| 481 |
-
drop_path = DropPath(
|
|
|
|
| 482 |
|
| 483 |
-
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
|
|
|
| 484 |
self.window_attn = PreNorm(
|
| 485 |
norm_layer(dim),
|
| 486 |
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
|
| 487 |
drop_path
|
| 488 |
)
|
| 489 |
-
self.conv2 = PreNorm(None, DepthWiseConv2d(
|
|
|
|
| 490 |
self.ffn = PreNorm(
|
| 491 |
norm_layer(dim),
|
| 492 |
-
Mlp(in_features=dim, hidden_features=int(
|
|
|
|
| 493 |
drop_path
|
| 494 |
)
|
| 495 |
|
|
@@ -547,7 +566,7 @@ class DaViT(nn.Module):
|
|
| 547 |
enable_checkpoint=True,
|
| 548 |
conv_at_attn=True,
|
| 549 |
conv_at_ffn=True,
|
| 550 |
-
|
| 551 |
super().__init__()
|
| 552 |
|
| 553 |
self.num_classes = num_classes
|
|
@@ -559,7 +578,8 @@ class DaViT(nn.Module):
|
|
| 559 |
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
| 560 |
|
| 561 |
num_stages = len(embed_dims)
|
| 562 |
-
dpr = [x.item() for x in torch.linspace(
|
|
|
|
| 563 |
|
| 564 |
depth_offset = 0
|
| 565 |
convs = []
|
|
@@ -613,7 +633,8 @@ class DaViT(nn.Module):
|
|
| 613 |
|
| 614 |
self.norms = norm_layer(self.embed_dims[-1])
|
| 615 |
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 616 |
-
self.head = nn.Linear(
|
|
|
|
| 617 |
|
| 618 |
self.apply(self._init_weights)
|
| 619 |
|
|
@@ -648,7 +669,8 @@ class DaViT(nn.Module):
|
|
| 648 |
for conv, block in zip(self.convs, self.blocks):
|
| 649 |
x, input_size = conv(x, input_size)
|
| 650 |
if self.enable_checkpoint:
|
| 651 |
-
x, input_size = checkpoint.checkpoint(
|
|
|
|
| 652 |
else:
|
| 653 |
x, input_size = block(x, input_size)
|
| 654 |
return x
|
|
@@ -668,7 +690,7 @@ class DaViT(nn.Module):
|
|
| 668 |
x = self.forward_features(x)
|
| 669 |
x = self.head(x)
|
| 670 |
return x
|
| 671 |
-
|
| 672 |
@classmethod
|
| 673 |
def from_config(cls, config):
|
| 674 |
return cls(
|
|
@@ -685,18 +707,19 @@ class DaViT(nn.Module):
|
|
| 685 |
)
|
| 686 |
|
| 687 |
|
| 688 |
-
|
| 689 |
-
|
| 690 |
if is_flash_attn_2_available():
|
| 691 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 692 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 693 |
|
| 694 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
|
|
|
| 695 |
def _get_unpad_data(attention_mask):
|
| 696 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 697 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 698 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 699 |
-
cu_seqlens = F.pad(torch.cumsum(
|
|
|
|
| 700 |
return (
|
| 701 |
indices,
|
| 702 |
cu_seqlens,
|
|
@@ -834,7 +857,8 @@ class Florence2Attention(nn.Module):
|
|
| 834 |
if past_key_value[0] is not None:
|
| 835 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 836 |
if past_key_value[1] is not None:
|
| 837 |
-
value_states = torch.cat(
|
|
|
|
| 838 |
else:
|
| 839 |
# self_attention
|
| 840 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
@@ -851,7 +875,8 @@ class Florence2Attention(nn.Module):
|
|
| 851 |
past_key_value = (key_states, value_states)
|
| 852 |
|
| 853 |
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 854 |
-
query_states = self._shape(
|
|
|
|
| 855 |
key_states = key_states.reshape(*proj_shape)
|
| 856 |
value_states = value_states.reshape(*proj_shape)
|
| 857 |
|
|
@@ -869,8 +894,10 @@ class Florence2Attention(nn.Module):
|
|
| 869 |
raise ValueError(
|
| 870 |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 871 |
)
|
| 872 |
-
attn_weights = attn_weights.view(
|
| 873 |
-
|
|
|
|
|
|
|
| 874 |
|
| 875 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 876 |
|
|
@@ -880,20 +907,25 @@ class Florence2Attention(nn.Module):
|
|
| 880 |
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 881 |
f" {layer_head_mask.size()}"
|
| 882 |
)
|
| 883 |
-
attn_weights = layer_head_mask.view(
|
| 884 |
-
|
|
|
|
|
|
|
| 885 |
|
| 886 |
if output_attentions:
|
| 887 |
# this operation is a bit awkward, but it's required to
|
| 888 |
# make sure that attn_weights keeps its gradient.
|
| 889 |
# In order to do so, attn_weights have to be reshaped
|
| 890 |
# twice and have to be reused in the following
|
| 891 |
-
attn_weights_reshaped = attn_weights.view(
|
| 892 |
-
|
|
|
|
|
|
|
| 893 |
else:
|
| 894 |
attn_weights_reshaped = None
|
| 895 |
|
| 896 |
-
attn_probs = nn.functional.dropout(
|
|
|
|
| 897 |
|
| 898 |
attn_output = torch.bmm(attn_probs, value_states)
|
| 899 |
|
|
@@ -903,7 +935,8 @@ class Florence2Attention(nn.Module):
|
|
| 903 |
f" {attn_output.size()}"
|
| 904 |
)
|
| 905 |
|
| 906 |
-
attn_output = attn_output.view(
|
|
|
|
| 907 |
attn_output = attn_output.transpose(1, 2)
|
| 908 |
|
| 909 |
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
@@ -945,7 +978,8 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 945 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 946 |
# Florence2FlashAttention2 attention does not support output_attentions
|
| 947 |
if output_attentions:
|
| 948 |
-
raise ValueError(
|
|
|
|
| 949 |
|
| 950 |
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 951 |
# for the decoder
|
|
@@ -970,13 +1004,16 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 970 |
elif is_cross_attention:
|
| 971 |
# cross_attentions
|
| 972 |
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 973 |
-
value_states = self._reshape(
|
|
|
|
| 974 |
elif past_key_value is not None:
|
| 975 |
# reuse k, v, self_attention
|
| 976 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 977 |
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 978 |
-
key_states = torch.cat(
|
| 979 |
-
|
|
|
|
|
|
|
| 980 |
else:
|
| 981 |
# self_attention
|
| 982 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
|
@@ -990,7 +1027,8 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 990 |
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 991 |
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 992 |
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 993 |
-
past_key_value = (key_states.transpose(
|
|
|
|
| 994 |
|
| 995 |
kv_seq_len = key_states.shape[-2]
|
| 996 |
if past_key_value is not None:
|
|
@@ -1086,7 +1124,8 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 1086 |
causal=causal,
|
| 1087 |
)
|
| 1088 |
|
| 1089 |
-
attn_output = pad_input(
|
|
|
|
| 1090 |
else:
|
| 1091 |
attn_output = flash_attn_func(
|
| 1092 |
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
|
@@ -1096,18 +1135,22 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 1096 |
|
| 1097 |
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
| 1098 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 1099 |
-
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
|
|
|
|
| 1100 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 1101 |
|
| 1102 |
key_layer = index_first_axis(
|
| 1103 |
-
key_layer.reshape(batch_size * kv_seq_len,
|
|
|
|
| 1104 |
)
|
| 1105 |
value_layer = index_first_axis(
|
| 1106 |
-
value_layer.reshape(batch_size * kv_seq_len,
|
|
|
|
| 1107 |
)
|
| 1108 |
if query_length == kv_seq_len:
|
| 1109 |
query_layer = index_first_axis(
|
| 1110 |
-
query_layer.reshape(batch_size * kv_seq_len,
|
|
|
|
| 1111 |
)
|
| 1112 |
cu_seqlens_q = cu_seqlens_k
|
| 1113 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
@@ -1122,7 +1165,8 @@ class Florence2FlashAttention2(Florence2Attention):
|
|
| 1122 |
else:
|
| 1123 |
# The -q_len: slice assumes left padding.
|
| 1124 |
attention_mask = attention_mask[:, -query_length:]
|
| 1125 |
-
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
|
|
|
| 1126 |
|
| 1127 |
return (
|
| 1128 |
query_layer,
|
|
@@ -1192,7 +1236,8 @@ class Florence2SdpaAttention(Florence2Attention):
|
|
| 1192 |
if past_key_value[0] is not None:
|
| 1193 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 1194 |
if past_key_value[1] is not None:
|
| 1195 |
-
value_states = torch.cat(
|
|
|
|
| 1196 |
else:
|
| 1197 |
# self_attention
|
| 1198 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
@@ -1294,23 +1339,28 @@ class Florence2EncoderLayer(nn.Module):
|
|
| 1294 |
layer_head_mask=layer_head_mask,
|
| 1295 |
output_attentions=output_attentions,
|
| 1296 |
)
|
| 1297 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1298 |
hidden_states = residual + hidden_states
|
| 1299 |
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1300 |
|
| 1301 |
residual = hidden_states
|
| 1302 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1303 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1304 |
hidden_states = self.fc2(hidden_states)
|
| 1305 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1306 |
hidden_states = residual + hidden_states
|
| 1307 |
hidden_states = self.final_layer_norm(hidden_states)
|
| 1308 |
|
| 1309 |
if hidden_states.dtype == torch.float16 and (
|
| 1310 |
-
torch.isinf(hidden_states).any() or torch.isnan(
|
|
|
|
| 1311 |
):
|
| 1312 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 1313 |
-
hidden_states = torch.clamp(
|
|
|
|
| 1314 |
|
| 1315 |
outputs = (hidden_states,)
|
| 1316 |
|
|
@@ -1384,7 +1434,8 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1384 |
|
| 1385 |
# Self Attention
|
| 1386 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 1387 |
-
self_attn_past_key_value = past_key_value[:
|
|
|
|
| 1388 |
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
| 1389 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 1390 |
hidden_states=hidden_states,
|
|
@@ -1393,7 +1444,8 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1393 |
layer_head_mask=layer_head_mask,
|
| 1394 |
output_attentions=output_attentions,
|
| 1395 |
)
|
| 1396 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1397 |
hidden_states = residual + hidden_states
|
| 1398 |
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1399 |
|
|
@@ -1404,7 +1456,8 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1404 |
residual = hidden_states
|
| 1405 |
|
| 1406 |
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
| 1407 |
-
cross_attn_past_key_value = past_key_value[-2:
|
|
|
|
| 1408 |
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
| 1409 |
hidden_states=hidden_states,
|
| 1410 |
key_value_states=encoder_hidden_states,
|
|
@@ -1413,7 +1466,8 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1413 |
past_key_value=cross_attn_past_key_value,
|
| 1414 |
output_attentions=output_attentions,
|
| 1415 |
)
|
| 1416 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1417 |
hidden_states = residual + hidden_states
|
| 1418 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 1419 |
|
|
@@ -1423,9 +1477,11 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1423 |
# Fully Connected
|
| 1424 |
residual = hidden_states
|
| 1425 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1426 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1427 |
hidden_states = self.fc2(hidden_states)
|
| 1428 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1429 |
hidden_states = residual + hidden_states
|
| 1430 |
hidden_states = self.final_layer_norm(hidden_states)
|
| 1431 |
|
|
@@ -1440,7 +1496,6 @@ class Florence2DecoderLayer(nn.Module):
|
|
| 1440 |
return outputs
|
| 1441 |
|
| 1442 |
|
| 1443 |
-
|
| 1444 |
class Florence2LanguagePreTrainedModel(PreTrainedModel):
|
| 1445 |
config_class = Florence2LanguageConfig
|
| 1446 |
base_model_prefix = "model"
|
|
@@ -1465,7 +1520,8 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
|
|
| 1465 |
@property
|
| 1466 |
def dummy_inputs(self):
|
| 1467 |
pad_token = self.config.pad_token_id
|
| 1468 |
-
input_ids = torch.tensor(
|
|
|
|
| 1469 |
dummy_inputs = {
|
| 1470 |
"attention_mask": input_ids.ne(pad_token),
|
| 1471 |
"input_ids": input_ids,
|
|
@@ -1505,7 +1561,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
|
| 1505 |
config.max_position_embeddings,
|
| 1506 |
embed_dim,
|
| 1507 |
)
|
| 1508 |
-
self.layers = nn.ModuleList([Florence2EncoderLayer(
|
|
|
|
| 1509 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 1510 |
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 1511 |
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
|
@@ -1574,14 +1631,16 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
|
| 1574 |
|
| 1575 |
# retrieve input_ids and inputs_embeds
|
| 1576 |
if input_ids is not None and inputs_embeds is not None:
|
| 1577 |
-
raise ValueError(
|
|
|
|
| 1578 |
elif input_ids is not None:
|
| 1579 |
input = input_ids
|
| 1580 |
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
| 1581 |
elif inputs_embeds is not None:
|
| 1582 |
input = inputs_embeds[:, :, -1]
|
| 1583 |
else:
|
| 1584 |
-
raise ValueError(
|
|
|
|
| 1585 |
|
| 1586 |
if inputs_embeds is None:
|
| 1587 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
@@ -1591,7 +1650,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
|
| 1591 |
|
| 1592 |
hidden_states = inputs_embeds + embed_pos
|
| 1593 |
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1594 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1595 |
|
| 1596 |
# expand attention_mask
|
| 1597 |
if attention_mask is not None:
|
|
@@ -1601,10 +1661,12 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
|
| 1601 |
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
| 1602 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1603 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1604 |
-
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
|
|
|
| 1605 |
else:
|
| 1606 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1607 |
-
attention_mask = _prepare_4d_attention_mask(
|
|
|
|
| 1608 |
|
| 1609 |
encoder_states = () if output_hidden_states else None
|
| 1610 |
all_attentions = () if output_attentions else None
|
|
@@ -1642,7 +1704,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
|
|
| 1642 |
layer_outputs = encoder_layer(
|
| 1643 |
hidden_states,
|
| 1644 |
attention_mask,
|
| 1645 |
-
layer_head_mask=(
|
|
|
|
| 1646 |
output_attentions=output_attentions,
|
| 1647 |
)
|
| 1648 |
|
|
@@ -1676,7 +1739,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1676 |
self.layerdrop = config.decoder_layerdrop
|
| 1677 |
self.padding_idx = config.pad_token_id
|
| 1678 |
self.max_target_positions = config.max_position_embeddings
|
| 1679 |
-
embed_scale = math.sqrt(
|
|
|
|
| 1680 |
|
| 1681 |
self.embed_tokens = Florence2ScaledWordEmbedding(
|
| 1682 |
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
|
@@ -1689,7 +1753,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1689 |
config.max_position_embeddings,
|
| 1690 |
config.d_model,
|
| 1691 |
)
|
| 1692 |
-
self.layers = nn.ModuleList([Florence2DecoderLayer(
|
|
|
|
| 1693 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 1694 |
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 1695 |
|
|
@@ -1794,7 +1859,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1794 |
|
| 1795 |
# retrieve input_ids and inputs_embeds
|
| 1796 |
if input_ids is not None and inputs_embeds is not None:
|
| 1797 |
-
raise ValueError(
|
|
|
|
| 1798 |
elif input_ids is not None:
|
| 1799 |
input = input_ids
|
| 1800 |
input_shape = input.shape
|
|
@@ -1803,17 +1869,20 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1803 |
input_shape = inputs_embeds.size()[:-1]
|
| 1804 |
input = inputs_embeds[:, :, -1]
|
| 1805 |
else:
|
| 1806 |
-
raise ValueError(
|
|
|
|
| 1807 |
|
| 1808 |
# past_key_values_length
|
| 1809 |
-
past_key_values_length = past_key_values[0][0].shape[
|
|
|
|
| 1810 |
|
| 1811 |
if inputs_embeds is None:
|
| 1812 |
inputs_embeds = self.embed_tokens(input)
|
| 1813 |
|
| 1814 |
if self._use_flash_attention_2:
|
| 1815 |
# 2d mask is passed through the layers
|
| 1816 |
-
attention_mask = attention_mask if (
|
|
|
|
| 1817 |
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
|
| 1818 |
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
| 1819 |
# the manual implementation that requires a 4D causal mask in all cases.
|
|
@@ -1855,7 +1924,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1855 |
hidden_states = inputs_embeds + positions
|
| 1856 |
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1857 |
|
| 1858 |
-
hidden_states = nn.functional.dropout(
|
|
|
|
| 1859 |
|
| 1860 |
if self.gradient_checkpointing and self.training:
|
| 1861 |
if use_cache:
|
|
@@ -1867,7 +1937,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1867 |
# decoder layers
|
| 1868 |
all_hidden_states = () if output_hidden_states else None
|
| 1869 |
all_self_attns = () if output_attentions else None
|
| 1870 |
-
all_cross_attentions = () if (
|
|
|
|
| 1871 |
next_decoder_cache = () if use_cache else None
|
| 1872 |
|
| 1873 |
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
@@ -1909,7 +1980,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1909 |
attention_mask=attention_mask,
|
| 1910 |
encoder_hidden_states=encoder_hidden_states,
|
| 1911 |
encoder_attention_mask=encoder_attention_mask,
|
| 1912 |
-
layer_head_mask=(
|
|
|
|
| 1913 |
cross_attn_layer_head_mask=(
|
| 1914 |
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
| 1915 |
),
|
|
@@ -1920,7 +1992,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1920 |
hidden_states = layer_outputs[0]
|
| 1921 |
|
| 1922 |
if use_cache:
|
| 1923 |
-
next_decoder_cache += (
|
|
|
|
| 1924 |
|
| 1925 |
if output_attentions:
|
| 1926 |
all_self_attns += (layer_outputs[1],)
|
|
@@ -1949,7 +2022,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
|
|
| 1949 |
|
| 1950 |
|
| 1951 |
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
| 1952 |
-
_tied_weights_keys = ["encoder.embed_tokens.weight",
|
|
|
|
| 1953 |
|
| 1954 |
def __init__(self, config: Florence2LanguageConfig):
|
| 1955 |
super().__init__(config)
|
|
@@ -2035,8 +2109,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
|
| 2035 |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 2036 |
encoder_outputs = BaseModelOutput(
|
| 2037 |
last_hidden_state=encoder_outputs[0],
|
| 2038 |
-
hidden_states=encoder_outputs[1] if len(
|
| 2039 |
-
|
|
|
|
|
|
|
| 2040 |
)
|
| 2041 |
|
| 2042 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
@@ -2072,14 +2148,17 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
|
| 2072 |
|
| 2073 |
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
| 2074 |
base_model_prefix = "model"
|
| 2075 |
-
_tied_weights_keys = ["encoder.embed_tokens.weight",
|
|
|
|
| 2076 |
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
| 2077 |
|
| 2078 |
def __init__(self, config: Florence2LanguageConfig):
|
| 2079 |
super().__init__(config)
|
| 2080 |
self.model = Florence2LanguageModel(config)
|
| 2081 |
-
self.register_buffer("final_logits_bias", torch.zeros(
|
| 2082 |
-
|
|
|
|
|
|
|
| 2083 |
|
| 2084 |
# Initialize weights and apply final processing
|
| 2085 |
self.post_init()
|
|
@@ -2091,7 +2170,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2091 |
return self.model.get_decoder()
|
| 2092 |
|
| 2093 |
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
| 2094 |
-
new_embeddings = super().resize_token_embeddings(
|
|
|
|
| 2095 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
| 2096 |
return new_embeddings
|
| 2097 |
|
|
@@ -2100,7 +2180,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2100 |
if new_num_tokens <= old_num_tokens:
|
| 2101 |
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
| 2102 |
else:
|
| 2103 |
-
extra_bias = torch.zeros(
|
|
|
|
| 2104 |
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
| 2105 |
self.register_buffer("final_logits_bias", new_bias)
|
| 2106 |
|
|
@@ -2141,7 +2222,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2141 |
|
| 2142 |
if labels is not None:
|
| 2143 |
if use_cache:
|
| 2144 |
-
logger.warning(
|
|
|
|
| 2145 |
use_cache = False
|
| 2146 |
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 2147 |
decoder_input_ids = shift_tokens_right(
|
|
@@ -2173,7 +2255,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2173 |
if labels is not None:
|
| 2174 |
labels = labels.to(lm_logits.device)
|
| 2175 |
loss_fct = CrossEntropyLoss()
|
| 2176 |
-
masked_lm_loss = loss_fct(
|
|
|
|
| 2177 |
|
| 2178 |
if not return_dict:
|
| 2179 |
output = (lm_logits,) + outputs[1:]
|
|
@@ -2227,7 +2310,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2227 |
"head_mask": head_mask,
|
| 2228 |
"decoder_head_mask": decoder_head_mask,
|
| 2229 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 2230 |
-
|
|
|
|
| 2231 |
}
|
| 2232 |
|
| 2233 |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
|
@@ -2239,11 +2323,13 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2239 |
for layer_past in past_key_values:
|
| 2240 |
# cached cross_attention states don't have to be reordered -> they are always the same
|
| 2241 |
reordered_past += (
|
| 2242 |
-
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
|
|
|
| 2243 |
+ layer_past[2:],
|
| 2244 |
)
|
| 2245 |
return reordered_past
|
| 2246 |
|
|
|
|
| 2247 |
@dataclass
|
| 2248 |
class Florence2Seq2SeqLMOutput(ModelOutput):
|
| 2249 |
"""
|
|
@@ -2429,6 +2515,7 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
|
| 2429 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 2430 |
"""
|
| 2431 |
|
|
|
|
| 2432 |
@add_start_docstrings(
|
| 2433 |
"""The FLORENCE2 vision model without any head""",
|
| 2434 |
FLORENCE2_START_DOCSTRING,
|
|
@@ -2436,11 +2523,12 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
|
|
| 2436 |
class Florence2VisionModel(Florence2PreTrainedModel):
|
| 2437 |
def __init__(self, config: Florence2VisionConfig):
|
| 2438 |
super().__init__(config)
|
| 2439 |
-
assert config.model_type in [
|
|
|
|
| 2440 |
self.vision_tower = DaViT.from_config(config=config)
|
| 2441 |
|
| 2442 |
self.post_init()
|
| 2443 |
-
|
| 2444 |
def forward(self, pixel_values):
|
| 2445 |
if len(pixel_values.shape) == 4:
|
| 2446 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
|
@@ -2456,13 +2544,14 @@ class Florence2VisionModel(Florence2PreTrainedModel):
|
|
| 2456 |
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
| 2457 |
def __init__(self, config: Florence2VisionConfig):
|
| 2458 |
super().__init__(config)
|
| 2459 |
-
assert config.model_type in [
|
|
|
|
| 2460 |
self.vision_tower = DaViT.from_config(config=config)
|
| 2461 |
|
| 2462 |
self._build_image_projection_layers(config)
|
| 2463 |
|
| 2464 |
self.post_init()
|
| 2465 |
-
|
| 2466 |
def _build_image_projection_layers(self, config):
|
| 2467 |
image_dim_out = config.dim_embed[-1]
|
| 2468 |
dim_projection = config.projection_dim
|
|
@@ -2498,7 +2587,7 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2498 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
| 2499 |
else:
|
| 2500 |
raise ValueError(f'invalid image shape {pixel_values.shape}')
|
| 2501 |
-
|
| 2502 |
if self.image_pos_embed is not None:
|
| 2503 |
x = x.view(batch_size * T, -1, x.shape[-1])
|
| 2504 |
num_tokens = x.shape[-2]
|
|
@@ -2510,15 +2599,18 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2510 |
x = x.view(batch_size, T * h*w, x.shape[-1])
|
| 2511 |
|
| 2512 |
if self.visual_temporal_embed is not None:
|
| 2513 |
-
visual_temporal_embed = self.visual_temporal_embed(
|
| 2514 |
-
|
|
|
|
|
|
|
| 2515 |
|
| 2516 |
x_feat_dict = {}
|
| 2517 |
|
| 2518 |
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
| 2519 |
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
|
| 2520 |
|
| 2521 |
-
temporal_avg_pool_x = x.view(
|
|
|
|
| 2522 |
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
|
| 2523 |
|
| 2524 |
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
|
@@ -2527,7 +2619,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2527 |
new_x = []
|
| 2528 |
for _image_feature_source in self.image_feature_source:
|
| 2529 |
if _image_feature_source not in x_feat_dict:
|
| 2530 |
-
raise ValueError(
|
|
|
|
| 2531 |
new_x.append(x_feat_dict[_image_feature_source])
|
| 2532 |
|
| 2533 |
x = torch.cat(new_x, dim=1)
|
|
@@ -2535,11 +2628,9 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2535 |
x = x @ self.image_projection
|
| 2536 |
x = self.image_proj_norm(x)
|
| 2537 |
|
| 2538 |
-
|
| 2539 |
return x
|
| 2540 |
|
| 2541 |
|
| 2542 |
-
|
| 2543 |
@add_start_docstrings(
|
| 2544 |
"""The FLORENCE2 model which consists of a vision backbone and a language model.""",
|
| 2545 |
FLORENCE2_START_DOCSTRING,
|
|
@@ -2547,9 +2638,10 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2547 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
| 2548 |
def __init__(self, config: Florence2Config):
|
| 2549 |
super().__init__(config)
|
| 2550 |
-
assert config.vision_config.model_type in [
|
|
|
|
| 2551 |
self.vision_tower = DaViT.from_config(config=config.vision_config)
|
| 2552 |
-
# remove unused layers
|
| 2553 |
del self.vision_tower.head
|
| 2554 |
del self.vision_tower.norms
|
| 2555 |
|
|
@@ -2557,10 +2649,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2557 |
self._attn_implementation = config._attn_implementation
|
| 2558 |
self._build_image_projection_layers(config)
|
| 2559 |
|
| 2560 |
-
language_model = Florence2LanguageForConditionalGeneration(
|
|
|
|
| 2561 |
|
| 2562 |
if language_model._tied_weights_keys is not None:
|
| 2563 |
-
self._tied_weights_keys = [
|
|
|
|
| 2564 |
self.language_model = language_model
|
| 2565 |
self.character_character_matching_head = nn.Sequential(
|
| 2566 |
nn.Linear(2 * 768, config.projection_dim),
|
|
@@ -2584,16 +2678,17 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2584 |
nn.Linear(config.projection_dim, 1)
|
| 2585 |
)
|
| 2586 |
self.text_classification_head = nn.Linear(config.projection_dim, 1)
|
| 2587 |
-
self.character_embedding_projection = nn.Linear(
|
|
|
|
| 2588 |
|
| 2589 |
self._init_weights(self.character_character_matching_head)
|
| 2590 |
self._init_weights(self.text_character_matching_head)
|
| 2591 |
self._init_weights(self.text_tail_matching_head)
|
| 2592 |
self._init_weights(self.text_classification_head)
|
| 2593 |
-
|
| 2594 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 2595 |
self.post_init()
|
| 2596 |
-
|
| 2597 |
def _init_weights(self, m):
|
| 2598 |
if isinstance(m, nn.Linear):
|
| 2599 |
trunc_normal_(m.weight, std=0.02)
|
|
@@ -2613,7 +2708,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2613 |
elif isinstance(m, nn.Sequential):
|
| 2614 |
for layer in m:
|
| 2615 |
self._init_weights(layer)
|
| 2616 |
-
|
| 2617 |
def _build_image_projection_layers(self, config):
|
| 2618 |
image_dim_out = config.vision_config.dim_embed[-1]
|
| 2619 |
dim_projection = config.vision_config.projection_dim
|
|
@@ -2652,13 +2747,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2652 |
return self.language_model.get_input_embeddings()
|
| 2653 |
|
| 2654 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 2655 |
-
model_embeds = self.language_model.resize_token_embeddings(
|
|
|
|
| 2656 |
# update vocab size
|
| 2657 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 2658 |
self.config.vocab_size = model_embeds.num_embeddings
|
| 2659 |
self.vocab_size = model_embeds.num_embeddings
|
| 2660 |
return model_embeds
|
| 2661 |
-
|
| 2662 |
def _encode_image(self, pixel_values):
|
| 2663 |
if len(pixel_values.shape) == 4:
|
| 2664 |
batch_size, C, H, W = pixel_values.shape
|
|
@@ -2666,7 +2762,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2666 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
| 2667 |
else:
|
| 2668 |
raise ValueError(f'invalid image shape {pixel_values.shape}')
|
| 2669 |
-
|
| 2670 |
if self.image_pos_embed is not None:
|
| 2671 |
x = x.view(batch_size * T, -1, x.shape[-1])
|
| 2672 |
num_tokens = x.shape[-2]
|
|
@@ -2678,15 +2774,18 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2678 |
x = x.view(batch_size, T * h*w, x.shape[-1])
|
| 2679 |
|
| 2680 |
if self.visual_temporal_embed is not None:
|
| 2681 |
-
visual_temporal_embed = self.visual_temporal_embed(
|
| 2682 |
-
|
|
|
|
|
|
|
| 2683 |
|
| 2684 |
x_feat_dict = {}
|
| 2685 |
|
| 2686 |
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
| 2687 |
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
|
| 2688 |
|
| 2689 |
-
temporal_avg_pool_x = x.view(
|
|
|
|
| 2690 |
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
|
| 2691 |
|
| 2692 |
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
|
@@ -2695,7 +2794,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2695 |
new_x = []
|
| 2696 |
for _image_feature_source in self.image_feature_source:
|
| 2697 |
if _image_feature_source not in x_feat_dict:
|
| 2698 |
-
raise ValueError(
|
|
|
|
| 2699 |
new_x.append(x_feat_dict[_image_feature_source])
|
| 2700 |
|
| 2701 |
x = torch.cat(new_x, dim=1)
|
|
@@ -2703,14 +2803,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2703 |
x = x @ self.image_projection
|
| 2704 |
x = self.image_proj_norm(x)
|
| 2705 |
|
| 2706 |
-
return x
|
| 2707 |
|
| 2708 |
def _merge_input_ids_with_image_features(
|
| 2709 |
-
self, image_features, inputs_embeds
|
| 2710 |
):
|
| 2711 |
batch_size, image_token_length = image_features.size()[:-1]
|
| 2712 |
device = image_features.device
|
| 2713 |
-
image_attention_mask = torch.ones(
|
|
|
|
| 2714 |
|
| 2715 |
# task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
|
| 2716 |
# task_prefix_attention_mask: [batch_size, context_length]
|
|
@@ -2718,17 +2819,19 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2718 |
return image_features, image_attention_mask
|
| 2719 |
|
| 2720 |
task_prefix_embeds = inputs_embeds
|
| 2721 |
-
task_prefix_attention_mask = torch.ones(
|
|
|
|
| 2722 |
|
| 2723 |
if len(task_prefix_attention_mask.shape) == 3:
|
| 2724 |
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
| 2725 |
|
| 2726 |
# concat [image embeds, task prefix embeds]
|
| 2727 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
| 2728 |
-
attention_mask = torch.cat(
|
|
|
|
| 2729 |
|
| 2730 |
return inputs_embeds, attention_mask
|
| 2731 |
-
|
| 2732 |
@torch.no_grad()
|
| 2733 |
def predict_detections_and_associations(
|
| 2734 |
self,
|
|
@@ -2740,7 +2843,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2740 |
essential_text_threshold=0.8,
|
| 2741 |
):
|
| 2742 |
batch_inputs = processor(
|
| 2743 |
-
batch_input_text=[
|
|
|
|
| 2744 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2745 |
batch_images=images,
|
| 2746 |
padding=True,
|
|
@@ -2758,13 +2862,16 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2758 |
do_sample=False,
|
| 2759 |
num_beams=3,
|
| 2760 |
)
|
| 2761 |
-
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
| 2762 |
-
|
|
|
|
|
|
|
| 2763 |
|
| 2764 |
results = []
|
| 2765 |
|
| 2766 |
for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
|
| 2767 |
-
categories = [map_to_category.get(
|
|
|
|
| 2768 |
result_for_this_image = {
|
| 2769 |
"panels": [],
|
| 2770 |
"texts": [],
|
|
@@ -2779,9 +2886,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2779 |
|
| 2780 |
cleaned_generated_ids = []
|
| 2781 |
for generated_id in generated_ids:
|
| 2782 |
-
index_of_last_bos = torch.where(
|
|
|
|
| 2783 |
cleaned_generated_ids.append(generated_id[index_of_last_bos:])
|
| 2784 |
-
cleaned_generated_ids = pad_sequence(
|
|
|
|
| 2785 |
association_outputs = self(
|
| 2786 |
input_ids=batch_inputs["input_ids"],
|
| 2787 |
pixel_values=batch_inputs["pixel_values"],
|
|
@@ -2790,17 +2899,21 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2790 |
)
|
| 2791 |
|
| 2792 |
for img_idx in range(len(results)):
|
| 2793 |
-
character_cluster_labels = UnionFind.from_adj_matrix(
|
| 2794 |
-
|
| 2795 |
-
|
| 2796 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2797 |
results[img_idx]["character_cluster_labels"] = character_cluster_labels
|
| 2798 |
results[img_idx]["text_character_associations"] = text_character_association
|
| 2799 |
results[img_idx]["text_tail_associations"] = text_tail_association
|
| 2800 |
results[img_idx]["is_essential_text"] = essential_text_logits
|
| 2801 |
|
| 2802 |
return results
|
| 2803 |
-
|
| 2804 |
@torch.no_grad()
|
| 2805 |
def predict_ocr(
|
| 2806 |
self,
|
|
@@ -2808,7 +2921,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2808 |
processor,
|
| 2809 |
):
|
| 2810 |
batch_inputs = processor(
|
| 2811 |
-
batch_input_text=[
|
|
|
|
| 2812 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2813 |
batch_images=images,
|
| 2814 |
padding=True,
|
|
@@ -2826,7 +2940,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2826 |
do_sample=False,
|
| 2827 |
num_beams=3,
|
| 2828 |
)
|
| 2829 |
-
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
|
|
|
| 2830 |
results = []
|
| 2831 |
for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
|
| 2832 |
ocr_texts = []
|
|
@@ -2850,9 +2965,10 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2850 |
):
|
| 2851 |
def convert_caption_to_instruction(caption):
|
| 2852 |
return "Locate the phrases in the caption: " + caption
|
| 2853 |
-
|
| 2854 |
batch_inputs = processor(
|
| 2855 |
-
batch_input_text=[convert_caption_to_instruction(
|
|
|
|
| 2856 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2857 |
batch_images=images,
|
| 2858 |
padding=True,
|
|
@@ -2880,7 +2996,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2880 |
do_sample=False,
|
| 2881 |
num_beams=3,
|
| 2882 |
)
|
| 2883 |
-
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
|
|
|
| 2884 |
return [
|
| 2885 |
{
|
| 2886 |
"grounded_caption": generated_text,
|
|
@@ -2909,7 +3026,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2909 |
output_attentions: Optional[bool] = None,
|
| 2910 |
output_hidden_states: Optional[bool] = True,
|
| 2911 |
return_dict: Optional[bool] = None,
|
| 2912 |
-
tokenizer
|
| 2913 |
) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
|
| 2914 |
assert output_hidden_states, "output_hidden_states must be True"
|
| 2915 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
@@ -2927,7 +3044,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2927 |
if pixel_values is not None:
|
| 2928 |
# (batch_size, num_image_tokens, hidden_size)
|
| 2929 |
image_features = self._encode_image(pixel_values)
|
| 2930 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
|
|
|
| 2931 |
|
| 2932 |
if inputs_embeds is not None:
|
| 2933 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
@@ -2949,10 +3067,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2949 |
return_dict=return_dict,
|
| 2950 |
)
|
| 2951 |
|
| 2952 |
-
character_character_affinity_matrices = self.get_character_character_affinity_matrices(
|
| 2953 |
-
|
| 2954 |
-
|
| 2955 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2956 |
|
| 2957 |
return Florence2Seq2SeqLMOutput(
|
| 2958 |
character_character_affinity_matrices=character_character_affinity_matrices,
|
|
@@ -2963,11 +3085,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2963 |
|
| 2964 |
def generate(
|
| 2965 |
self,
|
| 2966 |
-
input_ids,
|
| 2967 |
inputs_embeds=None,
|
| 2968 |
pixel_values=None,
|
| 2969 |
**kwargs
|
| 2970 |
-
|
| 2971 |
|
| 2972 |
if inputs_embeds is None:
|
| 2973 |
# 1. Extra the input embeddings
|
|
@@ -2976,14 +3098,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2976 |
# 2. Merge text and images
|
| 2977 |
if pixel_values is not None:
|
| 2978 |
image_features = self._encode_image(pixel_values)
|
| 2979 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
| 2980 |
-
|
|
|
|
| 2981 |
return self.language_model.generate(
|
| 2982 |
input_ids=None,
|
| 2983 |
inputs_embeds=inputs_embeds,
|
| 2984 |
**kwargs
|
| 2985 |
)
|
| 2986 |
-
|
| 2987 |
def slowly_generate_grounded_caption(
|
| 2988 |
self,
|
| 2989 |
input_ids,
|
|
@@ -2999,9 +3122,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2999 |
"""
|
| 3000 |
input_embeds = self.get_input_embeddings()(input_ids)
|
| 3001 |
image_features = self._encode_image(pixel_values)
|
| 3002 |
-
inputs_embeds, _ = self._merge_input_ids_with_image_features(
|
| 3003 |
-
|
| 3004 |
-
|
|
|
|
|
|
|
|
|
|
| 3005 |
running_decoder_input_ids = decoder_input_ids[:, :1]
|
| 3006 |
num_tokens_generated = 1
|
| 3007 |
CHUNK_SIZE = 8
|
|
@@ -3019,19 +3145,22 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 3019 |
)[:, -(CHUNK_SIZE+1):-1]
|
| 3020 |
|
| 3021 |
what_should_be_the_next_tokens = torch.stack([
|
| 3022 |
-
decoder_input_ids[i, running_indices[i]+
|
|
|
|
| 3023 |
for i in range(decoder_input_ids.shape[0])
|
| 3024 |
])
|
| 3025 |
# if the entire predicted chunk matches the next chunk, then we can saved some time and "jump" to the next chunk
|
| 3026 |
if predicted_next_tokens.shape[1] == what_should_be_the_next_tokens.shape[1] and torch.all(predicted_next_tokens == what_should_be_the_next_tokens):
|
| 3027 |
running_indices += CHUNK_SIZE
|
| 3028 |
-
running_decoder_input_ids = torch.cat(
|
|
|
|
| 3029 |
continue
|
| 3030 |
-
|
| 3031 |
# if, however, there is a deviation find the maximum prefix that matches in the batch
|
| 3032 |
|
| 3033 |
predicted_next_tokens = predicted_next_tokens[:, 0]
|
| 3034 |
-
predicted_next_token_strings = processor.batch_decode(
|
|
|
|
| 3035 |
next_tokens_to_concat = []
|
| 3036 |
for i, (pnts, pnt) in enumerate(zip(predicted_next_token_strings, predicted_next_tokens)):
|
| 3037 |
if (pnts.startswith("<loc_") or pnts in ["<s>", "<pad>", "</s>"]) and running_indices[i] < decoder_input_ids.shape[1] - 1:
|
|
@@ -3039,15 +3168,18 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 3039 |
else:
|
| 3040 |
running_indices[i] += 1
|
| 3041 |
if running_indices[i] >= decoder_input_ids.shape[1]:
|
| 3042 |
-
next_tokens_to_concat.append(torch.tensor(
|
|
|
|
| 3043 |
# elif "’" in pnts: # this is an annoying character which looks like ' (apostrophe) but isn't.
|
| 3044 |
# import pdb; pdb.set_trace()
|
| 3045 |
else:
|
| 3046 |
-
next_tokens_to_concat.append(
|
|
|
|
| 3047 |
next_tokens_to_concat = torch.stack(next_tokens_to_concat)[:, None]
|
| 3048 |
if (next_tokens_to_concat == processor.tokenizer.eos_token_id).all():
|
| 3049 |
break
|
| 3050 |
-
running_decoder_input_ids = torch.cat(
|
|
|
|
| 3051 |
if num_tokens_generated >= 1024:
|
| 3052 |
break
|
| 3053 |
return running_decoder_input_ids
|
|
@@ -3078,7 +3210,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 3078 |
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
| 3079 |
|
| 3080 |
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
| 3081 |
-
|
| 3082 |
return {
|
| 3083 |
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
| 3084 |
"encoder_outputs": encoder_outputs,
|
|
@@ -3090,105 +3222,146 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 3090 |
"head_mask": head_mask,
|
| 3091 |
"decoder_head_mask": decoder_head_mask,
|
| 3092 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 3093 |
-
|
|
|
|
| 3094 |
}
|
| 3095 |
-
|
| 3096 |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
| 3097 |
return self.language_model.shift_tokens_right(labels)
|
| 3098 |
|
| 3099 |
def _reorder_cache(self, *args, **kwargs):
|
| 3100 |
return self.language_model._reorder_cache(*args, **kwargs)
|
| 3101 |
-
|
| 3102 |
def get_character_character_affinity_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3103 |
character_character_affinity_matrices = []
|
| 3104 |
for index in range(len(decoder_hidden_states)):
|
| 3105 |
-
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3106 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3107 |
if character_embeddings.shape[0] == 0:
|
| 3108 |
-
character_character_affinity_matrices.append(
|
|
|
|
| 3109 |
continue
|
| 3110 |
-
character_embeddings = self.character_embedding_projection(
|
| 3111 |
-
|
| 3112 |
-
|
|
|
|
|
|
|
|
|
|
| 3113 |
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
|
| 3114 |
-
character_character_affinities = self.character_character_matching_head(
|
| 3115 |
-
|
| 3116 |
-
character_character_affinities = (
|
|
|
|
|
|
|
|
|
|
| 3117 |
if apply_sigmoid:
|
| 3118 |
-
character_character_affinities = torch.sigmoid(
|
| 3119 |
-
|
|
|
|
|
|
|
| 3120 |
return character_character_affinity_matrices
|
| 3121 |
|
| 3122 |
def get_text_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3123 |
text_character_association_matrices = []
|
| 3124 |
for index in range(len(decoder_hidden_states)):
|
| 3125 |
-
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3126 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3127 |
-
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3128 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3129 |
if character_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
|
| 3130 |
-
text_character_association_matrices.append(torch.zeros(
|
|
|
|
| 3131 |
continue
|
| 3132 |
-
text_i = repeat(text_embeddings, "i d -> i repeat d",
|
| 3133 |
-
|
| 3134 |
-
|
| 3135 |
-
|
| 3136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3137 |
if apply_sigmoid:
|
| 3138 |
-
text_character_affinities = torch.sigmoid(
|
| 3139 |
-
|
|
|
|
|
|
|
| 3140 |
return text_character_association_matrices
|
| 3141 |
|
| 3142 |
def get_text_tail_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3143 |
text_tail_association_matrices = []
|
| 3144 |
for index in range(len(decoder_hidden_states)):
|
| 3145 |
-
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3146 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3147 |
-
tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3148 |
tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
|
| 3149 |
if tail_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
|
| 3150 |
-
text_tail_association_matrices.append(torch.zeros(
|
|
|
|
| 3151 |
continue
|
| 3152 |
-
text_i = repeat(text_embeddings, "i d -> i repeat d",
|
| 3153 |
-
|
| 3154 |
-
|
|
|
|
|
|
|
|
|
|
| 3155 |
text_tail_affinities = self.text_tail_matching_head(text_tail_ij)
|
| 3156 |
-
text_tail_affinities = rearrange(
|
|
|
|
| 3157 |
if apply_sigmoid:
|
| 3158 |
text_tail_affinities = torch.sigmoid(text_tail_affinities)
|
| 3159 |
text_tail_association_matrices.append(text_tail_affinities)
|
| 3160 |
return text_tail_association_matrices
|
| 3161 |
-
|
| 3162 |
def get_tail_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3163 |
tail_character_association_matrices = []
|
| 3164 |
for index in range(len(decoder_hidden_states)):
|
| 3165 |
-
tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3166 |
tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
|
| 3167 |
-
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3168 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3169 |
if character_embeddings.shape[0] == 0 or tail_embeddings.shape[0] == 0:
|
| 3170 |
-
tail_character_association_matrices.append(torch.zeros(
|
|
|
|
| 3171 |
continue
|
| 3172 |
-
tail_i = repeat(tail_embeddings, "i d -> i repeat d",
|
| 3173 |
-
|
| 3174 |
-
|
| 3175 |
-
|
| 3176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3177 |
if apply_sigmoid:
|
| 3178 |
-
tail_character_affinities = torch.sigmoid(
|
| 3179 |
-
|
|
|
|
|
|
|
| 3180 |
return tail_character_association_matrices
|
| 3181 |
-
|
| 3182 |
def get_essential_text_logits(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3183 |
essential_text_logits = []
|
| 3184 |
for index in range(len(decoder_hidden_states)):
|
| 3185 |
-
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
|
|
|
| 3186 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3187 |
if text_embeddings.shape[0] == 0:
|
| 3188 |
-
essential_text_logits.append(
|
|
|
|
| 3189 |
continue
|
| 3190 |
-
text_logits = rearrange(
|
|
|
|
| 3191 |
if apply_sigmoid:
|
| 3192 |
text_logits = torch.sigmoid(text_logits)
|
| 3193 |
essential_text_logits.append(text_logits)
|
| 3194 |
-
return essential_text_logits
|
|
|
|
| 23 |
from torch import nn
|
| 24 |
import torch.nn.functional as F
|
| 25 |
import torch.utils.checkpoint as checkpoint
|
| 26 |
+
from torch.nn import CrossEntropyLoss
|
| 27 |
from collections import OrderedDict
|
| 28 |
from einops import rearrange, repeat
|
| 29 |
from timm.models.layers import DropPath, trunc_normal_
|
|
|
|
| 41 |
is_flash_attn_2_available,
|
| 42 |
is_flash_attn_greater_or_equal_2_10,
|
| 43 |
)
|
| 44 |
+
from .configuration_florence2 import Florence2Config
|
| 45 |
from .configuration_florence2 import Florence2LanguageConfig
|
| 46 |
from .configuration_florence2 import Florence2VisionConfig
|
| 47 |
from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_pairs_indices
|
|
|
|
| 72 |
|
| 73 |
_CONFIG_FOR_DOC = "Florence2Config"
|
| 74 |
|
| 75 |
+
|
| 76 |
class LearnedAbsolutePositionEmbedding2D(nn.Module):
|
| 77 |
"""
|
| 78 |
This module learns positional embeddings up to a fixed maximum size.
|
|
|
|
| 81 |
def __init__(self, embedding_dim=256, num_pos=50):
|
| 82 |
super().__init__()
|
| 83 |
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
|
| 84 |
+
self.column_embeddings = nn.Embedding(
|
| 85 |
+
num_pos, embedding_dim - (embedding_dim // 2))
|
| 86 |
|
| 87 |
def forward(self, pixel_values):
|
| 88 |
"""
|
|
|
|
| 97 |
x_emb = self.column_embeddings(width_values)
|
| 98 |
y_emb = self.row_embeddings(height_values)
|
| 99 |
# (height, width, embedding_dim * 2)
|
| 100 |
+
pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1),
|
| 101 |
+
y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
|
| 102 |
# (embedding_dim * 2, height, width)
|
| 103 |
pos = pos.permute(2, 0, 1)
|
| 104 |
pos = pos.unsqueeze(0)
|
|
|
|
| 108 |
pos = pos.permute(0, 2, 3, 1)
|
| 109 |
return pos
|
| 110 |
|
| 111 |
+
|
| 112 |
class PositionalEmbeddingCosine1D(nn.Module):
|
| 113 |
"""
|
| 114 |
This class implements a very simple positional encoding. It follows closely
|
|
|
|
| 120 |
dropout_prob: The dropout probability.
|
| 121 |
max_seq_len: The maximum length to precompute the positional encodings.
|
| 122 |
"""
|
| 123 |
+
|
| 124 |
def __init__(
|
| 125 |
self,
|
| 126 |
embed_dim: int = 512,
|
|
|
|
| 176 |
embed_dim: The dimension of the embeddings.
|
| 177 |
max_seq_len: The maximum length to precompute the positional encodings.
|
| 178 |
"""
|
| 179 |
+
|
| 180 |
def __init__(
|
| 181 |
self,
|
| 182 |
embedding_dim: int = 512,
|
|
|
|
| 202 |
len_seq = seq_embeds.size(-2)
|
| 203 |
assert len_seq <= self.num_pos
|
| 204 |
# [T, D]
|
| 205 |
+
pos_embeds = self.embeddings(
|
| 206 |
+
torch.arange(len_seq).to(seq_embeds.device))
|
| 207 |
# Adapt pre-computed positional embeddings to the input.
|
| 208 |
if shape_len == 3:
|
| 209 |
pos_embeds = pos_embeds.view(
|
|
|
|
| 211 |
return pos_embeds
|
| 212 |
|
| 213 |
|
|
|
|
| 214 |
class MySequential(nn.Sequential):
|
| 215 |
def forward(self, *inputs):
|
| 216 |
for module in self._modules.values():
|
|
|
|
| 355 |
def forward(self, x, size):
|
| 356 |
B, N, C = x.shape
|
| 357 |
|
| 358 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C //
|
| 359 |
+
self.groups).permute(2, 0, 3, 1, 4)
|
| 360 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 361 |
|
| 362 |
q = q * (float(N) ** -0.5)
|
|
|
|
| 375 |
conv_at_attn=True, conv_at_ffn=True):
|
| 376 |
super().__init__()
|
| 377 |
|
| 378 |
+
drop_path = DropPath(
|
| 379 |
+
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 380 |
|
| 381 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
| 382 |
+
dim, 3, 1, 1)) if conv_at_attn else None
|
| 383 |
self.channel_attn = PreNorm(
|
| 384 |
norm_layer(dim),
|
| 385 |
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
|
| 386 |
drop_path
|
| 387 |
)
|
| 388 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(
|
| 389 |
+
dim, 3, 1, 1)) if conv_at_ffn else None
|
| 390 |
self.ffn = PreNorm(
|
| 391 |
norm_layer(dim),
|
| 392 |
+
Mlp(in_features=dim, hidden_features=int(
|
| 393 |
+
dim*mlp_ratio), act_layer=act_layer),
|
| 394 |
drop_path
|
| 395 |
)
|
| 396 |
|
|
|
|
| 408 |
|
| 409 |
def window_partition(x, window_size: int):
|
| 410 |
B, H, W, C = x.shape
|
| 411 |
+
x = x.view(B, H // window_size, window_size,
|
| 412 |
+
W // window_size, window_size, C)
|
| 413 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
|
| 414 |
+
).view(-1, window_size, window_size, C)
|
| 415 |
return windows
|
| 416 |
|
| 417 |
|
| 418 |
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
|
| 419 |
+
B = batch_size
|
| 420 |
# this will cause onnx conversion failed for dynamic axis, because treated as constant
|
| 421 |
+
# int(windows.shape[0] / (H * W / window_size / window_size))
|
| 422 |
+
x = windows.view(B, H // window_size, W // window_size,
|
| 423 |
+
window_size, window_size, -1)
|
| 424 |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
| 425 |
return x
|
| 426 |
|
|
|
|
| 461 |
# attn_windows = self.attn(x_windows)
|
| 462 |
|
| 463 |
B_, N, C = x.shape
|
| 464 |
+
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
|
| 465 |
+
self.num_heads).permute(2, 0, 3, 1, 4)
|
| 466 |
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 467 |
|
| 468 |
q = q * self.scale
|
|
|
|
| 493 |
norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
|
| 494 |
super().__init__()
|
| 495 |
|
| 496 |
+
drop_path = DropPath(
|
| 497 |
+
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 498 |
|
| 499 |
+
self.conv1 = PreNorm(None, DepthWiseConv2d(
|
| 500 |
+
dim, 3, 1, 1)) if conv_at_attn else None
|
| 501 |
self.window_attn = PreNorm(
|
| 502 |
norm_layer(dim),
|
| 503 |
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
|
| 504 |
drop_path
|
| 505 |
)
|
| 506 |
+
self.conv2 = PreNorm(None, DepthWiseConv2d(
|
| 507 |
+
dim, 3, 1, 1)) if conv_at_ffn else None
|
| 508 |
self.ffn = PreNorm(
|
| 509 |
norm_layer(dim),
|
| 510 |
+
Mlp(in_features=dim, hidden_features=int(
|
| 511 |
+
dim*mlp_ratio), act_layer=act_layer),
|
| 512 |
drop_path
|
| 513 |
)
|
| 514 |
|
|
|
|
| 566 |
enable_checkpoint=True,
|
| 567 |
conv_at_attn=True,
|
| 568 |
conv_at_ffn=True,
|
| 569 |
+
):
|
| 570 |
super().__init__()
|
| 571 |
|
| 572 |
self.num_classes = num_classes
|
|
|
|
| 578 |
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
|
| 579 |
|
| 580 |
num_stages = len(embed_dims)
|
| 581 |
+
dpr = [x.item() for x in torch.linspace(
|
| 582 |
+
0, drop_path_rate, sum(depths)*2)]
|
| 583 |
|
| 584 |
depth_offset = 0
|
| 585 |
convs = []
|
|
|
|
| 633 |
|
| 634 |
self.norms = norm_layer(self.embed_dims[-1])
|
| 635 |
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 636 |
+
self.head = nn.Linear(
|
| 637 |
+
self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 638 |
|
| 639 |
self.apply(self._init_weights)
|
| 640 |
|
|
|
|
| 669 |
for conv, block in zip(self.convs, self.blocks):
|
| 670 |
x, input_size = conv(x, input_size)
|
| 671 |
if self.enable_checkpoint:
|
| 672 |
+
x, input_size = checkpoint.checkpoint(
|
| 673 |
+
block, x, input_size, use_reentrant=True)
|
| 674 |
else:
|
| 675 |
x, input_size = block(x, input_size)
|
| 676 |
return x
|
|
|
|
| 690 |
x = self.forward_features(x)
|
| 691 |
x = self.head(x)
|
| 692 |
return x
|
| 693 |
+
|
| 694 |
@classmethod
|
| 695 |
def from_config(cls, config):
|
| 696 |
return cls(
|
|
|
|
| 707 |
)
|
| 708 |
|
| 709 |
|
|
|
|
|
|
|
| 710 |
if is_flash_attn_2_available():
|
| 711 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 712 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 713 |
|
| 714 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 715 |
+
|
| 716 |
+
|
| 717 |
def _get_unpad_data(attention_mask):
|
| 718 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 719 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 720 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 721 |
+
cu_seqlens = F.pad(torch.cumsum(
|
| 722 |
+
seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
| 723 |
return (
|
| 724 |
indices,
|
| 725 |
cu_seqlens,
|
|
|
|
| 857 |
if past_key_value[0] is not None:
|
| 858 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 859 |
if past_key_value[1] is not None:
|
| 860 |
+
value_states = torch.cat(
|
| 861 |
+
[past_key_value[1], value_states], dim=2)
|
| 862 |
else:
|
| 863 |
# self_attention
|
| 864 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
|
|
| 875 |
past_key_value = (key_states, value_states)
|
| 876 |
|
| 877 |
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
| 878 |
+
query_states = self._shape(
|
| 879 |
+
query_states, tgt_len, bsz).view(*proj_shape)
|
| 880 |
key_states = key_states.reshape(*proj_shape)
|
| 881 |
value_states = value_states.reshape(*proj_shape)
|
| 882 |
|
|
|
|
| 894 |
raise ValueError(
|
| 895 |
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
| 896 |
)
|
| 897 |
+
attn_weights = attn_weights.view(
|
| 898 |
+
bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
| 899 |
+
attn_weights = attn_weights.view(
|
| 900 |
+
bsz * self.num_heads, tgt_len, src_len)
|
| 901 |
|
| 902 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
| 903 |
|
|
|
|
| 907 |
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
| 908 |
f" {layer_head_mask.size()}"
|
| 909 |
)
|
| 910 |
+
attn_weights = layer_head_mask.view(
|
| 911 |
+
1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 912 |
+
attn_weights = attn_weights.view(
|
| 913 |
+
bsz * self.num_heads, tgt_len, src_len)
|
| 914 |
|
| 915 |
if output_attentions:
|
| 916 |
# this operation is a bit awkward, but it's required to
|
| 917 |
# make sure that attn_weights keeps its gradient.
|
| 918 |
# In order to do so, attn_weights have to be reshaped
|
| 919 |
# twice and have to be reused in the following
|
| 920 |
+
attn_weights_reshaped = attn_weights.view(
|
| 921 |
+
bsz, self.num_heads, tgt_len, src_len)
|
| 922 |
+
attn_weights = attn_weights_reshaped.view(
|
| 923 |
+
bsz * self.num_heads, tgt_len, src_len)
|
| 924 |
else:
|
| 925 |
attn_weights_reshaped = None
|
| 926 |
|
| 927 |
+
attn_probs = nn.functional.dropout(
|
| 928 |
+
attn_weights, p=self.dropout, training=self.training)
|
| 929 |
|
| 930 |
attn_output = torch.bmm(attn_probs, value_states)
|
| 931 |
|
|
|
|
| 935 |
f" {attn_output.size()}"
|
| 936 |
)
|
| 937 |
|
| 938 |
+
attn_output = attn_output.view(
|
| 939 |
+
bsz, self.num_heads, tgt_len, self.head_dim)
|
| 940 |
attn_output = attn_output.transpose(1, 2)
|
| 941 |
|
| 942 |
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
|
|
|
| 978 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 979 |
# Florence2FlashAttention2 attention does not support output_attentions
|
| 980 |
if output_attentions:
|
| 981 |
+
raise ValueError(
|
| 982 |
+
"Florence2FlashAttention2 attention does not support output_attentions")
|
| 983 |
|
| 984 |
# if key_value_states are provided this layer is used as a cross-attention layer
|
| 985 |
# for the decoder
|
|
|
|
| 1004 |
elif is_cross_attention:
|
| 1005 |
# cross_attentions
|
| 1006 |
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
| 1007 |
+
value_states = self._reshape(
|
| 1008 |
+
self.v_proj(key_value_states), -1, bsz)
|
| 1009 |
elif past_key_value is not None:
|
| 1010 |
# reuse k, v, self_attention
|
| 1011 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
| 1012 |
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
| 1013 |
+
key_states = torch.cat(
|
| 1014 |
+
[past_key_value[0].transpose(1, 2), key_states], dim=1)
|
| 1015 |
+
value_states = torch.cat(
|
| 1016 |
+
[past_key_value[1].transpose(1, 2), value_states], dim=1)
|
| 1017 |
else:
|
| 1018 |
# self_attention
|
| 1019 |
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
|
|
|
| 1027 |
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
| 1028 |
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
| 1029 |
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
| 1030 |
+
past_key_value = (key_states.transpose(
|
| 1031 |
+
1, 2), value_states.transpose(1, 2))
|
| 1032 |
|
| 1033 |
kv_seq_len = key_states.shape[-2]
|
| 1034 |
if past_key_value is not None:
|
|
|
|
| 1124 |
causal=causal,
|
| 1125 |
)
|
| 1126 |
|
| 1127 |
+
attn_output = pad_input(
|
| 1128 |
+
attn_output_unpad, indices_q, batch_size, query_length)
|
| 1129 |
else:
|
| 1130 |
attn_output = flash_attn_func(
|
| 1131 |
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
|
|
|
| 1135 |
|
| 1136 |
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
|
| 1137 |
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 1138 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
|
| 1139 |
+
attention_mask)
|
| 1140 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 1141 |
|
| 1142 |
key_layer = index_first_axis(
|
| 1143 |
+
key_layer.reshape(batch_size * kv_seq_len,
|
| 1144 |
+
num_key_value_heads, head_dim), indices_k
|
| 1145 |
)
|
| 1146 |
value_layer = index_first_axis(
|
| 1147 |
+
value_layer.reshape(batch_size * kv_seq_len,
|
| 1148 |
+
num_key_value_heads, head_dim), indices_k
|
| 1149 |
)
|
| 1150 |
if query_length == kv_seq_len:
|
| 1151 |
query_layer = index_first_axis(
|
| 1152 |
+
query_layer.reshape(batch_size * kv_seq_len,
|
| 1153 |
+
self.num_heads, head_dim), indices_k
|
| 1154 |
)
|
| 1155 |
cu_seqlens_q = cu_seqlens_k
|
| 1156 |
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
|
|
|
| 1165 |
else:
|
| 1166 |
# The -q_len: slice assumes left padding.
|
| 1167 |
attention_mask = attention_mask[:, -query_length:]
|
| 1168 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
| 1169 |
+
query_layer, attention_mask)
|
| 1170 |
|
| 1171 |
return (
|
| 1172 |
query_layer,
|
|
|
|
| 1236 |
if past_key_value[0] is not None:
|
| 1237 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 1238 |
if past_key_value[1] is not None:
|
| 1239 |
+
value_states = torch.cat(
|
| 1240 |
+
[past_key_value[1], value_states], dim=2)
|
| 1241 |
else:
|
| 1242 |
# self_attention
|
| 1243 |
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
|
|
| 1339 |
layer_head_mask=layer_head_mask,
|
| 1340 |
output_attentions=output_attentions,
|
| 1341 |
)
|
| 1342 |
+
hidden_states = nn.functional.dropout(
|
| 1343 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1344 |
hidden_states = residual + hidden_states
|
| 1345 |
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1346 |
|
| 1347 |
residual = hidden_states
|
| 1348 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1349 |
+
hidden_states = nn.functional.dropout(
|
| 1350 |
+
hidden_states, p=self.activation_dropout, training=self.training)
|
| 1351 |
hidden_states = self.fc2(hidden_states)
|
| 1352 |
+
hidden_states = nn.functional.dropout(
|
| 1353 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1354 |
hidden_states = residual + hidden_states
|
| 1355 |
hidden_states = self.final_layer_norm(hidden_states)
|
| 1356 |
|
| 1357 |
if hidden_states.dtype == torch.float16 and (
|
| 1358 |
+
torch.isinf(hidden_states).any() or torch.isnan(
|
| 1359 |
+
hidden_states).any()
|
| 1360 |
):
|
| 1361 |
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
|
| 1362 |
+
hidden_states = torch.clamp(
|
| 1363 |
+
hidden_states, min=-clamp_value, max=clamp_value)
|
| 1364 |
|
| 1365 |
outputs = (hidden_states,)
|
| 1366 |
|
|
|
|
| 1434 |
|
| 1435 |
# Self Attention
|
| 1436 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 1437 |
+
self_attn_past_key_value = past_key_value[:
|
| 1438 |
+
2] if past_key_value is not None else None
|
| 1439 |
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
| 1440 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 1441 |
hidden_states=hidden_states,
|
|
|
|
| 1444 |
layer_head_mask=layer_head_mask,
|
| 1445 |
output_attentions=output_attentions,
|
| 1446 |
)
|
| 1447 |
+
hidden_states = nn.functional.dropout(
|
| 1448 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1449 |
hidden_states = residual + hidden_states
|
| 1450 |
hidden_states = self.self_attn_layer_norm(hidden_states)
|
| 1451 |
|
|
|
|
| 1456 |
residual = hidden_states
|
| 1457 |
|
| 1458 |
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
| 1459 |
+
cross_attn_past_key_value = past_key_value[-2:
|
| 1460 |
+
] if past_key_value is not None else None
|
| 1461 |
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
| 1462 |
hidden_states=hidden_states,
|
| 1463 |
key_value_states=encoder_hidden_states,
|
|
|
|
| 1466 |
past_key_value=cross_attn_past_key_value,
|
| 1467 |
output_attentions=output_attentions,
|
| 1468 |
)
|
| 1469 |
+
hidden_states = nn.functional.dropout(
|
| 1470 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1471 |
hidden_states = residual + hidden_states
|
| 1472 |
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
| 1473 |
|
|
|
|
| 1477 |
# Fully Connected
|
| 1478 |
residual = hidden_states
|
| 1479 |
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
| 1480 |
+
hidden_states = nn.functional.dropout(
|
| 1481 |
+
hidden_states, p=self.activation_dropout, training=self.training)
|
| 1482 |
hidden_states = self.fc2(hidden_states)
|
| 1483 |
+
hidden_states = nn.functional.dropout(
|
| 1484 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1485 |
hidden_states = residual + hidden_states
|
| 1486 |
hidden_states = self.final_layer_norm(hidden_states)
|
| 1487 |
|
|
|
|
| 1496 |
return outputs
|
| 1497 |
|
| 1498 |
|
|
|
|
| 1499 |
class Florence2LanguagePreTrainedModel(PreTrainedModel):
|
| 1500 |
config_class = Florence2LanguageConfig
|
| 1501 |
base_model_prefix = "model"
|
|
|
|
| 1520 |
@property
|
| 1521 |
def dummy_inputs(self):
|
| 1522 |
pad_token = self.config.pad_token_id
|
| 1523 |
+
input_ids = torch.tensor(
|
| 1524 |
+
[[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
| 1525 |
dummy_inputs = {
|
| 1526 |
"attention_mask": input_ids.ne(pad_token),
|
| 1527 |
"input_ids": input_ids,
|
|
|
|
| 1561 |
config.max_position_embeddings,
|
| 1562 |
embed_dim,
|
| 1563 |
)
|
| 1564 |
+
self.layers = nn.ModuleList([Florence2EncoderLayer(
|
| 1565 |
+
config) for _ in range(config.encoder_layers)])
|
| 1566 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 1567 |
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 1568 |
self.layernorm_embedding = nn.LayerNorm(embed_dim)
|
|
|
|
| 1631 |
|
| 1632 |
# retrieve input_ids and inputs_embeds
|
| 1633 |
if input_ids is not None and inputs_embeds is not None:
|
| 1634 |
+
raise ValueError(
|
| 1635 |
+
"You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1636 |
elif input_ids is not None:
|
| 1637 |
input = input_ids
|
| 1638 |
input_ids = input_ids.view(-1, input_ids.shape[-1])
|
| 1639 |
elif inputs_embeds is not None:
|
| 1640 |
input = inputs_embeds[:, :, -1]
|
| 1641 |
else:
|
| 1642 |
+
raise ValueError(
|
| 1643 |
+
"You have to specify either input_ids or inputs_embeds")
|
| 1644 |
|
| 1645 |
if inputs_embeds is None:
|
| 1646 |
inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
| 1650 |
|
| 1651 |
hidden_states = inputs_embeds + embed_pos
|
| 1652 |
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1653 |
+
hidden_states = nn.functional.dropout(
|
| 1654 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1655 |
|
| 1656 |
# expand attention_mask
|
| 1657 |
if attention_mask is not None:
|
|
|
|
| 1661 |
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
|
| 1662 |
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1663 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1664 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
| 1665 |
+
attention_mask, inputs_embeds.dtype)
|
| 1666 |
else:
|
| 1667 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 1668 |
+
attention_mask = _prepare_4d_attention_mask(
|
| 1669 |
+
attention_mask, inputs_embeds.dtype)
|
| 1670 |
|
| 1671 |
encoder_states = () if output_hidden_states else None
|
| 1672 |
all_attentions = () if output_attentions else None
|
|
|
|
| 1704 |
layer_outputs = encoder_layer(
|
| 1705 |
hidden_states,
|
| 1706 |
attention_mask,
|
| 1707 |
+
layer_head_mask=(
|
| 1708 |
+
head_mask[idx] if head_mask is not None else None),
|
| 1709 |
output_attentions=output_attentions,
|
| 1710 |
)
|
| 1711 |
|
|
|
|
| 1739 |
self.layerdrop = config.decoder_layerdrop
|
| 1740 |
self.padding_idx = config.pad_token_id
|
| 1741 |
self.max_target_positions = config.max_position_embeddings
|
| 1742 |
+
embed_scale = math.sqrt(
|
| 1743 |
+
config.d_model) if config.scale_embedding else 1.0
|
| 1744 |
|
| 1745 |
self.embed_tokens = Florence2ScaledWordEmbedding(
|
| 1746 |
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
|
|
|
|
| 1753 |
config.max_position_embeddings,
|
| 1754 |
config.d_model,
|
| 1755 |
)
|
| 1756 |
+
self.layers = nn.ModuleList([Florence2DecoderLayer(
|
| 1757 |
+
config) for _ in range(config.decoder_layers)])
|
| 1758 |
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 1759 |
self._use_sdpa = config._attn_implementation == "sdpa"
|
| 1760 |
|
|
|
|
| 1859 |
|
| 1860 |
# retrieve input_ids and inputs_embeds
|
| 1861 |
if input_ids is not None and inputs_embeds is not None:
|
| 1862 |
+
raise ValueError(
|
| 1863 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 1864 |
elif input_ids is not None:
|
| 1865 |
input = input_ids
|
| 1866 |
input_shape = input.shape
|
|
|
|
| 1869 |
input_shape = inputs_embeds.size()[:-1]
|
| 1870 |
input = inputs_embeds[:, :, -1]
|
| 1871 |
else:
|
| 1872 |
+
raise ValueError(
|
| 1873 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 1874 |
|
| 1875 |
# past_key_values_length
|
| 1876 |
+
past_key_values_length = past_key_values[0][0].shape[
|
| 1877 |
+
2] if past_key_values and past_key_values[0] and past_key_values[0][0] is not None else 0
|
| 1878 |
|
| 1879 |
if inputs_embeds is None:
|
| 1880 |
inputs_embeds = self.embed_tokens(input)
|
| 1881 |
|
| 1882 |
if self._use_flash_attention_2:
|
| 1883 |
# 2d mask is passed through the layers
|
| 1884 |
+
attention_mask = attention_mask if (
|
| 1885 |
+
attention_mask is not None and 0 in attention_mask) else None
|
| 1886 |
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
|
| 1887 |
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
|
| 1888 |
# the manual implementation that requires a 4D causal mask in all cases.
|
|
|
|
| 1924 |
hidden_states = inputs_embeds + positions
|
| 1925 |
hidden_states = self.layernorm_embedding(hidden_states)
|
| 1926 |
|
| 1927 |
+
hidden_states = nn.functional.dropout(
|
| 1928 |
+
hidden_states, p=self.dropout, training=self.training)
|
| 1929 |
|
| 1930 |
if self.gradient_checkpointing and self.training:
|
| 1931 |
if use_cache:
|
|
|
|
| 1937 |
# decoder layers
|
| 1938 |
all_hidden_states = () if output_hidden_states else None
|
| 1939 |
all_self_attns = () if output_attentions else None
|
| 1940 |
+
all_cross_attentions = () if (
|
| 1941 |
+
output_attentions and encoder_hidden_states is not None) else None
|
| 1942 |
next_decoder_cache = () if use_cache else None
|
| 1943 |
|
| 1944 |
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
|
|
|
| 1980 |
attention_mask=attention_mask,
|
| 1981 |
encoder_hidden_states=encoder_hidden_states,
|
| 1982 |
encoder_attention_mask=encoder_attention_mask,
|
| 1983 |
+
layer_head_mask=(
|
| 1984 |
+
head_mask[idx] if head_mask is not None else None),
|
| 1985 |
cross_attn_layer_head_mask=(
|
| 1986 |
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
| 1987 |
),
|
|
|
|
| 1992 |
hidden_states = layer_outputs[0]
|
| 1993 |
|
| 1994 |
if use_cache:
|
| 1995 |
+
next_decoder_cache += (
|
| 1996 |
+
layer_outputs[3 if output_attentions else 1],)
|
| 1997 |
|
| 1998 |
if output_attentions:
|
| 1999 |
all_self_attns += (layer_outputs[1],)
|
|
|
|
| 2022 |
|
| 2023 |
|
| 2024 |
class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
|
| 2025 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight",
|
| 2026 |
+
"decoder.embed_tokens.weight"]
|
| 2027 |
|
| 2028 |
def __init__(self, config: Florence2LanguageConfig):
|
| 2029 |
super().__init__(config)
|
|
|
|
| 2109 |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 2110 |
encoder_outputs = BaseModelOutput(
|
| 2111 |
last_hidden_state=encoder_outputs[0],
|
| 2112 |
+
hidden_states=encoder_outputs[1] if len(
|
| 2113 |
+
encoder_outputs) > 1 else None,
|
| 2114 |
+
attentions=encoder_outputs[2] if len(
|
| 2115 |
+
encoder_outputs) > 2 else None,
|
| 2116 |
)
|
| 2117 |
|
| 2118 |
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
|
|
|
|
| 2148 |
|
| 2149 |
class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
|
| 2150 |
base_model_prefix = "model"
|
| 2151 |
+
_tied_weights_keys = ["encoder.embed_tokens.weight",
|
| 2152 |
+
"decoder.embed_tokens.weight", "lm_head.weight"]
|
| 2153 |
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
|
| 2154 |
|
| 2155 |
def __init__(self, config: Florence2LanguageConfig):
|
| 2156 |
super().__init__(config)
|
| 2157 |
self.model = Florence2LanguageModel(config)
|
| 2158 |
+
self.register_buffer("final_logits_bias", torch.zeros(
|
| 2159 |
+
(1, self.model.shared.num_embeddings)))
|
| 2160 |
+
self.lm_head = nn.Linear(
|
| 2161 |
+
config.d_model, self.model.shared.num_embeddings, bias=False)
|
| 2162 |
|
| 2163 |
# Initialize weights and apply final processing
|
| 2164 |
self.post_init()
|
|
|
|
| 2170 |
return self.model.get_decoder()
|
| 2171 |
|
| 2172 |
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
| 2173 |
+
new_embeddings = super().resize_token_embeddings(
|
| 2174 |
+
new_num_tokens, pad_to_multiple_of)
|
| 2175 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
| 2176 |
return new_embeddings
|
| 2177 |
|
|
|
|
| 2180 |
if new_num_tokens <= old_num_tokens:
|
| 2181 |
new_bias = self.final_logits_bias[:, :new_num_tokens]
|
| 2182 |
else:
|
| 2183 |
+
extra_bias = torch.zeros(
|
| 2184 |
+
(1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
|
| 2185 |
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
|
| 2186 |
self.register_buffer("final_logits_bias", new_bias)
|
| 2187 |
|
|
|
|
| 2222 |
|
| 2223 |
if labels is not None:
|
| 2224 |
if use_cache:
|
| 2225 |
+
logger.warning(
|
| 2226 |
+
"The `use_cache` argument is changed to `False` since `labels` is provided.")
|
| 2227 |
use_cache = False
|
| 2228 |
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 2229 |
decoder_input_ids = shift_tokens_right(
|
|
|
|
| 2255 |
if labels is not None:
|
| 2256 |
labels = labels.to(lm_logits.device)
|
| 2257 |
loss_fct = CrossEntropyLoss()
|
| 2258 |
+
masked_lm_loss = loss_fct(
|
| 2259 |
+
lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 2260 |
|
| 2261 |
if not return_dict:
|
| 2262 |
output = (lm_logits,) + outputs[1:]
|
|
|
|
| 2310 |
"head_mask": head_mask,
|
| 2311 |
"decoder_head_mask": decoder_head_mask,
|
| 2312 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 2313 |
+
# change this to avoid caching (presumably for debugging)
|
| 2314 |
+
"use_cache": use_cache,
|
| 2315 |
}
|
| 2316 |
|
| 2317 |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
|
|
|
| 2323 |
for layer_past in past_key_values:
|
| 2324 |
# cached cross_attention states don't have to be reordered -> they are always the same
|
| 2325 |
reordered_past += (
|
| 2326 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device))
|
| 2327 |
+
for past_state in layer_past[:2])
|
| 2328 |
+ layer_past[2:],
|
| 2329 |
)
|
| 2330 |
return reordered_past
|
| 2331 |
|
| 2332 |
+
|
| 2333 |
@dataclass
|
| 2334 |
class Florence2Seq2SeqLMOutput(ModelOutput):
|
| 2335 |
"""
|
|
|
|
| 2515 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 2516 |
"""
|
| 2517 |
|
| 2518 |
+
|
| 2519 |
@add_start_docstrings(
|
| 2520 |
"""The FLORENCE2 vision model without any head""",
|
| 2521 |
FLORENCE2_START_DOCSTRING,
|
|
|
|
| 2523 |
class Florence2VisionModel(Florence2PreTrainedModel):
|
| 2524 |
def __init__(self, config: Florence2VisionConfig):
|
| 2525 |
super().__init__(config)
|
| 2526 |
+
assert config.model_type in [
|
| 2527 |
+
'davit', ""], 'only DaViT is supported for now'
|
| 2528 |
self.vision_tower = DaViT.from_config(config=config)
|
| 2529 |
|
| 2530 |
self.post_init()
|
| 2531 |
+
|
| 2532 |
def forward(self, pixel_values):
|
| 2533 |
if len(pixel_values.shape) == 4:
|
| 2534 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
|
|
|
| 2544 |
class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
| 2545 |
def __init__(self, config: Florence2VisionConfig):
|
| 2546 |
super().__init__(config)
|
| 2547 |
+
assert config.model_type in [
|
| 2548 |
+
'davit', ''], 'only DaViT is supported for now'
|
| 2549 |
self.vision_tower = DaViT.from_config(config=config)
|
| 2550 |
|
| 2551 |
self._build_image_projection_layers(config)
|
| 2552 |
|
| 2553 |
self.post_init()
|
| 2554 |
+
|
| 2555 |
def _build_image_projection_layers(self, config):
|
| 2556 |
image_dim_out = config.dim_embed[-1]
|
| 2557 |
dim_projection = config.projection_dim
|
|
|
|
| 2587 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
| 2588 |
else:
|
| 2589 |
raise ValueError(f'invalid image shape {pixel_values.shape}')
|
| 2590 |
+
|
| 2591 |
if self.image_pos_embed is not None:
|
| 2592 |
x = x.view(batch_size * T, -1, x.shape[-1])
|
| 2593 |
num_tokens = x.shape[-2]
|
|
|
|
| 2599 |
x = x.view(batch_size, T * h*w, x.shape[-1])
|
| 2600 |
|
| 2601 |
if self.visual_temporal_embed is not None:
|
| 2602 |
+
visual_temporal_embed = self.visual_temporal_embed(
|
| 2603 |
+
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
|
| 2604 |
+
x = x.view(
|
| 2605 |
+
batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
| 2606 |
|
| 2607 |
x_feat_dict = {}
|
| 2608 |
|
| 2609 |
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
| 2610 |
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
|
| 2611 |
|
| 2612 |
+
temporal_avg_pool_x = x.view(
|
| 2613 |
+
batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
| 2614 |
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
|
| 2615 |
|
| 2616 |
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
|
|
|
| 2619 |
new_x = []
|
| 2620 |
for _image_feature_source in self.image_feature_source:
|
| 2621 |
if _image_feature_source not in x_feat_dict:
|
| 2622 |
+
raise ValueError(
|
| 2623 |
+
'invalid image feature source: {}'.format(_image_feature_source))
|
| 2624 |
new_x.append(x_feat_dict[_image_feature_source])
|
| 2625 |
|
| 2626 |
x = torch.cat(new_x, dim=1)
|
|
|
|
| 2628 |
x = x @ self.image_projection
|
| 2629 |
x = self.image_proj_norm(x)
|
| 2630 |
|
|
|
|
| 2631 |
return x
|
| 2632 |
|
| 2633 |
|
|
|
|
| 2634 |
@add_start_docstrings(
|
| 2635 |
"""The FLORENCE2 model which consists of a vision backbone and a language model.""",
|
| 2636 |
FLORENCE2_START_DOCSTRING,
|
|
|
|
| 2638 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
| 2639 |
def __init__(self, config: Florence2Config):
|
| 2640 |
super().__init__(config)
|
| 2641 |
+
assert config.vision_config.model_type in [
|
| 2642 |
+
'davit', ''], 'only DaViT is supported for now'
|
| 2643 |
self.vision_tower = DaViT.from_config(config=config.vision_config)
|
| 2644 |
+
# remove unused layers
|
| 2645 |
del self.vision_tower.head
|
| 2646 |
del self.vision_tower.norms
|
| 2647 |
|
|
|
|
| 2649 |
self._attn_implementation = config._attn_implementation
|
| 2650 |
self._build_image_projection_layers(config)
|
| 2651 |
|
| 2652 |
+
language_model = Florence2LanguageForConditionalGeneration(
|
| 2653 |
+
config=config.text_config)
|
| 2654 |
|
| 2655 |
if language_model._tied_weights_keys is not None:
|
| 2656 |
+
self._tied_weights_keys = [
|
| 2657 |
+
f"language_model.{k}" for k in language_model._tied_weights_keys]
|
| 2658 |
self.language_model = language_model
|
| 2659 |
self.character_character_matching_head = nn.Sequential(
|
| 2660 |
nn.Linear(2 * 768, config.projection_dim),
|
|
|
|
| 2678 |
nn.Linear(config.projection_dim, 1)
|
| 2679 |
)
|
| 2680 |
self.text_classification_head = nn.Linear(config.projection_dim, 1)
|
| 2681 |
+
self.character_embedding_projection = nn.Linear(
|
| 2682 |
+
config.projection_dim, 768)
|
| 2683 |
|
| 2684 |
self._init_weights(self.character_character_matching_head)
|
| 2685 |
self._init_weights(self.text_character_matching_head)
|
| 2686 |
self._init_weights(self.text_tail_matching_head)
|
| 2687 |
self._init_weights(self.text_classification_head)
|
| 2688 |
+
|
| 2689 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
| 2690 |
self.post_init()
|
| 2691 |
+
|
| 2692 |
def _init_weights(self, m):
|
| 2693 |
if isinstance(m, nn.Linear):
|
| 2694 |
trunc_normal_(m.weight, std=0.02)
|
|
|
|
| 2708 |
elif isinstance(m, nn.Sequential):
|
| 2709 |
for layer in m:
|
| 2710 |
self._init_weights(layer)
|
| 2711 |
+
|
| 2712 |
def _build_image_projection_layers(self, config):
|
| 2713 |
image_dim_out = config.vision_config.dim_embed[-1]
|
| 2714 |
dim_projection = config.vision_config.projection_dim
|
|
|
|
| 2747 |
return self.language_model.get_input_embeddings()
|
| 2748 |
|
| 2749 |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 2750 |
+
model_embeds = self.language_model.resize_token_embeddings(
|
| 2751 |
+
new_num_tokens, pad_to_multiple_of)
|
| 2752 |
# update vocab size
|
| 2753 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 2754 |
self.config.vocab_size = model_embeds.num_embeddings
|
| 2755 |
self.vocab_size = model_embeds.num_embeddings
|
| 2756 |
return model_embeds
|
| 2757 |
+
|
| 2758 |
def _encode_image(self, pixel_values):
|
| 2759 |
if len(pixel_values.shape) == 4:
|
| 2760 |
batch_size, C, H, W = pixel_values.shape
|
|
|
|
| 2762 |
x = self.vision_tower.forward_features_unpool(pixel_values)
|
| 2763 |
else:
|
| 2764 |
raise ValueError(f'invalid image shape {pixel_values.shape}')
|
| 2765 |
+
|
| 2766 |
if self.image_pos_embed is not None:
|
| 2767 |
x = x.view(batch_size * T, -1, x.shape[-1])
|
| 2768 |
num_tokens = x.shape[-2]
|
|
|
|
| 2774 |
x = x.view(batch_size, T * h*w, x.shape[-1])
|
| 2775 |
|
| 2776 |
if self.visual_temporal_embed is not None:
|
| 2777 |
+
visual_temporal_embed = self.visual_temporal_embed(
|
| 2778 |
+
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
|
| 2779 |
+
x = x.view(
|
| 2780 |
+
batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
|
| 2781 |
|
| 2782 |
x_feat_dict = {}
|
| 2783 |
|
| 2784 |
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
|
| 2785 |
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
|
| 2786 |
|
| 2787 |
+
temporal_avg_pool_x = x.view(
|
| 2788 |
+
batch_size, T, -1, x.shape[-1]).mean(dim=1)
|
| 2789 |
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
|
| 2790 |
|
| 2791 |
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
|
|
|
|
| 2794 |
new_x = []
|
| 2795 |
for _image_feature_source in self.image_feature_source:
|
| 2796 |
if _image_feature_source not in x_feat_dict:
|
| 2797 |
+
raise ValueError(
|
| 2798 |
+
'invalid image feature source: {}'.format(_image_feature_source))
|
| 2799 |
new_x.append(x_feat_dict[_image_feature_source])
|
| 2800 |
|
| 2801 |
x = torch.cat(new_x, dim=1)
|
|
|
|
| 2803 |
x = x @ self.image_projection
|
| 2804 |
x = self.image_proj_norm(x)
|
| 2805 |
|
| 2806 |
+
return x
|
| 2807 |
|
| 2808 |
def _merge_input_ids_with_image_features(
|
| 2809 |
+
self, image_features, inputs_embeds
|
| 2810 |
):
|
| 2811 |
batch_size, image_token_length = image_features.size()[:-1]
|
| 2812 |
device = image_features.device
|
| 2813 |
+
image_attention_mask = torch.ones(
|
| 2814 |
+
batch_size, image_token_length, device=device)
|
| 2815 |
|
| 2816 |
# task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
|
| 2817 |
# task_prefix_attention_mask: [batch_size, context_length]
|
|
|
|
| 2819 |
return image_features, image_attention_mask
|
| 2820 |
|
| 2821 |
task_prefix_embeds = inputs_embeds
|
| 2822 |
+
task_prefix_attention_mask = torch.ones(
|
| 2823 |
+
batch_size, task_prefix_embeds.size(1), device=device)
|
| 2824 |
|
| 2825 |
if len(task_prefix_attention_mask.shape) == 3:
|
| 2826 |
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
| 2827 |
|
| 2828 |
# concat [image embeds, task prefix embeds]
|
| 2829 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
| 2830 |
+
attention_mask = torch.cat(
|
| 2831 |
+
[image_attention_mask, task_prefix_attention_mask], dim=1)
|
| 2832 |
|
| 2833 |
return inputs_embeds, attention_mask
|
| 2834 |
+
|
| 2835 |
@torch.no_grad()
|
| 2836 |
def predict_detections_and_associations(
|
| 2837 |
self,
|
|
|
|
| 2843 |
essential_text_threshold=0.8,
|
| 2844 |
):
|
| 2845 |
batch_inputs = processor(
|
| 2846 |
+
batch_input_text=[
|
| 2847 |
+
"Find all panels, texts, characters, and speech-bubble tails in the image."] * len(images),
|
| 2848 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2849 |
batch_images=images,
|
| 2850 |
padding=True,
|
|
|
|
| 2862 |
do_sample=False,
|
| 2863 |
num_beams=3,
|
| 2864 |
)
|
| 2865 |
+
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
| 2866 |
+
generated_ids, images)
|
| 2867 |
+
map_to_category = {"<pa": "panels", "<te": "texts",
|
| 2868 |
+
"<ch": "characters", "<ta": "tails"}
|
| 2869 |
|
| 2870 |
results = []
|
| 2871 |
|
| 2872 |
for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
|
| 2873 |
+
categories = [map_to_category.get(
|
| 2874 |
+
generated_text[j:j+3], None) for i, j in batch_indices_of_bboxes_in_generated_text]
|
| 2875 |
result_for_this_image = {
|
| 2876 |
"panels": [],
|
| 2877 |
"texts": [],
|
|
|
|
| 2886 |
|
| 2887 |
cleaned_generated_ids = []
|
| 2888 |
for generated_id in generated_ids:
|
| 2889 |
+
index_of_last_bos = torch.where(
|
| 2890 |
+
generated_id == processor.tokenizer.bos_token_id)[0][-1].item()
|
| 2891 |
cleaned_generated_ids.append(generated_id[index_of_last_bos:])
|
| 2892 |
+
cleaned_generated_ids = pad_sequence(
|
| 2893 |
+
cleaned_generated_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
|
| 2894 |
association_outputs = self(
|
| 2895 |
input_ids=batch_inputs["input_ids"],
|
| 2896 |
pixel_values=batch_inputs["pixel_values"],
|
|
|
|
| 2899 |
)
|
| 2900 |
|
| 2901 |
for img_idx in range(len(results)):
|
| 2902 |
+
character_cluster_labels = UnionFind.from_adj_matrix(
|
| 2903 |
+
association_outputs.character_character_affinity_matrices[img_idx] > character_character_association_threshold).get_labels_for_connected_components()
|
| 2904 |
+
text_character_association = torch.nonzero(
|
| 2905 |
+
association_outputs.text_character_association_matrices[img_idx] > text_character_association_threshold).tolist()
|
| 2906 |
+
text_tail_association = torch.nonzero(
|
| 2907 |
+
association_outputs.text_tail_association_matrices[img_idx] > text_tail_association_threshold).tolist()
|
| 2908 |
+
essential_text_logits = (
|
| 2909 |
+
association_outputs.essential_text_logits[img_idx] > essential_text_threshold).tolist()
|
| 2910 |
results[img_idx]["character_cluster_labels"] = character_cluster_labels
|
| 2911 |
results[img_idx]["text_character_associations"] = text_character_association
|
| 2912 |
results[img_idx]["text_tail_associations"] = text_tail_association
|
| 2913 |
results[img_idx]["is_essential_text"] = essential_text_logits
|
| 2914 |
|
| 2915 |
return results
|
| 2916 |
+
|
| 2917 |
@torch.no_grad()
|
| 2918 |
def predict_ocr(
|
| 2919 |
self,
|
|
|
|
| 2921 |
processor,
|
| 2922 |
):
|
| 2923 |
batch_inputs = processor(
|
| 2924 |
+
batch_input_text=[
|
| 2925 |
+
"What is the text in the image, with regions?"] * len(images),
|
| 2926 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2927 |
batch_images=images,
|
| 2928 |
padding=True,
|
|
|
|
| 2940 |
do_sample=False,
|
| 2941 |
num_beams=3,
|
| 2942 |
)
|
| 2943 |
+
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
| 2944 |
+
generated_ids, images)
|
| 2945 |
results = []
|
| 2946 |
for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
|
| 2947 |
ocr_texts = []
|
|
|
|
| 2965 |
):
|
| 2966 |
def convert_caption_to_instruction(caption):
|
| 2967 |
return "Locate the phrases in the caption: " + caption
|
| 2968 |
+
|
| 2969 |
batch_inputs = processor(
|
| 2970 |
+
batch_input_text=[convert_caption_to_instruction(
|
| 2971 |
+
caption) for caption in captions],
|
| 2972 |
batch_input_list_of_list_of_bboxes=[[]] * len(images),
|
| 2973 |
batch_images=images,
|
| 2974 |
padding=True,
|
|
|
|
| 2996 |
do_sample=False,
|
| 2997 |
num_beams=3,
|
| 2998 |
)
|
| 2999 |
+
generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
|
| 3000 |
+
generated_ids, images)
|
| 3001 |
return [
|
| 3002 |
{
|
| 3003 |
"grounded_caption": generated_text,
|
|
|
|
| 3026 |
output_attentions: Optional[bool] = None,
|
| 3027 |
output_hidden_states: Optional[bool] = True,
|
| 3028 |
return_dict: Optional[bool] = None,
|
| 3029 |
+
tokenizer=None,
|
| 3030 |
) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
|
| 3031 |
assert output_hidden_states, "output_hidden_states must be True"
|
| 3032 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
| 3044 |
if pixel_values is not None:
|
| 3045 |
# (batch_size, num_image_tokens, hidden_size)
|
| 3046 |
image_features = self._encode_image(pixel_values)
|
| 3047 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
| 3048 |
+
image_features, inputs_embeds)
|
| 3049 |
|
| 3050 |
if inputs_embeds is not None:
|
| 3051 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
|
| 3067 |
return_dict=return_dict,
|
| 3068 |
)
|
| 3069 |
|
| 3070 |
+
character_character_affinity_matrices = self.get_character_character_affinity_matrices(
|
| 3071 |
+
outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
|
| 3072 |
+
text_character_association_matrices = self.get_text_character_association_matrices(
|
| 3073 |
+
outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
|
| 3074 |
+
text_tail_association_matrices = self.get_text_tail_association_matrices(
|
| 3075 |
+
outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
|
| 3076 |
+
essential_text_logits = self.get_essential_text_logits(
|
| 3077 |
+
outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
|
| 3078 |
|
| 3079 |
return Florence2Seq2SeqLMOutput(
|
| 3080 |
character_character_affinity_matrices=character_character_affinity_matrices,
|
|
|
|
| 3085 |
|
| 3086 |
def generate(
|
| 3087 |
self,
|
| 3088 |
+
input_ids,
|
| 3089 |
inputs_embeds=None,
|
| 3090 |
pixel_values=None,
|
| 3091 |
**kwargs
|
| 3092 |
+
):
|
| 3093 |
|
| 3094 |
if inputs_embeds is None:
|
| 3095 |
# 1. Extra the input embeddings
|
|
|
|
| 3098 |
# 2. Merge text and images
|
| 3099 |
if pixel_values is not None:
|
| 3100 |
image_features = self._encode_image(pixel_values)
|
| 3101 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
| 3102 |
+
image_features, inputs_embeds)
|
| 3103 |
+
|
| 3104 |
return self.language_model.generate(
|
| 3105 |
input_ids=None,
|
| 3106 |
inputs_embeds=inputs_embeds,
|
| 3107 |
**kwargs
|
| 3108 |
)
|
| 3109 |
+
|
| 3110 |
def slowly_generate_grounded_caption(
|
| 3111 |
self,
|
| 3112 |
input_ids,
|
|
|
|
| 3122 |
"""
|
| 3123 |
input_embeds = self.get_input_embeddings()(input_ids)
|
| 3124 |
image_features = self._encode_image(pixel_values)
|
| 3125 |
+
inputs_embeds, _ = self._merge_input_ids_with_image_features(
|
| 3126 |
+
image_features, input_embeds)
|
| 3127 |
+
decoder_input_ids = processor.tokenizer(
|
| 3128 |
+
captions, return_tensors="pt", truncation=False, padding=True)["input_ids"].to(self.device)
|
| 3129 |
+
running_indices = torch.zeros(
|
| 3130 |
+
decoder_input_ids.shape[0], dtype=torch.long, device=self.device)
|
| 3131 |
running_decoder_input_ids = decoder_input_ids[:, :1]
|
| 3132 |
num_tokens_generated = 1
|
| 3133 |
CHUNK_SIZE = 8
|
|
|
|
| 3145 |
)[:, -(CHUNK_SIZE+1):-1]
|
| 3146 |
|
| 3147 |
what_should_be_the_next_tokens = torch.stack([
|
| 3148 |
+
decoder_input_ids[i, running_indices[i] +
|
| 3149 |
+
1:running_indices[i]+CHUNK_SIZE+1]
|
| 3150 |
for i in range(decoder_input_ids.shape[0])
|
| 3151 |
])
|
| 3152 |
# if the entire predicted chunk matches the next chunk, then we can saved some time and "jump" to the next chunk
|
| 3153 |
if predicted_next_tokens.shape[1] == what_should_be_the_next_tokens.shape[1] and torch.all(predicted_next_tokens == what_should_be_the_next_tokens):
|
| 3154 |
running_indices += CHUNK_SIZE
|
| 3155 |
+
running_decoder_input_ids = torch.cat(
|
| 3156 |
+
[running_decoder_input_ids, what_should_be_the_next_tokens], dim=-1)
|
| 3157 |
continue
|
| 3158 |
+
|
| 3159 |
# if, however, there is a deviation find the maximum prefix that matches in the batch
|
| 3160 |
|
| 3161 |
predicted_next_tokens = predicted_next_tokens[:, 0]
|
| 3162 |
+
predicted_next_token_strings = processor.batch_decode(
|
| 3163 |
+
predicted_next_tokens)
|
| 3164 |
next_tokens_to_concat = []
|
| 3165 |
for i, (pnts, pnt) in enumerate(zip(predicted_next_token_strings, predicted_next_tokens)):
|
| 3166 |
if (pnts.startswith("<loc_") or pnts in ["<s>", "<pad>", "</s>"]) and running_indices[i] < decoder_input_ids.shape[1] - 1:
|
|
|
|
| 3168 |
else:
|
| 3169 |
running_indices[i] += 1
|
| 3170 |
if running_indices[i] >= decoder_input_ids.shape[1]:
|
| 3171 |
+
next_tokens_to_concat.append(torch.tensor(
|
| 3172 |
+
processor.tokenizer.eos_token_id, device=self.device))
|
| 3173 |
# elif "’" in pnts: # this is an annoying character which looks like ' (apostrophe) but isn't.
|
| 3174 |
# import pdb; pdb.set_trace()
|
| 3175 |
else:
|
| 3176 |
+
next_tokens_to_concat.append(
|
| 3177 |
+
decoder_input_ids[i, running_indices[i]])
|
| 3178 |
next_tokens_to_concat = torch.stack(next_tokens_to_concat)[:, None]
|
| 3179 |
if (next_tokens_to_concat == processor.tokenizer.eos_token_id).all():
|
| 3180 |
break
|
| 3181 |
+
running_decoder_input_ids = torch.cat(
|
| 3182 |
+
[running_decoder_input_ids, next_tokens_to_concat], dim=-1)
|
| 3183 |
if num_tokens_generated >= 1024:
|
| 3184 |
break
|
| 3185 |
return running_decoder_input_ids
|
|
|
|
| 3210 |
remove_prefix_length = decoder_input_ids.shape[1] - 1
|
| 3211 |
|
| 3212 |
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
|
| 3213 |
+
|
| 3214 |
return {
|
| 3215 |
"input_ids": None, # encoder_outputs is defined. input_ids not needed
|
| 3216 |
"encoder_outputs": encoder_outputs,
|
|
|
|
| 3222 |
"head_mask": head_mask,
|
| 3223 |
"decoder_head_mask": decoder_head_mask,
|
| 3224 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 3225 |
+
# change this to avoid caching (presumably for debugging)
|
| 3226 |
+
"use_cache": use_cache,
|
| 3227 |
}
|
| 3228 |
+
|
| 3229 |
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
| 3230 |
return self.language_model.shift_tokens_right(labels)
|
| 3231 |
|
| 3232 |
def _reorder_cache(self, *args, **kwargs):
|
| 3233 |
return self.language_model._reorder_cache(*args, **kwargs)
|
| 3234 |
+
|
| 3235 |
def get_character_character_affinity_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3236 |
character_character_affinity_matrices = []
|
| 3237 |
for index in range(len(decoder_hidden_states)):
|
| 3238 |
+
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3239 |
+
'<character>')).nonzero().squeeze(-1)
|
| 3240 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3241 |
if character_embeddings.shape[0] == 0:
|
| 3242 |
+
character_character_affinity_matrices.append(
|
| 3243 |
+
torch.zeros(0, 0).type_as(character_embeddings))
|
| 3244 |
continue
|
| 3245 |
+
character_embeddings = self.character_embedding_projection(
|
| 3246 |
+
character_embeddings)
|
| 3247 |
+
char_i = repeat(character_embeddings, "i d -> i repeat d",
|
| 3248 |
+
repeat=character_embeddings.shape[0])
|
| 3249 |
+
char_j = repeat(character_embeddings, "j d -> repeat j d",
|
| 3250 |
+
repeat=character_embeddings.shape[0])
|
| 3251 |
char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
|
| 3252 |
+
character_character_affinities = self.character_character_matching_head(
|
| 3253 |
+
char_ij)
|
| 3254 |
+
character_character_affinities = rearrange(
|
| 3255 |
+
character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
|
| 3256 |
+
character_character_affinities = (
|
| 3257 |
+
character_character_affinities + character_character_affinities.T) / 2
|
| 3258 |
if apply_sigmoid:
|
| 3259 |
+
character_character_affinities = torch.sigmoid(
|
| 3260 |
+
character_character_affinities)
|
| 3261 |
+
character_character_affinity_matrices.append(
|
| 3262 |
+
character_character_affinities)
|
| 3263 |
return character_character_affinity_matrices
|
| 3264 |
|
| 3265 |
def get_text_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3266 |
text_character_association_matrices = []
|
| 3267 |
for index in range(len(decoder_hidden_states)):
|
| 3268 |
+
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3269 |
+
'<text>')).nonzero().squeeze(-1)
|
| 3270 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3271 |
+
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3272 |
+
'<character>')).nonzero().squeeze(-1)
|
| 3273 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3274 |
if character_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
|
| 3275 |
+
text_character_association_matrices.append(torch.zeros(
|
| 3276 |
+
text_embeddings.shape[0], character_embeddings.shape[0]).type_as(text_embeddings))
|
| 3277 |
continue
|
| 3278 |
+
text_i = repeat(text_embeddings, "i d -> i repeat d",
|
| 3279 |
+
repeat=character_embeddings.shape[0])
|
| 3280 |
+
char_j = repeat(character_embeddings, "j d -> repeat j d",
|
| 3281 |
+
repeat=text_embeddings.shape[0])
|
| 3282 |
+
text_char_ij = rearrange(
|
| 3283 |
+
[text_i, char_j], "two i j d -> (i j) (two d)")
|
| 3284 |
+
text_character_affinities = self.text_character_matching_head(
|
| 3285 |
+
text_char_ij)
|
| 3286 |
+
text_character_affinities = rearrange(
|
| 3287 |
+
text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
| 3288 |
if apply_sigmoid:
|
| 3289 |
+
text_character_affinities = torch.sigmoid(
|
| 3290 |
+
text_character_affinities)
|
| 3291 |
+
text_character_association_matrices.append(
|
| 3292 |
+
text_character_affinities)
|
| 3293 |
return text_character_association_matrices
|
| 3294 |
|
| 3295 |
def get_text_tail_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3296 |
text_tail_association_matrices = []
|
| 3297 |
for index in range(len(decoder_hidden_states)):
|
| 3298 |
+
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3299 |
+
'<text>')).nonzero().squeeze(-1)
|
| 3300 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3301 |
+
tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3302 |
+
'<tail>')).nonzero().squeeze(-1)
|
| 3303 |
tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
|
| 3304 |
if tail_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
|
| 3305 |
+
text_tail_association_matrices.append(torch.zeros(
|
| 3306 |
+
text_embeddings.shape[0], tail_embeddings.shape[0]).type_as(text_embeddings))
|
| 3307 |
continue
|
| 3308 |
+
text_i = repeat(text_embeddings, "i d -> i repeat d",
|
| 3309 |
+
repeat=tail_embeddings.shape[0])
|
| 3310 |
+
tail_j = repeat(tail_embeddings, "j d -> repeat j d",
|
| 3311 |
+
repeat=text_embeddings.shape[0])
|
| 3312 |
+
text_tail_ij = rearrange(
|
| 3313 |
+
[text_i, tail_j], "two i j d -> (i j) (two d)")
|
| 3314 |
text_tail_affinities = self.text_tail_matching_head(text_tail_ij)
|
| 3315 |
+
text_tail_affinities = rearrange(
|
| 3316 |
+
text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
|
| 3317 |
if apply_sigmoid:
|
| 3318 |
text_tail_affinities = torch.sigmoid(text_tail_affinities)
|
| 3319 |
text_tail_association_matrices.append(text_tail_affinities)
|
| 3320 |
return text_tail_association_matrices
|
| 3321 |
+
|
| 3322 |
def get_tail_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3323 |
tail_character_association_matrices = []
|
| 3324 |
for index in range(len(decoder_hidden_states)):
|
| 3325 |
+
tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3326 |
+
'<tail>')).nonzero().squeeze(-1)
|
| 3327 |
tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
|
| 3328 |
+
character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3329 |
+
'<character>')).nonzero().squeeze(-1)
|
| 3330 |
character_embeddings = decoder_hidden_states[index][character_embedding_indices]
|
| 3331 |
if character_embeddings.shape[0] == 0 or tail_embeddings.shape[0] == 0:
|
| 3332 |
+
tail_character_association_matrices.append(torch.zeros(
|
| 3333 |
+
tail_embeddings.shape[0], character_embeddings.shape[0]).type_as(tail_embeddings))
|
| 3334 |
continue
|
| 3335 |
+
tail_i = repeat(tail_embeddings, "i d -> i repeat d",
|
| 3336 |
+
repeat=character_embeddings.shape[0])
|
| 3337 |
+
char_j = repeat(character_embeddings, "j d -> repeat j d",
|
| 3338 |
+
repeat=tail_embeddings.shape[0])
|
| 3339 |
+
tail_char_ij = rearrange(
|
| 3340 |
+
[tail_i, char_j], "two i j d -> (i j) (two d)")
|
| 3341 |
+
tail_character_affinities = self.tail_character_matching_head(
|
| 3342 |
+
tail_char_ij)
|
| 3343 |
+
tail_character_affinities = rearrange(
|
| 3344 |
+
tail_character_affinities, "(i j) 1 -> i j", i=tail_i.shape[0])
|
| 3345 |
if apply_sigmoid:
|
| 3346 |
+
tail_character_affinities = torch.sigmoid(
|
| 3347 |
+
tail_character_affinities)
|
| 3348 |
+
tail_character_association_matrices.append(
|
| 3349 |
+
tail_character_affinities)
|
| 3350 |
return tail_character_association_matrices
|
| 3351 |
+
|
| 3352 |
def get_essential_text_logits(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
|
| 3353 |
essential_text_logits = []
|
| 3354 |
for index in range(len(decoder_hidden_states)):
|
| 3355 |
+
text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
|
| 3356 |
+
'<text>')).nonzero().squeeze(-1)
|
| 3357 |
text_embeddings = decoder_hidden_states[index][text_embedding_indices]
|
| 3358 |
if text_embeddings.shape[0] == 0:
|
| 3359 |
+
essential_text_logits.append(
|
| 3360 |
+
torch.zeros(0).type_as(text_embeddings))
|
| 3361 |
continue
|
| 3362 |
+
text_logits = rearrange(
|
| 3363 |
+
self.text_classification_head(text_embeddings), "i 1 -> i")
|
| 3364 |
if apply_sigmoid:
|
| 3365 |
text_logits = torch.sigmoid(text_logits)
|
| 3366 |
essential_text_logits.append(text_logits)
|
| 3367 |
+
return essential_text_logits
|
processing_florence2.py
CHANGED
|
@@ -53,18 +53,19 @@ class Florence2Processor(ProcessorMixin):
|
|
| 53 |
if tokenizer is None:
|
| 54 |
raise ValueError("You need to specify a `tokenizer`.")
|
| 55 |
if not hasattr(image_processor, "image_seq_length"):
|
| 56 |
-
raise ValueError(
|
|
|
|
| 57 |
|
| 58 |
self.image_seq_length = image_processor.image_seq_length
|
| 59 |
|
| 60 |
tokens_to_add = {
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
tokenizer.add_special_tokens(tokens_to_add)
|
| 69 |
self.decoder_start_token_id = 2
|
| 70 |
|
|
@@ -74,7 +75,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 74 |
)
|
| 75 |
|
| 76 |
super().__init__(image_processor, tokenizer)
|
| 77 |
-
|
| 78 |
def __call__(
|
| 79 |
self,
|
| 80 |
batch_input_text: List[TextInput] = None,
|
|
@@ -82,11 +83,11 @@ class Florence2Processor(ProcessorMixin):
|
|
| 82 |
batch_output_text: List[TextInput] = None,
|
| 83 |
batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None,
|
| 84 |
batch_images: ImageInput = None,
|
| 85 |
-
batch_character_cluster_labels
|
| 86 |
-
batch_text_character_association_labels
|
| 87 |
-
batch_text_tail_association_labels
|
| 88 |
-
batch_is_essential_text_labels
|
| 89 |
-
batch_tail_character_association_labels
|
| 90 |
padding: Union[bool, str, PaddingStrategy] = None,
|
| 91 |
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 92 |
max_input_length_including_image_tokens=None,
|
|
@@ -109,17 +110,23 @@ class Florence2Processor(ProcessorMixin):
|
|
| 109 |
assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance."
|
| 110 |
assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance."
|
| 111 |
if batch_input_list_of_list_of_bboxes is None:
|
| 112 |
-
batch_input_list_of_list_of_bboxes = [
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
if batch_output_text is None:
|
| 115 |
assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together."
|
| 116 |
else:
|
| 117 |
if batch_output_list_of_list_of_bboxes is None:
|
| 118 |
-
batch_output_list_of_list_of_bboxes = [
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
inputs = self.tokenizer(
|
| 124 |
batch_input_texts,
|
| 125 |
return_tensors=return_tensors,
|
|
@@ -130,9 +137,10 @@ class Florence2Processor(ProcessorMixin):
|
|
| 130 |
if inputs["input_ids"].shape[1] > max_input_length:
|
| 131 |
inputs["input_ids"] = inputs["input_ids"][:, :max_input_length]
|
| 132 |
inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length]
|
| 133 |
-
|
| 134 |
if batch_output_text is not None:
|
| 135 |
-
batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(
|
|
|
|
| 136 |
decoder_inputs = self.tokenizer(
|
| 137 |
batch_output_texts,
|
| 138 |
return_tensors=return_tensors,
|
|
@@ -141,9 +149,10 @@ class Florence2Processor(ProcessorMixin):
|
|
| 141 |
)
|
| 142 |
# Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior
|
| 143 |
if decoder_inputs["input_ids"].shape[1] > max_output_length:
|
| 144 |
-
decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:,
|
| 145 |
-
|
| 146 |
-
|
|
|
|
| 147 |
|
| 148 |
pixel_values = self.image_processor(
|
| 149 |
batch_images,
|
|
@@ -160,7 +169,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 160 |
|
| 161 |
if dtype is not None:
|
| 162 |
pixel_values = pixel_values.to(dtype)
|
| 163 |
-
|
| 164 |
return_data = {**inputs, "pixel_values": pixel_values}
|
| 165 |
|
| 166 |
if batch_output_text is not None:
|
|
@@ -168,8 +177,10 @@ class Florence2Processor(ProcessorMixin):
|
|
| 168 |
decoder_input_ids = labels.new_zeros(labels.shape)
|
| 169 |
decoder_input_ids[:, 1:] = labels[:, :-1].clone()
|
| 170 |
decoder_input_ids[:, 0] = self.decoder_start_token_id
|
| 171 |
-
decoder_attention_mask = decoder_inputs["attention_mask"].new_ones(
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
# Mask fill labels to replace pad token ID with -100
|
| 174 |
labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100)
|
| 175 |
return_data.update({
|
|
@@ -177,7 +188,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 177 |
"decoder_input_ids": decoder_input_ids,
|
| 178 |
"decoder_attention_mask": decoder_attention_mask,
|
| 179 |
})
|
| 180 |
-
|
| 181 |
if device is not None:
|
| 182 |
for key, value in return_data.items():
|
| 183 |
if isinstance(value, torch.Tensor):
|
|
@@ -201,25 +212,32 @@ class Florence2Processor(ProcessorMixin):
|
|
| 201 |
return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "")
|
| 202 |
|
| 203 |
def postprocess_output(self, generated_ids, images):
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
batch_list_of_list_of_bboxes = []
|
| 208 |
batch_indices_of_bboxes_in_new_string = []
|
| 209 |
batch_new_texts = []
|
| 210 |
for text, image in zip(batch_decoded_texts, images):
|
| 211 |
size_wh = self._get_image_size_wh(image)
|
| 212 |
-
parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes(
|
| 213 |
-
|
|
|
|
|
|
|
| 214 |
batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes)
|
| 215 |
-
batch_indices_of_bboxes_in_new_string.append(
|
|
|
|
| 216 |
batch_new_texts.append(parsed_text)
|
| 217 |
return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string
|
| 218 |
|
| 219 |
def _parse_text_with_bboxes(self, text):
|
| 220 |
loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)'
|
| 221 |
grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern
|
| 222 |
-
|
| 223 |
list_of_stringified_bboxes = []
|
| 224 |
start_end_in_new_string = []
|
| 225 |
new_text = ""
|
|
@@ -237,7 +255,8 @@ class Florence2Processor(ProcessorMixin):
|
|
| 237 |
locs = match.group(2)
|
| 238 |
new_text += grounding_text
|
| 239 |
list_of_stringified_bboxes.append(locs)
|
| 240 |
-
start_end_in_new_string.append(
|
|
|
|
| 241 |
new_pos += len(grounding_text)
|
| 242 |
else:
|
| 243 |
# Handle loc pattern
|
|
@@ -245,7 +264,8 @@ class Florence2Processor(ProcessorMixin):
|
|
| 245 |
replacement = ""
|
| 246 |
new_text += replacement
|
| 247 |
list_of_stringified_bboxes.append(locs)
|
| 248 |
-
start_end_in_new_string.append(
|
|
|
|
| 249 |
new_pos += len(replacement)
|
| 250 |
|
| 251 |
original_pos = match.end()
|
|
@@ -254,19 +274,21 @@ class Florence2Processor(ProcessorMixin):
|
|
| 254 |
new_text += text[original_pos:]
|
| 255 |
|
| 256 |
return new_text, list_of_stringified_bboxes, start_end_in_new_string
|
| 257 |
-
|
| 258 |
def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image):
|
| 259 |
size_wh = self._get_image_size_wh(image)
|
| 260 |
quantized_bbox_lists = []
|
| 261 |
-
for list_of_bboxes in list_of_list_of_bboxes:
|
| 262 |
-
quantized_bboxes = self.box_quantizer.quantize(
|
| 263 |
-
|
|
|
|
|
|
|
| 264 |
stringified_bboxes = ",".join(stringified_bboxes)
|
| 265 |
quantized_bbox_lists.append(stringified_bboxes)
|
| 266 |
return text.format(*quantized_bbox_lists)
|
| 267 |
|
| 268 |
def _get_image_size_wh(self, image):
|
| 269 |
-
|
| 270 |
if isinstance(image, torch.Tensor):
|
| 271 |
# For PyTorch tensor
|
| 272 |
if image.dim() == 3:
|
|
@@ -313,6 +335,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 313 |
image_processor_input_names = self.image_processor.model_input_names
|
| 314 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 315 |
|
|
|
|
| 316 |
class BoxQuantizer(object):
|
| 317 |
def __init__(self, mode, bins):
|
| 318 |
self.mode = mode
|
|
@@ -390,4 +413,4 @@ class BoxQuantizer(object):
|
|
| 390 |
dequantized_xmax, dequantized_ymax), dim=-1
|
| 391 |
)
|
| 392 |
|
| 393 |
-
return dequantized_boxes
|
|
|
|
| 53 |
if tokenizer is None:
|
| 54 |
raise ValueError("You need to specify a `tokenizer`.")
|
| 55 |
if not hasattr(image_processor, "image_seq_length"):
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"Image processor is missing an `image_seq_length` attribute.")
|
| 58 |
|
| 59 |
self.image_seq_length = image_processor.image_seq_length
|
| 60 |
|
| 61 |
tokens_to_add = {
|
| 62 |
+
'additional_special_tokens':
|
| 63 |
+
tokenizer.additional_special_tokens +
|
| 64 |
+
['<od>', '</od>', '<ocr>', '</ocr>'] +
|
| 65 |
+
[f'<loc_{x}>' for x in range(1000)] +
|
| 66 |
+
['<cap>', '</cap>', '<ncap>', '</ncap>', '<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>'] +
|
| 67 |
+
['<panel>', '<text>', '<character>', '<tail>']
|
| 68 |
+
}
|
| 69 |
tokenizer.add_special_tokens(tokens_to_add)
|
| 70 |
self.decoder_start_token_id = 2
|
| 71 |
|
|
|
|
| 75 |
)
|
| 76 |
|
| 77 |
super().__init__(image_processor, tokenizer)
|
| 78 |
+
|
| 79 |
def __call__(
|
| 80 |
self,
|
| 81 |
batch_input_text: List[TextInput] = None,
|
|
|
|
| 83 |
batch_output_text: List[TextInput] = None,
|
| 84 |
batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None,
|
| 85 |
batch_images: ImageInput = None,
|
| 86 |
+
batch_character_cluster_labels=None,
|
| 87 |
+
batch_text_character_association_labels=None,
|
| 88 |
+
batch_text_tail_association_labels=None,
|
| 89 |
+
batch_is_essential_text_labels=None,
|
| 90 |
+
batch_tail_character_association_labels=None,
|
| 91 |
padding: Union[bool, str, PaddingStrategy] = None,
|
| 92 |
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 93 |
max_input_length_including_image_tokens=None,
|
|
|
|
| 110 |
assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance."
|
| 111 |
assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance."
|
| 112 |
if batch_input_list_of_list_of_bboxes is None:
|
| 113 |
+
batch_input_list_of_list_of_bboxes = [
|
| 114 |
+
[] for _ in range(len(batch_input_text))]
|
| 115 |
+
assert len(batch_input_text) == len(batch_input_list_of_list_of_bboxes) == len(
|
| 116 |
+
batch_images), "`batch_input_text`, `batch_input_list_of_list_of_bboxes` and `batch_images` have different lengths."
|
| 117 |
if batch_output_text is None:
|
| 118 |
assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together."
|
| 119 |
else:
|
| 120 |
if batch_output_list_of_list_of_bboxes is None:
|
| 121 |
+
batch_output_list_of_list_of_bboxes = [
|
| 122 |
+
[] for _ in range(len(batch_output_text))]
|
| 123 |
+
assert len(batch_output_text) == len(batch_output_list_of_list_of_bboxes) == len(
|
| 124 |
+
batch_images), "`batch_output_text`, `batch_output_list_of_list_of_bboxes` and `batch_images` have different lengths."
|
| 125 |
+
|
| 126 |
+
max_input_length = max_input_length_including_image_tokens - \
|
| 127 |
+
self.image_seq_length if max_input_length_including_image_tokens is not None else None
|
| 128 |
+
batch_input_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(
|
| 129 |
+
batch_input_text, batch_input_list_of_list_of_bboxes, batch_images)]
|
| 130 |
inputs = self.tokenizer(
|
| 131 |
batch_input_texts,
|
| 132 |
return_tensors=return_tensors,
|
|
|
|
| 137 |
if inputs["input_ids"].shape[1] > max_input_length:
|
| 138 |
inputs["input_ids"] = inputs["input_ids"][:, :max_input_length]
|
| 139 |
inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length]
|
| 140 |
+
|
| 141 |
if batch_output_text is not None:
|
| 142 |
+
batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(
|
| 143 |
+
batch_output_text, batch_output_list_of_list_of_bboxes, batch_images)]
|
| 144 |
decoder_inputs = self.tokenizer(
|
| 145 |
batch_output_texts,
|
| 146 |
return_tensors=return_tensors,
|
|
|
|
| 149 |
)
|
| 150 |
# Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior
|
| 151 |
if decoder_inputs["input_ids"].shape[1] > max_output_length:
|
| 152 |
+
decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:,
|
| 153 |
+
:max_output_length]
|
| 154 |
+
decoder_inputs["attention_mask"] = decoder_inputs["attention_mask"][:,
|
| 155 |
+
:max_output_length]
|
| 156 |
|
| 157 |
pixel_values = self.image_processor(
|
| 158 |
batch_images,
|
|
|
|
| 169 |
|
| 170 |
if dtype is not None:
|
| 171 |
pixel_values = pixel_values.to(dtype)
|
| 172 |
+
|
| 173 |
return_data = {**inputs, "pixel_values": pixel_values}
|
| 174 |
|
| 175 |
if batch_output_text is not None:
|
|
|
|
| 177 |
decoder_input_ids = labels.new_zeros(labels.shape)
|
| 178 |
decoder_input_ids[:, 1:] = labels[:, :-1].clone()
|
| 179 |
decoder_input_ids[:, 0] = self.decoder_start_token_id
|
| 180 |
+
decoder_attention_mask = decoder_inputs["attention_mask"].new_ones(
|
| 181 |
+
decoder_input_ids.shape)
|
| 182 |
+
decoder_attention_mask[:,
|
| 183 |
+
1:] = decoder_inputs["attention_mask"][:, :-1].clone()
|
| 184 |
# Mask fill labels to replace pad token ID with -100
|
| 185 |
labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100)
|
| 186 |
return_data.update({
|
|
|
|
| 188 |
"decoder_input_ids": decoder_input_ids,
|
| 189 |
"decoder_attention_mask": decoder_attention_mask,
|
| 190 |
})
|
| 191 |
+
|
| 192 |
if device is not None:
|
| 193 |
for key, value in return_data.items():
|
| 194 |
if isinstance(value, torch.Tensor):
|
|
|
|
| 212 |
return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "")
|
| 213 |
|
| 214 |
def postprocess_output(self, generated_ids, images):
|
| 215 |
+
# only for some testing purposes
|
| 216 |
+
generated_ids.masked_fill_(
|
| 217 |
+
generated_ids == -100, self.tokenizer.pad_token_id)
|
| 218 |
+
batch_decoded_texts = self.batch_decode(
|
| 219 |
+
generated_ids, skip_special_tokens=False)
|
| 220 |
+
batch_decoded_texts = [self.cleanup_generated_text(
|
| 221 |
+
text) for text in batch_decoded_texts]
|
| 222 |
batch_list_of_list_of_bboxes = []
|
| 223 |
batch_indices_of_bboxes_in_new_string = []
|
| 224 |
batch_new_texts = []
|
| 225 |
for text, image in zip(batch_decoded_texts, images):
|
| 226 |
size_wh = self._get_image_size_wh(image)
|
| 227 |
+
parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes(
|
| 228 |
+
text)
|
| 229 |
+
list_of_list_of_bboxes = [self.box_quantizer.dequantize_from_stringified_bboxes(
|
| 230 |
+
stringified_bbox, size_wh) for stringified_bbox in list_of_stringified_bboxes]
|
| 231 |
batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes)
|
| 232 |
+
batch_indices_of_bboxes_in_new_string.append(
|
| 233 |
+
start_end_in_new_string)
|
| 234 |
batch_new_texts.append(parsed_text)
|
| 235 |
return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string
|
| 236 |
|
| 237 |
def _parse_text_with_bboxes(self, text):
|
| 238 |
loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)'
|
| 239 |
grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern
|
| 240 |
+
|
| 241 |
list_of_stringified_bboxes = []
|
| 242 |
start_end_in_new_string = []
|
| 243 |
new_text = ""
|
|
|
|
| 255 |
locs = match.group(2)
|
| 256 |
new_text += grounding_text
|
| 257 |
list_of_stringified_bboxes.append(locs)
|
| 258 |
+
start_end_in_new_string.append(
|
| 259 |
+
(new_pos, new_pos + len(grounding_text)))
|
| 260 |
new_pos += len(grounding_text)
|
| 261 |
else:
|
| 262 |
# Handle loc pattern
|
|
|
|
| 264 |
replacement = ""
|
| 265 |
new_text += replacement
|
| 266 |
list_of_stringified_bboxes.append(locs)
|
| 267 |
+
start_end_in_new_string.append(
|
| 268 |
+
(new_pos, new_pos + len(replacement)))
|
| 269 |
new_pos += len(replacement)
|
| 270 |
|
| 271 |
original_pos = match.end()
|
|
|
|
| 274 |
new_text += text[original_pos:]
|
| 275 |
|
| 276 |
return new_text, list_of_stringified_bboxes, start_end_in_new_string
|
| 277 |
+
|
| 278 |
def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image):
|
| 279 |
size_wh = self._get_image_size_wh(image)
|
| 280 |
quantized_bbox_lists = []
|
| 281 |
+
for list_of_bboxes in list_of_list_of_bboxes:
|
| 282 |
+
quantized_bboxes = self.box_quantizer.quantize(
|
| 283 |
+
list_of_bboxes, size_wh=size_wh)
|
| 284 |
+
stringified_bboxes = [
|
| 285 |
+
f"<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>" for x1, y1, x2, y2 in quantized_bboxes]
|
| 286 |
stringified_bboxes = ",".join(stringified_bboxes)
|
| 287 |
quantized_bbox_lists.append(stringified_bboxes)
|
| 288 |
return text.format(*quantized_bbox_lists)
|
| 289 |
|
| 290 |
def _get_image_size_wh(self, image):
|
| 291 |
+
# Get size_wh from image based on its type
|
| 292 |
if isinstance(image, torch.Tensor):
|
| 293 |
# For PyTorch tensor
|
| 294 |
if image.dim() == 3:
|
|
|
|
| 335 |
image_processor_input_names = self.image_processor.model_input_names
|
| 336 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 337 |
|
| 338 |
+
|
| 339 |
class BoxQuantizer(object):
|
| 340 |
def __init__(self, mode, bins):
|
| 341 |
self.mode = mode
|
|
|
|
| 413 |
dequantized_xmax, dequantized_ymax), dim=-1
|
| 414 |
)
|
| 415 |
|
| 416 |
+
return dequantized_boxes
|
utils.py
CHANGED
|
@@ -9,6 +9,7 @@ from copy import deepcopy
|
|
| 9 |
from itertools import groupby
|
| 10 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
| 11 |
|
|
|
|
| 12 |
def move_to_device(inputs, device):
|
| 13 |
if hasattr(inputs, "keys"):
|
| 14 |
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
|
@@ -21,6 +22,7 @@ def move_to_device(inputs, device):
|
|
| 21 |
else:
|
| 22 |
return inputs.to(device)
|
| 23 |
|
|
|
|
| 24 |
class UnionFind:
|
| 25 |
def __init__(self, n):
|
| 26 |
self.parent = list(range(n))
|
|
@@ -35,7 +37,7 @@ class UnionFind:
|
|
| 35 |
if adj_matrix[i, j] > 0:
|
| 36 |
ufds.unite(i, j)
|
| 37 |
return ufds
|
| 38 |
-
|
| 39 |
@classmethod
|
| 40 |
def from_adj_list(cls, adj_list):
|
| 41 |
ufds = cls(len(adj_list))
|
|
@@ -43,7 +45,7 @@ class UnionFind:
|
|
| 43 |
for j in adj_list[i]:
|
| 44 |
ufds.unite(i, j)
|
| 45 |
return ufds
|
| 46 |
-
|
| 47 |
@classmethod
|
| 48 |
def from_edge_list(cls, edge_list, num_nodes):
|
| 49 |
ufds = cls(num_nodes)
|
|
@@ -66,11 +68,11 @@ class UnionFind:
|
|
| 66 |
self.parent[y] = x
|
| 67 |
self.size[x] += self.size[y]
|
| 68 |
self.num_components -= 1
|
| 69 |
-
|
| 70 |
def get_components_of(self, x):
|
| 71 |
x = self.find(x)
|
| 72 |
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
| 73 |
-
|
| 74 |
def are_connected(self, x, y):
|
| 75 |
return self.find(x) == self.find(y)
|
| 76 |
|
|
@@ -79,7 +81,7 @@ class UnionFind:
|
|
| 79 |
|
| 80 |
def get_num_components(self):
|
| 81 |
return self.num_components
|
| 82 |
-
|
| 83 |
def get_labels_for_connected_components(self):
|
| 84 |
map_parent_to_label = {}
|
| 85 |
labels = []
|
|
@@ -90,6 +92,7 @@ class UnionFind:
|
|
| 90 |
labels.append(map_parent_to_label[parent])
|
| 91 |
return labels
|
| 92 |
|
|
|
|
| 93 |
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
| 94 |
h, w = image_as_np_array.shape[:2]
|
| 95 |
if h > w:
|
|
@@ -102,15 +105,16 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 102 |
plot_bboxes(subplot, predictions["characters"], color="blue")
|
| 103 |
|
| 104 |
COLOURS = [
|
| 105 |
-
"#b7ff51",
|
| 106 |
-
"#f50a8f",
|
| 107 |
-
"#4b13b6",
|
| 108 |
-
"#ddaa34",
|
| 109 |
-
"#bea2a2",
|
| 110 |
]
|
| 111 |
colour_index = 0
|
| 112 |
character_cluster_labels = predictions["character_cluster_labels"]
|
| 113 |
-
unique_label_sorted_by_frequency = sorted(list(set(
|
|
|
|
| 114 |
for label in unique_label_sorted_by_frequency:
|
| 115 |
root = None
|
| 116 |
others = []
|
|
@@ -123,7 +127,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 123 |
if colour_index >= len(COLOURS):
|
| 124 |
random_colour = COLOURS[0]
|
| 125 |
while random_colour in COLOURS:
|
| 126 |
-
random_colour = "#" +
|
|
|
|
|
|
|
| 127 |
else:
|
| 128 |
random_colour = COLOURS[colour_index]
|
| 129 |
colour_index += 1
|
|
@@ -139,8 +145,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 139 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 140 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 141 |
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
| 142 |
-
subplot.plot([x2], [y2], color=random_colour,
|
| 143 |
-
|
|
|
|
| 144 |
for (i, j) in predictions["text_character_associations"]:
|
| 145 |
score = predictions["dialog_confidences"][i]
|
| 146 |
bbox_i = predictions["texts"][i]
|
|
@@ -149,7 +156,8 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 149 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 150 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 151 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 152 |
-
subplot.plot([x1, x2], [y1, y2], color="red",
|
|
|
|
| 153 |
|
| 154 |
subplot.axis("off")
|
| 155 |
if filename is not None:
|
|
@@ -160,6 +168,7 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
|
| 160 |
plt.close()
|
| 161 |
return image
|
| 162 |
|
|
|
|
| 163 |
def plot_bboxes(subplot, bboxes, color="red", add_index=False):
|
| 164 |
for id, bbox in enumerate(bboxes):
|
| 165 |
w = bbox[2] - bbox[0]
|
|
@@ -170,7 +179,9 @@ def plot_bboxes(subplot, bboxes, color="red", add_index=False):
|
|
| 170 |
subplot.add_patch(rect)
|
| 171 |
if add_index:
|
| 172 |
cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
|
| 173 |
-
subplot.text(cx, cy, str(id), color=color,
|
|
|
|
|
|
|
| 174 |
|
| 175 |
def sort_panels(rects):
|
| 176 |
before_rects = convert_to_list_of_lists(rects)
|
|
@@ -203,34 +214,42 @@ def sort_panels(rects):
|
|
| 203 |
G.remove_edge(*max_cyclic_edge)
|
| 204 |
return list(nx.topological_sort(G))
|
| 205 |
|
|
|
|
| 206 |
def is_strictly_above(rectA, rectB):
|
| 207 |
x1A, y1A, x2A, y2A = rectA
|
| 208 |
x1B, y1B, x2B, y2B = rectB
|
| 209 |
return y2A < y1B
|
| 210 |
|
|
|
|
| 211 |
def is_strictly_below(rectA, rectB):
|
| 212 |
x1A, y1A, x2A, y2A = rectA
|
| 213 |
x1B, y1B, x2B, y2B = rectB
|
| 214 |
return y2B < y1A
|
| 215 |
|
|
|
|
| 216 |
def is_strictly_left_of(rectA, rectB):
|
| 217 |
x1A, y1A, x2A, y2A = rectA
|
| 218 |
x1B, y1B, x2B, y2B = rectB
|
| 219 |
return x2A < x1B
|
| 220 |
|
|
|
|
| 221 |
def is_strictly_right_of(rectA, rectB):
|
| 222 |
x1A, y1A, x2A, y2A = rectA
|
| 223 |
x1B, y1B, x2B, y2B = rectB
|
| 224 |
return x2B < x1A
|
| 225 |
|
|
|
|
| 226 |
def intersects(rectA, rectB):
|
| 227 |
return box(*rectA).intersects(box(*rectB))
|
| 228 |
|
|
|
|
| 229 |
def is_there_a_directed_edge(a, b, rects):
|
| 230 |
rectA = rects[a]
|
| 231 |
rectB = rects[b]
|
| 232 |
-
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
| 235 |
return box(*rectA).area > (box(*rectB)).area
|
| 236 |
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
|
@@ -247,34 +266,41 @@ def is_there_a_directed_edge(a, b, rects):
|
|
| 247 |
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
| 248 |
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 249 |
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
| 250 |
-
|
| 251 |
# otherwise they intersect
|
| 252 |
copy_A = erode_rectangle(copy_A, 0.05)
|
| 253 |
copy_B = erode_rectangle(copy_B, 0.05)
|
| 254 |
-
|
|
|
|
| 255 |
def get_distance(rectA, rectB):
|
| 256 |
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
| 257 |
|
|
|
|
| 258 |
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
| 259 |
rects = deepcopy(rects)
|
| 260 |
while True:
|
| 261 |
-
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
| 265 |
# try to split the panels using a "horizontal" lines
|
| 266 |
-
overlapping_y_ranges = merge_overlapping_ranges(
|
|
|
|
| 267 |
panel_index_to_split = {}
|
| 268 |
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
| 269 |
for i, index in enumerate(rect_index):
|
| 270 |
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
| 271 |
panel_index_to_split[index] = split_index
|
| 272 |
-
|
| 273 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 274 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 275 |
-
|
| 276 |
# try to split the panels using a "vertical" lines
|
| 277 |
-
overlapping_x_ranges = merge_overlapping_ranges(
|
|
|
|
| 278 |
panel_index_to_split = {}
|
| 279 |
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
| 280 |
for i, index in enumerate(rect_index):
|
|
@@ -282,10 +308,11 @@ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
|
| 282 |
panel_index_to_split[index] = split_index
|
| 283 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 284 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 285 |
-
|
| 286 |
# otherwise, erode the rectangles and try again
|
| 287 |
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
| 288 |
|
|
|
|
| 289 |
def erode_rectangle(bbox, erosion_factor):
|
| 290 |
x1, y1, x2, y2 = bbox
|
| 291 |
w, h = x2 - x1, y2 - y1
|
|
@@ -303,6 +330,7 @@ def erode_rectangle(bbox, erosion_factor):
|
|
| 303 |
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
| 304 |
return [x1, y1, x2, y2]
|
| 305 |
|
|
|
|
| 306 |
def merge_overlapping_ranges(ranges):
|
| 307 |
"""
|
| 308 |
ranges: list of tuples (x1, x2)
|
|
@@ -324,6 +352,7 @@ def merge_overlapping_ranges(ranges):
|
|
| 324 |
merged_ranges.append((prev_x1, prev_x2))
|
| 325 |
return merged_ranges
|
| 326 |
|
|
|
|
| 327 |
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
| 328 |
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
| 329 |
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
|
@@ -335,18 +364,23 @@ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
|
| 335 |
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
| 336 |
return [list(indices) for _, indices in groups]
|
| 337 |
|
| 338 |
-
panel_id_for_text = get_text_to_panel_mapping(
|
|
|
|
| 339 |
indices_of_texts = list(range(len(text_bboxes)))
|
| 340 |
-
indices_of_texts, panel_id_for_text = zip(
|
|
|
|
| 341 |
indices_of_texts = list(indices_of_texts)
|
| 342 |
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
| 343 |
for group in grouped_indices:
|
| 344 |
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
| 345 |
-
text_bboxes_of_subset = [text_bboxes[i]
|
|
|
|
| 346 |
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
| 347 |
-
indices_of_texts[group[0]
|
|
|
|
| 348 |
return indices_of_texts
|
| 349 |
|
|
|
|
| 350 |
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
| 351 |
text_to_panel_mapping = []
|
| 352 |
for text_bbox in text_bboxes:
|
|
@@ -359,14 +393,19 @@ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
|
| 359 |
for j, annotation in enumerate(sorted_panel_bboxes):
|
| 360 |
shapely_annotation_polygon = box(*annotation)
|
| 361 |
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
| 362 |
-
all_intersections.append(
|
| 363 |
-
|
|
|
|
|
|
|
| 364 |
if len(all_intersections) == 0:
|
| 365 |
-
text_to_panel_mapping.append(
|
|
|
|
| 366 |
else:
|
| 367 |
-
text_to_panel_mapping.append(
|
|
|
|
| 368 |
return text_to_panel_mapping
|
| 369 |
|
|
|
|
| 370 |
def sort_texts_within_panel(rects):
|
| 371 |
smallest_y = float("inf")
|
| 372 |
greatest_x = float("-inf")
|
|
@@ -374,18 +413,20 @@ def sort_texts_within_panel(rects):
|
|
| 374 |
x1, y1, x2, y2 = rect
|
| 375 |
smallest_y = min(smallest_y, y1)
|
| 376 |
greatest_x = max(greatest_x, x2)
|
| 377 |
-
|
| 378 |
reference_point = Point(greatest_x, smallest_y)
|
| 379 |
|
| 380 |
polygons_and_index = []
|
| 381 |
for i, rect in enumerate(rects):
|
| 382 |
x1, y1, x2, y2 = rect
|
| 383 |
-
polygons_and_index.append((box(x1,y1,x2,y2), i))
|
| 384 |
# sort points by closest to reference point
|
| 385 |
-
polygons_and_index = sorted(
|
|
|
|
| 386 |
indices = [x[1] for x in polygons_and_index]
|
| 387 |
return indices
|
| 388 |
|
|
|
|
| 389 |
def force_to_be_valid_bboxes(bboxes):
|
| 390 |
if len(bboxes) == 0:
|
| 391 |
return bboxes
|
|
@@ -394,20 +435,24 @@ def force_to_be_valid_bboxes(bboxes):
|
|
| 394 |
bboxes_as_xywh[:, 2] = torch.clamp(bboxes_as_xywh[:, 2], min=1)
|
| 395 |
bboxes_as_xywh[:, 3] = torch.clamp(bboxes_as_xywh[:, 3], min=1)
|
| 396 |
bboxes_as_xywh = bboxes_as_xywh.tolist()
|
| 397 |
-
bboxes_as_xyxy = [[x1, y1, x1 + w, y1 + h]
|
|
|
|
| 398 |
return bboxes_as_xyxy
|
| 399 |
|
|
|
|
| 400 |
def x1y1wh_to_x1y1x2y2(bbox):
|
| 401 |
x1, y1, w, h = bbox
|
| 402 |
return [x1, y1, x1 + w, y1 + h]
|
| 403 |
|
|
|
|
| 404 |
def x1y1x2y2_to_xywh(bbox):
|
| 405 |
x1, y1, x2, y2 = bbox
|
| 406 |
return [x1, y1, x2 - x1, y2 - y1]
|
| 407 |
|
|
|
|
| 408 |
def convert_to_list_of_lists(rects):
|
| 409 |
if isinstance(rects, torch.Tensor):
|
| 410 |
return rects.tolist()
|
| 411 |
if isinstance(rects, np.ndarray):
|
| 412 |
return rects.tolist()
|
| 413 |
-
return [[a, b, c, d] for a, b, c, d in rects]
|
|
|
|
| 9 |
from itertools import groupby
|
| 10 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
| 11 |
|
| 12 |
+
|
| 13 |
def move_to_device(inputs, device):
|
| 14 |
if hasattr(inputs, "keys"):
|
| 15 |
return {k: move_to_device(v, device) for k, v in inputs.items()}
|
|
|
|
| 22 |
else:
|
| 23 |
return inputs.to(device)
|
| 24 |
|
| 25 |
+
|
| 26 |
class UnionFind:
|
| 27 |
def __init__(self, n):
|
| 28 |
self.parent = list(range(n))
|
|
|
|
| 37 |
if adj_matrix[i, j] > 0:
|
| 38 |
ufds.unite(i, j)
|
| 39 |
return ufds
|
| 40 |
+
|
| 41 |
@classmethod
|
| 42 |
def from_adj_list(cls, adj_list):
|
| 43 |
ufds = cls(len(adj_list))
|
|
|
|
| 45 |
for j in adj_list[i]:
|
| 46 |
ufds.unite(i, j)
|
| 47 |
return ufds
|
| 48 |
+
|
| 49 |
@classmethod
|
| 50 |
def from_edge_list(cls, edge_list, num_nodes):
|
| 51 |
ufds = cls(num_nodes)
|
|
|
|
| 68 |
self.parent[y] = x
|
| 69 |
self.size[x] += self.size[y]
|
| 70 |
self.num_components -= 1
|
| 71 |
+
|
| 72 |
def get_components_of(self, x):
|
| 73 |
x = self.find(x)
|
| 74 |
return [i for i in range(len(self.parent)) if self.find(i) == x]
|
| 75 |
+
|
| 76 |
def are_connected(self, x, y):
|
| 77 |
return self.find(x) == self.find(y)
|
| 78 |
|
|
|
|
| 81 |
|
| 82 |
def get_num_components(self):
|
| 83 |
return self.num_components
|
| 84 |
+
|
| 85 |
def get_labels_for_connected_components(self):
|
| 86 |
map_parent_to_label = {}
|
| 87 |
labels = []
|
|
|
|
| 92 |
labels.append(map_parent_to_label[parent])
|
| 93 |
return labels
|
| 94 |
|
| 95 |
+
|
| 96 |
def visualise_single_image_prediction(image_as_np_array, predictions, filename):
|
| 97 |
h, w = image_as_np_array.shape[:2]
|
| 98 |
if h > w:
|
|
|
|
| 105 |
plot_bboxes(subplot, predictions["characters"], color="blue")
|
| 106 |
|
| 107 |
COLOURS = [
|
| 108 |
+
"#b7ff51", # green
|
| 109 |
+
"#f50a8f", # pink
|
| 110 |
+
"#4b13b6", # purple
|
| 111 |
+
"#ddaa34", # orange
|
| 112 |
+
"#bea2a2", # brown
|
| 113 |
]
|
| 114 |
colour_index = 0
|
| 115 |
character_cluster_labels = predictions["character_cluster_labels"]
|
| 116 |
+
unique_label_sorted_by_frequency = sorted(list(set(
|
| 117 |
+
character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
|
| 118 |
for label in unique_label_sorted_by_frequency:
|
| 119 |
root = None
|
| 120 |
others = []
|
|
|
|
| 127 |
if colour_index >= len(COLOURS):
|
| 128 |
random_colour = COLOURS[0]
|
| 129 |
while random_colour in COLOURS:
|
| 130 |
+
random_colour = "#" + \
|
| 131 |
+
"".join([random.choice("0123456789ABCDEF")
|
| 132 |
+
for j in range(6)])
|
| 133 |
else:
|
| 134 |
random_colour = COLOURS[colour_index]
|
| 135 |
colour_index += 1
|
|
|
|
| 145 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 146 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 147 |
subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
|
| 148 |
+
subplot.plot([x2], [y2], color=random_colour,
|
| 149 |
+
marker="o", markersize=5)
|
| 150 |
+
|
| 151 |
for (i, j) in predictions["text_character_associations"]:
|
| 152 |
score = predictions["dialog_confidences"][i]
|
| 153 |
bbox_i = predictions["texts"][i]
|
|
|
|
| 156 |
y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
|
| 157 |
x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
|
| 158 |
y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
|
| 159 |
+
subplot.plot([x1, x2], [y1, y2], color="red",
|
| 160 |
+
linewidth=2, linestyle="dashed", alpha=score)
|
| 161 |
|
| 162 |
subplot.axis("off")
|
| 163 |
if filename is not None:
|
|
|
|
| 168 |
plt.close()
|
| 169 |
return image
|
| 170 |
|
| 171 |
+
|
| 172 |
def plot_bboxes(subplot, bboxes, color="red", add_index=False):
|
| 173 |
for id, bbox in enumerate(bboxes):
|
| 174 |
w = bbox[2] - bbox[0]
|
|
|
|
| 179 |
subplot.add_patch(rect)
|
| 180 |
if add_index:
|
| 181 |
cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
|
| 182 |
+
subplot.text(cx, cy, str(id), color=color,
|
| 183 |
+
fontsize=10, ha="center", va="center")
|
| 184 |
+
|
| 185 |
|
| 186 |
def sort_panels(rects):
|
| 187 |
before_rects = convert_to_list_of_lists(rects)
|
|
|
|
| 214 |
G.remove_edge(*max_cyclic_edge)
|
| 215 |
return list(nx.topological_sort(G))
|
| 216 |
|
| 217 |
+
|
| 218 |
def is_strictly_above(rectA, rectB):
|
| 219 |
x1A, y1A, x2A, y2A = rectA
|
| 220 |
x1B, y1B, x2B, y2B = rectB
|
| 221 |
return y2A < y1B
|
| 222 |
|
| 223 |
+
|
| 224 |
def is_strictly_below(rectA, rectB):
|
| 225 |
x1A, y1A, x2A, y2A = rectA
|
| 226 |
x1B, y1B, x2B, y2B = rectB
|
| 227 |
return y2B < y1A
|
| 228 |
|
| 229 |
+
|
| 230 |
def is_strictly_left_of(rectA, rectB):
|
| 231 |
x1A, y1A, x2A, y2A = rectA
|
| 232 |
x1B, y1B, x2B, y2B = rectB
|
| 233 |
return x2A < x1B
|
| 234 |
|
| 235 |
+
|
| 236 |
def is_strictly_right_of(rectA, rectB):
|
| 237 |
x1A, y1A, x2A, y2A = rectA
|
| 238 |
x1B, y1B, x2B, y2B = rectB
|
| 239 |
return x2B < x1A
|
| 240 |
|
| 241 |
+
|
| 242 |
def intersects(rectA, rectB):
|
| 243 |
return box(*rectA).intersects(box(*rectB))
|
| 244 |
|
| 245 |
+
|
| 246 |
def is_there_a_directed_edge(a, b, rects):
|
| 247 |
rectA = rects[a]
|
| 248 |
rectB = rects[b]
|
| 249 |
+
centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
|
| 250 |
+
rectA[1] + (rectA[3] - rectA[1]) / 2]
|
| 251 |
+
centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
|
| 252 |
+
rectB[1] + (rectB[3] - rectB[1]) / 2]
|
| 253 |
if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
|
| 254 |
return box(*rectA).area > (box(*rectB)).area
|
| 255 |
copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
|
|
|
|
| 266 |
if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
|
| 267 |
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 268 |
if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
|
| 269 |
+
return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
|
| 270 |
# otherwise they intersect
|
| 271 |
copy_A = erode_rectangle(copy_A, 0.05)
|
| 272 |
copy_B = erode_rectangle(copy_B, 0.05)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
def get_distance(rectA, rectB):
|
| 276 |
return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
|
| 277 |
|
| 278 |
+
|
| 279 |
def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
|
| 280 |
rects = deepcopy(rects)
|
| 281 |
while True:
|
| 282 |
+
xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
|
| 283 |
+
rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
|
| 284 |
+
rect_index = [i for i in range(len(rects)) if intersects(
|
| 285 |
+
rects[i], [xmin, ymin, xmax, ymax])]
|
| 286 |
+
rects_copy = [rect for rect in rects if intersects(
|
| 287 |
+
rect, [xmin, ymin, xmax, ymax])]
|
| 288 |
+
|
| 289 |
# try to split the panels using a "horizontal" lines
|
| 290 |
+
overlapping_y_ranges = merge_overlapping_ranges(
|
| 291 |
+
[(y1, y2) for x1, y1, x2, y2 in rects_copy])
|
| 292 |
panel_index_to_split = {}
|
| 293 |
for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
|
| 294 |
for i, index in enumerate(rect_index):
|
| 295 |
if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
|
| 296 |
panel_index_to_split[index] = split_index
|
| 297 |
+
|
| 298 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 299 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 300 |
+
|
| 301 |
# try to split the panels using a "vertical" lines
|
| 302 |
+
overlapping_x_ranges = merge_overlapping_ranges(
|
| 303 |
+
[(x1, x2) for x1, y1, x2, y2 in rects_copy])
|
| 304 |
panel_index_to_split = {}
|
| 305 |
for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
|
| 306 |
for i, index in enumerate(rect_index):
|
|
|
|
| 308 |
panel_index_to_split[index] = split_index
|
| 309 |
if panel_index_to_split[a] != panel_index_to_split[b]:
|
| 310 |
return panel_index_to_split[a] < panel_index_to_split[b]
|
| 311 |
+
|
| 312 |
# otherwise, erode the rectangles and try again
|
| 313 |
rects = [erode_rectangle(rect, 0.05) for rect in rects]
|
| 314 |
|
| 315 |
+
|
| 316 |
def erode_rectangle(bbox, erosion_factor):
|
| 317 |
x1, y1, x2, y2 = bbox
|
| 318 |
w, h = x2 - x1, y2 - y1
|
|
|
|
| 330 |
x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
|
| 331 |
return [x1, y1, x2, y2]
|
| 332 |
|
| 333 |
+
|
| 334 |
def merge_overlapping_ranges(ranges):
|
| 335 |
"""
|
| 336 |
ranges: list of tuples (x1, x2)
|
|
|
|
| 352 |
merged_ranges.append((prev_x1, prev_x2))
|
| 353 |
return merged_ranges
|
| 354 |
|
| 355 |
+
|
| 356 |
def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
|
| 357 |
text_bboxes = convert_to_list_of_lists(text_bboxes)
|
| 358 |
sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
|
|
|
|
| 364 |
groups = groupby(range(len(nums)), key=lambda i: nums[i])
|
| 365 |
return [list(indices) for _, indices in groups]
|
| 366 |
|
| 367 |
+
panel_id_for_text = get_text_to_panel_mapping(
|
| 368 |
+
text_bboxes, sorted_panel_bboxes)
|
| 369 |
indices_of_texts = list(range(len(text_bboxes)))
|
| 370 |
+
indices_of_texts, panel_id_for_text = zip(
|
| 371 |
+
*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
|
| 372 |
indices_of_texts = list(indices_of_texts)
|
| 373 |
grouped_indices = indices_of_same_elements(panel_id_for_text)
|
| 374 |
for group in grouped_indices:
|
| 375 |
subset_of_text_indices = [indices_of_texts[i] for i in group]
|
| 376 |
+
text_bboxes_of_subset = [text_bboxes[i]
|
| 377 |
+
for i in subset_of_text_indices]
|
| 378 |
sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
|
| 379 |
+
indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
|
| 380 |
+
for i in sorted_subset_indices]
|
| 381 |
return indices_of_texts
|
| 382 |
|
| 383 |
+
|
| 384 |
def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
|
| 385 |
text_to_panel_mapping = []
|
| 386 |
for text_bbox in text_bboxes:
|
|
|
|
| 393 |
for j, annotation in enumerate(sorted_panel_bboxes):
|
| 394 |
shapely_annotation_polygon = box(*annotation)
|
| 395 |
if shapely_text_polygon.intersects(shapely_annotation_polygon):
|
| 396 |
+
all_intersections.append(
|
| 397 |
+
(shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
|
| 398 |
+
all_distances.append(
|
| 399 |
+
(shapely_text_polygon.distance(shapely_annotation_polygon), j))
|
| 400 |
if len(all_intersections) == 0:
|
| 401 |
+
text_to_panel_mapping.append(
|
| 402 |
+
min(all_distances, key=lambda x: x[0])[1])
|
| 403 |
else:
|
| 404 |
+
text_to_panel_mapping.append(
|
| 405 |
+
max(all_intersections, key=lambda x: x[0])[1])
|
| 406 |
return text_to_panel_mapping
|
| 407 |
|
| 408 |
+
|
| 409 |
def sort_texts_within_panel(rects):
|
| 410 |
smallest_y = float("inf")
|
| 411 |
greatest_x = float("-inf")
|
|
|
|
| 413 |
x1, y1, x2, y2 = rect
|
| 414 |
smallest_y = min(smallest_y, y1)
|
| 415 |
greatest_x = max(greatest_x, x2)
|
| 416 |
+
|
| 417 |
reference_point = Point(greatest_x, smallest_y)
|
| 418 |
|
| 419 |
polygons_and_index = []
|
| 420 |
for i, rect in enumerate(rects):
|
| 421 |
x1, y1, x2, y2 = rect
|
| 422 |
+
polygons_and_index.append((box(x1, y1, x2, y2), i))
|
| 423 |
# sort points by closest to reference point
|
| 424 |
+
polygons_and_index = sorted(
|
| 425 |
+
polygons_and_index, key=lambda x: reference_point.distance(x[0]))
|
| 426 |
indices = [x[1] for x in polygons_and_index]
|
| 427 |
return indices
|
| 428 |
|
| 429 |
+
|
| 430 |
def force_to_be_valid_bboxes(bboxes):
|
| 431 |
if len(bboxes) == 0:
|
| 432 |
return bboxes
|
|
|
|
| 435 |
bboxes_as_xywh[:, 2] = torch.clamp(bboxes_as_xywh[:, 2], min=1)
|
| 436 |
bboxes_as_xywh[:, 3] = torch.clamp(bboxes_as_xywh[:, 3], min=1)
|
| 437 |
bboxes_as_xywh = bboxes_as_xywh.tolist()
|
| 438 |
+
bboxes_as_xyxy = [[x1, y1, x1 + w, y1 + h]
|
| 439 |
+
for x1, y1, w, h in bboxes_as_xywh]
|
| 440 |
return bboxes_as_xyxy
|
| 441 |
|
| 442 |
+
|
| 443 |
def x1y1wh_to_x1y1x2y2(bbox):
|
| 444 |
x1, y1, w, h = bbox
|
| 445 |
return [x1, y1, x1 + w, y1 + h]
|
| 446 |
|
| 447 |
+
|
| 448 |
def x1y1x2y2_to_xywh(bbox):
|
| 449 |
x1, y1, x2, y2 = bbox
|
| 450 |
return [x1, y1, x2 - x1, y2 - y1]
|
| 451 |
|
| 452 |
+
|
| 453 |
def convert_to_list_of_lists(rects):
|
| 454 |
if isinstance(rects, torch.Tensor):
|
| 455 |
return rects.tolist()
|
| 456 |
if isinstance(rects, np.ndarray):
|
| 457 |
return rects.tolist()
|
| 458 |
+
return [[a, b, c, d] for a, b, c, d in rects]
|