E:\ComfyUI-aki-v1.3\models\LLM

#29
by wingman212 - opened
configuration_florence2.py CHANGED
@@ -77,7 +77,7 @@ class Florence2VisionConfig(PretrainedConfig):
77
  >>> configuration = model.config
78
  ```"""
79
 
80
- model_type = "davit"
81
  keys_to_ignore_at_inference = ["past_key_values"]
82
 
83
  def __init__(
@@ -327,7 +327,7 @@ class Florence2Config(PretrainedConfig):
327
  self.vocab_size = vocab_size
328
  self.projection_dim = projection_dim
329
  if vision_config is not None:
330
- vision_config = Florence2VisionConfig(**vision_config)
331
  self.vision_config = vision_config
332
  self.vocab_size = self.vocab_size
333
 
 
77
  >>> configuration = model.config
78
  ```"""
79
 
80
+ model_type = "florence2_vision"
81
  keys_to_ignore_at_inference = ["past_key_values"]
82
 
83
  def __init__(
 
327
  self.vocab_size = vocab_size
328
  self.projection_dim = projection_dim
329
  if vision_config is not None:
330
+ vision_config = PretrainedConfig(**vision_config)
331
  self.vision_config = vision_config
332
  self.vocab_size = self.vocab_size
333
 
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:03075d2d2d2bbd3e180b9ba0afae4aa8563226e2d32911656966e05b2f2ee060
3
- size 463221266
 
 
 
 
modeling_florence2.py CHANGED
@@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
26
  from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
- from timm.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.generation.utils import GenerationMixin
@@ -610,10 +610,29 @@ class DaViT(nn.Module):
610
  self.avgpool = nn.AdaptiveAvgPool1d(1)
611
  self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
612
 
 
 
613
  @property
614
  def dim_out(self):
615
  return self.embed_dims[-1]
616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
617
  def forward_features_unpool(self, x):
618
  """
619
  forward until avg pooling
@@ -1432,17 +1451,6 @@ class Florence2LanguagePreTrainedModel(PreTrainedModel):
1432
  module.weight.data.normal_(mean=0.0, std=std)
1433
  if module.padding_idx is not None:
1434
  module.weight.data[module.padding_idx].zero_()
1435
- elif isinstance(module, nn.Conv2d):
1436
- nn.init.normal_(module.weight, std=0.02)
1437
- for name, _ in module.named_parameters():
1438
- if name == "bias":
1439
- nn.init.constant_(module.bias, 0)
1440
- elif isinstance(module, nn.LayerNorm):
1441
- nn.init.constant_(module.weight, 1.0)
1442
- nn.init.constant_(module.bias, 0)
1443
- elif isinstance(module, nn.BatchNorm2d):
1444
- nn.init.constant_(module.weight, 1.0)
1445
- nn.init.constant_(module.bias, 0)
1446
 
1447
  @property
1448
  def dummy_inputs(self):
@@ -2066,20 +2074,14 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2066
  # Initialize weights and apply final processing
2067
  self.post_init()
2068
 
2069
- def _tie_weights(self):
2070
- if self.config.tie_word_embeddings:
2071
- self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared)
2072
- self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared)
2073
- self._tie_or_clone_weights(self.lm_head, self.model.shared)
2074
-
2075
  def get_encoder(self):
2076
  return self.model.get_encoder()
2077
 
2078
  def get_decoder(self):
2079
  return self.model.get_decoder()
2080
 
2081
- def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, **kwargs) -> nn.Embedding:
2082
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
2083
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2084
  return new_embeddings
2085
 
@@ -2529,8 +2531,6 @@ class Florence2VisionModelWithProjection(Florence2PreTrainedModel):
2529
  FLORENCE2_START_DOCSTRING,
2530
  )
2531
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2532
- _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"]
2533
-
2534
  def __init__(self, config: Florence2Config):
2535
  super().__init__(config)
2536
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
@@ -2545,6 +2545,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2545
 
2546
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2547
 
 
 
2548
  self.language_model = language_model
2549
 
2550
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
@@ -2587,8 +2589,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
2587
  def get_input_embeddings(self):
2588
  return self.language_model.get_input_embeddings()
2589
 
2590
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, **kwargs) -> nn.Embedding:
2591
- model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
2592
  # update vocab size
2593
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2594
  self.config.vocab_size = model_embeds.num_embeddings
 
26
  from torch.nn import CrossEntropyLoss
27
  from collections import OrderedDict
28
  from einops import rearrange
29
+ from timm.models.layers import DropPath, trunc_normal_
30
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.generation.utils import GenerationMixin
 
610
  self.avgpool = nn.AdaptiveAvgPool1d(1)
611
  self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
612
 
613
+ self.apply(self._init_weights)
614
+
615
  @property
616
  def dim_out(self):
617
  return self.embed_dims[-1]
618
 
619
+ def _init_weights(self, m):
620
+ if isinstance(m, nn.Linear):
621
+ trunc_normal_(m.weight, std=0.02)
622
+ if m.bias is not None:
623
+ nn.init.constant_(m.bias, 0)
624
+ elif isinstance(m, nn.Conv2d):
625
+ nn.init.normal_(m.weight, std=0.02)
626
+ for name, _ in m.named_parameters():
627
+ if name in ['bias']:
628
+ nn.init.constant_(m.bias, 0)
629
+ elif isinstance(m, nn.LayerNorm):
630
+ nn.init.constant_(m.weight, 1.0)
631
+ nn.init.constant_(m.bias, 0)
632
+ elif isinstance(m, nn.BatchNorm2d):
633
+ nn.init.constant_(m.weight, 1.0)
634
+ nn.init.constant_(m.bias, 0)
635
+
636
  def forward_features_unpool(self, x):
637
  """
638
  forward until avg pooling
 
1451
  module.weight.data.normal_(mean=0.0, std=std)
1452
  if module.padding_idx is not None:
1453
  module.weight.data[module.padding_idx].zero_()
 
 
 
 
 
 
 
 
 
 
 
1454
 
1455
  @property
1456
  def dummy_inputs(self):
 
2074
  # Initialize weights and apply final processing
2075
  self.post_init()
2076
 
 
 
 
 
 
 
2077
  def get_encoder(self):
2078
  return self.model.get_encoder()
2079
 
2080
  def get_decoder(self):
2081
  return self.model.get_decoder()
2082
 
2083
+ def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
2084
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
2085
  self._resize_final_logits_bias(new_embeddings.weight.shape[0])
2086
  return new_embeddings
2087
 
 
2531
  FLORENCE2_START_DOCSTRING,
2532
  )
2533
  class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
 
 
2534
  def __init__(self, config: Florence2Config):
2535
  super().__init__(config)
2536
  assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now'
 
2545
 
2546
  language_model = Florence2LanguageForConditionalGeneration(config=config.text_config)
2547
 
2548
+ if language_model._tied_weights_keys is not None:
2549
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
2550
  self.language_model = language_model
2551
 
2552
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
 
2589
  def get_input_embeddings(self):
2590
  return self.language_model.get_input_embeddings()
2591
 
2592
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
2593
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
2594
  # update vocab size
2595
  self.config.text_config.vocab_size = model_embeds.num_embeddings
2596
  self.config.vocab_size = model_embeds.num_embeddings
processing_florence2.py CHANGED
@@ -20,7 +20,6 @@ import re
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
23
- import math
24
 
25
  import torch
26
 
@@ -33,7 +32,6 @@ from transformers.tokenization_utils_base import (
33
  TextInput,
34
  TruncationStrategy,
35
  )
36
- from transformers import BartTokenizer, BartTokenizerFast
37
  from transformers.utils import TensorType
38
 
39
 
@@ -306,7 +304,7 @@ class Florence2Processor(ProcessorMixin):
306
  image_processor_input_names = self.image_processor.model_input_names
307
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
308
 
309
- def post_process_generation(self, text=None, sequence=None, transition_beam_score=None, task=None, image_size=None):
310
  """
311
  Post-process the output of the model to each of the task outputs.
312
 
@@ -319,8 +317,6 @@ class Florence2Processor(ProcessorMixin):
319
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
320
  task_answer = self.post_processor(
321
  text=text,
322
- sequence=sequence,
323
- transition_beam_score=transition_beam_score,
324
  image_size=image_size,
325
  parse_tasks=task_answer_post_processing_type,
326
  )[task_answer_post_processing_type]
@@ -334,9 +330,6 @@ class Florence2Processor(ProcessorMixin):
334
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
335
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
336
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
337
- if len(od_instances) and 'score' in od_instances[0]:
338
- scores_od = [_od_instance['score'] for _od_instance in od_instances]
339
- final_answer['scores'] = scores_od
340
  elif task_answer_post_processing_type in ['ocr']:
341
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
342
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
@@ -503,7 +496,7 @@ class CoordinatesQuantizer(object):
503
 
504
 
505
  class Florence2PostProcesser(object):
506
- r"""
507
  Florence-2 post process for converting text prediction to various tasks results.
508
 
509
  Args:
@@ -598,8 +591,7 @@ class Florence2PostProcesser(object):
598
  'PARSE_TASKS': [
599
  {
600
  'TASK_NAME': 'od',
601
- 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>',
602
- 'SCORE_MODE': 'avg_loc_scores'
603
  },
604
  {
605
  'TASK_NAME': 'ocr',
@@ -615,7 +607,6 @@ class Florence2PostProcesser(object):
615
  },
616
  {
617
  'TASK_NAME': 'description_with_bboxes',
618
- 'SCORE_MODE': 'avg_loc_scores'
619
  },
620
  {
621
  'TASK_NAME': 'description_with_polygons',
@@ -657,6 +648,9 @@ class Florence2PostProcesser(object):
657
  token_ids, skip_special_tokens=False)
658
  assert len(filtered_tokens) == len(token_ids)
659
 
 
 
 
660
  sub_texts = []
661
  for token in filtered_tokens:
662
  if token in self.all_special_tokens:
@@ -664,6 +658,10 @@ class Florence2PostProcesser(object):
664
  else:
665
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
666
  sub_text = tokenizer.convert_tokens_to_string([token])
 
 
 
 
667
  else:
668
  raise ValueError(f'type {type(tokenizer)} not supported')
669
  sub_texts.append(sub_text)
@@ -675,6 +673,13 @@ class Florence2PostProcesser(object):
675
  text += sub_text
676
  spans.append(span)
677
 
 
 
 
 
 
 
 
678
  return text, spans
679
 
680
  def parse_od_from_text_and_spans(
@@ -709,7 +714,7 @@ class Florence2PostProcesser(object):
709
  return instances
710
 
711
  def parse_ocr_from_text_and_spans(self,
712
- text,
713
  pattern,
714
  image_size,
715
  area_threshold=-1.0,
@@ -813,26 +818,9 @@ class Florence2PostProcesser(object):
813
 
814
  return instances
815
 
816
- def parse_description_with_bboxes_from_text_and_spans(
817
- self,
818
- text,
819
- spans=None,
820
- scores=None,
821
- score_mode=None,
822
- pattern=None,
823
- image_size=None,
824
- allow_empty_phrase=False
825
- ):
826
- def find_matched_token_indices(cur_span, token_spans):
827
- inds = []
828
- for i, token_span in enumerate(token_spans):
829
- if not (token_span[1] <= cur_span[0] or token_span[0] >= cur_span[1]):
830
- inds.append(i)
831
- return inds
832
-
833
- cur_span = 0
834
- if text.startswith('<s>'):
835
- cur_span += 3
836
 
837
  text = text.replace('<s>', '')
838
  text = text.replace('</s>', '')
@@ -854,16 +842,13 @@ class Florence2PostProcesser(object):
854
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
855
 
856
  if phrase_text_strip == '' and not allow_empty_phrase:
857
- cur_span += len(pharse_text)
858
  continue
859
 
860
  # parse phrase, get string
861
  phrase = re.search(pattern, phrase_text_strip)
862
  if phrase is None:
863
- cur_span += len(pharse_text)
864
  continue
865
 
866
- phrase_span = phrase.span()
867
  phrase = phrase.group()
868
  # remove leading and trailing spaces
869
  phrase = phrase.strip()
@@ -871,7 +856,6 @@ class Florence2PostProcesser(object):
871
  # parse bboxes by box_pattern
872
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
873
  if len(bboxes_parsed) == 0:
874
- cur_span += len(pharse_text)
875
  continue
876
 
877
  # a list of list
@@ -882,42 +866,14 @@ class Florence2PostProcesser(object):
882
  size=image_size
883
  ).tolist()
884
 
885
- if score_mode == 'avg_loc_scores':
886
- if spans is None or scores is None:
887
- all_scores = None
888
- else:
889
- bbox_end_spans = [_bboxes_parsed.span(0) for _bboxes_parsed in bboxes_parsed]
890
- all_scores = []
891
- for _spans in bbox_end_spans:
892
- token_inds = find_matched_token_indices((_spans[0] + cur_span, _spans[1]+ cur_span), spans)
893
- loc_scores = [scores[token_i] for token_i in token_inds]
894
- score = sum(loc_scores) / len(loc_scores)
895
- all_scores.append(score)
896
- elif score_mode == 'avg_cat_name_scores':
897
- if spans is None or scores is None:
898
- all_scores = None
899
- else:
900
- cat_name_token_inds = find_matched_token_indices((phrase_span[0] + cur_span, phrase_span[1]+cur_span), spans)
901
- cat_name_scores = [scores[token_i] for token_i in cat_name_token_inds]
902
- score = sum(cat_name_scores) / len(cat_name_scores)
903
- all_scores = [score] * len(bboxes)
904
- elif score_mode is None:
905
- all_scores = None
906
- else:
907
- raise ValueError('Unknown score mode: {}'.format(score_mode))
908
-
909
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
910
- for _idx, _bboxes in enumerate(bboxes):
911
  # Prepare instance.
912
  instance = {}
913
  instance['bbox'] = _bboxes
914
  # exclude non-ascii characters
915
  instance['cat_name'] = phrase
916
- if all_scores is not None:
917
- instance['score'] = math.exp(all_scores[_idx])
918
  instances.append(instance)
919
-
920
- cur_span += len(pharse_text)
921
 
922
  return instances
923
 
@@ -1035,8 +991,6 @@ class Florence2PostProcesser(object):
1035
  def __call__(
1036
  self,
1037
  text=None,
1038
- sequence=None,
1039
- transition_beam_score=None,
1040
  image_size=None,
1041
  parse_tasks=None,
1042
  ):
@@ -1045,6 +999,7 @@ class Florence2PostProcesser(object):
1045
  text: model outputs
1046
  image_size: (width, height)
1047
  parse_tasks: a list of tasks to parse, if None, parse all tasks.
 
1048
  """
1049
  if parse_tasks is not None:
1050
  if isinstance(parse_tasks, str):
@@ -1053,18 +1008,7 @@ class Florence2PostProcesser(object):
1053
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1054
 
1055
  # sequence or text should be provided
1056
- assert sequence is not None or text is not None, 'sequence or text should be provided'
1057
- assert sequence is None or text is None, 'only one of sequence and text should be provided'
1058
-
1059
- if sequence is not None:
1060
- sequence = sequence.tolist()[1:]
1061
- text, spans = self.decode_with_spans(self.tokenizer, sequence)
1062
- if transition_beam_score is not None:
1063
- transition_beam_score = transition_beam_score.tolist()
1064
- assert len(sequence) == len(transition_beam_score)
1065
- else:
1066
- spans = None
1067
- transition_beam_score = None
1068
 
1069
  parsed_dict = {
1070
  'text': text
@@ -1075,7 +1019,6 @@ class Florence2PostProcesser(object):
1075
  continue
1076
 
1077
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
1078
- score_mode = self.parse_tasks_configs[task].get('SCORE_MODE', None)
1079
 
1080
  if task == 'ocr':
1081
  instances = self.parse_ocr_from_text_and_spans(
@@ -1097,9 +1040,6 @@ class Florence2PostProcesser(object):
1097
  elif task == 'description_with_bboxes':
1098
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1099
  text,
1100
- spans=spans,
1101
- scores=transition_beam_score,
1102
- score_mode=score_mode,
1103
  pattern=pattern,
1104
  image_size=image_size,
1105
  )
 
20
  import logging
21
  from typing import List, Optional, Union
22
  import numpy as np
 
23
 
24
  import torch
25
 
 
32
  TextInput,
33
  TruncationStrategy,
34
  )
 
35
  from transformers.utils import TensorType
36
 
37
 
 
304
  image_processor_input_names = self.image_processor.model_input_names
305
  return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
306
 
307
+ def post_process_generation(self, text, task, image_size):
308
  """
309
  Post-process the output of the model to each of the task outputs.
310
 
 
317
  task_answer_post_processing_type = self.tasks_answer_post_processing_type.get(task, 'pure_text')
318
  task_answer = self.post_processor(
319
  text=text,
 
 
320
  image_size=image_size,
321
  parse_tasks=task_answer_post_processing_type,
322
  )[task_answer_post_processing_type]
 
330
  bboxes_od = [_od_instance['bbox'] for _od_instance in od_instances]
331
  labels_od = [str(_od_instance['cat_name']) for _od_instance in od_instances]
332
  final_answer = {'bboxes': bboxes_od, 'labels': labels_od}
 
 
 
333
  elif task_answer_post_processing_type in ['ocr']:
334
  bboxes = [_od_instance['quad_box'] for _od_instance in task_answer]
335
  labels = [str(_od_instance['text']) for _od_instance in task_answer]
 
496
 
497
 
498
  class Florence2PostProcesser(object):
499
+ """
500
  Florence-2 post process for converting text prediction to various tasks results.
501
 
502
  Args:
 
591
  'PARSE_TASKS': [
592
  {
593
  'TASK_NAME': 'od',
594
+ 'PATTERN': r'([a-zA-Z0-9 ]+)<loc_(\\d+)><loc_(\\d+)><loc_(\\d+)><loc_(\\d+)>'
 
595
  },
596
  {
597
  'TASK_NAME': 'ocr',
 
607
  },
608
  {
609
  'TASK_NAME': 'description_with_bboxes',
 
610
  },
611
  {
612
  'TASK_NAME': 'description_with_polygons',
 
648
  token_ids, skip_special_tokens=False)
649
  assert len(filtered_tokens) == len(token_ids)
650
 
651
+ # To avoid mixing byte-level and unicode for byte-level BPT
652
+ # we need to build string separately for added tokens and byte-level tokens
653
+ # cf. https://github.com/huggingface/transformers/issues/1133
654
  sub_texts = []
655
  for token in filtered_tokens:
656
  if token in self.all_special_tokens:
 
658
  else:
659
  if isinstance(tokenizer, (BartTokenizer, BartTokenizerFast)):
660
  sub_text = tokenizer.convert_tokens_to_string([token])
661
+ elif isinstance(tokenizer, (T5Tokenizer, T5TokenizerFast)):
662
+ # Ref: https://github.com/google/sentencepiece#whitespace-is-treated-as-a-basic-symbol
663
+ # Note: Do not strip sub_text as it may have functional whitespace
664
+ sub_text = token.replace('▁', ' ')
665
  else:
666
  raise ValueError(f'type {type(tokenizer)} not supported')
667
  sub_texts.append(sub_text)
 
673
  text += sub_text
674
  spans.append(span)
675
 
676
+ # Text format:
677
+ # 1. T5Tokenizer/T5TokenizerFast:
678
+ # "<loc_1><loc_2><loc_3><loc_4> transplanting dog<loc_1><loc_2><loc_3><loc_4> cat</s>"
679
+ # Equivalent to t5_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
680
+ # 2. BartTokenizer (need to double check):
681
+ # "<s><loc_1><loc_2><loc_3><loc_4>transplanting dog<loc_1><loc_2><loc_3><loc_4>cat</s>"
682
+ # Equivalent to bart_tokenizer.decode(input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False, spaces_between_special_tokens=False)
683
  return text, spans
684
 
685
  def parse_od_from_text_and_spans(
 
714
  return instances
715
 
716
  def parse_ocr_from_text_and_spans(self,
717
+ text,
718
  pattern,
719
  image_size,
720
  area_threshold=-1.0,
 
818
 
819
  return instances
820
 
821
+ def parse_description_with_bboxes_from_text_and_spans(self, text, pattern, image_size, allow_empty_phrase=False):
822
+ # temporary parse solution, split by '.'
823
+ # ignore <s> </s> and <pad>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
825
  text = text.replace('<s>', '')
826
  text = text.replace('</s>', '')
 
842
  phrase_text_strip = pharse_text.replace('<obj>', '', 1)
843
 
844
  if phrase_text_strip == '' and not allow_empty_phrase:
 
845
  continue
846
 
847
  # parse phrase, get string
848
  phrase = re.search(pattern, phrase_text_strip)
849
  if phrase is None:
 
850
  continue
851
 
 
852
  phrase = phrase.group()
853
  # remove leading and trailing spaces
854
  phrase = phrase.strip()
 
856
  # parse bboxes by box_pattern
857
  bboxes_parsed = list(re.finditer(box_pattern, pharse_text))
858
  if len(bboxes_parsed) == 0:
 
859
  continue
860
 
861
  # a list of list
 
866
  size=image_size
867
  ).tolist()
868
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  phrase = phrase.encode('ascii',errors='ignore').decode('ascii')
870
+ for _bboxes in bboxes:
871
  # Prepare instance.
872
  instance = {}
873
  instance['bbox'] = _bboxes
874
  # exclude non-ascii characters
875
  instance['cat_name'] = phrase
 
 
876
  instances.append(instance)
 
 
877
 
878
  return instances
879
 
 
991
  def __call__(
992
  self,
993
  text=None,
 
 
994
  image_size=None,
995
  parse_tasks=None,
996
  ):
 
999
  text: model outputs
1000
  image_size: (width, height)
1001
  parse_tasks: a list of tasks to parse, if None, parse all tasks.
1002
+
1003
  """
1004
  if parse_tasks is not None:
1005
  if isinstance(parse_tasks, str):
 
1008
  assert _parse_task in self.parse_tasks, f'parse task {_parse_task} not supported'
1009
 
1010
  # sequence or text should be provided
1011
+ assert text is not None, 'text should be provided'
 
 
 
 
 
 
 
 
 
 
 
1012
 
1013
  parsed_dict = {
1014
  'text': text
 
1019
  continue
1020
 
1021
  pattern = self.parse_tasks_configs[task].get('PATTERN', None)
 
1022
 
1023
  if task == 'ocr':
1024
  instances = self.parse_ocr_from_text_and_spans(
 
1040
  elif task == 'description_with_bboxes':
1041
  instances = self.parse_description_with_bboxes_from_text_and_spans(
1042
  text,
 
 
 
1043
  pattern=pattern,
1044
  image_size=image_size,
1045
  )