Student0809's picture
Add files using upload-large-folder tool
636c5b4 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, List, Optional
import numpy as np
from .core import ResponsePreprocessor
class GroundingMixin:
"""This class offers prompts to the grounding task"""
task_type: Optional[str] = None
_grounding_language_mixin = [0.8, 0.2]
_grounding_prompts = {
'grounding': {
'en': [('<ref-object>', '<bbox>'), ('The positions of <ref-object> is', '<bbox>'),
('Find the positions of <ref-object>', '<bbox>'), ('Where is <ref-object>', '<bbox>'),
('Find <ref-object>', '<bbox>'), ('Show me <ref-object>', '<bbox>'),
('Detect <ref-object>', '<bbox>'), ('Locate <ref-object>', '<bbox>'),
('Tell me the location of <ref-object>', '<bbox>'), ('Give the location of <ref-object>', '<bbox>'),
('Provide the bounding box coordinate of <ref-object>', '<bbox>')],
'zh': [('<ref-object>', '<bbox>'), ('<ref-object>的位置在图片中', '<bbox>'), ('<ref-object>在图片中', '<bbox>'),
('<ref-object>在', '<bbox>'), ('找到<ref-object>的位置', '<bbox>'), ('<ref-object>在哪里', '<bbox>'),
('提供<ref-object>的坐标位置', '<bbox>')]
},
'caption': {
'en': [
('<bbox>', '<ref-object>'),
('The object at position <bbox>', '<ref-object>'),
('This <bbox> is', '<ref-object>'),
('What is the object at <bbox>', '<ref-object>'),
('Describe <bbox>', '<ref-object>'),
('<bbox> is', '<ref-object>'),
('The bounding box coordinate <bbox> contains', '<ref-object>'),
],
'zh': [
('<bbox>', '<ref-object>'),
('<bbox>是什么', '<ref-object>'),
('<bbox>的位置包含', '<ref-object>'),
('描述<bbox>', '<ref-object>'),
('<bbox>中是', '<ref-object>'),
('坐标<bbox>描述了什么', '<ref-object>'),
('描述<bbox>中的事物', '<ref-object>'),
]
},
}
def construct_grounding_prompt(self):
# TODO Only support one bbox to one object
lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2])
prompts = GroundingMixin._grounding_prompts[self.task_type][lang]
query, response = prompts[np.random.choice(range(len(prompts)))]
return query, response
class TextGenerationPreprocessor(ResponsePreprocessor):
def __init__(self,
*,
prompt: str,
query_tag: str = '{{QUERY}}',
columns: Optional[Dict[str, str]] = None,
**kwargs) -> None:
self.query_tag = query_tag
self.prompt = prompt
super().__init__(columns=columns, **kwargs)
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
row['query'] = self.prompt.replace(self.query_tag, row['query'])
return super().preprocess(row)
class ClsGenerationPreprocessor(ResponsePreprocessor):
def __init__(self,
labels: List[str],
*,
task: str,
is_pair_seq: bool = False,
columns: Optional[Dict[str, str]] = None,
**kwargs) -> None:
self.labels = labels
self.task = task
self.is_pair_seq = is_pair_seq
category = ', '.join(labels)
self.sentence2_key = 'sentence2'
self.label_key = 'label'
if is_pair_seq:
self.sentence_key = 'sentence1'
inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}'
else:
self.sentence_key = 'sentence'
inputs = 'Sentence: {sentence}'
self.prompt = f"""Task: {task}
{inputs}
Category: {category}
Output:"""
super().__init__(columns=columns, **kwargs)
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
label = row.pop(self.label_key, None)
if label is None:
return
if self.is_pair_seq:
query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key))
else:
query = self.prompt.format(sentence=row.pop(self.sentence_key))
row['query'] = query
row['response'] = self.labels[int(label)]
return super().preprocess(row)