code compatibility with python3.9
Browse files- processing_maira2.py +48 -48
processing_maira2.py
CHANGED
|
@@ -3,7 +3,7 @@
|
|
| 3 |
|
| 4 |
|
| 5 |
import re
|
| 6 |
-
from typing import Any,
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
|
@@ -14,9 +14,9 @@ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
|
| 14 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
| 15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
|
| 17 |
-
SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
| 18 |
-
ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
| 19 |
-
BoxType: TypeAlias = tuple[float, float, float, float]
|
| 20 |
|
| 21 |
|
| 22 |
class Maira2Processor(LlavaProcessor):
|
|
@@ -55,9 +55,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 55 |
self,
|
| 56 |
image_processor: BaseImageProcessor = None,
|
| 57 |
tokenizer: PreTrainedTokenizer = None,
|
| 58 |
-
patch_size
|
| 59 |
-
vision_feature_select_strategy
|
| 60 |
-
chat_template
|
| 61 |
image_token: str = "<image>",
|
| 62 |
phrase_start_token: str = "<obj>",
|
| 63 |
phrase_end_token: str = "</obj>",
|
|
@@ -106,9 +106,9 @@ class Maira2Processor(LlavaProcessor):
|
|
| 106 |
def _normalize_and_stack_images(
|
| 107 |
self,
|
| 108 |
current_frontal: Image.Image,
|
| 109 |
-
current_lateral: Image.Image
|
| 110 |
-
prior_frontal: Image.Image
|
| 111 |
-
)
|
| 112 |
"""
|
| 113 |
This function normalizes the input images and stacks them together. The images are stacked in the order of
|
| 114 |
current_frontal, current_lateral, and prior_frontal. The order of images is important, since it must match the
|
|
@@ -133,7 +133,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 133 |
return images
|
| 134 |
|
| 135 |
@staticmethod
|
| 136 |
-
def _get_section_text_or_missing_text(section: str
|
| 137 |
"""
|
| 138 |
This function returns the input section text if it is not None and not empty, otherwise it returns a missing
|
| 139 |
section text "N/A".
|
|
@@ -151,7 +151,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 151 |
return section
|
| 152 |
|
| 153 |
@staticmethod
|
| 154 |
-
def _construct_image_chat_messages_for_reporting(has_prior: bool, has_lateral: bool)
|
| 155 |
"""
|
| 156 |
This function constructs user chat messages based on the presence of the prior and lateral images.
|
| 157 |
|
|
@@ -187,7 +187,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 187 |
]
|
| 188 |
)
|
| 189 |
|
| 190 |
-
image_prompt
|
| 191 |
image_index = 0
|
| 192 |
if not has_prior and not has_lateral:
|
| 193 |
_add_single_image_to_chat_messages("Given the current frontal image only", image_index)
|
|
@@ -208,13 +208,13 @@ class Maira2Processor(LlavaProcessor):
|
|
| 208 |
self,
|
| 209 |
has_prior: bool,
|
| 210 |
has_lateral: bool,
|
| 211 |
-
indication: str
|
| 212 |
-
technique: str
|
| 213 |
-
comparison: str
|
| 214 |
-
prior_report: str
|
| 215 |
get_grounding: bool = False,
|
| 216 |
-
assistant_text: str
|
| 217 |
-
)
|
| 218 |
"""
|
| 219 |
This function constructs the chat messages for reporting used in the grounded and non-grounded reporting tasks.
|
| 220 |
|
|
@@ -299,14 +299,14 @@ class Maira2Processor(LlavaProcessor):
|
|
| 299 |
"type": "text",
|
| 300 |
}
|
| 301 |
)
|
| 302 |
-
messages
|
| 303 |
if assistant_text is not None:
|
| 304 |
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
|
| 305 |
return messages
|
| 306 |
|
| 307 |
def _construct_chat_messages_phrase_grounding(
|
| 308 |
-
self, phrase: str, assistant_text: str
|
| 309 |
-
)
|
| 310 |
"""
|
| 311 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
| 312 |
|
|
@@ -319,7 +319,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 319 |
Returns:
|
| 320 |
ChatMessageListType: The chat messages for phrase grounding in the form of a list of dictionaries.
|
| 321 |
"""
|
| 322 |
-
prompt
|
| 323 |
{"index": None, "text": "Given the current frontal image", "type": "text"},
|
| 324 |
{"index": 0, "text": None, "type": "image"},
|
| 325 |
{
|
|
@@ -329,7 +329,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 329 |
"type": "text",
|
| 330 |
},
|
| 331 |
]
|
| 332 |
-
messages
|
| 333 |
if assistant_text is not None:
|
| 334 |
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
|
| 335 |
return messages
|
|
@@ -337,15 +337,15 @@ class Maira2Processor(LlavaProcessor):
|
|
| 337 |
def format_reporting_input(
|
| 338 |
self,
|
| 339 |
current_frontal: Image.Image,
|
| 340 |
-
current_lateral: Image.Image
|
| 341 |
-
prior_frontal: Image.Image
|
| 342 |
-
indication: str
|
| 343 |
-
technique: str
|
| 344 |
-
comparison: str
|
| 345 |
-
prior_report: str
|
| 346 |
get_grounding: bool = False,
|
| 347 |
-
assistant_text: str
|
| 348 |
-
)
|
| 349 |
"""
|
| 350 |
This function formats the reporting prompt for the grounded and non-grounded reporting tasks from the given
|
| 351 |
input images and text sections. The images are normalized and stacked together in the right order.
|
|
@@ -395,8 +395,8 @@ class Maira2Processor(LlavaProcessor):
|
|
| 395 |
self,
|
| 396 |
frontal_image: Image.Image,
|
| 397 |
phrase: str,
|
| 398 |
-
assistant_text: str
|
| 399 |
-
)
|
| 400 |
"""
|
| 401 |
This function formats the phrase grounding prompt for the phrase grounding task from the given input
|
| 402 |
image and phrase.
|
|
@@ -425,14 +425,14 @@ class Maira2Processor(LlavaProcessor):
|
|
| 425 |
def format_and_preprocess_reporting_input(
|
| 426 |
self,
|
| 427 |
current_frontal: Image.Image,
|
| 428 |
-
current_lateral: Image.Image
|
| 429 |
-
prior_frontal: Image.Image
|
| 430 |
-
indication: str
|
| 431 |
-
technique: str
|
| 432 |
-
comparison: str
|
| 433 |
-
prior_report: str
|
| 434 |
get_grounding: bool = False,
|
| 435 |
-
assistant_text: str
|
| 436 |
**kwargs: Any,
|
| 437 |
) -> BatchFeature:
|
| 438 |
"""
|
|
@@ -481,7 +481,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 481 |
self,
|
| 482 |
frontal_image: Image.Image,
|
| 483 |
phrase: str,
|
| 484 |
-
assistant_text: str
|
| 485 |
**kwargs: Any,
|
| 486 |
) -> BatchFeature:
|
| 487 |
"""
|
|
@@ -507,7 +507,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 507 |
)
|
| 508 |
return self(text=text, images=images, **kwargs)
|
| 509 |
|
| 510 |
-
def _get_text_between_delimiters(self, text: str, begin_token: str, end_token: str)
|
| 511 |
"""
|
| 512 |
This function splits the input text into a list of substrings beased on the given begin and end tokens.
|
| 513 |
|
|
@@ -544,7 +544,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 544 |
|
| 545 |
def convert_output_to_plaintext_or_grounded_sequence(
|
| 546 |
self, text: str
|
| 547 |
-
)
|
| 548 |
"""
|
| 549 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
| 550 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
|
@@ -584,7 +584,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 584 |
|
| 585 |
# One or more grounded phrases
|
| 586 |
grounded_phrase_texts = self._get_text_between_delimiters(text, self.phrase_start_token, self.phrase_end_token)
|
| 587 |
-
grounded_phrases
|
| 588 |
for grounded_phrase_text in grounded_phrase_texts:
|
| 589 |
if self.box_start_token in grounded_phrase_text or self.box_end_token in grounded_phrase_text:
|
| 590 |
first_box_start_index = grounded_phrase_text.find(self.box_start_token)
|
|
@@ -593,14 +593,14 @@ class Maira2Processor(LlavaProcessor):
|
|
| 593 |
boxes_text_list = self._get_text_between_delimiters(
|
| 594 |
boxes_text, self.box_start_token, self.box_end_token
|
| 595 |
)
|
| 596 |
-
boxes
|
| 597 |
for box_text in boxes_text_list:
|
| 598 |
# extract from <x_><y_><x_><y_>
|
| 599 |
regex = r"<x(\d+?)><y(\d+?)><x(\d+?)><y(\d+?)>"
|
| 600 |
match = re.search(regex, box_text)
|
| 601 |
if match:
|
| 602 |
x_min, y_min, x_max, y_max = match.groups()
|
| 603 |
-
box
|
| 604 |
(int(coord) + 0.5) / self.num_box_coord_bins for coord in (x_min, y_min, x_max, y_max)
|
| 605 |
)
|
| 606 |
assert all(0 <= coord <= 1 for coord in box), f"Invalid box coordinates: {box}"
|
|
@@ -613,7 +613,7 @@ class Maira2Processor(LlavaProcessor):
|
|
| 613 |
return grounded_phrases
|
| 614 |
|
| 615 |
@staticmethod
|
| 616 |
-
def adjust_box_for_original_image_size(box
|
| 617 |
"""
|
| 618 |
This function adjusts the bounding boxes from the MAIRA-2 model output to account for the image processor
|
| 619 |
cropping the image to be square prior to the model forward pass. The box coordinates are adjusted to be
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
import re
|
| 6 |
+
from typing import Any, Union, List
|
| 7 |
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
|
|
|
| 14 |
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
| 15 |
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
|
| 17 |
+
# SingleChatMessageType: TypeAlias = dict[str, str | int | None]
|
| 18 |
+
# ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
|
| 19 |
+
# BoxType: TypeAlias = tuple[float, float, float, float]
|
| 20 |
|
| 21 |
|
| 22 |
class Maira2Processor(LlavaProcessor):
|
|
|
|
| 55 |
self,
|
| 56 |
image_processor: BaseImageProcessor = None,
|
| 57 |
tokenizer: PreTrainedTokenizer = None,
|
| 58 |
+
patch_size = None,
|
| 59 |
+
vision_feature_select_strategy = None,
|
| 60 |
+
chat_template = None,
|
| 61 |
image_token: str = "<image>",
|
| 62 |
phrase_start_token: str = "<obj>",
|
| 63 |
phrase_end_token: str = "</obj>",
|
|
|
|
| 106 |
def _normalize_and_stack_images(
|
| 107 |
self,
|
| 108 |
current_frontal: Image.Image,
|
| 109 |
+
current_lateral: Image.Image,
|
| 110 |
+
prior_frontal: Image.Image,
|
| 111 |
+
):
|
| 112 |
"""
|
| 113 |
This function normalizes the input images and stacks them together. The images are stacked in the order of
|
| 114 |
current_frontal, current_lateral, and prior_frontal. The order of images is important, since it must match the
|
|
|
|
| 133 |
return images
|
| 134 |
|
| 135 |
@staticmethod
|
| 136 |
+
def _get_section_text_or_missing_text(section: str) -> str:
|
| 137 |
"""
|
| 138 |
This function returns the input section text if it is not None and not empty, otherwise it returns a missing
|
| 139 |
section text "N/A".
|
|
|
|
| 151 |
return section
|
| 152 |
|
| 153 |
@staticmethod
|
| 154 |
+
def _construct_image_chat_messages_for_reporting(has_prior: bool, has_lateral: bool):
|
| 155 |
"""
|
| 156 |
This function constructs user chat messages based on the presence of the prior and lateral images.
|
| 157 |
|
|
|
|
| 187 |
]
|
| 188 |
)
|
| 189 |
|
| 190 |
+
image_prompt = []
|
| 191 |
image_index = 0
|
| 192 |
if not has_prior and not has_lateral:
|
| 193 |
_add_single_image_to_chat_messages("Given the current frontal image only", image_index)
|
|
|
|
| 208 |
self,
|
| 209 |
has_prior: bool,
|
| 210 |
has_lateral: bool,
|
| 211 |
+
indication: str,
|
| 212 |
+
technique: str,
|
| 213 |
+
comparison: str,
|
| 214 |
+
prior_report: str,
|
| 215 |
get_grounding: bool = False,
|
| 216 |
+
assistant_text: str = None,
|
| 217 |
+
):
|
| 218 |
"""
|
| 219 |
This function constructs the chat messages for reporting used in the grounded and non-grounded reporting tasks.
|
| 220 |
|
|
|
|
| 299 |
"type": "text",
|
| 300 |
}
|
| 301 |
)
|
| 302 |
+
messages = [{"content": prompt, "role": "user"}]
|
| 303 |
if assistant_text is not None:
|
| 304 |
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
|
| 305 |
return messages
|
| 306 |
|
| 307 |
def _construct_chat_messages_phrase_grounding(
|
| 308 |
+
self, phrase: str, assistant_text: str = None
|
| 309 |
+
):
|
| 310 |
"""
|
| 311 |
This function constructs the chat messages for phrase grounding used in the phrase grounding task.
|
| 312 |
|
|
|
|
| 319 |
Returns:
|
| 320 |
ChatMessageListType: The chat messages for phrase grounding in the form of a list of dictionaries.
|
| 321 |
"""
|
| 322 |
+
prompt = [
|
| 323 |
{"index": None, "text": "Given the current frontal image", "type": "text"},
|
| 324 |
{"index": 0, "text": None, "type": "image"},
|
| 325 |
{
|
|
|
|
| 329 |
"type": "text",
|
| 330 |
},
|
| 331 |
]
|
| 332 |
+
messages = [{"content": prompt, "role": "user"}]
|
| 333 |
if assistant_text is not None:
|
| 334 |
messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
|
| 335 |
return messages
|
|
|
|
| 337 |
def format_reporting_input(
|
| 338 |
self,
|
| 339 |
current_frontal: Image.Image,
|
| 340 |
+
current_lateral: Image.Image,
|
| 341 |
+
prior_frontal: Image.Image,
|
| 342 |
+
indication: str,
|
| 343 |
+
technique: str,
|
| 344 |
+
comparison: str,
|
| 345 |
+
prior_report: str,
|
| 346 |
get_grounding: bool = False,
|
| 347 |
+
assistant_text: str,
|
| 348 |
+
):
|
| 349 |
"""
|
| 350 |
This function formats the reporting prompt for the grounded and non-grounded reporting tasks from the given
|
| 351 |
input images and text sections. The images are normalized and stacked together in the right order.
|
|
|
|
| 395 |
self,
|
| 396 |
frontal_image: Image.Image,
|
| 397 |
phrase: str,
|
| 398 |
+
assistant_text: str = None,
|
| 399 |
+
):
|
| 400 |
"""
|
| 401 |
This function formats the phrase grounding prompt for the phrase grounding task from the given input
|
| 402 |
image and phrase.
|
|
|
|
| 425 |
def format_and_preprocess_reporting_input(
|
| 426 |
self,
|
| 427 |
current_frontal: Image.Image,
|
| 428 |
+
current_lateral: Image.Image,
|
| 429 |
+
prior_frontal: Image.Image,
|
| 430 |
+
indication: str,
|
| 431 |
+
technique: str,
|
| 432 |
+
comparison: str,
|
| 433 |
+
prior_report: str,
|
| 434 |
get_grounding: bool = False,
|
| 435 |
+
assistant_text: str = None,
|
| 436 |
**kwargs: Any,
|
| 437 |
) -> BatchFeature:
|
| 438 |
"""
|
|
|
|
| 481 |
self,
|
| 482 |
frontal_image: Image.Image,
|
| 483 |
phrase: str,
|
| 484 |
+
assistant_text: str = None,
|
| 485 |
**kwargs: Any,
|
| 486 |
) -> BatchFeature:
|
| 487 |
"""
|
|
|
|
| 507 |
)
|
| 508 |
return self(text=text, images=images, **kwargs)
|
| 509 |
|
| 510 |
+
def _get_text_between_delimiters(self, text: str, begin_token: str, end_token: str):
|
| 511 |
"""
|
| 512 |
This function splits the input text into a list of substrings beased on the given begin and end tokens.
|
| 513 |
|
|
|
|
| 544 |
|
| 545 |
def convert_output_to_plaintext_or_grounded_sequence(
|
| 546 |
self, text: str
|
| 547 |
+
):
|
| 548 |
"""
|
| 549 |
This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
|
| 550 |
boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
|
|
|
|
| 584 |
|
| 585 |
# One or more grounded phrases
|
| 586 |
grounded_phrase_texts = self._get_text_between_delimiters(text, self.phrase_start_token, self.phrase_end_token)
|
| 587 |
+
grounded_phrases = []
|
| 588 |
for grounded_phrase_text in grounded_phrase_texts:
|
| 589 |
if self.box_start_token in grounded_phrase_text or self.box_end_token in grounded_phrase_text:
|
| 590 |
first_box_start_index = grounded_phrase_text.find(self.box_start_token)
|
|
|
|
| 593 |
boxes_text_list = self._get_text_between_delimiters(
|
| 594 |
boxes_text, self.box_start_token, self.box_end_token
|
| 595 |
)
|
| 596 |
+
boxes = []
|
| 597 |
for box_text in boxes_text_list:
|
| 598 |
# extract from <x_><y_><x_><y_>
|
| 599 |
regex = r"<x(\d+?)><y(\d+?)><x(\d+?)><y(\d+?)>"
|
| 600 |
match = re.search(regex, box_text)
|
| 601 |
if match:
|
| 602 |
x_min, y_min, x_max, y_max = match.groups()
|
| 603 |
+
box = tuple( # type: ignore[assignment]
|
| 604 |
(int(coord) + 0.5) / self.num_box_coord_bins for coord in (x_min, y_min, x_max, y_max)
|
| 605 |
)
|
| 606 |
assert all(0 <= coord <= 1 for coord in box), f"Invalid box coordinates: {box}"
|
|
|
|
| 613 |
return grounded_phrases
|
| 614 |
|
| 615 |
@staticmethod
|
| 616 |
+
def adjust_box_for_original_image_size(box, width: int, height: int):
|
| 617 |
"""
|
| 618 |
This function adjusts the bounding boxes from the MAIRA-2 model output to account for the image processor
|
| 619 |
cropping the image to be square prior to the model forward pass. The box coordinates are adjusted to be
|