|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from .core import ResponsePreprocessor |
|
|
|
|
|
|
|
|
class GroundingMixin: |
|
|
"""This class offers prompts to the grounding task""" |
|
|
task_type: Optional[str] = None |
|
|
|
|
|
_grounding_language_mixin = [0.8, 0.2] |
|
|
_grounding_prompts = { |
|
|
'grounding': { |
|
|
'en': [('<ref-object>', '<bbox>'), ('The positions of <ref-object> is', '<bbox>'), |
|
|
('Find the positions of <ref-object>', '<bbox>'), ('Where is <ref-object>', '<bbox>'), |
|
|
('Find <ref-object>', '<bbox>'), ('Show me <ref-object>', '<bbox>'), |
|
|
('Detect <ref-object>', '<bbox>'), ('Locate <ref-object>', '<bbox>'), |
|
|
('Tell me the location of <ref-object>', '<bbox>'), ('Give the location of <ref-object>', '<bbox>'), |
|
|
('Provide the bounding box coordinate of <ref-object>', '<bbox>')], |
|
|
'zh': [('<ref-object>', '<bbox>'), ('<ref-object>的位置在图片中', '<bbox>'), ('<ref-object>在图片中', '<bbox>'), |
|
|
('<ref-object>在', '<bbox>'), ('找到<ref-object>的位置', '<bbox>'), ('<ref-object>在哪里', '<bbox>'), |
|
|
('提供<ref-object>的坐标位置', '<bbox>')] |
|
|
}, |
|
|
'caption': { |
|
|
'en': [ |
|
|
('<bbox>', '<ref-object>'), |
|
|
('The object at position <bbox>', '<ref-object>'), |
|
|
('This <bbox> is', '<ref-object>'), |
|
|
('What is the object at <bbox>', '<ref-object>'), |
|
|
('Describe <bbox>', '<ref-object>'), |
|
|
('<bbox> is', '<ref-object>'), |
|
|
('The bounding box coordinate <bbox> contains', '<ref-object>'), |
|
|
], |
|
|
'zh': [ |
|
|
('<bbox>', '<ref-object>'), |
|
|
('<bbox>是什么', '<ref-object>'), |
|
|
('<bbox>的位置包含', '<ref-object>'), |
|
|
('描述<bbox>', '<ref-object>'), |
|
|
('<bbox>中是', '<ref-object>'), |
|
|
('坐标<bbox>描述了什么', '<ref-object>'), |
|
|
('描述<bbox>中的事物', '<ref-object>'), |
|
|
] |
|
|
}, |
|
|
} |
|
|
|
|
|
def construct_grounding_prompt(self): |
|
|
|
|
|
lang = np.random.choice(['en', 'zh'], p=[0.8, 0.2]) |
|
|
prompts = GroundingMixin._grounding_prompts[self.task_type][lang] |
|
|
query, response = prompts[np.random.choice(range(len(prompts)))] |
|
|
return query, response |
|
|
|
|
|
|
|
|
class TextGenerationPreprocessor(ResponsePreprocessor): |
|
|
|
|
|
def __init__(self, |
|
|
*, |
|
|
prompt: str, |
|
|
query_tag: str = '{{QUERY}}', |
|
|
columns: Optional[Dict[str, str]] = None, |
|
|
**kwargs) -> None: |
|
|
self.query_tag = query_tag |
|
|
self.prompt = prompt |
|
|
super().__init__(columns=columns, **kwargs) |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: |
|
|
row['query'] = self.prompt.replace(self.query_tag, row['query']) |
|
|
return super().preprocess(row) |
|
|
|
|
|
|
|
|
class ClsGenerationPreprocessor(ResponsePreprocessor): |
|
|
|
|
|
def __init__(self, |
|
|
labels: List[str], |
|
|
*, |
|
|
task: str, |
|
|
is_pair_seq: bool = False, |
|
|
columns: Optional[Dict[str, str]] = None, |
|
|
**kwargs) -> None: |
|
|
self.labels = labels |
|
|
self.task = task |
|
|
self.is_pair_seq = is_pair_seq |
|
|
|
|
|
category = ', '.join(labels) |
|
|
self.sentence2_key = 'sentence2' |
|
|
self.label_key = 'label' |
|
|
if is_pair_seq: |
|
|
self.sentence_key = 'sentence1' |
|
|
inputs = 'Sentence1: {sentence1}\nSentence2: {sentence2}' |
|
|
else: |
|
|
self.sentence_key = 'sentence' |
|
|
inputs = 'Sentence: {sentence}' |
|
|
self.prompt = f"""Task: {task} |
|
|
{inputs} |
|
|
Category: {category} |
|
|
Output:""" |
|
|
super().__init__(columns=columns, **kwargs) |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
label = row.pop(self.label_key, None) |
|
|
if label is None: |
|
|
return |
|
|
|
|
|
if self.is_pair_seq: |
|
|
query = self.prompt.format(sentence1=row.pop(self.sentence_key), sentence2=row.pop(self.sentence2_key)) |
|
|
else: |
|
|
query = self.prompt.format(sentence=row.pop(self.sentence_key)) |
|
|
row['query'] = query |
|
|
row['response'] = self.labels[int(label)] |
|
|
return super().preprocess(row) |
|
|
|