|
|
|
|
|
import hashlib |
|
|
import inspect |
|
|
import math |
|
|
import os |
|
|
import re |
|
|
from contextlib import contextmanager, nullcontext |
|
|
from copy import deepcopy |
|
|
from dataclasses import asdict |
|
|
from functools import partial, wraps |
|
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from modelscope.hub.utils.utils import get_cache_dir |
|
|
from peft import PeftModel |
|
|
from PIL import Image |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from transformers import StoppingCriteriaList |
|
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
from transformers.utils import strtobool |
|
|
|
|
|
from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc |
|
|
from ..utils import Processor, ProcessorMixin |
|
|
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs |
|
|
from .utils import Context, ContextType, StopWordsCriteria, fetch_one, findall, split_str_parts_by |
|
|
from .vision_utils import load_audio, load_batch, load_image, rescale_image |
|
|
|
|
|
logger = get_logger() |
|
|
if TYPE_CHECKING: |
|
|
from .template_meta import TemplateMeta |
|
|
|
|
|
|
|
|
class MaxLengthError(ValueError): |
|
|
pass |
|
|
|
|
|
|
|
|
class Template(ProcessorMixin): |
|
|
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>', '<cot-process>', '<start-image>'] |
|
|
special_keys = ['images', 'videos', 'audios', 'objects'] |
|
|
|
|
|
image_placeholder = ['<image>'] |
|
|
video_placeholder = ['<video>'] |
|
|
audio_placeholder = ['<audio>'] |
|
|
cot_process_placeholder = ['ки'] |
|
|
placeholder_tokens = [] |
|
|
load_images = True |
|
|
skip_prompt = True |
|
|
use_model = False |
|
|
norm_bbox = 'norm1000' |
|
|
|
|
|
is_encoder_decoder = False |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
processor: Processor, |
|
|
template_meta: 'TemplateMeta', |
|
|
default_system: Optional[str] = None, |
|
|
max_length: Optional[int] = None, |
|
|
*, |
|
|
use_chat_template: bool = True, |
|
|
truncation_strategy: Literal['raise', 'left', 'right'] = 'raise', |
|
|
max_pixels: Optional[int] = None, |
|
|
agent_template: Optional[str] = None, |
|
|
norm_bbox: Literal['norm1000', 'none', None] = None, |
|
|
response_prefix: Optional[str] = None, |
|
|
|
|
|
padding_side: Literal['left', 'right'] = 'right', |
|
|
loss_scale: str = 'default', |
|
|
sequence_parallel_size: int = 1, |
|
|
|
|
|
template_backend: Literal['swift', 'jinja'] = 'swift', |
|
|
) -> None: |
|
|
""" |
|
|
default_system: Override the default_system in the template. |
|
|
max_length: Max length of the sequence |
|
|
truncation_strategy: The truncation strategy |
|
|
max_pixels: Rescale image to reduce memory usage, default `None` means no limitation. |
|
|
e.g. 512 * 512 (H*W) |
|
|
padding_side: The padding_side when the training batch_size >= 2 |
|
|
loss_scale: The loss scale function to use |
|
|
""" |
|
|
from .template_meta import TemplateMeta |
|
|
from swift.plugin import agent_templates |
|
|
|
|
|
self.processor = processor |
|
|
self.model_info = processor.model_info |
|
|
self.config = self.model_info.config |
|
|
self.model_meta = processor.model_meta |
|
|
if max_length is None: |
|
|
max_length = self.model_info.max_model_len |
|
|
tokenizer = self.tokenizer |
|
|
|
|
|
if not use_chat_template: |
|
|
template_meta = template_meta.to_generate_template_meta() |
|
|
else: |
|
|
template_meta = deepcopy(template_meta) |
|
|
|
|
|
template_meta.check_system(default_system) |
|
|
if default_system is not None: |
|
|
template_meta.default_system = default_system |
|
|
if response_prefix is not None: |
|
|
template_meta.response_prefix = response_prefix |
|
|
logger.info(f'default_system: {repr(template_meta.default_system)}') |
|
|
logger.info(f'response_prefix: {repr(template_meta.response_prefix)}') |
|
|
|
|
|
for i, token in enumerate(self.placeholder_tokens): |
|
|
if isinstance(token, str): |
|
|
self.placeholder_tokens[i] = tokenizer.convert_tokens_to_ids(token) |
|
|
template_meta.init(tokenizer) |
|
|
|
|
|
self.template_meta: TemplateMeta = template_meta |
|
|
self.use_chat_template = use_chat_template |
|
|
self.template_backend = template_backend |
|
|
self.max_length = max_length |
|
|
self.truncation_strategy = truncation_strategy |
|
|
self.loss_scale = loss_scale |
|
|
self.max_pixels = max_pixels |
|
|
self.padding_side = padding_side |
|
|
self.sequence_parallel_size = sequence_parallel_size |
|
|
agent_template = agent_template or template_meta.agent_template |
|
|
logger.info(f'agent_template: {agent_template}') |
|
|
self.agent_template = agent_templates[agent_template]() |
|
|
self.norm_bbox = norm_bbox or self.norm_bbox |
|
|
logger.info(f'max_length: {self.max_length}') |
|
|
logger.info(f'norm_bbox: {self.norm_bbox}') |
|
|
if self.is_encoder_decoder: |
|
|
self.skip_prompt = False |
|
|
self.mode: Literal['pt', 'vllm', 'lmdeploy', |
|
|
'train', 'rlhf', 'kto', |
|
|
'seq_cls', 'embedding', 'prm'] = 'pt' |
|
|
self._packing = False |
|
|
self.use_megatron = False |
|
|
if self.model_info.task_type != 'causal_lm': |
|
|
self.mode = self.model_info.task_type |
|
|
self._handles = [] |
|
|
self._deepspeed_initialize = None |
|
|
|
|
|
@staticmethod |
|
|
def _load_image(image, load_images: bool): |
|
|
if load_images: |
|
|
if isinstance(image, dict) and 'bytes' in image: |
|
|
image = image['bytes'] or image['path'] |
|
|
image = load_image(image) |
|
|
else: |
|
|
if isinstance(image, dict): |
|
|
path = image['path'] |
|
|
if path and (path.startswith('http') or os.path.exists(path)): |
|
|
image = path |
|
|
else: |
|
|
image = load_image(image['bytes']) |
|
|
elif not isinstance(image, str): |
|
|
image = load_image(image) |
|
|
return image |
|
|
|
|
|
@staticmethod |
|
|
def _get_height_width(inputs: StdTemplateInputs) -> None: |
|
|
width = [] |
|
|
height = [] |
|
|
for image in inputs.images: |
|
|
width.append(image.width) |
|
|
height.append(image.height) |
|
|
inputs.objects['width'] = width |
|
|
inputs.objects['height'] = height |
|
|
|
|
|
def normalize_bbox(self, inputs: StdTemplateInputs) -> None: |
|
|
objects = inputs.objects |
|
|
bbox_list = objects['bbox'] |
|
|
width_list = objects['width'] |
|
|
height_list = objects['height'] |
|
|
bbox_type = objects.pop('bbox_type', None) or 'real' |
|
|
image_id_list = objects.pop('image_id', None) or [] |
|
|
image_id_list += [0] * (len(bbox_list) - len(image_id_list)) |
|
|
for bbox, image_id in zip(bbox_list, image_id_list): |
|
|
if bbox_type == 'norm1': |
|
|
width, height = 1, 1 |
|
|
else: |
|
|
width, height = width_list[image_id], height_list[image_id] |
|
|
for i, (x, y) in enumerate(zip(bbox[::2], bbox[1::2])): |
|
|
if self.norm_bbox == 'norm1000': |
|
|
norm_width, norm_height = 1000, 1000 |
|
|
elif self.norm_bbox == 'none': |
|
|
image = inputs.images[image_id] |
|
|
norm_width, norm_height = image.width, image.height |
|
|
bbox[2 * i] = int(round(x / width * norm_width)) |
|
|
bbox[2 * i + 1] = int(round(y / height * norm_height)) |
|
|
|
|
|
def _preprocess_function_call(self, inputs: StdTemplateInputs) -> None: |
|
|
agent_template = self.agent_template |
|
|
agent_template.template_meta = self.template_meta |
|
|
if inputs.tools: |
|
|
if isinstance(inputs.tools, str): |
|
|
inputs.tools = agent_template._parse_json(inputs.tools) |
|
|
if not isinstance(inputs.tools, (list, tuple)): |
|
|
inputs.tools = [inputs.tools] |
|
|
elif isinstance(inputs.tools, (list, tuple)): |
|
|
inputs.tools = [agent_template._parse_json(tool) for tool in inputs.tools] |
|
|
else: |
|
|
raise ValueError(f'inputs.tools: {inputs.tools}') |
|
|
for i, tool in enumerate(inputs.tools): |
|
|
inputs.tools[i] = agent_template.wrap_tool(tool) |
|
|
i = 0 |
|
|
messages = inputs.messages |
|
|
while i < len(messages): |
|
|
if messages[i]['role'] == 'tool_call': |
|
|
i_start = i |
|
|
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool_call': |
|
|
i += 1 |
|
|
tool_content = self.agent_template._format_tool_calls(messages[i_start:i + 1]) |
|
|
messages[i_start:i + 1] = [{'role': 'assistant', 'content': tool_content}] |
|
|
i = i_start + 1 |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
def _preprocess_inputs( |
|
|
self, |
|
|
inputs: StdTemplateInputs, |
|
|
) -> None: |
|
|
self._preprocess_function_call(inputs) |
|
|
if self.model_meta.is_multimodal: |
|
|
self._replace_image_tags(inputs) |
|
|
self._replace_start_image_tags(inputs) |
|
|
images = inputs.images |
|
|
load_images = self.load_images or self.mode in {'vllm', 'lmdeploy'} |
|
|
load_images_origin = load_images |
|
|
if self.max_pixels is not None or inputs.objects: |
|
|
load_images = True |
|
|
if images: |
|
|
for i, image in enumerate(images): |
|
|
images[i] = self._load_image(images[i], load_images) |
|
|
if inputs.objects: |
|
|
self._get_height_width(inputs) |
|
|
if self.max_pixels is not None: |
|
|
|
|
|
images = [rescale_image(img, self.max_pixels) for img in images] |
|
|
if images and not load_images_origin: |
|
|
for i, image in enumerate(images): |
|
|
if isinstance(image, Image.Image): |
|
|
images[i] = self._save_pil_image(image) |
|
|
inputs.images = images |
|
|
|
|
|
if self.mode == 'vllm' and inputs.audios: |
|
|
sampling_rate = get_env_args('sampling_rate', int, None) |
|
|
inputs.audios = load_batch( |
|
|
inputs.audios, load_func=partial(load_audio, sampling_rate=sampling_rate, return_sr=True)) |
|
|
|
|
|
if inputs.is_multimodal: |
|
|
self._add_default_tags(inputs) |
|
|
|
|
|
@staticmethod |
|
|
def _replace_image_tags(inputs: StdTemplateInputs): |
|
|
|
|
|
if inputs.images: |
|
|
return |
|
|
images = [] |
|
|
pattern = r'<img>(.+?)</img>' |
|
|
for message in inputs.messages: |
|
|
content = message['content'] |
|
|
if not isinstance(content, str): |
|
|
continue |
|
|
for image in re.findall(pattern, content): |
|
|
|
|
|
if os.path.isfile(image): |
|
|
images.append(image) |
|
|
else: |
|
|
logger.warning_once(f'Failed to parse image path: `{content}`.', hash_id='<img></img>') |
|
|
message['content'] = re.sub(pattern, '<image>', content) |
|
|
inputs.images = images |
|
|
|
|
|
@staticmethod |
|
|
def _replace_start_image_tags(inputs: StdTemplateInputs): |
|
|
|
|
|
generate_mode = False |
|
|
message = inputs.messages[-1] |
|
|
content = message['content'] |
|
|
if message['role'] == 'user' and content.endswith('<start-image>'): |
|
|
generate_mode = True |
|
|
message['content'] = message['content'][:-len('<start-image>')] |
|
|
inputs.generate_mode = generate_mode |
|
|
|
|
|
@staticmethod |
|
|
def _extend_tokens(input_ids: List[int], labels: Optional[List[int]], replace_idx_list: List[int], |
|
|
get_new_tokens: Callable[[int], List[int]]) -> Tuple[List[int], Optional[List[int]]]: |
|
|
added_tokens_len = 0 |
|
|
for i, idx in enumerate(replace_idx_list): |
|
|
new_tokens = get_new_tokens(i) |
|
|
token_len = len(new_tokens) |
|
|
input_ids = input_ids[:idx + added_tokens_len] + new_tokens + input_ids[added_tokens_len + idx + 1:] |
|
|
if labels: |
|
|
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx + 1:] |
|
|
added_tokens_len += token_len - 1 |
|
|
return input_ids, labels |
|
|
|
|
|
def compute_loss_context(self, model, inputs): |
|
|
return nullcontext() |
|
|
|
|
|
def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
|
|
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs) |
|
|
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}' |
|
|
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response |
|
|
chosen_encoded = self._encode_truncated(chosen_inputs) |
|
|
rejected_encoded = self._encode_truncated(rejected_inputs) |
|
|
|
|
|
encoded = {} |
|
|
for prefix in ['chosen', 'rejected']: |
|
|
data = locals()[f'{prefix}_encoded'] |
|
|
for k, v in data.items(): |
|
|
encoded[f'{prefix}_{k}'] = v |
|
|
return encoded |
|
|
|
|
|
def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
|
|
label, inputs.label = inputs.label, None |
|
|
encoded = self._rlhf_encode(inputs) |
|
|
encoded['label'] = bool(label) |
|
|
return encoded |
|
|
|
|
|
def _embedding_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
|
|
_encoded = {} |
|
|
labels = [] |
|
|
|
|
|
def split_multi_medias(_inputs): |
|
|
_content = _inputs.messages[-2]['content'] |
|
|
image_size = len(re.findall('<image>', _content)) |
|
|
video_size = len(re.findall('<video>', _content)) |
|
|
audio_size = len(re.findall('<audio>', _content)) |
|
|
_inputs.images = inputs.images[:image_size] |
|
|
assert len(_inputs.images) == image_size |
|
|
inputs.images = inputs.images[image_size:] |
|
|
_inputs.videos = inputs.videos[:video_size] |
|
|
assert len(_inputs.videos) == video_size |
|
|
inputs.videos = inputs.videos[video_size:] |
|
|
_inputs.audios = inputs.audios[:audio_size] |
|
|
assert len(_inputs.audios) == audio_size |
|
|
inputs.audios = inputs.audios[audio_size:] |
|
|
|
|
|
anchor = deepcopy(inputs) |
|
|
anchor.messages[-1]['content'] = '' |
|
|
anchor.rejected_response = [] |
|
|
split_multi_medias(anchor) |
|
|
anchor_encoded = self._encode_truncated(anchor) |
|
|
for key in anchor_encoded: |
|
|
_encoded[f'anchor_{key}'] = anchor_encoded[key] |
|
|
|
|
|
positive = deepcopy(inputs) |
|
|
positive.messages[-2]['content'] = positive.messages[-1]['content'] |
|
|
positive.messages[-1]['content'] = '' |
|
|
positive.rejected_response = [] |
|
|
split_multi_medias(positive) |
|
|
positive_encoded = self._encode_truncated(positive) |
|
|
for key in positive_encoded: |
|
|
_encoded[f'positive_{key}'] = positive_encoded[key] |
|
|
labels.append(float(inputs.label) if inputs.label is not None else 1.0) |
|
|
|
|
|
rejected_len = len(inputs.rejected_response) if inputs.rejected_response else 0 |
|
|
for i in range(rejected_len): |
|
|
negative = deepcopy(inputs) |
|
|
negative.messages[-2]['content'] = negative.rejected_response[i] |
|
|
negative.messages[-1]['content'] = '' |
|
|
negative.rejected_response = [] |
|
|
split_multi_medias(negative) |
|
|
negative_encoded = self._encode_truncated(negative) |
|
|
for key in negative_encoded: |
|
|
_encoded[f'negative{i}_{key}'] = negative_encoded[key] |
|
|
labels.append(0.0) |
|
|
|
|
|
_encoded['labels'] = labels |
|
|
return _encoded |
|
|
|
|
|
def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
|
|
encoded = self._encode_truncated(inputs) |
|
|
encoded.pop('labels', None) |
|
|
if inputs.label is not None: |
|
|
labels = inputs.label |
|
|
problem_type = self._get_problem_type(self.config, labels=labels) |
|
|
if problem_type == 'single_label_classification': |
|
|
labels = int(labels) |
|
|
encoded['labels'] = labels |
|
|
return encoded |
|
|
|
|
|
@torch.inference_mode() |
|
|
def encode(self, |
|
|
inputs: Union[TemplateInputs, Dict[str, Any], InferRequest], |
|
|
return_template_inputs: bool = False) -> Dict[str, Any]: |
|
|
"""The entrance method of Template! |
|
|
|
|
|
Returns: |
|
|
return {'input_ids': List[int], 'labels': Optional[List[int]], ...} |
|
|
""" |
|
|
if isinstance(inputs, (InferRequest, TemplateInputs)): |
|
|
inputs = asdict(inputs) |
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
inputs = deepcopy(inputs) |
|
|
if not self.is_training: |
|
|
InferRequest.remove_response(inputs['messages']) |
|
|
inputs = StdTemplateInputs.from_dict(inputs) |
|
|
elif isinstance(inputs, StdTemplateInputs): |
|
|
inputs = deepcopy(inputs) |
|
|
assert isinstance(inputs, StdTemplateInputs) |
|
|
self._preprocess_inputs(inputs) |
|
|
|
|
|
if self.mode in {'pt', 'train', 'prm', 'vllm', 'lmdeploy'}: |
|
|
encoded = self._encode_truncated(inputs) |
|
|
elif self.mode == 'seq_cls': |
|
|
encoded = self._seq_cls_encode(inputs) |
|
|
elif self.mode == 'rlhf': |
|
|
encoded = self._rlhf_encode(inputs) |
|
|
elif self.mode == 'kto': |
|
|
encoded = self._kto_encode(inputs) |
|
|
elif self.mode == 'embedding': |
|
|
encoded = self._embedding_encode(inputs) |
|
|
for key in list(encoded.keys()): |
|
|
if encoded[key] is None: |
|
|
encoded.pop(key) |
|
|
if return_template_inputs: |
|
|
encoded['template_inputs'] = inputs |
|
|
return encoded |
|
|
|
|
|
def packing_row(self, row: List[Tuple[Dict[str, Any], int]]) -> Dict[str, Any]: |
|
|
packed = {} |
|
|
keys = set() |
|
|
for r in row: |
|
|
keys.update(r[0].keys()) |
|
|
for key in keys: |
|
|
if key in {'input_ids', 'labels', 'loss_scale'}: |
|
|
packed[key] = sum((x[0][key] for x in row), start=[]) |
|
|
if 'position_ids' not in packed: |
|
|
packed['position_ids'] = sum((list(range(x[1])) for x in row), start=[]) |
|
|
|
|
|
packed.update(self._data_collator_mm_data([r[0] for r in row])) |
|
|
return packed |
|
|
|
|
|
def _post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: |
|
|
return inputs |
|
|
|
|
|
@staticmethod |
|
|
def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]: |
|
|
len_tokens = len(stop_tokens) |
|
|
if is_finished and generate_ids[-len_tokens:] == stop_tokens: |
|
|
return generate_ids[:-len_tokens] |
|
|
if not is_finished: |
|
|
for i in range(len_tokens, 0, -1): |
|
|
if generate_ids[-i:] == stop_tokens[:i]: |
|
|
return generate_ids[:-i] |
|
|
return generate_ids |
|
|
|
|
|
@staticmethod |
|
|
def _get_seq_cls_logprobs(pred: int, logprobs: torch.Tensor, top_logprobs: int): |
|
|
idxs = logprobs.argsort(descending=True, dim=-1)[:top_logprobs].tolist() |
|
|
logprobs = logprobs.tolist() |
|
|
return { |
|
|
'content': [{ |
|
|
'index': pred, |
|
|
'logprobs': [logprobs[p] for p in pred] if isinstance(pred, (list, tuple)) else logprobs[pred], |
|
|
'top_logprobs': [{ |
|
|
'index': idx, |
|
|
'logprob': logprobs[idx] |
|
|
} for idx in idxs] |
|
|
}] |
|
|
} |
|
|
|
|
|
@staticmethod |
|
|
def _get_problem_type(config, labels=None, logits=None) -> str: |
|
|
problem_type = config.problem_type |
|
|
if problem_type is not None: |
|
|
return problem_type |
|
|
if labels is not None: |
|
|
if isinstance(labels, (list, tuple)): |
|
|
if labels and isinstance(labels[0], float): |
|
|
problem_type = 'regression' |
|
|
else: |
|
|
problem_type = 'multi_label_classification' |
|
|
else: |
|
|
problem_type = 'single_label_classification' |
|
|
assert config.num_labels >= labels + 1 |
|
|
if logits is not None: |
|
|
if logits.shape[-1] == 1: |
|
|
problem_type = 'regression' |
|
|
else: |
|
|
problem_type = 'single_label_classification' |
|
|
assert problem_type is not None |
|
|
config.problem_type = problem_type |
|
|
return problem_type |
|
|
|
|
|
def decode_seq_cls(self, logits: torch.Tensor, top_logprobs: int): |
|
|
assert isinstance(logits, torch.Tensor) |
|
|
problem_type = self._get_problem_type(self.config, logits=logits) |
|
|
if problem_type == 'regression': |
|
|
preds = logits.squeeze(dim=-1).tolist() |
|
|
logprobs = [None] * len(preds) |
|
|
else: |
|
|
if problem_type == 'single_label_classification': |
|
|
preds = torch.argmax(logits, dim=-1).tolist() |
|
|
logprobs = torch.log_softmax(logits, -1) |
|
|
else: |
|
|
preds = [(logprob >= 0.5).nonzero(as_tuple=True)[0].tolist() for logprob in torch.sigmoid(logits)] |
|
|
logprobs = F.logsigmoid(logits) |
|
|
logprobs = [self._get_seq_cls_logprobs(pred, logprobs[i], top_logprobs) for i, pred in enumerate(preds)] |
|
|
return preds, logprobs |
|
|
|
|
|
def decode(self, |
|
|
generate_ids: List[int], |
|
|
*, |
|
|
is_finished: bool = True, |
|
|
tokenizer_kwargs=None, |
|
|
first_token=True, |
|
|
**kwargs) -> Any: |
|
|
tokenizer_kwargs = tokenizer_kwargs or {} |
|
|
response = self._skip_stop_decode(generate_ids, is_finished, **tokenizer_kwargs) |
|
|
if first_token and self.template_meta.response_prefix: |
|
|
response = self.template_meta.response_prefix + response |
|
|
return response |
|
|
|
|
|
def decode_prm(self, input_ids: torch.Tensor, logits: torch.Tensor) -> Any: |
|
|
raise NotImplementedError |
|
|
|
|
|
def generate(self, model, *args, **kwargs): |
|
|
if isinstance(model, PeftModel): |
|
|
signature = inspect.signature(model.model.generate) |
|
|
else: |
|
|
signature = inspect.signature(model.generate) |
|
|
if 'use_model_defaults' in signature.parameters and 'use_model_defaults' not in kwargs: |
|
|
kwargs['use_model_defaults'] = False |
|
|
return model.generate(*args, **kwargs) |
|
|
|
|
|
def _skip_stop_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any: |
|
|
|
|
|
|
|
|
tokenizer = self.tokenizer |
|
|
|
|
|
if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id: |
|
|
generate_ids = generate_ids[:-1] |
|
|
|
|
|
template_suffix = self.template_meta.suffix[-1] |
|
|
if isinstance(template_suffix, str): |
|
|
|
|
|
template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)[-1:] |
|
|
generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished) |
|
|
if 'spaces_between_special_tokens' not in decode_kwargs: |
|
|
decode_kwargs['spaces_between_special_tokens'] = False |
|
|
return tokenizer.decode(generate_ids, **decode_kwargs) |
|
|
|
|
|
def prepare_generate_kwargs(self, generate_kwargs: Dict[str, Any], *, model=None) -> Dict[str, Any]: |
|
|
generation_config = generate_kwargs['generation_config'] |
|
|
stop_words = getattr(generation_config, 'stop_words', None) or self.template_meta.stop_words |
|
|
generate_kwargs['stopping_criteria'] = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, stop_words)]) |
|
|
return generate_kwargs |
|
|
|
|
|
@staticmethod |
|
|
def _save_pil_image(image: Image.Image) -> str: |
|
|
img_bytes = image.tobytes() |
|
|
img_hash = hashlib.sha256(img_bytes).hexdigest() |
|
|
tmp_dir = os.path.join(get_cache_dir(), 'tmp', 'images') |
|
|
logger.info_once(f'create tmp_dir: {tmp_dir}') |
|
|
os.makedirs(tmp_dir, exist_ok=True) |
|
|
img_path = os.path.join(tmp_dir, f'{img_hash}.png') |
|
|
if not os.path.exists(img_path): |
|
|
image.save(img_path) |
|
|
return img_path |
|
|
|
|
|
@staticmethod |
|
|
def _concat_context_list( |
|
|
context_list: List[Context], |
|
|
res_context_list: List[Context], |
|
|
res_context_type: List[ContextType], |
|
|
system: Optional[str] = None, |
|
|
query: Optional[str] = None, |
|
|
response: Optional[str] = None, |
|
|
round0: Optional[int] = None) -> None: |
|
|
"""Concat context list and replace placeholder""" |
|
|
round1 = None |
|
|
if round0 is not None: |
|
|
round1 = str(round0 + 1) |
|
|
round0 = str(round0) |
|
|
for context in context_list: |
|
|
if isinstance(context, str): |
|
|
if '{{RESPONSE}}' == context: |
|
|
assert response is not None |
|
|
res_context_list.append(response) |
|
|
res_context_type.append(ContextType.RESPONSE) |
|
|
continue |
|
|
old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}'] |
|
|
new_str_list = [system, query, round0, round1] |
|
|
for (old_str, new_str) in zip(old_str_list, new_str_list): |
|
|
if new_str is not None and old_str in context: |
|
|
assert isinstance(new_str, str), f'new_str: {new_str}' |
|
|
context = context.replace(old_str, new_str) |
|
|
if len(context) == 0: |
|
|
continue |
|
|
res_context_list.append(context) |
|
|
res_context_type.append(ContextType.OTHER) |
|
|
|
|
|
def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float], |
|
|
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
|
|
"""Merge anything in the context to simplify the inputs""" |
|
|
context_list, loss_scale_list = self._split_special_tokens(context_list, loss_scale_list) |
|
|
context_list, loss_scale_list = self._pre_tokenize(context_list, loss_scale_list, inputs) |
|
|
|
|
|
res: List[Context] = [] |
|
|
res_loss_scale: List[float] = [] |
|
|
temp: List[str] = [] |
|
|
temp_loss_scale = 0. |
|
|
for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)): |
|
|
if isinstance(context, str) and (loss_scale == temp_loss_scale): |
|
|
temp.append(context) |
|
|
else: |
|
|
if len(temp) > 0: |
|
|
res.append(''.join(temp)) |
|
|
res_loss_scale.append(temp_loss_scale) |
|
|
temp.clear() |
|
|
if isinstance(context, str): |
|
|
temp.append(context) |
|
|
else: |
|
|
res.append(context) |
|
|
res_loss_scale.append(loss_scale) |
|
|
temp_loss_scale = loss_scale |
|
|
if len(temp) > 0: |
|
|
res.append(''.join(temp)) |
|
|
res_loss_scale.append(temp_loss_scale) |
|
|
|
|
|
return res, res_loss_scale |
|
|
|
|
|
@staticmethod |
|
|
def _split_special_tokens(context_list: List[Context], |
|
|
loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]: |
|
|
"""Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation""" |
|
|
res: List[Context] = [] |
|
|
loss_scale_res: List[float] = [] |
|
|
for context, loss_scale in zip(context_list, loss_scale_list): |
|
|
contexts = [] |
|
|
if isinstance(fetch_one(context), str): |
|
|
for d in split_str_parts_by(context, Template.special_tokens): |
|
|
contexts.extend([d['key'], d['content']]) |
|
|
contexts = [c for c in contexts if c] |
|
|
res.extend(contexts) |
|
|
loss_scale_res.extend([loss_scale] * len(contexts)) |
|
|
else: |
|
|
res.append(context) |
|
|
loss_scale_res.append(loss_scale) |
|
|
return res, loss_scale_res |
|
|
|
|
|
def _tokenize(self, context, **tokenizer_kwargs): |
|
|
return self.tokenizer( |
|
|
context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids'] |
|
|
|
|
|
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int, |
|
|
inputs: StdTemplateInputs) -> List[Context]: |
|
|
"""Override this function to do your own replace operation. |
|
|
|
|
|
This method is used to replace standard tags like `<image>` to some tokens that the model needs. |
|
|
|
|
|
Args: |
|
|
media_type: The modal. |
|
|
index: The index of the medias, for index 0 represents the first elements in `images` |
|
|
inputs: The inputs |
|
|
|
|
|
Returns: |
|
|
The content or input_ids after replacement. |
|
|
""" |
|
|
if media_type == 'image': |
|
|
if self.mode == 'lmdeploy': |
|
|
return [[-100]] |
|
|
return self.image_placeholder |
|
|
elif media_type == 'video': |
|
|
return self.video_placeholder |
|
|
elif media_type == 'audio': |
|
|
return self.audio_placeholder |
|
|
|
|
|
def replace_ref(self, ref: str, index: int, inputs: StdTemplateInputs) -> List[Context]: |
|
|
"""Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task. |
|
|
Override this function to do your own replace operation. |
|
|
|
|
|
Args: |
|
|
ref: Description of the bbox |
|
|
index: The index in the `objects` key |
|
|
inputs: The inputs |
|
|
|
|
|
Returns: |
|
|
The contents or input_ids replaced |
|
|
""" |
|
|
return [ref] |
|
|
|
|
|
def replace_cot_process(self, inputs: StdTemplateInputs) -> List[Context]: |
|
|
"""Replace the cot process label for PRM training or inference. |
|
|
Override this function to do your own replace operation. |
|
|
|
|
|
Args: |
|
|
inputs: The inputs |
|
|
|
|
|
Returns: |
|
|
The contents or input_ids replaced |
|
|
""" |
|
|
return [self.cot_process_placeholder] |
|
|
|
|
|
@staticmethod |
|
|
def _get_bbox_str(bbox: List[int]) -> str: |
|
|
point = [] |
|
|
for x, y in zip(bbox[::2], bbox[1::2]): |
|
|
point.append(f'({x},{y})') |
|
|
return ','.join(point) |
|
|
|
|
|
def replace_bbox(self, bbox: List[int], index: int, inputs: StdTemplateInputs) -> List[Context]: |
|
|
"""Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task. |
|
|
Override this function to do your own replace operation. |
|
|
|
|
|
Args: |
|
|
bbox: [x, y] or [x1, y1, x2, y2] |
|
|
index: The index in the `objects` key |
|
|
inputs: The inputs |
|
|
|
|
|
Returns: |
|
|
The contents or input_ids replaced |
|
|
""" |
|
|
return [f'[{self._get_bbox_str(bbox)}]'] |
|
|
|
|
|
def _pre_tokenize_images(self, context_list: List[Context], loss_scale_list: List[float], |
|
|
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
|
|
|
|
|
|
|
|
res: List[Context] = [] |
|
|
res_loss_scale: List[float] = [] |
|
|
inputs.image_idx = 0 |
|
|
|
|
|
for context, loss_scale in zip(context_list, loss_scale_list): |
|
|
if context == '<image>' and inputs.is_multimodal and inputs.image_idx < len(inputs.images): |
|
|
c_list = self.replace_tag('image', inputs.image_idx, inputs) |
|
|
inputs.image_idx += 1 |
|
|
loss_scale = 0. if self.template_backend == 'swift' else 1. |
|
|
else: |
|
|
c_list = [context] |
|
|
res += c_list |
|
|
res_loss_scale += [loss_scale] * len(c_list) |
|
|
return res, res_loss_scale |
|
|
|
|
|
def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float], |
|
|
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: |
|
|
"""This method happens before tokenization, replace standard tags to the contents or input_ids needed by |
|
|
the model. |
|
|
|
|
|
Args: |
|
|
context_list: The content list |
|
|
loss_scale_list: The loss scale list |
|
|
Returns: |
|
|
The context_list and loss_scale_list after replacement. |
|
|
""" |
|
|
context_list, loss_scale_list = self._pre_tokenize_images(context_list, loss_scale_list, inputs) |
|
|
if inputs.images and inputs.objects: |
|
|
self.normalize_bbox(inputs) |
|
|
|
|
|
res: List[Context] = [] |
|
|
res_loss_scale: List[float] = [] |
|
|
|
|
|
|
|
|
for k in ['video', 'audio', 'object', 'box']: |
|
|
setattr(inputs, f'{k}_idx', 0) |
|
|
|
|
|
for context, loss_scale in zip(context_list, loss_scale_list): |
|
|
for k in ['video', 'audio']: |
|
|
if context == f'<{k}>' and inputs.is_multimodal and getattr(inputs, f'{k}_idx') < len( |
|
|
getattr(inputs, f'{k}s')): |
|
|
c_list = self.replace_tag(k, getattr(inputs, f'{k}_idx'), inputs) |
|
|
setattr(inputs, f'{k}_idx', getattr(inputs, f'{k}_idx') + 1) |
|
|
loss_scale = 0. |
|
|
break |
|
|
else: |
|
|
ref = inputs.objects.get('ref') or [] |
|
|
bbox = inputs.objects.get('bbox') or [] |
|
|
if context == '<ref-object>' and inputs.ref_idx < len(ref): |
|
|
idx = inputs.ref_idx |
|
|
c_list = self.replace_ref(ref[idx], idx, inputs) |
|
|
inputs.ref_idx += 1 |
|
|
elif context == '<bbox>' and inputs.bbox_idx < len(bbox): |
|
|
idx = inputs.bbox_idx |
|
|
c_list = self.replace_bbox(bbox[idx], idx, inputs) |
|
|
inputs.bbox_idx += 1 |
|
|
elif context == '<cot-process>' and self.mode == 'prm': |
|
|
c_list = self.replace_cot_process(inputs) |
|
|
else: |
|
|
c_list = [context] |
|
|
res += c_list |
|
|
res_loss_scale += [loss_scale] * len(c_list) |
|
|
return res, res_loss_scale |
|
|
|
|
|
@staticmethod |
|
|
def _add_default_tags(inputs: StdTemplateInputs): |
|
|
total_content = '\n'.join([message['content'] or '' for message in inputs.messages]) |
|
|
if inputs.rejected_response: |
|
|
if isinstance(inputs.rejected_response, str): |
|
|
total_content += inputs.rejected_response |
|
|
else: |
|
|
total_content += '\n'.join(inputs.rejected_response) |
|
|
if inputs.system: |
|
|
total_content = f'{inputs.system}\n{total_content}' |
|
|
for media_type in ['image', 'audio', 'video']: |
|
|
media_key, media_tag = f'{media_type}s', f'<{media_type}>' |
|
|
medias = getattr(inputs, media_key) |
|
|
if not isinstance(medias, list): |
|
|
medias = [medias] |
|
|
if medias: |
|
|
num_media_tags = len(re.findall(media_tag, total_content)) |
|
|
num_media = len(medias) |
|
|
num_new_tags = num_media - num_media_tags |
|
|
if num_new_tags > 0: |
|
|
inputs.messages[0]['content'] = media_tag * num_new_tags + inputs.messages[0]['content'] |
|
|
elif num_new_tags < 0: |
|
|
logger.warning( |
|
|
f'num_media: {num_media}, num_media_tags: {num_media_tags}, total_content: {total_content}. ' |
|
|
'We will only replace the frontmost media_tags while keeping the subsequent media_tags.') |
|
|
|
|
|
def _encode_context_list( |
|
|
self, |
|
|
context_list: List[Context], |
|
|
loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]: |
|
|
"""return: input_ids, labels, tokenizer_kwargs""" |
|
|
input_ids: List[int] = [] |
|
|
labels: List[int] = [] |
|
|
loss_scale: List[float] = [] |
|
|
tokenizer_kwargs = {} |
|
|
if loss_scale_list is None: |
|
|
loss_scale_list = [0.] * len(context_list) |
|
|
ignore_loss_scale = all(loss_scale in {0, 1} for loss_scale in loss_scale_list) |
|
|
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)): |
|
|
if isinstance(context, str): |
|
|
|
|
|
|
|
|
token_list = self._tokenize(context) |
|
|
else: |
|
|
token_list = context |
|
|
input_ids += token_list |
|
|
if loss_scale_list[i] > 0.0: |
|
|
labels += token_list |
|
|
else: |
|
|
labels += [-100] * len(token_list) |
|
|
if not ignore_loss_scale: |
|
|
loss_scale.extend([loss_weight] * len(token_list)) |
|
|
if ignore_loss_scale: |
|
|
loss_scale = None |
|
|
return input_ids, labels, loss_scale, tokenizer_kwargs |
|
|
|
|
|
@staticmethod |
|
|
def _add_dynamic_eos(input_ids: List[int], labels: List[int], loss_scale: Optional[List[int]], |
|
|
suffix_tokens_id: List[int]) -> None: |
|
|
suffix_len = len(suffix_tokens_id) |
|
|
start = 0 |
|
|
for i in range(1, len(labels)): |
|
|
if labels[i - 1] >= 0 and labels[i] == -100: |
|
|
start = i |
|
|
if start > 0 and labels[i - 1] == -100 and labels[i] >= 0: |
|
|
|
|
|
length = i - start |
|
|
if length >= suffix_len and input_ids[start:start + suffix_len] == suffix_tokens_id: |
|
|
labels[start:start + suffix_len] = suffix_tokens_id |
|
|
if loss_scale and loss_scale[start:start + suffix_len] == [0] * suffix_len: |
|
|
loss_scale[start:start + suffix_len] = [1] * suffix_len |
|
|
|
|
|
@staticmethod |
|
|
def _get_std_messages(messages): |
|
|
if messages and messages[0]['role'] == 'assistant': |
|
|
messages.insert(0, {'role': 'user', 'content': ''}) |
|
|
if len(messages) % 2 == 1: |
|
|
messages.append({'role': 'assistant', 'content': None}) |
|
|
|
|
|
def _jinja_encode(self, inputs: StdTemplateInputs): |
|
|
messages = inputs.messages.copy() |
|
|
if inputs.system is not None: |
|
|
messages.insert(0, {'role': 'system', 'content': inputs.system}) |
|
|
if messages[-1]['content'] is None: |
|
|
messages.pop() |
|
|
add_generation_prompt = messages[-1]['role'] != 'assistant' |
|
|
kwargs = {} |
|
|
if inputs.tools: |
|
|
kwargs['tools'] = inputs.tools |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs) |
|
|
answer_len = 1 if self.is_training else 0 |
|
|
return [text], [1.], answer_len |
|
|
|
|
|
def _get_system(self, inputs) -> Optional[str]: |
|
|
template_meta = self.template_meta |
|
|
system = inputs.system |
|
|
tools = inputs.tools |
|
|
template_meta.check_system(system) |
|
|
if system is None: |
|
|
system = template_meta.default_system |
|
|
|
|
|
if tools is not None: |
|
|
system = self.agent_template._format_tools(tools, system or '', inputs.messages[0]) |
|
|
return system |
|
|
|
|
|
@staticmethod |
|
|
def _swift_prepare_function_call(agent_template, messages): |
|
|
if len(messages) < 2: |
|
|
return |
|
|
i = 1 |
|
|
while i < len(messages): |
|
|
pre_message, message = messages[i - 1], messages[i] |
|
|
pre_role, pre_content = pre_message['role'], pre_message['content'] |
|
|
role, content = message['role'], message['content'] |
|
|
if pre_role == 'assistant' and role == 'tool': |
|
|
i_start = i |
|
|
while i + 1 < len(messages) and messages[i + 1]['role'] == 'tool': |
|
|
i += 1 |
|
|
pre_message['content'], tool_content = agent_template._format_tool_responses( |
|
|
pre_content, messages[i_start:i + 1]) |
|
|
messages[i_start:i + 1] = [{'role': 'tool', 'content': tool_content}] |
|
|
i = i_start + 1 |
|
|
elif pre_role == 'assistant' and role == 'assistant': |
|
|
|
|
|
pre_message['content'] = pre_content + content |
|
|
messages.pop(i) |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
def _swift_encode(self, inputs: StdTemplateInputs): |
|
|
template_meta = self.template_meta |
|
|
system = self._get_system(inputs) |
|
|
self._swift_prepare_function_call(self.agent_template, inputs.messages) |
|
|
|
|
|
self._get_std_messages(inputs.messages) |
|
|
n_round = len(inputs.messages) // 2 |
|
|
if n_round > 1 and not self.template_meta.support_multi_round: |
|
|
logger.warning_once( |
|
|
'The template does not support multi-round chat. Only use the last round of the conversation.') |
|
|
inputs.messages = inputs.messages[-2:] |
|
|
|
|
|
res_context_list: List[Context] = [] |
|
|
res_context_types: List[ContextType] = [] |
|
|
sep_token = None |
|
|
if template_meta.auto_add_bos: |
|
|
all_tokens = self.tokenizer.encode('a') |
|
|
single_token = self.tokenizer.encode('a', add_special_tokens=False) |
|
|
assert len(single_token) == 1 |
|
|
idx = all_tokens.index(single_token[0]) |
|
|
bos_token = all_tokens[:idx] |
|
|
sep_token = all_tokens[idx + 1:] |
|
|
if bos_token: |
|
|
res_context_list.append(bos_token) |
|
|
res_context_types.append(ContextType.OTHER) |
|
|
|
|
|
prefix = template_meta.system_prefix if system else template_meta.prefix |
|
|
self._concat_context_list(prefix, res_context_list, res_context_types, system=system) |
|
|
|
|
|
n_round = len(inputs.messages) // 2 |
|
|
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): |
|
|
query_role, query = query_message['role'], query_message['content'] |
|
|
response_role, response = response_message['role'], response_message['content'] |
|
|
|
|
|
assert query_role in {'user', 'tool'}, f'query_role: {query_role}' |
|
|
assert response_role in {'assistant'}, f'response_role: {response_role}' |
|
|
if query_role == 'tool': |
|
|
prompt = query |
|
|
query = '' |
|
|
elif template_meta.is_post_system and i == n_round - 1: |
|
|
prompt = template_meta.system_prompt |
|
|
else: |
|
|
prompt = template_meta.prompt |
|
|
|
|
|
context_list = prompt.copy() |
|
|
extra_context_list = [] |
|
|
extra_context_type = None |
|
|
if i < n_round - 1: |
|
|
|
|
|
context_list.append('{{RESPONSE}}') |
|
|
if inputs.messages[2 * (i + 1)]['role'] != 'tool': |
|
|
extra_context_list = template_meta.chat_sep |
|
|
extra_context_type = ContextType.OTHER |
|
|
elif response is not None: |
|
|
|
|
|
context_list.append('{{RESPONSE}}') |
|
|
if self.is_training and not sep_token: |
|
|
extra_context_list = template_meta.suffix |
|
|
extra_context_type = ContextType.SUFFIX |
|
|
elif template_meta.response_prefix: |
|
|
|
|
|
context_list.append(template_meta.response_prefix) |
|
|
|
|
|
self._concat_context_list( |
|
|
context_list, |
|
|
res_context_list, |
|
|
res_context_types, |
|
|
query=query, |
|
|
response=response, |
|
|
system=system, |
|
|
round0=i) |
|
|
res_context_list += extra_context_list |
|
|
res_context_types += [extra_context_type] * len(extra_context_list) |
|
|
if template_meta.auto_add_bos and sep_token: |
|
|
res_context_list.append(sep_token) |
|
|
res_context_types.append(ContextType.SUFFIX) |
|
|
from swift.plugin import loss_scale_map |
|
|
res_context_list, loss_scale_list = loss_scale_map[self.loss_scale](res_context_list, res_context_types, |
|
|
inputs.messages) |
|
|
if self.is_training: |
|
|
answer_len = len(extra_context_list) + bool(response is not None) |
|
|
else: |
|
|
answer_len = 0 |
|
|
return res_context_list, loss_scale_list, answer_len |
|
|
|
|
|
def _encode_truncated(self, inputs): |
|
|
if self.mode in {'vllm', 'lmdeploy'}: |
|
|
encoded = Template._encode(self, inputs) |
|
|
for key in ['images', 'audios', 'videos']: |
|
|
encoded[key] = getattr(inputs, key) |
|
|
else: |
|
|
encoded = self._encode(inputs) |
|
|
|
|
|
input_ids = encoded.get('input_ids') |
|
|
labels = encoded.get('labels') |
|
|
loss_scale = encoded.get('loss_scale') |
|
|
if self.max_length is not None: |
|
|
if self.truncation_strategy == 'right': |
|
|
input_ids = input_ids[:self.max_length] |
|
|
if labels is not None: |
|
|
labels = labels[:self.max_length] |
|
|
if loss_scale is not None: |
|
|
loss_scale = loss_scale[:self.max_length] |
|
|
elif self.truncation_strategy == 'left': |
|
|
if len(input_ids) > self.max_length: |
|
|
logger.warning_once( |
|
|
'Input data was left-truncated because its length exceeds `max_length` (input length: ' |
|
|
f'{len(input_ids)}, max_length: {self.max_length}). ' |
|
|
'This may cause loss of important tokens (e.g., image tokens) and lead to errors. ' |
|
|
'To avoid this, consider increasing `max_length` or pre-filtering long sequences.', |
|
|
hash_id='max_length_check') |
|
|
input_ids = input_ids[-self.max_length:] |
|
|
if labels is not None: |
|
|
labels = labels[-self.max_length:] |
|
|
if loss_scale is not None: |
|
|
loss_scale = loss_scale[-self.max_length:] |
|
|
elif self.truncation_strategy == 'raise': |
|
|
length = len(input_ids or labels or []) |
|
|
if length > self.max_length: |
|
|
raise MaxLengthError(f'Current length of row({length}) is larger' |
|
|
f' than the max_length({self.max_length}).') |
|
|
encoded['input_ids'] = input_ids |
|
|
encoded['labels'] = labels |
|
|
encoded['loss_scale'] = loss_scale |
|
|
return encoded |
|
|
|
|
|
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: |
|
|
template_backend = self.template_backend |
|
|
if (self.template_meta.template_type == 'dummy' and self.use_chat_template and not self.is_training |
|
|
and self.mode != 'seq_cls'): |
|
|
template_backend = 'jinja' |
|
|
logger.info_once(f'Setting template_backend: {template_backend}') |
|
|
res_context_list, loss_scale_list, answer_len = ( |
|
|
self._swift_encode(inputs) if template_backend == 'swift' else self._jinja_encode(inputs)) |
|
|
encoded = {} |
|
|
if self.is_encoder_decoder: |
|
|
|
|
|
total_len = len(res_context_list) |
|
|
for key, _slice in zip(['prompt', 'answer'], |
|
|
[slice(0, total_len - answer_len), |
|
|
slice(total_len - answer_len, total_len)]): |
|
|
context_list, loss_scale = self._simplify_context_list(res_context_list[_slice], |
|
|
loss_scale_list[_slice], inputs) |
|
|
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(context_list, loss_scale) |
|
|
encoded[f'{key}_input_ids'] = input_ids |
|
|
if key == 'answer': |
|
|
encoded['labels'] = labels |
|
|
encoded['loss_scale'] = loss_scale |
|
|
input_ids = encoded['prompt_input_ids'] + encoded['answer_input_ids'] |
|
|
else: |
|
|
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, inputs) |
|
|
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list( |
|
|
res_context_list, loss_scale_list) |
|
|
self._add_dynamic_eos(input_ids, labels, loss_scale, self._encode_context_list(self.template_meta.suffix)[0]) |
|
|
|
|
|
if tokenizer_kwargs: |
|
|
encoded['tokenizer_kwargs'] = tokenizer_kwargs |
|
|
|
|
|
encoded['input_ids'] = input_ids |
|
|
encoded['labels'] = labels |
|
|
encoded['loss_scale'] = loss_scale |
|
|
if self.use_megatron: |
|
|
self._handle_megatron_cp(encoded) |
|
|
encoded['labels'] = encoded['labels'][1:] + [-100] |
|
|
encoded['position_ids'] = list(range(len(encoded['labels']))) |
|
|
elif encoded.get('labels') is not None: |
|
|
encoded['labels'][0] = -100 |
|
|
if not self.is_training: |
|
|
for k in list(encoded.keys()): |
|
|
if k.endswith('labels') or k.endswith('loss_scale'): |
|
|
encoded[k] = None |
|
|
return encoded |
|
|
|
|
|
def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: |
|
|
cp_size = self.sequence_parallel_size |
|
|
if cp_size == 1: |
|
|
return |
|
|
input_ids = encoded['input_ids'] |
|
|
padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) |
|
|
input_ids += [self.tokenizer.pad_token_id] * padding_len |
|
|
encoded['labels'] += [-100] * padding_len |
|
|
|
|
|
def debug_logger(self, inputs): |
|
|
if not strtobool(os.getenv('SWIFT_DEBUG', 'false')): |
|
|
return |
|
|
if 'input_ids' in inputs: |
|
|
k = 'input_ids' |
|
|
val = inputs['input_ids'] |
|
|
elif 'generate_ids' in inputs: |
|
|
k = 'generate_ids' |
|
|
val = inputs['generate_ids'] |
|
|
for v in val: |
|
|
self.print_inputs({k: v.tolist()}) |
|
|
|
|
|
@staticmethod |
|
|
def _split_list(inputs: List[int], x: int) -> List[List[int]]: |
|
|
idxs = findall(inputs, x) |
|
|
idxs.append(len(inputs)) |
|
|
res = [] |
|
|
lo = 0 |
|
|
for idx in idxs: |
|
|
res.append(inputs[lo:idx]) |
|
|
lo = idx + 1 |
|
|
return res |
|
|
|
|
|
def replace_video2image(self, load_video_func, inputs, replace_tag: Callable) -> List[Context]: |
|
|
context_list = [] |
|
|
if self.mode in {'vllm', 'lmdeploy'}: |
|
|
video = inputs.videos.pop(inputs.video_idx) |
|
|
inputs.video_idx -= 1 |
|
|
else: |
|
|
video = inputs.videos[inputs.video_idx] |
|
|
images = inputs.images |
|
|
new_images = load_video_func(video) |
|
|
inputs.images = images[:inputs.image_idx] + new_images + images[inputs.image_idx:] |
|
|
for i in range(len(new_images)): |
|
|
context_list += replace_tag(i) |
|
|
inputs.image_idx += len(new_images) |
|
|
return context_list |
|
|
|
|
|
def get_generate_ids(self, generate_ids: Union[torch.Tensor, List[int]], |
|
|
num_prompt_tokens: int) -> Union[torch.Tensor, List[int]]: |
|
|
if self.skip_prompt: |
|
|
generate_ids = generate_ids[..., num_prompt_tokens:] |
|
|
return generate_ids |
|
|
|
|
|
def post_process_generate_response(self, response: str, inputs: StdTemplateInputs) -> str: |
|
|
return response |
|
|
|
|
|
def pre_forward_hook(self, model: nn.Module, args, kwargs): |
|
|
from swift.llm import to_device |
|
|
old_kwargs = to_device(kwargs, model.device) |
|
|
kwargs = to_device(self._post_encode(model, old_kwargs), model.device) |
|
|
for k, v in old_kwargs.items(): |
|
|
if k in {'input_ids', 'attention_mask', 'labels', 'position_ids'} and k not in kwargs: |
|
|
kwargs[k] = v |
|
|
if 'inputs_embeds' in kwargs: |
|
|
kwargs.pop('input_ids', None) |
|
|
|
|
|
if isinstance(model, PeftModel): |
|
|
parameters = inspect.signature(model.model.forward).parameters |
|
|
else: |
|
|
parameters = inspect.signature(model.forward).parameters |
|
|
if 'position_ids' not in parameters: |
|
|
kwargs.pop('position_ids', None) |
|
|
return args, kwargs |
|
|
|
|
|
@property |
|
|
def is_training(self): |
|
|
return self.mode not in {'vllm', 'lmdeploy', 'pt'} |
|
|
|
|
|
def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None: |
|
|
self.mode = mode |
|
|
|
|
|
def register_post_encode_hook(self, models: List[nn.Module]) -> None: |
|
|
"""This function is important for multi-modal training, as it registers the post_encode method |
|
|
as a forward hook, converting input_ids into inputs_embeds. |
|
|
""" |
|
|
if self._handles: |
|
|
return |
|
|
|
|
|
for model in models: |
|
|
|
|
|
handle = model.register_forward_pre_hook(self.pre_forward_hook, with_kwargs=True) |
|
|
self._handles.append((model, handle)) |
|
|
|
|
|
if is_deepspeed_zero3_enabled(): |
|
|
import deepspeed |
|
|
self._deepspeed_initialize = deepspeed.initialize |
|
|
|
|
|
@wraps(self._deepspeed_initialize) |
|
|
def _initialize(*args, **kwargs): |
|
|
res = self._deepspeed_initialize(*args, **kwargs) |
|
|
for model, handle in self._handles: |
|
|
model._forward_pre_hooks.move_to_end(handle.id) |
|
|
return res |
|
|
|
|
|
deepspeed.initialize = _initialize |
|
|
|
|
|
def remove_post_encode_hook(self): |
|
|
models = [] |
|
|
for model, handle in self._handles: |
|
|
models.append(model) |
|
|
handle.remove() |
|
|
self._handles = [] |
|
|
|
|
|
if self._deepspeed_initialize is not None: |
|
|
import deepspeed |
|
|
deepspeed.initialize = self._deepspeed_initialize |
|
|
self._deepspeed_initialize = None |
|
|
return models |
|
|
|
|
|
def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
if self.mode == 'rlhf': |
|
|
return self._rlhf_data_collator(batch, padding_to=padding_to) |
|
|
elif self.mode == 'kto': |
|
|
return self._kto_data_collator(batch, padding_to=padding_to) |
|
|
elif self.mode in {'pt', 'train', 'prm'}: |
|
|
return self._data_collator(batch, padding_to=padding_to) |
|
|
elif self.mode == 'seq_cls': |
|
|
return self._seq_cls_data_collator(batch, padding_to=padding_to) |
|
|
elif self.mode == 'embedding': |
|
|
return self._embedding_data_collator(batch, padding_to=padding_to) |
|
|
|
|
|
@staticmethod |
|
|
def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]: |
|
|
new_batch = [] |
|
|
for inputs in batch: |
|
|
new_inputs = {} |
|
|
for k, v in inputs.items(): |
|
|
if k.startswith(prefix): |
|
|
new_inputs[k[len(prefix):]] = v |
|
|
new_batch.append(new_inputs) |
|
|
return new_batch |
|
|
|
|
|
@staticmethod |
|
|
def fetch_inputs(batch: List[Dict[str, Any]], keys: Optional[List[str]] = None) -> Dict[str, Any]: |
|
|
from swift.llm import RowPreprocessor |
|
|
keys = keys or [] |
|
|
rows = RowPreprocessor.rows_to_batched(batch) |
|
|
return {k: rows[k] for k in keys if rows.get(k) is not None} |
|
|
|
|
|
@staticmethod |
|
|
def gather_list(batch: List[Dict[str, Any]], attr_name: str) -> Optional[List[Any]]: |
|
|
|
|
|
res = [] |
|
|
for b in batch: |
|
|
if b.get(attr_name) is not None: |
|
|
res += b.pop(attr_name) |
|
|
return res |
|
|
|
|
|
@staticmethod |
|
|
def concat_tensor(batch: List[Dict[str, Any]], attr_name: str, dim: int) -> Optional[torch.Tensor]: |
|
|
res = [] |
|
|
for b in batch: |
|
|
if b.get(attr_name) is not None: |
|
|
res.append(b.pop(attr_name)) |
|
|
return torch.concat(res, dim=dim) if res else None |
|
|
|
|
|
def _rlhf_data_collator(self, |
|
|
batch: List[Dict[str, Any]], |
|
|
*, |
|
|
chosen_prefix: str = 'chosen_', |
|
|
rejected_prefix: str = 'rejected_', |
|
|
padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
new_batch = [] |
|
|
for prefix in [chosen_prefix, rejected_prefix]: |
|
|
new_batch += self._fetch_inputs_startswith(batch, prefix) |
|
|
return self._data_collator(new_batch, padding_to=padding_to) |
|
|
|
|
|
def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
new_batch = self._fetch_inputs_startswith(batch, 'chosen_') |
|
|
kl_batch = self._fetch_inputs_startswith(batch, 'rejected_') |
|
|
|
|
|
res = self._data_collator(new_batch, padding_to=padding_to) |
|
|
kl_res = self._data_collator(kl_batch, padding_to=padding_to) |
|
|
res = { |
|
|
**{f'completion_{k}': v |
|
|
for k, v in res.items()}, |
|
|
**{f'KL_completion_{k}': v |
|
|
for k, v in kl_res.items()}, |
|
|
} |
|
|
label = [b['label'] for b in batch if b.get('label') is not None] |
|
|
if label: |
|
|
res['label'] = label |
|
|
return res |
|
|
|
|
|
def _embedding_data_collator(self, |
|
|
batch: List[Dict[str, Any]], |
|
|
*, |
|
|
padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
labels = [] |
|
|
new_batch = [] |
|
|
for b in batch: |
|
|
keys = [key for key in b.keys() if 'negative' in key] |
|
|
max_neg = max([int(re.findall(r'negative(-?\d+)', key)[0]) for key in keys]) if keys else None |
|
|
indexes = ['anchor_', 'positive_'] |
|
|
if max_neg is not None: |
|
|
for i in range(0, max_neg + 1): |
|
|
indexes.append(f'negative{i}_') |
|
|
for prefix in indexes: |
|
|
new_batch += self._fetch_inputs_startswith([b], prefix) |
|
|
labels.extend(b.get('labels', None)) |
|
|
res = self._data_collator(new_batch, padding_to=padding_to) |
|
|
if labels: |
|
|
res['labels'] = torch.tensor(labels, dtype=torch.float32) |
|
|
return res |
|
|
|
|
|
def _seq_cls_data_collator(self, |
|
|
batch: List[Dict[str, Any]], |
|
|
*, |
|
|
padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
labels = [b.pop('labels') for b in batch if b.get('labels') is not None] |
|
|
res = self._data_collator(batch, padding_to=padding_to) |
|
|
if labels: |
|
|
problem_type = self._get_problem_type(self.config) |
|
|
if problem_type == 'regression': |
|
|
labels = torch.tensor(labels, dtype=torch.float32) |
|
|
elif problem_type == 'multi_label_classification': |
|
|
one_hot_labels = torch.zeros((len(labels), self.config.num_labels), dtype=torch.float32) |
|
|
for i, label in enumerate(labels): |
|
|
one_hot_labels[i, label] = 1 |
|
|
labels = one_hot_labels |
|
|
else: |
|
|
labels = torch.tensor(labels, dtype=torch.long) |
|
|
res['labels'] = labels |
|
|
return res |
|
|
|
|
|
def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Args: |
|
|
batch(`List[Dict[str, Any]]`): The input data in batch |
|
|
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch |
|
|
will be padded to the `longest` |
|
|
""" |
|
|
assert self.tokenizer.pad_token_id is not None |
|
|
padding_side = self.padding_side if self.is_training else 'left' |
|
|
padding_right = padding_side == 'right' |
|
|
packing_mode = self.use_megatron or self._packing and 'position_ids' in batch[0] |
|
|
res = {} |
|
|
if packing_mode: |
|
|
|
|
|
for k in ['input_ids', 'labels', 'position_ids', 'loss_scale']: |
|
|
v = self.gather_list(batch, k) |
|
|
if v: |
|
|
res[k] = [v] |
|
|
else: |
|
|
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None] |
|
|
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None] |
|
|
if inputs_embeds: |
|
|
res['inputs_embeds'] = inputs_embeds |
|
|
if input_ids: |
|
|
res['input_ids'] = input_ids |
|
|
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']: |
|
|
val = [b[key] for b in batch if b.get(key) is not None] |
|
|
if val: |
|
|
res[key] = val |
|
|
|
|
|
keys = [ |
|
|
'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids' |
|
|
] |
|
|
pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0] |
|
|
|
|
|
seq_lens = None |
|
|
for key in keys: |
|
|
if key not in res: |
|
|
continue |
|
|
for i, val in enumerate(res[key]): |
|
|
if isinstance(val, (list, tuple)): |
|
|
val = torch.tensor(val) |
|
|
elif key == 'inputs_embeds' and val.ndim == 3 or key != 'inputs_embeds' and val.ndim == 2: |
|
|
val = val[0] |
|
|
res[key][i] = val |
|
|
if not seq_lens: |
|
|
seq_lens = [seq.shape[0] for seq in res[key]] |
|
|
if not packing_mode and seq_lens and ('input_ids' in res or 'inputs_embeds' in res): |
|
|
res['attention_mask'] = [torch.ones(seq_len, dtype=torch.int64) for seq_len in seq_lens] |
|
|
if self.is_training and self.padding_side == 'left': |
|
|
res['position_ids'] = [torch.arange(seq_len, dtype=torch.int64) for seq_len in seq_lens] |
|
|
|
|
|
if self.use_megatron: |
|
|
padding_to = math.ceil(max(seq_lens) / 128) * 128 |
|
|
cp_size = self.sequence_parallel_size |
|
|
if cp_size > 1: |
|
|
padding_len = padding_to - seq_lens[0] |
|
|
position_ids = res['position_ids'][0].tolist() |
|
|
position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2)) |
|
|
res['position_ids'][0] = torch.tensor(position_ids) |
|
|
|
|
|
for key, pad_value in zip(keys, pad_values): |
|
|
if key not in res: |
|
|
continue |
|
|
if self.use_megatron and key == 'position_ids' and self.sequence_parallel_size > 1: |
|
|
pass |
|
|
elif padding_to is not None: |
|
|
padding_len = padding_to - seq_lens[0] |
|
|
if padding_len > 0: |
|
|
res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0), |
|
|
'constant', pad_value) |
|
|
res[key] = self._pad_sequence(res[key], pad_value) |
|
|
|
|
|
|
|
|
res.update(self._data_collator_mm_data(batch)) |
|
|
if not self.use_megatron and (use_torchacc() or self.sequence_parallel_size > 1): |
|
|
res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side) |
|
|
|
|
|
return res |
|
|
|
|
|
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
|
|
|
|
|
res = {} |
|
|
pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None] |
|
|
if len(pixel_values) > 0: |
|
|
res['pixel_values'] = torch.concat(pixel_values) |
|
|
|
|
|
image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None] |
|
|
if len(image_sizes) > 0: |
|
|
res['image_sizes'] = torch.concat(image_sizes) |
|
|
|
|
|
pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None] |
|
|
if len(pixel_values_videos) > 0: |
|
|
res['pixel_values_videos'] = torch.concat(pixel_values_videos) |
|
|
return res |
|
|
|
|
|
def _torchacc_xtuner_data_collator(self, res, padding_to, tokenizer, padding_side): |
|
|
|
|
|
input_ids = res.get('input_ids') |
|
|
attention_mask = res.get('attention_mask') |
|
|
labels = res.get('labels') |
|
|
loss_scale = res.get('loss_scale') |
|
|
if use_torchacc(): |
|
|
from swift.utils.torchacc_utils import pad_and_split_batch |
|
|
rank, _, world_size, _ = get_dist_setting() |
|
|
input_ids, attention_mask, labels, loss_scale = pad_and_split_batch( |
|
|
padding_to, |
|
|
input_ids, |
|
|
attention_mask, |
|
|
labels, |
|
|
loss_scale, |
|
|
self.max_length, |
|
|
tokenizer, |
|
|
rank, |
|
|
world_size, |
|
|
padding_right=padding_side == 'right') |
|
|
if self.sequence_parallel_size > 1 and input_ids is not None: |
|
|
bs, seq_len = input_ids.shape |
|
|
if 'position_ids' not in res: |
|
|
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) |
|
|
else: |
|
|
position_ids = res['position_ids'] |
|
|
assert padding_side == 'right' or bs == 1, 'Sequence parallel only support padding_side=right' |
|
|
from swift.trainers.sequence_parallel import sequence_parallel |
|
|
if sequence_parallel.world_size() > 1: |
|
|
from swift.trainers.sequence_parallel import sequence_parallel |
|
|
input_ids, _, labels, position_ids, attention_mask, loss_scale = \ |
|
|
sequence_parallel.pad_and_split_inputs( |
|
|
tokenizer, input_ids, None, labels, position_ids, attention_mask, loss_scale) |
|
|
res['position_ids'] = position_ids |
|
|
_local_var = locals() |
|
|
for key in ['input_ids', 'attention_mask', 'labels', 'loss_scale']: |
|
|
value = _local_var[key] |
|
|
if value is not None: |
|
|
res[key] = value |
|
|
return res |
|
|
|
|
|
def print_inputs(self, inputs: Dict[str, Any], tokenizer_kwargs: Optional[Dict[str, Any]] = None) -> None: |
|
|
if tokenizer_kwargs is None: |
|
|
tokenizer_kwargs = {} |
|
|
for key in [ |
|
|
'input', 'labels', 'generate', 'chosen_input', 'chosen_labels', 'rejected_input', 'rejected_labels' |
|
|
]: |
|
|
val = inputs.get(key) |
|
|
if val is None: |
|
|
val = inputs.get(f'{key}_ids') |
|
|
if val is not None: |
|
|
key_upper = key.upper() |
|
|
logger.info(f'[{key_upper}_IDS] {val}') |
|
|
if key == 'labels' and self.mode in {'seq_cls', 'embedding'}: |
|
|
continue |
|
|
if isinstance(val, (list, tuple, torch.Tensor)): |
|
|
val_str = self.safe_decode(val, **tokenizer_kwargs) |
|
|
logger.info(f'[{key_upper}] {val_str}') |
|
|
if inputs.get('loss_scale') is not None: |
|
|
val = inputs['loss_scale'] |
|
|
logger.info(f'[LOSS_SCALE] {val}') |
|
|
|
|
|
async def prepare_lmdeploy_pytorch_inputs(self, inputs) -> None: |
|
|
images = inputs.pop('images', None) or [] |
|
|
if len(images) == 0: |
|
|
return |
|
|
input_ids = inputs['input_ids'] |
|
|
idx_list = findall(input_ids, -100) |
|
|
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}' |
|
|
idx_list.insert(0, -1) |
|
|
new_input_ids = [] |
|
|
for i in range(len(idx_list) - 1): |
|
|
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] |
|
|
images[i]['offset'] = len(new_input_ids) |
|
|
new_input_ids += [images[i]['image_token_id']] * images[i]['image_tokens'] |
|
|
new_input_ids += input_ids[idx_list[-1] + 1:] |
|
|
inputs['input_ids'] = new_input_ids |
|
|
inputs['multimodal'] = images |
|
|
|
|
|
async def prepare_lmdeploy_turbomind_inputs(self, inputs: Dict[str, Any]) -> None: |
|
|
images = inputs.pop('images', None) or [] |
|
|
if len(images) == 0: |
|
|
return |
|
|
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX |
|
|
input_ids = inputs['input_ids'] |
|
|
idx_list = findall(input_ids, -100) |
|
|
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}' |
|
|
idx_list.insert(0, -1) |
|
|
new_input_ids = [] |
|
|
ranges = [] |
|
|
for i in range(len(idx_list) - 1): |
|
|
_range = [] |
|
|
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] |
|
|
_range.append(len(new_input_ids)) |
|
|
new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0] |
|
|
_range.append(len(new_input_ids)) |
|
|
ranges.append(_range) |
|
|
new_input_ids += input_ids[idx_list[-1] + 1:] |
|
|
inputs['input_embeddings'] = [image.to('cpu') for image in images] |
|
|
inputs['input_embedding_ranges'] = ranges |
|
|
inputs['input_ids'] = new_input_ids |
|
|
|
|
|
def _pad_sequence(self, sequences: List[torch.Tensor], padding_value: float = 0.) -> torch.Tensor: |
|
|
"""Pad sequence by some side |
|
|
|
|
|
Args: |
|
|
sequences: The input sequences in tensor. |
|
|
padding_value: The padding value |
|
|
|
|
|
Returns: |
|
|
A tensor after padding |
|
|
""" |
|
|
padding_side = self.padding_side if self.is_training else 'left' |
|
|
padding_right = padding_side == 'right' |
|
|
if padding_right: |
|
|
return pad_sequence(sequences, batch_first=True, padding_value=padding_value) |
|
|
|
|
|
max_len = max([s.shape[0] for s in sequences]) |
|
|
|
|
|
padded_sequences = [] |
|
|
for seq in sequences: |
|
|
pad_length = max_len - seq.shape[0] |
|
|
pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0] |
|
|
padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value) |
|
|
padded_sequences.append(padded_seq) |
|
|
|
|
|
return torch.stack(padded_sequences) |
|
|
|
|
|
def safe_decode(self, input_ids: List[int], **tokenizer_kwargs) -> str: |
|
|
if isinstance(self, Template): |
|
|
tokenizer = self.tokenizer |
|
|
placeholder_tokens = self.placeholder_tokens |
|
|
else: |
|
|
tokenizer = self |
|
|
placeholder_tokens = [] |
|
|
|
|
|
def _is_special(token: int) -> bool: |
|
|
if isinstance(token, float) or token < 0: |
|
|
return True |
|
|
return token in placeholder_tokens |
|
|
|
|
|
if isinstance(input_ids, torch.Tensor): |
|
|
input_ids = input_ids.tolist() |
|
|
if len(input_ids) == 0: |
|
|
return '' |
|
|
result_str = '' |
|
|
for i in range(len(input_ids)): |
|
|
if i == 0: |
|
|
if _is_special(input_ids[i]): |
|
|
s = 0 |
|
|
else: |
|
|
e = 0 |
|
|
continue |
|
|
if _is_special(input_ids[i]) and not _is_special(input_ids[i - 1]): |
|
|
s = i |
|
|
result_str += tokenizer.decode(input_ids[e:s], **tokenizer_kwargs) |
|
|
if not _is_special(input_ids[i]) and _is_special(input_ids[i - 1]): |
|
|
e = i |
|
|
result_str += f'[{input_ids[i - 1]} * {e - s}]' |
|
|
if _is_special(input_ids[i]): |
|
|
result_str += f'[{input_ids[i]} * {len(input_ids) - s}]' |
|
|
else: |
|
|
result_str += tokenizer.decode(input_ids[e:], **tokenizer_kwargs) |
|
|
return result_str |
|
|
|
|
|
@staticmethod |
|
|
@contextmanager |
|
|
def _patch_flash_attention_forward(modeling_module, position_ids, use_new_func: bool = False): |
|
|
_origin_flash_attention_forward = modeling_module._flash_attention_forward |
|
|
|
|
|
def _flash_attention_forward(*args, **kwargs): |
|
|
if use_new_func: |
|
|
from transformers.modeling_flash_attention_utils import (_flash_attention_forward as |
|
|
flash_attention_forward) |
|
|
if args and isinstance(args[0], nn.Module): |
|
|
args = args[1:] |
|
|
if 'is_causal' not in kwargs: |
|
|
kwargs['is_causal'] = True |
|
|
else: |
|
|
flash_attention_forward = _origin_flash_attention_forward |
|
|
kwargs['position_ids'] = position_ids |
|
|
return flash_attention_forward(*args, **kwargs) |
|
|
|
|
|
modeling_module._flash_attention_forward = _flash_attention_forward |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
modeling_module._flash_attention_forward = _origin_flash_attention_forward |
|
|
|