Mateusz Mróz commited on
Commit
bfabb11
·
1 Parent(s): 848f81c

Refactor Florence2Processor and utils.py for improved readability and maintainability; add colab setup script with necessary patches and dependencies for model execution.

Browse files
Files changed (5) hide show
  1. colab copy.py +187 -0
  2. colab.py +52 -16
  3. modeling_florence2.py +382 -209
  4. processing_florence2.py +68 -45
  5. utils.py +87 -42
colab copy.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @title Ostateczne Uruchomienie `magiv3` z Wymaganymi Poprawkami
2
+ # TO nie działa ale pobiera jakieś zależności by dziłąło colab.py
3
+ import os
4
+ import sys
5
+ import re
6
+ import requests
7
+ import torch
8
+ from PIL import Image, ImageDraw, ImageFont
9
+ import numpy as np
10
+ import json
11
+ from IPython.display import display
12
+ import warnings
13
+
14
+ # --- KROK 1: PRZYGOTOWANIE ŚRODOWISKA Z POPRAWKAMI ---
15
+
16
+ print("--- KROK 1: PRZYGOTOWANIE ŚRODOWISKA Z POPRAWKAMI ---")
17
+
18
+ # --- Instalacja zależności ---
19
+ print("⏳ Instaluję `uv` i wszystkie potrzebne pakiety...")
20
+ !curl -LsSf https://astral.sh/uv/install.sh | sh
21
+ os.environ['PATH'] = f"/root/.local/bin:{os.environ['PATH']}"
22
+ !uv pip install --quiet transformers accelerate einops timm scipy tokenizers pulp torch pytorch-metric-learning Pillow requests shapely
23
+ print("✅ Zależności zainstalowane.")
24
+
25
+ # --- Klonowanie repozytorium ---
26
+ repo_path = "/content/magiv3"
27
+ print(f"\n⏳ Klonuję repozytorium do folderu `{repo_path}`...")
28
+ if os.path.exists(repo_path):
29
+ !rm -rf {repo_path}
30
+ !git clone https://huggingface.co/ragavsachdeva/magiv3 {repo_path}
31
+ print("✅ Repozytorium sklonowane.")
32
+
33
+ # --- OSTATECZNA, KOMPLEKSOWA POPRAWKA KODU ---
34
+ file_to_patch = os.path.join(repo_path, "modeling_florence2.py")
35
+ print(f"\n⏳ Nanoszę wszystkie wymagane poprawki na plik `{file_to_patch}`...")
36
+
37
+ try:
38
+ with open(file_to_patch, 'r', encoding='utf-8') as f:
39
+ content = f.read()
40
+
41
+ # Poprawka 1: Dodanie importu GenerationMixin
42
+ if "from transformers.generation.utils import GenerationMixin" not in content:
43
+ content = content.replace(
44
+ "from transformers.modeling_utils import PreTrainedModel",
45
+ "from transformers.generation.utils import GenerationMixin\nfrom transformers.modeling_utils import PreTrainedModel"
46
+ )
47
+ print("PATCH 1: Dodano import `GenerationMixin`.")
48
+
49
+ # Poprawka 2: Naprawa klasy bazowej modelu językowego (dla metody .generate)
50
+ original_lang_class = "class Florence2LanguagePreTrainedModel(PreTrainedModel):"
51
+ patched_lang_class = "class Florence2LanguagePreTrainedModel(GenerationMixin, PreTrainedModel):"
52
+ if original_lang_class in content:
53
+ content = content.replace(original_lang_class, patched_lang_class)
54
+ print("PATCH 2: Poprawiono dziedziczenie `Florence2LanguagePreTrainedModel`.")
55
+
56
+ # Poprawka 3: Usunięcie wadliwych właściwości, które powodują błąd inicjalizacji
57
+ faulty_property_block = r"""
58
+ @property
59
+ def _supports_flash_attn_2\(self\):.*?return self.language_model._supports_flash_attn_2
60
+
61
+ @property
62
+ def _supports_sdpa\(self\):.*?return self.language_model._supports_sdpa"""
63
+
64
+ if re.search(faulty_property_block, content, flags=re.DOTALL):
65
+ content = re.sub(faulty_property_block, "", content, flags=re.DOTALL)
66
+ print("PATCH 3: Usunięto wadliwe właściwości `@property`.")
67
+
68
+ # Poprawka 4: Naprawa błędu nadpisywania modelu w __init__
69
+ faulty_init_block = r""" language_model = Florence2LanguageForConditionalGeneration\(config=config.text_config\)
70
+
71
+ if language_model._tied_weights_keys is not None:
72
+ self._tied_weights_keys = \[f"language_model.{k}" for k in language_model._tied_weights_keys\]
73
+ self.language_model = language_model"""
74
+
75
+ correct_init_block = r""" # This line is intentionally left blank.
76
+ # The language_model is already initialized by the parent class.
77
+ # The original code had a bug here that overwrote the pretrained language model.
78
+ if self.language_model._tied_weights_keys is not None:
79
+ self._tied_weights_keys = [f"language_model.{k}" for k in self.language_model._tied_weights_keys]"""
80
+
81
+ if re.search(faulty_init_block, content, flags=re.DOTALL):
82
+ content = re.sub(faulty_init_block, correct_init_block, content, flags=re.DOTALL)
83
+ print("PATCH 4: Naprawiono błąd nadpisywania modelu w `__init__`.")
84
+
85
+ with open(file_to_patch, 'w', encoding='utf-8') as f:
86
+ f.write(content)
87
+ print("\n✅ Wszystkie poprawki zostały pomyślnie naniesione!")
88
+
89
+ except Exception as e:
90
+ print(f"❌ Wystąpił krytyczny błąd podczas patchowania pliku: {e}")
91
+ sys.exit()
92
+
93
+
94
+ # --- KROK 2: POBRANIE OBRAZKA TESTOWEGO ---
95
+
96
+ print("\n--- KROK 2: POBRANIE OBRAZKA TESTOWEGO ---")
97
+ IMAGE_URL = "https://raw.githubusercontent.com/MattyMroz/Manga_Whisperer/refs/heads/main/input/raw/04.jpg"
98
+ IMAGE_PATH = "/content/test_image.jpg"
99
+
100
+ try:
101
+ response = requests.get(IMAGE_URL)
102
+ response.raise_for_status()
103
+ with open(IMAGE_PATH, 'wb') as f:
104
+ f.write(response.content)
105
+ print(f"✅ Obrazek testowy został pomyślnie pobrany i zapisany jako `{IMAGE_PATH}`.")
106
+ display(Image.open(IMAGE_PATH).resize((300, 400)))
107
+ except Exception as e:
108
+ print(f"❌ Nie udało się pobrać obrazka. Błąd: {e}")
109
+ sys.exit()
110
+
111
+
112
+ # --- KROK 3: URUCHOMIENIE POPRAWIONEGO MODELU ---
113
+
114
+ print("\n--- KROK 3: URUCHOMIENIE POPRAWIONEGO MODELU ---")
115
+ warnings.filterwarnings("ignore", category=FutureWarning)
116
+
117
+ # Dodajemy poprawiony kod do ścieżki Pythona
118
+ if repo_path not in sys.path:
119
+ sys.path.insert(0, repo_path)
120
+
121
+ # Importujemy klasy z naszego poprawionego kodu
122
+ from magiv3.modeling_florence2 import Florence2ForConditionalGeneration
123
+ from transformers import AutoProcessor
124
+
125
+ model = None
126
+ processor = None
127
+
128
+ try:
129
+ print(f"⏳ Ładowanie modelu i procesora (z użyciem poprawionego kodu)...")
130
+ model_id = "ragavsachdeva/magiv3"
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+
133
+ # Używamy AutoProcessor, ale dla modelu musimy wskazać naszą poprawioną klasę
134
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
135
+ model = Florence2ForConditionalGeneration.from_pretrained(
136
+ model_id,
137
+ torch_dtype=torch.float16,
138
+ trust_remote_code=True
139
+ ).to(device).eval()
140
+
141
+ print("✅ Model i procesor załadowane pomyślnie.")
142
+
143
+ except Exception as e:
144
+ print(f"\n❌ Wystąpił błąd podczas ładowania modelu, nawet po poprawkach: {e}")
145
+ sys.exit()
146
+
147
+ # Uruchamiamy predykcję
148
+ if model and processor:
149
+ try:
150
+ print("\n⏳ Przygotowuję dane wejściowe...")
151
+ images = [Image.open(IMAGE_PATH).convert("RGB")]
152
+ np_images = [np.array(img) for img in images]
153
+ print("✅ Dane wejściowe gotowe.")
154
+
155
+ print("\n⏳ Uruchamiam `predict_detections_and_associations`...")
156
+ with torch.no_grad():
157
+ results = model.predict_detections_and_associations(np_images, processor)
158
+
159
+ print("✅ Przetwarzanie zakończone pomyślnie!")
160
+ print("\n--- WYNIKI ---")
161
+
162
+ # Funkcja do wizualizacji
163
+ def visualize_results(image, results):
164
+ colors = {"panels": "red", "texts": "blue", "characters": "green", "tails": "yellow"}
165
+ draw_image = image.copy()
166
+ draw = ImageDraw.Draw(draw_image)
167
+ try:
168
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 15)
169
+ except IOError:
170
+ font = ImageFont.load_default()
171
+
172
+ for category, bboxes in results.items():
173
+ if category not in colors: continue
174
+ for i, box in enumerate(bboxes):
175
+ draw.rectangle(box, outline=colors[category], width=3)
176
+ draw.text((box[0], box[1]), f"{category}_{i}", fill=colors[category], font=font)
177
+ return draw_image
178
+
179
+ visualized_image = visualize_results(images[0], results[0])
180
+ display(visualized_image)
181
+
182
+ serializable_results = {k: (v.tolist() if isinstance(v, torch.Tensor) else v) for k, v in results[0].items()}
183
+ print(json.dumps(serializable_results, indent=2))
184
+
185
+ except Exception as e:
186
+ print(f"\n❌ WYSTĄPIŁ KRYTYCZNY BŁĄD PODCZAS PRZETWARZANIA:")
187
+ print(f"Błąd: {e}")
colab.py CHANGED
@@ -1,20 +1,20 @@
1
  # ==============================================================================
2
  # 1) INSTALACJA PAKIETÓW
3
  # ==============================================================================
 
 
 
 
 
 
 
 
 
4
  !pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning"
5
 
6
  # ==============================================================================
7
  # 2) IMPORTY
8
  # ==============================================================================
9
- import re
10
- import requests
11
- import torch
12
- import json
13
- import math
14
- from PIL import Image, ImageDraw
15
- from io import BytesIO
16
- from IPython.display import display
17
- from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
18
 
19
  # ==============================================================================
20
  # 3) POBRANIE OBRAZU
@@ -62,9 +62,15 @@ print("Model i procesor załadowane pomyślnie.")
62
  # ==============================================================================
63
 
64
 
65
- def create_visualization(image, data):
66
  """
67
  Rysuje zaawansowaną wizualizację detekcji i asocjacji na obrazie.
 
 
 
 
 
 
68
  """
69
  img_draw = image.copy()
70
  draw = ImageDraw.Draw(img_draw)
@@ -77,8 +83,10 @@ def create_visualization(image, data):
77
  "tails": "purple",
78
  "cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"],
79
  "speaker_line": "magenta",
 
 
80
  }
81
- line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1}
82
 
83
  def get_box_center(box):
84
  x1, y1, x2, y2 = box
@@ -131,10 +139,35 @@ def create_visualization(image, data):
131
  draw_dashed_line(
132
  draw, p1, p2, fill=colors["speaker_line"], width=1)
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  return img_draw
135
 
136
 
137
- def process_image(image, caption_for_grounding="elf girl"):
 
 
 
 
 
 
 
 
 
138
  print("\n--- Rozpoczynanie przetwarzania obrazu ---")
139
  images = [image]
140
  captions = [caption_for_grounding]
@@ -157,8 +190,9 @@ def process_image(image, caption_for_grounding="elf girl"):
157
  "grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))]
158
  }
159
 
160
- print("Tworzenie zaawansowanej wizualizacji...")
161
- visualization_image = create_visualization(image, final_json)
 
162
 
163
  print("--- Zakończono przetwarzanie ---")
164
  return final_json, visualization_image
@@ -167,9 +201,11 @@ def process_image(image, caption_for_grounding="elf girl"):
167
  # 6) URUCHOMIENIE I WYŚWIETLENIE WYNIKÓW
168
  # ==============================================================================
169
 
170
-
 
 
171
  json_output, image_output = process_image(
172
- pil_img, caption_for_grounding="elf girl")
173
 
174
  print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====")
175
  print(json.dumps(json_output, indent=2))
 
1
  # ==============================================================================
2
  # 1) INSTALACJA PAKIETÓW
3
  # ==============================================================================
4
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
5
+ from IPython.display import display
6
+ from io import BytesIO
7
+ from PIL import Image, ImageDraw
8
+ import math
9
+ import json
10
+ import torch
11
+ import requests
12
+ import re
13
  !pip -q install -U "transformers" "huggingface_hub" "accelerate" "timm" "sentencepiece" "safensors" "pillow" "einops" "pytorch_metric_learning"
14
 
15
  # ==============================================================================
16
  # 2) IMPORTY
17
  # ==============================================================================
 
 
 
 
 
 
 
 
 
18
 
19
  # ==============================================================================
20
  # 3) POBRANIE OBRAZU
 
62
  # ==============================================================================
63
 
64
 
65
+ def create_visualization(image, data, detailed_mode=False):
66
  """
67
  Rysuje zaawansowaną wizualizację detekcji i asocjacji na obrazie.
68
+
69
+ Args:
70
+ image: Obraz wejściowy
71
+ data: Dane JSON z wynikami
72
+ detailed_mode: Jeśli True, rysuje wszystko z JSON (OCR, grounding).
73
+ Jeśli False (domyślnie), rysuje tylko detekcje i asocjacje.
74
  """
75
  img_draw = image.copy()
76
  draw = ImageDraw.Draw(img_draw)
 
83
  "tails": "purple",
84
  "cluster_colors": ["#f50a8f", "#4b13b6", "#ddaa34", "#b7ff51", "#bea2a2"],
85
  "speaker_line": "magenta",
86
+ "ocr": "orange",
87
+ "grounding": "cyan",
88
  }
89
+ line_widths = {"panels": 2, "texts": 1, "characters": 2, "tails": 1, "ocr": 2, "grounding": 2}
90
 
91
  def get_box_center(box):
92
  x1, y1, x2, y2 = box
 
139
  draw_dashed_line(
140
  draw, p1, p2, fill=colors["speaker_line"], width=1)
141
 
142
+ # Tryb wybredny - rysowanie dodatkowych elementów z JSON
143
+ if detailed_mode:
144
+ # Rysowanie OCR boxes
145
+ ocr_data = data.get("ocr", [])
146
+ for ocr_item in ocr_data:
147
+ box = ocr_item.get("box")
148
+ if box:
149
+ draw.rectangle(box, outline=colors["ocr"], width=line_widths["ocr"])
150
+
151
+ # Rysowanie Grounding boxes
152
+ grounding_data = data.get("grounding", [])
153
+ for grounding_item in grounding_data:
154
+ boxes = grounding_item.get("boxes", [])
155
+ for box in boxes:
156
+ draw.rectangle(box, outline=colors["grounding"], width=line_widths["grounding"])
157
+
158
  return img_draw
159
 
160
 
161
+ def process_image(image, caption_for_grounding="elf girl", detailed_mode=False):
162
+ """
163
+ Przetwarza obraz i tworzy wizualizację.
164
+
165
+ Args:
166
+ image: Obraz wejściowy
167
+ caption_for_grounding: Caption dla character grounding
168
+ detailed_mode: Jeśli True, wizualizacja zawiera wszystko z JSON (OCR, grounding).
169
+ Jeśli False (domyślnie), tylko detekcje i asocjacje.
170
+ """
171
  print("\n--- Rozpoczynanie przetwarzania obrazu ---")
172
  images = [image]
173
  captions = [caption_for_grounding]
 
190
  "grounding": [{"phrase": grounding_results.get("grounded_caption", "")[start:end], "boxes": boxes} for boxes, (start, end) in zip(grounding_results.get("bboxes", []), grounding_results.get("indices_of_bboxes_in_caption", []))]
191
  }
192
 
193
+ mode_text = "wybredny (wszystkie elementy)" if detailed_mode else "domyślny (detekcje i asocjacje)"
194
+ print(f"Tworzenie wizualizacji w trybie: {mode_text}")
195
+ visualization_image = create_visualization(image, final_json, detailed_mode=detailed_mode)
196
 
197
  print("--- Zakończono przetwarzanie ---")
198
  return final_json, visualization_image
 
201
  # 6) URUCHOMIENIE I WYŚWIETLENIE WYNIKÓW
202
  # ==============================================================================
203
 
204
+ # Tryb wizualizacji:
205
+ # detailed_mode=False (domyślny) - rysuje tylko detekcje i asocjacje (obecne kolory)
206
+ # detailed_mode=True (wybredny) - rysuje wszystko z JSON: OCR (pomarańczowy), grounding (cyjan)
207
  json_output, image_output = process_image(
208
+ pil_img, caption_for_grounding="elf girl", detailed_mode=True)
209
 
210
  print("\n\n===== WYNIKI W FORMACIE JSON (przed filtrowaniem) =====")
211
  print(json.dumps(json_output, indent=2))
modeling_florence2.py CHANGED
@@ -23,7 +23,7 @@ import torch.utils.checkpoint
23
  from torch import nn
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint as checkpoint
26
- from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange, repeat
29
  from timm.models.layers import DropPath, trunc_normal_
@@ -41,7 +41,7 @@ from transformers.utils import (
41
  is_flash_attn_2_available,
42
  is_flash_attn_greater_or_equal_2_10,
43
  )
44
- from .configuration_florence2 import Florence2Config
45
  from .configuration_florence2 import Florence2LanguageConfig
46
  from .configuration_florence2 import Florence2VisionConfig
47
  from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_pairs_indices
@@ -72,6 +72,7 @@ logger = logging.get_logger(__name__)
72
 
73
  _CONFIG_FOR_DOC = "Florence2Config"
74
 
 
75
  class LearnedAbsolutePositionEmbedding2D(nn.Module):
76
  """
77
  This module learns positional embeddings up to a fixed maximum size.
@@ -80,7 +81,8 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
80
  def __init__(self, embedding_dim=256, num_pos=50):
81
  super().__init__()
82
  self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
83
- self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2))
 
84
 
85
  def forward(self, pixel_values):
86
  """
@@ -95,7 +97,8 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
95
  x_emb = self.column_embeddings(width_values)
96
  y_emb = self.row_embeddings(height_values)
97
  # (height, width, embedding_dim * 2)
98
- pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
 
99
  # (embedding_dim * 2, height, width)
100
  pos = pos.permute(2, 0, 1)
101
  pos = pos.unsqueeze(0)
@@ -105,6 +108,7 @@ class LearnedAbsolutePositionEmbedding2D(nn.Module):
105
  pos = pos.permute(0, 2, 3, 1)
106
  return pos
107
 
 
108
  class PositionalEmbeddingCosine1D(nn.Module):
109
  """
110
  This class implements a very simple positional encoding. It follows closely
@@ -116,6 +120,7 @@ class PositionalEmbeddingCosine1D(nn.Module):
116
  dropout_prob: The dropout probability.
117
  max_seq_len: The maximum length to precompute the positional encodings.
118
  """
 
119
  def __init__(
120
  self,
121
  embed_dim: int = 512,
@@ -171,6 +176,7 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
171
  embed_dim: The dimension of the embeddings.
172
  max_seq_len: The maximum length to precompute the positional encodings.
173
  """
 
174
  def __init__(
175
  self,
176
  embedding_dim: int = 512,
@@ -196,7 +202,8 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
196
  len_seq = seq_embeds.size(-2)
197
  assert len_seq <= self.num_pos
198
  # [T, D]
199
- pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device))
 
200
  # Adapt pre-computed positional embeddings to the input.
201
  if shape_len == 3:
202
  pos_embeds = pos_embeds.view(
@@ -204,7 +211,6 @@ class LearnedAbsolutePositionEmbedding1D(nn.Module):
204
  return pos_embeds
205
 
206
 
207
-
208
  class MySequential(nn.Sequential):
209
  def forward(self, *inputs):
210
  for module in self._modules.values():
@@ -349,7 +355,8 @@ class ChannelAttention(nn.Module):
349
  def forward(self, x, size):
350
  B, N, C = x.shape
351
 
352
- qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
 
353
  q, k, v = qkv[0], qkv[1], qkv[2]
354
 
355
  q = q * (float(N) ** -0.5)
@@ -368,18 +375,22 @@ class ChannelBlock(nn.Module):
368
  conv_at_attn=True, conv_at_ffn=True):
369
  super().__init__()
370
 
371
- drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
372
 
373
- self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
 
374
  self.channel_attn = PreNorm(
375
  norm_layer(dim),
376
  ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
377
  drop_path
378
  )
379
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
 
380
  self.ffn = PreNorm(
381
  norm_layer(dim),
382
- Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
 
383
  drop_path
384
  )
385
 
@@ -397,16 +408,19 @@ class ChannelBlock(nn.Module):
397
 
398
  def window_partition(x, window_size: int):
399
  B, H, W, C = x.shape
400
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
401
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
 
 
402
  return windows
403
 
404
 
405
  def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
406
- B = batch_size
407
  # this will cause onnx conversion failed for dynamic axis, because treated as constant
408
- # int(windows.shape[0] / (H * W / window_size / window_size))
409
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
 
410
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
411
  return x
412
 
@@ -447,7 +461,8 @@ class WindowAttention(nn.Module):
447
  # attn_windows = self.attn(x_windows)
448
 
449
  B_, N, C = x.shape
450
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
451
  q, k, v = qkv[0], qkv[1], qkv[2]
452
 
453
  q = q * self.scale
@@ -478,18 +493,22 @@ class SpatialBlock(nn.Module):
478
  norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
479
  super().__init__()
480
 
481
- drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
 
482
 
483
- self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None
 
484
  self.window_attn = PreNorm(
485
  norm_layer(dim),
486
  WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
487
  drop_path
488
  )
489
- self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None
 
490
  self.ffn = PreNorm(
491
  norm_layer(dim),
492
- Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer),
 
493
  drop_path
494
  )
495
 
@@ -547,7 +566,7 @@ class DaViT(nn.Module):
547
  enable_checkpoint=True,
548
  conv_at_attn=True,
549
  conv_at_ffn=True,
550
- ):
551
  super().__init__()
552
 
553
  self.num_classes = num_classes
@@ -559,7 +578,8 @@ class DaViT(nn.Module):
559
  assert self.num_stages == len(self.num_heads) == len(self.num_groups)
560
 
561
  num_stages = len(embed_dims)
562
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)]
 
563
 
564
  depth_offset = 0
565
  convs = []
@@ -613,7 +633,8 @@ class DaViT(nn.Module):
613
 
614
  self.norms = norm_layer(self.embed_dims[-1])
615
  self.avgpool = nn.AdaptiveAvgPool1d(1)
616
- self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
 
617
 
618
  self.apply(self._init_weights)
619
 
@@ -648,7 +669,8 @@ class DaViT(nn.Module):
648
  for conv, block in zip(self.convs, self.blocks):
649
  x, input_size = conv(x, input_size)
650
  if self.enable_checkpoint:
651
- x, input_size = checkpoint.checkpoint(block, x, input_size, use_reentrant=True)
 
652
  else:
653
  x, input_size = block(x, input_size)
654
  return x
@@ -668,7 +690,7 @@ class DaViT(nn.Module):
668
  x = self.forward_features(x)
669
  x = self.head(x)
670
  return x
671
-
672
  @classmethod
673
  def from_config(cls, config):
674
  return cls(
@@ -685,18 +707,19 @@ class DaViT(nn.Module):
685
  )
686
 
687
 
688
-
689
-
690
  if is_flash_attn_2_available():
691
  from flash_attn import flash_attn_func, flash_attn_varlen_func
692
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
693
 
694
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
 
695
  def _get_unpad_data(attention_mask):
696
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
697
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
698
  max_seqlen_in_batch = seqlens_in_batch.max().item()
699
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 
700
  return (
701
  indices,
702
  cu_seqlens,
@@ -834,7 +857,8 @@ class Florence2Attention(nn.Module):
834
  if past_key_value[0] is not None:
835
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
836
  if past_key_value[1] is not None:
837
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
838
  else:
839
  # self_attention
840
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@@ -851,7 +875,8 @@ class Florence2Attention(nn.Module):
851
  past_key_value = (key_states, value_states)
852
 
853
  proj_shape = (bsz * self.num_heads, -1, self.head_dim)
854
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
 
855
  key_states = key_states.reshape(*proj_shape)
856
  value_states = value_states.reshape(*proj_shape)
857
 
@@ -869,8 +894,10 @@ class Florence2Attention(nn.Module):
869
  raise ValueError(
870
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
871
  )
872
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
873
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
 
 
874
 
875
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
876
 
@@ -880,20 +907,25 @@ class Florence2Attention(nn.Module):
880
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
881
  f" {layer_head_mask.size()}"
882
  )
883
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
884
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
 
 
885
 
886
  if output_attentions:
887
  # this operation is a bit awkward, but it's required to
888
  # make sure that attn_weights keeps its gradient.
889
  # In order to do so, attn_weights have to be reshaped
890
  # twice and have to be reused in the following
891
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
892
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
 
 
893
  else:
894
  attn_weights_reshaped = None
895
 
896
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
 
897
 
898
  attn_output = torch.bmm(attn_probs, value_states)
899
 
@@ -903,7 +935,8 @@ class Florence2Attention(nn.Module):
903
  f" {attn_output.size()}"
904
  )
905
 
906
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
 
907
  attn_output = attn_output.transpose(1, 2)
908
 
909
  # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
@@ -945,7 +978,8 @@ class Florence2FlashAttention2(Florence2Attention):
945
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
946
  # Florence2FlashAttention2 attention does not support output_attentions
947
  if output_attentions:
948
- raise ValueError("Florence2FlashAttention2 attention does not support output_attentions")
 
949
 
950
  # if key_value_states are provided this layer is used as a cross-attention layer
951
  # for the decoder
@@ -970,13 +1004,16 @@ class Florence2FlashAttention2(Florence2Attention):
970
  elif is_cross_attention:
971
  # cross_attentions
972
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
973
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
 
974
  elif past_key_value is not None:
975
  # reuse k, v, self_attention
976
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
977
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
978
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
979
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
 
 
980
  else:
981
  # self_attention
982
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
@@ -990,7 +1027,8 @@ class Florence2FlashAttention2(Florence2Attention):
990
  # all previous decoder key/value_states. Further calls to uni-directional self-attention
991
  # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
992
  # if encoder bi-directional self-attention `past_key_value` is always `None`
993
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
 
994
 
995
  kv_seq_len = key_states.shape[-2]
996
  if past_key_value is not None:
@@ -1086,7 +1124,8 @@ class Florence2FlashAttention2(Florence2Attention):
1086
  causal=causal,
1087
  )
1088
 
1089
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
1090
  else:
1091
  attn_output = flash_attn_func(
1092
  query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
@@ -1096,18 +1135,22 @@ class Florence2FlashAttention2(Florence2Attention):
1096
 
1097
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
1098
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
1099
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
 
1100
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1101
 
1102
  key_layer = index_first_axis(
1103
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
1104
  )
1105
  value_layer = index_first_axis(
1106
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
1107
  )
1108
  if query_length == kv_seq_len:
1109
  query_layer = index_first_axis(
1110
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
 
1111
  )
1112
  cu_seqlens_q = cu_seqlens_k
1113
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -1122,7 +1165,8 @@ class Florence2FlashAttention2(Florence2Attention):
1122
  else:
1123
  # The -q_len: slice assumes left padding.
1124
  attention_mask = attention_mask[:, -query_length:]
1125
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
1126
 
1127
  return (
1128
  query_layer,
@@ -1192,7 +1236,8 @@ class Florence2SdpaAttention(Florence2Attention):
1192
  if past_key_value[0] is not None:
1193
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
1194
  if past_key_value[1] is not None:
1195
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
1196
  else:
1197
  # self_attention
1198
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
@@ -1294,23 +1339,28 @@ class Florence2EncoderLayer(nn.Module):
1294
  layer_head_mask=layer_head_mask,
1295
  output_attentions=output_attentions,
1296
  )
1297
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1298
  hidden_states = residual + hidden_states
1299
  hidden_states = self.self_attn_layer_norm(hidden_states)
1300
 
1301
  residual = hidden_states
1302
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1303
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
1304
  hidden_states = self.fc2(hidden_states)
1305
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1306
  hidden_states = residual + hidden_states
1307
  hidden_states = self.final_layer_norm(hidden_states)
1308
 
1309
  if hidden_states.dtype == torch.float16 and (
1310
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
 
1311
  ):
1312
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1313
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
 
1314
 
1315
  outputs = (hidden_states,)
1316
 
@@ -1384,7 +1434,8 @@ class Florence2DecoderLayer(nn.Module):
1384
 
1385
  # Self Attention
1386
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
1387
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
 
1388
  # add present self-attn cache to positions 1,2 of present_key_value tuple
1389
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1390
  hidden_states=hidden_states,
@@ -1393,7 +1444,8 @@ class Florence2DecoderLayer(nn.Module):
1393
  layer_head_mask=layer_head_mask,
1394
  output_attentions=output_attentions,
1395
  )
1396
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1397
  hidden_states = residual + hidden_states
1398
  hidden_states = self.self_attn_layer_norm(hidden_states)
1399
 
@@ -1404,7 +1456,8 @@ class Florence2DecoderLayer(nn.Module):
1404
  residual = hidden_states
1405
 
1406
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
1407
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
 
1408
  hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
1409
  hidden_states=hidden_states,
1410
  key_value_states=encoder_hidden_states,
@@ -1413,7 +1466,8 @@ class Florence2DecoderLayer(nn.Module):
1413
  past_key_value=cross_attn_past_key_value,
1414
  output_attentions=output_attentions,
1415
  )
1416
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1417
  hidden_states = residual + hidden_states
1418
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
1419
 
@@ -1423,9 +1477,11 @@ class Florence2DecoderLayer(nn.Module):
1423
  # Fully Connected
1424
  residual = hidden_states
1425
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1426
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
 
1427
  hidden_states = self.fc2(hidden_states)
1428
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1429
  hidden_states = residual + hidden_states
1430
  hidden_states = self.final_layer_norm(hidden_states)
1431
 
@@ -1440,7 +1496,6 @@ class Florence2DecoderLayer(nn.Module):
1440
  return outputs
1441
 
1442
 
1443
-
1444
  class Florence2LanguagePreTrainedModel(PreTrainedModel):
1445
  config_class = Florence2LanguageConfig
1446
  base_model_prefix = "model"
@@ -1465,7 +1520,8 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
1465
  @property
1466
  def dummy_inputs(self):
1467
  pad_token = self.config.pad_token_id
1468
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
 
1469
  dummy_inputs = {
1470
  "attention_mask": input_ids.ne(pad_token),
1471
  "input_ids": input_ids,
@@ -1505,7 +1561,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1505
  config.max_position_embeddings,
1506
  embed_dim,
1507
  )
1508
- self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)])
 
1509
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1510
  self._use_sdpa = config._attn_implementation == "sdpa"
1511
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -1574,14 +1631,16 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1574
 
1575
  # retrieve input_ids and inputs_embeds
1576
  if input_ids is not None and inputs_embeds is not None:
1577
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
1578
  elif input_ids is not None:
1579
  input = input_ids
1580
  input_ids = input_ids.view(-1, input_ids.shape[-1])
1581
  elif inputs_embeds is not None:
1582
  input = inputs_embeds[:, :, -1]
1583
  else:
1584
- raise ValueError("You have to specify either input_ids or inputs_embeds")
 
1585
 
1586
  if inputs_embeds is None:
1587
  inputs_embeds = self.embed_tokens(input_ids)
@@ -1591,7 +1650,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1591
 
1592
  hidden_states = inputs_embeds + embed_pos
1593
  hidden_states = self.layernorm_embedding(hidden_states)
1594
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1595
 
1596
  # expand attention_mask
1597
  if attention_mask is not None:
@@ -1601,10 +1661,12 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1601
  # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1602
  # the manual implementation that requires a 4D causal mask in all cases.
1603
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1604
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
 
1605
  else:
1606
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1607
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
 
1608
 
1609
  encoder_states = () if output_hidden_states else None
1610
  all_attentions = () if output_attentions else None
@@ -1642,7 +1704,8 @@ class Florence2Encoder(Florence2LanguagePreTrainedModel):
1642
  layer_outputs = encoder_layer(
1643
  hidden_states,
1644
  attention_mask,
1645
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
 
1646
  output_attentions=output_attentions,
1647
  )
1648
 
@@ -1676,7 +1739,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1676
  self.layerdrop = config.decoder_layerdrop
1677
  self.padding_idx = config.pad_token_id
1678
  self.max_target_positions = config.max_position_embeddings
1679
- embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
 
1680
 
1681
  self.embed_tokens = Florence2ScaledWordEmbedding(
1682
  config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
@@ -1689,7 +1753,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1689
  config.max_position_embeddings,
1690
  config.d_model,
1691
  )
1692
- self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)])
 
1693
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1694
  self._use_sdpa = config._attn_implementation == "sdpa"
1695
 
@@ -1794,7 +1859,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1794
 
1795
  # retrieve input_ids and inputs_embeds
1796
  if input_ids is not None and inputs_embeds is not None:
1797
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
1798
  elif input_ids is not None:
1799
  input = input_ids
1800
  input_shape = input.shape
@@ -1803,17 +1869,20 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1803
  input_shape = inputs_embeds.size()[:-1]
1804
  input = inputs_embeds[:, :, -1]
1805
  else:
1806
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
1807
 
1808
  # past_key_values_length
1809
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values and past_key_values[0] and past_key_values[0][0] is not None else 0
 
1810
 
1811
  if inputs_embeds is None:
1812
  inputs_embeds = self.embed_tokens(input)
1813
 
1814
  if self._use_flash_attention_2:
1815
  # 2d mask is passed through the layers
1816
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
 
1817
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1818
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1819
  # the manual implementation that requires a 4D causal mask in all cases.
@@ -1855,7 +1924,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1855
  hidden_states = inputs_embeds + positions
1856
  hidden_states = self.layernorm_embedding(hidden_states)
1857
 
1858
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
 
1859
 
1860
  if self.gradient_checkpointing and self.training:
1861
  if use_cache:
@@ -1867,7 +1937,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1867
  # decoder layers
1868
  all_hidden_states = () if output_hidden_states else None
1869
  all_self_attns = () if output_attentions else None
1870
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
 
1871
  next_decoder_cache = () if use_cache else None
1872
 
1873
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
@@ -1909,7 +1980,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1909
  attention_mask=attention_mask,
1910
  encoder_hidden_states=encoder_hidden_states,
1911
  encoder_attention_mask=encoder_attention_mask,
1912
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
 
1913
  cross_attn_layer_head_mask=(
1914
  cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1915
  ),
@@ -1920,7 +1992,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1920
  hidden_states = layer_outputs[0]
1921
 
1922
  if use_cache:
1923
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
 
1924
 
1925
  if output_attentions:
1926
  all_self_attns += (layer_outputs[1],)
@@ -1949,7 +2022,8 @@ class Florence2Decoder(Florence2LanguagePreTrainedModel):
1949
 
1950
 
1951
  class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
1952
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
 
1953
 
1954
  def __init__(self, config: Florence2LanguageConfig):
1955
  super().__init__(config)
@@ -2035,8 +2109,10 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2035
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
2036
  encoder_outputs = BaseModelOutput(
2037
  last_hidden_state=encoder_outputs[0],
2038
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
2039
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
 
 
2040
  )
2041
 
2042
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
@@ -2072,14 +2148,17 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2072
 
2073
  class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
2074
  base_model_prefix = "model"
2075
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
 
2076
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
2077
 
2078
  def __init__(self, config: Florence2LanguageConfig):
2079
  super().__init__(config)
2080
  self.model = Florence2LanguageModel(config)
2081
- self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
2082
- self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
 
 
2083
 
2084
  # Initialize weights and apply final processing
2085
  self.post_init()
@@ -2091,7 +2170,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2091
  return self.model.get_decoder()
2092
 
2093
  def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
2094
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
 
2095
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2096
  return new_embeddings
2097
 
@@ -2100,7 +2180,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2100
  if new_num_tokens <= old_num_tokens:
2101
  new_bias = self.final_logits_bias[:, :new_num_tokens]
2102
  else:
2103
- extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
 
2104
  new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
2105
  self.register_buffer("final_logits_bias", new_bias)
2106
 
@@ -2141,7 +2222,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2141
 
2142
  if labels is not None:
2143
  if use_cache:
2144
- logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
 
2145
  use_cache = False
2146
  if decoder_input_ids is None and decoder_inputs_embeds is None:
2147
  decoder_input_ids = shift_tokens_right(
@@ -2173,7 +2255,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2173
  if labels is not None:
2174
  labels = labels.to(lm_logits.device)
2175
  loss_fct = CrossEntropyLoss()
2176
- masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
 
2177
 
2178
  if not return_dict:
2179
  output = (lm_logits,) + outputs[1:]
@@ -2227,7 +2310,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2227
  "head_mask": head_mask,
2228
  "decoder_head_mask": decoder_head_mask,
2229
  "cross_attn_head_mask": cross_attn_head_mask,
2230
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
 
2231
  }
2232
 
2233
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
@@ -2239,11 +2323,13 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2239
  for layer_past in past_key_values:
2240
  # cached cross_attention states don't have to be reordered -> they are always the same
2241
  reordered_past += (
2242
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
 
2243
  + layer_past[2:],
2244
  )
2245
  return reordered_past
2246
 
 
2247
  @dataclass
2248
  class Florence2Seq2SeqLMOutput(ModelOutput):
2249
  """
@@ -2429,6 +2515,7 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
2429
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2430
  """
2431
 
 
2432
  @add_start_docstrings(
2433
  """The FLORENCE2 vision model without any head""",
2434
  FLORENCE2_START_DOCSTRING,
@@ -2436,11 +2523,12 @@ FLORENCE2_INPUTS_DOCSTRING = r"""
2436
  class Florence2VisionModel(Florence2PreTrainedModel):
2437
  def __init__(self, config: Florence2VisionConfig):
2438
  super().__init__(config)
2439
- assert config.model_type in ['davit', ""], 'only DaViT is supported for now'
 
2440
  self.vision_tower = DaViT.from_config(config=config)
2441
 
2442
  self.post_init()
2443
-
2444
  def forward(self, pixel_values):
2445
  if len(pixel_values.shape) == 4:
2446
  x = self.vision_tower.forward_features_unpool(pixel_values)
@@ -2456,13 +2544,14 @@ class Florence2VisionModel(Florence2PreTrainedModel):
2456
  class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2457
  def __init__(self, config: Florence2VisionConfig):
2458
  super().__init__(config)
2459
- assert config.model_type in ['davit', ''], 'only DaViT is supported for now'
 
2460
  self.vision_tower = DaViT.from_config(config=config)
2461
 
2462
  self._build_image_projection_layers(config)
2463
 
2464
  self.post_init()
2465
-
2466
  def _build_image_projection_layers(self, config):
2467
  image_dim_out = config.dim_embed[-1]
2468
  dim_projection = config.projection_dim
@@ -2498,7 +2587,7 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2498
  x = self.vision_tower.forward_features_unpool(pixel_values)
2499
  else:
2500
  raise ValueError(f'invalid image shape {pixel_values.shape}')
2501
-
2502
  if self.image_pos_embed is not None:
2503
  x = x.view(batch_size * T, -1, x.shape[-1])
2504
  num_tokens = x.shape[-2]
@@ -2510,15 +2599,18 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2510
  x = x.view(batch_size, T * h*w, x.shape[-1])
2511
 
2512
  if self.visual_temporal_embed is not None:
2513
- visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2514
- x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
 
 
2515
 
2516
  x_feat_dict = {}
2517
 
2518
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2519
  x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2520
 
2521
- temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
 
2522
  x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2523
 
2524
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
@@ -2527,7 +2619,8 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2527
  new_x = []
2528
  for _image_feature_source in self.image_feature_source:
2529
  if _image_feature_source not in x_feat_dict:
2530
- raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
 
2531
  new_x.append(x_feat_dict[_image_feature_source])
2532
 
2533
  x = torch.cat(new_x, dim=1)
@@ -2535,11 +2628,9 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2535
  x = x @ self.image_projection
2536
  x = self.image_proj_norm(x)
2537
 
2538
-
2539
  return x
2540
 
2541
 
2542
-
2543
  @add_start_docstrings(
2544
  """The FLORENCE2 model which consists of a vision backbone and a language model.""",
2545
  FLORENCE2_START_DOCSTRING,
@@ -2547,9 +2638,10 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2547
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2548
  def __init__(self, config: Florence2Config):
2549
  super().__init__(config)
2550
- assert config.vision_config.model_type in ['davit', ''], 'only DaViT is supported for now'
 
2551
  self.vision_tower = DaViT.from_config(config=config.vision_config)
2552
- # remove unused layers
2553
  del self.vision_tower.head
2554
  del self.vision_tower.norms
2555
 
@@ -2557,10 +2649,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2557
  self._attn_implementation = config._attn_implementation
2558
  self._build_image_projection_layers(config)
2559
 
2560
- language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
 
2561
 
2562
  if language_model._tied_weights_keys is not None:
2563
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
 
2564
  self.language_model = language_model
2565
  self.character_character_matching_head = nn.Sequential(
2566
  nn.Linear(2 * 768, config.projection_dim),
@@ -2584,16 +2678,17 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2584
  nn.Linear(config.projection_dim, 1)
2585
  )
2586
  self.text_classification_head = nn.Linear(config.projection_dim, 1)
2587
- self.character_embedding_projection = nn.Linear(config.projection_dim, 768)
 
2588
 
2589
  self._init_weights(self.character_character_matching_head)
2590
  self._init_weights(self.text_character_matching_head)
2591
  self._init_weights(self.text_tail_matching_head)
2592
  self._init_weights(self.text_classification_head)
2593
-
2594
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
2595
  self.post_init()
2596
-
2597
  def _init_weights(self, m):
2598
  if isinstance(m, nn.Linear):
2599
  trunc_normal_(m.weight, std=0.02)
@@ -2613,7 +2708,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2613
  elif isinstance(m, nn.Sequential):
2614
  for layer in m:
2615
  self._init_weights(layer)
2616
-
2617
  def _build_image_projection_layers(self, config):
2618
  image_dim_out = config.vision_config.dim_embed[-1]
2619
  dim_projection = config.vision_config.projection_dim
@@ -2652,13 +2747,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2652
  return self.language_model.get_input_embeddings()
2653
 
2654
  def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
2655
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
 
2656
  # update vocab size
2657
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2658
  self.config.vocab_size = model_embeds.num_embeddings
2659
  self.vocab_size = model_embeds.num_embeddings
2660
  return model_embeds
2661
-
2662
  def _encode_image(self, pixel_values):
2663
  if len(pixel_values.shape) == 4:
2664
  batch_size, C, H, W = pixel_values.shape
@@ -2666,7 +2762,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2666
  x = self.vision_tower.forward_features_unpool(pixel_values)
2667
  else:
2668
  raise ValueError(f'invalid image shape {pixel_values.shape}')
2669
-
2670
  if self.image_pos_embed is not None:
2671
  x = x.view(batch_size * T, -1, x.shape[-1])
2672
  num_tokens = x.shape[-2]
@@ -2678,15 +2774,18 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2678
  x = x.view(batch_size, T * h*w, x.shape[-1])
2679
 
2680
  if self.visual_temporal_embed is not None:
2681
- visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2682
- x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
 
 
2683
 
2684
  x_feat_dict = {}
2685
 
2686
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2687
  x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2688
 
2689
- temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1)
 
2690
  x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2691
 
2692
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
@@ -2695,7 +2794,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2695
  new_x = []
2696
  for _image_feature_source in self.image_feature_source:
2697
  if _image_feature_source not in x_feat_dict:
2698
- raise ValueError('invalid image feature source: {}'.format(_image_feature_source))
 
2699
  new_x.append(x_feat_dict[_image_feature_source])
2700
 
2701
  x = torch.cat(new_x, dim=1)
@@ -2703,14 +2803,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2703
  x = x @ self.image_projection
2704
  x = self.image_proj_norm(x)
2705
 
2706
- return x
2707
 
2708
  def _merge_input_ids_with_image_features(
2709
- self, image_features, inputs_embeds
2710
  ):
2711
  batch_size, image_token_length = image_features.size()[:-1]
2712
  device = image_features.device
2713
- image_attention_mask = torch.ones(batch_size, image_token_length, device=device)
 
2714
 
2715
  # task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
2716
  # task_prefix_attention_mask: [batch_size, context_length]
@@ -2718,17 +2819,19 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2718
  return image_features, image_attention_mask
2719
 
2720
  task_prefix_embeds = inputs_embeds
2721
- task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
 
2722
 
2723
  if len(task_prefix_attention_mask.shape) == 3:
2724
  task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2725
 
2726
  # concat [image embeds, task prefix embeds]
2727
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
2728
- attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1)
 
2729
 
2730
  return inputs_embeds, attention_mask
2731
-
2732
  @torch.no_grad()
2733
  def predict_detections_and_associations(
2734
  self,
@@ -2740,7 +2843,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2740
  essential_text_threshold=0.8,
2741
  ):
2742
  batch_inputs = processor(
2743
- batch_input_text=["Find all panels, texts, characters, and speech-bubble tails in the image."] * len(images),
 
2744
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2745
  batch_images=images,
2746
  padding=True,
@@ -2758,13 +2862,16 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2758
  do_sample=False,
2759
  num_beams=3,
2760
  )
2761
- generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(generated_ids, images)
2762
- map_to_category = {"<pa": "panels", "<te": "texts", "<ch": "characters", "<ta": "tails"}
 
 
2763
 
2764
  results = []
2765
 
2766
  for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
2767
- categories = [map_to_category.get(generated_text[j:j+3], None) for i, j in batch_indices_of_bboxes_in_generated_text]
 
2768
  result_for_this_image = {
2769
  "panels": [],
2770
  "texts": [],
@@ -2779,9 +2886,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2779
 
2780
  cleaned_generated_ids = []
2781
  for generated_id in generated_ids:
2782
- index_of_last_bos = torch.where(generated_id == processor.tokenizer.bos_token_id)[0][-1].item()
 
2783
  cleaned_generated_ids.append(generated_id[index_of_last_bos:])
2784
- cleaned_generated_ids = pad_sequence(cleaned_generated_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
 
2785
  association_outputs = self(
2786
  input_ids=batch_inputs["input_ids"],
2787
  pixel_values=batch_inputs["pixel_values"],
@@ -2790,17 +2899,21 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2790
  )
2791
 
2792
  for img_idx in range(len(results)):
2793
- character_cluster_labels = UnionFind.from_adj_matrix(association_outputs.character_character_affinity_matrices[img_idx] > character_character_association_threshold).get_labels_for_connected_components()
2794
- text_character_association = torch.nonzero(association_outputs.text_character_association_matrices[img_idx] > text_character_association_threshold).tolist()
2795
- text_tail_association = torch.nonzero(association_outputs.text_tail_association_matrices[img_idx] > text_tail_association_threshold).tolist()
2796
- essential_text_logits = (association_outputs.essential_text_logits[img_idx] > essential_text_threshold).tolist()
 
 
 
 
2797
  results[img_idx]["character_cluster_labels"] = character_cluster_labels
2798
  results[img_idx]["text_character_associations"] = text_character_association
2799
  results[img_idx]["text_tail_associations"] = text_tail_association
2800
  results[img_idx]["is_essential_text"] = essential_text_logits
2801
 
2802
  return results
2803
-
2804
  @torch.no_grad()
2805
  def predict_ocr(
2806
  self,
@@ -2808,7 +2921,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2808
  processor,
2809
  ):
2810
  batch_inputs = processor(
2811
- batch_input_text=["What is the text in the image, with regions?"] * len(images),
 
2812
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2813
  batch_images=images,
2814
  padding=True,
@@ -2826,7 +2940,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2826
  do_sample=False,
2827
  num_beams=3,
2828
  )
2829
- generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(generated_ids, images)
 
2830
  results = []
2831
  for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
2832
  ocr_texts = []
@@ -2850,9 +2965,10 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2850
  ):
2851
  def convert_caption_to_instruction(caption):
2852
  return "Locate the phrases in the caption: " + caption
2853
-
2854
  batch_inputs = processor(
2855
- batch_input_text=[convert_caption_to_instruction(caption) for caption in captions],
 
2856
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2857
  batch_images=images,
2858
  padding=True,
@@ -2880,7 +2996,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2880
  do_sample=False,
2881
  num_beams=3,
2882
  )
2883
- generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(generated_ids, images)
 
2884
  return [
2885
  {
2886
  "grounded_caption": generated_text,
@@ -2909,7 +3026,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2909
  output_attentions: Optional[bool] = None,
2910
  output_hidden_states: Optional[bool] = True,
2911
  return_dict: Optional[bool] = None,
2912
- tokenizer = None,
2913
  ) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
2914
  assert output_hidden_states, "output_hidden_states must be True"
2915
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -2927,7 +3044,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2927
  if pixel_values is not None:
2928
  # (batch_size, num_image_tokens, hidden_size)
2929
  image_features = self._encode_image(pixel_values)
2930
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
 
2931
 
2932
  if inputs_embeds is not None:
2933
  attention_mask = attention_mask.to(inputs_embeds.dtype)
@@ -2949,10 +3067,14 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2949
  return_dict=return_dict,
2950
  )
2951
 
2952
- character_character_affinity_matrices = self.get_character_character_affinity_matrices(outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
2953
- text_character_association_matrices = self.get_text_character_association_matrices(outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
2954
- text_tail_association_matrices = self.get_text_tail_association_matrices(outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
2955
- essential_text_logits = self.get_essential_text_logits(outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
 
 
 
 
2956
 
2957
  return Florence2Seq2SeqLMOutput(
2958
  character_character_affinity_matrices=character_character_affinity_matrices,
@@ -2963,11 +3085,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2963
 
2964
  def generate(
2965
  self,
2966
- input_ids,
2967
  inputs_embeds=None,
2968
  pixel_values=None,
2969
  **kwargs
2970
- ):
2971
 
2972
  if inputs_embeds is None:
2973
  # 1. Extra the input embeddings
@@ -2976,14 +3098,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2976
  # 2. Merge text and images
2977
  if pixel_values is not None:
2978
  image_features = self._encode_image(pixel_values)
2979
- inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
2980
-
 
2981
  return self.language_model.generate(
2982
  input_ids=None,
2983
  inputs_embeds=inputs_embeds,
2984
  **kwargs
2985
  )
2986
-
2987
  def slowly_generate_grounded_caption(
2988
  self,
2989
  input_ids,
@@ -2999,9 +3122,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2999
  """
3000
  input_embeds = self.get_input_embeddings()(input_ids)
3001
  image_features = self._encode_image(pixel_values)
3002
- inputs_embeds, _ = self._merge_input_ids_with_image_features(image_features, input_embeds)
3003
- decoder_input_ids = processor.tokenizer(captions, return_tensors="pt", truncation=False, padding=True)["input_ids"].to(self.device)
3004
- running_indices = torch.zeros(decoder_input_ids.shape[0], dtype=torch.long, device=self.device)
 
 
 
3005
  running_decoder_input_ids = decoder_input_ids[:, :1]
3006
  num_tokens_generated = 1
3007
  CHUNK_SIZE = 8
@@ -3019,19 +3145,22 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3019
  )[:, -(CHUNK_SIZE+1):-1]
3020
 
3021
  what_should_be_the_next_tokens = torch.stack([
3022
- decoder_input_ids[i, running_indices[i]+1:running_indices[i]+CHUNK_SIZE+1]
 
3023
  for i in range(decoder_input_ids.shape[0])
3024
  ])
3025
  # if the entire predicted chunk matches the next chunk, then we can saved some time and "jump" to the next chunk
3026
  if predicted_next_tokens.shape[1] == what_should_be_the_next_tokens.shape[1] and torch.all(predicted_next_tokens == what_should_be_the_next_tokens):
3027
  running_indices += CHUNK_SIZE
3028
- running_decoder_input_ids = torch.cat([running_decoder_input_ids, what_should_be_the_next_tokens], dim=-1)
 
3029
  continue
3030
-
3031
  # if, however, there is a deviation find the maximum prefix that matches in the batch
3032
 
3033
  predicted_next_tokens = predicted_next_tokens[:, 0]
3034
- predicted_next_token_strings = processor.batch_decode(predicted_next_tokens)
 
3035
  next_tokens_to_concat = []
3036
  for i, (pnts, pnt) in enumerate(zip(predicted_next_token_strings, predicted_next_tokens)):
3037
  if (pnts.startswith("<loc_") or pnts in ["<s>", "<pad>", "</s>"]) and running_indices[i] < decoder_input_ids.shape[1] - 1:
@@ -3039,15 +3168,18 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3039
  else:
3040
  running_indices[i] += 1
3041
  if running_indices[i] >= decoder_input_ids.shape[1]:
3042
- next_tokens_to_concat.append(torch.tensor(processor.tokenizer.eos_token_id, device=self.device))
 
3043
  # elif "’" in pnts: # this is an annoying character which looks like ' (apostrophe) but isn't.
3044
  # import pdb; pdb.set_trace()
3045
  else:
3046
- next_tokens_to_concat.append(decoder_input_ids[i, running_indices[i]])
 
3047
  next_tokens_to_concat = torch.stack(next_tokens_to_concat)[:, None]
3048
  if (next_tokens_to_concat == processor.tokenizer.eos_token_id).all():
3049
  break
3050
- running_decoder_input_ids = torch.cat([running_decoder_input_ids, next_tokens_to_concat], dim=-1)
 
3051
  if num_tokens_generated >= 1024:
3052
  break
3053
  return running_decoder_input_ids
@@ -3078,7 +3210,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3078
  remove_prefix_length = decoder_input_ids.shape[1] - 1
3079
 
3080
  decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
3081
-
3082
  return {
3083
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
3084
  "encoder_outputs": encoder_outputs,
@@ -3090,105 +3222,146 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3090
  "head_mask": head_mask,
3091
  "decoder_head_mask": decoder_head_mask,
3092
  "cross_attn_head_mask": cross_attn_head_mask,
3093
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
 
3094
  }
3095
-
3096
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3097
  return self.language_model.shift_tokens_right(labels)
3098
 
3099
  def _reorder_cache(self, *args, **kwargs):
3100
  return self.language_model._reorder_cache(*args, **kwargs)
3101
-
3102
  def get_character_character_affinity_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3103
  character_character_affinity_matrices = []
3104
  for index in range(len(decoder_hidden_states)):
3105
- character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<character>')).nonzero().squeeze(-1)
 
3106
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3107
  if character_embeddings.shape[0] == 0:
3108
- character_character_affinity_matrices.append(torch.zeros(0, 0).type_as(character_embeddings))
 
3109
  continue
3110
- character_embeddings = self.character_embedding_projection(character_embeddings)
3111
- char_i = repeat(character_embeddings, "i d -> i repeat d", repeat=character_embeddings.shape[0])
3112
- char_j = repeat(character_embeddings, "j d -> repeat j d", repeat=character_embeddings.shape[0])
 
 
 
3113
  char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
3114
- character_character_affinities = self.character_character_matching_head(char_ij)
3115
- character_character_affinities = rearrange(character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
3116
- character_character_affinities = (character_character_affinities + character_character_affinities.T) / 2
 
 
 
3117
  if apply_sigmoid:
3118
- character_character_affinities = torch.sigmoid(character_character_affinities)
3119
- character_character_affinity_matrices.append(character_character_affinities)
 
 
3120
  return character_character_affinity_matrices
3121
 
3122
  def get_text_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3123
  text_character_association_matrices = []
3124
  for index in range(len(decoder_hidden_states)):
3125
- text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<text>')).nonzero().squeeze(-1)
 
3126
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3127
- character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<character>')).nonzero().squeeze(-1)
 
3128
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3129
  if character_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
3130
- text_character_association_matrices.append(torch.zeros(text_embeddings.shape[0], character_embeddings.shape[0]).type_as(text_embeddings))
 
3131
  continue
3132
- text_i = repeat(text_embeddings, "i d -> i repeat d", repeat=character_embeddings.shape[0])
3133
- char_j = repeat(character_embeddings, "j d -> repeat j d", repeat=text_embeddings.shape[0])
3134
- text_char_ij = rearrange([text_i, char_j], "two i j d -> (i j) (two d)")
3135
- text_character_affinities = self.text_character_matching_head(text_char_ij)
3136
- text_character_affinities = rearrange(text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
 
 
 
 
 
3137
  if apply_sigmoid:
3138
- text_character_affinities = torch.sigmoid(text_character_affinities)
3139
- text_character_association_matrices.append(text_character_affinities)
 
 
3140
  return text_character_association_matrices
3141
 
3142
  def get_text_tail_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3143
  text_tail_association_matrices = []
3144
  for index in range(len(decoder_hidden_states)):
3145
- text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<text>')).nonzero().squeeze(-1)
 
3146
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3147
- tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<tail>')).nonzero().squeeze(-1)
 
3148
  tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
3149
  if tail_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
3150
- text_tail_association_matrices.append(torch.zeros(text_embeddings.shape[0], tail_embeddings.shape[0]).type_as(text_embeddings))
 
3151
  continue
3152
- text_i = repeat(text_embeddings, "i d -> i repeat d", repeat=tail_embeddings.shape[0])
3153
- tail_j = repeat(tail_embeddings, "j d -> repeat j d", repeat=text_embeddings.shape[0])
3154
- text_tail_ij = rearrange([text_i, tail_j], "two i j d -> (i j) (two d)")
 
 
 
3155
  text_tail_affinities = self.text_tail_matching_head(text_tail_ij)
3156
- text_tail_affinities = rearrange(text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
 
3157
  if apply_sigmoid:
3158
  text_tail_affinities = torch.sigmoid(text_tail_affinities)
3159
  text_tail_association_matrices.append(text_tail_affinities)
3160
  return text_tail_association_matrices
3161
-
3162
  def get_tail_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3163
  tail_character_association_matrices = []
3164
  for index in range(len(decoder_hidden_states)):
3165
- tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<tail>')).nonzero().squeeze(-1)
 
3166
  tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
3167
- character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<character>')).nonzero().squeeze(-1)
 
3168
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3169
  if character_embeddings.shape[0] == 0 or tail_embeddings.shape[0] == 0:
3170
- tail_character_association_matrices.append(torch.zeros(tail_embeddings.shape[0], character_embeddings.shape[0]).type_as(tail_embeddings))
 
3171
  continue
3172
- tail_i = repeat(tail_embeddings, "i d -> i repeat d", repeat=character_embeddings.shape[0])
3173
- char_j = repeat(character_embeddings, "j d -> repeat j d", repeat=tail_embeddings.shape[0])
3174
- tail_char_ij = rearrange([tail_i, char_j], "two i j d -> (i j) (two d)")
3175
- tail_character_affinities = self.tail_character_matching_head(tail_char_ij)
3176
- tail_character_affinities = rearrange(tail_character_affinities, "(i j) 1 -> i j", i=tail_i.shape[0])
 
 
 
 
 
3177
  if apply_sigmoid:
3178
- tail_character_affinities = torch.sigmoid(tail_character_affinities)
3179
- tail_character_association_matrices.append(tail_character_affinities)
 
 
3180
  return tail_character_association_matrices
3181
-
3182
  def get_essential_text_logits(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3183
  essential_text_logits = []
3184
  for index in range(len(decoder_hidden_states)):
3185
- text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids('<text>')).nonzero().squeeze(-1)
 
3186
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3187
  if text_embeddings.shape[0] == 0:
3188
- essential_text_logits.append(torch.zeros(0).type_as(text_embeddings))
 
3189
  continue
3190
- text_logits = rearrange(self.text_classification_head(text_embeddings), "i 1 -> i")
 
3191
  if apply_sigmoid:
3192
  text_logits = torch.sigmoid(text_logits)
3193
  essential_text_logits.append(text_logits)
3194
- return essential_text_logits
 
23
  from torch import nn
24
  import torch.nn.functional as F
25
  import torch.utils.checkpoint as checkpoint
26
+ from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange, repeat
29
  from timm.models.layers import DropPath, trunc_normal_
 
41
  is_flash_attn_2_available,
42
  is_flash_attn_greater_or_equal_2_10,
43
  )
44
+ from .configuration_florence2 import Florence2Config
45
  from .configuration_florence2 import Florence2LanguageConfig
46
  from .configuration_florence2 import Florence2VisionConfig
47
  from pytorch_metric_learning.utils.loss_and_miner_utils import get_all_pairs_indices
 
72
 
73
  _CONFIG_FOR_DOC = "Florence2Config"
74
 
75
+
76
  class LearnedAbsolutePositionEmbedding2D(nn.Module):
77
  """
78
  This module learns positional embeddings up to a fixed maximum size.
 
81
  def __init__(self, embedding_dim=256, num_pos=50):
82
  super().__init__()
83
  self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
84
+ self.column_embeddings = nn.Embedding(
85
+ num_pos, embedding_dim - (embedding_dim // 2))
86
 
87
  def forward(self, pixel_values):
88
  """
 
97
  x_emb = self.column_embeddings(width_values)
98
  y_emb = self.row_embeddings(height_values)
99
  # (height, width, embedding_dim * 2)
100
+ pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1),
101
+ y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1)
102
  # (embedding_dim * 2, height, width)
103
  pos = pos.permute(2, 0, 1)
104
  pos = pos.unsqueeze(0)
 
108
  pos = pos.permute(0, 2, 3, 1)
109
  return pos
110
 
111
+
112
  class PositionalEmbeddingCosine1D(nn.Module):
113
  """
114
  This class implements a very simple positional encoding. It follows closely
 
120
  dropout_prob: The dropout probability.
121
  max_seq_len: The maximum length to precompute the positional encodings.
122
  """
123
+
124
  def __init__(
125
  self,
126
  embed_dim: int = 512,
 
176
  embed_dim: The dimension of the embeddings.
177
  max_seq_len: The maximum length to precompute the positional encodings.
178
  """
179
+
180
  def __init__(
181
  self,
182
  embedding_dim: int = 512,
 
202
  len_seq = seq_embeds.size(-2)
203
  assert len_seq <= self.num_pos
204
  # [T, D]
205
+ pos_embeds = self.embeddings(
206
+ torch.arange(len_seq).to(seq_embeds.device))
207
  # Adapt pre-computed positional embeddings to the input.
208
  if shape_len == 3:
209
  pos_embeds = pos_embeds.view(
 
211
  return pos_embeds
212
 
213
 
 
214
  class MySequential(nn.Sequential):
215
  def forward(self, *inputs):
216
  for module in self._modules.values():
 
355
  def forward(self, x, size):
356
  B, N, C = x.shape
357
 
358
+ qkv = self.qkv(x).reshape(B, N, 3, self.groups, C //
359
+ self.groups).permute(2, 0, 3, 1, 4)
360
  q, k, v = qkv[0], qkv[1], qkv[2]
361
 
362
  q = q * (float(N) ** -0.5)
 
375
  conv_at_attn=True, conv_at_ffn=True):
376
  super().__init__()
377
 
378
+ drop_path = DropPath(
379
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
380
 
381
+ self.conv1 = PreNorm(None, DepthWiseConv2d(
382
+ dim, 3, 1, 1)) if conv_at_attn else None
383
  self.channel_attn = PreNorm(
384
  norm_layer(dim),
385
  ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
386
  drop_path
387
  )
388
+ self.conv2 = PreNorm(None, DepthWiseConv2d(
389
+ dim, 3, 1, 1)) if conv_at_ffn else None
390
  self.ffn = PreNorm(
391
  norm_layer(dim),
392
+ Mlp(in_features=dim, hidden_features=int(
393
+ dim*mlp_ratio), act_layer=act_layer),
394
  drop_path
395
  )
396
 
 
408
 
409
  def window_partition(x, window_size: int):
410
  B, H, W, C = x.shape
411
+ x = x.view(B, H // window_size, window_size,
412
+ W // window_size, window_size, C)
413
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous(
414
+ ).view(-1, window_size, window_size, C)
415
  return windows
416
 
417
 
418
  def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
419
+ B = batch_size
420
  # this will cause onnx conversion failed for dynamic axis, because treated as constant
421
+ # int(windows.shape[0] / (H * W / window_size / window_size))
422
+ x = windows.view(B, H // window_size, W // window_size,
423
+ window_size, window_size, -1)
424
  x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
425
  return x
426
 
 
461
  # attn_windows = self.attn(x_windows)
462
 
463
  B_, N, C = x.shape
464
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C //
465
+ self.num_heads).permute(2, 0, 3, 1, 4)
466
  q, k, v = qkv[0], qkv[1], qkv[2]
467
 
468
  q = q * self.scale
 
493
  norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True):
494
  super().__init__()
495
 
496
+ drop_path = DropPath(
497
+ drop_path_rate) if drop_path_rate > 0. else nn.Identity()
498
 
499
+ self.conv1 = PreNorm(None, DepthWiseConv2d(
500
+ dim, 3, 1, 1)) if conv_at_attn else None
501
  self.window_attn = PreNorm(
502
  norm_layer(dim),
503
  WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
504
  drop_path
505
  )
506
+ self.conv2 = PreNorm(None, DepthWiseConv2d(
507
+ dim, 3, 1, 1)) if conv_at_ffn else None
508
  self.ffn = PreNorm(
509
  norm_layer(dim),
510
+ Mlp(in_features=dim, hidden_features=int(
511
+ dim*mlp_ratio), act_layer=act_layer),
512
  drop_path
513
  )
514
 
 
566
  enable_checkpoint=True,
567
  conv_at_attn=True,
568
  conv_at_ffn=True,
569
+ ):
570
  super().__init__()
571
 
572
  self.num_classes = num_classes
 
578
  assert self.num_stages == len(self.num_heads) == len(self.num_groups)
579
 
580
  num_stages = len(embed_dims)
581
+ dpr = [x.item() for x in torch.linspace(
582
+ 0, drop_path_rate, sum(depths)*2)]
583
 
584
  depth_offset = 0
585
  convs = []
 
633
 
634
  self.norms = norm_layer(self.embed_dims[-1])
635
  self.avgpool = nn.AdaptiveAvgPool1d(1)
636
+ self.head = nn.Linear(
637
+ self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
638
 
639
  self.apply(self._init_weights)
640
 
 
669
  for conv, block in zip(self.convs, self.blocks):
670
  x, input_size = conv(x, input_size)
671
  if self.enable_checkpoint:
672
+ x, input_size = checkpoint.checkpoint(
673
+ block, x, input_size, use_reentrant=True)
674
  else:
675
  x, input_size = block(x, input_size)
676
  return x
 
690
  x = self.forward_features(x)
691
  x = self.head(x)
692
  return x
693
+
694
  @classmethod
695
  def from_config(cls, config):
696
  return cls(
 
707
  )
708
 
709
 
 
 
710
  if is_flash_attn_2_available():
711
  from flash_attn import flash_attn_func, flash_attn_varlen_func
712
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
713
 
714
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
715
+
716
+
717
  def _get_unpad_data(attention_mask):
718
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
719
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
720
  max_seqlen_in_batch = seqlens_in_batch.max().item()
721
+ cu_seqlens = F.pad(torch.cumsum(
722
+ seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
723
  return (
724
  indices,
725
  cu_seqlens,
 
857
  if past_key_value[0] is not None:
858
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
859
  if past_key_value[1] is not None:
860
+ value_states = torch.cat(
861
+ [past_key_value[1], value_states], dim=2)
862
  else:
863
  # self_attention
864
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
 
875
  past_key_value = (key_states, value_states)
876
 
877
  proj_shape = (bsz * self.num_heads, -1, self.head_dim)
878
+ query_states = self._shape(
879
+ query_states, tgt_len, bsz).view(*proj_shape)
880
  key_states = key_states.reshape(*proj_shape)
881
  value_states = value_states.reshape(*proj_shape)
882
 
 
894
  raise ValueError(
895
  f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
896
  )
897
+ attn_weights = attn_weights.view(
898
+ bsz, self.num_heads, tgt_len, src_len) + attention_mask
899
+ attn_weights = attn_weights.view(
900
+ bsz * self.num_heads, tgt_len, src_len)
901
 
902
  attn_weights = nn.functional.softmax(attn_weights, dim=-1)
903
 
 
907
  f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
908
  f" {layer_head_mask.size()}"
909
  )
910
+ attn_weights = layer_head_mask.view(
911
+ 1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
912
+ attn_weights = attn_weights.view(
913
+ bsz * self.num_heads, tgt_len, src_len)
914
 
915
  if output_attentions:
916
  # this operation is a bit awkward, but it's required to
917
  # make sure that attn_weights keeps its gradient.
918
  # In order to do so, attn_weights have to be reshaped
919
  # twice and have to be reused in the following
920
+ attn_weights_reshaped = attn_weights.view(
921
+ bsz, self.num_heads, tgt_len, src_len)
922
+ attn_weights = attn_weights_reshaped.view(
923
+ bsz * self.num_heads, tgt_len, src_len)
924
  else:
925
  attn_weights_reshaped = None
926
 
927
+ attn_probs = nn.functional.dropout(
928
+ attn_weights, p=self.dropout, training=self.training)
929
 
930
  attn_output = torch.bmm(attn_probs, value_states)
931
 
 
935
  f" {attn_output.size()}"
936
  )
937
 
938
+ attn_output = attn_output.view(
939
+ bsz, self.num_heads, tgt_len, self.head_dim)
940
  attn_output = attn_output.transpose(1, 2)
941
 
942
  # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
 
978
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
979
  # Florence2FlashAttention2 attention does not support output_attentions
980
  if output_attentions:
981
+ raise ValueError(
982
+ "Florence2FlashAttention2 attention does not support output_attentions")
983
 
984
  # if key_value_states are provided this layer is used as a cross-attention layer
985
  # for the decoder
 
1004
  elif is_cross_attention:
1005
  # cross_attentions
1006
  key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
1007
+ value_states = self._reshape(
1008
+ self.v_proj(key_value_states), -1, bsz)
1009
  elif past_key_value is not None:
1010
  # reuse k, v, self_attention
1011
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
1012
  value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
1013
+ key_states = torch.cat(
1014
+ [past_key_value[0].transpose(1, 2), key_states], dim=1)
1015
+ value_states = torch.cat(
1016
+ [past_key_value[1].transpose(1, 2), value_states], dim=1)
1017
  else:
1018
  # self_attention
1019
  key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
 
1027
  # all previous decoder key/value_states. Further calls to uni-directional self-attention
1028
  # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
1029
  # if encoder bi-directional self-attention `past_key_value` is always `None`
1030
+ past_key_value = (key_states.transpose(
1031
+ 1, 2), value_states.transpose(1, 2))
1032
 
1033
  kv_seq_len = key_states.shape[-2]
1034
  if past_key_value is not None:
 
1124
  causal=causal,
1125
  )
1126
 
1127
+ attn_output = pad_input(
1128
+ attn_output_unpad, indices_q, batch_size, query_length)
1129
  else:
1130
  attn_output = flash_attn_func(
1131
  query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
 
1135
 
1136
  # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
1137
  def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
1138
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
1139
+ attention_mask)
1140
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
1141
 
1142
  key_layer = index_first_axis(
1143
+ key_layer.reshape(batch_size * kv_seq_len,
1144
+ num_key_value_heads, head_dim), indices_k
1145
  )
1146
  value_layer = index_first_axis(
1147
+ value_layer.reshape(batch_size * kv_seq_len,
1148
+ num_key_value_heads, head_dim), indices_k
1149
  )
1150
  if query_length == kv_seq_len:
1151
  query_layer = index_first_axis(
1152
+ query_layer.reshape(batch_size * kv_seq_len,
1153
+ self.num_heads, head_dim), indices_k
1154
  )
1155
  cu_seqlens_q = cu_seqlens_k
1156
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
1165
  else:
1166
  # The -q_len: slice assumes left padding.
1167
  attention_mask = attention_mask[:, -query_length:]
1168
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
1169
+ query_layer, attention_mask)
1170
 
1171
  return (
1172
  query_layer,
 
1236
  if past_key_value[0] is not None:
1237
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
1238
  if past_key_value[1] is not None:
1239
+ value_states = torch.cat(
1240
+ [past_key_value[1], value_states], dim=2)
1241
  else:
1242
  # self_attention
1243
  key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
 
1339
  layer_head_mask=layer_head_mask,
1340
  output_attentions=output_attentions,
1341
  )
1342
+ hidden_states = nn.functional.dropout(
1343
+ hidden_states, p=self.dropout, training=self.training)
1344
  hidden_states = residual + hidden_states
1345
  hidden_states = self.self_attn_layer_norm(hidden_states)
1346
 
1347
  residual = hidden_states
1348
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1349
+ hidden_states = nn.functional.dropout(
1350
+ hidden_states, p=self.activation_dropout, training=self.training)
1351
  hidden_states = self.fc2(hidden_states)
1352
+ hidden_states = nn.functional.dropout(
1353
+ hidden_states, p=self.dropout, training=self.training)
1354
  hidden_states = residual + hidden_states
1355
  hidden_states = self.final_layer_norm(hidden_states)
1356
 
1357
  if hidden_states.dtype == torch.float16 and (
1358
+ torch.isinf(hidden_states).any() or torch.isnan(
1359
+ hidden_states).any()
1360
  ):
1361
  clamp_value = torch.finfo(hidden_states.dtype).max - 1000
1362
+ hidden_states = torch.clamp(
1363
+ hidden_states, min=-clamp_value, max=clamp_value)
1364
 
1365
  outputs = (hidden_states,)
1366
 
 
1434
 
1435
  # Self Attention
1436
  # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
1437
+ self_attn_past_key_value = past_key_value[:
1438
+ 2] if past_key_value is not None else None
1439
  # add present self-attn cache to positions 1,2 of present_key_value tuple
1440
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
1441
  hidden_states=hidden_states,
 
1444
  layer_head_mask=layer_head_mask,
1445
  output_attentions=output_attentions,
1446
  )
1447
+ hidden_states = nn.functional.dropout(
1448
+ hidden_states, p=self.dropout, training=self.training)
1449
  hidden_states = residual + hidden_states
1450
  hidden_states = self.self_attn_layer_norm(hidden_states)
1451
 
 
1456
  residual = hidden_states
1457
 
1458
  # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
1459
+ cross_attn_past_key_value = past_key_value[-2:
1460
+ ] if past_key_value is not None else None
1461
  hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
1462
  hidden_states=hidden_states,
1463
  key_value_states=encoder_hidden_states,
 
1466
  past_key_value=cross_attn_past_key_value,
1467
  output_attentions=output_attentions,
1468
  )
1469
+ hidden_states = nn.functional.dropout(
1470
+ hidden_states, p=self.dropout, training=self.training)
1471
  hidden_states = residual + hidden_states
1472
  hidden_states = self.encoder_attn_layer_norm(hidden_states)
1473
 
 
1477
  # Fully Connected
1478
  residual = hidden_states
1479
  hidden_states = self.activation_fn(self.fc1(hidden_states))
1480
+ hidden_states = nn.functional.dropout(
1481
+ hidden_states, p=self.activation_dropout, training=self.training)
1482
  hidden_states = self.fc2(hidden_states)
1483
+ hidden_states = nn.functional.dropout(
1484
+ hidden_states, p=self.dropout, training=self.training)
1485
  hidden_states = residual + hidden_states
1486
  hidden_states = self.final_layer_norm(hidden_states)
1487
 
 
1496
  return outputs
1497
 
1498
 
 
1499
  class Florence2LanguagePreTrainedModel(PreTrainedModel):
1500
  config_class = Florence2LanguageConfig
1501
  base_model_prefix = "model"
 
1520
  @property
1521
  def dummy_inputs(self):
1522
  pad_token = self.config.pad_token_id
1523
+ input_ids = torch.tensor(
1524
+ [[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
1525
  dummy_inputs = {
1526
  "attention_mask": input_ids.ne(pad_token),
1527
  "input_ids": input_ids,
 
1561
  config.max_position_embeddings,
1562
  embed_dim,
1563
  )
1564
+ self.layers = nn.ModuleList([Florence2EncoderLayer(
1565
+ config) for _ in range(config.encoder_layers)])
1566
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1567
  self._use_sdpa = config._attn_implementation == "sdpa"
1568
  self.layernorm_embedding = nn.LayerNorm(embed_dim)
 
1631
 
1632
  # retrieve input_ids and inputs_embeds
1633
  if input_ids is not None and inputs_embeds is not None:
1634
+ raise ValueError(
1635
+ "You cannot specify both input_ids and inputs_embeds at the same time")
1636
  elif input_ids is not None:
1637
  input = input_ids
1638
  input_ids = input_ids.view(-1, input_ids.shape[-1])
1639
  elif inputs_embeds is not None:
1640
  input = inputs_embeds[:, :, -1]
1641
  else:
1642
+ raise ValueError(
1643
+ "You have to specify either input_ids or inputs_embeds")
1644
 
1645
  if inputs_embeds is None:
1646
  inputs_embeds = self.embed_tokens(input_ids)
 
1650
 
1651
  hidden_states = inputs_embeds + embed_pos
1652
  hidden_states = self.layernorm_embedding(hidden_states)
1653
+ hidden_states = nn.functional.dropout(
1654
+ hidden_states, p=self.dropout, training=self.training)
1655
 
1656
  # expand attention_mask
1657
  if attention_mask is not None:
 
1661
  # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1662
  # the manual implementation that requires a 4D causal mask in all cases.
1663
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1664
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(
1665
+ attention_mask, inputs_embeds.dtype)
1666
  else:
1667
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1668
+ attention_mask = _prepare_4d_attention_mask(
1669
+ attention_mask, inputs_embeds.dtype)
1670
 
1671
  encoder_states = () if output_hidden_states else None
1672
  all_attentions = () if output_attentions else None
 
1704
  layer_outputs = encoder_layer(
1705
  hidden_states,
1706
  attention_mask,
1707
+ layer_head_mask=(
1708
+ head_mask[idx] if head_mask is not None else None),
1709
  output_attentions=output_attentions,
1710
  )
1711
 
 
1739
  self.layerdrop = config.decoder_layerdrop
1740
  self.padding_idx = config.pad_token_id
1741
  self.max_target_positions = config.max_position_embeddings
1742
+ embed_scale = math.sqrt(
1743
+ config.d_model) if config.scale_embedding else 1.0
1744
 
1745
  self.embed_tokens = Florence2ScaledWordEmbedding(
1746
  config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
 
1753
  config.max_position_embeddings,
1754
  config.d_model,
1755
  )
1756
+ self.layers = nn.ModuleList([Florence2DecoderLayer(
1757
+ config) for _ in range(config.decoder_layers)])
1758
  self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1759
  self._use_sdpa = config._attn_implementation == "sdpa"
1760
 
 
1859
 
1860
  # retrieve input_ids and inputs_embeds
1861
  if input_ids is not None and inputs_embeds is not None:
1862
+ raise ValueError(
1863
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
1864
  elif input_ids is not None:
1865
  input = input_ids
1866
  input_shape = input.shape
 
1869
  input_shape = inputs_embeds.size()[:-1]
1870
  input = inputs_embeds[:, :, -1]
1871
  else:
1872
+ raise ValueError(
1873
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
1874
 
1875
  # past_key_values_length
1876
+ past_key_values_length = past_key_values[0][0].shape[
1877
+ 2] if past_key_values and past_key_values[0] and past_key_values[0][0] is not None else 0
1878
 
1879
  if inputs_embeds is None:
1880
  inputs_embeds = self.embed_tokens(input)
1881
 
1882
  if self._use_flash_attention_2:
1883
  # 2d mask is passed through the layers
1884
+ attention_mask = attention_mask if (
1885
+ attention_mask is not None and 0 in attention_mask) else None
1886
  elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1887
  # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1888
  # the manual implementation that requires a 4D causal mask in all cases.
 
1924
  hidden_states = inputs_embeds + positions
1925
  hidden_states = self.layernorm_embedding(hidden_states)
1926
 
1927
+ hidden_states = nn.functional.dropout(
1928
+ hidden_states, p=self.dropout, training=self.training)
1929
 
1930
  if self.gradient_checkpointing and self.training:
1931
  if use_cache:
 
1937
  # decoder layers
1938
  all_hidden_states = () if output_hidden_states else None
1939
  all_self_attns = () if output_attentions else None
1940
+ all_cross_attentions = () if (
1941
+ output_attentions and encoder_hidden_states is not None) else None
1942
  next_decoder_cache = () if use_cache else None
1943
 
1944
  # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
 
1980
  attention_mask=attention_mask,
1981
  encoder_hidden_states=encoder_hidden_states,
1982
  encoder_attention_mask=encoder_attention_mask,
1983
+ layer_head_mask=(
1984
+ head_mask[idx] if head_mask is not None else None),
1985
  cross_attn_layer_head_mask=(
1986
  cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1987
  ),
 
1992
  hidden_states = layer_outputs[0]
1993
 
1994
  if use_cache:
1995
+ next_decoder_cache += (
1996
+ layer_outputs[3 if output_attentions else 1],)
1997
 
1998
  if output_attentions:
1999
  all_self_attns += (layer_outputs[1],)
 
2022
 
2023
 
2024
  class Florence2LanguageModel(Florence2LanguagePreTrainedModel):
2025
+ _tied_weights_keys = ["encoder.embed_tokens.weight",
2026
+ "decoder.embed_tokens.weight"]
2027
 
2028
  def __init__(self, config: Florence2LanguageConfig):
2029
  super().__init__(config)
 
2109
  elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
2110
  encoder_outputs = BaseModelOutput(
2111
  last_hidden_state=encoder_outputs[0],
2112
+ hidden_states=encoder_outputs[1] if len(
2113
+ encoder_outputs) > 1 else None,
2114
+ attentions=encoder_outputs[2] if len(
2115
+ encoder_outputs) > 2 else None,
2116
  )
2117
 
2118
  # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
 
2148
 
2149
  class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin):
2150
  base_model_prefix = "model"
2151
+ _tied_weights_keys = ["encoder.embed_tokens.weight",
2152
+ "decoder.embed_tokens.weight", "lm_head.weight"]
2153
  _keys_to_ignore_on_load_missing = ["final_logits_bias"]
2154
 
2155
  def __init__(self, config: Florence2LanguageConfig):
2156
  super().__init__(config)
2157
  self.model = Florence2LanguageModel(config)
2158
+ self.register_buffer("final_logits_bias", torch.zeros(
2159
+ (1, self.model.shared.num_embeddings)))
2160
+ self.lm_head = nn.Linear(
2161
+ config.d_model, self.model.shared.num_embeddings, bias=False)
2162
 
2163
  # Initialize weights and apply final processing
2164
  self.post_init()
 
2170
  return self.model.get_decoder()
2171
 
2172
  def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
2173
+ new_embeddings = super().resize_token_embeddings(
2174
+ new_num_tokens, pad_to_multiple_of)
2175
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2176
  return new_embeddings
2177
 
 
2180
  if new_num_tokens <= old_num_tokens:
2181
  new_bias = self.final_logits_bias[:, :new_num_tokens]
2182
  else:
2183
+ extra_bias = torch.zeros(
2184
+ (1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
2185
  new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
2186
  self.register_buffer("final_logits_bias", new_bias)
2187
 
 
2222
 
2223
  if labels is not None:
2224
  if use_cache:
2225
+ logger.warning(
2226
+ "The `use_cache` argument is changed to `False` since `labels` is provided.")
2227
  use_cache = False
2228
  if decoder_input_ids is None and decoder_inputs_embeds is None:
2229
  decoder_input_ids = shift_tokens_right(
 
2255
  if labels is not None:
2256
  labels = labels.to(lm_logits.device)
2257
  loss_fct = CrossEntropyLoss()
2258
+ masked_lm_loss = loss_fct(
2259
+ lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
2260
 
2261
  if not return_dict:
2262
  output = (lm_logits,) + outputs[1:]
 
2310
  "head_mask": head_mask,
2311
  "decoder_head_mask": decoder_head_mask,
2312
  "cross_attn_head_mask": cross_attn_head_mask,
2313
+ # change this to avoid caching (presumably for debugging)
2314
+ "use_cache": use_cache,
2315
  }
2316
 
2317
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
 
2323
  for layer_past in past_key_values:
2324
  # cached cross_attention states don't have to be reordered -> they are always the same
2325
  reordered_past += (
2326
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device))
2327
+ for past_state in layer_past[:2])
2328
  + layer_past[2:],
2329
  )
2330
  return reordered_past
2331
 
2332
+
2333
  @dataclass
2334
  class Florence2Seq2SeqLMOutput(ModelOutput):
2335
  """
 
2515
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
2516
  """
2517
 
2518
+
2519
  @add_start_docstrings(
2520
  """The FLORENCE2 vision model without any head""",
2521
  FLORENCE2_START_DOCSTRING,
 
2523
  class Florence2VisionModel(Florence2PreTrainedModel):
2524
  def __init__(self, config: Florence2VisionConfig):
2525
  super().__init__(config)
2526
+ assert config.model_type in [
2527
+ 'davit', ""], 'only DaViT is supported for now'
2528
  self.vision_tower = DaViT.from_config(config=config)
2529
 
2530
  self.post_init()
2531
+
2532
  def forward(self, pixel_values):
2533
  if len(pixel_values.shape) == 4:
2534
  x = self.vision_tower.forward_features_unpool(pixel_values)
 
2544
  class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2545
  def __init__(self, config: Florence2VisionConfig):
2546
  super().__init__(config)
2547
+ assert config.model_type in [
2548
+ 'davit', ''], 'only DaViT is supported for now'
2549
  self.vision_tower = DaViT.from_config(config=config)
2550
 
2551
  self._build_image_projection_layers(config)
2552
 
2553
  self.post_init()
2554
+
2555
  def _build_image_projection_layers(self, config):
2556
  image_dim_out = config.dim_embed[-1]
2557
  dim_projection = config.projection_dim
 
2587
  x = self.vision_tower.forward_features_unpool(pixel_values)
2588
  else:
2589
  raise ValueError(f'invalid image shape {pixel_values.shape}')
2590
+
2591
  if self.image_pos_embed is not None:
2592
  x = x.view(batch_size * T, -1, x.shape[-1])
2593
  num_tokens = x.shape[-2]
 
2599
  x = x.view(batch_size, T * h*w, x.shape[-1])
2600
 
2601
  if self.visual_temporal_embed is not None:
2602
+ visual_temporal_embed = self.visual_temporal_embed(
2603
+ x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2604
+ x = x.view(
2605
+ batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
2606
 
2607
  x_feat_dict = {}
2608
 
2609
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2610
  x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2611
 
2612
+ temporal_avg_pool_x = x.view(
2613
+ batch_size, T, -1, x.shape[-1]).mean(dim=1)
2614
  x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2615
 
2616
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
 
2619
  new_x = []
2620
  for _image_feature_source in self.image_feature_source:
2621
  if _image_feature_source not in x_feat_dict:
2622
+ raise ValueError(
2623
+ 'invalid image feature source: {}'.format(_image_feature_source))
2624
  new_x.append(x_feat_dict[_image_feature_source])
2625
 
2626
  x = torch.cat(new_x, dim=1)
 
2628
  x = x @ self.image_projection
2629
  x = self.image_proj_norm(x)
2630
 
 
2631
  return x
2632
 
2633
 
 
2634
  @add_start_docstrings(
2635
  """The FLORENCE2 model which consists of a vision backbone and a language model.""",
2636
  FLORENCE2_START_DOCSTRING,
 
2638
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2639
  def __init__(self, config: Florence2Config):
2640
  super().__init__(config)
2641
+ assert config.vision_config.model_type in [
2642
+ 'davit', ''], 'only DaViT is supported for now'
2643
  self.vision_tower = DaViT.from_config(config=config.vision_config)
2644
+ # remove unused layers
2645
  del self.vision_tower.head
2646
  del self.vision_tower.norms
2647
 
 
2649
  self._attn_implementation = config._attn_implementation
2650
  self._build_image_projection_layers(config)
2651
 
2652
+ language_model = Florence2LanguageForConditionalGeneration(
2653
+ config=config.text_config)
2654
 
2655
  if language_model._tied_weights_keys is not None:
2656
+ self._tied_weights_keys = [
2657
+ f"language_model.{k}" for k in language_model._tied_weights_keys]
2658
  self.language_model = language_model
2659
  self.character_character_matching_head = nn.Sequential(
2660
  nn.Linear(2 * 768, config.projection_dim),
 
2678
  nn.Linear(config.projection_dim, 1)
2679
  )
2680
  self.text_classification_head = nn.Linear(config.projection_dim, 1)
2681
+ self.character_embedding_projection = nn.Linear(
2682
+ config.projection_dim, 768)
2683
 
2684
  self._init_weights(self.character_character_matching_head)
2685
  self._init_weights(self.text_character_matching_head)
2686
  self._init_weights(self.text_tail_matching_head)
2687
  self._init_weights(self.text_classification_head)
2688
+
2689
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
2690
  self.post_init()
2691
+
2692
  def _init_weights(self, m):
2693
  if isinstance(m, nn.Linear):
2694
  trunc_normal_(m.weight, std=0.02)
 
2708
  elif isinstance(m, nn.Sequential):
2709
  for layer in m:
2710
  self._init_weights(layer)
2711
+
2712
  def _build_image_projection_layers(self, config):
2713
  image_dim_out = config.vision_config.dim_embed[-1]
2714
  dim_projection = config.vision_config.projection_dim
 
2747
  return self.language_model.get_input_embeddings()
2748
 
2749
  def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
2750
+ model_embeds = self.language_model.resize_token_embeddings(
2751
+ new_num_tokens, pad_to_multiple_of)
2752
  # update vocab size
2753
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2754
  self.config.vocab_size = model_embeds.num_embeddings
2755
  self.vocab_size = model_embeds.num_embeddings
2756
  return model_embeds
2757
+
2758
  def _encode_image(self, pixel_values):
2759
  if len(pixel_values.shape) == 4:
2760
  batch_size, C, H, W = pixel_values.shape
 
2762
  x = self.vision_tower.forward_features_unpool(pixel_values)
2763
  else:
2764
  raise ValueError(f'invalid image shape {pixel_values.shape}')
2765
+
2766
  if self.image_pos_embed is not None:
2767
  x = x.view(batch_size * T, -1, x.shape[-1])
2768
  num_tokens = x.shape[-2]
 
2774
  x = x.view(batch_size, T * h*w, x.shape[-1])
2775
 
2776
  if self.visual_temporal_embed is not None:
2777
+ visual_temporal_embed = self.visual_temporal_embed(
2778
+ x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
2779
+ x = x.view(
2780
+ batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1])
2781
 
2782
  x_feat_dict = {}
2783
 
2784
  spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
2785
  x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
2786
 
2787
+ temporal_avg_pool_x = x.view(
2788
+ batch_size, T, -1, x.shape[-1]).mean(dim=1)
2789
  x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
2790
 
2791
  x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
 
2794
  new_x = []
2795
  for _image_feature_source in self.image_feature_source:
2796
  if _image_feature_source not in x_feat_dict:
2797
+ raise ValueError(
2798
+ 'invalid image feature source: {}'.format(_image_feature_source))
2799
  new_x.append(x_feat_dict[_image_feature_source])
2800
 
2801
  x = torch.cat(new_x, dim=1)
 
2803
  x = x @ self.image_projection
2804
  x = self.image_proj_norm(x)
2805
 
2806
+ return x
2807
 
2808
  def _merge_input_ids_with_image_features(
2809
+ self, image_features, inputs_embeds
2810
  ):
2811
  batch_size, image_token_length = image_features.size()[:-1]
2812
  device = image_features.device
2813
+ image_attention_mask = torch.ones(
2814
+ batch_size, image_token_length, device=device)
2815
 
2816
  # task_prefix_embeds: [batch_size, padded_context_length, hidden_size]
2817
  # task_prefix_attention_mask: [batch_size, context_length]
 
2819
  return image_features, image_attention_mask
2820
 
2821
  task_prefix_embeds = inputs_embeds
2822
+ task_prefix_attention_mask = torch.ones(
2823
+ batch_size, task_prefix_embeds.size(1), device=device)
2824
 
2825
  if len(task_prefix_attention_mask.shape) == 3:
2826
  task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
2827
 
2828
  # concat [image embeds, task prefix embeds]
2829
  inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
2830
+ attention_mask = torch.cat(
2831
+ [image_attention_mask, task_prefix_attention_mask], dim=1)
2832
 
2833
  return inputs_embeds, attention_mask
2834
+
2835
  @torch.no_grad()
2836
  def predict_detections_and_associations(
2837
  self,
 
2843
  essential_text_threshold=0.8,
2844
  ):
2845
  batch_inputs = processor(
2846
+ batch_input_text=[
2847
+ "Find all panels, texts, characters, and speech-bubble tails in the image."] * len(images),
2848
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2849
  batch_images=images,
2850
  padding=True,
 
2862
  do_sample=False,
2863
  num_beams=3,
2864
  )
2865
+ generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
2866
+ generated_ids, images)
2867
+ map_to_category = {"<pa": "panels", "<te": "texts",
2868
+ "<ch": "characters", "<ta": "tails"}
2869
 
2870
  results = []
2871
 
2872
  for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
2873
+ categories = [map_to_category.get(
2874
+ generated_text[j:j+3], None) for i, j in batch_indices_of_bboxes_in_generated_text]
2875
  result_for_this_image = {
2876
  "panels": [],
2877
  "texts": [],
 
2886
 
2887
  cleaned_generated_ids = []
2888
  for generated_id in generated_ids:
2889
+ index_of_last_bos = torch.where(
2890
+ generated_id == processor.tokenizer.bos_token_id)[0][-1].item()
2891
  cleaned_generated_ids.append(generated_id[index_of_last_bos:])
2892
+ cleaned_generated_ids = pad_sequence(
2893
+ cleaned_generated_ids, batch_first=True, padding_value=processor.tokenizer.pad_token_id)
2894
  association_outputs = self(
2895
  input_ids=batch_inputs["input_ids"],
2896
  pixel_values=batch_inputs["pixel_values"],
 
2899
  )
2900
 
2901
  for img_idx in range(len(results)):
2902
+ character_cluster_labels = UnionFind.from_adj_matrix(
2903
+ association_outputs.character_character_affinity_matrices[img_idx] > character_character_association_threshold).get_labels_for_connected_components()
2904
+ text_character_association = torch.nonzero(
2905
+ association_outputs.text_character_association_matrices[img_idx] > text_character_association_threshold).tolist()
2906
+ text_tail_association = torch.nonzero(
2907
+ association_outputs.text_tail_association_matrices[img_idx] > text_tail_association_threshold).tolist()
2908
+ essential_text_logits = (
2909
+ association_outputs.essential_text_logits[img_idx] > essential_text_threshold).tolist()
2910
  results[img_idx]["character_cluster_labels"] = character_cluster_labels
2911
  results[img_idx]["text_character_associations"] = text_character_association
2912
  results[img_idx]["text_tail_associations"] = text_tail_association
2913
  results[img_idx]["is_essential_text"] = essential_text_logits
2914
 
2915
  return results
2916
+
2917
  @torch.no_grad()
2918
  def predict_ocr(
2919
  self,
 
2921
  processor,
2922
  ):
2923
  batch_inputs = processor(
2924
+ batch_input_text=[
2925
+ "What is the text in the image, with regions?"] * len(images),
2926
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2927
  batch_images=images,
2928
  padding=True,
 
2940
  do_sample=False,
2941
  num_beams=3,
2942
  )
2943
+ generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
2944
+ generated_ids, images)
2945
  results = []
2946
  for generated_text, batch_indices_of_bboxes_in_generated_text, list_of_list_of_bboxes in zip(generated_texts, batch_indices_of_bboxes_in_generated_text, list_of_list_of_list_of_bboxes):
2947
  ocr_texts = []
 
2965
  ):
2966
  def convert_caption_to_instruction(caption):
2967
  return "Locate the phrases in the caption: " + caption
2968
+
2969
  batch_inputs = processor(
2970
+ batch_input_text=[convert_caption_to_instruction(
2971
+ caption) for caption in captions],
2972
  batch_input_list_of_list_of_bboxes=[[]] * len(images),
2973
  batch_images=images,
2974
  padding=True,
 
2996
  do_sample=False,
2997
  num_beams=3,
2998
  )
2999
+ generated_texts, list_of_list_of_list_of_bboxes, batch_indices_of_bboxes_in_generated_text = processor.postprocess_output(
3000
+ generated_ids, images)
3001
  return [
3002
  {
3003
  "grounded_caption": generated_text,
 
3026
  output_attentions: Optional[bool] = None,
3027
  output_hidden_states: Optional[bool] = True,
3028
  return_dict: Optional[bool] = None,
3029
+ tokenizer=None,
3030
  ) -> Union[Tuple, Florence2Seq2SeqLMOutput]:
3031
  assert output_hidden_states, "output_hidden_states must be True"
3032
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
3044
  if pixel_values is not None:
3045
  # (batch_size, num_image_tokens, hidden_size)
3046
  image_features = self._encode_image(pixel_values)
3047
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
3048
+ image_features, inputs_embeds)
3049
 
3050
  if inputs_embeds is not None:
3051
  attention_mask = attention_mask.to(inputs_embeds.dtype)
 
3067
  return_dict=return_dict,
3068
  )
3069
 
3070
+ character_character_affinity_matrices = self.get_character_character_affinity_matrices(
3071
+ outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
3072
+ text_character_association_matrices = self.get_text_character_association_matrices(
3073
+ outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
3074
+ text_tail_association_matrices = self.get_text_tail_association_matrices(
3075
+ outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
3076
+ essential_text_logits = self.get_essential_text_logits(
3077
+ outputs.decoder_hidden_states[-1], decoder_input_ids, tokenizer, apply_sigmoid=True)
3078
 
3079
  return Florence2Seq2SeqLMOutput(
3080
  character_character_affinity_matrices=character_character_affinity_matrices,
 
3085
 
3086
  def generate(
3087
  self,
3088
+ input_ids,
3089
  inputs_embeds=None,
3090
  pixel_values=None,
3091
  **kwargs
3092
+ ):
3093
 
3094
  if inputs_embeds is None:
3095
  # 1. Extra the input embeddings
 
3098
  # 2. Merge text and images
3099
  if pixel_values is not None:
3100
  image_features = self._encode_image(pixel_values)
3101
+ inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
3102
+ image_features, inputs_embeds)
3103
+
3104
  return self.language_model.generate(
3105
  input_ids=None,
3106
  inputs_embeds=inputs_embeds,
3107
  **kwargs
3108
  )
3109
+
3110
  def slowly_generate_grounded_caption(
3111
  self,
3112
  input_ids,
 
3122
  """
3123
  input_embeds = self.get_input_embeddings()(input_ids)
3124
  image_features = self._encode_image(pixel_values)
3125
+ inputs_embeds, _ = self._merge_input_ids_with_image_features(
3126
+ image_features, input_embeds)
3127
+ decoder_input_ids = processor.tokenizer(
3128
+ captions, return_tensors="pt", truncation=False, padding=True)["input_ids"].to(self.device)
3129
+ running_indices = torch.zeros(
3130
+ decoder_input_ids.shape[0], dtype=torch.long, device=self.device)
3131
  running_decoder_input_ids = decoder_input_ids[:, :1]
3132
  num_tokens_generated = 1
3133
  CHUNK_SIZE = 8
 
3145
  )[:, -(CHUNK_SIZE+1):-1]
3146
 
3147
  what_should_be_the_next_tokens = torch.stack([
3148
+ decoder_input_ids[i, running_indices[i] +
3149
+ 1:running_indices[i]+CHUNK_SIZE+1]
3150
  for i in range(decoder_input_ids.shape[0])
3151
  ])
3152
  # if the entire predicted chunk matches the next chunk, then we can saved some time and "jump" to the next chunk
3153
  if predicted_next_tokens.shape[1] == what_should_be_the_next_tokens.shape[1] and torch.all(predicted_next_tokens == what_should_be_the_next_tokens):
3154
  running_indices += CHUNK_SIZE
3155
+ running_decoder_input_ids = torch.cat(
3156
+ [running_decoder_input_ids, what_should_be_the_next_tokens], dim=-1)
3157
  continue
3158
+
3159
  # if, however, there is a deviation find the maximum prefix that matches in the batch
3160
 
3161
  predicted_next_tokens = predicted_next_tokens[:, 0]
3162
+ predicted_next_token_strings = processor.batch_decode(
3163
+ predicted_next_tokens)
3164
  next_tokens_to_concat = []
3165
  for i, (pnts, pnt) in enumerate(zip(predicted_next_token_strings, predicted_next_tokens)):
3166
  if (pnts.startswith("<loc_") or pnts in ["<s>", "<pad>", "</s>"]) and running_indices[i] < decoder_input_ids.shape[1] - 1:
 
3168
  else:
3169
  running_indices[i] += 1
3170
  if running_indices[i] >= decoder_input_ids.shape[1]:
3171
+ next_tokens_to_concat.append(torch.tensor(
3172
+ processor.tokenizer.eos_token_id, device=self.device))
3173
  # elif "’" in pnts: # this is an annoying character which looks like ' (apostrophe) but isn't.
3174
  # import pdb; pdb.set_trace()
3175
  else:
3176
+ next_tokens_to_concat.append(
3177
+ decoder_input_ids[i, running_indices[i]])
3178
  next_tokens_to_concat = torch.stack(next_tokens_to_concat)[:, None]
3179
  if (next_tokens_to_concat == processor.tokenizer.eos_token_id).all():
3180
  break
3181
+ running_decoder_input_ids = torch.cat(
3182
+ [running_decoder_input_ids, next_tokens_to_concat], dim=-1)
3183
  if num_tokens_generated >= 1024:
3184
  break
3185
  return running_decoder_input_ids
 
3210
  remove_prefix_length = decoder_input_ids.shape[1] - 1
3211
 
3212
  decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
3213
+
3214
  return {
3215
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
3216
  "encoder_outputs": encoder_outputs,
 
3222
  "head_mask": head_mask,
3223
  "decoder_head_mask": decoder_head_mask,
3224
  "cross_attn_head_mask": cross_attn_head_mask,
3225
+ # change this to avoid caching (presumably for debugging)
3226
+ "use_cache": use_cache,
3227
  }
3228
+
3229
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3230
  return self.language_model.shift_tokens_right(labels)
3231
 
3232
  def _reorder_cache(self, *args, **kwargs):
3233
  return self.language_model._reorder_cache(*args, **kwargs)
3234
+
3235
  def get_character_character_affinity_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3236
  character_character_affinity_matrices = []
3237
  for index in range(len(decoder_hidden_states)):
3238
+ character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3239
+ '<character>')).nonzero().squeeze(-1)
3240
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3241
  if character_embeddings.shape[0] == 0:
3242
+ character_character_affinity_matrices.append(
3243
+ torch.zeros(0, 0).type_as(character_embeddings))
3244
  continue
3245
+ character_embeddings = self.character_embedding_projection(
3246
+ character_embeddings)
3247
+ char_i = repeat(character_embeddings, "i d -> i repeat d",
3248
+ repeat=character_embeddings.shape[0])
3249
+ char_j = repeat(character_embeddings, "j d -> repeat j d",
3250
+ repeat=character_embeddings.shape[0])
3251
  char_ij = rearrange([char_i, char_j], "two i j d -> (i j) (two d)")
3252
+ character_character_affinities = self.character_character_matching_head(
3253
+ char_ij)
3254
+ character_character_affinities = rearrange(
3255
+ character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0])
3256
+ character_character_affinities = (
3257
+ character_character_affinities + character_character_affinities.T) / 2
3258
  if apply_sigmoid:
3259
+ character_character_affinities = torch.sigmoid(
3260
+ character_character_affinities)
3261
+ character_character_affinity_matrices.append(
3262
+ character_character_affinities)
3263
  return character_character_affinity_matrices
3264
 
3265
  def get_text_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3266
  text_character_association_matrices = []
3267
  for index in range(len(decoder_hidden_states)):
3268
+ text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3269
+ '<text>')).nonzero().squeeze(-1)
3270
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3271
+ character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3272
+ '<character>')).nonzero().squeeze(-1)
3273
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3274
  if character_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
3275
+ text_character_association_matrices.append(torch.zeros(
3276
+ text_embeddings.shape[0], character_embeddings.shape[0]).type_as(text_embeddings))
3277
  continue
3278
+ text_i = repeat(text_embeddings, "i d -> i repeat d",
3279
+ repeat=character_embeddings.shape[0])
3280
+ char_j = repeat(character_embeddings, "j d -> repeat j d",
3281
+ repeat=text_embeddings.shape[0])
3282
+ text_char_ij = rearrange(
3283
+ [text_i, char_j], "two i j d -> (i j) (two d)")
3284
+ text_character_affinities = self.text_character_matching_head(
3285
+ text_char_ij)
3286
+ text_character_affinities = rearrange(
3287
+ text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
3288
  if apply_sigmoid:
3289
+ text_character_affinities = torch.sigmoid(
3290
+ text_character_affinities)
3291
+ text_character_association_matrices.append(
3292
+ text_character_affinities)
3293
  return text_character_association_matrices
3294
 
3295
  def get_text_tail_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3296
  text_tail_association_matrices = []
3297
  for index in range(len(decoder_hidden_states)):
3298
+ text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3299
+ '<text>')).nonzero().squeeze(-1)
3300
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3301
+ tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3302
+ '<tail>')).nonzero().squeeze(-1)
3303
  tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
3304
  if tail_embeddings.shape[0] == 0 or text_embeddings.shape[0] == 0:
3305
+ text_tail_association_matrices.append(torch.zeros(
3306
+ text_embeddings.shape[0], tail_embeddings.shape[0]).type_as(text_embeddings))
3307
  continue
3308
+ text_i = repeat(text_embeddings, "i d -> i repeat d",
3309
+ repeat=tail_embeddings.shape[0])
3310
+ tail_j = repeat(tail_embeddings, "j d -> repeat j d",
3311
+ repeat=text_embeddings.shape[0])
3312
+ text_tail_ij = rearrange(
3313
+ [text_i, tail_j], "two i j d -> (i j) (two d)")
3314
  text_tail_affinities = self.text_tail_matching_head(text_tail_ij)
3315
+ text_tail_affinities = rearrange(
3316
+ text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0])
3317
  if apply_sigmoid:
3318
  text_tail_affinities = torch.sigmoid(text_tail_affinities)
3319
  text_tail_association_matrices.append(text_tail_affinities)
3320
  return text_tail_association_matrices
3321
+
3322
  def get_tail_character_association_matrices(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3323
  tail_character_association_matrices = []
3324
  for index in range(len(decoder_hidden_states)):
3325
+ tail_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3326
+ '<tail>')).nonzero().squeeze(-1)
3327
  tail_embeddings = decoder_hidden_states[index][tail_embedding_indices]
3328
+ character_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3329
+ '<character>')).nonzero().squeeze(-1)
3330
  character_embeddings = decoder_hidden_states[index][character_embedding_indices]
3331
  if character_embeddings.shape[0] == 0 or tail_embeddings.shape[0] == 0:
3332
+ tail_character_association_matrices.append(torch.zeros(
3333
+ tail_embeddings.shape[0], character_embeddings.shape[0]).type_as(tail_embeddings))
3334
  continue
3335
+ tail_i = repeat(tail_embeddings, "i d -> i repeat d",
3336
+ repeat=character_embeddings.shape[0])
3337
+ char_j = repeat(character_embeddings, "j d -> repeat j d",
3338
+ repeat=tail_embeddings.shape[0])
3339
+ tail_char_ij = rearrange(
3340
+ [tail_i, char_j], "two i j d -> (i j) (two d)")
3341
+ tail_character_affinities = self.tail_character_matching_head(
3342
+ tail_char_ij)
3343
+ tail_character_affinities = rearrange(
3344
+ tail_character_affinities, "(i j) 1 -> i j", i=tail_i.shape[0])
3345
  if apply_sigmoid:
3346
+ tail_character_affinities = torch.sigmoid(
3347
+ tail_character_affinities)
3348
+ tail_character_association_matrices.append(
3349
+ tail_character_affinities)
3350
  return tail_character_association_matrices
3351
+
3352
  def get_essential_text_logits(self, decoder_hidden_states, decoder_input_ids, tokenizer, apply_sigmoid=False):
3353
  essential_text_logits = []
3354
  for index in range(len(decoder_hidden_states)):
3355
+ text_embedding_indices = (decoder_input_ids[index] == tokenizer.convert_tokens_to_ids(
3356
+ '<text>')).nonzero().squeeze(-1)
3357
  text_embeddings = decoder_hidden_states[index][text_embedding_indices]
3358
  if text_embeddings.shape[0] == 0:
3359
+ essential_text_logits.append(
3360
+ torch.zeros(0).type_as(text_embeddings))
3361
  continue
3362
+ text_logits = rearrange(
3363
+ self.text_classification_head(text_embeddings), "i 1 -> i")
3364
  if apply_sigmoid:
3365
  text_logits = torch.sigmoid(text_logits)
3366
  essential_text_logits.append(text_logits)
3367
+ return essential_text_logits
processing_florence2.py CHANGED
@@ -53,18 +53,19 @@ class Florence2Processor(ProcessorMixin):
53
  if tokenizer is None:
54
  raise ValueError("You need to specify a `tokenizer`.")
55
  if not hasattr(image_processor, "image_seq_length"):
56
- raise ValueError("Image processor is missing an `image_seq_length` attribute.")
 
57
 
58
  self.image_seq_length = image_processor.image_seq_length
59
 
60
  tokens_to_add = {
61
- 'additional_special_tokens': \
62
- tokenizer.additional_special_tokens + \
63
- ['<od>', '</od>', '<ocr>', '</ocr>'] + \
64
- [f'<loc_{x}>' for x in range(1000)] + \
65
- ['<cap>', '</cap>', '<ncap>', '</ncap>','<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>'] + \
66
- ['<panel>', '<text>', '<character>', '<tail>']
67
- }
68
  tokenizer.add_special_tokens(tokens_to_add)
69
  self.decoder_start_token_id = 2
70
 
@@ -74,7 +75,7 @@ class Florence2Processor(ProcessorMixin):
74
  )
75
 
76
  super().__init__(image_processor, tokenizer)
77
-
78
  def __call__(
79
  self,
80
  batch_input_text: List[TextInput] = None,
@@ -82,11 +83,11 @@ class Florence2Processor(ProcessorMixin):
82
  batch_output_text: List[TextInput] = None,
83
  batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None,
84
  batch_images: ImageInput = None,
85
- batch_character_cluster_labels = None,
86
- batch_text_character_association_labels = None,
87
- batch_text_tail_association_labels = None,
88
- batch_is_essential_text_labels = None,
89
- batch_tail_character_association_labels = None,
90
  padding: Union[bool, str, PaddingStrategy] = None,
91
  truncation: Union[bool, str, TruncationStrategy] = None,
92
  max_input_length_including_image_tokens=None,
@@ -109,17 +110,23 @@ class Florence2Processor(ProcessorMixin):
109
  assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance."
110
  assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance."
111
  if batch_input_list_of_list_of_bboxes is None:
112
- batch_input_list_of_list_of_bboxes = [[] for _ in range(len(batch_input_text))]
113
- assert len(batch_input_text) == len(batch_input_list_of_list_of_bboxes) == len(batch_images), "`batch_input_text`, `batch_input_list_of_list_of_bboxes` and `batch_images` have different lengths."
 
 
114
  if batch_output_text is None:
115
  assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together."
116
  else:
117
  if batch_output_list_of_list_of_bboxes is None:
118
- batch_output_list_of_list_of_bboxes = [[] for _ in range(len(batch_output_text))]
119
- assert len(batch_output_text) == len(batch_output_list_of_list_of_bboxes) == len(batch_images), "`batch_output_text`, `batch_output_list_of_list_of_bboxes` and `batch_images` have different lengths."
120
-
121
- max_input_length = max_input_length_including_image_tokens - self.image_seq_length if max_input_length_including_image_tokens is not None else None
122
- batch_input_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(batch_input_text, batch_input_list_of_list_of_bboxes, batch_images)]
 
 
 
 
123
  inputs = self.tokenizer(
124
  batch_input_texts,
125
  return_tensors=return_tensors,
@@ -130,9 +137,10 @@ class Florence2Processor(ProcessorMixin):
130
  if inputs["input_ids"].shape[1] > max_input_length:
131
  inputs["input_ids"] = inputs["input_ids"][:, :max_input_length]
132
  inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length]
133
-
134
  if batch_output_text is not None:
135
- batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(batch_output_text, batch_output_list_of_list_of_bboxes, batch_images)]
 
136
  decoder_inputs = self.tokenizer(
137
  batch_output_texts,
138
  return_tensors=return_tensors,
@@ -141,9 +149,10 @@ class Florence2Processor(ProcessorMixin):
141
  )
142
  # Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior
143
  if decoder_inputs["input_ids"].shape[1] > max_output_length:
144
- decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:, :max_output_length]
145
- decoder_inputs["attention_mask"] = decoder_inputs["attention_mask"][:, :max_output_length]
146
-
 
147
 
148
  pixel_values = self.image_processor(
149
  batch_images,
@@ -160,7 +169,7 @@ class Florence2Processor(ProcessorMixin):
160
 
161
  if dtype is not None:
162
  pixel_values = pixel_values.to(dtype)
163
-
164
  return_data = {**inputs, "pixel_values": pixel_values}
165
 
166
  if batch_output_text is not None:
@@ -168,8 +177,10 @@ class Florence2Processor(ProcessorMixin):
168
  decoder_input_ids = labels.new_zeros(labels.shape)
169
  decoder_input_ids[:, 1:] = labels[:, :-1].clone()
170
  decoder_input_ids[:, 0] = self.decoder_start_token_id
171
- decoder_attention_mask = decoder_inputs["attention_mask"].new_ones(decoder_input_ids.shape)
172
- decoder_attention_mask[:, 1:] = decoder_inputs["attention_mask"][:, :-1].clone()
 
 
173
  # Mask fill labels to replace pad token ID with -100
174
  labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100)
175
  return_data.update({
@@ -177,7 +188,7 @@ class Florence2Processor(ProcessorMixin):
177
  "decoder_input_ids": decoder_input_ids,
178
  "decoder_attention_mask": decoder_attention_mask,
179
  })
180
-
181
  if device is not None:
182
  for key, value in return_data.items():
183
  if isinstance(value, torch.Tensor):
@@ -201,25 +212,32 @@ class Florence2Processor(ProcessorMixin):
201
  return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "")
202
 
203
  def postprocess_output(self, generated_ids, images):
204
- generated_ids.masked_fill_(generated_ids == -100, self.tokenizer.pad_token_id) # only for some testing purposes
205
- batch_decoded_texts = self.batch_decode(generated_ids, skip_special_tokens=False)
206
- batch_decoded_texts = [self.cleanup_generated_text(text) for text in batch_decoded_texts]
 
 
 
 
207
  batch_list_of_list_of_bboxes = []
208
  batch_indices_of_bboxes_in_new_string = []
209
  batch_new_texts = []
210
  for text, image in zip(batch_decoded_texts, images):
211
  size_wh = self._get_image_size_wh(image)
212
- parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes(text)
213
- list_of_list_of_bboxes = [self.box_quantizer.dequantize_from_stringified_bboxes(stringified_bbox, size_wh) for stringified_bbox in list_of_stringified_bboxes]
 
 
214
  batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes)
215
- batch_indices_of_bboxes_in_new_string.append(start_end_in_new_string)
 
216
  batch_new_texts.append(parsed_text)
217
  return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string
218
 
219
  def _parse_text_with_bboxes(self, text):
220
  loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)'
221
  grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern
222
-
223
  list_of_stringified_bboxes = []
224
  start_end_in_new_string = []
225
  new_text = ""
@@ -237,7 +255,8 @@ class Florence2Processor(ProcessorMixin):
237
  locs = match.group(2)
238
  new_text += grounding_text
239
  list_of_stringified_bboxes.append(locs)
240
- start_end_in_new_string.append((new_pos, new_pos + len(grounding_text)))
 
241
  new_pos += len(grounding_text)
242
  else:
243
  # Handle loc pattern
@@ -245,7 +264,8 @@ class Florence2Processor(ProcessorMixin):
245
  replacement = ""
246
  new_text += replacement
247
  list_of_stringified_bboxes.append(locs)
248
- start_end_in_new_string.append((new_pos, new_pos + len(replacement)))
 
249
  new_pos += len(replacement)
250
 
251
  original_pos = match.end()
@@ -254,19 +274,21 @@ class Florence2Processor(ProcessorMixin):
254
  new_text += text[original_pos:]
255
 
256
  return new_text, list_of_stringified_bboxes, start_end_in_new_string
257
-
258
  def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image):
259
  size_wh = self._get_image_size_wh(image)
260
  quantized_bbox_lists = []
261
- for list_of_bboxes in list_of_list_of_bboxes:
262
- quantized_bboxes = self.box_quantizer.quantize(list_of_bboxes, size_wh=size_wh)
263
- stringified_bboxes = [f"<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>" for x1, y1, x2, y2 in quantized_bboxes]
 
 
264
  stringified_bboxes = ",".join(stringified_bboxes)
265
  quantized_bbox_lists.append(stringified_bboxes)
266
  return text.format(*quantized_bbox_lists)
267
 
268
  def _get_image_size_wh(self, image):
269
- # Get size_wh from image based on its type
270
  if isinstance(image, torch.Tensor):
271
  # For PyTorch tensor
272
  if image.dim() == 3:
@@ -313,6 +335,7 @@ class Florence2Processor(ProcessorMixin):
313
  image_processor_input_names = self.image_processor.model_input_names
314
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
315
 
 
316
  class BoxQuantizer(object):
317
  def __init__(self, mode, bins):
318
  self.mode = mode
@@ -390,4 +413,4 @@ class BoxQuantizer(object):
390
  dequantized_xmax, dequantized_ymax), dim=-1
391
  )
392
 
393
- return dequantized_boxes
 
53
  if tokenizer is None:
54
  raise ValueError("You need to specify a `tokenizer`.")
55
  if not hasattr(image_processor, "image_seq_length"):
56
+ raise ValueError(
57
+ "Image processor is missing an `image_seq_length` attribute.")
58
 
59
  self.image_seq_length = image_processor.image_seq_length
60
 
61
  tokens_to_add = {
62
+ 'additional_special_tokens':
63
+ tokenizer.additional_special_tokens +
64
+ ['<od>', '</od>', '<ocr>', '</ocr>'] +
65
+ [f'<loc_{x}>' for x in range(1000)] +
66
+ ['<cap>', '</cap>', '<ncap>', '</ncap>', '<dcap>', '</dcap>', '<grounding>', '</grounding>', '<seg>', '</seg>', '<sep>', '<region_cap>', '</region_cap>', '<region_to_desciption>', '</region_to_desciption>', '<proposal>', '</proposal>', '<poly>', '</poly>', '<and>'] +
67
+ ['<panel>', '<text>', '<character>', '<tail>']
68
+ }
69
  tokenizer.add_special_tokens(tokens_to_add)
70
  self.decoder_start_token_id = 2
71
 
 
75
  )
76
 
77
  super().__init__(image_processor, tokenizer)
78
+
79
  def __call__(
80
  self,
81
  batch_input_text: List[TextInput] = None,
 
83
  batch_output_text: List[TextInput] = None,
84
  batch_output_list_of_list_of_bboxes: List[List[List[List[float]]]] = None,
85
  batch_images: ImageInput = None,
86
+ batch_character_cluster_labels=None,
87
+ batch_text_character_association_labels=None,
88
+ batch_text_tail_association_labels=None,
89
+ batch_is_essential_text_labels=None,
90
+ batch_tail_character_association_labels=None,
91
  padding: Union[bool, str, PaddingStrategy] = None,
92
  truncation: Union[bool, str, TruncationStrategy] = None,
93
  max_input_length_including_image_tokens=None,
 
110
  assert batch_images is not None, "`batch_images` are expected as arguments to a `Florence2Processor` instance."
111
  assert batch_input_text is not None, "`batch_input_text` are expected as arguments to a `Florence2Processor` instance."
112
  if batch_input_list_of_list_of_bboxes is None:
113
+ batch_input_list_of_list_of_bboxes = [
114
+ [] for _ in range(len(batch_input_text))]
115
+ assert len(batch_input_text) == len(batch_input_list_of_list_of_bboxes) == len(
116
+ batch_images), "`batch_input_text`, `batch_input_list_of_list_of_bboxes` and `batch_images` have different lengths."
117
  if batch_output_text is None:
118
  assert batch_output_list_of_list_of_bboxes is None, "`batch_output_text` and `batch_output_list_of_list_of_bboxes` should be provided together."
119
  else:
120
  if batch_output_list_of_list_of_bboxes is None:
121
+ batch_output_list_of_list_of_bboxes = [
122
+ [] for _ in range(len(batch_output_text))]
123
+ assert len(batch_output_text) == len(batch_output_list_of_list_of_bboxes) == len(
124
+ batch_images), "`batch_output_text`, `batch_output_list_of_list_of_bboxes` and `batch_images` have different lengths."
125
+
126
+ max_input_length = max_input_length_including_image_tokens - \
127
+ self.image_seq_length if max_input_length_including_image_tokens is not None else None
128
+ batch_input_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(
129
+ batch_input_text, batch_input_list_of_list_of_bboxes, batch_images)]
130
  inputs = self.tokenizer(
131
  batch_input_texts,
132
  return_tensors=return_tensors,
 
137
  if inputs["input_ids"].shape[1] > max_input_length:
138
  inputs["input_ids"] = inputs["input_ids"][:, :max_input_length]
139
  inputs["attention_mask"] = inputs["attention_mask"][:, :max_input_length]
140
+
141
  if batch_output_text is not None:
142
+ batch_output_texts = [self._format_text_with_bboxes(text, list_of_list_of_bboxes, image) for text, list_of_list_of_bboxes, image in zip(
143
+ batch_output_text, batch_output_list_of_list_of_bboxes, batch_images)]
144
  decoder_inputs = self.tokenizer(
145
  batch_output_texts,
146
  return_tensors=return_tensors,
 
149
  )
150
  # Truncating manually because I don't want </s> token at the end of truncated sequences, which is the default behavior
151
  if decoder_inputs["input_ids"].shape[1] > max_output_length:
152
+ decoder_inputs["input_ids"] = decoder_inputs["input_ids"][:,
153
+ :max_output_length]
154
+ decoder_inputs["attention_mask"] = decoder_inputs["attention_mask"][:,
155
+ :max_output_length]
156
 
157
  pixel_values = self.image_processor(
158
  batch_images,
 
169
 
170
  if dtype is not None:
171
  pixel_values = pixel_values.to(dtype)
172
+
173
  return_data = {**inputs, "pixel_values": pixel_values}
174
 
175
  if batch_output_text is not None:
 
177
  decoder_input_ids = labels.new_zeros(labels.shape)
178
  decoder_input_ids[:, 1:] = labels[:, :-1].clone()
179
  decoder_input_ids[:, 0] = self.decoder_start_token_id
180
+ decoder_attention_mask = decoder_inputs["attention_mask"].new_ones(
181
+ decoder_input_ids.shape)
182
+ decoder_attention_mask[:,
183
+ 1:] = decoder_inputs["attention_mask"][:, :-1].clone()
184
  # Mask fill labels to replace pad token ID with -100
185
  labels.masked_fill_(labels == self.tokenizer.pad_token_id, -100)
186
  return_data.update({
 
188
  "decoder_input_ids": decoder_input_ids,
189
  "decoder_attention_mask": decoder_attention_mask,
190
  })
191
+
192
  if device is not None:
193
  for key, value in return_data.items():
194
  if isinstance(value, torch.Tensor):
 
212
  return generated_text.replace("<s>", "").replace("</s>", "").replace("<pad>", "")
213
 
214
  def postprocess_output(self, generated_ids, images):
215
+ # only for some testing purposes
216
+ generated_ids.masked_fill_(
217
+ generated_ids == -100, self.tokenizer.pad_token_id)
218
+ batch_decoded_texts = self.batch_decode(
219
+ generated_ids, skip_special_tokens=False)
220
+ batch_decoded_texts = [self.cleanup_generated_text(
221
+ text) for text in batch_decoded_texts]
222
  batch_list_of_list_of_bboxes = []
223
  batch_indices_of_bboxes_in_new_string = []
224
  batch_new_texts = []
225
  for text, image in zip(batch_decoded_texts, images):
226
  size_wh = self._get_image_size_wh(image)
227
+ parsed_text, list_of_stringified_bboxes, start_end_in_new_string = self._parse_text_with_bboxes(
228
+ text)
229
+ list_of_list_of_bboxes = [self.box_quantizer.dequantize_from_stringified_bboxes(
230
+ stringified_bbox, size_wh) for stringified_bbox in list_of_stringified_bboxes]
231
  batch_list_of_list_of_bboxes.append(list_of_list_of_bboxes)
232
+ batch_indices_of_bboxes_in_new_string.append(
233
+ start_end_in_new_string)
234
  batch_new_texts.append(parsed_text)
235
  return batch_new_texts, batch_list_of_list_of_bboxes, batch_indices_of_bboxes_in_new_string
236
 
237
  def _parse_text_with_bboxes(self, text):
238
  loc_pattern = r'((?:<loc_\d+>){4}(?:,(?:<loc_\d+>){4})*)'
239
  grounding_pattern = r'<grounding>(.*?)</grounding>' + loc_pattern
240
+
241
  list_of_stringified_bboxes = []
242
  start_end_in_new_string = []
243
  new_text = ""
 
255
  locs = match.group(2)
256
  new_text += grounding_text
257
  list_of_stringified_bboxes.append(locs)
258
+ start_end_in_new_string.append(
259
+ (new_pos, new_pos + len(grounding_text)))
260
  new_pos += len(grounding_text)
261
  else:
262
  # Handle loc pattern
 
264
  replacement = ""
265
  new_text += replacement
266
  list_of_stringified_bboxes.append(locs)
267
+ start_end_in_new_string.append(
268
+ (new_pos, new_pos + len(replacement)))
269
  new_pos += len(replacement)
270
 
271
  original_pos = match.end()
 
274
  new_text += text[original_pos:]
275
 
276
  return new_text, list_of_stringified_bboxes, start_end_in_new_string
277
+
278
  def _format_text_with_bboxes(self, text, list_of_list_of_bboxes, image):
279
  size_wh = self._get_image_size_wh(image)
280
  quantized_bbox_lists = []
281
+ for list_of_bboxes in list_of_list_of_bboxes:
282
+ quantized_bboxes = self.box_quantizer.quantize(
283
+ list_of_bboxes, size_wh=size_wh)
284
+ stringified_bboxes = [
285
+ f"<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>" for x1, y1, x2, y2 in quantized_bboxes]
286
  stringified_bboxes = ",".join(stringified_bboxes)
287
  quantized_bbox_lists.append(stringified_bboxes)
288
  return text.format(*quantized_bbox_lists)
289
 
290
  def _get_image_size_wh(self, image):
291
+ # Get size_wh from image based on its type
292
  if isinstance(image, torch.Tensor):
293
  # For PyTorch tensor
294
  if image.dim() == 3:
 
335
  image_processor_input_names = self.image_processor.model_input_names
336
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
337
 
338
+
339
  class BoxQuantizer(object):
340
  def __init__(self, mode, bins):
341
  self.mode = mode
 
413
  dequantized_xmax, dequantized_ymax), dim=-1
414
  )
415
 
416
+ return dequantized_boxes
utils.py CHANGED
@@ -9,6 +9,7 @@ from copy import deepcopy
9
  from itertools import groupby
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
11
 
 
12
  def move_to_device(inputs, device):
13
  if hasattr(inputs, "keys"):
14
  return {k: move_to_device(v, device) for k, v in inputs.items()}
@@ -21,6 +22,7 @@ def move_to_device(inputs, device):
21
  else:
22
  return inputs.to(device)
23
 
 
24
  class UnionFind:
25
  def __init__(self, n):
26
  self.parent = list(range(n))
@@ -35,7 +37,7 @@ class UnionFind:
35
  if adj_matrix[i, j] > 0:
36
  ufds.unite(i, j)
37
  return ufds
38
-
39
  @classmethod
40
  def from_adj_list(cls, adj_list):
41
  ufds = cls(len(adj_list))
@@ -43,7 +45,7 @@ class UnionFind:
43
  for j in adj_list[i]:
44
  ufds.unite(i, j)
45
  return ufds
46
-
47
  @classmethod
48
  def from_edge_list(cls, edge_list, num_nodes):
49
  ufds = cls(num_nodes)
@@ -66,11 +68,11 @@ class UnionFind:
66
  self.parent[y] = x
67
  self.size[x] += self.size[y]
68
  self.num_components -= 1
69
-
70
  def get_components_of(self, x):
71
  x = self.find(x)
72
  return [i for i in range(len(self.parent)) if self.find(i) == x]
73
-
74
  def are_connected(self, x, y):
75
  return self.find(x) == self.find(y)
76
 
@@ -79,7 +81,7 @@ class UnionFind:
79
 
80
  def get_num_components(self):
81
  return self.num_components
82
-
83
  def get_labels_for_connected_components(self):
84
  map_parent_to_label = {}
85
  labels = []
@@ -90,6 +92,7 @@ class UnionFind:
90
  labels.append(map_parent_to_label[parent])
91
  return labels
92
 
 
93
  def visualise_single_image_prediction(image_as_np_array, predictions, filename):
94
  h, w = image_as_np_array.shape[:2]
95
  if h > w:
@@ -102,15 +105,16 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
102
  plot_bboxes(subplot, predictions["characters"], color="blue")
103
 
104
  COLOURS = [
105
- "#b7ff51", # green
106
- "#f50a8f", # pink
107
- "#4b13b6", # purple
108
- "#ddaa34", # orange
109
- "#bea2a2", # brown
110
  ]
111
  colour_index = 0
112
  character_cluster_labels = predictions["character_cluster_labels"]
113
- unique_label_sorted_by_frequency = sorted(list(set(character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
 
114
  for label in unique_label_sorted_by_frequency:
115
  root = None
116
  others = []
@@ -123,7 +127,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
123
  if colour_index >= len(COLOURS):
124
  random_colour = COLOURS[0]
125
  while random_colour in COLOURS:
126
- random_colour = "#" + "".join([random.choice("0123456789ABCDEF") for j in range(6)])
 
 
127
  else:
128
  random_colour = COLOURS[colour_index]
129
  colour_index += 1
@@ -139,8 +145,9 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
139
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
140
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
141
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
142
- subplot.plot([x2], [y2], color=random_colour, marker="o", markersize=5)
143
-
 
144
  for (i, j) in predictions["text_character_associations"]:
145
  score = predictions["dialog_confidences"][i]
146
  bbox_i = predictions["texts"][i]
@@ -149,7 +156,8 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
149
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
150
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
151
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
152
- subplot.plot([x1, x2], [y1, y2], color="red", linewidth=2, linestyle="dashed", alpha=score)
 
153
 
154
  subplot.axis("off")
155
  if filename is not None:
@@ -160,6 +168,7 @@ def visualise_single_image_prediction(image_as_np_array, predictions, filename):
160
  plt.close()
161
  return image
162
 
 
163
  def plot_bboxes(subplot, bboxes, color="red", add_index=False):
164
  for id, bbox in enumerate(bboxes):
165
  w = bbox[2] - bbox[0]
@@ -170,7 +179,9 @@ def plot_bboxes(subplot, bboxes, color="red", add_index=False):
170
  subplot.add_patch(rect)
171
  if add_index:
172
  cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
173
- subplot.text(cx, cy, str(id), color=color, fontsize=10, ha="center", va="center")
 
 
174
 
175
  def sort_panels(rects):
176
  before_rects = convert_to_list_of_lists(rects)
@@ -203,34 +214,42 @@ def sort_panels(rects):
203
  G.remove_edge(*max_cyclic_edge)
204
  return list(nx.topological_sort(G))
205
 
 
206
  def is_strictly_above(rectA, rectB):
207
  x1A, y1A, x2A, y2A = rectA
208
  x1B, y1B, x2B, y2B = rectB
209
  return y2A < y1B
210
 
 
211
  def is_strictly_below(rectA, rectB):
212
  x1A, y1A, x2A, y2A = rectA
213
  x1B, y1B, x2B, y2B = rectB
214
  return y2B < y1A
215
 
 
216
  def is_strictly_left_of(rectA, rectB):
217
  x1A, y1A, x2A, y2A = rectA
218
  x1B, y1B, x2B, y2B = rectB
219
  return x2A < x1B
220
 
 
221
  def is_strictly_right_of(rectA, rectB):
222
  x1A, y1A, x2A, y2A = rectA
223
  x1B, y1B, x2B, y2B = rectB
224
  return x2B < x1A
225
 
 
226
  def intersects(rectA, rectB):
227
  return box(*rectA).intersects(box(*rectB))
228
 
 
229
  def is_there_a_directed_edge(a, b, rects):
230
  rectA = rects[a]
231
  rectB = rects[b]
232
- centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2, rectA[1] + (rectA[3] - rectA[1]) / 2]
233
- centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2, rectB[1] + (rectB[3] - rectB[1]) / 2]
 
 
234
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
235
  return box(*rectA).area > (box(*rectB)).area
236
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
@@ -247,34 +266,41 @@ def is_there_a_directed_edge(a, b, rects):
247
  if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
248
  return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
249
  if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
250
- return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
251
  # otherwise they intersect
252
  copy_A = erode_rectangle(copy_A, 0.05)
253
  copy_B = erode_rectangle(copy_B, 0.05)
254
-
 
255
  def get_distance(rectA, rectB):
256
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
257
 
 
258
  def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
259
  rects = deepcopy(rects)
260
  while True:
261
- xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
262
- rect_index = [i for i in range(len(rects)) if intersects(rects[i], [xmin, ymin, xmax, ymax])]
263
- rects_copy = [rect for rect in rects if intersects(rect, [xmin, ymin, xmax, ymax])]
264
-
 
 
 
265
  # try to split the panels using a "horizontal" lines
266
- overlapping_y_ranges = merge_overlapping_ranges([(y1, y2) for x1, y1, x2, y2 in rects_copy])
 
267
  panel_index_to_split = {}
268
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
269
  for i, index in enumerate(rect_index):
270
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
271
  panel_index_to_split[index] = split_index
272
-
273
  if panel_index_to_split[a] != panel_index_to_split[b]:
274
  return panel_index_to_split[a] < panel_index_to_split[b]
275
-
276
  # try to split the panels using a "vertical" lines
277
- overlapping_x_ranges = merge_overlapping_ranges([(x1, x2) for x1, y1, x2, y2 in rects_copy])
 
278
  panel_index_to_split = {}
279
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
280
  for i, index in enumerate(rect_index):
@@ -282,10 +308,11 @@ def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
282
  panel_index_to_split[index] = split_index
283
  if panel_index_to_split[a] != panel_index_to_split[b]:
284
  return panel_index_to_split[a] < panel_index_to_split[b]
285
-
286
  # otherwise, erode the rectangles and try again
287
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
288
 
 
289
  def erode_rectangle(bbox, erosion_factor):
290
  x1, y1, x2, y2 = bbox
291
  w, h = x2 - x1, y2 - y1
@@ -303,6 +330,7 @@ def erode_rectangle(bbox, erosion_factor):
303
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
304
  return [x1, y1, x2, y2]
305
 
 
306
  def merge_overlapping_ranges(ranges):
307
  """
308
  ranges: list of tuples (x1, x2)
@@ -324,6 +352,7 @@ def merge_overlapping_ranges(ranges):
324
  merged_ranges.append((prev_x1, prev_x2))
325
  return merged_ranges
326
 
 
327
  def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
328
  text_bboxes = convert_to_list_of_lists(text_bboxes)
329
  sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
@@ -335,18 +364,23 @@ def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
335
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
336
  return [list(indices) for _, indices in groups]
337
 
338
- panel_id_for_text = get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes)
 
339
  indices_of_texts = list(range(len(text_bboxes)))
340
- indices_of_texts, panel_id_for_text = zip(*sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
 
341
  indices_of_texts = list(indices_of_texts)
342
  grouped_indices = indices_of_same_elements(panel_id_for_text)
343
  for group in grouped_indices:
344
  subset_of_text_indices = [indices_of_texts[i] for i in group]
345
- text_bboxes_of_subset = [text_bboxes[i] for i in subset_of_text_indices]
 
346
  sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
347
- indices_of_texts[group[0] : group[-1] + 1] = [subset_of_text_indices[i] for i in sorted_subset_indices]
 
348
  return indices_of_texts
349
 
 
350
  def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
351
  text_to_panel_mapping = []
352
  for text_bbox in text_bboxes:
@@ -359,14 +393,19 @@ def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
359
  for j, annotation in enumerate(sorted_panel_bboxes):
360
  shapely_annotation_polygon = box(*annotation)
361
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
362
- all_intersections.append((shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
363
- all_distances.append((shapely_text_polygon.distance(shapely_annotation_polygon), j))
 
 
364
  if len(all_intersections) == 0:
365
- text_to_panel_mapping.append(min(all_distances, key=lambda x: x[0])[1])
 
366
  else:
367
- text_to_panel_mapping.append(max(all_intersections, key=lambda x: x[0])[1])
 
368
  return text_to_panel_mapping
369
 
 
370
  def sort_texts_within_panel(rects):
371
  smallest_y = float("inf")
372
  greatest_x = float("-inf")
@@ -374,18 +413,20 @@ def sort_texts_within_panel(rects):
374
  x1, y1, x2, y2 = rect
375
  smallest_y = min(smallest_y, y1)
376
  greatest_x = max(greatest_x, x2)
377
-
378
  reference_point = Point(greatest_x, smallest_y)
379
 
380
  polygons_and_index = []
381
  for i, rect in enumerate(rects):
382
  x1, y1, x2, y2 = rect
383
- polygons_and_index.append((box(x1,y1,x2,y2), i))
384
  # sort points by closest to reference point
385
- polygons_and_index = sorted(polygons_and_index, key=lambda x: reference_point.distance(x[0]))
 
386
  indices = [x[1] for x in polygons_and_index]
387
  return indices
388
 
 
389
  def force_to_be_valid_bboxes(bboxes):
390
  if len(bboxes) == 0:
391
  return bboxes
@@ -394,20 +435,24 @@ def force_to_be_valid_bboxes(bboxes):
394
  bboxes_as_xywh[:, 2] = torch.clamp(bboxes_as_xywh[:, 2], min=1)
395
  bboxes_as_xywh[:, 3] = torch.clamp(bboxes_as_xywh[:, 3], min=1)
396
  bboxes_as_xywh = bboxes_as_xywh.tolist()
397
- bboxes_as_xyxy = [[x1, y1, x1 + w, y1 + h] for x1, y1, w, h in bboxes_as_xywh]
 
398
  return bboxes_as_xyxy
399
 
 
400
  def x1y1wh_to_x1y1x2y2(bbox):
401
  x1, y1, w, h = bbox
402
  return [x1, y1, x1 + w, y1 + h]
403
 
 
404
  def x1y1x2y2_to_xywh(bbox):
405
  x1, y1, x2, y2 = bbox
406
  return [x1, y1, x2 - x1, y2 - y1]
407
 
 
408
  def convert_to_list_of_lists(rects):
409
  if isinstance(rects, torch.Tensor):
410
  return rects.tolist()
411
  if isinstance(rects, np.ndarray):
412
  return rects.tolist()
413
- return [[a, b, c, d] for a, b, c, d in rects]
 
9
  from itertools import groupby
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError
11
 
12
+
13
  def move_to_device(inputs, device):
14
  if hasattr(inputs, "keys"):
15
  return {k: move_to_device(v, device) for k, v in inputs.items()}
 
22
  else:
23
  return inputs.to(device)
24
 
25
+
26
  class UnionFind:
27
  def __init__(self, n):
28
  self.parent = list(range(n))
 
37
  if adj_matrix[i, j] > 0:
38
  ufds.unite(i, j)
39
  return ufds
40
+
41
  @classmethod
42
  def from_adj_list(cls, adj_list):
43
  ufds = cls(len(adj_list))
 
45
  for j in adj_list[i]:
46
  ufds.unite(i, j)
47
  return ufds
48
+
49
  @classmethod
50
  def from_edge_list(cls, edge_list, num_nodes):
51
  ufds = cls(num_nodes)
 
68
  self.parent[y] = x
69
  self.size[x] += self.size[y]
70
  self.num_components -= 1
71
+
72
  def get_components_of(self, x):
73
  x = self.find(x)
74
  return [i for i in range(len(self.parent)) if self.find(i) == x]
75
+
76
  def are_connected(self, x, y):
77
  return self.find(x) == self.find(y)
78
 
 
81
 
82
  def get_num_components(self):
83
  return self.num_components
84
+
85
  def get_labels_for_connected_components(self):
86
  map_parent_to_label = {}
87
  labels = []
 
92
  labels.append(map_parent_to_label[parent])
93
  return labels
94
 
95
+
96
  def visualise_single_image_prediction(image_as_np_array, predictions, filename):
97
  h, w = image_as_np_array.shape[:2]
98
  if h > w:
 
105
  plot_bboxes(subplot, predictions["characters"], color="blue")
106
 
107
  COLOURS = [
108
+ "#b7ff51", # green
109
+ "#f50a8f", # pink
110
+ "#4b13b6", # purple
111
+ "#ddaa34", # orange
112
+ "#bea2a2", # brown
113
  ]
114
  colour_index = 0
115
  character_cluster_labels = predictions["character_cluster_labels"]
116
+ unique_label_sorted_by_frequency = sorted(list(set(
117
+ character_cluster_labels)), key=lambda x: character_cluster_labels.count(x), reverse=True)
118
  for label in unique_label_sorted_by_frequency:
119
  root = None
120
  others = []
 
127
  if colour_index >= len(COLOURS):
128
  random_colour = COLOURS[0]
129
  while random_colour in COLOURS:
130
+ random_colour = "#" + \
131
+ "".join([random.choice("0123456789ABCDEF")
132
+ for j in range(6)])
133
  else:
134
  random_colour = COLOURS[colour_index]
135
  colour_index += 1
 
145
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
146
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
147
  subplot.plot([x1, x2], [y1, y2], color=random_colour, linewidth=2)
148
+ subplot.plot([x2], [y2], color=random_colour,
149
+ marker="o", markersize=5)
150
+
151
  for (i, j) in predictions["text_character_associations"]:
152
  score = predictions["dialog_confidences"][i]
153
  bbox_i = predictions["texts"][i]
 
156
  y1 = bbox_i[1] + (bbox_i[3] - bbox_i[1]) / 2
157
  x2 = bbox_j[0] + (bbox_j[2] - bbox_j[0]) / 2
158
  y2 = bbox_j[1] + (bbox_j[3] - bbox_j[1]) / 2
159
+ subplot.plot([x1, x2], [y1, y2], color="red",
160
+ linewidth=2, linestyle="dashed", alpha=score)
161
 
162
  subplot.axis("off")
163
  if filename is not None:
 
168
  plt.close()
169
  return image
170
 
171
+
172
  def plot_bboxes(subplot, bboxes, color="red", add_index=False):
173
  for id, bbox in enumerate(bboxes):
174
  w = bbox[2] - bbox[0]
 
179
  subplot.add_patch(rect)
180
  if add_index:
181
  cx, cy = bbox[0] + w / 2, bbox[1] + h / 2
182
+ subplot.text(cx, cy, str(id), color=color,
183
+ fontsize=10, ha="center", va="center")
184
+
185
 
186
  def sort_panels(rects):
187
  before_rects = convert_to_list_of_lists(rects)
 
214
  G.remove_edge(*max_cyclic_edge)
215
  return list(nx.topological_sort(G))
216
 
217
+
218
  def is_strictly_above(rectA, rectB):
219
  x1A, y1A, x2A, y2A = rectA
220
  x1B, y1B, x2B, y2B = rectB
221
  return y2A < y1B
222
 
223
+
224
  def is_strictly_below(rectA, rectB):
225
  x1A, y1A, x2A, y2A = rectA
226
  x1B, y1B, x2B, y2B = rectB
227
  return y2B < y1A
228
 
229
+
230
  def is_strictly_left_of(rectA, rectB):
231
  x1A, y1A, x2A, y2A = rectA
232
  x1B, y1B, x2B, y2B = rectB
233
  return x2A < x1B
234
 
235
+
236
  def is_strictly_right_of(rectA, rectB):
237
  x1A, y1A, x2A, y2A = rectA
238
  x1B, y1B, x2B, y2B = rectB
239
  return x2B < x1A
240
 
241
+
242
  def intersects(rectA, rectB):
243
  return box(*rectA).intersects(box(*rectB))
244
 
245
+
246
  def is_there_a_directed_edge(a, b, rects):
247
  rectA = rects[a]
248
  rectB = rects[b]
249
+ centre_of_A = [rectA[0] + (rectA[2] - rectA[0]) / 2,
250
+ rectA[1] + (rectA[3] - rectA[1]) / 2]
251
+ centre_of_B = [rectB[0] + (rectB[2] - rectB[0]) / 2,
252
+ rectB[1] + (rectB[3] - rectB[1]) / 2]
253
  if np.allclose(np.array(centre_of_A), np.array(centre_of_B)):
254
  return box(*rectA).area > (box(*rectB)).area
255
  copy_A = [rectA[0], rectA[1], rectA[2], rectA[3]]
 
266
  if is_strictly_below(copy_A, copy_B) and is_strictly_right_of(copy_A, copy_B):
267
  return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
268
  if is_strictly_below(copy_B, copy_A) and is_strictly_right_of(copy_B, copy_A):
269
+ return use_cuts_to_determine_edge_from_a_to_b(a, b, rects)
270
  # otherwise they intersect
271
  copy_A = erode_rectangle(copy_A, 0.05)
272
  copy_B = erode_rectangle(copy_B, 0.05)
273
+
274
+
275
  def get_distance(rectA, rectB):
276
  return box(rectA[0], rectA[1], rectA[2], rectA[3]).distance(box(rectB[0], rectB[1], rectB[2], rectB[3]))
277
 
278
+
279
  def use_cuts_to_determine_edge_from_a_to_b(a, b, rects):
280
  rects = deepcopy(rects)
281
  while True:
282
+ xmin, ymin, xmax, ymax = min(rects[a][0], rects[b][0]), min(
283
+ rects[a][1], rects[b][1]), max(rects[a][2], rects[b][2]), max(rects[a][3], rects[b][3])
284
+ rect_index = [i for i in range(len(rects)) if intersects(
285
+ rects[i], [xmin, ymin, xmax, ymax])]
286
+ rects_copy = [rect for rect in rects if intersects(
287
+ rect, [xmin, ymin, xmax, ymax])]
288
+
289
  # try to split the panels using a "horizontal" lines
290
+ overlapping_y_ranges = merge_overlapping_ranges(
291
+ [(y1, y2) for x1, y1, x2, y2 in rects_copy])
292
  panel_index_to_split = {}
293
  for split_index, (y1, y2) in enumerate(overlapping_y_ranges):
294
  for i, index in enumerate(rect_index):
295
  if y1 <= rects_copy[i][1] <= rects_copy[i][3] <= y2:
296
  panel_index_to_split[index] = split_index
297
+
298
  if panel_index_to_split[a] != panel_index_to_split[b]:
299
  return panel_index_to_split[a] < panel_index_to_split[b]
300
+
301
  # try to split the panels using a "vertical" lines
302
+ overlapping_x_ranges = merge_overlapping_ranges(
303
+ [(x1, x2) for x1, y1, x2, y2 in rects_copy])
304
  panel_index_to_split = {}
305
  for split_index, (x1, x2) in enumerate(overlapping_x_ranges[::-1]):
306
  for i, index in enumerate(rect_index):
 
308
  panel_index_to_split[index] = split_index
309
  if panel_index_to_split[a] != panel_index_to_split[b]:
310
  return panel_index_to_split[a] < panel_index_to_split[b]
311
+
312
  # otherwise, erode the rectangles and try again
313
  rects = [erode_rectangle(rect, 0.05) for rect in rects]
314
 
315
+
316
  def erode_rectangle(bbox, erosion_factor):
317
  x1, y1, x2, y2 = bbox
318
  w, h = x2 - x1, y2 - y1
 
330
  x1, y1, x2, y2 = cx - w / 2, cy - h / 2, cx + w / 2, cy + h / 2
331
  return [x1, y1, x2, y2]
332
 
333
+
334
  def merge_overlapping_ranges(ranges):
335
  """
336
  ranges: list of tuples (x1, x2)
 
352
  merged_ranges.append((prev_x1, prev_x2))
353
  return merged_ranges
354
 
355
+
356
  def sort_text_boxes_in_reading_order(text_bboxes, sorted_panel_bboxes):
357
  text_bboxes = convert_to_list_of_lists(text_bboxes)
358
  sorted_panel_bboxes = convert_to_list_of_lists(sorted_panel_bboxes)
 
364
  groups = groupby(range(len(nums)), key=lambda i: nums[i])
365
  return [list(indices) for _, indices in groups]
366
 
367
+ panel_id_for_text = get_text_to_panel_mapping(
368
+ text_bboxes, sorted_panel_bboxes)
369
  indices_of_texts = list(range(len(text_bboxes)))
370
+ indices_of_texts, panel_id_for_text = zip(
371
+ *sorted(zip(indices_of_texts, panel_id_for_text), key=lambda x: x[1]))
372
  indices_of_texts = list(indices_of_texts)
373
  grouped_indices = indices_of_same_elements(panel_id_for_text)
374
  for group in grouped_indices:
375
  subset_of_text_indices = [indices_of_texts[i] for i in group]
376
+ text_bboxes_of_subset = [text_bboxes[i]
377
+ for i in subset_of_text_indices]
378
  sorted_subset_indices = sort_texts_within_panel(text_bboxes_of_subset)
379
+ indices_of_texts[group[0]: group[-1] + 1] = [subset_of_text_indices[i]
380
+ for i in sorted_subset_indices]
381
  return indices_of_texts
382
 
383
+
384
  def get_text_to_panel_mapping(text_bboxes, sorted_panel_bboxes):
385
  text_to_panel_mapping = []
386
  for text_bbox in text_bboxes:
 
393
  for j, annotation in enumerate(sorted_panel_bboxes):
394
  shapely_annotation_polygon = box(*annotation)
395
  if shapely_text_polygon.intersects(shapely_annotation_polygon):
396
+ all_intersections.append(
397
+ (shapely_text_polygon.intersection(shapely_annotation_polygon).area, j))
398
+ all_distances.append(
399
+ (shapely_text_polygon.distance(shapely_annotation_polygon), j))
400
  if len(all_intersections) == 0:
401
+ text_to_panel_mapping.append(
402
+ min(all_distances, key=lambda x: x[0])[1])
403
  else:
404
+ text_to_panel_mapping.append(
405
+ max(all_intersections, key=lambda x: x[0])[1])
406
  return text_to_panel_mapping
407
 
408
+
409
  def sort_texts_within_panel(rects):
410
  smallest_y = float("inf")
411
  greatest_x = float("-inf")
 
413
  x1, y1, x2, y2 = rect
414
  smallest_y = min(smallest_y, y1)
415
  greatest_x = max(greatest_x, x2)
416
+
417
  reference_point = Point(greatest_x, smallest_y)
418
 
419
  polygons_and_index = []
420
  for i, rect in enumerate(rects):
421
  x1, y1, x2, y2 = rect
422
+ polygons_and_index.append((box(x1, y1, x2, y2), i))
423
  # sort points by closest to reference point
424
+ polygons_and_index = sorted(
425
+ polygons_and_index, key=lambda x: reference_point.distance(x[0]))
426
  indices = [x[1] for x in polygons_and_index]
427
  return indices
428
 
429
+
430
  def force_to_be_valid_bboxes(bboxes):
431
  if len(bboxes) == 0:
432
  return bboxes
 
435
  bboxes_as_xywh[:, 2] = torch.clamp(bboxes_as_xywh[:, 2], min=1)
436
  bboxes_as_xywh[:, 3] = torch.clamp(bboxes_as_xywh[:, 3], min=1)
437
  bboxes_as_xywh = bboxes_as_xywh.tolist()
438
+ bboxes_as_xyxy = [[x1, y1, x1 + w, y1 + h]
439
+ for x1, y1, w, h in bboxes_as_xywh]
440
  return bboxes_as_xyxy
441
 
442
+
443
  def x1y1wh_to_x1y1x2y2(bbox):
444
  x1, y1, w, h = bbox
445
  return [x1, y1, x1 + w, y1 + h]
446
 
447
+
448
  def x1y1x2y2_to_xywh(bbox):
449
  x1, y1, x2, y2 = bbox
450
  return [x1, y1, x2 - x1, y2 - y1]
451
 
452
+
453
  def convert_to_list_of_lists(rects):
454
  if isinstance(rects, torch.Tensor):
455
  return rects.tolist()
456
  if isinstance(rects, np.ndarray):
457
  return rects.tolist()
458
+ return [[a, b, c, d] for a, b, c, d in rects]