|
|
|
|
|
import ast |
|
|
import os |
|
|
from collections import Counter |
|
|
from contextlib import contextmanager |
|
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
from datasets import Dataset as HfDataset |
|
|
from datasets import Image |
|
|
from datasets import IterableDataset as HfIterableDataset |
|
|
from datasets import Sequence, Value |
|
|
|
|
|
from swift.llm import history_to_messages |
|
|
from swift.utils import get_logger, is_dist, is_master, safe_ddp_context |
|
|
|
|
|
DATASET_TYPE = Union[HfDataset, HfIterableDataset] |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class RowPreprocessor: |
|
|
standard_keys = ['messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects'] |
|
|
|
|
|
def __init__(self, |
|
|
*, |
|
|
columns: Optional[Dict[str, str]] = None, |
|
|
dataset_sample: Optional[int] = None, |
|
|
random_state: Union[np.random.RandomState, int, None] = 42, |
|
|
traceback_limit: int = 10) -> None: |
|
|
self.columns = columns or {} |
|
|
self.origin_columns = self.columns.copy() |
|
|
images_keys = ['images', 'image'] |
|
|
audios_keys = ['audios', 'audio'] |
|
|
videos_keys = ['videos', 'video'] |
|
|
for mm_type in ['images', 'audios', 'videos']: |
|
|
keys = locals()[f'{mm_type}_keys'] |
|
|
for key in keys: |
|
|
self.columns[key] = mm_type |
|
|
|
|
|
self.traceback_limit = traceback_limit |
|
|
self._traceback_counter = 0 |
|
|
self.dataset_sample = dataset_sample |
|
|
if not isinstance(random_state, np.random.RandomState): |
|
|
random_state = np.random.RandomState(random_state) |
|
|
self.random_state = random_state |
|
|
|
|
|
@staticmethod |
|
|
def _check_messages(row: Dict[str, Any]) -> None: |
|
|
if 'messages' not in row: |
|
|
return |
|
|
messages = row['messages'] |
|
|
assert len(messages) > 0, f'messages: {messages}' |
|
|
|
|
|
for message in messages: |
|
|
keys = set(message.keys()) - {'role', 'content'} |
|
|
for key in keys: |
|
|
message.pop(key) |
|
|
|
|
|
for message in messages: |
|
|
role, content = message['role'], message['content'] |
|
|
|
|
|
assert role in {'system', 'user', 'tool_call', 'tool_response', 'tool', 'assistant'}, f'message: {message}' |
|
|
assert content is not None, f'message: {message}' |
|
|
|
|
|
@staticmethod |
|
|
def _cast_images(row: Dict[str, Any]) -> None: |
|
|
images = row.get('images') |
|
|
|
|
|
if isinstance(images, str) or isinstance(images, list) and images and isinstance(images[0], str): |
|
|
if isinstance(images, str): |
|
|
images = [images] |
|
|
for i, image in enumerate(images): |
|
|
images[i] = {'bytes': None, 'path': image} |
|
|
row['images'] = images |
|
|
elif isinstance(images, dict): |
|
|
row['images'] = [images] |
|
|
|
|
|
@staticmethod |
|
|
def _check_rejected_response(row: Dict[str, Any]) -> None: |
|
|
if 'rejected_messages' in row: |
|
|
chosen_messages = row['messages'] |
|
|
rejected_messages = row['rejected_messages'] |
|
|
messages = [] |
|
|
rejected_response = None |
|
|
for chosen_user, chosen_assistant, rejected_user, rejected_assistant in zip( |
|
|
chosen_messages[::2], chosen_messages[1::2], rejected_messages[::2], rejected_messages[1::2]): |
|
|
assert chosen_user == rejected_user |
|
|
messages.append(chosen_user) |
|
|
messages.append(chosen_assistant) |
|
|
if chosen_assistant != rejected_assistant: |
|
|
rejected_response = rejected_assistant['content'] |
|
|
row['messages'] = messages |
|
|
row['rejected_response'] = rejected_response |
|
|
|
|
|
if 'rejected_response' in row: |
|
|
messages = row['messages'] |
|
|
rejected_response = row['rejected_response'] |
|
|
if rejected_response is None or rejected_response == messages[-1]['content']: |
|
|
raise ValueError(f'rejected_response: {rejected_response}') |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
raise NotImplementedError |
|
|
|
|
|
def prepare_dataset(self, dataset: DATASET_TYPE) -> DATASET_TYPE: |
|
|
return dataset |
|
|
|
|
|
@staticmethod |
|
|
def batched_to_rows(batched_row: Dict[str, Any]): |
|
|
keys = list(batched_row.keys()) |
|
|
batch_size = len(batched_row[keys[0]]) |
|
|
return [{key: batched_row[key][i] for key in keys} for i in range(batch_size)] |
|
|
|
|
|
@staticmethod |
|
|
def rows_to_batched(rows: List[Dict[str, Any]]): |
|
|
batched = {} |
|
|
for i, row in enumerate(rows): |
|
|
for k, v in row.items(): |
|
|
if k not in batched: |
|
|
batched[k] = [None] * i |
|
|
batched[k].append(v) |
|
|
|
|
|
for k in set(batched.keys()) - set(row.keys()): |
|
|
batched[k].append(None) |
|
|
return batched |
|
|
|
|
|
@staticmethod |
|
|
def _remove_prefix_keys(row, prefix: str): |
|
|
for k in list(row.keys()): |
|
|
if k.startswith(prefix): |
|
|
new_k = k[len(prefix):] |
|
|
new_v = row.pop(k) |
|
|
if new_k not in row: |
|
|
row[new_k] = new_v |
|
|
|
|
|
@staticmethod |
|
|
def _check_objects(row): |
|
|
objects = row.get('objects') |
|
|
if objects is None: |
|
|
return |
|
|
new_objects = {} |
|
|
|
|
|
for k in ['ref', 'bbox', 'bbox_type', 'image_id']: |
|
|
if k in objects.keys(): |
|
|
new_objects[k] = objects[k] |
|
|
row['objects'] = new_objects |
|
|
bbox = new_objects['bbox'] |
|
|
|
|
|
|
|
|
for box in bbox: |
|
|
assert len(box) in {2, 4}, f'len(box): {len(box)}' |
|
|
if len(box) == 2: |
|
|
continue |
|
|
if box[0] > box[2]: |
|
|
box[0], box[2] = box[2], box[0] |
|
|
if box[1] > box[3]: |
|
|
box[1], box[3] = box[3], box[1] |
|
|
|
|
|
def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool, |
|
|
ignore_max_length_error: bool) -> Dict[str, Any]: |
|
|
from ...template import MaxLengthError |
|
|
batched_row = dict(batched_row) |
|
|
assert len(batched_row) > 0 |
|
|
self._remove_prefix_keys(batched_row, '__@') |
|
|
rows = self.batched_to_rows(batched_row) |
|
|
|
|
|
new_rows = [] |
|
|
for row in rows: |
|
|
try: |
|
|
row = self.preprocess(row) |
|
|
|
|
|
if row is None: |
|
|
row = [] |
|
|
if isinstance(row, dict): |
|
|
row = [row] |
|
|
for r in row: |
|
|
self._check_objects(r) |
|
|
self._check_messages(r) |
|
|
self._check_rejected_response(r) |
|
|
self._cast_images(r) |
|
|
except Exception as e: |
|
|
if strict: |
|
|
logger.warning('To avoid errors, you can pass `strict=False`.') |
|
|
raise |
|
|
if isinstance(e, MaxLengthError) and ignore_max_length_error: |
|
|
pass |
|
|
elif self.traceback_limit is not None and self._traceback_counter < self.traceback_limit: |
|
|
import traceback |
|
|
logger.info(traceback.format_exc()) |
|
|
logger.warning('👆👆👆There are errors in the dataset, the data will be deleted') |
|
|
self._traceback_counter += 1 |
|
|
row = [] |
|
|
new_rows += row |
|
|
res = self.rows_to_batched(new_rows) |
|
|
self._remove_prefix_keys(res, '__#') |
|
|
if len(res) == 0: |
|
|
res['messages'] = [] |
|
|
|
|
|
return res |
|
|
|
|
|
@staticmethod |
|
|
def get_features_dataset(dataset: DATASET_TYPE) -> DATASET_TYPE: |
|
|
if dataset.features is None: |
|
|
assert isinstance(dataset, HfIterableDataset) |
|
|
dataset = dataset._resolve_features() |
|
|
return dataset |
|
|
|
|
|
@staticmethod |
|
|
def safe_rename_columns(dataset, columns): |
|
|
dataset = RowPreprocessor.get_features_dataset(dataset) |
|
|
columns_keys = {k.lower(): k for k in dataset.features.keys()} |
|
|
safe_columns = {columns_keys[k.lower()]: v for k, v in columns.items() if k.lower() in columns_keys} |
|
|
|
|
|
counter = Counter(safe_columns.values()) |
|
|
for k, new_k in list(safe_columns.items()): |
|
|
if counter[new_k] > 1: |
|
|
|
|
|
safe_columns.pop(k) |
|
|
continue |
|
|
|
|
|
|
|
|
safe_columns = {k: v for k, v in safe_columns.items() if k != v} |
|
|
if safe_columns: |
|
|
dataset = dataset.rename_columns(safe_columns) |
|
|
|
|
|
return dataset |
|
|
|
|
|
def _rename_columns(self, dataset: DATASET_TYPE) -> DATASET_TYPE: |
|
|
dataset = self.safe_rename_columns(dataset, self.origin_columns) |
|
|
dataset = self.safe_rename_columns(dataset, self.columns) |
|
|
if isinstance(dataset, HfIterableDataset): |
|
|
|
|
|
columns = {k: f'__@{k}' for k in RowPreprocessor.standard_keys if k in dataset.features} |
|
|
if columns: |
|
|
dataset = dataset.rename_columns(columns) |
|
|
return dataset |
|
|
|
|
|
@staticmethod |
|
|
def remove_useless_columns(dataset: DATASET_TYPE) -> DATASET_TYPE: |
|
|
dataset = RowPreprocessor.get_features_dataset(dataset) |
|
|
features = dataset.features |
|
|
k_list = [k for k in RowPreprocessor.standard_keys if k in features] |
|
|
if len(k_list) != len(features): |
|
|
dataset = dataset.select_columns(k_list) |
|
|
return dataset |
|
|
|
|
|
@staticmethod |
|
|
@contextmanager |
|
|
def _patch_arrow_writer(): |
|
|
|
|
|
from datasets.arrow_writer import ArrowWriter |
|
|
|
|
|
def _new_init(self, schema=None, features=None, *args, **kwargs): |
|
|
|
|
|
if features is not None: |
|
|
features['messages'] = [{'role': Value(dtype='string'), 'content': Value(dtype='string')}] |
|
|
features['images'] = [{'bytes': Value(dtype='binary'), 'path': Value(dtype='string')}] |
|
|
features['objects'] = { |
|
|
'ref': Sequence(feature=Value(dtype='string'), length=-1), |
|
|
'bbox': Sequence(feature=Sequence(feature=Value(dtype='float64'), length=-1), length=-1) |
|
|
} |
|
|
ArrowWriter.__origin_init__(self, schema, features, *args, **kwargs) |
|
|
|
|
|
ArrowWriter.__origin_init__ = ArrowWriter.__init__ |
|
|
ArrowWriter.__init__ = _new_init |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
ArrowWriter.__init__ = ArrowWriter.__origin_init__ |
|
|
del ArrowWriter.__origin_init__ |
|
|
|
|
|
def _cast_pil_image(self, dataset): |
|
|
features = dataset.features |
|
|
if 'images' in features and isinstance(features['images'], Image) and features['images'].decode: |
|
|
dataset = dataset.cast_column('images', Image(decode=False)) |
|
|
return dataset |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
dataset: DATASET_TYPE, |
|
|
*, |
|
|
num_proc: int = 1, |
|
|
load_from_cache_file: bool = True, |
|
|
strict: bool = False, |
|
|
batch_size: Optional[int] = None, |
|
|
) -> DATASET_TYPE: |
|
|
from ..utils import sample_dataset |
|
|
if batch_size is None: |
|
|
batch_size = 1000 if isinstance(dataset, HfDataset) else 16 |
|
|
if self.dataset_sample is not None: |
|
|
dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state) |
|
|
|
|
|
map_kwargs = {'batched': True, 'batch_size': batch_size} |
|
|
if isinstance(dataset, HfDataset): |
|
|
if not load_from_cache_file and is_dist() and not is_master(): |
|
|
load_from_cache_file = True |
|
|
map_kwargs.update({ |
|
|
'num_proc': num_proc, |
|
|
'load_from_cache_file': load_from_cache_file, |
|
|
}) |
|
|
|
|
|
dataset = RowPreprocessor.get_features_dataset(dataset) |
|
|
if 'solution' in dataset.features: |
|
|
with safe_ddp_context(None, True): |
|
|
dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) |
|
|
dataset = self._rename_columns(dataset) |
|
|
dataset = self.prepare_dataset(dataset) |
|
|
dataset = self._cast_pil_image(dataset) |
|
|
|
|
|
ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False |
|
|
with self._patch_arrow_writer(), safe_ddp_context(None, True): |
|
|
try: |
|
|
dataset_mapped = dataset.map( |
|
|
self.batched_preprocess, |
|
|
fn_kwargs={ |
|
|
'strict': strict, |
|
|
'ignore_max_length_error': ignore_max_length_error |
|
|
}, |
|
|
remove_columns=list(dataset.features.keys()), |
|
|
**map_kwargs) |
|
|
except NotImplementedError: |
|
|
pass |
|
|
if isinstance(dataset_mapped, HfDataset) and len(dataset) != len(dataset_mapped): |
|
|
logger.info( |
|
|
f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(dataset_mapped)}') |
|
|
|
|
|
return dataset_mapped |
|
|
|
|
|
|
|
|
class ResponsePreprocessor(RowPreprocessor): |
|
|
"""Dataset compatible with older versions of ms-swift""" |
|
|
|
|
|
def __init__(self, *, columns: Optional[Dict[str, str]] = None, **kwargs) -> None: |
|
|
super().__init__(columns=columns, **kwargs) |
|
|
system_keys = ['system', 'system_prompt'] |
|
|
query_keys = ['query', 'prompt', 'input', 'instruction', 'question', 'problem'] |
|
|
response_keys = ['response', 'answer', 'output', 'targets', 'target', 'answer_key', 'answers', 'solution' |
|
|
] + ['text', 'completion', 'content'] |
|
|
for key in system_keys: |
|
|
self.columns[key] = 'system' |
|
|
for key in query_keys: |
|
|
self.columns[key] = 'query' |
|
|
for key in response_keys: |
|
|
self.columns[key] = 'response' |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
response = row.pop('response', None) |
|
|
if response is not None: |
|
|
if isinstance(response, (list, tuple)): |
|
|
from transformers.utils import strtobool |
|
|
|
|
|
if strtobool(os.environ.get('RANDOM_DATASET_RESPONSE', 'True')): |
|
|
response = self.random_state.choice(response) |
|
|
else: |
|
|
response = response[0] |
|
|
history = row.pop('history', None) or [] |
|
|
query = row.pop('query', None) |
|
|
system = row.pop('system', None) |
|
|
if isinstance(history, str): |
|
|
history = ast.literal_eval(history) |
|
|
history.append([query, response]) |
|
|
|
|
|
row.update({'messages': history_to_messages(history, system)}) |
|
|
return row |
|
|
|
|
|
|
|
|
class AlpacaPreprocessor(ResponsePreprocessor): |
|
|
|
|
|
@classmethod |
|
|
def concat_inst_input(cls, instruction, input_): |
|
|
if instruction and input_: |
|
|
query = f'{instruction}\n{input_}' |
|
|
else: |
|
|
query = instruction or input_ |
|
|
assert isinstance(query, str), f'query: {query}' |
|
|
return query |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
instruction = row.pop('instruction', None) |
|
|
input_ = row.pop('input', None) |
|
|
output = row.pop('output', None) |
|
|
if output is not None: |
|
|
row['response'] = output |
|
|
row['query'] = self.concat_inst_input(instruction, input_) |
|
|
return super().preprocess(row) |
|
|
|
|
|
|
|
|
def default_repair_messages(s: Union[str, Any]) -> Any: |
|
|
if isinstance(s, str): |
|
|
return ast.literal_eval(s) |
|
|
return s |
|
|
|
|
|
|
|
|
class MessagesPreprocessor(RowPreprocessor): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
|
|
|
role_key: Optional[str] = None, |
|
|
content_key: Optional[str] = None, |
|
|
user_role: Optional[str] = None, |
|
|
assistant_role: Optional[str] = None, |
|
|
system_role: str = 'system', |
|
|
|
|
|
columns: Optional[Dict[str, str]] = None, |
|
|
repair_messages: Callable[[Union[str, List[Dict[str, str]]]], |
|
|
Optional[List[Dict[str, str]]]] = default_repair_messages, |
|
|
inner_key: Optional[str] = None, |
|
|
**kwargs): |
|
|
super().__init__(columns=columns, **kwargs) |
|
|
self.role_keys = ['role', 'from'] if role_key is None else [role_key] |
|
|
self.content_keys = ['content', 'value'] if content_key is None else [content_key] |
|
|
self.user_roles = ['user', 'human'] if user_role is None else [user_role] |
|
|
self.assistant_roles = ['assistant', 'gpt', 'bot'] if assistant_role is None else [assistant_role] |
|
|
self.tool_call_roles = ['function_call'] |
|
|
self.tool_response_roles = ['function_response', 'observation', 'observations'] |
|
|
|
|
|
self.system_role = system_role |
|
|
self.repair_messages = repair_messages |
|
|
self.inner_key = inner_key |
|
|
|
|
|
message_keys = ['messages', 'conversation', 'conversations'] |
|
|
for key in message_keys: |
|
|
self.columns[key] = 'messages' |
|
|
|
|
|
system_keys = ['system', 'system_prompt'] |
|
|
if system_role not in system_keys: |
|
|
system_keys.append(system_role) |
|
|
for key in system_keys: |
|
|
self.columns[key] = 'system' |
|
|
|
|
|
@staticmethod |
|
|
def _is_sharegpt_format(message: Dict[str, str]) -> bool: |
|
|
if 'role' in message or 'content' in message: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def sharegpt_to_messages(self, messages: List[Dict[str, str]], system: Optional[str]) -> List[Dict[str, str]]: |
|
|
self._to_std_key(messages, 'user', self.user_roles) |
|
|
self._to_std_key(messages, 'assistant', self.assistant_roles) |
|
|
new_messages = [] |
|
|
if system is not None: |
|
|
new_messages.append({'role': 'system', 'content': system}) |
|
|
for message in messages: |
|
|
user_message = {'role': 'user', 'content': message['user']} |
|
|
assistant_message = {'role': 'assistant', 'content': message['assistant']} |
|
|
new_messages.append(user_message) |
|
|
new_messages.append(assistant_message) |
|
|
return new_messages |
|
|
|
|
|
def to_std_messages(self, messages: List[Dict[str, str]], system: Optional[str]) -> None: |
|
|
if messages[0]['role'] == self.system_role: |
|
|
messages[0]['role'] = 'system' |
|
|
elif system is not None: |
|
|
messages.insert(0, {'role': 'system', 'content': system}) |
|
|
for message in messages: |
|
|
role = message['role'] |
|
|
if role in self.user_roles: |
|
|
message['role'] = 'user' |
|
|
elif role in self.assistant_roles: |
|
|
message['role'] = 'assistant' |
|
|
elif role.replace('-', '_') in self.tool_call_roles: |
|
|
message['role'] = 'tool_call' |
|
|
elif role.replace('-', '_') in self.tool_response_roles: |
|
|
message['role'] = 'tool_response' |
|
|
|
|
|
@staticmethod |
|
|
def _to_std_key(messages: List[Dict[str, str]], std_key: str, optional_keys: List[str]) -> None: |
|
|
for message in messages: |
|
|
for key in optional_keys: |
|
|
if key in message: |
|
|
message[std_key] = message.pop(key) |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
if 'rejected_messages' in row: |
|
|
row['rejected_messages'] = MessagesPreprocessor.preprocess( |
|
|
self, {'messages': row['rejected_messages']})['messages'] |
|
|
messages = row['messages'] |
|
|
if self.inner_key is not None: |
|
|
messages = messages[self.inner_key] |
|
|
messages: Optional[List[Dict[str, str]]] = self.repair_messages(messages) |
|
|
if not messages or isinstance(messages, str): |
|
|
return |
|
|
self._to_std_key(messages, 'role', self.role_keys) |
|
|
self._to_std_key(messages, 'content', self.content_keys) |
|
|
system = row.pop('system', None) |
|
|
if self._is_sharegpt_format(messages[0]): |
|
|
messages = self.sharegpt_to_messages(messages, system) |
|
|
else: |
|
|
self.to_std_messages(messages, system) |
|
|
row['messages'] = messages |
|
|
return row |
|
|
|
|
|
|
|
|
class ClsPreprocessor(ResponsePreprocessor): |
|
|
|
|
|
def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: |
|
|
res = super().preprocess(row) |
|
|
res['label'] = int(res['label']) |
|
|
return res |
|
|
|
|
|
|
|
|
class AutoPreprocessor: |
|
|
|
|
|
def __init__(self, *, columns: Optional[Dict[str, str]] = None, **kwargs) -> None: |
|
|
self.columns = columns or {} |
|
|
self.kwargs = kwargs |
|
|
|
|
|
def _get_preprocessor(self, dataset: DATASET_TYPE) -> RowPreprocessor: |
|
|
features = dataset.features |
|
|
for key in ['conversation', 'conversations', 'messages']: |
|
|
if key in features: |
|
|
return MessagesPreprocessor(**self.kwargs) |
|
|
if 'instruction' in features and 'input' in features: |
|
|
return AlpacaPreprocessor(**self.kwargs) |
|
|
return ResponsePreprocessor(**self.kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
dataset: DATASET_TYPE, |
|
|
*, |
|
|
num_proc: int = 1, |
|
|
load_from_cache_file: bool = True, |
|
|
strict: bool = False, |
|
|
) -> DATASET_TYPE: |
|
|
dataset = RowPreprocessor.safe_rename_columns(dataset, self.columns) |
|
|
preprocessor = self._get_preprocessor(dataset) |
|
|
return preprocessor(dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) |
|
|
|