# Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Optional import numpy as np from .core import ResponsePreprocessor class GroundingMixin: """This class offers prompts to the grounding task""" task_type: Optional[str] = None _grounding_language_mixin = [0.8, 0.2] _grounding_prompts = { 'grounding': { 'en': [('', ''), ('The positions of is', ''), ('Find the positions of ', ''), ('Where is ', ''), ('Find ', ''), ('Show me ', ''), ('Detect ', ''), ('Locate ', ''), ('Tell me the location of ', ''), ('Give the location of ', ''), ('Provide the bounding box coordinate of ', '')], 'zh': [('', ''), ('的位置在图片中', ''), ('在图片中', ''), ('在', ''), ('找到的位置', ''), ('在哪里', ''), ('提供的坐标位置', '')] }, 'caption': { 'en': [ ('', ''), ('The object at position ', ''), ('This is', ''), ('What is the object at ', ''), ('Describe ', ''), (' is', ''), ('The bounding box coordinate contains', ''), ], 'zh': [ ('', ''), ('是什么', ''), ('的位置包含', ''), ('描述', ''), ('中是', ''), ('坐标描述了什么', ''), ('描述中的事物', ''), ] }, } def construct_grounding_prompt(self): # TODO Only support one bbox to one object lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2]) prompts = GroundingMixin._grounding_prompts[self.task_type][lang] query, response = prompts[np.random.choice(range(len(prompts)))] return query, response class TextGenerationPreprocessor(ResponsePreprocessor): def __init__(self, *, prompt: str, query_tag: str = '{{QUERY}}', columns: Optional[Dict[str, str]] = None, **kwargs) -> None: self.query_tag = query_tag self.prompt = prompt super().__init__(columns=columns, **kwargs) def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: row['query'] = self.prompt.replace(self.query_tag, row['query']) return super().preprocess(row) class ClsGenerationPreprocessor(ResponsePreprocessor): def __init__(self, labels: List[str], *, task: str, is_pair_seq: bool = False, columns: Optional[Dict[str, str]] = None, **kwargs) -> None: self.labels = labels self.task = task self.is_pair_seq = is_pair_seq category = ', '.join(labels) self.sentence2_key = 'sentence2' self.label_key = 'label' if is_pair_seq: self.sentence_key = 'sentence1' inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}' else: self.sentence_key = 'sentence' inputs = 'Sentence: {sentence}' self.prompt = f"""Task: {task} {inputs} Category: {category} Output:""" super().__init__(columns=columns, **kwargs) def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: label = row.pop(self.label_key, None) if label is None: return if self.is_pair_seq: query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key)) else: query = self.prompt.format(sentence=row.pop(self.sentence_key)) row['query'] = query row['response'] = self.labels[int(label)] return super().preprocess(row)