| | |
| | import copy |
| | import re |
| | import warnings |
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from mmdet.registry import MODELS |
| | from mmdet.structures import SampleList |
| | from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
| | from .single_stage import SingleStageDetector |
| |
|
| |
|
| | def find_noun_phrases(caption: str) -> list: |
| | """Find noun phrases in a caption using nltk. |
| | Args: |
| | caption (str): The caption to analyze. |
| | |
| | Returns: |
| | list: List of noun phrases found in the caption. |
| | |
| | Examples: |
| | >>> caption = 'There is two cat and a remote in the picture' |
| | >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture'] |
| | """ |
| | try: |
| | import nltk |
| | nltk.download('punkt', download_dir='~/nltk_data') |
| | nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data') |
| | except ImportError: |
| | raise RuntimeError('nltk is not installed, please install it by: ' |
| | 'pip install nltk.') |
| |
|
| | caption = caption.lower() |
| | tokens = nltk.word_tokenize(caption) |
| | pos_tags = nltk.pos_tag(tokens) |
| |
|
| | grammar = 'NP: {<DT>?<JJ.*>*<NN.*>+}' |
| | cp = nltk.RegexpParser(grammar) |
| | result = cp.parse(pos_tags) |
| |
|
| | noun_phrases = [] |
| | for subtree in result.subtrees(): |
| | if subtree.label() == 'NP': |
| | noun_phrases.append(' '.join(t[0] for t in subtree.leaves())) |
| |
|
| | return noun_phrases |
| |
|
| |
|
| | def remove_punctuation(text: str) -> str: |
| | """Remove punctuation from a text. |
| | Args: |
| | text (str): The input text. |
| | |
| | Returns: |
| | str: The text with punctuation removed. |
| | """ |
| | punctuation = [ |
| | '|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^', '\'', '\"', '’', |
| | '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.' |
| | ] |
| | for p in punctuation: |
| | text = text.replace(p, '') |
| | return text.strip() |
| |
|
| |
|
| | def run_ner(caption: str) -> Tuple[list, list]: |
| | """Run NER on a caption and return the tokens and noun phrases. |
| | Args: |
| | caption (str): The input caption. |
| | |
| | Returns: |
| | Tuple[List, List]: A tuple containing the tokens and noun phrases. |
| | - tokens_positive (List): A list of token positions. |
| | - noun_phrases (List): A list of noun phrases. |
| | """ |
| | noun_phrases = find_noun_phrases(caption) |
| | noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases] |
| | noun_phrases = [phrase for phrase in noun_phrases if phrase != ''] |
| | print('noun_phrases:', noun_phrases) |
| | relevant_phrases = noun_phrases |
| | labels = noun_phrases |
| |
|
| | tokens_positive = [] |
| | for entity, label in zip(relevant_phrases, labels): |
| | try: |
| | |
| | |
| | for m in re.finditer(entity, caption.lower()): |
| | tokens_positive.append([[m.start(), m.end()]]) |
| | except Exception: |
| | print('noun entities:', noun_phrases) |
| | print('entity:', entity) |
| | print('caption:', caption.lower()) |
| | return tokens_positive, noun_phrases |
| |
|
| |
|
| | def create_positive_map(tokenized, |
| | tokens_positive: list, |
| | max_num_entities: int = 256) -> Tensor: |
| | """construct a map such that positive_map[i,j] = True |
| | if box i is associated to token j |
| | |
| | Args: |
| | tokenized: The tokenized input. |
| | tokens_positive (list): A list of token ranges |
| | associated with positive boxes. |
| | max_num_entities (int, optional): The maximum number of entities. |
| | Defaults to 256. |
| | |
| | Returns: |
| | torch.Tensor: The positive map. |
| | |
| | Raises: |
| | Exception: If an error occurs during token-to-char mapping. |
| | """ |
| | positive_map = torch.zeros((len(tokens_positive), max_num_entities), |
| | dtype=torch.float) |
| |
|
| | for j, tok_list in enumerate(tokens_positive): |
| | for (beg, end) in tok_list: |
| | try: |
| | beg_pos = tokenized.char_to_token(beg) |
| | end_pos = tokenized.char_to_token(end - 1) |
| | except Exception as e: |
| | print('beg:', beg, 'end:', end) |
| | print('token_positive:', tokens_positive) |
| | raise e |
| | if beg_pos is None: |
| | try: |
| | beg_pos = tokenized.char_to_token(beg + 1) |
| | if beg_pos is None: |
| | beg_pos = tokenized.char_to_token(beg + 2) |
| | except Exception: |
| | beg_pos = None |
| | if end_pos is None: |
| | try: |
| | end_pos = tokenized.char_to_token(end - 2) |
| | if end_pos is None: |
| | end_pos = tokenized.char_to_token(end - 3) |
| | except Exception: |
| | end_pos = None |
| | if beg_pos is None or end_pos is None: |
| | continue |
| |
|
| | assert beg_pos is not None and end_pos is not None |
| | positive_map[j, beg_pos:end_pos + 1].fill_(1) |
| | return positive_map / (positive_map.sum(-1)[:, None] + 1e-6) |
| |
|
| |
|
| | def create_positive_map_label_to_token(positive_map: Tensor, |
| | plus: int = 0) -> dict: |
| | """Create a dictionary mapping the label to the token. |
| | Args: |
| | positive_map (Tensor): The positive map tensor. |
| | plus (int, optional): Value added to the label for indexing. |
| | Defaults to 0. |
| | |
| | Returns: |
| | dict: The dictionary mapping the label to the token. |
| | """ |
| | positive_map_label_to_token = {} |
| | for i in range(len(positive_map)): |
| | positive_map_label_to_token[i + plus] = torch.nonzero( |
| | positive_map[i], as_tuple=True)[0].tolist() |
| | return positive_map_label_to_token |
| |
|
| |
|
| | def clean_label_name(name: str) -> str: |
| | name = re.sub(r'\(.*\)', '', name) |
| | name = re.sub(r'_', ' ', name) |
| | name = re.sub(r' ', ' ', name) |
| | return name |
| |
|
| |
|
| | def chunks(lst: list, n: int) -> list: |
| | """Yield successive n-sized chunks from lst.""" |
| | all_ = [] |
| | for i in range(0, len(lst), n): |
| | data_index = lst[i:i + n] |
| | all_.append(data_index) |
| | counter = 0 |
| | for i in all_: |
| | counter += len(i) |
| | assert (counter == len(lst)) |
| |
|
| | return all_ |
| |
|
| |
|
| | @MODELS.register_module() |
| | class GLIP(SingleStageDetector): |
| | """Implementation of `GLIP <https://arxiv.org/abs/2112.03857>`_ |
| | Args: |
| | backbone (:obj:`ConfigDict` or dict): The backbone config. |
| | neck (:obj:`ConfigDict` or dict): The neck config. |
| | bbox_head (:obj:`ConfigDict` or dict): The bbox head config. |
| | language_model (:obj:`ConfigDict` or dict): The language model config. |
| | train_cfg (:obj:`ConfigDict` or dict, optional): The training config |
| | of GLIP. Defaults to None. |
| | test_cfg (:obj:`ConfigDict` or dict, optional): The testing config |
| | of GLIP. Defaults to None. |
| | data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of |
| | :class:`DetDataPreprocessor` to process the input data. |
| | Defaults to None. |
| | init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
| | list[dict], optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | def __init__(self, |
| | backbone: ConfigType, |
| | neck: ConfigType, |
| | bbox_head: ConfigType, |
| | language_model: ConfigType, |
| | train_cfg: OptConfigType = None, |
| | test_cfg: OptConfigType = None, |
| | data_preprocessor: OptConfigType = None, |
| | init_cfg: OptMultiConfig = None) -> None: |
| | super().__init__( |
| | backbone=backbone, |
| | neck=neck, |
| | bbox_head=bbox_head, |
| | train_cfg=train_cfg, |
| | test_cfg=test_cfg, |
| | data_preprocessor=data_preprocessor, |
| | init_cfg=init_cfg) |
| | self.language_model = MODELS.build(language_model) |
| |
|
| | self._special_tokens = '. ' |
| |
|
| | def to_enhance_text_prompts(self, original_caption, enhanced_text_prompts): |
| | caption_string = '' |
| | tokens_positive = [] |
| | for idx, word in enumerate(original_caption): |
| | if word in enhanced_text_prompts: |
| | enhanced_text_dict = enhanced_text_prompts[word] |
| | if 'prefix' in enhanced_text_dict: |
| | caption_string += enhanced_text_dict['prefix'] |
| | start_i = len(caption_string) |
| | if 'name' in enhanced_text_dict: |
| | caption_string += enhanced_text_dict['name'] |
| | else: |
| | caption_string += word |
| | end_i = len(caption_string) |
| | tokens_positive.append([[start_i, end_i]]) |
| |
|
| | if 'suffix' in enhanced_text_dict: |
| | caption_string += enhanced_text_dict['suffix'] |
| | else: |
| | tokens_positive.append( |
| | [[len(caption_string), |
| | len(caption_string) + len(word)]]) |
| | caption_string += word |
| |
|
| | if idx != len(original_caption) - 1: |
| | caption_string += self._special_tokens |
| | return caption_string, tokens_positive |
| |
|
| | def to_plain_text_prompts(self, original_caption): |
| | caption_string = '' |
| | tokens_positive = [] |
| | for idx, word in enumerate(original_caption): |
| | tokens_positive.append( |
| | [[len(caption_string), |
| | len(caption_string) + len(word)]]) |
| | caption_string += word |
| | if idx != len(original_caption) - 1: |
| | caption_string += self._special_tokens |
| | return caption_string, tokens_positive |
| |
|
| | def get_tokens_and_prompts( |
| | self, |
| | original_caption: Union[str, list, tuple], |
| | custom_entities: bool = False, |
| | enhanced_text_prompts: Optional[ConfigType] = None |
| | ) -> Tuple[dict, str, list, list]: |
| | """Get the tokens positive and prompts for the caption.""" |
| | if isinstance(original_caption, (list, tuple)) or custom_entities: |
| | if custom_entities and isinstance(original_caption, str): |
| | original_caption = original_caption.strip(self._special_tokens) |
| | original_caption = original_caption.split(self._special_tokens) |
| | original_caption = list( |
| | filter(lambda x: len(x) > 0, original_caption)) |
| |
|
| | original_caption = [clean_label_name(i) for i in original_caption] |
| |
|
| | if custom_entities and enhanced_text_prompts is not None: |
| | caption_string, tokens_positive = self.to_enhance_text_prompts( |
| | original_caption, enhanced_text_prompts) |
| | else: |
| | caption_string, tokens_positive = self.to_plain_text_prompts( |
| | original_caption) |
| |
|
| | tokenized = self.language_model.tokenizer([caption_string], |
| | return_tensors='pt') |
| | entities = original_caption |
| | else: |
| | original_caption = original_caption.strip(self._special_tokens) |
| | tokenized = self.language_model.tokenizer([original_caption], |
| | return_tensors='pt') |
| | tokens_positive, noun_phrases = run_ner(original_caption) |
| | entities = noun_phrases |
| | caption_string = original_caption |
| |
|
| | return tokenized, caption_string, tokens_positive, entities |
| |
|
| | def get_positive_map(self, tokenized, tokens_positive): |
| | positive_map = create_positive_map(tokenized, tokens_positive) |
| | positive_map_label_to_token = create_positive_map_label_to_token( |
| | positive_map, plus=1) |
| | return positive_map_label_to_token, positive_map |
| |
|
| | def get_tokens_positive_and_prompts( |
| | self, |
| | original_caption: Union[str, list, tuple], |
| | custom_entities: bool = False, |
| | enhanced_text_prompt: Optional[ConfigType] = None, |
| | tokens_positive: Optional[list] = None, |
| | ) -> Tuple[dict, str, Tensor, list]: |
| | if tokens_positive is not None: |
| | if tokens_positive == -1: |
| | if not original_caption.endswith('.'): |
| | original_caption = original_caption + self._special_tokens |
| | return None, original_caption, None, original_caption |
| | else: |
| | if not original_caption.endswith('.'): |
| | original_caption = original_caption + self._special_tokens |
| | tokenized = self.language_model.tokenizer([original_caption], |
| | return_tensors='pt') |
| | positive_map_label_to_token, positive_map = \ |
| | self.get_positive_map(tokenized, tokens_positive) |
| |
|
| | entities = [] |
| | for token_positive in tokens_positive: |
| | instance_entities = [] |
| | for t in token_positive: |
| | instance_entities.append(original_caption[t[0]:t[1]]) |
| | entities.append(' / '.join(instance_entities)) |
| | return positive_map_label_to_token, original_caption, \ |
| | positive_map, entities |
| |
|
| | chunked_size = self.test_cfg.get('chunked_size', -1) |
| | if not self.training and chunked_size > 0: |
| | assert isinstance(original_caption, |
| | (list, tuple)) or custom_entities is True |
| | all_output = self.get_tokens_positive_and_prompts_chunked( |
| | original_caption, enhanced_text_prompt) |
| | positive_map_label_to_token, \ |
| | caption_string, \ |
| | positive_map, \ |
| | entities = all_output |
| | else: |
| | tokenized, caption_string, tokens_positive, entities = \ |
| | self.get_tokens_and_prompts( |
| | original_caption, custom_entities, enhanced_text_prompt) |
| | positive_map_label_to_token, positive_map = self.get_positive_map( |
| | tokenized, tokens_positive) |
| | if tokenized.input_ids.shape[1] > self.language_model.max_tokens: |
| | warnings.warn('Inputting a text that is too long will result ' |
| | 'in poor prediction performance. ' |
| | 'Please reduce the text length.') |
| | return positive_map_label_to_token, caption_string, \ |
| | positive_map, entities |
| |
|
| | def get_tokens_positive_and_prompts_chunked( |
| | self, |
| | original_caption: Union[list, tuple], |
| | enhanced_text_prompts: Optional[ConfigType] = None): |
| | chunked_size = self.test_cfg.get('chunked_size', -1) |
| | original_caption = [clean_label_name(i) for i in original_caption] |
| |
|
| | original_caption_chunked = chunks(original_caption, chunked_size) |
| | ids_chunked = chunks( |
| | list(range(1, |
| | len(original_caption) + 1)), chunked_size) |
| |
|
| | positive_map_label_to_token_chunked = [] |
| | caption_string_chunked = [] |
| | positive_map_chunked = [] |
| | entities_chunked = [] |
| |
|
| | for i in range(len(ids_chunked)): |
| | if enhanced_text_prompts is not None: |
| | caption_string, tokens_positive = self.to_enhance_text_prompts( |
| | original_caption_chunked[i], enhanced_text_prompts) |
| | else: |
| | caption_string, tokens_positive = self.to_plain_text_prompts( |
| | original_caption_chunked[i]) |
| | tokenized = self.language_model.tokenizer([caption_string], |
| | return_tensors='pt') |
| | if tokenized.input_ids.shape[1] > self.language_model.max_tokens: |
| | warnings.warn('Inputting a text that is too long will result ' |
| | 'in poor prediction performance. ' |
| | 'Please reduce the --chunked-size.') |
| | positive_map_label_to_token, positive_map = self.get_positive_map( |
| | tokenized, tokens_positive) |
| |
|
| | caption_string_chunked.append(caption_string) |
| | positive_map_label_to_token_chunked.append( |
| | positive_map_label_to_token) |
| | positive_map_chunked.append(positive_map) |
| | entities_chunked.append(original_caption_chunked[i]) |
| |
|
| | return positive_map_label_to_token_chunked, \ |
| | caption_string_chunked, \ |
| | positive_map_chunked, \ |
| | entities_chunked |
| |
|
| | def loss(self, batch_inputs: Tensor, |
| | batch_data_samples: SampleList) -> Union[dict, list]: |
| | |
| | text_prompts = [ |
| | data_samples.text for data_samples in batch_data_samples |
| | ] |
| |
|
| | gt_labels = [ |
| | data_samples.gt_instances.labels |
| | for data_samples in batch_data_samples |
| | ] |
| |
|
| | new_text_prompts = [] |
| | positive_maps = [] |
| | if len(set(text_prompts)) == 1: |
| | |
| | |
| | tokenized, caption_string, tokens_positive, _ = \ |
| | self.get_tokens_and_prompts( |
| | text_prompts[0], True) |
| | new_text_prompts = [caption_string] * len(batch_inputs) |
| | for gt_label in gt_labels: |
| | new_tokens_positive = [ |
| | tokens_positive[label] for label in gt_label |
| | ] |
| | _, positive_map = self.get_positive_map( |
| | tokenized, new_tokens_positive) |
| | positive_maps.append(positive_map) |
| | else: |
| | for text_prompt, gt_label in zip(text_prompts, gt_labels): |
| | tokenized, caption_string, tokens_positive, _ = \ |
| | self.get_tokens_and_prompts( |
| | text_prompt, True) |
| | new_tokens_positive = [ |
| | tokens_positive[label] for label in gt_label |
| | ] |
| | _, positive_map = self.get_positive_map( |
| | tokenized, new_tokens_positive) |
| | positive_maps.append(positive_map) |
| | new_text_prompts.append(caption_string) |
| |
|
| | language_dict_features = self.language_model(new_text_prompts) |
| | for i, data_samples in enumerate(batch_data_samples): |
| | |
| | positive_map = positive_maps[i].to( |
| | batch_inputs.device).bool().float() |
| | data_samples.gt_instances.positive_maps = positive_map |
| |
|
| | visual_features = self.extract_feat(batch_inputs) |
| |
|
| | losses = self.bbox_head.loss(visual_features, language_dict_features, |
| | batch_data_samples) |
| | return losses |
| |
|
| | def predict(self, |
| | batch_inputs: Tensor, |
| | batch_data_samples: SampleList, |
| | rescale: bool = True) -> SampleList: |
| | """Predict results from a batch of inputs and data samples with post- |
| | processing. |
| | |
| | Args: |
| | batch_inputs (Tensor): Inputs with shape (N, C, H, W). |
| | batch_data_samples (List[:obj:`DetDataSample`]): The Data |
| | Samples. It usually includes information such as |
| | `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. |
| | rescale (bool): Whether to rescale the results. |
| | Defaults to True. |
| | |
| | Returns: |
| | list[:obj:`DetDataSample`]: Detection results of the |
| | input images. Each DetDataSample usually contain |
| | 'pred_instances'. And the ``pred_instances`` usually |
| | contains following keys. |
| | |
| | - scores (Tensor): Classification scores, has a shape |
| | (num_instance, ) |
| | - labels (Tensor): Labels of bboxes, has a shape |
| | (num_instances, ). |
| | - label_names (List[str]): Label names of bboxes. |
| | - bboxes (Tensor): Has a shape (num_instances, 4), |
| | the last dimension 4 arrange as (x1, y1, x2, y2). |
| | """ |
| | text_prompts = [] |
| | enhanced_text_prompts = [] |
| | tokens_positives = [] |
| | for data_samples in batch_data_samples: |
| | text_prompts.append(data_samples.text) |
| | if 'caption_prompt' in data_samples: |
| | enhanced_text_prompts.append(data_samples.caption_prompt) |
| | else: |
| | enhanced_text_prompts.append(None) |
| | tokens_positives.append(data_samples.get('tokens_positive', None)) |
| |
|
| | if 'custom_entities' in batch_data_samples[0]: |
| | |
| | |
| | custom_entities = batch_data_samples[0].custom_entities |
| | else: |
| | custom_entities = False |
| |
|
| | if len(set(text_prompts)) == 1: |
| | |
| | |
| | _positive_maps_and_prompts = [ |
| | self.get_tokens_positive_and_prompts( |
| | text_prompts[0], custom_entities, enhanced_text_prompts[0], |
| | tokens_positives[0]) |
| | ] * len(batch_inputs) |
| | else: |
| | _positive_maps_and_prompts = [ |
| | self.get_tokens_positive_and_prompts(text_prompt, |
| | custom_entities, |
| | enhanced_text_prompt, |
| | tokens_positive) |
| | for text_prompt, enhanced_text_prompt, tokens_positive in zip( |
| | text_prompts, enhanced_text_prompts, tokens_positives) |
| | ] |
| |
|
| | token_positive_maps, text_prompts, _, entities = zip( |
| | *_positive_maps_and_prompts) |
| |
|
| | visual_features = self.extract_feat(batch_inputs) |
| |
|
| | if isinstance(text_prompts[0], list): |
| | |
| | assert len(batch_inputs) == 1 |
| | count = 0 |
| | results_list = [] |
| |
|
| | entities = [[item for lst in entities[0] for item in lst]] |
| |
|
| | for b in range(len(text_prompts[0])): |
| | text_prompts_once = [text_prompts[0][b]] |
| | token_positive_maps_once = token_positive_maps[0][b] |
| | language_dict_features = self.language_model(text_prompts_once) |
| | batch_data_samples[ |
| | 0].token_positive_map = token_positive_maps_once |
| |
|
| | pred_instances = self.bbox_head.predict( |
| | copy.deepcopy(visual_features), |
| | language_dict_features, |
| | batch_data_samples, |
| | rescale=rescale)[0] |
| |
|
| | if len(pred_instances) > 0: |
| | pred_instances.labels += count |
| | count += len(token_positive_maps_once) |
| | results_list.append(pred_instances) |
| | results_list = [results_list[0].cat(results_list)] |
| | else: |
| | language_dict_features = self.language_model(list(text_prompts)) |
| |
|
| | for i, data_samples in enumerate(batch_data_samples): |
| | data_samples.token_positive_map = token_positive_maps[i] |
| |
|
| | results_list = self.bbox_head.predict( |
| | visual_features, |
| | language_dict_features, |
| | batch_data_samples, |
| | rescale=rescale) |
| |
|
| | for data_sample, pred_instances, entity in zip(batch_data_samples, |
| | results_list, entities): |
| | if len(pred_instances) > 0: |
| | label_names = [] |
| | for labels in pred_instances.labels: |
| | if labels >= len(entity): |
| | warnings.warn( |
| | 'The unexpected output indicates an issue with ' |
| | 'named entity recognition. You can try ' |
| | 'setting custom_entities=True and running ' |
| | 'again to see if it helps.') |
| | label_names.append('unobject') |
| | else: |
| | label_names.append(entity[labels]) |
| | |
| | pred_instances.label_names = label_names |
| | data_sample.pred_instances = pred_instances |
| | return batch_data_samples |
| |
|