E:\ComfyUI-aki-v1.3\models\LLM
#29
by
wingman212
- opened
- configuration_florence2.py +2 -2
- model.safetensors +0 -3
- modeling_florence2.py +26 -24
- processing_florence2.py +24 -84
configuration_florence2.py
CHANGED
|
@@ -77,7 +77,7 @@ class Florence2VisionConfig(PretrainedConfig):
|
|
| 77 |
>>> configuration = model.config
|
| 78 |
```"""
|
| 79 |
|
| 80 |
-
model_type = "
|
| 81 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 82 |
|
| 83 |
def __init__(
|
|
@@ -327,7 +327,7 @@ class Florence2Config(PretrainedConfig):
|
|
| 327 |
self.vocab_size = vocab_size
|
| 328 |
self.projection_dim = projection_dim
|
| 329 |
if vision_config is not None:
|
| 330 |
-
vision_config =
|
| 331 |
self.vision_config = vision_config
|
| 332 |
self.vocab_size = self.vocab_size
|
| 333 |
|
|
|
|
| 77 |
>>> configuration = model.config
|
| 78 |
```"""
|
| 79 |
|
| 80 |
+
model_type = "florence2_vision"
|
| 81 |
keys_to_ignore_at_inference = ["past_key_values"]
|
| 82 |
|
| 83 |
def __init__(
|
|
|
|
| 327 |
self.vocab_size = vocab_size
|
| 328 |
self.projection_dim = projection_dim
|
| 329 |
if vision_config is not None:
|
| 330 |
+
vision_config = PretrainedConfig(**vision_config)
|
| 331 |
self.vision_config = vision_config
|
| 332 |
self.vocab_size = self.vocab_size
|
| 333 |
|
model.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:03075d2d2d2bbd3e180b9ba0afae4aa8563226e2d32911656966e05b2f2ee060
|
| 3 |
-
size 463221266
|
|
|
|
|
|
|
|
|
|
|
|
modeling_florence2.py
CHANGED
|
@@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
|
|
| 26 |
from torch.nn import CrossEntropyLoss
|
| 27 |
from collections import OrderedDict
|
| 28 |
from einops import rearrange
|
| 29 |
-
from timm.layers import DropPath, trunc_normal_
|
| 30 |
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.generation.utils import GenerationMixin
|
|
@@ -610,10 +610,29 @@ class DaViT(nn.Module):
|
|
| 610 |
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 611 |
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 612 |
|
|
|
|
|
|
|
| 613 |
@property
|
| 614 |
def dim_out(self):
|
| 615 |
return self.embed_dims[-1]
|
| 616 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
def forward_features_unpool(self, x):
|
| 618 |
"""
|
| 619 |
forward until avg pooling
|
|
@@ -1432,17 +1451,6 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
|
|
| 1432 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 1433 |
if module.padding_idx is not None:
|
| 1434 |
module.weight.data[module.padding_idx].zero_()
|
| 1435 |
-
elif isinstance(module, nn.Conv2d):
|
| 1436 |
-
nn.init.normal_(module.weight, std=0.02)
|
| 1437 |
-
for name, _ in module.named_parameters():
|
| 1438 |
-
if name == "bias":
|
| 1439 |
-
nn.init.constant_(module.bias, 0)
|
| 1440 |
-
elif isinstance(module, nn.LayerNorm):
|
| 1441 |
-
nn.init.constant_(module.weight, 1.0)
|
| 1442 |
-
nn.init.constant_(module.bias, 0)
|
| 1443 |
-
elif isinstance(module, nn.BatchNorm2d):
|
| 1444 |
-
nn.init.constant_(module.weight, 1.0)
|
| 1445 |
-
nn.init.constant_(module.bias, 0)
|
| 1446 |
|
| 1447 |
@property
|
| 1448 |
def dummy_inputs(self):
|
|
@@ -2066,20 +2074,14 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
| 2066 |
# Initialize weights and apply final processing
|
| 2067 |
self.post_init()
|
| 2068 |
|
| 2069 |
-
def _tie_weights(self):
|
| 2070 |
-
if self.config.tie_word_embeddings:
|
| 2071 |
-
self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
|
| 2072 |
-
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
|
| 2073 |
-
self._tie_or_clone_weights(self.lm_head, self.model.shared)
|
| 2074 |
-
|
| 2075 |
def get_encoder(self):
|
| 2076 |
return self.model.get_encoder()
|
| 2077 |
|
| 2078 |
def get_decoder(self):
|
| 2079 |
return self.model.get_decoder()
|
| 2080 |
|
| 2081 |
-
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None
|
| 2082 |
-
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of
|
| 2083 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
| 2084 |
return new_embeddings
|
| 2085 |
|
|
@@ -2529,8 +2531,6 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
|
|
| 2529 |
FLORENCE2_START_DOCSTRING,
|
| 2530 |
)
|
| 2531 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
| 2532 |
-
_tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
|
| 2533 |
-
|
| 2534 |
def __init__(self, config: Florence2Config):
|
| 2535 |
super().__init__(config)
|
| 2536 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
@@ -2545,6 +2545,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2545 |
|
| 2546 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
| 2547 |
|
|
|
|
|
|
|
| 2548 |
self.language_model = language_model
|
| 2549 |
|
| 2550 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
@@ -2587,8 +2589,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
| 2587 |
def get_input_embeddings(self):
|
| 2588 |
return self.language_model.get_input_embeddings()
|
| 2589 |
|
| 2590 |
-
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
|
| 2591 |
-
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of
|
| 2592 |
# update vocab size
|
| 2593 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 2594 |
self.config.vocab_size = model_embeds.num_embeddings
|
|
|
|
| 26 |
from torch.nn import CrossEntropyLoss
|
| 27 |
from collections import OrderedDict
|
| 28 |
from einops import rearrange
|
| 29 |
+
from timm.models.layers import DropPath, trunc_normal_
|
| 30 |
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from transformers.generation.utils import GenerationMixin
|
|
|
|
| 610 |
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
| 611 |
self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
|
| 612 |
|
| 613 |
+
self.apply(self._init_weights)
|
| 614 |
+
|
| 615 |
@property
|
| 616 |
def dim_out(self):
|
| 617 |
return self.embed_dims[-1]
|
| 618 |
|
| 619 |
+
def _init_weights(self, m):
|
| 620 |
+
if isinstance(m, nn.Linear):
|
| 621 |
+
trunc_normal_(m.weight, std=0.02)
|
| 622 |
+
if m.bias is not None:
|
| 623 |
+
nn.init.constant_(m.bias, 0)
|
| 624 |
+
elif isinstance(m, nn.Conv2d):
|
| 625 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 626 |
+
for name, _ in m.named_parameters():
|
| 627 |
+
if name in ['bias']:
|
| 628 |
+
nn.init.constant_(m.bias, 0)
|
| 629 |
+
elif isinstance(m, nn.LayerNorm):
|
| 630 |
+
nn.init.constant_(m.weight, 1.0)
|
| 631 |
+
nn.init.constant_(m.bias, 0)
|
| 632 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 633 |
+
nn.init.constant_(m.weight, 1.0)
|
| 634 |
+
nn.init.constant_(m.bias, 0)
|
| 635 |
+
|
| 636 |
def forward_features_unpool(self, x):
|
| 637 |
"""
|
| 638 |
forward until avg pooling
|
|
|
|
| 1451 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 1452 |
if module.padding_idx is not None:
|
| 1453 |
module.weight.data[module.padding_idx].zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1454 |
|
| 1455 |
@property
|
| 1456 |
def dummy_inputs(self):
|
|
|
|
| 2074 |
# Initialize weights and apply final processing
|
| 2075 |
self.post_init()
|
| 2076 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2077 |
def get_encoder(self):
|
| 2078 |
return self.model.get_encoder()
|
| 2079 |
|
| 2080 |
def get_decoder(self):
|
| 2081 |
return self.model.get_decoder()
|
| 2082 |
|
| 2083 |
+
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
| 2084 |
+
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 2085 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
| 2086 |
return new_embeddings
|
| 2087 |
|
|
|
|
| 2531 |
FLORENCE2_START_DOCSTRING,
|
| 2532 |
)
|
| 2533 |
class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
|
|
|
|
|
| 2534 |
def __init__(self, config: Florence2Config):
|
| 2535 |
super().__init__(config)
|
| 2536 |
assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
|
|
|
|
| 2545 |
|
| 2546 |
language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
|
| 2547 |
|
| 2548 |
+
if language_model._tied_weights_keys is not None:
|
| 2549 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
| 2550 |
self.language_model = language_model
|
| 2551 |
|
| 2552 |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
|
|
|
| 2589 |
def get_input_embeddings(self):
|
| 2590 |
return self.language_model.get_input_embeddings()
|
| 2591 |
|
| 2592 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 2593 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 2594 |
# update vocab size
|
| 2595 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 2596 |
self.config.vocab_size = model_embeds.num_embeddings
|
processing_florence2.py
CHANGED
|
@@ -20,7 +20,6 @@ import re
|
|
| 20 |
import logging
|
| 21 |
from typing import List, Optional, Union
|
| 22 |
import numpy as np
|
| 23 |
-
import math
|
| 24 |
|
| 25 |
import torch
|
| 26 |
|
|
@@ -33,7 +32,6 @@ from transformers.tokenization_utils_base import (
|
|
| 33 |
TextInput,
|
| 34 |
TruncationStrategy,
|
| 35 |
)
|
| 36 |
-
from transformers import BartTokenizer, BartTokenizerFast
|
| 37 |
from transformers.utils import TensorType
|
| 38 |
|
| 39 |
|
|
@@ -306,7 +304,7 @@ class Florence2Processor(ProcessorMixin):
|
|
| 306 |
image_processor_input_names = self.image_processor.model_input_names
|
| 307 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 308 |
|
| 309 |
-
def post_process_generation(self, text
|
| 310 |
"""
|
| 311 |
Post-process the output of the model to each of the task outputs.
|
| 312 |
|
|
@@ -319,8 +317,6 @@ class Florence2Processor(ProcessorMixin):
|
|
| 319 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
| 320 |
task_answer = self.post_processor(
|
| 321 |
text=text,
|
| 322 |
-
sequence=sequence,
|
| 323 |
-
transition_beam_score=transition_beam_score,
|
| 324 |
image_size=image_size,
|
| 325 |
parse_tasks=task_answer_post_processing_type,
|
| 326 |
)[task_answer_post_processing_type]
|
|
@@ -334,9 +330,6 @@ class Florence2Processor(ProcessorMixin):
|
|
| 334 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
| 335 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
| 336 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
| 337 |
-
if len(od_instances) and 'score' in od_instances[0]:
|
| 338 |
-
scores_od = [_od_instance['score'] for _od_instance in od_instances]
|
| 339 |
-
final_answer['scores'] = scores_od
|
| 340 |
elif task_answer_post_processing_type in ['ocr']:
|
| 341 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
| 342 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
|
@@ -503,7 +496,7 @@ class CoordinatesQuantizer(object):
|
|
| 503 |
|
| 504 |
|
| 505 |
class Florence2PostProcesser(object):
|
| 506 |
-
|
| 507 |
Florence-2 post process for converting text prediction to various tasks results.
|
| 508 |
|
| 509 |
Args:
|
|
@@ -598,8 +591,7 @@ class Florence2PostProcesser(object):
|
|
| 598 |
'PARSE_TASKS': [
|
| 599 |
{
|
| 600 |
'TASK_NAME': 'od',
|
| 601 |
-
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
| 602 |
-
'SCORE_MODE': 'avg_loc_scores'
|
| 603 |
},
|
| 604 |
{
|
| 605 |
'TASK_NAME': 'ocr',
|
|
@@ -615,7 +607,6 @@ class Florence2PostProcesser(object):
|
|
| 615 |
},
|
| 616 |
{
|
| 617 |
'TASK_NAME': 'description_with_bboxes',
|
| 618 |
-
'SCORE_MODE': 'avg_loc_scores'
|
| 619 |
},
|
| 620 |
{
|
| 621 |
'TASK_NAME': 'description_with_polygons',
|
|
@@ -657,6 +648,9 @@ class Florence2PostProcesser(object):
|
|
| 657 |
token_ids, skip_special_tokens=False)
|
| 658 |
assert len(filtered_tokens) == len(token_ids)
|
| 659 |
|
|
|
|
|
|
|
|
|
|
| 660 |
sub_texts = []
|
| 661 |
for token in filtered_tokens:
|
| 662 |
if token in self.all_special_tokens:
|
|
@@ -664,6 +658,10 @@ class Florence2PostProcesser(object):
|
|
| 664 |
else:
|
| 665 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
| 666 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
else:
|
| 668 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
| 669 |
sub_texts.append(sub_text)
|
|
@@ -675,6 +673,13 @@ class Florence2PostProcesser(object):
|
|
| 675 |
text += sub_text
|
| 676 |
spans.append(span)
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
return text, spans
|
| 679 |
|
| 680 |
def parse_od_from_text_and_spans(
|
|
@@ -709,7 +714,7 @@ class Florence2PostProcesser(object):
|
|
| 709 |
return instances
|
| 710 |
|
| 711 |
def parse_ocr_from_text_and_spans(self,
|
| 712 |
-
|
| 713 |
pattern,
|
| 714 |
image_size,
|
| 715 |
area_threshold=-1.0,
|
|
@@ -813,26 +818,9 @@ class Florence2PostProcesser(object):
|
|
| 813 |
|
| 814 |
return instances
|
| 815 |
|
| 816 |
-
def parse_description_with_bboxes_from_text_and_spans(
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
spans=None,
|
| 820 |
-
scores=None,
|
| 821 |
-
score_mode=None,
|
| 822 |
-
pattern=None,
|
| 823 |
-
image_size=None,
|
| 824 |
-
allow_empty_phrase=False
|
| 825 |
-
):
|
| 826 |
-
def find_matched_token_indices(cur_span, token_spans):
|
| 827 |
-
inds = []
|
| 828 |
-
for i, token_span in enumerate(token_spans):
|
| 829 |
-
if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
|
| 830 |
-
inds.append(i)
|
| 831 |
-
return inds
|
| 832 |
-
|
| 833 |
-
cur_span = 0
|
| 834 |
-
if text.startswith('<s>'):
|
| 835 |
-
cur_span += 3
|
| 836 |
|
| 837 |
text = text.replace('<s>', '')
|
| 838 |
text = text.replace('</s>', '')
|
|
@@ -854,16 +842,13 @@ class Florence2PostProcesser(object):
|
|
| 854 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
| 855 |
|
| 856 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
| 857 |
-
cur_span += len(pharse_text)
|
| 858 |
continue
|
| 859 |
|
| 860 |
# parse phrase, get string
|
| 861 |
phrase = re.search(pattern, phrase_text_strip)
|
| 862 |
if phrase is None:
|
| 863 |
-
cur_span += len(pharse_text)
|
| 864 |
continue
|
| 865 |
|
| 866 |
-
phrase_span = phrase.span()
|
| 867 |
phrase = phrase.group()
|
| 868 |
# remove leading and trailing spaces
|
| 869 |
phrase = phrase.strip()
|
|
@@ -871,7 +856,6 @@ class Florence2PostProcesser(object):
|
|
| 871 |
# parse bboxes by box_pattern
|
| 872 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
| 873 |
if len(bboxes_parsed) == 0:
|
| 874 |
-
cur_span += len(pharse_text)
|
| 875 |
continue
|
| 876 |
|
| 877 |
# a list of list
|
|
@@ -882,42 +866,14 @@ class Florence2PostProcesser(object):
|
|
| 882 |
size=image_size
|
| 883 |
).tolist()
|
| 884 |
|
| 885 |
-
if score_mode == 'avg_loc_scores':
|
| 886 |
-
if spans is None or scores is None:
|
| 887 |
-
all_scores = None
|
| 888 |
-
else:
|
| 889 |
-
bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
|
| 890 |
-
all_scores = []
|
| 891 |
-
for _spans in bbox_end_spans:
|
| 892 |
-
token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
|
| 893 |
-
loc_scores = [scores[token_i] for token_i in token_inds]
|
| 894 |
-
score = sum(loc_scores) / len(loc_scores)
|
| 895 |
-
all_scores.append(score)
|
| 896 |
-
elif score_mode == 'avg_cat_name_scores':
|
| 897 |
-
if spans is None or scores is None:
|
| 898 |
-
all_scores = None
|
| 899 |
-
else:
|
| 900 |
-
cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
|
| 901 |
-
cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
|
| 902 |
-
score = sum(cat_name_scores) / len(cat_name_scores)
|
| 903 |
-
all_scores = [score] * len(bboxes)
|
| 904 |
-
elif score_mode is None:
|
| 905 |
-
all_scores = None
|
| 906 |
-
else:
|
| 907 |
-
raise ValueError('Unknown score mode: {}'.format(score_mode))
|
| 908 |
-
|
| 909 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
| 910 |
-
for
|
| 911 |
# Prepare instance.
|
| 912 |
instance = {}
|
| 913 |
instance['bbox'] = _bboxes
|
| 914 |
# exclude non-ascii characters
|
| 915 |
instance['cat_name'] = phrase
|
| 916 |
-
if all_scores is not None:
|
| 917 |
-
instance['score'] = math.exp(all_scores[_idx])
|
| 918 |
instances.append(instance)
|
| 919 |
-
|
| 920 |
-
cur_span += len(pharse_text)
|
| 921 |
|
| 922 |
return instances
|
| 923 |
|
|
@@ -1035,8 +991,6 @@ class Florence2PostProcesser(object):
|
|
| 1035 |
def __call__(
|
| 1036 |
self,
|
| 1037 |
text=None,
|
| 1038 |
-
sequence=None,
|
| 1039 |
-
transition_beam_score=None,
|
| 1040 |
image_size=None,
|
| 1041 |
parse_tasks=None,
|
| 1042 |
):
|
|
@@ -1045,6 +999,7 @@ class Florence2PostProcesser(object):
|
|
| 1045 |
text: model outputs
|
| 1046 |
image_size: (width, height)
|
| 1047 |
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
|
|
|
| 1048 |
"""
|
| 1049 |
if parse_tasks is not None:
|
| 1050 |
if isinstance(parse_tasks, str):
|
|
@@ -1053,18 +1008,7 @@ class Florence2PostProcesser(object):
|
|
| 1053 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
| 1054 |
|
| 1055 |
# sequence or text should be provided
|
| 1056 |
-
assert
|
| 1057 |
-
assert sequence is None or text is None, 'only one of sequence and text should be provided'
|
| 1058 |
-
|
| 1059 |
-
if sequence is not None:
|
| 1060 |
-
sequence = sequence.tolist()[1:]
|
| 1061 |
-
text, spans = self.decode_with_spans(self.tokenizer, sequence)
|
| 1062 |
-
if transition_beam_score is not None:
|
| 1063 |
-
transition_beam_score = transition_beam_score.tolist()
|
| 1064 |
-
assert len(sequence) == len(transition_beam_score)
|
| 1065 |
-
else:
|
| 1066 |
-
spans = None
|
| 1067 |
-
transition_beam_score = None
|
| 1068 |
|
| 1069 |
parsed_dict = {
|
| 1070 |
'text': text
|
|
@@ -1075,7 +1019,6 @@ class Florence2PostProcesser(object):
|
|
| 1075 |
continue
|
| 1076 |
|
| 1077 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
| 1078 |
-
score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
|
| 1079 |
|
| 1080 |
if task == 'ocr':
|
| 1081 |
instances = self.parse_ocr_from_text_and_spans(
|
|
@@ -1097,9 +1040,6 @@ class Florence2PostProcesser(object):
|
|
| 1097 |
elif task == 'description_with_bboxes':
|
| 1098 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
| 1099 |
text,
|
| 1100 |
-
spans=spans,
|
| 1101 |
-
scores=transition_beam_score,
|
| 1102 |
-
score_mode=score_mode,
|
| 1103 |
pattern=pattern,
|
| 1104 |
image_size=image_size,
|
| 1105 |
)
|
|
|
|
| 20 |
import logging
|
| 21 |
from typing import List, Optional, Union
|
| 22 |
import numpy as np
|
|
|
|
| 23 |
|
| 24 |
import torch
|
| 25 |
|
|
|
|
| 32 |
TextInput,
|
| 33 |
TruncationStrategy,
|
| 34 |
)
|
|
|
|
| 35 |
from transformers.utils import TensorType
|
| 36 |
|
| 37 |
|
|
|
|
| 304 |
image_processor_input_names = self.image_processor.model_input_names
|
| 305 |
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
| 306 |
|
| 307 |
+
def post_process_generation(self, text, task, image_size):
|
| 308 |
"""
|
| 309 |
Post-process the output of the model to each of the task outputs.
|
| 310 |
|
|
|
|
| 317 |
task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
|
| 318 |
task_answer = self.post_processor(
|
| 319 |
text=text,
|
|
|
|
|
|
|
| 320 |
image_size=image_size,
|
| 321 |
parse_tasks=task_answer_post_processing_type,
|
| 322 |
)[task_answer_post_processing_type]
|
|
|
|
| 330 |
bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
|
| 331 |
labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
|
| 332 |
final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
|
|
|
|
|
|
|
|
|
|
| 333 |
elif task_answer_post_processing_type in ['ocr']:
|
| 334 |
bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
|
| 335 |
labels = [str(_od_instance['text']) for _od_instance in task_answer]
|
|
|
|
| 496 |
|
| 497 |
|
| 498 |
class Florence2PostProcesser(object):
|
| 499 |
+
"""
|
| 500 |
Florence-2 post process for converting text prediction to various tasks results.
|
| 501 |
|
| 502 |
Args:
|
|
|
|
| 591 |
'PARSE_TASKS': [
|
| 592 |
{
|
| 593 |
'TASK_NAME': 'od',
|
| 594 |
+
'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
|
|
|
|
| 595 |
},
|
| 596 |
{
|
| 597 |
'TASK_NAME': 'ocr',
|
|
|
|
| 607 |
},
|
| 608 |
{
|
| 609 |
'TASK_NAME': 'description_with_bboxes',
|
|
|
|
| 610 |
},
|
| 611 |
{
|
| 612 |
'TASK_NAME': 'description_with_polygons',
|
|
|
|
| 648 |
token_ids, skip_special_tokens=False)
|
| 649 |
assert len(filtered_tokens) == len(token_ids)
|
| 650 |
|
| 651 |
+
# To avoid mixing byte-level and unicode for byte-level BPT
|
| 652 |
+
# we need to build string separately for added tokens and byte-level tokens
|
| 653 |
+
# cf. https://github.com/huggingface/transformers/issues/1133
|
| 654 |
sub_texts = []
|
| 655 |
for token in filtered_tokens:
|
| 656 |
if token in self.all_special_tokens:
|
|
|
|
| 658 |
else:
|
| 659 |
if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
|
| 660 |
sub_text = tokenizer.convert_tokens_to_string([token])
|
| 661 |
+
elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
|
| 662 |
+
# Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
|
| 663 |
+
# Note: Do not strip sub_text as it may have functional whitespace
|
| 664 |
+
sub_text = token.replace('▁', ' ')
|
| 665 |
else:
|
| 666 |
raise ValueError(f'type {type(tokenizer)} not supported')
|
| 667 |
sub_texts.append(sub_text)
|
|
|
|
| 673 |
text += sub_text
|
| 674 |
spans.append(span)
|
| 675 |
|
| 676 |
+
# Text format:
|
| 677 |
+
# 1. T5Tokenizer/T5TokenizerFast:
|
| 678 |
+
# "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
|
| 679 |
+
# Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
| 680 |
+
# 2. BartTokenizer (need to double check):
|
| 681 |
+
# "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
|
| 682 |
+
# Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
|
| 683 |
return text, spans
|
| 684 |
|
| 685 |
def parse_od_from_text_and_spans(
|
|
|
|
| 714 |
return instances
|
| 715 |
|
| 716 |
def parse_ocr_from_text_and_spans(self,
|
| 717 |
+
text,
|
| 718 |
pattern,
|
| 719 |
image_size,
|
| 720 |
area_threshold=-1.0,
|
|
|
|
| 818 |
|
| 819 |
return instances
|
| 820 |
|
| 821 |
+
def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
|
| 822 |
+
# temporary parse solution, split by '.'
|
| 823 |
+
# ignore <s> </s> and <pad>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
|
| 825 |
text = text.replace('<s>', '')
|
| 826 |
text = text.replace('</s>', '')
|
|
|
|
| 842 |
phrase_text_strip = pharse_text.replace('<obj>', '', 1)
|
| 843 |
|
| 844 |
if phrase_text_strip == '' and not allow_empty_phrase:
|
|
|
|
| 845 |
continue
|
| 846 |
|
| 847 |
# parse phrase, get string
|
| 848 |
phrase = re.search(pattern, phrase_text_strip)
|
| 849 |
if phrase is None:
|
|
|
|
| 850 |
continue
|
| 851 |
|
|
|
|
| 852 |
phrase = phrase.group()
|
| 853 |
# remove leading and trailing spaces
|
| 854 |
phrase = phrase.strip()
|
|
|
|
| 856 |
# parse bboxes by box_pattern
|
| 857 |
bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
|
| 858 |
if len(bboxes_parsed) == 0:
|
|
|
|
| 859 |
continue
|
| 860 |
|
| 861 |
# a list of list
|
|
|
|
| 866 |
size=image_size
|
| 867 |
).tolist()
|
| 868 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
|
| 870 |
+
for _bboxes in bboxes:
|
| 871 |
# Prepare instance.
|
| 872 |
instance = {}
|
| 873 |
instance['bbox'] = _bboxes
|
| 874 |
# exclude non-ascii characters
|
| 875 |
instance['cat_name'] = phrase
|
|
|
|
|
|
|
| 876 |
instances.append(instance)
|
|
|
|
|
|
|
| 877 |
|
| 878 |
return instances
|
| 879 |
|
|
|
|
| 991 |
def __call__(
|
| 992 |
self,
|
| 993 |
text=None,
|
|
|
|
|
|
|
| 994 |
image_size=None,
|
| 995 |
parse_tasks=None,
|
| 996 |
):
|
|
|
|
| 999 |
text: model outputs
|
| 1000 |
image_size: (width, height)
|
| 1001 |
parse_tasks: a list of tasks to parse, if None, parse all tasks.
|
| 1002 |
+
|
| 1003 |
"""
|
| 1004 |
if parse_tasks is not None:
|
| 1005 |
if isinstance(parse_tasks, str):
|
|
|
|
| 1008 |
assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
|
| 1009 |
|
| 1010 |
# sequence or text should be provided
|
| 1011 |
+
assert text is not None, 'text should be provided'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
parsed_dict = {
|
| 1014 |
'text': text
|
|
|
|
| 1019 |
continue
|
| 1020 |
|
| 1021 |
pattern = self.parse_tasks_configs[task].get('PATTERN', None)
|
|
|
|
| 1022 |
|
| 1023 |
if task == 'ocr':
|
| 1024 |
instances = self.parse_ocr_from_text_and_spans(
|
|
|
|
| 1040 |
elif task == 'description_with_bboxes':
|
| 1041 |
instances = self.parse_description_with_bboxes_from_text_and_spans(
|
| 1042 |
text,
|
|
|
|
|
|
|
|
|
|
| 1043 |
pattern=pattern,
|
| 1044 |
image_size=image_size,
|
| 1045 |
)
|