| | |
| | import hashlib |
| | import inspect |
| | import math |
| | import os |
| | import re |
| | from contextlib import contextmanager, nullcontext |
| | from copy import deepcopy |
| | from dataclasses import asdict |
| | from functools import partial, wraps |
| | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from modelscope.hub.utils.utils import get_cache_dir |
| | from peft import PeftModel |
| | from PIL import Image |
| | from torch.nn.utils.rnn import pad_sequence |
| | from transformers import StoppingCriteriaList |
| | from transformers.integrations import is_deepspeed_zero3_enabled |
| | from transformers.utils import strtobool |
| |
|
| | from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc |
| | from ..utils import Processor, ProcessorMixin |
| | from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs |
| | from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by |
| | from .vision_utils import load_audio, load_batch, load_image, rescale_image |
| |
|
| | logger = get_logger() |
| | if TYPE_CHECKING: |
| | from .template_meta import TemplateMeta |
| |
|
| |
|
| | class MaxLengthError(ValueError): |
| | pass |
| |
|
| |
|
| | class Template(ProcessorMixin): |
| | special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>'] |
| | special_keys = ['images', 'videos', 'audios', 'objects'] |
| |
|
| | image_placeholder = ['<image>'] |
| | video_placeholder = ['<video>'] |
| | audio_placeholder = ['<audio>'] |
| | cot_process_placeholder = ['ки'] |
| | placeholder_tokens = [] |
| | load_images = True |
| | skip_prompt = True |
| | use_model = False |
| | norm_bbox = 'norm1000' |
| |
|
| | is_encoder_decoder = False |
| |
|
| | def __init__( |
| | self, |
| | processor: Processor, |
| | template_meta: 'TemplateMeta', |
| | default_system: Optional[str] = None, |
| | max_length: Optional[int] = None, |
| | *, |
| | use_chat_template: bool = True, |
| | truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', |
| | max_pixels: Optional[int] = None, |
| | agent_template: Optional[str] = None, |
| | norm_bbox: Literal['norm1000', 'none', None] = None, |
| | response_prefix: Optional[str] = None, |
| | |
| | padding_side: Literal['left', 'right'] = 'right', |
| | loss_scale: str = 'default', |
| | sequence_parallel_size: int = 1, |
| | |
| | template_backend: Literal['swift', 'jinja'] = 'swift', |
| | ) -> None: |
| | """ |
| | default_system: Override the default_system in the template. |
| | max_length: Max length of the sequence |
| | truncation_strategy: The truncation strategy |
| | max_pixels: Rescale image to reduce memory usage, default `None` means no limitation. |
| | e.g. 512 * 512 (H*W) |
| | padding_side: The padding_side when the training batch_size >= 2 |
| | loss_scale: The loss scale function to use |
| | """ |
| | from .template_meta import TemplateMeta |
| | from swift.plugin import agent_templates |
| |
|
| | self.processor = processor |
| | self.model_info = processor.model_info |
| | self.config = self.model_info.config |
| | self.model_meta = processor.model_meta |
| | if max_length is None: |
| | max_length = self.model_info.max_model_len |
| | tokenizer = self.tokenizer |
| |
|
| | if not use_chat_template: |
| | template_meta = template_meta.to_generate_template_meta() |
| | else: |
| | template_meta = deepcopy(template_meta) |
| | |
| | template_meta.check_system(default_system) |
| | if default_system is not None: |
| | template_meta.default_system = default_system |
| | if response_prefix is not None: |
| | template_meta.response_prefix = response_prefix |
| | logger.info(f'default_system: {repr(template_meta.default_system)}') |
| | logger.info(f'response_prefix: {repr(template_meta.response_prefix)}') |
| |
|
| | for i, token in enumerate(self.placeholder_tokens): |
| | if isinstance(token, str): |
| | self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token) |
| | template_meta.init(tokenizer) |
| |
|
| | self.template_meta: TemplateMeta = template_meta |
| | self.use_chat_template = use_chat_template |
| | self.template_backend = template_backend |
| | self.max_length = max_length |
| | self.truncation_strategy = truncation_strategy |
| | self.loss_scale = loss_scale |
| | self.max_pixels = max_pixels |
| | self.padding_side = padding_side |
| | self.sequence_parallel_size = sequence_parallel_size |
| | agent_template = agent_template or template_meta.agent_template |
| | logger.info(f'agent_template: {agent_template}') |
| | self.agent_template = agent_templates[agent_template]() |
| | self.norm_bbox = norm_bbox or self.norm_bbox |
| | logger.info(f'max_length: {self.max_length}') |
| | logger.info(f'norm_bbox: {self.norm_bbox}') |
| | if self.is_encoder_decoder: |
| | self.skip_prompt = False |
| | self.mode: Literal['pt', 'vllm', 'lmdeploy', |
| | 'train', 'rlhf', 'kto', |
| | 'seq_cls', 'embedding', 'prm'] = 'pt' |
| | self._packing = False |
| | self.use_megatron = False |
| | if self.model_info.task_type != 'causal_lm': |
| | self.mode = self.model_info.task_type |
| | self._handles = [] |
| | self._deepspeed_initialize = None |
| |
|
| | @staticmethod |
| | def _load_image(image, load_images: bool): |
| | if load_images: |
| | if isinstance(image, dict) and 'bytes' in image: |
| | image = image['bytes'] or image['path'] |
| | image = load_image(image) |
| | else: |
| | if isinstance(image, dict): |
| | path = image['path'] |
| | if path and (path.startswith('http') or os.path.exists(path)): |
| | image = path |
| | else: |
| | image = load_image(image['bytes']) |
| | elif not isinstance(image, str): |
| | image = load_image(image) |
| | return image |
| |
|
| | @staticmethod |
| | def _get_height_width(inputs: StdTemplateInputs) -> None: |
| | width = [] |
| | height = [] |
| | for image in inputs.images: |
| | width.append(image.width) |
| | height.append(image.height) |
| | inputs.objects['width'] = width |
| | inputs.objects['height'] = height |
| |
|
| | def normalize_bbox(self, inputs: StdTemplateInputs) -> None: |
| | objects = inputs.objects |
| | bbox_list = objects['bbox'] |
| | width_list = objects['width'] |
| | height_list = objects['height'] |
| | bbox_type = objects.pop('bbox_type', None) or 'real' |
| | image_id_list = objects.pop('image_id', None) or [] |
| | image_id_list += [0] * (len(bbox_list) - len(image_id_list)) |
| | for bbox, image_id in zip(bbox_list, image_id_list): |
| | if bbox_type == 'norm1': |
| | width, height = 1, 1 |
| | else: |
| | width, height = width_list[image_id], height_list[image_id] |
| | for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])): |
| | if self.norm_bbox == 'norm1000': |
| | norm_width, norm_height = 1000, 1000 |
| | elif self.norm_bbox == 'none': |
| | image = inputs.images[image_id] |
| | norm_width, norm_height = image.width, image.height |
| | bbox[2 * i] = int(round(x / width * norm_width)) |
| | bbox[2 * i + 1] = int(round(y / height * norm_height)) |
| |
|
| | def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: |
| | agent_template = self.agent_template |
| | agent_template.template_meta = self.template_meta |
| | if inputs.tools: |
| | if isinstance(inputs.tools, str): |
| | inputs.tools = agent_template._parse_json(inputs.tools) |
| | if not isinstance(inputs.tools, (list, tuple)): |
| | inputs.tools = [inputs.tools] |
| | elif isinstance(inputs.tools, (list, tuple)): |
| | inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools] |
| | else: |
| | raise ValueError(f'inputs.tools: {inputs.tools}') |
| | for i, tool in enumerate(inputs.tools): |
| | inputs.tools[i] = agent_template.wrap_tool(tool) |
| | i = 0 |
| | messages = inputs.messages |
| | while i < len(messages): |
| | if messages[i]['role'] == 'tool_call': |
| | i_start = i |
| | while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call': |
| | i += 1 |
| | tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1]) |
| | messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}] |
| | i = i_start + 1 |
| | else: |
| | i += 1 |
| |
|
| | def _preprocess_inputs( |
| | self, |
| | inputs: StdTemplateInputs, |
| | ) -> None: |
| | self._preprocess_function_call(inputs) |
| | if self.model_meta.is_multimodal: |
| | self._replace_image_tags(inputs) |
| | self._replace_start_image_tags(inputs) |
| | images = inputs.images |
| | load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'} |
| | load_images_origin = load_images |
| | if self.max_pixels is not None or inputs.objects: |
| | load_images = True |
| | if images: |
| | for i, image in enumerate(images): |
| | images[i] = self._load_image(images[i], load_images) |
| | if inputs.objects: |
| | self._get_height_width(inputs) |
| | if self.max_pixels is not None: |
| | |
| | images = [rescale_image(img, self.max_pixels) for img in images] |
| | if images and not load_images_origin: |
| | for i, image in enumerate(images): |
| | if isinstance(image, Image.Image): |
| | images[i] = self._save_pil_image(image) |
| | inputs.images = images |
| |
|
| | if self.mode == 'vllm' and inputs.audios: |
| | sampling_rate = get_env_args('sampling_rate', int, None) |
| | inputs.audios = load_batch( |
| | inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True)) |
| |
|
| | if inputs.is_multimodal: |
| | self._add_default_tags(inputs) |
| |
|
| | @staticmethod |
| | def _replace_image_tags(inputs: StdTemplateInputs): |
| | |
| | if inputs.images: |
| | return |
| | images = [] |
| | pattern = r'<img>(.+?)</img>' |
| | for message in inputs.messages: |
| | content = message['content'] |
| | if not isinstance(content, str): |
| | continue |
| | for image in re.findall(pattern, content): |
| | |
| | if os.path.isfile(image): |
| | images.append(image) |
| | else: |
| | logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>') |
| | message['content'] = re.sub(pattern, '<image>', content) |
| | inputs.images = images |
| |
|
| | @staticmethod |
| | def _replace_start_image_tags(inputs: StdTemplateInputs): |
| | |
| | generate_mode = False |
| | message = inputs.messages[-1] |
| | content = message['content'] |
| | if message['role'] == 'user' and content.endswith('<start-image>'): |
| | generate_mode = True |
| | message['content'] = message['content'][:-len('<start-image>')] |
| | inputs.generate_mode = generate_mode |
| |
|
| | @staticmethod |
| | def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int], |
| | get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]: |
| | added_tokens_len = 0 |
| | for i, idx in enumerate(replace_idx_list): |
| | new_tokens = get_new_tokens(i) |
| | token_len = len(new_tokens) |
| | input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:] |
| | if labels: |
| | labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:] |
| | added_tokens_len += token_len - 1 |
| | return input_ids, labels |
| |
|
| | def compute_loss_context(self, model, inputs): |
| | return nullcontext() |
| |
|
| | def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| | chosen_inputs, rejected_inputs = inputs, deepcopy(inputs) |
| | assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}' |
| | rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response |
| | chosen_encoded = self._encode_truncated(chosen_inputs) |
| | rejected_encoded = self._encode_truncated(rejected_inputs) |
| |
|
| | encoded = {} |
| | for prefix in ['chosen', 'rejected']: |
| | data = locals()[f'{prefix}_encoded'] |
| | for k, v in data.items(): |
| | encoded[f'{prefix}_{k}'] = v |
| | return encoded |
| |
|
| | def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| | label, inputs.label = inputs.label, None |
| | encoded = self._rlhf_encode(inputs) |
| | encoded['label'] = bool(label) |
| | return encoded |
| |
|
| | def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| | _encoded = {} |
| | labels = [] |
| |
|
| | def split_multi_medias(_inputs): |
| | _content = _inputs.messages[-2]['content'] |
| | image_size = len(re.findall('<image>', _content)) |
| | video_size = len(re.findall('<video>', _content)) |
| | audio_size = len(re.findall('<audio>', _content)) |
| | _inputs.images = inputs.images[:image_size] |
| | assert len(_inputs.images) == image_size |
| | inputs.images = inputs.images[image_size:] |
| | _inputs.videos = inputs.videos[:video_size] |
| | assert len(_inputs.videos) == video_size |
| | inputs.videos = inputs.videos[video_size:] |
| | _inputs.audios = inputs.audios[:audio_size] |
| | assert len(_inputs.audios) == audio_size |
| | inputs.audios = inputs.audios[audio_size:] |
| |
|
| | anchor = deepcopy(inputs) |
| | anchor.messages[-1]['content'] = '' |
| | anchor.rejected_response = [] |
| | split_multi_medias(anchor) |
| | anchor_encoded = self._encode_truncated(anchor) |
| | for key in anchor_encoded: |
| | _encoded[f'anchor_{key}'] = anchor_encoded[key] |
| |
|
| | positive = deepcopy(inputs) |
| | positive.messages[-2]['content'] = positive.messages[-1]['content'] |
| | positive.messages[-1]['content'] = '' |
| | positive.rejected_response = [] |
| | split_multi_medias(positive) |
| | positive_encoded = self._encode_truncated(positive) |
| | for key in positive_encoded: |
| | _encoded[f'positive_{key}'] = positive_encoded[key] |
| | labels.append(float(inputs.label) if inputs.label is not None else 1.0) |
| |
|
| | rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0 |
| | for i in range(rejected_len): |
| | negative = deepcopy(inputs) |
| | negative.messages[-2]['content'] = negative.rejected_response[i] |
| | negative.messages[-1]['content'] = '' |
| | negative.rejected_response = [] |
| | split_multi_medias(negative) |
| | negative_encoded = self._encode_truncated(negative) |
| | for key in negative_encoded: |
| | _encoded[f'negative{i}_{key}'] = negative_encoded[key] |
| | labels.append(0.0) |
| |
|
| | _encoded['labels'] = labels |
| | return _encoded |
| |
|
| | def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| | encoded = self._encode_truncated(inputs) |
| | encoded.pop('labels', None) |
| | if inputs.label is not None: |
| | labels = inputs.label |
| | problem_type = self._get_problem_type(self.config, labels=labels) |
| | if problem_type == 'single_label_classification': |
| | labels = int(labels) |
| | encoded['labels'] = labels |
| | return encoded |
| |
|
| | @torch.inference_mode() |
| | def encode(self, |
| | inputs: Union[TemplateInputs, Dict[str, Any], InferRequest], |
| | return_template_inputs: bool = False) -> Dict[str, Any]: |
| | """The entrance method of Template! |
| | |
| | Returns: |
| | return {'input_ids': List[int], 'labels': Optional[List[int]], ...} |
| | """ |
| | if isinstance(inputs, (InferRequest, TemplateInputs)): |
| | inputs = asdict(inputs) |
| |
|
| | if isinstance(inputs, dict): |
| | inputs = deepcopy(inputs) |
| | if not self.is_training: |
| | InferRequest.remove_response(inputs['messages']) |
| | inputs = StdTemplateInputs.from_dict(inputs) |
| | elif isinstance(inputs, StdTemplateInputs): |
| | inputs = deepcopy(inputs) |
| | assert isinstance(inputs, StdTemplateInputs) |
| | self._preprocess_inputs(inputs) |
| |
|
| | if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}: |
| | encoded = self._encode_truncated(inputs) |
| | elif self.mode == 'seq_cls': |
| | encoded = self._seq_cls_encode(inputs) |
| | elif self.mode == 'rlhf': |
| | encoded = self._rlhf_encode(inputs) |
| | elif self.mode == 'kto': |
| | encoded = self._kto_encode(inputs) |
| | elif self.mode == 'embedding': |
| | encoded = self._embedding_encode(inputs) |
| | for key in list(encoded.keys()): |
| | if encoded[key] is None: |
| | encoded.pop(key) |
| | if return_template_inputs: |
| | encoded['template_inputs'] = inputs |
| | return encoded |
| |
|
| | def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: |
| | packed = {} |
| | keys = set() |
| | for r in row: |
| | keys.update(r[0].keys()) |
| | for key in keys: |
| | if key in {'input_ids', 'labels', 'loss_scale'}: |
| | packed[key] = sum((x[0][key] for x in row), start=[]) |
| | if 'position_ids' not in packed: |
| | packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[]) |
| |
|
| | packed.update(self._data_collator_mm_data([r[0] for r in row])) |
| | return packed |
| |
|
| | def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| | return inputs |
| |
|
| | @staticmethod |
| | def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]: |
| | len_tokens = len(stop_tokens) |
| | if is_finished and generate_ids[-len_tokens:] == stop_tokens: |
| | return generate_ids[:-len_tokens] |
| | if not is_finished: |
| | for i in range(len_tokens, 0, -1): |
| | if generate_ids[-i:] == stop_tokens[:i]: |
| | return generate_ids[:-i] |
| | return generate_ids |
| |
|
| | @staticmethod |
| | def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int): |
| | idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist() |
| | logprobs = logprobs.tolist() |
| | return { |
| | 'content': [{ |
| | 'index': pred, |
| | 'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred], |
| | 'top_logprobs': [{ |
| | 'index': idx, |
| | 'logprob': logprobs[idx] |
| | } for idx in idxs] |
| | }] |
| | } |
| |
|
| | @staticmethod |
| | def _get_problem_type(config, labels=None, logits=None) -> str: |
| | problem_type = config.problem_type |
| | if problem_type is not None: |
| | return problem_type |
| | if labels is not None: |
| | if isinstance(labels, (list, tuple)): |
| | if labels and isinstance(labels[0], float): |
| | problem_type = 'regression' |
| | else: |
| | problem_type = 'multi_label_classification' |
| | else: |
| | problem_type = 'single_label_classification' |
| | assert config.num_labels >= labels + 1 |
| | if logits is not None: |
| | if logits.shape[-1] == 1: |
| | problem_type = 'regression' |
| | else: |
| | problem_type = 'single_label_classification' |
| | assert problem_type is not None |
| | config.problem_type = problem_type |
| | return problem_type |
| |
|
| | def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int): |
| | assert isinstance(logits, torch.Tensor) |
| | problem_type = self._get_problem_type(self.config, logits=logits) |
| | if problem_type == 'regression': |
| | preds = logits.squeeze(dim=-1).tolist() |
| | logprobs = [None] * len(preds) |
| | else: |
| | if problem_type == 'single_label_classification': |
| | preds = torch.argmax(logits, dim=-1).tolist() |
| | logprobs = torch.log_softmax(logits, -1) |
| | else: |
| | preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)] |
| | logprobs = F.logsigmoid(logits) |
| | logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)] |
| | return preds, logprobs |
| |
|
| | def decode(self, |
| | generate_ids: List[int], |
| | *, |
| | is_finished: bool = True, |
| | tokenizer_kwargs=None, |
| | first_token=True, |
| | **kwargs) -> Any: |
| | tokenizer_kwargs = tokenizer_kwargs or {} |
| | response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs) |
| | if first_token and self.template_meta.response_prefix: |
| | response = self.template_meta.response_prefix + response |
| | return response |
| |
|
| | def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any: |
| | raise NotImplementedError |
| |
|
| | def generate(self, model, *args, **kwargs): |
| | if isinstance(model, PeftModel): |
| | signature = inspect.signature(model.model.generate) |
| | else: |
| | signature = inspect.signature(model.generate) |
| | if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs: |
| | kwargs['use_model_defaults'] = False |
| | return model.generate(*args, **kwargs) |
| |
|
| | def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any: |
| | |
| | |
| | tokenizer = self.tokenizer |
| |
|
| | if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id: |
| | generate_ids = generate_ids[:-1] |
| | |
| | template_suffix = self.template_meta.suffix[-1] |
| | if isinstance(template_suffix, str): |
| | |
| | template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:] |
| | generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished) |
| | if 'spaces_between_special_tokens' not in decode_kwargs: |
| | decode_kwargs['spaces_between_special_tokens'] = False |
| | return tokenizer.decode(generate_ids, **decode_kwargs) |
| |
|
| | def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]: |
| | generation_config = generate_kwargs['generation_config'] |
| | stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words |
| | generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)]) |
| | return generate_kwargs |
| |
|
| | @staticmethod |
| | def _save_pil_image(image: Image.Image) -> str: |
| | img_bytes = image.tobytes() |
| | img_hash = hashlib.sha256(img_bytes).hexdigest() |
| | tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images') |
| | logger.info_once(f'create tmp_dir: {tmp_dir}') |
| | os.makedirs(tmp_dir, exist_ok=True) |
| | img_path = os.path.join(tmp_dir, f'{img_hash}.png') |
| | if not os.path.exists(img_path): |
| | image.save(img_path) |
| | return img_path |
| |
|
| | @staticmethod |
| | def _concat_context_list( |
| | context_list: List[Context], |
| | res_context_list: List[Context], |
| | res_context_type: List[ContextType], |
| | system: Optional[str] = None, |
| | query: Optional[str] = None, |
| | response: Optional[str] = None, |
| | round0: Optional[int] = None) -> None: |
| | """Concat context list and replace placeholder""" |
| | round1 = None |
| | if round0 is not None: |
| | round1 = str(round0 + 1) |
| | round0 = str(round0) |
| | for context in context_list: |
| | if isinstance(context, str): |
| | if '{{RESPONSE}}' == context: |
| | assert response is not None |
| | res_context_list.append(response) |
| | res_context_type.append(ContextType.RESPONSE) |
| | continue |
| | old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}'] |
| | new_str_list = [system, query, round0, round1] |
| | for (old_str, new_str) in zip(old_str_list, new_str_list): |
| | if new_str is not None and old_str in context: |
| | assert isinstance(new_str, str), f'new_str: {new_str}' |
| | context = context.replace(old_str, new_str) |
| | if len(context) == 0: |
| | continue |
| | res_context_list.append(context) |
| | res_context_type.append(ContextType.OTHER) |
| |
|
| | def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float], |
| | inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
| | """Merge anything in the context to simplify the inputs""" |
| | context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list) |
| | context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs) |
| |
|
| | res: List[Context] = [] |
| | res_loss_scale: List[float] = [] |
| | temp: List[str] = [] |
| | temp_loss_scale = 0. |
| | for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)): |
| | if isinstance(context, str) and (loss_scale == temp_loss_scale): |
| | temp.append(context) |
| | else: |
| | if len(temp) > 0: |
| | res.append(''.join(temp)) |
| | res_loss_scale.append(temp_loss_scale) |
| | temp.clear() |
| | if isinstance(context, str): |
| | temp.append(context) |
| | else: |
| | res.append(context) |
| | res_loss_scale.append(loss_scale) |
| | temp_loss_scale = loss_scale |
| | if len(temp) > 0: |
| | res.append(''.join(temp)) |
| | res_loss_scale.append(temp_loss_scale) |
| |
|
| | return res, res_loss_scale |
| |
|
| | @staticmethod |
| | def _split_special_tokens(context_list: List[Context], |
| | loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]: |
| | """Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation""" |
| | res: List[Context] = [] |
| | loss_scale_res: List[float] = [] |
| | for context, loss_scale in zip(context_list, loss_scale_list): |
| | contexts = [] |
| | if isinstance(fetch_one(context), str): |
| | for d in split_str_parts_by(context, Template.special_tokens): |
| | contexts.extend([d['key'], d['content']]) |
| | contexts = [c for c in contexts if c] |
| | res.extend(contexts) |
| | loss_scale_res.extend([loss_scale] * len(contexts)) |
| | else: |
| | res.append(context) |
| | loss_scale_res.append(loss_scale) |
| | return res, loss_scale_res |
| |
|
| | def _tokenize(self, context, **tokenizer_kwargs): |
| | return self.tokenizer( |
| | context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids'] |
| |
|
| | def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
| | inputs: StdTemplateInputs) -> List[Context]: |
| | """Override this function to do your own replace operation. |
| | |
| | This method is used to replace standard tags like `<image>` to some tokens that the model needs. |
| | |
| | Args: |
| | media_type: The modal. |
| | index: The index of the medias, for index 0 represents the first elements in `images` |
| | inputs: The inputs |
| | |
| | Returns: |
| | The content or input_ids after replacement. |
| | """ |
| | if media_type == 'image': |
| | if self.mode == 'lmdeploy': |
| | return [[-100]] |
| | return self.image_placeholder |
| | elif media_type == 'video': |
| | return self.video_placeholder |
| | elif media_type == 'audio': |
| | return self.audio_placeholder |
| |
|
| | def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: |
| | """Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task. |
| | Override this function to do your own replace operation. |
| | |
| | Args: |
| | ref: Description of the bbox |
| | index: The index in the `objects` key |
| | inputs: The inputs |
| | |
| | Returns: |
| | The contents or input_ids replaced |
| | """ |
| | return [ref] |
| |
|
| | def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]: |
| | """Replace the cot process label for PRM training or inference. |
| | Override this function to do your own replace operation. |
| | |
| | Args: |
| | inputs: The inputs |
| | |
| | Returns: |
| | The contents or input_ids replaced |
| | """ |
| | return [self.cot_process_placeholder] |
| |
|
| | @staticmethod |
| | def _get_bbox_str(bbox: List[int]) -> str: |
| | point = [] |
| | for x, y in zip(bbox[::2], bbox[1::2]): |
| | point.append(f'({x},{y})') |
| | return ','.join(point) |
| |
|
| | def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: |
| | """Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task. |
| | Override this function to do your own replace operation. |
| | |
| | Args: |
| | bbox: [x, y] or [x1, y1, x2, y2] |
| | index: The index in the `objects` key |
| | inputs: The inputs |
| | |
| | Returns: |
| | The contents or input_ids replaced |
| | """ |
| | return [f'[{self._get_bbox_str(bbox)}]'] |
| |
|
| | def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float], |
| | inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
| | |
| | |
| | res: List[Context] = [] |
| | res_loss_scale: List[float] = [] |
| | inputs.image_idx = 0 |
| |
|
| | for context, loss_scale in zip(context_list, loss_scale_list): |
| | if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images): |
| | c_list = self.replace_tag('image', inputs.image_idx, inputs) |
| | inputs.image_idx += 1 |
| | loss_scale = 0. if self.template_backend == 'swift' else 1. |
| | else: |
| | c_list = [context] |
| | res += c_list |
| | res_loss_scale += [loss_scale] * len(c_list) |
| | return res, res_loss_scale |
| |
|
| | def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float], |
| | inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
| | """This method happens before tokenization, replace standard tags to the contents or input_ids needed by |
| | the model. |
| | |
| | Args: |
| | context_list: The content list |
| | loss_scale_list: The loss scale list |
| | Returns: |
| | The context_list and loss_scale_list after replacement. |
| | """ |
| | context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs) |
| | if inputs.images and inputs.objects: |
| | self.normalize_bbox(inputs) |
| | |
| | res: List[Context] = [] |
| | res_loss_scale: List[float] = [] |
| |
|
| | |
| | for k in ['video', 'audio', 'object', 'box']: |
| | setattr(inputs, f'{k}_idx', 0) |
| |
|
| | for context, loss_scale in zip(context_list, loss_scale_list): |
| | for k in ['video', 'audio']: |
| | if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len( |
| | getattr(inputs, f'{k}s')): |
| | c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs) |
| | setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1) |
| | loss_scale = 0. |
| | break |
| | else: |
| | ref = inputs.objects.get('ref') or [] |
| | bbox = inputs.objects.get('bbox') or [] |
| | if context == '<ref-object>' and inputs.ref_idx < len(ref): |
| | idx = inputs.ref_idx |
| | c_list = self.replace_ref(ref[idx], idx, inputs) |
| | inputs.ref_idx += 1 |
| | elif context == '<bbox>' and inputs.bbox_idx < len(bbox): |
| | idx = inputs.bbox_idx |
| | c_list = self.replace_bbox(bbox[idx], idx, inputs) |
| | inputs.bbox_idx += 1 |
| | elif context == '<cot-process>' and self.mode == 'prm': |
| | c_list = self.replace_cot_process(inputs) |
| | else: |
| | c_list = [context] |
| | res += c_list |
| | res_loss_scale += [loss_scale] * len(c_list) |
| | return res, res_loss_scale |
| |
|
| | @staticmethod |
| | def _add_default_tags(inputs: StdTemplateInputs): |
| | total_content = '\n'.join([message['content'] or '' for message in inputs.messages]) |
| | if inputs.rejected_response: |
| | if isinstance(inputs.rejected_response, str): |
| | total_content += inputs.rejected_response |
| | else: |
| | total_content += '\n'.join(inputs.rejected_response) |
| | if inputs.system: |
| | total_content = f'{inputs.system}\n{total_content}' |
| | for media_type in ['image', 'audio', 'video']: |
| | media_key, media_tag = f'{media_type}s', f'<{media_type}>' |
| | medias = getattr(inputs, media_key) |
| | if not isinstance(medias, list): |
| | medias = [medias] |
| | if medias: |
| | num_media_tags = len(re.findall(media_tag, total_content)) |
| | num_media = len(medias) |
| | num_new_tags = num_media - num_media_tags |
| | if num_new_tags > 0: |
| | inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content'] |
| | elif num_new_tags < 0: |
| | logger.warning( |
| | f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. ' |
| | 'We will only replace the frontmost media_tags while keeping the subsequent media_tags.') |
| |
|
| | def _encode_context_list( |
| | self, |
| | context_list: List[Context], |
| | loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]: |
| | """return: input_ids, labels, tokenizer_kwargs""" |
| | input_ids: List[int] = [] |
| | labels: List[int] = [] |
| | loss_scale: List[float] = [] |
| | tokenizer_kwargs = {} |
| | if loss_scale_list is None: |
| | loss_scale_list = [0.] * len(context_list) |
| | ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list) |
| | for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): |
| | if isinstance(context, str): |
| | |
| | |
| | token_list = self._tokenize(context) |
| | else: |
| | token_list = context |
| | input_ids += token_list |
| | if loss_scale_list[i] > 0.0: |
| | labels += token_list |
| | else: |
| | labels += [-100] * len(token_list) |
| | if not ignore_loss_scale: |
| | loss_scale.extend([loss_weight] * len(token_list)) |
| | if ignore_loss_scale: |
| | loss_scale = None |
| | return input_ids, labels, loss_scale, tokenizer_kwargs |
| |
|
| | @staticmethod |
| | def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]], |
| | suffix_tokens_id: List[int]) -> None: |
| | suffix_len = len(suffix_tokens_id) |
| | start = 0 |
| | for i in range(1, len(labels)): |
| | if labels[i - 1] >= 0 and labels[i] == -100: |
| | start = i |
| | if start > 0 and labels[i - 1] == -100 and labels[i] >= 0: |
| | |
| | length = i - start |
| | if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id: |
| | labels[start:start + suffix_len] = suffix_tokens_id |
| | if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len: |
| | loss_scale[start:start + suffix_len] = [1] * suffix_len |
| |
|
| | @staticmethod |
| | def _get_std_messages(messages): |
| | if messages and messages[0]['role'] == 'assistant': |
| | messages.insert(0, {'role': 'user', 'content': ''}) |
| | if len(messages) % 2 == 1: |
| | messages.append({'role': 'assistant', 'content': None}) |
| |
|
| | def _jinja_encode(self, inputs: StdTemplateInputs): |
| | messages = inputs.messages.copy() |
| | if inputs.system is not None: |
| | messages.insert(0, {'role': 'system', 'content': inputs.system}) |
| | if messages[-1]['content'] is None: |
| | messages.pop() |
| | add_generation_prompt = messages[-1]['role'] != 'assistant' |
| | kwargs = {} |
| | if inputs.tools: |
| | kwargs['tools'] = inputs.tools |
| | text = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs) |
| | answer_len = 1 if self.is_training else 0 |
| | return [text], [1.], answer_len |
| |
|
| | def _get_system(self, inputs) -> Optional[str]: |
| | template_meta = self.template_meta |
| | system = inputs.system |
| | tools = inputs.tools |
| | template_meta.check_system(system) |
| | if system is None: |
| | system = template_meta.default_system |
| |
|
| | if tools is not None: |
| | system = self.agent_template._format_tools(tools, system or '', inputs.messages[0]) |
| | return system |
| |
|
| | @staticmethod |
| | def _swift_prepare_function_call(agent_template, messages): |
| | if len(messages) < 2: |
| | return |
| | i = 1 |
| | while i < len(messages): |
| | pre_message, message = messages[i - 1], messages[i] |
| | pre_role, pre_content = pre_message['role'], pre_message['content'] |
| | role, content = message['role'], message['content'] |
| | if pre_role == 'assistant' and role == 'tool': |
| | i_start = i |
| | while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool': |
| | i += 1 |
| | pre_message['content'], tool_content = agent_template._format_tool_responses( |
| | pre_content, messages[i_start:i + 1]) |
| | messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}] |
| | i = i_start + 1 |
| | elif pre_role == 'assistant' and role == 'assistant': |
| | |
| | pre_message['content'] = pre_content + content |
| | messages.pop(i) |
| | else: |
| | i += 1 |
| |
|
| | def _swift_encode(self, inputs: StdTemplateInputs): |
| | template_meta = self.template_meta |
| | system = self._get_system(inputs) |
| | self._swift_prepare_function_call(self.agent_template, inputs.messages) |
| |
|
| | self._get_std_messages(inputs.messages) |
| | n_round = len(inputs.messages) // 2 |
| | if n_round > 1 and not self.template_meta.support_multi_round: |
| | logger.warning_once( |
| | 'The template does not support multi-round chat. Only use the last round of the conversation.') |
| | inputs.messages = inputs.messages[-2:] |
| |
|
| | res_context_list: List[Context] = [] |
| | res_context_types: List[ContextType] = [] |
| | sep_token = None |
| | if template_meta.auto_add_bos: |
| | all_tokens = self.tokenizer.encode('a') |
| | single_token = self.tokenizer.encode('a', add_special_tokens=False) |
| | assert len(single_token) == 1 |
| | idx = all_tokens.index(single_token[0]) |
| | bos_token = all_tokens[:idx] |
| | sep_token = all_tokens[idx + 1:] |
| | if bos_token: |
| | res_context_list.append(bos_token) |
| | res_context_types.append(ContextType.OTHER) |
| |
|
| | prefix = template_meta.system_prefix if system else template_meta.prefix |
| | self._concat_context_list(prefix, res_context_list, res_context_types, system=system) |
| |
|
| | n_round = len(inputs.messages) // 2 |
| | for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): |
| | query_role, query = query_message['role'], query_message['content'] |
| | response_role, response = response_message['role'], response_message['content'] |
| | |
| | assert query_role in {'user', 'tool'}, f'query_role: {query_role}' |
| | assert response_role in {'assistant'}, f'response_role: {response_role}' |
| | if query_role == 'tool': |
| | prompt = query |
| | query = '' |
| | elif template_meta.is_post_system and i == n_round - 1: |
| | prompt = template_meta.system_prompt |
| | else: |
| | prompt = template_meta.prompt |
| |
|
| | context_list = prompt.copy() |
| | extra_context_list = [] |
| | extra_context_type = None |
| | if i < n_round - 1: |
| | |
| | context_list.append('{{RESPONSE}}') |
| | if inputs.messages[2 * (i + 1)]['role'] != 'tool': |
| | extra_context_list = template_meta.chat_sep |
| | extra_context_type = ContextType.OTHER |
| | elif response is not None: |
| | |
| | context_list.append('{{RESPONSE}}') |
| | if self.is_training and not sep_token: |
| | extra_context_list = template_meta.suffix |
| | extra_context_type = ContextType.SUFFIX |
| | elif template_meta.response_prefix: |
| | |
| | context_list.append(template_meta.response_prefix) |
| |
|
| | self._concat_context_list( |
| | context_list, |
| | res_context_list, |
| | res_context_types, |
| | query=query, |
| | response=response, |
| | system=system, |
| | round0=i) |
| | res_context_list += extra_context_list |
| | res_context_types += [extra_context_type] * len(extra_context_list) |
| | if template_meta.auto_add_bos and sep_token: |
| | res_context_list.append(sep_token) |
| | res_context_types.append(ContextType.SUFFIX) |
| | from swift.plugin import loss_scale_map |
| | res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types, |
| | inputs.messages) |
| | if self.is_training: |
| | answer_len = len(extra_context_list) + bool(response is not None) |
| | else: |
| | answer_len = 0 |
| | return res_context_list, loss_scale_list, answer_len |
| |
|
| | def _encode_truncated(self, inputs): |
| | if self.mode in {'vllm', 'lmdeploy'}: |
| | encoded = Template._encode(self, inputs) |
| | for key in ['images', 'audios', 'videos']: |
| | encoded[key] = getattr(inputs, key) |
| | else: |
| | encoded = self._encode(inputs) |
| |
|
| | input_ids = encoded.get('input_ids') |
| | labels = encoded.get('labels') |
| | loss_scale = encoded.get('loss_scale') |
| | if self.max_length is not None: |
| | if self.truncation_strategy == 'right': |
| | input_ids = input_ids[:self.max_length] |
| | if labels is not None: |
| | labels = labels[:self.max_length] |
| | if loss_scale is not None: |
| | loss_scale = loss_scale[:self.max_length] |
| | elif self.truncation_strategy == 'left': |
| | if len(input_ids) > self.max_length: |
| | logger.warning_once( |
| | 'Input data was left-truncated because its length exceeds `max_length` (input length: ' |
| | f'{len(input_ids)}, max_length: {self.max_length}). ' |
| | 'This may cause loss of important tokens (e.g., image tokens) and lead to errors. ' |
| | 'To avoid this, consider increasing `max_length` or pre-filtering long sequences.', |
| | hash_id='max_length_check') |
| | input_ids = input_ids[-self.max_length:] |
| | if labels is not None: |
| | labels = labels[-self.max_length:] |
| | if loss_scale is not None: |
| | loss_scale = loss_scale[-self.max_length:] |
| | elif self.truncation_strategy == 'raise': |
| | length = len(input_ids or labels or []) |
| | if length > self.max_length: |
| | raise MaxLengthError(f'Current length of row({length}) is larger' |
| | f' than the max_length({self.max_length}).') |
| | encoded['input_ids'] = input_ids |
| | encoded['labels'] = labels |
| | encoded['loss_scale'] = loss_scale |
| | return encoded |
| |
|
| | def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
| | template_backend = self.template_backend |
| | if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training |
| | and self.mode != 'seq_cls'): |
| | template_backend = 'jinja' |
| | logger.info_once(f'Setting template_backend: {template_backend}') |
| | res_context_list, loss_scale_list, answer_len = ( |
| | self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs)) |
| | encoded = {} |
| | if self.is_encoder_decoder: |
| | |
| | total_len = len(res_context_list) |
| | for key, _slice in zip(['prompt', 'answer'], |
| | [slice(0, total_len - answer_len), |
| | slice(total_len - answer_len, total_len)]): |
| | context_list, loss_scale = self._simplify_context_list(res_context_list[_slice], |
| | loss_scale_list[_slice], inputs) |
| | input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale) |
| | encoded[f'{key}_input_ids'] = input_ids |
| | if key == 'answer': |
| | encoded['labels'] = labels |
| | encoded['loss_scale'] = loss_scale |
| | input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids'] |
| | else: |
| | res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs) |
| | input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list( |
| | res_context_list, loss_scale_list) |
| | self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0]) |
| |
|
| | if tokenizer_kwargs: |
| | encoded['tokenizer_kwargs'] = tokenizer_kwargs |
| |
|
| | encoded['input_ids'] = input_ids |
| | encoded['labels'] = labels |
| | encoded['loss_scale'] = loss_scale |
| | if self.use_megatron: |
| | self._handle_megatron_cp(encoded) |
| | encoded['labels'] = encoded['labels'][1:] + [-100] |
| | encoded['position_ids'] = list(range(len(encoded['labels']))) |
| | elif encoded.get('labels') is not None: |
| | encoded['labels'][0] = -100 |
| | if not self.is_training: |
| | for k in list(encoded.keys()): |
| | if k.endswith('labels') or k.endswith('loss_scale'): |
| | encoded[k] = None |
| | return encoded |
| |
|
| | def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: |
| | cp_size = self.sequence_parallel_size |
| | if cp_size == 1: |
| | return |
| | input_ids = encoded['input_ids'] |
| | padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) |
| | input_ids += [self.tokenizer.pad_token_id] * padding_len |
| | encoded['labels'] += [-100] * padding_len |
| |
|
| | def debug_logger(self, inputs): |
| | if not strtobool(os.getenv('SWIFT_DEBUG', 'false')): |
| | return |
| | if 'input_ids' in inputs: |
| | k = 'input_ids' |
| | val = inputs['input_ids'] |
| | elif 'generate_ids' in inputs: |
| | k = 'generate_ids' |
| | val = inputs['generate_ids'] |
| | for v in val: |
| | self.print_inputs({k: v.tolist()}) |
| |
|
| | @staticmethod |
| | def _split_list(inputs: List[int], x: int) -> List[List[int]]: |
| | idxs = findall(inputs, x) |
| | idxs.append(len(inputs)) |
| | res = [] |
| | lo = 0 |
| | for idx in idxs: |
| | res.append(inputs[lo:idx]) |
| | lo = idx + 1 |
| | return res |
| |
|
| | def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]: |
| | context_list = [] |
| | if self.mode in {'vllm', 'lmdeploy'}: |
| | video = inputs.videos.pop(inputs.video_idx) |
| | inputs.video_idx -= 1 |
| | else: |
| | video = inputs.videos[inputs.video_idx] |
| | images = inputs.images |
| | new_images = load_video_func(video) |
| | inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:] |
| | for i in range(len(new_images)): |
| | context_list += replace_tag(i) |
| | inputs.image_idx += len(new_images) |
| | return context_list |
| |
|
| | def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]], |
| | num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]: |
| | if self.skip_prompt: |
| | generate_ids = generate_ids[..., num_prompt_tokens:] |
| | return generate_ids |
| |
|
| | def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str: |
| | return response |
| |
|
| | def pre_forward_hook(self, model: nn.Module, args, kwargs): |
| | from swift.llm import to_device |
| | old_kwargs = to_device(kwargs, model.device) |
| | kwargs = to_device(self._post_encode(model, old_kwargs), model.device) |
| | for k, v in old_kwargs.items(): |
| | if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs: |
| | kwargs[k] = v |
| | if 'inputs_embeds' in kwargs: |
| | kwargs.pop('input_ids', None) |
| |
|
| | if isinstance(model, PeftModel): |
| | parameters = inspect.signature(model.model.forward).parameters |
| | else: |
| | parameters = inspect.signature(model.forward).parameters |
| | if 'position_ids' not in parameters: |
| | kwargs.pop('position_ids', None) |
| | return args, kwargs |
| |
|
| | @property |
| | def is_training(self): |
| | return self.mode not in {'vllm', 'lmdeploy', 'pt'} |
| |
|
| | def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None: |
| | self.mode = mode |
| |
|
| | def register_post_encode_hook(self, models: List[nn.Module]) -> None: |
| | """This function is important for multi-modal training, as it registers the post_encode method |
| | as a forward hook, converting input_ids into inputs_embeds. |
| | """ |
| | if self._handles: |
| | return |
| |
|
| | for model in models: |
| | |
| | handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True) |
| | self._handles.append((model, handle)) |
| |
|
| | if is_deepspeed_zero3_enabled(): |
| | import deepspeed |
| | self._deepspeed_initialize = deepspeed.initialize |
| |
|
| | @wraps(self._deepspeed_initialize) |
| | def _initialize(*args, **kwargs): |
| | res = self._deepspeed_initialize(*args, **kwargs) |
| | for model, handle in self._handles: |
| | model._forward_pre_hooks.move_to_end(handle.id) |
| | return res |
| |
|
| | deepspeed.initialize = _initialize |
| |
|
| | def remove_post_encode_hook(self): |
| | models = [] |
| | for model, handle in self._handles: |
| | models.append(model) |
| | handle.remove() |
| | self._handles = [] |
| |
|
| | if self._deepspeed_initialize is not None: |
| | import deepspeed |
| | deepspeed.initialize = self._deepspeed_initialize |
| | self._deepspeed_initialize = None |
| | return models |
| |
|
| | def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | if self.mode == 'rlhf': |
| | return self._rlhf_data_collator(batch, padding_to=padding_to) |
| | elif self.mode == 'kto': |
| | return self._kto_data_collator(batch, padding_to=padding_to) |
| | elif self.mode in {'pt', 'train', 'prm'}: |
| | return self._data_collator(batch, padding_to=padding_to) |
| | elif self.mode == 'seq_cls': |
| | return self._seq_cls_data_collator(batch, padding_to=padding_to) |
| | elif self.mode == 'embedding': |
| | return self._embedding_data_collator(batch, padding_to=padding_to) |
| |
|
| | @staticmethod |
| | def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]: |
| | new_batch = [] |
| | for inputs in batch: |
| | new_inputs = {} |
| | for k, v in inputs.items(): |
| | if k.startswith(prefix): |
| | new_inputs[k[len(prefix):]] = v |
| | new_batch.append(new_inputs) |
| | return new_batch |
| |
|
| | @staticmethod |
| | def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]: |
| | from swift.llm import RowPreprocessor |
| | keys = keys or [] |
| | rows = RowPreprocessor.rows_to_batched(batch) |
| | return {k: rows[k] for k in keys if rows.get(k) is not None} |
| |
|
| | @staticmethod |
| | def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]: |
| | |
| | res = [] |
| | for b in batch: |
| | if b.get(attr_name) is not None: |
| | res += b.pop(attr_name) |
| | return res |
| |
|
| | @staticmethod |
| | def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]: |
| | res = [] |
| | for b in batch: |
| | if b.get(attr_name) is not None: |
| | res.append(b.pop(attr_name)) |
| | return torch.concat(res, dim=dim) if res else None |
| |
|
| | def _rlhf_data_collator(self, |
| | batch: List[Dict[str, Any]], |
| | *, |
| | chosen_prefix: str = 'chosen_', |
| | rejected_prefix: str = 'rejected_', |
| | padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | new_batch = [] |
| | for prefix in [chosen_prefix, rejected_prefix]: |
| | new_batch += self._fetch_inputs_startswith(batch, prefix) |
| | return self._data_collator(new_batch, padding_to=padding_to) |
| |
|
| | def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | new_batch = self._fetch_inputs_startswith(batch, 'chosen_') |
| | kl_batch = self._fetch_inputs_startswith(batch, 'rejected_') |
| |
|
| | res = self._data_collator(new_batch, padding_to=padding_to) |
| | kl_res = self._data_collator(kl_batch, padding_to=padding_to) |
| | res = { |
| | **{f'completion_{k}': v |
| | for k, v in res.items()}, |
| | **{f'KL_completion_{k}': v |
| | for k, v in kl_res.items()}, |
| | } |
| | label = [b['label'] for b in batch if b.get('label') is not None] |
| | if label: |
| | res['label'] = label |
| | return res |
| |
|
| | def _embedding_data_collator(self, |
| | batch: List[Dict[str, Any]], |
| | *, |
| | padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | labels = [] |
| | new_batch = [] |
| | for b in batch: |
| | keys = [key for key in b.keys() if 'negative' in key] |
| | max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None |
| | indexes = ['anchor_', 'positive_'] |
| | if max_neg is not None: |
| | for i in range(0, max_neg + 1): |
| | indexes.append(f'negative{i}_') |
| | for prefix in indexes: |
| | new_batch += self._fetch_inputs_startswith([b], prefix) |
| | labels.extend(b.get('labels', None)) |
| | res = self._data_collator(new_batch, padding_to=padding_to) |
| | if labels: |
| | res['labels'] = torch.tensor(labels, dtype=torch.float32) |
| | return res |
| |
|
| | def _seq_cls_data_collator(self, |
| | batch: List[Dict[str, Any]], |
| | *, |
| | padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | labels = [b.pop('labels') for b in batch if b.get('labels') is not None] |
| | res = self._data_collator(batch, padding_to=padding_to) |
| | if labels: |
| | problem_type = self._get_problem_type(self.config) |
| | if problem_type == 'regression': |
| | labels = torch.tensor(labels, dtype=torch.float32) |
| | elif problem_type == 'multi_label_classification': |
| | one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32) |
| | for i, label in enumerate(labels): |
| | one_hot_labels[i, label] = 1 |
| | labels = one_hot_labels |
| | else: |
| | labels = torch.tensor(labels, dtype=torch.long) |
| | res['labels'] = labels |
| | return res |
| |
|
| | def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
| | """ |
| | Args: |
| | batch(`List[Dict[str, Any]]`): The input data in batch |
| | padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
| | will be padded to the `longest` |
| | """ |
| | assert self.tokenizer.pad_token_id is not None |
| | padding_side = self.padding_side if self.is_training else 'left' |
| | padding_right = padding_side == 'right' |
| | packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0] |
| | res = {} |
| | if packing_mode: |
| | |
| | for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']: |
| | v = self.gather_list(batch, k) |
| | if v: |
| | res[k] = [v] |
| | else: |
| | inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None] |
| | input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None] |
| | if inputs_embeds: |
| | res['inputs_embeds'] = inputs_embeds |
| | if input_ids: |
| | res['input_ids'] = input_ids |
| | for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']: |
| | val = [b[key] for b in batch if b.get(key) is not None] |
| | if val: |
| | res[key] = val |
| |
|
| | keys = [ |
| | 'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids' |
| | ] |
| | pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0] |
| | |
| | seq_lens = None |
| | for key in keys: |
| | if key not in res: |
| | continue |
| | for i, val in enumerate(res[key]): |
| | if isinstance(val, (list, tuple)): |
| | val = torch.tensor(val) |
| | elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2: |
| | val = val[0] |
| | res[key][i] = val |
| | if not seq_lens: |
| | seq_lens = [seq.shape[0] for seq in res[key]] |
| | if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res): |
| | res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens] |
| | if self.is_training and self.padding_side == 'left': |
| | res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens] |
| |
|
| | if self.use_megatron: |
| | padding_to = math.ceil(max(seq_lens) / 128) * 128 |
| | cp_size = self.sequence_parallel_size |
| | if cp_size > 1: |
| | padding_len = padding_to - seq_lens[0] |
| | position_ids = res['position_ids'][0].tolist() |
| | position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2)) |
| | res['position_ids'][0] = torch.tensor(position_ids) |
| |
|
| | for key, pad_value in zip(keys, pad_values): |
| | if key not in res: |
| | continue |
| | if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1: |
| | pass |
| | elif padding_to is not None: |
| | padding_len = padding_to - seq_lens[0] |
| | if padding_len > 0: |
| | res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0), |
| | 'constant', pad_value) |
| | res[key] = self._pad_sequence(res[key], pad_value) |
| |
|
| | |
| | res.update(self._data_collator_mm_data(batch)) |
| | if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1): |
| | res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side) |
| |
|
| | return res |
| |
|
| | def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | |
| | res = {} |
| | pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None] |
| | if len(pixel_values) > 0: |
| | res['pixel_values'] = torch.concat(pixel_values) |
| |
|
| | image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None] |
| | if len(image_sizes) > 0: |
| | res['image_sizes'] = torch.concat(image_sizes) |
| |
|
| | pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None] |
| | if len(pixel_values_videos) > 0: |
| | res['pixel_values_videos'] = torch.concat(pixel_values_videos) |
| | return res |
| |
|
| | def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side): |
| | |
| | input_ids = res.get('input_ids') |
| | attention_mask = res.get('attention_mask') |
| | labels = res.get('labels') |
| | loss_scale = res.get('loss_scale') |
| | if use_torchacc(): |
| | from swift.utils.torchacc_utils import pad_and_split_batch |
| | rank, _, world_size, _ = get_dist_setting() |
| | input_ids, attention_mask, labels, loss_scale = pad_and_split_batch( |
| | padding_to, |
| | input_ids, |
| | attention_mask, |
| | labels, |
| | loss_scale, |
| | self.max_length, |
| | tokenizer, |
| | rank, |
| | world_size, |
| | padding_right=padding_side == 'right') |
| | if self.sequence_parallel_size > 1 and input_ids is not None: |
| | bs, seq_len = input_ids.shape |
| | if 'position_ids' not in res: |
| | position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) |
| | else: |
| | position_ids = res['position_ids'] |
| | assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right' |
| | from swift.trainers.sequence_parallel import sequence_parallel |
| | if sequence_parallel.world_size() > 1: |
| | from swift.trainers.sequence_parallel import sequence_parallel |
| | input_ids, _, labels, position_ids, attention_mask, loss_scale = \ |
| | sequence_parallel.pad_and_split_inputs( |
| | tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale) |
| | res['position_ids'] = position_ids |
| | _local_var = locals() |
| | for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']: |
| | value = _local_var[key] |
| | if value is not None: |
| | res[key] = value |
| | return res |
| |
|
| | def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None: |
| | if tokenizer_kwargs is None: |
| | tokenizer_kwargs = {} |
| | for key in [ |
| | 'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels' |
| | ]: |
| | val = inputs.get(key) |
| | if val is None: |
| | val = inputs.get(f'{key}_ids') |
| | if val is not None: |
| | key_upper = key.upper() |
| | logger.info(f'[{key_upper}_IDS] {val}') |
| | if key == 'labels' and self.mode in {'seq_cls', 'embedding'}: |
| | continue |
| | if isinstance(val, (list, tuple, torch.Tensor)): |
| | val_str = self.safe_decode(val, **tokenizer_kwargs) |
| | logger.info(f'[{key_upper}] {val_str}') |
| | if inputs.get('loss_scale') is not None: |
| | val = inputs['loss_scale'] |
| | logger.info(f'[LOSS_SCALE] {val}') |
| |
|
| | async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None: |
| | images = inputs.pop('images', None) or [] |
| | if len(images) == 0: |
| | return |
| | input_ids = inputs['input_ids'] |
| | idx_list = findall(input_ids, -100) |
| | assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}' |
| | idx_list.insert(0, -1) |
| | new_input_ids = [] |
| | for i in range(len(idx_list) - 1): |
| | new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] |
| | images[i]['offset'] = len(new_input_ids) |
| | new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens'] |
| | new_input_ids += input_ids[idx_list[-1] + 1:] |
| | inputs['input_ids'] = new_input_ids |
| | inputs['multimodal'] = images |
| |
|
| | async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None: |
| | images = inputs.pop('images', None) or [] |
| | if len(images) == 0: |
| | return |
| | from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX |
| | input_ids = inputs['input_ids'] |
| | idx_list = findall(input_ids, -100) |
| | assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}' |
| | idx_list.insert(0, -1) |
| | new_input_ids = [] |
| | ranges = [] |
| | for i in range(len(idx_list) - 1): |
| | _range = [] |
| | new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] |
| | _range.append(len(new_input_ids)) |
| | new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0] |
| | _range.append(len(new_input_ids)) |
| | ranges.append(_range) |
| | new_input_ids += input_ids[idx_list[-1] + 1:] |
| | inputs['input_embeddings'] = [image.to('cpu') for image in images] |
| | inputs['input_embedding_ranges'] = ranges |
| | inputs['input_ids'] = new_input_ids |
| |
|
| | def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor: |
| | """Pad sequence by some side |
| | |
| | Args: |
| | sequences: The input sequences in tensor. |
| | padding_value: The padding value |
| | |
| | Returns: |
| | A tensor after padding |
| | """ |
| | padding_side = self.padding_side if self.is_training else 'left' |
| | padding_right = padding_side == 'right' |
| | if padding_right: |
| | return pad_sequence(sequences, batch_first=True, padding_value=padding_value) |
| |
|
| | max_len = max([s.shape[0] for s in sequences]) |
| |
|
| | padded_sequences = [] |
| | for seq in sequences: |
| | pad_length = max_len - seq.shape[0] |
| | pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0] |
| | padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value) |
| | padded_sequences.append(padded_seq) |
| |
|
| | return torch.stack(padded_sequences) |
| |
|
| | def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str: |
| | if isinstance(self, Template): |
| | tokenizer = self.tokenizer |
| | placeholder_tokens = self.placeholder_tokens |
| | else: |
| | tokenizer = self |
| | placeholder_tokens = [] |
| |
|
| | def _is_special(token: int) -> bool: |
| | if isinstance(token, float) or token < 0: |
| | return True |
| | return token in placeholder_tokens |
| |
|
| | if isinstance(input_ids, torch.Tensor): |
| | input_ids = input_ids.tolist() |
| | if len(input_ids) == 0: |
| | return '' |
| | result_str = '' |
| | for i in range(len(input_ids)): |
| | if i == 0: |
| | if _is_special(input_ids[i]): |
| | s = 0 |
| | else: |
| | e = 0 |
| | continue |
| | if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]): |
| | s = i |
| | result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs) |
| | if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]): |
| | e = i |
| | result_str += f'[{input_ids[i - 1]} * {e - s}]' |
| | if _is_special(input_ids[i]): |
| | result_str += f'[{input_ids[i]} * {len(input_ids) - s}]' |
| | else: |
| | result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs) |
| | return result_str |
| |
|
| | @staticmethod |
| | @contextmanager |
| | def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False): |
| | _origin_flash_attention_forward = modeling_module._flash_attention_forward |
| |
|
| | def _flash_attention_forward(*args, **kwargs): |
| | if use_new_func: |
| | from transformers.modeling_flash_attention_utils import (_flash_attention_forward as |
| | flash_attention_forward) |
| | if args and isinstance(args[0], nn.Module): |
| | args = args[1:] |
| | if 'is_causal' not in kwargs: |
| | kwargs['is_causal'] = True |
| | else: |
| | flash_attention_forward = _origin_flash_attention_forward |
| | kwargs['position_ids'] = position_ids |
| | return flash_attention_forward(*args, **kwargs) |
| |
|
| | modeling_module._flash_attention_forward = _flash_attention_forward |
| | try: |
| | yield |
| | finally: |
| | modeling_module._flash_attention_forward = _origin_flash_attention_forward |
| |
|